In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error

In [2]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(2601, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 32)
        self.fc4 = nn.Linear(32, 1)
        
    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        x = torch.sigmoid(self.fc4(x))
        return x

In [3]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(1, 32)
        self.fc2 = nn.Linear(32, 64)
        self.fc3 = nn.Linear(64, 128)
        self.fc4 = nn.Linear(128, 2601)
        
    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        x = torch.sigmoid(self.fc4(x))
        return x

In [4]:
class SINDy():
    def fit(self, x):
        size = x.shape
        sin_x = torch.sin(x)
        cos_x = torch.cos(x)
        x0 = torch.ones(size)
        x1 = x
        x2 = x**2
        x3 = x**3
        x4 = x**4
        
    output = torch.stack((x0,x1, x2, x3, x4, sin_x, cos_x))
    return output

In [None]:
class AutoEncoder(nn.Module):
    
    def __init(self):
        super(AutoEncoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.SINDy() = SINDy()
        self.SINDy_coeff = nn.Parameter(torch.full((7,1),1,dtype = torch.float32,requires_grad=True))
        
    def compute_gradients_second_order(self, x, x_dot, x_dot_dot, w, b):
        dx = x_dot
        d2x = x_dot_dot
        for i in range(len(w)-1):
            x = torch.sigmoid(torch.matmul(x, w[i].T) + b[i])
            dx_cap = torch.matmul(dx, w[i])
            sigm_dash = torch.multiply(x, 1-x)
            sigm_doubledash = torch.multiply(sigm_dash, 1 - 2*x)
            dx = torch.multiply(sigm_dash, dx_cap)
            d2x = torch.multiply(sigm_doubledash, torch.square(dx_cap)) + torch.multiply(sigm_dash, torch.matmul(d2x, w[i]))
        dx = torch.matmul(dx, w[-1])
        d2x = torch.matmul(d2x, w[-1])
        
        return dx, d2x
    
    def outputs(self, x, x_dot):
        z = self.encoder(z)
        x_predicted = self.decoder(z)
        
        theta = self.SINDy(z)