In [3]:
import sys
sys.path.append("../../src")
import os
import datetime
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from example_lorenz import get_lorenz_data
from autoencoderTF2 import Autoencoder, Encoder, Decoder
from SINDYTF2 import sindy_library_tf, library_size
from typing import List, Optional
import numpy as np
from tensorflow.keras.layers import Dense, Input, Dropout

In [None]:
class Autoencoder(tf.keras.Model):
    """
    Autoencoder
    Stack both encoder and decoder
    """
    def __init__(self, 
    widths : List[int] = [32,28,25], 
    ekwargs : Optional[dict] = {}, 
    dkwargs : Optional[dict] = {}, 
    **kwargs):
        if not ("name" in kwargs.keys()):
            kwargs["name"] = "autoencoder"

        super(Autoencoder, self).__init__(**kwargs)
        self.input_dim = widths[0]
        self.latent_dim = widths[-1]
        self.encoder = Encoder(widths, **ekwargs).build_graph()
        self.decoder = Decoder(self.encoder, **dkwargs).build_graph()
        self.theta = None
    
    def call(self, input):        
        x = self.encoder.layers[1](input)
        for layer in self.encoder.layers[2:] + self.decoder.layers[1:]:
            x = layer(x)
        return x

    def compile(self, **kwargs):
        super(Autoencoder, self).compile(**kwargs)

    def build_graph(self,):
        x = Input(shape=(self.input_dim, ), name = 'autoencoder_input')
        return tf.keras.Model(inputs = [x], outputs = self.call(x))

    def set_theta(self, x, poly_order, include_sin):
        self.theta = sindy_library_tf(self.encode(x), poly_order, include_sin)

    def encode(self, input):
        return self.encoder(input)
    
    def decode(self, input):
        return self.decoder(input)

    def get_input_dim(self) -> int:
        return self.input_dim
    
    def get_latent_dim(self) -> int:
        return self.latent_dim
    
    def set_theta(self, theta):
        self.theta = theta

    def get_loss(self, x):
        x = self.currentState["x"]
        x_decode = self.currentState["x_decode"]
        dz = self.currentState["dz"]
        dz_predict = self.currentState["dz_predict"]
        dx = self.currentState["dx"]
        dx_decode = self.currentState["dx_decode"]
        sindy_coefficients = self.coefficient_mask*self.sindy_coefficients

        losses = {}
        losses['decoder'] = tf.reduce_mean((x - x_decode)**2)
        losses['sindy_z'] = tf.reduce_mean((dz - dz_predict)**2)
        losses['sindy_x'] = tf.reduce_mean((dx - dx_decode)**2)
        losses['sindy_regularization'] = tf.reduce_mean(tf.abs(sindy_coefficients))
        loss = self.currentState['loss_weight_decoder'] * losses['decoder'] \
           + self.currentState['loss_weight_sindy_z'] * losses['sindy_z'] \
           + self.currentState['loss_weight_sindy_x'] * losses['sindy_x'] \
           + self.currentState['loss_weight_sindy_regularization'] * losses['sindy_regularization']

        loss_refinement = self.currentState['loss_weight_decoder'] * losses['decoder'] \
            + self.currentState['loss_weight_sindy_z'] * losses['sindy_z'] \
            + self.currentState['loss_weight_sindy_x'] * losses['sindy_x']

        return loss, losses, loss_refinement
    
    def get_loss_refinement(self, ):
        return 0
    
    def get_losses(self,) -> dict:
        return {}


In [4]:

class SindyTrain:
    
    def __init__(self,
    model : Autoencoder,
    coefficient_threshold : float = 0.1,
    threshold_frequency : int = 50,
    loss_weights : List[float] = [1.0, 0.0, 1e-4, 1e-5],
    ):
        """[summary]

        Args:
            model (Autoencoder): Autoencoder not compiled
            coefficient_threshold (float, optional): Defaults to 0.1.
            threshold_frequency (int, optional): Defaults to 50.
        """
        
        # SINDY related attributes
        self.model = model
        self.coefficient_threshold = coefficient_threshold
        self.threshold_frequency = threshold_frequency
        self.loss_weight_decoder = loss_weights[0]
        self.loss_weight_sindy_x = loss_weights[1]
        self.loss_weight_sindy_z = loss_weights[2]
        self.loss_weight_sindy_regularization = loss_weights[3]


        self.currentState = {}
        self.coefficient_mask = None
        self.theta = None
        self.coefficient_mask = None
        self.sindy_coefficients = None
        self.library_dim = None
 
    def training_step(self, x):
        
        return 

    def fit(self, 
        optimizer : tf.keras.optimizers = tf.keras.optimizers.Adam(learning_rate = 1e-3),
        batch_size : int = 1024,
        refinement_epochs = 1001,
        data_path : str = os.getcwd() + "/"):
        return

    def get_loss(self, x, dx):
        loss_refinement, losses = self.get_loss_refinement(x, dx)
        return loss_refinement + self.loss_weight_sindy_regularization*losses['sindy_regularization']

    def get_loss_refinement(self, x, dx):

        dz = self.model.encode(dx)
        dz_predict = self.sindy_predict()
        x_decode = self.model.call(x)
        dx_decode = self.model.decode(dz_predict)

        losses = {}
        losses['decoder'] = tf.reduce_mean((x - x_decode)**2)
        losses['sindy_z'] = tf.reduce_mean((dz - dz_predict)**2)
        losses['sindy_x'] = tf.reduce_mean((dx - dx_decode)**2)
        losses['sindy_regularization'] = tf.reduce_mean(tf.abs(self.sindy_coefficients))

        loss_refinement = self.loss_weight_decoder * losses['decoder'] \
            + self.loss_weight_sindy_x * losses['sindy_z'] \
            + self.loss_weight_sindy_z * losses['sindy_x'] \
            
        return loss_refinement, losses
        

    def set_theta(self, x, poly_order : Optional[int] = 3, include_sin : Optional[bool] = True):
        self.theta = sindy_library_tf(x, self.model.latent_dim, poly_order, include_sin)
        self.library_dim = library_size(self.model.latent_dim, poly_order, include_sin, True)
        self.coefficient_mask = np.ones((self.library_dim, self.model.latent_dim))

    def sindy_predict(self, ):
        return tf.matmul(self.theta, self.coefficient_mask * self.sindy_coefficients)


    def compile(self, ):

        return