In [4]:
import numpy as np
import torch
import sympy
import torch.nn.functional as F
from backend.utils import *
import pickle
import configparser
import yaml
import torch.optim as optim
from sklearn.preprocessing import MinMaxScaler
from backend.lbfgsnew import *

This notebook trains the PDE-Net for the aquaplanet data(336 x 48 x 30, averged over latxlon)  
Specifically, this notebook aims to learn the equation for QBP.   
There are some chnages to the vanilla PDE-Net (Notebook 02)  
- This incorporates undependent channels (LHFLX,SHFLX )
- dependent channels but equation shouldn't be learned (TBP)
- Added option to train via LBFGS (more stable).
- Changes to add the other derivates before taking a timestep


In [9]:
ds =  xr.open_dataset('/oasis/scratch/comet/ankitesh/temp_project/PDEExp/data/preprocessed_pde_cluster_1.nc')

In [10]:
ds

In [5]:
## keep a num_channel_recoverable (for tacking LHFLX,SHFLX variables)
class ClimateData(DataModel):
    def __init__(self,data_name,Nx,Nt,dt,dx,batch_size,channel_names,\
                 non_eqn_depen_channel, non_eqn_channel, data_file,scaling=1,total_points=-1):
        '''
            currently only batch size of 1 is supported
            Nx is the interpolated dimension 
            Nt max can be 240 (this will take the middle 240 from 7 days i.e skip the 1st day and the last day)
            dt is fixed which is 1800 #can't change 
            dx (the lev tilde will be interpolated from 0 to Nx*dx)(should be 1000/Nx)
            scaling (the scaling value by which the lev coordinates will be scaled, dx,Lx will be scaled)
            batch_size: as of now max is 48 (1 year data)
            data_dir: location of the xarray data
            channel_names: list of channels (variables) ['QBP'] (time x batch x lev)
            non_eqn_depen_channel: list of channels whose eqn shouldn't be discovered, but dependent ['TBP']
            non_eqn_channel: list of channels whose eqn shouldn't be discovered, but independent ['LHFLX','SHFLX']
            scaling: The scaling factor to scale the lev dimension (just so that dx can be increased)
            
        '''
        
        self.scaling = scaling
        dx = scaling*dx
        super(ClimateData,self).__init__(data_name=data_name,Nt=Nt, Nx=Nx, dt=dt, dx=dx)
        self.batch_size = batch_size
        self.ds = xr.open_dataset(data_file)
        self.channel_names = channel_names
        self.non_eqn_depen_channel = non_eqn_depen_channel
        self.n_non_eqn_depen = len(non_eqn_depen_channel)
        self.diff_dict = {"TAP":"DTV","QAP":"VD01","TBP":"TPHYSTND","QBP":"PHQ","TCRM":"TCDTAdiab","QCRM":"QCDTAdiab"}
        self.non_eqn_channel = non_eqn_channel
        self.n_eqn = len(channel_names)
        self.n_non_eqn = len(non_eqn_channel)
        if total_points==-1 or total_points > self.ds.batch_size.size :
            total = self.ds.batch_size.size
        else:
            total = total_points
        
        self.n_batch = total//batch_size
        self.curr_batch = 0
        self.sub, self.div = self._getGlobalStats()
        
        ##future dict is the 
        self.batch_start = 0
        self.batch_end = batch_size
        self._data,self._U_noneqn_dep,self._U_noneqn,self.dict_ = self._generateData()
    
    def _getGlobalStats(self):
        return self.ds.min(),self.ds.max()-self.ds.min()

    def _generateData(self):
        start = (336-self.Nt)//2
        end = start+self.Nt
        
        lev = self.scaling*self.ds.lev.values
        
        U = []
        eqn_channels_vars = []
        non_eqn_depen_channel_vars = []
        non_eqn_channels_vars = []
        diff_dict_key_name = list(self.diff_dict.values())
        diff_values_vars = []
        diff_dict_keys = list(self.diff_dict.keys())

        #shape the eqn vars
        for var in self.channel_names:
            v = self.ds[var][start:end,self.batch_start:self.batch_end].values[:,:,np.newaxis,:]
            v = (v-float(self.sub[var]))/float(self.div[var])
            eqn_channels_vars.append(v)
            
        for var in self.non_eqn_depen_channel:
            v = self.ds[var][start:end,self.batch_start:self.batch_end].values[:,:,np.newaxis,:]
            v = (v-float(self.sub[var]))/float(self.div[var])
            non_eqn_depen_channel_vars.append(v)
        
        for i,var in enumerate(diff_dict_key_name):
            v = self.ds[var][start:end,self.batch_start:self.batch_end].values[:,:,np.newaxis,:]
            if float(self.div[var]) !=0:
                v = (v-float(self.sub[var]))/float(self.div[var])
            diff_values_vars.append(v)
            
        n_diff_values_vars = len(diff_values_vars)
        #shape the non eqn vas
        for var in self.non_eqn_channel:
            v = self.ds[var][start:end,self.batch_start:self.batch_end].values[:,:,np.newaxis]
            v = np.tile(v,[1,1,self.Nx])
            v = v[:,:,np.newaxis,:]
            v = (v-float(self.sub[var]))/float(self.div[var])
            non_eqn_channels_vars.append(v)
        
        # interpolate the eqn variables
        self.lev_tilde_after =  np.linspace(0,self.Lx,num=self.Nx)

        eqn_channels_vars_interp = []
        non_eqn_depen_channel_vars_interp = []
        diff_values_vars_interp = []

        for i,v in enumerate(eqn_channels_vars+diff_values_vars+non_eqn_depen_channel_vars):
            batch_size = v.shape[1]
            v_interp = np.zeros(v.shape[:-1]+(self.Nx,))
            
            for t in range(self.Nt):
                for b in range(batch_size):
                    interp = np.interp(self.lev_tilde_after,lev,v[t][b][0])
                    v_interp[t][b][0] = interp
        
            if i<self.n_eqn:
                eqn_channels_vars_interp.append(v_interp) 
                
            elif i<self.n_eqn+n_diff_values_vars:
                diff_values_vars_interp.append(v_interp)
            else:
                non_eqn_depen_channel_vars_interp.append(v_interp)
                
                
        
        U = np.concatenate(eqn_channels_vars_interp,axis=2)
        dict_ = {"AP":[],"BP":[],"CRM":[]}
        dict_keys = list(dict_.keys())
        for i in range(3):
            dict_[dict_keys[i]] = torch.from_numpy(np.concatenate(diff_values_vars_interp[2*i:2*i+2],axis=2)).type(torch.DoubleTensor)

        _U_noneqn = np.concatenate(non_eqn_channels_vars,axis=2)
        _U_noneqn_dep = np.concatenate(non_eqn_depen_channel_vars_interp,axis=2)
        return U,_U_noneqn_dep,_U_noneqn,dict_
        

    def data(self,step_num):
        '''
            fetches the next batch
        '''

        return torch.from_numpy(self._data[:step_num]),torch.from_numpy(self._U_noneqn_dep[:step_num]),\
                torch.from_numpy(self._U_noneqn[:step_num]).type(torch.DoubleTensor),\
                self.dict_
                    
    
        
    def visualize(self,b,subset=True):
        #displays ith batch plot
        ##time x batch_size x 1 x x_dim 
        u = []
        v = []
        x,t = self._getMeshPoints()
        disp_arr = self._data
            
        for i in range(len(self._data)):
            u.append(np.array(disp_arr[i][b][0]).reshape(-1))
            v.append(np.array(disp_arr[i][b][1]).reshape(-1))
            
        u = np.array(u)
        v = np.array(v)
        
        
        fig = plt.figure(figsize=(8,8))
        ax = fig.add_subplot(111, projection='3d')
        X,T = np.meshgrid(x,t)        
        surf = ax.plot_surface(T, X, u)
        plt.xlabel("Time")
        plt.ylabel("X")
        plt.title(self.channel_names[0])
        plt.show()
        
        
        fig = plt.figure(figsize=(8,8))
        ax = fig.add_subplot(111, projection='3d')
        X,T = np.meshgrid(x,t)        
        surf = ax.plot_surface(T, X, v)
        plt.xlabel("Time")
        plt.ylabel("X")
        plt.title(self.channel_names[1])
        plt.show()
        
        


In [6]:
class SymNet(torch.nn.Module):
    def __init__(self, n_hidden, n_deriv_channel, deriv_channel_names=None, normalization_weight=None):
        super(SymNet, self).__init__()
        self.n_hidden = n_hidden
        self.n_deriv_channel = n_deriv_channel
        if deriv_channel_names is None:
            deriv_channel_names = list('u_'+str(i) for i in range(self.n_deriv_channel))
        self.deriv_channel_names = deriv_channel_names
        layers = []
        for k in range(n_hidden):
            module = torch.nn.Linear(n_deriv_channel+k,2)
            self.add_module('layer'+str(k), module)
            layers.append(self.__getattr__('layer'+str(k)))
        module = torch.nn.Linear(n_deriv_channel+n_hidden, 1)
        self.add_module('layer_final', module)
        layers.append(self.__getattr__('layer_final'))
        self.layers = tuple(layers)
        
        
    def forward(self, inputs):

        outputs = inputs.type(torch.FloatTensor)
        for k in range(self.n_hidden):
            
            o = self.layers[k](outputs)
            outputs = torch.cat([outputs,o[...,:1]*o[...,1:]], dim=-1)
        outputs = self.layers[-1](outputs)
        return outputs[...,0]
    
    
    
    def _cast2symbol(self,layer):
        weight,bias = layer.weight.data.cpu().numpy(), \
                    layer.bias.data.cpu().numpy()
        weight,bias = sympy.Matrix(weight),sympy.Matrix(bias)
        return weight,bias

    def _sympychop(self,o, calprec):
        for i in range(o.shape[0]):
            cdict = o[i].expand().as_coefficients_dict()  
            o_i = 0
            for k,v in cdict.items():
                if abs(v)>0.1**calprec:
                    o_i = o_i+k*v
            o[i] = o_i
        return o

    def getEquation(self,calprec=6):
        ## assume symnet model

        deriv_channels = sympy.symbols(self.deriv_channel_names)
        deriv_channels = sympy.Matrix([deriv_channels,])
        for i in range(self.n_hidden):
            weight,bias = self._cast2symbol(self.layers[i])
            o = weight*deriv_channels.transpose()+bias
            o = self._sympychop(o, calprec) #ignores very low params terms
            deriv_channels = list(deriv_channels)+[o[0]*o[1],]
            deriv_channels = sympy.Matrix([deriv_channels,])

        weight,bias = self._cast2symbol(self.layers[-1])
        o = (weight*deriv_channels.transpose()+bias)
        o = self._sympychop(o,calprec)

        return o[0]


In [7]:
class FD1D(torch.nn.Module):

    def __init__(self,dx, kernel_size, diff_order,acc_order,constraint='free'):
        super(FD1D, self).__init__()
        self.dx = dx
        self.kernel_size = kernel_size
        self.diff_order = diff_order
        self.acc_order = acc_order
        self.constraint = constraint
        ## will only be used in case of contraint = 'free'
        if constraint=='free':
            self._kernel = (getKernelTorch(diff_order,acc_order+1,dim=kernel_size,scheme='central')/(dx**diff_order)).type(torch.DoubleTensor)
            
        ##Moment to kernel and vice versa
        if constraint=='moment':
            self.m2k = M2K(kernel_size)
            self.k2m = K2M(kernel_size)
            #define moment matrix
            moment = torch.DoubleTensor(kernel_size).zero_()
            moment[diff_order] = 1
            moment = moment.reshape(1,1,-1)
            self.moment = torch.nn.Parameter(moment) ## now weights will be updated on this
            ##create a mask for gradeint hook
            self.gradient_mask = self._getGradientMask()
            ##register hook to the moment matrix
            self.moment.register_hook(lambda grad: grad.mul_(self.gradient_mask))
        
        
    @property
    def kernel(self):
        if self.constraint == 'moment':
            kernel = self.m2k(self.moment)/(self.dx**self.diff_order)
        else:
            kernel = self._kernel
            
            
        return kernel
    
    def _getGradientMask(self):
        gradient_mask = torch.ones(self.kernel_size,dtype=torch.double)
        order_bank = np.arange(self.kernel_size)
        for j in range(self.diff_order+self.acc_order):
            gradient_mask[order_bank[j]] = 0
        gradient_mask = gradient_mask.reshape(1,1,-1)
        return gradient_mask
        
    def forward(self,inputs):

        inp_padded = padInputTorch(inputs,self.diff_order,self.acc_order+1,dim=self.kernel_size) #batch_size x n_channels x (x_dim+padded)
        conv = F.conv1d(inp_padded.type(torch.DoubleTensor),self.kernel)
        return conv
        

In [33]:

class PdeNet(torch.nn.Module):
    def __init__(self,dt, dx, kernel_size, max_diff_order, n_channel,channel_names,
                 n_non_eqn_channels,non_eqn_channel_names,dependent_channels,
                 acc_order=2,n_hidden=2,\
                constraint='free'):

        super(PdeNet, self).__init__()
        self.dx = dx
        self.dt = dt
        self.kernel_size = kernel_size
        self.max_diff_order = max_diff_order
        self.n_channel = n_channel
        self.channel_names = channel_names
        self.n_non_eqn_channels = n_non_eqn_channels
        self.non_eqn_channel_names = non_eqn_channel_names
        self.dependent_channels = dependent_channels
        self.n_dependent_channels = len(dependent_channels)
        self.n_hidden = n_hidden
        self.constraint = constraint
                        
        if not np.iterable(acc_order):
            acc_order = [acc_order,]*(self.max_diff_order+1)
            
        self.acc_order = acc_order
        
        #conv operation
        for i in range(max_diff_order+1):
            kernel = FD1D(dx,kernel_size,i,acc_order[i],constraint=constraint)
            self.add_module('fd'+str(i), kernel) #finite difference of order
            
        #symnet 
        c = channel_names.split(',')
        derivative_channels = []
        for ch in c+self.dependent_channels:
            for k in range(max_diff_order+1):
                derivative_channels.append(ch+'_'+str(k))
                
        '''CHANGED: for climate''' 
        for ch in non_eqn_channel_names:
            derivative_channels.append(ch)
        
            
        self.derivative_channels = derivative_channels 
        all_symnets = []
        for k in range(self.n_channel):
            self.add_module("symnet_"+str(k),SymNet(n_hidden,len(derivative_channels), deriv_channel_names=derivative_channels))
            all_symnets.append(self.__getattr__('symnet_'+str(k)))
        self.all_symnets = all_symnets
    
    @property
    def fds(self):
        for i in range(self.max_diff_order+1):
            yield self.__getattr__('fd'+str(i))
                
    def multistep(self,inputs,non_eqn_depe,non_eqn_t,diff_values,step_num):
        #pass it throught the kernels then the symmnet to 
        '''
        Takes multistep through the whole PDE Net.
        '''
        u = inputs
        for i in range(step_num):
            uadd = self.RightHandItems(u,non_eqn_depe,non_eqn_t)#will take a dt step from u using the network 
            u = u + self.dt*(uadd+diff_values[0]+diff_values[1]) #only for QBP
        return u
    
    def symNetParams(self):
        params = []
        for symnet in self.all_symnets:
            params += list(symnet.parameters())
        return params
    
    def diffParams(self):
        params = []
        for fd in self.fds:
            params += list(fd.parameters())
        return params

    def RightHandItems(self,u,non_eqn_depe,non_eqn_t):
        
        #convolve the u with the derivative kernals to get the different derivatives 
        #batch_size x n_channels x X_dim
        derives = []
        u_split = u.split(1,dim=1)
        non_eqn_depe_split = non_eqn_depe.split(1,dim=1)
        for ch in range(self.n_channel):       
            for i in range(self.max_diff_order+1):
                fd_obj = self.__getattr__('fd'+str(i))
                deriv_channel_ch_order_i = fd_obj(u_split[ch])
                derives.append(deriv_channel_ch_order_i)   
        
        for ch in range(self.n_dependent_channels):
            for i in range(self.max_diff_order+1):
                fd_obj = self.__getattr__('fd'+str(i))
                deriv_channel_ch_order_i = fd_obj(non_eqn_depe_split[ch])
                derives.append(deriv_channel_ch_order_i)   
                
                
            
        U = torch.cat(derives, dim=1) #batch_size x n_derivatives x X_dim 
        U = torch.cat([U,non_eqn_t],dim=1)
        
        #symnet_output = (batch_size x X_dim x n_derivatives)
        u_outs = []
        for symnet in self.all_symnets:
            u_symnet = symnet(U.permute(0,2,1)) #batch_size x X_dim x n_derivatives
            u_out = u_symnet.unsqueeze_(1)
            u_outs.append(u_out)
        u_out = torch.cat(u_outs,axis=1)#only 1 channel as there will only be 1 symnet
        return u_out
        
        
    def forward(self,inputs,non_eqn_depe,non_eqn_t,diff_values,step_num):
        '''
            inputs of shape batch_size x n_channels x X_dim
            step_nums = number of dt blocks to calculate the inputs for
        '''
        return self.multistep(inputs,non_eqn_depe,non_eqn_t,diff_values,step_num)

    

In [34]:
def symnetRegularizeLoss(model):
    loss = 0
    s = 1e-2
    for p in model.symNetParams():
        p = p.abs()
        loss = loss+((p<s).to(p)*0.5/s*p**2).sum()+((p>=s).to(p)*(p-s/2)).sum()
    return loss

In [35]:
def momentRegularizeLoss(model):
    loss = 0
    s = 1e-2
    for p in model.diffParams():
        p = p.abs()
        loss = loss+((p<s).to(p)*0.5/s*p**2).sum()+((p>=s).to(p)*(p-s/2)).sum()

    return loss

In [36]:
#global names are all the parameters
def modelLoss(model,u_obs,non_eqn_dep_obs,non_eqn_obs,diff_dict,config,block):
    '''
        Returns the loss value for so that it can be given to an optimizer
        Inputs:
            u_obs (batch_size x n_channels x X_dim)
            blocks is stepnum
    '''
    sparsity = config['sparsity']
    momentsparsity = config['momentsparsity']
    
    
    if block==0: #warmup
        sparsity = 0
        momentsparsity = 0
    step_num = block if block>=1 else 1
    dt = config['dt']
    data_loss = 0
    symnet_loss = symnetRegularizeLoss(model)
    moment_loss = momentRegularizeLoss(model)
    ut = u_obs[0]
    loss_mse = 0
    mse_loss = torch.nn.MSELoss()
    for steps in range(1,step_num+1):
        non_eqn_t = non_eqn_obs[steps-1]
        non_eqn_dep_t = non_eqn_dep_obs[steps-1]
        diff_values = [diff_dict['AP'][steps-1][:,-1:,:],diff_dict['CRM'][steps][:,-1:,:]] #for QBP
        ut_next_predicted = model(ut,non_eqn_dep_t,non_eqn_t,diff_values,step_num=1) #take one step from this point #only 1 channel(QBP)
        loss_mse_t = mse_loss(ut_next_predicted,u_obs[steps])
        loss_mse += loss_mse_t
        data_loss += (loss_mse_t/dt**2)/step_num
        ut = ut_next_predicted
        

    loss = data_loss+stepnum*sparsity*symnet_loss+stepnum*momentsparsity*moment_loss
    if torch.isnan(loss):
#         raise "Loss Nan"
        loss = (torch.ones(1,requires_grad=True)/torch.zeros(1)).to(loss)
    return loss,data_loss,symnet_loss,moment_loss,loss_mse


In [37]:
##modify channel names and length
def setenv(config): #return model and datamodel
    model = PdeNet(config['dt'],config['dx']*config['scaling'],config['kernel_size'],config['max_diff_order']\
                   ,config['n_channels'],config['channel_names'],
                   config['n_non_eqn_channels'],config['non_eqn_channels'],
                   config['dependent_channels'],
                   config['acc_order'],config['n_hidden_layers'],config['constraint'])
    
    data_model =  ClimateData(config['dataname'],config['Nx'],config['Nt']
                              ,config['dt'],config['dx'],config['batch_size'],config['channel_vars'],\
                              config['dependent_channels'],
                              config['non_eqn_channels'],config['data_file'],config['scaling'],config["total"])
        
    #possible some callbacks
    callbacks = None
    return model,data_model,callbacks

In [38]:
with open("config_AP_climate.yaml", 'r') as stream:
    config = yaml.safe_load(stream)

In [39]:
config

{'name': 'Climate BP',
 'dt': 1800,
 'dx': 10,
 'scaling': 30,
 'blocks': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
 'kernel_size': 5,
 'max_diff_order': 2,
 'acc_order': 2,
 'n_hidden_layers': 2,
 'n_channels': 1,
 'n_non_eqn_channels': 4,
 'dataname': 'Climate BP',
 'batch_size': 12,
 'total': -1,
 'channel_names': 'QBP',
 'channel_vars': ['QBP'],
 'non_eqn_channels': ['LHFLX', 'SHFLX', 'PS', 'SOLIN'],
 'dependent_channels': ['TBP'],
 'data_file': '/Users/ankitesh/Desktop/RA_data/preprocessed_pde_cluster_1.nc',
 'Nt': 240,
 'Nx': 100,
 'sparsity': 0.005,
 'momentsparsity': 0.004,
 'epochs': 1000,
 'model_dir': '/Users/ankitesh/Desktop/RA_data/',
 'seed': -1,
 'learning_rate': 0.01,
 'constraint': 'moment',
 'optimizer': 'LBFGS'}

In [46]:
blocks = config['blocks']
dt = config['dt']
dx = config['dx']
epochs = config['epochs']
lr = config['learning_rate']
opti = config['optimizer']

In [47]:
model,data_model,callbacks = setenv(config)

In [48]:
def getOptimizer(config):
    if config['optimizer'] == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=lr)
    if config['optimizer'] == 'LBFGS':
        optimizer = LBFGSNew(model.parameters(), history_size=7, max_iter=10, line_search_fn=True,batch_mode=True)

    return optimizer

In [49]:
##optimizer
optimizer = getOptimizer(config)
decayRate = 0.96
my_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)

In [192]:
for block in blocks:
    print('[PRINT] block:',block)
    if block==0:
        print('[PRINT] Warmum Stage')
    stepnum = block if block>=1 else 1
    #get the data at this time #shape [block,batch,channel,X_dim]
    u_obs,non_eqn_dep_t,non_eqn_t,diff_dict = data_model.data(stepnum+1) #np array of stepnum elements 
    
    for epoch in range(epochs):
            #zero grad 
        def closure():
            optimizer.zero_grad()
        #forward
            loss,data_loss,syment_reg,moment_loss,loss_mse = modelLoss(model,u_obs,non_eqn_dep_t,non_eqn_t,diff_dict,config,block)
            if loss.requires_grad:
                    loss.backward()
            return loss
        def closureTemp():
            optimizer.zero_grad()
            loss,data_loss,syment_reg,moment_loss,loss_mse = modelLoss(model,u_obs,non_eqn_dep_t,non_eqn_t,diff_dict,config,block)
            loss.backward()
            return loss,data_loss,syment_reg,moment_loss,loss_mse

        optimizer.step(closure)

        if epoch%10==0:
            loss,data_loss,syment_reg,moment_loss,loss_mse = closureTemp()
            print("[PRINT] Epoch: %d, Loss: %.5f, Mse Loss: %.5f, Data Loss: %.5f, Symnet Regularize: %.5f, Moment Regularize: %.5f "\
                  % (epoch,loss,loss_mse,\
                      data_loss,syment_reg,\
                      moment_loss))
        if epochs - epoch == 20:
            loss,data_loss,syment_reg,moment_loss,loss_mse = closureTemp()
            if loss > 2:
                epochs += 200
            
        #save at every 500th epoch
        if epoch%500==0:
            name = "Block_"+str(block)+"_Epoch_"+str(epoch)+".pth"
            torch.save(model.state_dict(),config['model_dir']+name)
            
    my_lr_scheduler.step() #this doesn't matter for LBFGS
    name = "Block_"+str(block)+"_Epoch_"+str(epochs)+".pth"
    torch.save(model.state_dict(),config['model_dir']+name)
    
#     epochs += 1000

[PRINT] block: 0
[PRINT] Warmum Stage
[PRINT] Epoch: 0, Loss: 0.00778, Mse Loss: 25202.48338, Data Loss: 0.00778, Symnet Regularize: 9.66502, Moment Regularize: 3.00115 
[PRINT] Epoch: 10, Loss: 0.00254, Mse Loss: 8219.04743, Data Loss: 0.00254, Symnet Regularize: 18.77275, Moment Regularize: 4.81708 
[PRINT] Epoch: 20, Loss: 0.00192, Mse Loss: 6216.71566, Data Loss: 0.00192, Symnet Regularize: 20.41881, Moment Regularize: 5.02832 
[PRINT] Epoch: 30, Loss: 0.00185, Mse Loss: 5981.19687, Data Loss: 0.00185, Symnet Regularize: 21.73436, Moment Regularize: 5.56314 
[PRINT] Epoch: 40, Loss: 0.00181, Mse Loss: 5849.07664, Data Loss: 0.00181, Symnet Regularize: 22.13572, Moment Regularize: 5.61020 


KeyboardInterrupt: 

## Results on the original understanding

**ADAM**

In [61]:
params = torch.load('/Users/ankitesh/Desktop/RA_data/ADAM/Block_4_Epoch_4500.pth')
model.load_state_dict(params)

<All keys matched successfully>

In [65]:
model.symnet_0.getEquation(1)

0.154835*LHFLX - 0.1034*PS - 0.128181*TBP_0 - 0.251681*TBP_1 - 0.113844*TBP_2

**BFGS**

In [53]:
params = torch.load('/Users/ankitesh/Desktop/RA_data/LBFGS/Block_10_Epoch_1000.pth')
model.load_state_dict(params)

<All keys matched successfully>

In [58]:
model.symnet_0.getEquation(3)

-0.00556218*LHFLX - 0.0059576*PS - 0.00110481*QBP_0 - 0.00109013*QBP_1 - 0.00583966*SHFLX - 0.0609475*SOLIN - 0.00803638*TBP_0 - 0.538447

## Batch Support

This is for those dataset that can be processed all together at the same time eg when we have (48*8192) timeseries.  
05_ is the more detailed notebook

In [45]:
## keep a num_channel_recoverable (for tacking LHFLX,SHFLX variables)
class ClimateData(DataModel):
    def __init__(self,data_name,Nx,Nt,dt,dx,batch_size,channel_names,\
                 non_eqn_depen_channel, non_eqn_channel, data_file,scaling=1,total_points=-1):
        '''
            currently only batch size of 1 is supported
            Nx is the interpolated dimension 
            Nt max can be 240 (this will take the middle 240 from 7 days i.e skip the 1st day and the last day)
            dt is fixed which is 1800 #can't change 
            dx (the lev tilde will be interpolated from 0 to Nx*dx)(should be 1000/Nx)
            scaling (the scaling value by which the lev coordinates will be scaled, dx,Lx will be scaled)
            batch_size: as of now max is 48 (1 year data)
            data_dir: location of the xarray data
            channel_names: list of channels (variables) ['TAP','QAP']
            channel_const_dict: constant multiplier dict {'TAP':C_P, 'QAP':L_V}
            non_eqn_channel: list of channels(variables) without recoverable eqn ['LHFLX','SHFLX'] (time x batch)
                basically not dependent on lev
        '''
        
        self.scaling = scaling
        dx = scaling*dx
        super(ClimateData,self).__init__(data_name=data_name,Nt=Nt, Nx=Nx, dt=dt, dx=dx)
        self.batch_size = batch_size
        self.ds = xr.open_dataset(data_file)
        self.channel_names = channel_names
        self.non_eqn_depen_channel = non_eqn_depen_channel
        self.n_non_eqn_depen = len(non_eqn_depen_channel)
        self.diff_dict = {"TAP":"DTV","QAP":"VD01","TBP":"TPHYSTND","QBP":"PHQ","TCRM":"TCDTAdiab","QCRM":"QCDTAdiab"}
        self.non_eqn_channel = non_eqn_channel
        self.n_eqn = len(channel_names)
        self.n_non_eqn = len(non_eqn_channel)
        if total_points==-1 or total_points > self.ds.batch_size.size :
            total = self.ds.batch_size.size
        else:
            total = total_points
        
        self.n_batch = total//batch_size
        self.curr_batch = 0
        self.sub, self.div = self._getGlobalStats()
        
    
    def _getGlobalStats(self):
        return self.ds.min(),self.ds.max()-self.ds.min()

    def _generateData(self):
        start = (336-self.Nt)//2
        end = start+self.Nt
        
        lev = self.scaling*self.ds.lev.values
        
        U = []
        eqn_channels_vars = []
        non_eqn_depen_channel_vars = []
        non_eqn_channels_vars = []
        diff_dict_key_name = list(self.diff_dict.values())
        diff_values_vars = []
        diff_dict_keys = list(self.diff_dict.keys())

        #shape the eqn vars
        for var in self.channel_names:
            v = self.ds[var][start:end,self.batch_start:self.batch_end].values[:,:,np.newaxis,:]
            v = (v-float(self.sub[var]))/float(self.div[var])
            eqn_channels_vars.append(v)
            
        for var in self.non_eqn_depen_channel:
            v = self.ds[var][start:end,self.batch_start:self.batch_end].values[:,:,np.newaxis,:]
            v = (v-float(self.sub[var]))/float(self.div[var])
            non_eqn_depen_channel_vars.append(v)
        
        for i,var in enumerate(diff_dict_key_name):
            v = self.ds[var][start:end,self.batch_start:self.batch_end].values[:,:,np.newaxis,:]
            if float(self.div[var]) !=0:
                v = (v-float(self.sub[var]))/float(self.div[var])
            diff_values_vars.append(v)
            
        n_diff_values_vars = len(diff_values_vars)
        #shape the non eqn vas
        for var in self.non_eqn_channel:
            v = self.ds[var][start:end,self.batch_start:self.batch_end].values[:,:,np.newaxis]
            v = np.tile(v,[1,1,self.Nx])
            v = v[:,:,np.newaxis,:]
            v = (v-float(self.sub[var]))/float(self.div[var])
            non_eqn_channels_vars.append(v)
        
        # interpolate the eqn variables
        self.lev_tilde_after =  np.linspace(0,self.Lx,num=self.Nx)

        eqn_channels_vars_interp = []
        non_eqn_depen_channel_vars_interp = []
        diff_values_vars_interp = []

        for i,v in enumerate(eqn_channels_vars+diff_values_vars+non_eqn_depen_channel_vars):
            batch_size = v.shape[1]
            v_interp = np.zeros(v.shape[:-1]+(self.Nx,))
            
            for t in range(self.Nt):
                for b in range(batch_size):
                    interp = np.interp(self.lev_tilde_after,lev,v[t][b][0])
                    v_interp[t][b][0] = interp
        
            if i<self.n_eqn:
                eqn_channels_vars_interp.append(v_interp) 
                
            elif i<self.n_eqn+n_diff_values_vars:
                diff_values_vars_interp.append(v_interp)
            else:
                non_eqn_depen_channel_vars_interp.append(v_interp)
                
                
        
        U = np.concatenate(eqn_channels_vars_interp,axis=2)
        dict_ = {"AP":[],"BP":[],"CRM":[]}
        dict_keys = list(dict_.keys())
        for i in range(3):
            dict_[dict_keys[i]] = torch.from_numpy(np.concatenate(diff_values_vars_interp[2*i:2*i+2],axis=2)).type(torch.DoubleTensor)

        _U_noneqn = np.concatenate(non_eqn_channels_vars,axis=2)
        _U_noneqn_dep = np.concatenate(non_eqn_depen_channel_vars_interp,axis=2)
        return U,_U_noneqn_dep,_U_noneqn,dict_
        

    def data(self,step_num):
        '''
            fetches the next batch
        '''

        self.batch_start = self.curr_batch*self.batch_size
        self.batch_end = self.batch_start + self.batch_size
        self.curr_batch += 1
        self._data,self._U_noneqn_dep,self._U_noneqn,self.dict_ = self._generateData()
        if self.curr_batch >= self.n_batch:
            self.curr_batch = 0
        return torch.from_numpy(self._data[:step_num]),torch.from_numpy(self._U_noneqn_dep[:step_num]),\
                torch.from_numpy(self._U_noneqn[:step_num]).type(torch.DoubleTensor),\
                self.dict_
                    
    
        
    def visualize(self,b,subset=True):
        #displays ith batch plot
        ##time x batch_size x 1 x x_dim 
        u = []
        v = []
        x,t = self._getMeshPoints()
        disp_arr = self._data
            
        for i in range(len(self._data)):
            u.append(np.array(disp_arr[i][b][0]).reshape(-1))
            v.append(np.array(disp_arr[i][b][1]).reshape(-1))
            
        u = np.array(u)
        v = np.array(v)
        
        
        fig = plt.figure(figsize=(8,8))
        ax = fig.add_subplot(111, projection='3d')
        X,T = np.meshgrid(x,t)        
        surf = ax.plot_surface(T, X, u)
        plt.xlabel("Time")
        plt.ylabel("X")
        plt.title(self.channel_names[0])
        plt.show()
        
        
        fig = plt.figure(figsize=(8,8))
        ax = fig.add_subplot(111, projection='3d')
        X,T = np.meshgrid(x,t)        
        surf = ax.plot_surface(T, X, v)
        plt.xlabel("Time")
        plt.ylabel("X")
        plt.title(self.channel_names[1])
        plt.show()
        
        


In [50]:
for block in blocks:
    print('[PRINT] block:',block)
    if block==0:
        print('[PRINT] Warmum Stage')
    stepnum = block if block>=1 else 1
    #get the data at this time #shape [block,batch,channel,X_dim]
    
    for epoch in range(epochs):
        #for every batch
        for b in range(data_model.n_batch):
            u_obs,non_eqn_dep_t,non_eqn_t,diff_dict = data_model.data(stepnum+1) #np array of stepnum elements 
            def closure():
                optimizer.zero_grad()
            #forward
                loss,data_loss,syment_reg,moment_loss,loss_mse = modelLoss(model,u_obs,non_eqn_dep_t,non_eqn_t,diff_dict,config,block)
                if loss.requires_grad:
                        loss.backward()
                return loss
            def closureTemp():
                optimizer.zero_grad()
                loss,data_loss,syment_reg,moment_loss,loss_mse = modelLoss(model,u_obs,non_eqn_dep_t,non_eqn_t,diff_dict,config,block)
                loss.backward()
                return loss,data_loss,syment_reg,moment_loss,loss_mse

            optimizer.step(closure)
#             if b%10==0 and data_model.n_batch!=1:
#                 loss,data_loss,syment_reg,moment_loss,loss_mse = closureTemp()         
#                 print("[PRINT] Epoch: %d, Batch: %d, Loss: %.3f, Mse Loss: %.3f, Data Loss: %.3f, Symnet Regularize: %.3f, Moment Regularize: %.3f "\
#                       % (epoch,b,loss,loss_mse,\
#                           data_loss,syment_reg,\
#                           moment_loss))

        if epoch%10==0:
            loss,data_loss,syment_reg,moment_loss,loss_mse = closureTemp()
            print("[PRINT] Epoch: %d, Loss: %.3f, Mse Loss: %.3f, Data Loss: %.3f, Symnet Regularize: %.3f, Moment Regularize: %.3f "\
                  % (epoch,loss,loss_mse,\
                      data_loss,syment_reg,\
                      moment_loss))
            
        if epoch%500==0:
            name = "Block_"+str(block)+"_Epoch_"+str(epoch)+".pth"
            torch.save(model.state_dict(),config['model_dir']+name)
    name = "Block_"+str(block)+"_Epoch_"+str(epochs)+".pth"
    torch.save(model.state_dict(),config['model_dir']+name)

[PRINT] block: 0
[PRINT] Warmum Stage


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha)


[PRINT] Epoch: 0, Batch: 0, Loss: 0.008, Mse Loss: 26110.972, Data Loss: 0.008, Symnet Regularize: 8.368, Moment Regularize: 2.997 
[PRINT] Epoch: 0, Loss: 0.006, Mse Loss: 18481.919, Data Loss: 0.006, Symnet Regularize: 9.857, Moment Regularize: 3.496 
[PRINT] Epoch: 1, Batch: 0, Loss: 0.005, Mse Loss: 15626.780, Data Loss: 0.005, Symnet Regularize: 11.558, Moment Regularize: 3.678 
[PRINT] Epoch: 2, Batch: 0, Loss: 0.004, Mse Loss: 12166.871, Data Loss: 0.004, Symnet Regularize: 12.240, Moment Regularize: 3.725 
[PRINT] Epoch: 3, Batch: 0, Loss: 0.004, Mse Loss: 11747.032, Data Loss: 0.004, Symnet Regularize: 12.450, Moment Regularize: 3.896 
[PRINT] Epoch: 4, Batch: 0, Loss: 0.003, Mse Loss: 8202.262, Data Loss: 0.003, Symnet Regularize: 16.786, Moment Regularize: 4.762 
[PRINT] Epoch: 5, Batch: 0, Loss: 0.002, Mse Loss: 6879.208, Data Loss: 0.002, Symnet Regularize: 17.914, Moment Regularize: 4.825 


KeyboardInterrupt: 