In [1]:
import tensorflow as tf
import numpy as np
from netCDF4 import Dataset

In [25]:
#Parameter List
parameter_list = {}

parameter_list['netCDf_loc'] = "./lorenz96_multi/DATA_sample/X40F18/all_10/nocorr_I20/assim.nc"
parameter_list['locality'] = 5
parameter_list['time_splits'] = 30
parameter_list['batch_size'] = 240
parameter_list['val_size'] = 100

In [3]:
#Getting the NetCDF files
root_grp = Dataset(parameter_list['netCDf_loc'], "r", format="NETCDF4")

#Extrating the datasets
analysis_init = root_grp["vam"]
forecast_init = root_grp["vfm"]

In [5]:
#Creation of datasets for training and validation

#Creating locality for individual state variable
def locality_creator(init_dataset):
    
    output_dataset = np.zeros((init_dataset.shape[0], init_dataset.shape[1], parameter_list['locality']))
    radius = int(parameter_list['locality'] / 2)
    
    for i in range(init_dataset.shape[1]):
        start = i - radius
        stop = i + radius
        index = np.linspace(start,stop,parameter_list['locality'], dtype='int')
        if stop >= init_dataset.shape[1]:
            stop2 = (stop + 1)%init_dataset.shape[1]
            index[:-stop2] = np.linspace(start,analysis_init.shape[1]-1,analysis_init.shape[1]-start, dtype='int')
            index[-stop2:] = np.arange(0,stop2,1,dtype='int')
        output_dataset[:,i,:] = init_dataset[:,index]

    return np.transpose(output_dataset,(1,0,2)).astype('float32')

analysis_dataset = locality_creator(analysis_init)
forecast_dataset = locality_creator(forecast_init)

In [6]:
#Creating time data splits
def split_sequences(sequences, n_steps):
    X = list()
    for i in range(len(sequences)):
        # find the end of this pattern
        end_ix = i*n_steps + n_steps
        # check if we are beyond the dataset
        if end_ix > len(sequences):
            break
        # gather input and output parts of the pattern
        seq_x = sequences[i*n_steps:end_ix, :]
        X.append(seq_x)
    return np.array(X)

In [7]:
#For serializing the tensor to a string for TFRecord
def _serialize_tensor(value):
    return tf.io.serialize_tensor(value)

In [12]:
#For writing data to the TFRecord file
def write_TFRecord(filename, dataset):
    with tf.io.TFRecordWriter(filename) as writer:
        for i in range(dataset.shape[0]):
            dataset_splits = split_sequences(dataset[i],parameter_list['time_splits'])
            for j in range(dataset_splits.shape[0]):
                data = dataset_splits[j]
                serial_string = _serialize_tensor(data)
                writer.write(serial_string.numpy())
    writer.close()

In [9]:
#For reading the TFRecord File
def read_TFRecord(filename):
    return tf.data.TFRecordDataset(filename)

#For parsing the value from string to float32
def _parse_tensor(value):
    return tf.io.parse_tensor(value, out_type=tf.float32)

In [13]:
write_TFRecord('analysis.tfrecord', analysis_dataset)
write_TFRecord('forecast.tfrecord', forecast_dataset)

In [19]:
#Reading the TFRecord files
anal_file = read_TFRecord('analysis.tfrecord')
fore_file = read_TFRecord('forecast.tfrecord')

#Parsing the dataset
anal_file = anal_file.map(_parse_tensor)
fore_file = fore_file.map(_parse_tensor)

In [20]:
#Zipping the files
dataset = tf.data.Dataset.zip((anal_file, fore_file))

#Shuffling the dataset
dataset = dataset.shuffle(100000)
dataset = dataset.batch(batch_size=parameter_list['batch_size'])

In [21]:
for i,j in dataset.take(1):
    print(j)

tf.Tensor(
[[[-1.7398463  -3.854305    2.9147084  16.174902    5.367921  ]
  [-4.6436844  -2.077001    0.65262926 15.686306    1.3139693 ]
  [-4.058007   -0.788952    0.8208224  16.103973   -1.8078716 ]
  ...
  [ 8.991234   14.651135   -0.8641131   9.0588455   5.509129  ]
  [ 8.639453   15.103537    0.684075    8.81687     5.3432107 ]
  [ 8.833881   14.964432    1.0165862   7.9007444   6.395882  ]]

 [[ 9.5995455  10.341174    0.05193586 -8.029909    3.750391  ]
  [10.430953    6.9138894  -4.4073424  -5.5132213   0.25255945]
  [10.695385    0.08031861 -5.9061074  -2.7605069  -2.7055893 ]
  ...
  [13.661081    3.9851658   2.7625027   5.3202868   7.7867455 ]
  [13.499271    2.283166    2.1953018   6.693972   10.515221  ]
  [12.797866    1.2229387   2.23049     7.863687   11.2533    ]]

 [[ 5.485619   11.169187    5.6348686  -2.9273381  -5.6248546 ]
  [ 6.379557   11.873319    2.0390596  -3.5120602  -4.1368504 ]
  [ 6.7938714  11.4493265  -3.14341    -2.9815614  -2.8750167 ]
  ...
  [-9.0

In [26]:
#For creating Train and Validation datasets

def train_val_creator(dataset, val_size):
    val_dataset = dataset.take(val_size)
    train_dataset = dataset.skip(val_size)
    return train_dataset, val_dataset

In [27]:
train_dataset, val_dataset = train_val_creator(dataset, parameter_list['val_size'])