In [1]:
import sys
sys.path.insert(1,"/home1/07064/tg863631/anaconda3/envs/CbrainCustomLayer/lib/python3.6/site-packages") #work around for h5py
from cbrain.imports import *
from cbrain.cam_constants import *
from cbrain.utils import *
from cbrain.layers import *
from cbrain.data_generator import DataGenerator
import tensorflow as tf
from tensorflow import math as tfm
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
# import tensorflow_probability as tfp
import xarray as xr
import numpy as np
from cbrain.model_diagnostics import ModelDiagnostics
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.image as imag
import scipy.integrate as sin
import matplotlib.ticker as mticker
import pickle
from tensorflow.keras import layers
from tensorflow.keras.losses import *
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import datetime
from cbrain.climate_invariant import *
import yaml


## Data Generators

In [2]:
from cbrain.imports import *
from cbrain.utils import *
from cbrain.normalization import *
import h5py
from sklearn.preprocessing import OneHotEncoder

In [3]:
class DataGeneratorClassification(tf.keras.utils.Sequence):
    def __init__(self, data_fn, input_vars, output_vars, percentile_path, data_name,
                 norm_fn=None, input_transform=None, output_transform=None,
                 batch_size=1024, shuffle=True, xarray=False, var_cut_off=None, normalize_flag=True, bin_size=100):
        # Just copy over the attributes
        self.data_fn, self.norm_fn = data_fn, norm_fn
        self.input_vars, self.output_vars = input_vars, output_vars
        self.batch_size, self.shuffle = batch_size, shuffle
        self.bin_size = bin_size
        self.percentile_bins = load_pickle(percentile_path)['Percentile'][data_name]
        self.enc = OneHotEncoder(sparse=False)
        classes = np.arange(self.bin_size+2)
        self.enc.fit(classes.reshape(-1,1))
        # Open datasets
        self.data_ds = xr.open_dataset(data_fn)
        if norm_fn is not None: self.norm_ds = xr.open_dataset(norm_fn)
     # Compute number of samples and batches
        self.n_samples = self.data_ds.vars.shape[0]
        self.n_batches = int(np.floor(self.n_samples) / self.batch_size)

        # Get input and output variable indices
        self.input_idxs = return_var_idxs(self.data_ds, input_vars, var_cut_off)
        self.output_idxs = return_var_idxs(self.data_ds, output_vars)
        self.n_inputs, self.n_outputs = len(self.input_idxs), len(self.output_idxs)
        
                # Initialize input and output normalizers/transformers
        if input_transform is None:
            self.input_transform = Normalizer()
        elif type(input_transform) is tuple:
            ## normalize flag added by Ankitesh
            self.input_transform = InputNormalizer(
                self.norm_ds,normalize_flag, input_vars, input_transform[0], input_transform[1], var_cut_off)
        else:
            self.input_transform = input_transform  # Assume an initialized normalizer is passed
            
            
        if output_transform is None:
            self.output_transform = Normalizer()
        elif type(output_transform) is dict:
            self.output_transform = DictNormalizer(self.norm_ds, output_vars, output_transform)
        else:
            self.output_transform = output_transform  # Assume an initialized normalizer is passed

        # Now close the xarray file and load it as an h5 file instead
        # This significantly speeds up the reading of the data...
        if not xarray:
            self.data_ds.close()
            self.data_ds = h5py.File(data_fn, 'r')
    
    def __len__(self):
        return self.n_batches
    
    # TODO: Find a better way to implement this, currently it is the hardcoded way.
    def _transform_to_one_hot(self,Y):
        '''
            return shape = batch_size X 64 X bin_size
        '''

        Y_trans = []
        out_vars = ['PHQ','TPHYSTND','FSNT', 'FSNS', 'FLNT', 'FLNS']
        var_dict = {}
        var_dict['PHQ'] = Y[:,:30]
        var_dict['TPHYSTND'] = Y[:,30:60]
        var_dict['FSNT'] = Y[:,60]
        var_dict['FSNS'] = Y[:,61]
        var_dict['FLNT'] = Y[:,62]
        var_dict['FLNS'] = Y[:,63]
        perc = self.percentile_bins
        for var in out_vars[:2]:
            all_levels_one_hot = []
            for ilev in range(30):
                bin_index = np.digitize(var_dict[var][:,ilev],perc[var][ilev])
                one_hot = self.enc.transform(bin_index.reshape(-1,1))
                all_levels_one_hot.append(one_hot)
            var_one_hot = np.stack(all_levels_one_hot,axis=1) 
            Y_trans.append(var_one_hot)
        for var in out_vars[2:]:
            bin_index = np.digitize(var_dict[var][:], perc[var])
            one_hot = self.enc.transform(bin_index.reshape(-1,1))[:,np.newaxis,:]
            Y_trans.append(one_hot)
        
        Y_concatenated = np.concatenate(Y_trans,axis=1)
        transformed = {}
        for i in range(64):
            transformed[f'output_{i}'] = Y_concatenated[:,i,:]
        return transformed
            
        
        
        
    def __getitem__(self, index):
        # Compute start and end indices for batch
        start_idx = index * self.batch_size
        end_idx = start_idx + self.batch_size

        # Grab batch from data
        batch = self.data_ds['vars'][start_idx:end_idx]

        # Split into inputs and outputs
        X = batch[:, self.input_idxs]
        Y = batch[:, self.output_idxs]
        # Normalize
        X = self.input_transform.transform(X)
        Y = self.output_transform.transform(Y) #shape batch_size X 64 
        Y = self._transform_to_one_hot(Y)
        return X, Y

    def on_epoch_end(self):
        self.indices = np.arange(self.n_batches)
        if self.shuffle: np.random.shuffle(self.indices)

In [4]:
tf.debugging.set_log_device_placement(False)


In [5]:
scale_dict = load_pickle('/export/nfs0home/ankitesg/CBrain_project/CBRAIN-CAM/nn_config/scale_dicts/009_Wm2_scaling.pkl')

In [6]:
TRAINFILE = 'CI_SP_M4K_train_shuffle.nc'
VALIDFILE = 'CI_SP_M4K_valid.nc'
NORMFILE = 'CI_SP_M4K_NORM_norm.nc'
data_path = '/scratch/ankitesh/data/'

In [7]:
train_gen = DataGeneratorClassification(
    data_fn=f'{data_path}{TRAINFILE}', 
    input_vars=['QBP','TBP','PS', 'SOLIN', 'SHFLX', 'LHFLX'], 
    output_vars=['PHQ','TPHYSTND','FSNT', 'FSNS', 'FLNT', 'FLNS'], 
    percentile_path='/export/nfs0home/ankitesg/data/percentile_data.pkl', 
    data_name = 'M4K',
    input_transform = ('mean', 'maxrs'),
    output_transform = scale_dict,
    norm_fn = f'{data_path}{NORMFILE}',
    batch_size=1024
)

In [8]:
valid_gen = DataGeneratorClassification(
    data_fn=f'{data_path}{VALIDFILE}', 
    input_vars=['QBP','TBP','PS', 'SOLIN', 'SHFLX', 'LHFLX'], 
    output_vars=['PHQ','TPHYSTND','FSNT', 'FSNS', 'FLNT', 'FLNS'], 
    percentile_path='/export/nfs0home/ankitesg/data/percentile_data.pkl', 
    data_name = 'M4K',
    input_transform = ('mean', 'maxrs'),
    output_transform = scale_dict,
    norm_fn = f'{data_path}{NORMFILE}',
    batch_size=1024
)

## Model (mult-output classification)

In [9]:
bin_size = 100

In [10]:
#this defines a single branch out of 64 branches
def define_single_output_branch(densout,out_index):
    out = Dense(bin_size+2, activation='softmax',name=f"output_{out_index}")(densout)
    return out

In [11]:
inp = Input(shape=(64,))
densout = Dense(128, activation='linear')(inp)
densout = LeakyReLU(alpha=0.3)(densout)
for i in range (4):
    densout = Dense(128, activation='linear')(densout)
    densout = LeakyReLU(alpha=0.3)(densout)
densout = Dense(32, activation='linear')(densout)
densout = LeakyReLU(alpha=0.3)(densout)
all_outputs = [define_single_output_branch(densout,i) for i in range(64)]
model = tf.keras.models.Model(inputs=inp, outputs=all_outputs)

In [12]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 64)]         0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 128)          8320        input_1[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu (LeakyReLU)         (None, 128)          0           dense[0][0]                      
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 128)          16512       leaky_re_lu[0][0]                
______________________________________________________________________________________________

In [13]:
losses = {}
for i in range(64):
    losses[f'output_{i}'] = "categorical_crossentropy"

In [14]:
model.compile(tf.keras.optimizers.Adam(), loss=losses, metrics=["accuracy"])
path_HDF5 = '/scratch/ankitesh/models/'
earlyStopping = EarlyStopping(monitor='val_loss', patience=10, verbose=0, mode='min')
mcp_save = ModelCheckpoint(path_HDF5+'BF_Classification.hdf5',save_best_only=True, monitor='val_loss', mode='min')

In [15]:
with tf.device('/gpu:0'):
    Nep = 5
    model.fit_generator(train_gen, epochs=Nep, validation_data=valid_gen,\
                  callbacks=[earlyStopping, mcp_save])

Epoch 1/5
Epoch 2/5


    1/41376 [..............................] - ETA: 8:34:52 - loss: 228.3732 - output_0_loss: 1.1921e-07 - output_1_loss: 1.1921e-07 - output_2_loss: 4.0089 - output_3_loss: 4.0924 - output_4_loss: 4.4826 - output_5_loss: 4.4654 - output_6_loss: 4.3444 - output_7_loss: 4.1900 - output_8_loss: 4.0336 - output_9_loss: 3.8898 - output_10_loss: 3.9298 - output_11_loss: 3.9665 - output_12_loss: 3.8927 - output_13_loss: 3.8229 - output_14_loss: 3.8019 - output_15_loss: 3.8593 - output_16_loss: 3.8520 - output_17_loss: 3.8629 - output_18_loss: 3.8552 - output_19_loss: 3.8885 - output_20_loss: 3.9031 - output_21_loss: 3.9810 - output_22_loss: 4.0189 - output_23_loss: 4.0956 - output_24_loss: 4.1426 - output_25_loss: 4.1414 - output_26_loss: 4.1959 - output_27_loss: 4.2625 - output_28_loss: 4.2756 - output_29_loss: 4.2755 - output_30_loss: 1.7757 - output_31_loss: 1.8820 - output_32_loss: 3.9489 - output_33_loss: 3.9174 - output_34_loss: 2.4818 - output_35_loss: 2.7193 - output_36_loss: 3.1699 

  157/41376 [..............................] - ETA: 7:39:43 - loss: 229.1477 - output_0_loss: 1.2023e-07 - output_1_loss: 1.2019e-07 - output_2_loss: 4.0491 - output_3_loss: 4.1242 - output_4_loss: 4.4825 - output_5_loss: 4.4628 - output_6_loss: 4.3646 - output_7_loss: 4.2477 - output_8_loss: 4.0496 - output_9_loss: 3.9012 - output_10_loss: 3.8998 - output_11_loss: 3.9276 - output_12_loss: 3.8944 - output_13_loss: 3.8448 - output_14_loss: 3.8227 - output_15_loss: 3.8444 - output_16_loss: 3.8922 - output_17_loss: 3.9153 - output_18_loss: 3.9139 - output_19_loss: 3.9243 - output_20_loss: 3.9496 - output_21_loss: 3.9906 - output_22_loss: 4.0505 - output_23_loss: 4.1166 - output_24_loss: 4.1418 - output_25_loss: 4.1814 - output_26_loss: 4.2247 - output_27_loss: 4.2559 - output_28_loss: 4.2757 - output_29_loss: 4.2733 - output_30_loss: 1.8130 - output_31_loss: 1.9815 - output_32_loss: 3.9026 - output_33_loss: 3.9310 - output_34_loss: 2.5313 - output_35_loss: 2.7096 - output_36_loss: 3.1597 

KeyboardInterrupt: 

In [58]:
tf.config.experimental.list_physical_devices()


[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:XLA_CPU:0', device_type='XLA_CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU'),
 PhysicalDevice(name='/physical_device:XLA_GPU:0', device_type='XLA_GPU'),
 PhysicalDevice(name='/physical_device:XLA_GPU:1', device_type='XLA_GPU'),
 PhysicalDevice(name='/physical_device:XLA_GPU:2', device_type='XLA_GPU')]

## RH, T-TNS Transformation

In [None]:
inp = Input(shape=(64,))
densout = Dense(128, activation='linear')(inp)
densout = LeakyReLU(alpha=0.3)(densout)
for i in range (6):
    densout = Dense(128, activation='linear')(densout)
    densout = LeakyReLU(alpha=0.3)(densout)
densout = Dense(64*(bin_size+2), activation='sigmoid')(densout)
model = tf.keras.models.Model(inp, densout)