In [1]:
import numpy as np
import torch
import sympy
import torch.nn.functional as F
from utils import *
import pickle
import configparser
import yaml
import torch.optim as optim

## SymNet

In [290]:
class SymNet(torch.nn.Module):
    def __init__(self, n_hidden, n_deriv_channel, deriv_channel_names=None, normalization_weight=None):
        '''
        Input:
            n_channel = Number of derivatives using (u,ux,vx,vxx)
        '''
        super(SymNet, self).__init__()
        self.n_hidden = n_hidden
        self.n_deriv_channel = n_deriv_channel
        if deriv_channel_names is None:
            deriv_channel_names = list('u_'+str(i) for i in range(self.n_deriv_channel))
        self.deriv_channel_names = deriv_channel_names
        layers = []
        for k in range(n_hidden):
            module = torch.nn.Linear(n_deriv_channel+k,2)
            self.add_module('layer'+str(k), module)
            layers.append(self.__getattr__('layer'+str(k)))
        module = torch.nn.Linear(n_deriv_channel+n_hidden, 1)
        self.add_module('layer_final', module)
        layers.append(self.__getattr__('layer_final'))
        self.layers = tuple(layers)
        
        
    def forward(self, inputs):
        '''
            inputs tensor be of shape (batch_size x X_dim x n_derivatives)
            output shape batch_size x X_dim
        '''
        outputs = inputs.type(torch.FloatTensor)
        for k in range(self.n_hidden):
            o = self.layers[k](outputs)
            outputs = torch.cat([outputs,o[...,:1]*o[...,1:]], dim=-1)
        outputs = self.layers[-1](outputs)
        
        return outputs[...,0]
    
    
    
    def _cast2symbol(self,layer):
        weight,bias = layer.weight.data.cpu().numpy(), \
                    layer.bias.data.cpu().numpy()
        weight,bias = sympy.Matrix(weight),sympy.Matrix(bias)
        return weight,bias

    def _sympychop(self,o, calprec):
        for i in range(o.shape[0]):
            cdict = o[i].expand().as_coefficients_dict()  
            o_i = 0
            for k,v in cdict.items():
                if abs(v)>0.1**calprec:
                    o_i = o_i+k*v
            o[i] = o_i
        return o

    def getEquation(self,calprec=6):
        ## assume symnet model

        deriv_channels = sympy.symbols(self.deriv_channel_names)
        deriv_channels = sympy.Matrix([deriv_channels,])
        for i in range(self.n_hidden):
            weight,bias = self._cast2symbol(self.layers[i])
            o = weight*deriv_channels.transpose()+bias
            o = self._sympychop(o, calprec) #ignores very low params terms
            deriv_channels = list(deriv_channels)+[o[0]*o[1],]
            deriv_channels = sympy.Matrix([deriv_channels,])

        weight,bias = self._cast2symbol(self.layers[-1])
        o = (weight*deriv_channels.transpose()+bias)
        o = _sympychop(o,calprec)

        return o[0]


### Testing

In [193]:
s = SymNet(2,4)

In [194]:
inp = [[[1,2,3,4],[1,2,3,4],[1,2,3,4]]]
inp = torch.from_numpy(np.array(inp)).float()

In [195]:
s(inp)

tensor([[-0.6280, -0.6280, -0.6280]], grad_fn=<SelectBackward>)

In [197]:
s.getEquation(1)

0.38415*u_0 - 0.196855*u_1 - 0.38974*u_2 + 0.378879

## Finite Differences 

In [198]:
class FD1D(torch.nn.Module):
    '''
        Finite Differences scheme for 1D dependency
        acc_order list of acc_order for diff_order
        kernel size should be greater than or equal to the max_diff_order (otherwise error will be thrown)
        diff_order starts from 0,1,... 
        
        Basically this class initializes one kernel of the specified parameters
    '''
    def __init__(self,dx, kernel_size, diff_order,acc_order):
        super(FD1D, self).__init__()
        self.dx = dx
        self.kernel_size = kernel_size
        self.diff_order = diff_order
        self.acc_order = acc_order

        self.kernel = getKernelTorch(diff_order,acc_order,dim=kernel_size,scheme='central')/(dx**diff_order)


        
    def forward(self,inputs):
        '''
            Process:
            Need to pad the input and then appy conv1D
            input shape can be batch_size x n_channels x x_dim
        '''
        inp_padded = padInputTorch(inputs,self.diff_order,self.acc_order,dim=self.kernel_size) #batch_size x n_channels x (x_dim+padded)
#         print(inp_padded)
        conv = F.conv1d(inp_padded,self.kernel)
        return conv
        

### Testing 

In [8]:
inp = torch.randn(1,1,40)
inp.shape

torch.Size([1, 1, 40])

In [9]:
fd = FD1D(1,5,2,2) #fix nans in accuracy order > 2
fd(inp)

tensor([[[-8.1643, -3.3742,  1.4160,  0.6854, -0.1082, -1.6104,  1.3949,
           1.3783, -0.8669, -2.5061,  1.4144,  1.5445, -2.0005,  2.8164,
          -4.1300,  2.9963, -1.1669, -0.5869,  1.4000,  0.7138, -0.4725,
          -1.6408,  0.4569, -1.0700,  4.9659, -3.9201, -0.6370,  2.2561,
          -0.1850, -3.4091,  2.7100,  0.1704, -0.0895,  0.8708, -0.1929,
          -2.8485,  2.6325, -2.1210,  1.8284,  5.7779]]])

## PDE Net

In [199]:
class PdeNet(torch.nn.Module):
    def __init__(self,dt, dx, kernel_size, max_diff_order, n_channel,channel_names,acc_order=2,n_hidden=2):
        '''
        Input:
        '''
        super(PdeNet, self).__init__()
        self.dx = dx
        self.dt = dt
        self.kernel_size = kernel_size
        self.max_diff_order = max_diff_order
        self.n_channel = n_channel
        self.channel_names = channel_names
        self.n_hidden = n_hidden
                        
        if not np.iterable(acc_order):
            acc_order = [acc_order,]*(self.max_diff_order+1)
            
        self.acc_order = acc_order
        
        #conv operation
        for i in range(max_diff_order+1):
            kernel = FD1D(dx,kernel_size,i,acc_order[i])
            self.add_module('fd'+str(i), kernel) #finite difference of order
            
        #symnet 
        c = channel_names.split(',')
        derivative_channels = []
        for ch in c:
            for k in range(max_diff_order+1):
                derivative_channels.append(ch+'_'+str(k))
        self.derivative_channels = derivative_channels 
        self.add_module("symnet",SymNet(n_hidden,len(derivative_channels), deriv_channel_names=derivative_channels))
    
    def multistep(self,inputs,step_num):
        #pass it throught the kernels then the symmnet to 
        '''
        Takes multistep through the whole PDE Net.
        '''
        u = inputs
        for i in range(step_num):
            uadd = self.RightHandItems(u)#will take a dt step from u using the network
            u = u + self.dt*uadd
        return u
    
    def RightHandItems(self,u):
        
        #convolve the u with the derivative kernals to get the different derivatives 
        #batch_size x n_channels x X_dim
        u_derives = []
        for i in range(self.max_diff_order+1):
            fd_obj = self.__getattr__('fd'+str(i))
            u_deriv_order_i = fd_obj(u)
            
            u_derives.append(u_deriv_order_i)
            
        u = torch.cat(u_derives, dim=1) #batch_size x n_derivatives x X_dim 
        #symnet_output = (batch_size x X_dim x n_derivatives)
        symnet = self.__getattr__('symnet')
        u_symnet = symnet(u.permute(0,2,1)) #batch_size x X_dim
        u_out = u_symnet.unsqueeze_(1)
        return u_out
        
        
    def forward(self,inputs,step_num):
        '''
            inputs of shape batch_size x n_channels(1 for our case) x X_dim
            step_nums = number of dt blocks to calculate the inputs for
        '''
        return self.multistep(inputs,step_num)

    

### Testing

In [11]:
net = PdeNet(dt=0.01,dx=0.1,kernel_size=5,max_diff_order=2,n_channel=1,channel_names='u')
print(net)

PdeNet(
  (fd0): FD1D()
  (fd1): FD1D()
  (fd2): FD1D()
  (symnet): SymNet(
    (layer0): Linear(in_features=3, out_features=2, bias=True)
    (layer1): Linear(in_features=4, out_features=2, bias=True)
    (layer_final): Linear(in_features=5, out_features=1, bias=True)
  )
)


**run it for 1 step from start**

In [12]:
inp = torch.randn(1,1,40)
inp.shape

torch.Size([1, 1, 40])

In [20]:
temp = net(inp,step_num=1)
print(temp)

tensor([[[ 1.0171e+01, -2.4517e-01,  8.6392e-01, -2.7246e-01,  3.5923e-01,
          -9.0346e-01,  7.2720e+00,  1.3810e+02,  5.8666e+00, -5.5084e-01,
           4.7074e-01,  2.1184e-01,  3.1020e+00, -1.2126e+00,  1.2576e+00,
          -4.8390e-01,  6.9350e-01,  8.6608e-01, -1.5373e-01,  4.9069e+01,
           3.5265e+01, -1.2972e+00,  1.5328e+00,  1.4518e+00, -2.0000e-01,
           8.5266e-01,  2.7885e-01, -9.9466e-01,  1.0931e+00, -5.1715e-01,
           1.5044e+01, -6.7751e-01, -2.0942e-01, -1.7663e-01, -4.8952e-02,
           8.6972e-01,  7.7066e-01,  4.3546e-01,  1.0066e+00,  2.3007e+00]]],
       grad_fn=<AddBackward0>)


In [21]:
net.state_dict()

OrderedDict([('symnet.layer0.weight', tensor([[ 0.2226,  0.2604,  0.1286],
                      [ 0.2489,  0.0043, -0.1339]])),
             ('symnet.layer0.bias', tensor([-0.4999, -0.3734])),
             ('symnet.layer1.weight',
              tensor([[-0.3535, -0.4990,  0.1757, -0.2431],
                      [-0.3812, -0.0333, -0.1424,  0.4230]])),
             ('symnet.layer1.bias', tensor([-0.2737, -0.4254])),
             ('symnet.layer_final.weight',
              tensor([[ 0.0192,  0.0190, -0.3789,  0.2718, -0.0036]])),
             ('symnet.layer_final.bias', tensor([-0.4418]))])

## Loss functions

We will be using 2 losses, one is data loss and the other is symnet loss

In [22]:
def symnetRegularizeLoss(model):
    loss = 0
    s = 1e-2
    for p in model.symnet.parameters():
        p = p.abs()
        loss = loss+((p<s).to(p)*0.5/s*p**2).sum()+((p>=s).to(p)*(p-s/2)).sum()
    return loss

In [200]:
#global names are all the parameters
def modelLoss(model,u_obs,config,block):
    '''
        Returns the loss value for so that it can be given to an optimizer
        Inputs:
            u_obs (batch_size x n_channels x X_dim)
            blocks is stepnum
    '''
    sparsity = config['sparsity']
    
    if block==0: #warmup
        sparsity = 0
    step_num = block if block>=1 else 1
    dt = config['dt']
    data_loss = 0
    symnet_loss = symnetRegularizeLoss(model)
    ut = u_obs[0]
    mse_loss = torch.nn.MSELoss()
    for steps in range(1,step_num+1):
        ut_next_predicted = model(ut,step_num=1) #take one step from this point
        data_loss += (mse_loss(ut_next_predicted,u_obs[steps])/dt**2)/step_num
        ut = ut_next_predicted

    loss = data_loss+stepnum*sparsity*symnet_loss
    if torch.isnan(loss):
        raise "Loss Nan"
        loss = (torch.ones(1,requires_grad=True)/torch.zeros(1)).to(loss)
    return loss,data_loss,symnet_loss


## Utility Functions

In [133]:
##modify channel names and length
def setenv(config): #return model and datamodel
    model = PdeNet(config['dt'],config['dx'],config['kernel_size'],config['max_diff_order']\
                   ,1,config['channel_names'],config['acc_order'],config['n_hidden_layers'])
    
    #data model 
    if 'Diffusion' in config['name']:
    #data model 
        data_model = DiffusionDataTrign(config['name'],config['Nt'],\
                                   config['dt'],config['dx'],config['viscosity'],batch_size=config['batch_size'],\
                                  time_scheme=config['data_timescheme'],acc_order=config['acc_order'])        
    if 'Burgers' in config['name']:
        data_model = BurgersEqnTrign(config['name'],config['Nt'],\
                               config['dt'],config['dx'],config['viscosity'],batch_size=config['batch_size'],\
                              time_scheme=config['data_timescheme'],acc_order=config['acc_order'])
    #possible some callbacks
    callbacks = None
    return model,data_model,callbacks

# Training

## Diffusion Eqn

In [230]:
with open("config.yaml", 'r') as stream:
    config = yaml.safe_load(stream)

In [231]:
config

{'name': 'Diffusion Eqn',
 'dt': 0.01,
 'dx': 0.1,
 'blocks': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
 'kernel_size': 5,
 'max_diff_order': 2,
 'acc_order': 2,
 'n_hidden_layers': 2,
 'dataname': 'Diffusion',
 'viscosity': 0.1,
 'batch_size': 32,
 'channel_names': 'u',
 'data_timescheme': 'rk4',
 'data_dir': '/sdsd/dsds/sdsd',
 'Nt': 100,
 'Nx': 32,
 'sigma': 1,
 'sparsity': 0.05,
 'epochs': 500,
 'results_dir': '/comet/results/',
 'seed': -1,
 'learning_rate': 0.005}

In [232]:
blocks = config['blocks']
dt = config['dt']
dx = config['dx']
epochs = config['epochs']
lr = config['learning_rate']

In [233]:
model,data_model,callbacks = setenv(config)

In [234]:
##optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

In [235]:
for block in blocks:
    print('[PRINT] block:',block)
    if block==0:
        print('[PRINT] Warmum Stage')
    stepnum = block if block>=1 else 1
    #get the data at this time #shape [block,batch,channel,X_dim]
    u_obs = data_model.data(stepnum+1) #np array of stepnum elements
    for epoch in range(epochs):
        #zero grad
        optimizer.zero_grad()
        #forward
        loss,data_loss,syment_reg = modelLoss(model,u_obs,config,block)
        loss.backward()
        optimizer.step()
        if epoch%10==0:
            print("[PRINT] Epoch: %d, Loss: %.3f, Data Loss: %.3f, Symnet Regularize: %.3f" % (epoch,loss,\
                                                                                              data_loss,syment_reg))

    
# put prints 

[PRINT] block: 0
[PRINT] Warmum Stage
[PRINT] Epoch: 0, Loss: 6.244, Data Loss: 6.244, Symnet Regularize: 5.285
[PRINT] Epoch: 10, Loss: 3.458, Data Loss: 3.458, Symnet Regularize: 5.422
[PRINT] Epoch: 20, Loss: 1.961, Data Loss: 1.961, Symnet Regularize: 5.304
[PRINT] Epoch: 30, Loss: 1.124, Data Loss: 1.124, Symnet Regularize: 5.233
[PRINT] Epoch: 40, Loss: 0.693, Data Loss: 0.693, Symnet Regularize: 5.135
[PRINT] Epoch: 50, Loss: 0.453, Data Loss: 0.453, Symnet Regularize: 5.091
[PRINT] Epoch: 60, Loss: 0.310, Data Loss: 0.310, Symnet Regularize: 5.075
[PRINT] Epoch: 70, Loss: 0.218, Data Loss: 0.218, Symnet Regularize: 5.062
[PRINT] Epoch: 80, Loss: 0.152, Data Loss: 0.152, Symnet Regularize: 5.068
[PRINT] Epoch: 90, Loss: 0.106, Data Loss: 0.106, Symnet Regularize: 5.079
[PRINT] Epoch: 100, Loss: 0.073, Data Loss: 0.073, Symnet Regularize: 5.108
[PRINT] Epoch: 110, Loss: 0.051, Data Loss: 0.051, Symnet Regularize: 5.151
[PRINT] Epoch: 120, Loss: 0.035, Data Loss: 0.035, Symnet Reg

[PRINT] Epoch: 90, Loss: 0.035, Data Loss: 0.002, Symnet Regularize: 0.326
[PRINT] Epoch: 100, Loss: 0.032, Data Loss: 0.002, Symnet Regularize: 0.295
[PRINT] Epoch: 110, Loss: 0.029, Data Loss: 0.002, Symnet Regularize: 0.265
[PRINT] Epoch: 120, Loss: 0.026, Data Loss: 0.002, Symnet Regularize: 0.234
[PRINT] Epoch: 130, Loss: 0.022, Data Loss: 0.002, Symnet Regularize: 0.203
[PRINT] Epoch: 140, Loss: 0.019, Data Loss: 0.002, Symnet Regularize: 0.172
[PRINT] Epoch: 150, Loss: 0.016, Data Loss: 0.002, Symnet Regularize: 0.142
[PRINT] Epoch: 160, Loss: 0.015, Data Loss: 0.002, Symnet Regularize: 0.125
[PRINT] Epoch: 170, Loss: 0.013, Data Loss: 0.002, Symnet Regularize: 0.109
[PRINT] Epoch: 180, Loss: 0.012, Data Loss: 0.002, Symnet Regularize: 0.094
[PRINT] Epoch: 190, Loss: 0.011, Data Loss: 0.002, Symnet Regularize: 0.094
[PRINT] Epoch: 200, Loss: 0.011, Data Loss: 0.002, Symnet Regularize: 0.093
[PRINT] Epoch: 210, Loss: 0.011, Data Loss: 0.002, Symnet Regularize: 0.093
[PRINT] Epoch

[PRINT] Epoch: 170, Loss: 0.026, Data Loss: 0.008, Symnet Regularize: 0.093
[PRINT] Epoch: 180, Loss: 0.026, Data Loss: 0.008, Symnet Regularize: 0.093
[PRINT] Epoch: 190, Loss: 0.026, Data Loss: 0.008, Symnet Regularize: 0.093
[PRINT] Epoch: 200, Loss: 0.026, Data Loss: 0.008, Symnet Regularize: 0.093
[PRINT] Epoch: 210, Loss: 0.026, Data Loss: 0.008, Symnet Regularize: 0.093
[PRINT] Epoch: 220, Loss: 0.026, Data Loss: 0.008, Symnet Regularize: 0.093
[PRINT] Epoch: 230, Loss: 0.026, Data Loss: 0.008, Symnet Regularize: 0.093
[PRINT] Epoch: 240, Loss: 0.026, Data Loss: 0.008, Symnet Regularize: 0.093
[PRINT] Epoch: 250, Loss: 0.026, Data Loss: 0.008, Symnet Regularize: 0.093
[PRINT] Epoch: 260, Loss: 0.026, Data Loss: 0.008, Symnet Regularize: 0.093
[PRINT] Epoch: 270, Loss: 0.026, Data Loss: 0.008, Symnet Regularize: 0.093
[PRINT] Epoch: 280, Loss: 0.026, Data Loss: 0.008, Symnet Regularize: 0.093
[PRINT] Epoch: 290, Loss: 0.026, Data Loss: 0.008, Symnet Regularize: 0.093
[PRINT] Epoc

[PRINT] Epoch: 250, Loss: 0.047, Data Loss: 0.019, Symnet Regularize: 0.092
[PRINT] Epoch: 260, Loss: 0.047, Data Loss: 0.019, Symnet Regularize: 0.092
[PRINT] Epoch: 270, Loss: 0.047, Data Loss: 0.019, Symnet Regularize: 0.092
[PRINT] Epoch: 280, Loss: 0.047, Data Loss: 0.019, Symnet Regularize: 0.092
[PRINT] Epoch: 290, Loss: 0.047, Data Loss: 0.019, Symnet Regularize: 0.092
[PRINT] Epoch: 300, Loss: 0.047, Data Loss: 0.019, Symnet Regularize: 0.092
[PRINT] Epoch: 310, Loss: 0.047, Data Loss: 0.019, Symnet Regularize: 0.092
[PRINT] Epoch: 320, Loss: 0.047, Data Loss: 0.019, Symnet Regularize: 0.092
[PRINT] Epoch: 330, Loss: 0.047, Data Loss: 0.019, Symnet Regularize: 0.092
[PRINT] Epoch: 340, Loss: 0.047, Data Loss: 0.019, Symnet Regularize: 0.092
[PRINT] Epoch: 350, Loss: 0.047, Data Loss: 0.019, Symnet Regularize: 0.092
[PRINT] Epoch: 360, Loss: 0.047, Data Loss: 0.019, Symnet Regularize: 0.092
[PRINT] Epoch: 370, Loss: 0.047, Data Loss: 0.019, Symnet Regularize: 0.092
[PRINT] Epoc

[PRINT] Epoch: 330, Loss: 0.075, Data Loss: 0.038, Symnet Regularize: 0.092
[PRINT] Epoch: 340, Loss: 0.075, Data Loss: 0.038, Symnet Regularize: 0.092
[PRINT] Epoch: 350, Loss: 0.075, Data Loss: 0.038, Symnet Regularize: 0.092
[PRINT] Epoch: 360, Loss: 0.075, Data Loss: 0.038, Symnet Regularize: 0.092
[PRINT] Epoch: 370, Loss: 0.075, Data Loss: 0.038, Symnet Regularize: 0.092
[PRINT] Epoch: 380, Loss: 0.075, Data Loss: 0.038, Symnet Regularize: 0.093
[PRINT] Epoch: 390, Loss: 0.075, Data Loss: 0.038, Symnet Regularize: 0.092
[PRINT] Epoch: 400, Loss: 0.075, Data Loss: 0.038, Symnet Regularize: 0.092
[PRINT] Epoch: 410, Loss: 0.075, Data Loss: 0.038, Symnet Regularize: 0.092
[PRINT] Epoch: 420, Loss: 0.075, Data Loss: 0.038, Symnet Regularize: 0.092
[PRINT] Epoch: 430, Loss: 0.075, Data Loss: 0.038, Symnet Regularize: 0.092
[PRINT] Epoch: 440, Loss: 0.075, Data Loss: 0.038, Symnet Regularize: 0.092
[PRINT] Epoch: 450, Loss: 0.075, Data Loss: 0.038, Symnet Regularize: 0.092
[PRINT] Epoc

KeyboardInterrupt: 

In [239]:
model.symnet.getEquation(calprec=3)

-0.00228215*u_0 + 0.0968245*u_2 + 0.00166485

In [242]:
model.state_dict()

OrderedDict([('symnet.layer0.weight',
              tensor([[ 2.2683e-04, -4.9045e-44, -2.8026e-45],
                      [-2.8026e-45, -2.8026e-45,  2.8026e-45]])),
             ('symnet.layer0.bias', tensor([ 4.2039e-44, -1.9478e-43])),
             ('symnet.layer1.weight',
              tensor([[ 8.0529e-05,  2.8026e-45,  5.6052e-45, -2.8026e-45],
                      [-1.2234e-34, -9.1084e-44, -8.4078e-45,  1.4013e-45]])),
             ('symnet.layer1.bias', tensor([ 2.7142e-08, -7.2868e-44])),
             ('symnet.layer_final.weight',
              tensor([[-2.2822e-03,  9.0344e-04,  9.6824e-02,  1.4013e-45, -5.6108e-42]])),
             ('symnet.layer_final.bias', tensor([0.0017]))])

In [240]:
torch.save(model.state_dict(),"Diffusion_model.pth")

## Burger's Equation

In [222]:
with open("config_burgers.yaml", 'r') as stream:
    config = yaml.safe_load(stream)

In [223]:
config

{'name': 'Burgers Eqn',
 'dt': 0.01,
 'dx': 0.1,
 'blocks': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
 'kernel_size': 5,
 'max_diff_order': 2,
 'acc_order': 2,
 'n_hidden_layers': 2,
 'dataname': 'Diffusion',
 'viscosity': 0.1,
 'batch_size': 32,
 'channel_names': 'u',
 'data_timescheme': 'rk4',
 'data_dir': '/sdsd/dsds/sdsd',
 'Nt': 100,
 'Nx': 32,
 'sigma': 1,
 'sparsity': 0.05,
 'epochs': 500,
 'results_dir': '/comet/results/',
 'seed': -1,
 'learning_rate': 0.005}

In [224]:
model,data_model,callbacks = setenv(config)

In [206]:
##optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

In [41]:
for block in blocks:
    print('[PRINT] block:',block)
    if block==0:
        print('[PRINT] Warmum Stage')
    stepnum = block if block>=1 else 1
    #get the data at this time #shape [block,batch,channel,X_dim]
    u_obs = data_model.data(stepnum+1) #np array of stepnum elements
    for epoch in range(1000):
        #zero grad
        optimizer.zero_grad()
        #forward
        loss,data_loss,syment_reg = modelLoss(model,u_obs,config,block)
        loss.backward()
        optimizer.step()
        if epoch%10==0:
            print("[PRINT] Epoch: %d, Loss: %.3f, Data Loss: %.3f, Symnet Regularize: %.3f" % (epoch,loss,\
                                                                                              data_loss,syment_reg))

    
# put prints 

[PRINT] block: 0
[PRINT] Warmum Stage
[PRINT] Epoch: 0, Loss: 77.414, Data Loss: 77.414, Symnet Regularize: 6.930
[PRINT] Epoch: 10, Loss: 21.137, Data Loss: 21.137, Symnet Regularize: 6.679
[PRINT] Epoch: 20, Loss: 12.704, Data Loss: 12.704, Symnet Regularize: 6.479
[PRINT] Epoch: 30, Loss: 9.997, Data Loss: 9.997, Symnet Regularize: 6.351
[PRINT] Epoch: 40, Loss: 9.146, Data Loss: 9.146, Symnet Regularize: 6.274
[PRINT] Epoch: 50, Loss: 8.112, Data Loss: 8.112, Symnet Regularize: 6.305
[PRINT] Epoch: 60, Loss: 7.464, Data Loss: 7.464, Symnet Regularize: 6.350
[PRINT] Epoch: 70, Loss: 6.999, Data Loss: 6.999, Symnet Regularize: 6.389
[PRINT] Epoch: 80, Loss: 6.634, Data Loss: 6.634, Symnet Regularize: 6.429
[PRINT] Epoch: 90, Loss: 6.337, Data Loss: 6.337, Symnet Regularize: 6.475
[PRINT] Epoch: 100, Loss: 6.092, Data Loss: 6.092, Symnet Regularize: 6.521
[PRINT] Epoch: 110, Loss: 5.877, Data Loss: 5.877, Symnet Regularize: 6.569
[PRINT] Epoch: 120, Loss: 5.683, Data Loss: 5.683, Symn

[PRINT] Epoch: 80, Loss: 0.393, Data Loss: 0.012, Symnet Regularize: 7.620
[PRINT] Epoch: 90, Loss: 0.390, Data Loss: 0.012, Symnet Regularize: 7.571
[PRINT] Epoch: 100, Loss: 0.388, Data Loss: 0.012, Symnet Regularize: 7.521
[PRINT] Epoch: 110, Loss: 0.385, Data Loss: 0.011, Symnet Regularize: 7.472
[PRINT] Epoch: 120, Loss: 0.382, Data Loss: 0.011, Symnet Regularize: 7.422
[PRINT] Epoch: 130, Loss: 0.380, Data Loss: 0.011, Symnet Regularize: 7.372
[PRINT] Epoch: 140, Loss: 0.377, Data Loss: 0.011, Symnet Regularize: 7.322
[PRINT] Epoch: 150, Loss: 0.374, Data Loss: 0.011, Symnet Regularize: 7.271
[PRINT] Epoch: 160, Loss: 0.371, Data Loss: 0.010, Symnet Regularize: 7.220
[PRINT] Epoch: 170, Loss: 0.369, Data Loss: 0.010, Symnet Regularize: 7.169
[PRINT] Epoch: 180, Loss: 0.366, Data Loss: 0.010, Symnet Regularize: 7.119
[PRINT] Epoch: 190, Loss: 0.364, Data Loss: 0.010, Symnet Regularize: 7.072
[PRINT] Epoch: 200, Loss: 0.361, Data Loss: 0.010, Symnet Regularize: 7.027
[PRINT] Epoch:

[PRINT] Epoch: 160, Loss: 0.505, Data Loss: 0.023, Symnet Regularize: 4.821
[PRINT] Epoch: 170, Loss: 0.504, Data Loss: 0.023, Symnet Regularize: 4.802
[PRINT] Epoch: 180, Loss: 0.502, Data Loss: 0.023, Symnet Regularize: 4.782
[PRINT] Epoch: 190, Loss: 0.500, Data Loss: 0.023, Symnet Regularize: 4.763
[PRINT] Epoch: 200, Loss: 0.498, Data Loss: 0.023, Symnet Regularize: 4.744
[PRINT] Epoch: 210, Loss: 0.496, Data Loss: 0.024, Symnet Regularize: 4.726
[PRINT] Epoch: 220, Loss: 0.494, Data Loss: 0.024, Symnet Regularize: 4.708
[PRINT] Epoch: 230, Loss: 0.493, Data Loss: 0.024, Symnet Regularize: 4.690
[PRINT] Epoch: 240, Loss: 0.491, Data Loss: 0.024, Symnet Regularize: 4.676
[PRINT] Epoch: 250, Loss: 0.490, Data Loss: 0.024, Symnet Regularize: 4.666
[PRINT] Epoch: 260, Loss: 0.490, Data Loss: 0.024, Symnet Regularize: 4.656
[PRINT] Epoch: 270, Loss: 0.489, Data Loss: 0.024, Symnet Regularize: 4.646
[PRINT] Epoch: 280, Loss: 0.488, Data Loss: 0.024, Symnet Regularize: 4.636
[PRINT] Epoc

[PRINT] Epoch: 240, Loss: 0.614, Data Loss: 0.070, Symnet Regularize: 3.627
[PRINT] Epoch: 250, Loss: 0.613, Data Loss: 0.070, Symnet Regularize: 3.622
[PRINT] Epoch: 260, Loss: 0.613, Data Loss: 0.070, Symnet Regularize: 3.618
[PRINT] Epoch: 270, Loss: 0.613, Data Loss: 0.071, Symnet Regularize: 3.613
[PRINT] Epoch: 280, Loss: 0.612, Data Loss: 0.071, Symnet Regularize: 3.609
[PRINT] Epoch: 290, Loss: 0.612, Data Loss: 0.071, Symnet Regularize: 3.605
[PRINT] Epoch: 300, Loss: 0.611, Data Loss: 0.071, Symnet Regularize: 3.600
[PRINT] Epoch: 310, Loss: 0.611, Data Loss: 0.072, Symnet Regularize: 3.596
[PRINT] Epoch: 320, Loss: 0.611, Data Loss: 0.072, Symnet Regularize: 3.591
[PRINT] Epoch: 330, Loss: 0.610, Data Loss: 0.072, Symnet Regularize: 3.587
[PRINT] Epoch: 340, Loss: 0.610, Data Loss: 0.073, Symnet Regularize: 3.582
[PRINT] Epoch: 350, Loss: 0.609, Data Loss: 0.073, Symnet Regularize: 3.578
[PRINT] Epoch: 360, Loss: 0.609, Data Loss: 0.073, Symnet Regularize: 3.573
[PRINT] Epoc

[PRINT] Epoch: 320, Loss: 0.762, Data Loss: 0.149, Symnet Regularize: 3.064
[PRINT] Epoch: 330, Loss: 0.762, Data Loss: 0.149, Symnet Regularize: 3.064
[PRINT] Epoch: 340, Loss: 0.762, Data Loss: 0.149, Symnet Regularize: 3.064
[PRINT] Epoch: 350, Loss: 0.762, Data Loss: 0.149, Symnet Regularize: 3.064
[PRINT] Epoch: 360, Loss: 0.762, Data Loss: 0.149, Symnet Regularize: 3.064
[PRINT] Epoch: 370, Loss: 0.762, Data Loss: 0.149, Symnet Regularize: 3.064
[PRINT] Epoch: 380, Loss: 0.762, Data Loss: 0.149, Symnet Regularize: 3.064
[PRINT] Epoch: 390, Loss: 0.762, Data Loss: 0.149, Symnet Regularize: 3.064
[PRINT] Epoch: 400, Loss: 0.762, Data Loss: 0.149, Symnet Regularize: 3.064
[PRINT] Epoch: 410, Loss: 0.762, Data Loss: 0.149, Symnet Regularize: 3.064
[PRINT] Epoch: 420, Loss: 0.762, Data Loss: 0.149, Symnet Regularize: 3.064
[PRINT] Epoch: 430, Loss: 0.762, Data Loss: 0.149, Symnet Regularize: 3.064
[PRINT] Epoch: 440, Loss: 0.762, Data Loss: 0.149, Symnet Regularize: 3.064
[PRINT] Epoc

[PRINT] Epoch: 400, Loss: 0.990, Data Loss: 0.224, Symnet Regularize: 3.064
[PRINT] Epoch: 410, Loss: 0.990, Data Loss: 0.224, Symnet Regularize: 3.064
[PRINT] Epoch: 420, Loss: 0.990, Data Loss: 0.224, Symnet Regularize: 3.064
[PRINT] Epoch: 430, Loss: 0.990, Data Loss: 0.224, Symnet Regularize: 3.064
[PRINT] Epoch: 440, Loss: 0.990, Data Loss: 0.224, Symnet Regularize: 3.064
[PRINT] Epoch: 450, Loss: 0.990, Data Loss: 0.224, Symnet Regularize: 3.064
[PRINT] Epoch: 460, Loss: 0.990, Data Loss: 0.224, Symnet Regularize: 3.064
[PRINT] Epoch: 470, Loss: 0.990, Data Loss: 0.224, Symnet Regularize: 3.064
[PRINT] Epoch: 480, Loss: 0.990, Data Loss: 0.224, Symnet Regularize: 3.064
[PRINT] Epoch: 490, Loss: 0.990, Data Loss: 0.224, Symnet Regularize: 3.064
[PRINT] Epoch: 500, Loss: 0.990, Data Loss: 0.224, Symnet Regularize: 3.064
[PRINT] Epoch: 510, Loss: 0.990, Data Loss: 0.224, Symnet Regularize: 3.064
[PRINT] Epoch: 520, Loss: 0.990, Data Loss: 0.224, Symnet Regularize: 3.064
[PRINT] Epoc

[PRINT] Epoch: 480, Loss: 1.242, Data Loss: 0.323, Symnet Regularize: 3.063
[PRINT] Epoch: 490, Loss: 1.242, Data Loss: 0.323, Symnet Regularize: 3.063
[PRINT] Epoch: 500, Loss: 1.242, Data Loss: 0.323, Symnet Regularize: 3.063
[PRINT] Epoch: 510, Loss: 1.242, Data Loss: 0.323, Symnet Regularize: 3.063
[PRINT] Epoch: 520, Loss: 1.242, Data Loss: 0.323, Symnet Regularize: 3.063
[PRINT] Epoch: 530, Loss: 1.242, Data Loss: 0.323, Symnet Regularize: 3.063
[PRINT] Epoch: 540, Loss: 1.242, Data Loss: 0.323, Symnet Regularize: 3.063
[PRINT] Epoch: 550, Loss: 1.242, Data Loss: 0.323, Symnet Regularize: 3.063
[PRINT] Epoch: 560, Loss: 1.242, Data Loss: 0.323, Symnet Regularize: 3.063
[PRINT] Epoch: 570, Loss: 1.242, Data Loss: 0.323, Symnet Regularize: 3.063
[PRINT] Epoch: 580, Loss: 1.242, Data Loss: 0.323, Symnet Regularize: 3.063
[PRINT] Epoch: 590, Loss: 1.242, Data Loss: 0.323, Symnet Regularize: 3.063
[PRINT] Epoch: 600, Loss: 1.242, Data Loss: 0.323, Symnet Regularize: 3.063
[PRINT] Epoc

[PRINT] Epoch: 560, Loss: 1.527, Data Loss: 0.455, Symnet Regularize: 3.062
[PRINT] Epoch: 570, Loss: 1.527, Data Loss: 0.455, Symnet Regularize: 3.062
[PRINT] Epoch: 580, Loss: 1.527, Data Loss: 0.455, Symnet Regularize: 3.062
[PRINT] Epoch: 590, Loss: 1.527, Data Loss: 0.455, Symnet Regularize: 3.062
[PRINT] Epoch: 600, Loss: 1.527, Data Loss: 0.455, Symnet Regularize: 3.062
[PRINT] Epoch: 610, Loss: 1.527, Data Loss: 0.455, Symnet Regularize: 3.062
[PRINT] Epoch: 620, Loss: 1.527, Data Loss: 0.455, Symnet Regularize: 3.062
[PRINT] Epoch: 630, Loss: 1.527, Data Loss: 0.455, Symnet Regularize: 3.062
[PRINT] Epoch: 640, Loss: 1.527, Data Loss: 0.455, Symnet Regularize: 3.062
[PRINT] Epoch: 650, Loss: 1.527, Data Loss: 0.455, Symnet Regularize: 3.062
[PRINT] Epoch: 660, Loss: 1.527, Data Loss: 0.455, Symnet Regularize: 3.062
[PRINT] Epoch: 670, Loss: 1.527, Data Loss: 0.455, Symnet Regularize: 3.062
[PRINT] Epoch: 680, Loss: 1.527, Data Loss: 0.455, Symnet Regularize: 3.062
[PRINT] Epoc

[PRINT] Epoch: 640, Loss: 1.858, Data Loss: 0.633, Symnet Regularize: 3.061
[PRINT] Epoch: 650, Loss: 1.858, Data Loss: 0.633, Symnet Regularize: 3.061
[PRINT] Epoch: 660, Loss: 1.858, Data Loss: 0.633, Symnet Regularize: 3.061
[PRINT] Epoch: 670, Loss: 1.858, Data Loss: 0.633, Symnet Regularize: 3.061
[PRINT] Epoch: 680, Loss: 1.858, Data Loss: 0.633, Symnet Regularize: 3.061
[PRINT] Epoch: 690, Loss: 1.858, Data Loss: 0.633, Symnet Regularize: 3.061
[PRINT] Epoch: 700, Loss: 1.858, Data Loss: 0.633, Symnet Regularize: 3.061
[PRINT] Epoch: 710, Loss: 1.858, Data Loss: 0.633, Symnet Regularize: 3.061
[PRINT] Epoch: 720, Loss: 1.858, Data Loss: 0.633, Symnet Regularize: 3.061
[PRINT] Epoch: 730, Loss: 1.858, Data Loss: 0.633, Symnet Regularize: 3.061
[PRINT] Epoch: 740, Loss: 1.858, Data Loss: 0.633, Symnet Regularize: 3.061
[PRINT] Epoch: 750, Loss: 1.858, Data Loss: 0.633, Symnet Regularize: 3.061
[PRINT] Epoch: 760, Loss: 1.858, Data Loss: 0.633, Symnet Regularize: 3.061
[PRINT] Epoc

[PRINT] Epoch: 720, Loss: 2.252, Data Loss: 0.875, Symnet Regularize: 3.060
[PRINT] Epoch: 730, Loss: 2.252, Data Loss: 0.875, Symnet Regularize: 3.060
[PRINT] Epoch: 740, Loss: 2.252, Data Loss: 0.875, Symnet Regularize: 3.060
[PRINT] Epoch: 750, Loss: 2.252, Data Loss: 0.875, Symnet Regularize: 3.060
[PRINT] Epoch: 760, Loss: 2.252, Data Loss: 0.875, Symnet Regularize: 3.060
[PRINT] Epoch: 770, Loss: 2.252, Data Loss: 0.875, Symnet Regularize: 3.060
[PRINT] Epoch: 780, Loss: 2.252, Data Loss: 0.875, Symnet Regularize: 3.060
[PRINT] Epoch: 790, Loss: 2.252, Data Loss: 0.875, Symnet Regularize: 3.060
[PRINT] Epoch: 800, Loss: 2.252, Data Loss: 0.875, Symnet Regularize: 3.060
[PRINT] Epoch: 810, Loss: 2.252, Data Loss: 0.875, Symnet Regularize: 3.060
[PRINT] Epoch: 820, Loss: 2.252, Data Loss: 0.875, Symnet Regularize: 3.060
[PRINT] Epoch: 830, Loss: 2.252, Data Loss: 0.875, Symnet Regularize: 3.060
[PRINT] Epoch: 840, Loss: 2.252, Data Loss: 0.875, Symnet Regularize: 3.060
[PRINT] Epoc

KeyboardInterrupt: 

In [243]:
weights = torch.load('burgers_weight.pth')
model.load_state_dict(weights)

<All keys matched successfully>

In [246]:
model.symnet.getEquation(calprec=3)

-0.974806*u_0*u_1 - 0.00462748*u_0*u_2 - 0.00480364*u_0 - 0.00430956*u_1**2 + 0.100894*u_2 - 0.00235616

In [247]:
model.state_dict()

OrderedDict([('symnet.layer0.weight',
              tensor([[-3.1466e-41,  4.3728e-38, -2.7045e-40],
                      [ 7.1718e-42,  1.4293e-42, -1.9632e-42]])),
             ('symnet.layer0.bias', tensor([1.8190e-28, 1.3788e-13])),
             ('symnet.layer1.weight',
              tensor([[ 9.9174e-01,  4.3844e-03,  6.0919e-04,  2.6048e-36],
                      [-7.0374e-04,  9.9059e-01,  4.7024e-03,  5.1386e-34]])),
             ('symnet.layer1.bias', tensor([1.2057e-05, 2.3986e-03])),
             ('symnet.layer_final.weight',
              tensor([[-2.4433e-03, -5.5761e-05,  1.0089e-01,  9.9048e-35, -9.9226e-01]])),
             ('symnet.layer_final.bias', tensor([-0.0024]))])

## Some Additional Stuff

In [None]:
w1 = model.symnet.layer0.weight.data.cpu().numpy()
w1
w1_d = sympy.Matrix(w1)
w1_d

In [180]:

def _cast2symbol(layer):
    weight,bias = layer.weight.data.cpu().numpy(), \
                layer.bias.data.cpu().numpy()
    weight,bias = sympy.Matrix(weight),sympy.Matrix(bias)
    return weight,bias

def _sympychop(o, calprec):
    for i in range(o.shape[0]):
        cdict = o[i].expand().as_coefficients_dict()  
        o_i = 0
        for k,v in cdict.items():
            if abs(v)>0.1**calprec:
                o_i = o_i+k*v
        o[i] = o_i
    return o

def getEquation(model,calprec=6):
    ## assume symnet model
    
    deriv_channels = sympy.symbols(model.channel_names)
    deriv_channels = sympy.Matrix([deriv_channels,])
    for i in range(model.n_hidden):
        weight,bias = _cast2symbol(model.layers[i])
        o = weight*deriv_channels.transpose()+bias
        o = _sympychop(o, calprec) #ignores very low params terms
        deriv_channels = list(deriv_channels)+[o[0]*o[1],]
        deriv_channels = sympy.Matrix([deriv_channels,])
        
    weight,bias = _cast2symbol(model.layers[-1])
    o = (weight*deriv_channels.transpose()+bias)
    o = _sympychop(o,calprec)

    return o[0]


In [181]:
getEquation(model.symnet,2)

-0.974806*u*u_x + 0.100894*u_xx

In [177]:
model.symnet.channel_names = ['u','u_x','u_xx']

In [248]:
getKernel(1)

array([-0.5,  0. ,  0.5])

## Finite Differences With Moment

In [272]:
class FD1DMoment(torch.nn.Module):
    '''
        Finite Differences scheme for 1D dependency
        acc_order list of acc_order for diff_order
        kernel size should be greater than or equal to the max_diff_order (otherwise error will be thrown)
        diff_order starts from 0,1,... 
        constarint: Moment matrix constraint "free" or "moment"
        Basically this class initializes one kernel of the specified parameters
    '''
    def __init__(self,dx, kernel_size, diff_order,acc_order,constraint='free'):
        super(FD1DMoment, self).__init__()
        self.dx = dx
        self.kernel_size = kernel_size
        self.diff_order = diff_order
        self.acc_order = acc_order
        self.constraint = constraint
        ## will only be used in case of contraint = 'free'
        if constraint=='free':
            self._kernel = (getKernelTorch(diff_order,acc_order+1,dim=kernel_size,scheme='central')/(dx**diff_order)).type(torch.DoubleTensor)
            
        ##Moment to kernel and vice versa
        if constraint=='moment':
            self.m2k = M2K(kernel_size)
            self.k2m = K2M(kernel_size)
            #define moment matrix
            moment = torch.DoubleTensor(kernel_size).zero_()
            moment[diff_order] = 1
            moment = moment.reshape(1,1,-1)
            self.moment = torch.nn.Parameter(moment) ## now weights will be updated on this
            ##create a mask for gradeint hook
            self.gradient_mask = self._getGradientMask()
            ##register hook to the moment matrix
            self.moment.register_hook(lambda grad: grad.mul_(self.gradient_mask))
        
        
    @property
    def kernel(self):
        if self.constraint == 'moment':
            kernel = self.m2k(self.moment)/(self.dx**self.diff_order)
        else:
            kernel = self._kernel
            
            
        return kernel
    
    def _getGradientMask(self):
        gradient_mask = torch.ones(self.kernel_size,dtype=torch.double)
        order_bank = np.arange(self.kernel_size)
        for j in range(self.diff_order+self.acc_order):
            gradient_mask[order_bank[j]] = 0
        gradient_mask = gradient_mask.reshape(1,1,-1)
        return gradient_mask
        
    def forward(self,inputs):
        '''
            Process:
            Need to pad the input and then appy conv1D
            input shape can be batch_size x n_channels x x_dim
        '''
        inp_padded = padInputTorch(inputs,self.diff_order,self.acc_order+1,dim=self.kernel_size) #batch_size x n_channels x (x_dim+padded)
#         print(inp_padded)
        conv = F.conv1d(inp_padded.type(torch.DoubleTensor),self.kernel)
        return conv
        

In [191]:
inp = torch.randn(1,1,40)
inp.shape

torch.Size([1, 1, 40])

In [192]:
fd = FD1DMoment(1,3,1,2,constraint='moment') #fix nans in accuracy order > 2
fd(inp)

tensor([[[-0.6544,  0.3960, -0.4965,  0.1895, -0.2412, -0.3051,  0.6687,
          -0.2338, -0.2105,  0.6703, -0.7742, -0.7891,  1.2694, -0.5749,
          -0.6582,  0.9995,  0.7799,  0.5030,  0.2730, -0.6519, -0.5912,
           0.0077, -0.6355, -0.7861,  0.5252, -0.4820, -0.4739,  0.5875,
           0.6501, -0.1029, -0.8475, -0.6698,  0.9173,  1.5599, -0.7804,
           0.5820, -0.0037, -1.2849,  0.8986,  1.6743]]], dtype=torch.float64,
       grad_fn=<SqueezeBackward1>)

In [196]:
list(fd.parameters())

[Parameter containing:
 tensor([[[0., 1., 0.]]], dtype=torch.float64, requires_grad=True)]

In [124]:
fd = FD1D(1,3,1,2,constraint='free') #fix nans in accuracy order > 2
fd(inp)

tensor([[[-5.2352e-02,  7.8054e-03,  3.2451e-02,  4.2015e-01,  4.5650e-02,
           3.2077e-02,  2.6063e-01, -8.3255e-01,  3.2625e-01,  1.6310e+00,
           3.0084e-01, -1.3547e+00, -7.3103e-01,  6.2955e-01,  1.3862e-01,
          -5.9324e-02, -3.9002e-01,  2.9395e-01,  8.2532e-01,  2.5238e-01,
           5.8228e-01, -1.3900e+00, -1.3439e+00,  4.2385e-01,  2.6559e-01,
           1.2791e+00,  3.8773e-01, -2.0274e+00, -6.3823e-01,  2.4442e-01,
           2.9626e-01,  8.5085e-01,  1.5764e-04, -4.5245e-01,  3.2517e-01,
           7.6670e-01,  1.6585e-01, -6.8229e-01,  2.0467e-01,  5.6886e+00]]],
       dtype=torch.float64)

## Trying PdeNet with Moments

In [273]:
class PdeNet(torch.nn.Module):
    def __init__(self,dt, dx, kernel_size, max_diff_order, n_channel,channel_names,acc_order=2,n_hidden=2,\
                constraint='free'):
        '''
        Input:
        '''
        super(PdeNet, self).__init__()
        self.dx = dx
        self.dt = dt
        self.kernel_size = kernel_size
        self.max_diff_order = max_diff_order
        self.n_channel = n_channel
        self.channel_names = channel_names
        self.n_hidden = n_hidden
        self.constraint = constraint
                        
        if not np.iterable(acc_order):
            acc_order = [acc_order,]*(self.max_diff_order+1)
            
        self.acc_order = acc_order
        
        #conv operation
        for i in range(max_diff_order+1):
            kernel = FD1DMoment(dx,kernel_size,i,acc_order[i],constraint=constraint)
            self.add_module('fd'+str(i), kernel) #finite difference of order
            
        #symnet 
        c = channel_names.split(',')
        derivative_channels = []
        for ch in c:
            for k in range(max_diff_order+1):
                derivative_channels.append(ch+'_'+str(k))
        self.derivative_channels = derivative_channels 
        self.add_module("symnet",SymNet(n_hidden,len(derivative_channels), deriv_channel_names=derivative_channels))
    
    @property
    def fds(self):
        for i in range(self.max_diff_order+1):
            yield self.__getattr__('fd'+str(i))
                
    def multistep(self,inputs,step_num):
        #pass it throught the kernels then the symmnet to 
        '''
        Takes multistep through the whole PDE Net.
        '''
        u = inputs
        for i in range(step_num):
            uadd = self.RightHandItems(u)#will take a dt step from u using the network
            u = u + self.dt*uadd
        return u
    
    def diffParams(self):
        params = []
        for fd in self.fds:
            params += list(fd.parameters())
        return params

    def RightHandItems(self,u):
        
        #convolve the u with the derivative kernals to get the different derivatives 
        #batch_size x n_channels x X_dim
        u_derives = []
        for i in range(self.max_diff_order+1):
            fd_obj = self.__getattr__('fd'+str(i))
            u_deriv_order_i = fd_obj(u)
            
            u_derives.append(u_deriv_order_i)
            
        u = torch.cat(u_derives, dim=1) #batch_size x n_derivatives x X_dim 
        #symnet_output = (batch_size x X_dim x n_derivatives)
        symnet = self.__getattr__('symnet')
        u_symnet = symnet(u.permute(0,2,1)) #batch_size x X_dim
        u_out = u_symnet.unsqueeze_(1)
        return u_out
        
        
    def forward(self,inputs,step_num):
        '''
            inputs of shape batch_size x n_channels(1 for our case) x X_dim
            step_nums = number of dt blocks to calculate the inputs for
        '''
        return self.multistep(inputs,step_num)

    

### Losses

In [274]:
def symnetRegularizeLoss(model):
    loss = 0
    s = 1e-2
    for p in model.symnet.parameters():
        p = p.abs()
        loss = loss+((p<s).to(p)*0.5/s*p**2).sum()+((p>=s).to(p)*(p-s/2)).sum()
    return loss

In [275]:
def momentRegularizeLoss(model):
    loss = 0
    s = 1e-2
    for p in model.diffParams():
        p = p.abs()
        loss = loss+((p<s).to(p)*0.5/s*p**2).sum()+((p>=s).to(p)*(p-s/2)).sum()

    return loss

In [276]:
#global names are all the parameters
def modelLoss(model,u_obs,config,block):
    '''
        Returns the loss value for so that it can be given to an optimizer
        Inputs:
            u_obs (batch_size x n_channels x X_dim)
            blocks is stepnum
    '''
    sparsity = config['sparsity']
    momentsparsity = config['momentsparsity']
    
    
    if block==0: #warmup
        sparsity = 0
    step_num = block if block>=1 else 1
    dt = config['dt']
    data_loss = 0
    symnet_loss = symnetRegularizeLoss(model)
    moment_loss = momentRegularizeLoss(model)
    ut = u_obs[0]
    mse_loss = torch.nn.MSELoss()
    for steps in range(1,step_num+1):
        ut_next_predicted = model(ut,step_num=1) #take one step from this point
        data_loss += (mse_loss(ut_next_predicted,u_obs[steps])/dt**2)/step_num
        ut = ut_next_predicted

    loss = data_loss+stepnum*sparsity*symnet_loss+stepnum*momentsparsity*moment_loss
    if torch.isnan(loss):
        raise "Loss Nan"
        loss = (torch.ones(1,requires_grad=True)/torch.zeros(1)).to(loss)
    return loss,data_loss,symnet_loss,moment_loss


### Diffusion Eqn

In [277]:
##modify channel names and length
def setenv(config): #return model and datamodel
    model = PdeNet(config['dt'],config['dx'],config['kernel_size'],config['max_diff_order']\
                   ,1,config['channel_names'],config['acc_order'],config['n_hidden_layers'],config['constraint'])
    
    #data model 
    if 'Diffusion' in config['name']:
    #data model 
        data_model = DiffusionDataTrign(config['name'],config['Nt'],\
                                   config['dt'],config['dx'],config['viscosity'],batch_size=config['batch_size'],\
                                  time_scheme=config['data_timescheme'],acc_order=config['acc_order']+1)        
    if 'Burgers' in config['name']:
        data_model = BurgersEqnTrign(config['name'],config['Nt'],\
                               config['dt'],config['dx'],config['viscosity'],batch_size=config['batch_size'],\
                              time_scheme=config['data_timescheme'],acc_order=config['acc_order'])
    #possible some callbacks
    callbacks = None
    return model,data_model,callbacks

In [278]:
with open("config_diff_moment.yaml", 'r') as stream:
    config = yaml.safe_load(stream)

In [279]:
config

{'name': 'Diffusion Eqn Moment',
 'dt': 0.01,
 'dx': 0.1,
 'blocks': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
 'kernel_size': 5,
 'max_diff_order': 2,
 'acc_order': 1,
 'n_hidden_layers': 2,
 'dataname': 'Diffusion',
 'viscosity': 0.1,
 'batch_size': 32,
 'channel_names': 'u',
 'data_timescheme': 'rk4',
 'data_dir': '/sdsd/dsds/sdsd',
 'Nt': 100,
 'Nx': 32,
 'sigma': 1,
 'sparsity': 0.005,
 'momentsparsity': 0.001,
 'epochs': 500,
 'results_dir': '/comet/results/',
 'seed': -1,
 'learning_rate': 0.005,
 'constraint': 'moment'}

In [285]:
blocks = config['blocks']
dt = config['dt']
dx = config['dx']
epochs = config['epochs']
lr = config['learning_rate']

In [286]:
model,data_model,callbacks = setenv(config)

In [287]:
##optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

In [288]:
for block in blocks:
    print('[PRINT] block:',block)
    if block==0:
        print('[PRINT] Warmum Stage')
    stepnum = block if block>=1 else 1
    #get the data at this time #shape [block,batch,channel,X_dim]
    u_obs = data_model.data(stepnum+1) #np array of stepnum elements
    for epoch in range(epochs):
        #zero grad
        optimizer.zero_grad()
        #forward
        loss,data_loss,syment_reg,moment_loss = modelLoss(model,u_obs,config,block)
        loss.backward()
        optimizer.step()
        if epoch%10==0:
            print("[PRINT] Epoch: %d, Loss: %.3f, Data Loss: %.3f, Symnet Regularize: %.3f, Moment Regularize: %.3f "\
                  % (epoch,loss,\
                      data_loss,syment_reg,\
                      moment_loss))

    
# put prints 

[PRINT] block: 0
[PRINT] Warmum Stage
[PRINT] Epoch: 0, Loss: 7.986, Data Loss: 7.983, Symnet Regularize: 5.548, Moment Regularize: 2.985 
[PRINT] Epoch: 10, Loss: 4.220, Data Loss: 4.217, Symnet Regularize: 5.074, Moment Regularize: 3.271 
[PRINT] Epoch: 20, Loss: 2.508, Data Loss: 2.504, Symnet Regularize: 4.723, Moment Regularize: 3.486 
[PRINT] Epoch: 30, Loss: 1.407, Data Loss: 1.403, Symnet Regularize: 4.472, Moment Regularize: 3.572 
[PRINT] Epoch: 40, Loss: 0.714, Data Loss: 0.710, Symnet Regularize: 4.297, Moment Regularize: 3.604 
[PRINT] Epoch: 50, Loss: 0.327, Data Loss: 0.323, Symnet Regularize: 4.163, Moment Regularize: 3.625 
[PRINT] Epoch: 60, Loss: 0.136, Data Loss: 0.132, Symnet Regularize: 4.070, Moment Regularize: 3.636 
[PRINT] Epoch: 70, Loss: 0.054, Data Loss: 0.050, Symnet Regularize: 4.001, Moment Regularize: 3.636 
[PRINT] Epoch: 80, Loss: 0.024, Data Loss: 0.020, Symnet Regularize: 3.942, Moment Regularize: 3.625 
[PRINT] Epoch: 90, Loss: 0.014, Data Loss: 0.

[PRINT] Epoch: 320, Loss: 0.014, Data Loss: 0.001, Symnet Regularize: 2.000, Moment Regularize: 3.097 
[PRINT] Epoch: 330, Loss: 0.014, Data Loss: 0.001, Symnet Regularize: 1.971, Moment Regularize: 3.094 
[PRINT] Epoch: 340, Loss: 0.014, Data Loss: 0.001, Symnet Regularize: 1.943, Moment Regularize: 3.090 
[PRINT] Epoch: 350, Loss: 0.013, Data Loss: 0.001, Symnet Regularize: 1.917, Moment Regularize: 3.086 
[PRINT] Epoch: 360, Loss: 0.013, Data Loss: 0.001, Symnet Regularize: 1.891, Moment Regularize: 3.082 
[PRINT] Epoch: 370, Loss: 0.013, Data Loss: 0.001, Symnet Regularize: 1.864, Moment Regularize: 3.079 
[PRINT] Epoch: 380, Loss: 0.013, Data Loss: 0.001, Symnet Regularize: 1.838, Moment Regularize: 3.075 
[PRINT] Epoch: 390, Loss: 0.013, Data Loss: 0.001, Symnet Regularize: 1.811, Moment Regularize: 3.071 
[PRINT] Epoch: 400, Loss: 0.013, Data Loss: 0.001, Symnet Regularize: 1.785, Moment Regularize: 3.067 
[PRINT] Epoch: 410, Loss: 0.013, Data Loss: 0.001, Symnet Regularize: 1.7

[PRINT] Epoch: 120, Loss: 0.018, Data Loss: 0.005, Symnet Regularize: 0.251, Moment Regularize: 3.063 
[PRINT] Epoch: 130, Loss: 0.017, Data Loss: 0.005, Symnet Regularize: 0.229, Moment Regularize: 3.063 
[PRINT] Epoch: 140, Loss: 0.017, Data Loss: 0.005, Symnet Regularize: 0.206, Moment Regularize: 3.063 
[PRINT] Epoch: 150, Loss: 0.017, Data Loss: 0.005, Symnet Regularize: 0.183, Moment Regularize: 3.063 
[PRINT] Epoch: 160, Loss: 0.016, Data Loss: 0.005, Symnet Regularize: 0.159, Moment Regularize: 3.063 
[PRINT] Epoch: 170, Loss: 0.016, Data Loss: 0.005, Symnet Regularize: 0.137, Moment Regularize: 3.063 
[PRINT] Epoch: 180, Loss: 0.016, Data Loss: 0.005, Symnet Regularize: 0.118, Moment Regularize: 3.063 
[PRINT] Epoch: 190, Loss: 0.015, Data Loss: 0.005, Symnet Regularize: 0.102, Moment Regularize: 3.063 
[PRINT] Epoch: 200, Loss: 0.015, Data Loss: 0.005, Symnet Regularize: 0.094, Moment Regularize: 3.063 
[PRINT] Epoch: 210, Loss: 0.015, Data Loss: 0.005, Symnet Regularize: 0.0

[PRINT] Epoch: 420, Loss: 0.022, Data Loss: 0.008, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 430, Loss: 0.022, Data Loss: 0.008, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 440, Loss: 0.022, Data Loss: 0.008, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 450, Loss: 0.022, Data Loss: 0.008, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 460, Loss: 0.022, Data Loss: 0.008, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 470, Loss: 0.022, Data Loss: 0.008, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 480, Loss: 0.022, Data Loss: 0.008, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 490, Loss: 0.022, Data Loss: 0.008, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] block: 5
[PRINT] Epoch: 0, Loss: 0.031, Data Loss: 0.013, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 10, Loss: 0.031, Data Loss: 0.013, Symnet R

[PRINT] Epoch: 220, Loss: 0.042, Data Loss: 0.020, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 230, Loss: 0.042, Data Loss: 0.020, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 240, Loss: 0.042, Data Loss: 0.020, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 250, Loss: 0.042, Data Loss: 0.020, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 260, Loss: 0.042, Data Loss: 0.020, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 270, Loss: 0.042, Data Loss: 0.020, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 280, Loss: 0.042, Data Loss: 0.020, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 290, Loss: 0.042, Data Loss: 0.020, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 300, Loss: 0.042, Data Loss: 0.020, Symnet Regularize: 0.093, Moment Regularize: 3.063 
[PRINT] Epoch: 310, Loss: 0.042, Data Loss: 0.020, Symnet Regularize: 0.0

[PRINT] Epoch: 20, Loss: 0.070, Data Loss: 0.042, Symnet Regularize: 0.094, Moment Regularize: 3.068 
[PRINT] Epoch: 30, Loss: 0.069, Data Loss: 0.041, Symnet Regularize: 0.094, Moment Regularize: 3.064 
[PRINT] Epoch: 40, Loss: 0.069, Data Loss: 0.041, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 50, Loss: 0.069, Data Loss: 0.041, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 60, Loss: 0.069, Data Loss: 0.041, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 70, Loss: 0.069, Data Loss: 0.041, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 80, Loss: 0.069, Data Loss: 0.041, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 90, Loss: 0.069, Data Loss: 0.041, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 100, Loss: 0.069, Data Loss: 0.041, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 110, Loss: 0.069, Data Loss: 0.041, Symnet Regularize: 0.093, Mome

[PRINT] Epoch: 320, Loss: 0.086, Data Loss: 0.054, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 330, Loss: 0.086, Data Loss: 0.054, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 340, Loss: 0.086, Data Loss: 0.054, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 350, Loss: 0.086, Data Loss: 0.054, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 360, Loss: 0.086, Data Loss: 0.054, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 370, Loss: 0.086, Data Loss: 0.054, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 380, Loss: 0.086, Data Loss: 0.054, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 390, Loss: 0.086, Data Loss: 0.054, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 400, Loss: 0.086, Data Loss: 0.054, Symnet Regularize: 0.093, Moment Regularize: 3.064 
[PRINT] Epoch: 410, Loss: 0.086, Data Loss: 0.054, Symnet Regularize: 0.0

In [291]:
model.state_dict()

OrderedDict([('fd0.moment',
              tensor([[[ 1.0000e+00,  6.2113e-05,  1.7956e-03, -1.2160e-04, -6.2169e-03]]],
                     dtype=torch.float64)),
             ('fd0.m2k._M',
              tensor([[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000],
                      [-2.0000, -1.0000,  0.0000,  1.0000,  2.0000],
                      [ 2.0000,  0.5000,  0.0000,  0.5000,  2.0000],
                      [-1.3333, -0.1667,  0.0000,  0.1667,  1.3333],
                      [ 0.6667,  0.0417,  0.0000,  0.0417,  0.6667]], dtype=torch.float64)),
             ('fd0.m2k._invM',
              tensor([[-0.0000e+00,  8.3333e-02, -8.3333e-02, -5.0000e-01,  1.0000e+00],
                      [ 0.0000e+00, -6.6667e-01,  1.3333e+00,  1.0000e+00, -4.0000e+00],
                      [ 1.0000e+00,  0.0000e+00, -2.5000e+00,  4.4409e-16,  6.0000e+00],
                      [ 0.0000e+00,  6.6667e-01,  1.3333e+00, -1.0000e+00, -4.0000e+00],
                      [ 0.0000e+00, -8.3333e-02, -8.

In [296]:
model.symnet.getEquation(calprec=2)

0.0970599*u_2

In [293]:
def _cast2symbol(layer):
    weight,bias = layer.weight.data.cpu().numpy(), \
                layer.bias.data.cpu().numpy()
    weight,bias = sympy.Matrix(weight),sympy.Matrix(bias)
    return weight,bias

In [295]:
def _sympychop(o, calprec):
    for i in range(o.shape[0]):
        cdict = o[i].expand().as_coefficients_dict()  
        o_i = 0
        for k,v in cdict.items():
            if abs(v)>0.1**calprec:
                o_i = o_i+k*v
        o[i] = o_i
    return o