In [None]:
## import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(2601, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 1)
        self.init_w()
        
    def init_w(self):
        for layer in self.modules():
            print(layer)
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
                if layer.bias is not None:
                    nn.init.xavier_normal_(layer.bias)
                    
    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        x = torch.sigmoid(self.fc4(x))
        return x

In [4]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(1, 32)
        self.fc2 = nn.Linear(32, 64)
        self.fc3 = nn.Linear(64, 128)
        self.fc4 = nn.Linear(128, 2601)
        self.init_w()
        
    def init_w(self):
        for layer in self.modules():
            if isinstance(layer, nn.Linear):
                nn.init.xavier_normal_(layer.weight)
                if layer.bias is not None:
                    nn.init.xavier_normal_(layer.bias)
                    
    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        x = torch.sigmoid(self.fc4(x))
        return x

In [6]:
class SINDy():
    def __init__(self):
        self.fit()
        
    def fit(self, x):
        sin_x = torch.sin(x)
        cos_x = torch.cos(x)
        x0 = torch.tensor(1)
        x1 = x
        x2 = x**2
        x3 = x**3
        x4 = x**4

        out = torch.stack((x0, x1, x2, x3, x4, sin_x, cos_x))
        return out

In [7]:
class AutoEncoder(nn.Module):
    
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.SINDy = SINDy()
        self.SINDy_coeff = nn.Parameter(torch.full((7,1),1,dtype = torch.float32,requires_grad=True))
        
    def compute_gradients(self, x, x_dot,x_dot_dot, w, b):
        dx = x_dot
        d2x = x_dot_dot
        for i in range(len(w)-1):
            x = torch.sigmoid(torch.matmul(x, w[i].T) + b[i])
            dx_cap = torch.matmul(dx, w[i])
            sigm_dash = torch.multiply(x, 1-x)
            sigm_doubledash = torch.multiply(sigm_dash, 1 - 2*x)
            dx = torch.multiply(sigm_dash, dx_cap)
            d2x = torch.multiply(sigm_doubledash, torch.square(dx_cap)) + torch.multiply(sigm_dash, torch.matmul(d2x, w[i]))
        dx = torch.matmul(dx, w[-1])
        d2x = torch.matmul(d2x, w[-1])
        
        return dx, d2x
        
    def compute_losses(self, x, x_dot, x_dot_dot):
        dx = x_dot
        d2x = x_dot_dot
        z = self.encoder(x)
        x_predicted = self.decoder(z)
        sindy_coeff = self.SINDy_coeff
        theta = self.SINDy.fit(z)
        sindy_predicted = torch.matmul(theta, sindy_coeff) 
        
        encoder_parameters = list(self.encoder.parameters())
        encoder_weight_list = [w for w in encoder_parameters if len(w.shape) == 2]
        encoder_biases_list = [b for b in encoder_parameters if len(b.shape) == 1]
        
        decoder_parameters = list(self.decoder.parameters())
        decoder_weight_list = [w for w in decoder_parameters if len(w.shape) == 2]
        decoder_biases_list = [b for b in decoder_parameters if len(b.shape) == 1]
        
        dz, d2z = compute_gradients(x, dx, d2x, encoder_weight_list, encoder_biases_list)
        dx_predicted, d2x_predicted = compute_gradients(z, dz, sindy_predicted, decoder_weight_list, decoder_biases_list)
        
        losses = {}
        losses['Recontruction Loss'] = torch.mean((x - x_predicted)**2)
        losses['SINDy Loss in z'] = torch.mean((d2z - sindy_predicted)**2)
        losses['SINDy Loss in x'] = torch.mean((d2x - d2x_predicted)**2)
        losses['Regularization Loss'] = torch.mean(torch.abs(sindy_coeff))
        loss = losses['Recontruction Loss'] + 5e-4 * losses['SINDy Loss in x'] + 5e-5 * losses['SINDy Loss in z'] + 1e-5 * losses['Regularization Loss']
        
        return loss, losses
    
    
    def forward(self, x, x_dot, x_dot_dot):
        return x, x_dot, x_dot_dot
        