In [18]:
import pysindy as ps

import deepSI
from deepSI.fit_systems import SS_encoder_general
from deepSI.fit_systems.encoders import default_encoder_net, default_state_net, default_output_net

import torch
from torch import nn

import numpy as np

from sklearn.preprocessing import PolynomialFeatures

from scipy.io import loadmat
import os

import deepSI
from deepSI import System_data

from utils import load_data, normalize

import matplotlib.pyplot as plt

In [19]:
x_data, u_data, y_data, th_data = load_data(pc=0)

train, test = System_data(u=u_data[:10000,0],y=x_data[:10000,:]), System_data(u=u_data[-1000:],y=x_data[-1000:,:])

train.y.shape, train.u.shape, test.y.shape, test.u.shape

((10000, 2), (10000,), (1000, 2), (1000, 1))

In [21]:
class SS_encoder_general_eq(SS_encoder_general):
    def __init__(self, nx=10, na=20, nb=20, feedthrough=False, \
        e_net=default_encoder_net, f_net=default_state_net, h_net=default_output_net, \
        e_net_kwargs={},           f_net_kwargs={},         h_net_kwargs={}, na_right=0, nb_right=0, \
        gamma=1e-4):

        super(SS_encoder_general_eq, self).__init__()
        self.nx, self.na, self.nb = nx, na, nb
        self.k0 = max(self.na,self.nb)
        
        self.e_net = e_net
        self.e_net_kwargs = e_net_kwargs

        self.f_net = f_net
        self.f_net_kwargs = f_net_kwargs

        self.h_net = h_net
        self.h_net_kwargs = h_net_kwargs

        self.feedthrough = feedthrough
        self.na_right = na_right
        self.nb_right = nb_right
        ######################################
        # args added for feature transform and
        # regurlarization
        self.gamma = gamma
        ######################################

    def init_nets(self, nu, ny): # a bit weird
        na_right = self.na_right if hasattr(self,'na_right') else 0
        nb_right = self.nb_right if hasattr(self,'nb_right') else 0
        self.encoder = self.e_net(nb=(self.nb+nb_right), nu=nu, na=(self.na+na_right), ny=ny, nx=self.nx, **self.e_net_kwargs)
        ######################################
        ###### change fn intialization #######
        self.fn     =      self.f_net(nx=self.nx, nu=nu, **self.f_net_kwargs)
        ######################################
        if self.feedthrough:
            self.hn =      self.h_net(nx=self.nx, ny=ny, nu=nu,                     **self.h_net_kwargs) 
        else:
            self.hn =      self.h_net(nx=self.nx, ny=ny,                            **self.h_net_kwargs) 

    def loss(self, uhist, yhist, ufuture, yfuture, loss_nf_cutoff=None, **Loss_kwargs):
        x = self.encoder(uhist, yhist) #initialize Nbatch number of states
        errors = []
        for y, u in zip(torch.transpose(yfuture,0,1), torch.transpose(ufuture,0,1)): #iterate over time
            error = nn.functional.mse_loss(y, self.hn(x,u) if self.feedthrough else self.hn(x))
            ##################################
            ## add penalty to weights in fn ##
            # params = [*self.fn.parameters()]
            # weights = [x.view(-1) for x in params][0]
            # error += self.gamma*torch.norm(weights, 1)
            ##################################
            errors.append(error) #calculate error after taking n-steps
            if loss_nf_cutoff is not None and error.item()>loss_nf_cutoff:
                print(len(errors), end=' ')
                break
            x = self.fn(x,u) #advance state.
        
            
        return torch.mean(torch.stack(errors))

In [22]:
class h_identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, input):
        return input
    
class e_identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, *input):
        output = input[-1]
        output = torch.reshape(output,(output.shape[0], output.shape[-1]))
    
        return output
    
class simple_Linear(torch.nn.Module):
    def __init__(self, nx, nu, **kwargs):
        super(simple_Linear, self).__init__()

        self.nx = nx
        self.nu = kwargs['u']

        self.feature_library = kwargs['feature_library']
        test_sample = torch.rand(1,self.nx+self.nu, requires_grad=True)
        self.nf = (self.feature_library.fit_transform(test_sample)).shape[1]
        
        self.layer = nn.Linear(self.nf, nx, bias=False)
        

    def forward(self, x, u):
        # make sure u is column
        u = torch.reshape(u, (u.shape[-1],1))
        x = torch.hstack((x, u))
        Theta = self.feature_library.fit_transform(x)
        out = self.layer(Theta)
        return out

In [23]:
class feature_library():
    """ Class object that holds the features for constructing the functions basis.

        Attributes:
            functions (list)        : list of function objects
            nx (int)                : the number of states
            nu (int)                : the number of inputs
            include_one (bool)      : choose to include an offset
            interaction_only (bool) : true exlcudes terms such as x1[k]*x2[k]

        Methods:
            __init__(self, functions, nx, nu, include_one, interaction_only):
                Constructor

            fit_transform(self, X):
                Transforms data in X using the specified functions stored in self.functions

            feature_names(self):
                Returns a list of strings which correspond to the features in self.functions

    """
    def __init__(
            self,
            functions,
            nx,
            nu,
            include_one = True,
            interaction_only=True
    ):
        self.functions = functions
        # TODO: add interaction only to include/exclude cross terms
        # now functions are applied to all states and inputs
        self.interaction_only = interaction_only

        # set to false to exclude possible offset
        self.include_one = include_one

        # creates list of factors that are in system
        # x0[k],..., xn[k], u0[k],..., un[k]
        self.term_list = [f"x{i}[k]" for i in range(nx)]
        input_list = [f"u{i}[k]" for i in range(nu)]
        self.term_list.extend(input_list)

    def fit_transform(self, X):
        # if include_one = True add term for offset
        out_feature = ((X[:,0])**0).unsqueeze(1) if self.include_one else torch.empty(X.shape[0], 1)
        if self.interaction_only:
            for f in self.functions:
                out_feature = torch.hstack((out_feature, f(X)[0]))
            return out_feature
        # TODO: add the stuff for cross terms
        
    def feature_names(self):
        # returns list with feature names
        flist = ["1"] if self.include_one else []
        for f in self.functions:
            for x in self.term_list:
                flist.append(f(torch.tensor(1.),f"{x}")[-1])
        return flist

In [24]:
def f(x, name="_"):
  return x, f"{name}"

def f2(x, name="_"):
  return x**2, f"{name}**2"

def f3(x, name="_"):
  return x**3, f"{name}**3"

def sin(x, name="_"):
  return torch.sin(x), f"sin({name})"

In [25]:
functions = [f, sin]

poly = feature_library(functions=functions, nx=2, nu=1)

In [26]:
# fit_sys = deepSI.fit_systems.SS_encoder_general(nx=2, na=50, nb=50)

nx, nu = 2, 1 # state dimension and inputs
na, nb = 1, 0

f_net = simple_Linear
f_net_kwargs = {"feature_library": poly, "u": nu, "nf": 7}

# e_net_kwargs = {"slot": 1}

h_net = h_identity
h_net_kwargs = {}

fit_sys = SS_encoder_general_eq(nx=nx, na=na, nb=nb, \
                                f_net=f_net, f_net_kwargs=f_net_kwargs,\
                                e_net=e_identity, e_net_kwargs=f_net_kwargs,\
                                h_net=h_net)


fit_sys.fit(train, test, epochs=1, batch_size = 9900, optimizer_kwargs={"lr": 1e-3}, loss_kwargs=dict(nf=100), auto_fit_norm=False)

Initilizing the model and optimizer
Size of the training array =  22.8 MB
N_training_samples = 9900, batch_size = 9900, N_batch_updates_per_epoch = 1


Initial Validation sim-NRMS= 189.75283873845834


  0%|          | 0/1 [00:00<?, ?it/s]

torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size([9900, 2])
torch.Size([9900, 2]) torch.Size

100%|██████████| 1/1 [00:01<00:00,  1.19s/it]

########## New lowest validation loss achieved ########### sim-NRMS = 188.28141538035192
Epoch    1, sqrt loss  0.1773, Val sim-NRMS  188.3, Time Loss: 74.6%, data: 1.0%, val: 24.4%,  1.2 sec/batch
Loaded model with best known validation sim-NRMS of  188.3 which happened on epoch 1 (epoch_id=1.00)





In [27]:
found = [*fit_sys.fn.parameters()][0].detach().numpy()

found

array([[ 0.27900198, -0.36201245, -0.22251362, -0.3190518 , -0.2969291 ,
         0.34710744,  0.07655019],
       [ 0.20583577, -0.01664452, -0.28023294, -0.27788538,  0.365198  ,
        -0.27606222,  0.36894867]], dtype=float32)

In [28]:
# test_sim_enc = fit_sys.apply_experiment(test)

# plt.plot(test.y)
# plt.plot(test.y - test_sim_enc.y)
# plt.title(f'test set simulation SS encoder, NRMS = {test_sim_enc.NRMS(test):.2%}')
# plt.show()

In [29]:
# def NRMS(y_pred, y_true):
#     RMS = np.sqrt(np.mean((y_pred-y_true)**2))
#     return RMS/np.std(y_true)

In [30]:
# plt.plot(test.y[:,0])
# plt.plot(test.y[:,0]-test_sim_enc.y[:,0],'--')
# plt.title(f'test set simulation SS encoder, NRMS = {NRMS(test_sim_enc.y[:,0],test.y[:,0]):.2%}')
# plt.show()

# plt.plot(test.y[:,1])
# plt.plot(test.y[:,1]-test_sim_enc.y[:,1],'--')
# plt.title(f'test set simulation SS encoder, NRMS = {NRMS(test_sim_enc.y[:,1],test.y[:,1]):.2%}')
# plt.show()

In [31]:
# plt.plot(test_sim_enc.y[:,1],'--')
# plt.plot(test.y[:,1])

In [32]:
# found = [*fit_sys.fn.parameters()][0].detach().numpy()
# true = np.array([[0, 1, 1, 0, 0, 0, 0],[0, -0.1, 0.5, 0.1, -0.2, 0, 0]])

In [33]:
# from matplotlib.colors import LinearSegmentedColormap

In [34]:
# fig, (ax1,ax2) = plt.subplots(2, 1)

# x_labels = ["1","x0[k]","x1[k]","u[k]","sin(x0[k])","sin([x1[k]])","sin(u[k])"]

# data1 = np.vstack((true[0,:],found[0,:]))
# data2 = np.vstack((true[1,:],found[1,:]))
# cmap_white = LinearSegmentedColormap.from_list("white", [(1, 1, 1), (1, 1, 1)])

# im = ax1.imshow(data1, cmap=cmap_white)

# ax1.set_xticks(np.arange(data1.shape[1]), labels=x_labels, rotation=25)
# ax1.set_yticks(np.arange(data1.shape[0]), labels=["True", "Found"])

# for i in range(data1.shape[0]):
#     for j in range(data1.shape[1]):
#         text = ax1.text(j, i, round(data1[i, j],3),
#                        ha="center", va="center", color="k")
#         rect = plt.Rectangle((j - 0.5, i - 0.5), 1, 1, fill=False, edgecolor='black', linewidth=1)
#         ax1.add_patch(rect)

# ax1.patch.set_linewidth(2.0)        
# ax1.patch.set_edgecolor('black')

# # second
# im = ax2.imshow(data2, cmap=cmap_white)

# ax2.set_xticks(np.arange(data2.shape[1]), labels=x_labels, rotation=25)
# ax2.set_yticks(np.arange(data2.shape[0]), labels=["True", "Found"])

# for i in range(data2.shape[0]):
#     for j in range(data2.shape[1]):
#         text = ax2.text(j, i, round(data2[i, j], 3),
#                        ha="center", va="center", color="k")
#         rect = plt.Rectangle((j - 0.5, i - 0.5), 1, 1, fill=False, edgecolor='black', linewidth=1)
#         ax2.add_patch(rect)

# ax2.patch.set_linewidth(2.0)        
# ax2.patch.set_edgecolor('black')

# plt.show()