In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd.functional as F
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
class TimeSeriesDataset(Dataset):
    def __init__(self, time_series, sequence_length, prediction_step=1):
        self.time_series = torch.FloatTensor(time_series)
        self.sequence_length = sequence_length
        self.prediction_step = prediction_step

    def __len__(self):
        return len(self.time_series) - self.sequence_length - self.prediction_step + 1

    def __getitem__(self, idx):
        y = self.time_series[idx : idx + self.sequence_length]
        y_dot = (
            self.time_series[
                idx
                + self.prediction_step : idx
                + self.sequence_length
                + self.prediction_step
            ]
            - self.time_series[idx : idx + self.sequence_length]
        ) / self.prediction_step
        return y, y_dot

In [30]:
class SINDyAutoencoder(nn.Module):
    def __init__(self, input_dim, latent_dim, library_dim):
        super().__init__()
        self.encoder_lstm = nn.LSTM(input_dim, 32, batch_first=True)
        self.encoder_linear = nn.Linear(32, latent_dim)
        self.decoder_linear = nn.Linear(latent_dim, 32)
        self.decoder_lstm = nn.LSTM(32, input_dim, batch_first=True)
        self.library = SINDyLibrary(latent_dim)
        self.xi = nn.Parameter(torch.randn(library_dim, latent_dim).cuda())

    def forward(self, y):
        # y shape: [batch_size, sequence_length, input_dim]
        lstm_out, _ = self.encoder_lstm(y)
        z = self.encoder_linear(
            lstm_out
        )  # z shape: [batch_size, sequence_length, latent_dim]

        decoder_in = self.decoder_linear(z)
        y_reconstructed, _ = self.decoder_lstm(decoder_in)
        return z, y_reconstructed

    def encoder(self, y):
        lstm_out, _ = self.encoder_lstm(y)
        return self.encoder_linear(lstm_out)

    def decoder(self, z):
        decoder_in = self.decoder_linear(z)
        y_reconstructed, _ = self.decoder_lstm(decoder_in)
        return y_reconstructed

    def sindy_predict(self, z):
        # Ensure z is 3D: [batch_size, sequence_length, latent_dim]
        if z.dim() == 2:
            z = z.unsqueeze(1)
        theta = self.library(z)
        return torch.einsum("...ij,jk->...ik", theta, self.xi)

In [20]:
import torch
import torch.autograd.functional as F


def compute_sindy_losses(model, y, y_dot):
    # Ensure y requires grad
    y = y.detach().requires_grad_(True)

    # Forward pass
    z, y_reconstructed = model(y)

    # Compute gradients
    d_phi_d_y = F.jacobian(model.encoder, y)
    d_phi_d_y = d_phi_d_y.permute(0, 2, 1, 3)

    # Compute SINDy prediction
    z_dot_sindy = model.sindy_predict(z)

    # Compute L_ż
    z_dot_chain_rule = torch.matmul(d_phi_d_y, y_dot.unsqueeze(-1)).squeeze(-1)
    L_z_dot = torch.norm(z_dot_chain_rule - z_dot_sindy, p=2, dim=(1, 2)).mean()

    # Compute L_ẏ
    d_psi_d_z = F.jacobian(model.decoder, z)
    d_psi_d_z = d_psi_d_z.permute(0, 2, 1, 3)

    y_dot_reconstructed = torch.matmul(d_psi_d_z, z_dot_sindy.unsqueeze(-1)).squeeze(-1)
    L_y_dot = torch.norm(y_dot - y_dot_reconstructed, p=2, dim=(1, 2)).mean()

    return L_z_dot, L_y_dot

In [21]:
class SINDyLibrary(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim

    def forward(self, z):
        # Ensure z is 3D: [batch_size, sequence_length, latent_dim]
        if z.dim() == 2:
            z = z.unsqueeze(1)

        batch_size, seq_len, _ = z.shape
        library = [torch.ones(batch_size, seq_len, 1).to(z.device)]
        for i in range(self.latent_dim):
            library.append(z[:, :, i : i + 1])
            for j in range(i, self.latent_dim):
                library.append(z[:, :, i : i + 1] * z[:, :, j : j + 1])
        return torch.cat(library, dim=2)


from tqdm import tqdm


def train(model, data_loader, num_epochs, learning_rate, lambda_1, lambda_2):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    mse_loss = nn.MSELoss()

    for epoch in range(num_epochs):
        # Initialize tqdm progress bar
        pbar = tqdm(total=len(data_loader), desc=f"Epoch {epoch+1}/{num_epochs}")

        # Initialize loss accumulators
        epoch_recon_loss = 0
        epoch_z_dot_loss = 0
        epoch_y_dot_loss = 0
        epoch_total_loss = 0

        for batch in data_loader:
            y, y_dot = batch
            y = y.cuda()
            y_dot = y_dot.cuda()

            # Forward pass
            z, y_reconstructed = model(y)

            # Compute losses
            reconstruction_loss = mse_loss(y, y_reconstructed)
            L_z_dot, L_y_dot = compute_sindy_losses(model, y, y_dot)

            # Total loss
            total_loss = reconstruction_loss + lambda_1 * L_z_dot + lambda_2 * L_y_dot

            # Backpropagation and optimization
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            # Accumulate losses
            epoch_recon_loss += reconstruction_loss.item()
            epoch_z_dot_loss += L_z_dot.item()
            epoch_y_dot_loss += L_y_dot.item()
            epoch_total_loss += total_loss.item()

            # Update progress bar
            pbar.update(1)
            pbar.set_postfix(
                {
                    "Recon Loss": f"{epoch_recon_loss / (pbar.n + 1):.4f}",
                    "Z_dot Loss": f"{epoch_z_dot_loss / (pbar.n + 1):.4f}",
                    "Y_dot Loss": f"{epoch_y_dot_loss / (pbar.n + 1):.4f}",
                    "Total Loss": f"{epoch_total_loss / (pbar.n + 1):.4f}",
                }
            )

        # Close progress bar
        pbar.close()

        # Print epoch summary
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"  Recon Loss: {epoch_recon_loss / len(data_loader):.4f}")
        print(f"  Z_dot Loss: {epoch_z_dot_loss / len(data_loader):.4f}")
        print(f"  Y_dot Loss: {epoch_y_dot_loss / len(data_loader):.4f}")
        print(f"  Total Loss: {epoch_total_loss / len(data_loader):.4f}")
        print()

In [4]:
data = np.load("../data/exp_pro/v1_anim5_tp6_actwvelp.npy", allow_pickle=True)
x = data[100, :]

In [5]:
sequence_length = 50  # Length of input sequences
batch_size = 32
prediction_step = 1  # For computing y_dot

# Create dataset and dataloader
dataset = TimeSeriesDataset(x, sequence_length, prediction_step)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [24]:
# Usage
input_dim = 3  # Adjust based on your data
latent_dim = 2  # Adjust based on your expected system complexity
library_dim = 6  # Adjust based on your SINDy library

num_epochs = 10
learning_rate = 1e-3
lambda_1 = 0.1  # Weight for L_z_dot
lambda_2 = 0.1  # Weight for L_y_dot

In [32]:
y, y_dot = next(iter(dataloader))

In [35]:
model = SINDyAutoencoder(50, latent_dim, library_dim).cuda()

In [36]:
train(model, dataloader, num_epochs, learning_rate, lambda_1, lambda_2)

Epoch 1/10:  20%|█▉        | 46/233 [00:51<03:32,  1.13s/it, Recon Loss=0.9601, Z_dot Loss=1.4889, Y_dot Loss=6.3128, Total Loss=1.7403]

KeyboardInterrupt: 

In [96]:
y, y_dot = next(iter(dataloader))
z, y_reconstructed = model(y)

In [97]:
theta = model.library(z)

In [None]:
torch.matmul()

In [98]:
theta.shape, model.xi.shape

(torch.Size([32, 6]), torch.Size([6, 2]))

In [58]:
y, y_dot = next(iter(dataloader))

y = y.detach().requires_grad_(True)

z = model.encoder(y)
y_reconstructed = model.decoder(z)

# Compute gradients
d_phi_d_y = torch.autograd.functional.jacobian(model.encoder, y)

# Compute SINDy prediction
z_dot_sindy = model.sindy_predict(z)

# Compute L_ż
# Reshape d_phi_d_y for batch matrix multiplication
d_phi_d_y = d_phi_d_y.permute(0, 2, 1, 3)  # (batch, latent_dim, seq_len, feature_dim)
z_dot_chain_rule = torch.matmul(d_phi_d_y, y_dot.unsqueeze(-1)).squeeze(-1)
L_z_dot = torch.norm(z_dot_chain_rule - z_dot_sindy, p=2, dim=(1, 2)).mean()

# Compute L_ẏ
d_psi_d_z = torch.autograd.functional.jacobian(model.decoder, z)
y_dot_reconstructed = torch.matmul(d_psi_d_z, z_dot_sindy.unsqueeze(-1)).squeeze(-1)
L_y_dot = torch.norm(y_dot - y_dot_reconstructed, p=2, dim=(1, 2)).mean()

RuntimeError: The size of tensor a (50) must match the size of tensor b (32) at non-singleton dimension 1

In [62]:
d_psi_d_z.shape, z_dot_sindy.unsqueeze(-1).shape

(torch.Size([32, 50, 32, 2]), torch.Size([32, 2, 1]))

In [67]:
def compute_sindy_losses(model, y, y_dot):
    # Ensure y requires grad
    y = y.detach().requires_grad_(True)

    z = model.encoder(y)
    y_reconstructed = model.decoder(z)

    print(f"y shape: {y.shape}")
    print(f"y_dot shape: {y_dot.shape}")
    print(f"z shape: {z.shape}")
    print(f"y_reconstructed shape: {y_reconstructed.shape}")

    # Compute gradients
    d_phi_d_y = torch.autograd.functional.jacobian(model.encoder, y)
    print(f"d_phi_d_y shape: {d_phi_d_y.shape}")

    # Compute SINDy prediction
    z_dot_sindy = model.sindy_predict(z)
    print(f"z_dot_sindy shape: {z_dot_sindy.shape}")

    # Compute L_ż
    # Reshape d_phi_d_y for batch matrix multiplication
    d_phi_d_y = d_phi_d_y.permute(
        0, 2, 1, 3
    )  # (batch, latent_dim, seq_len, feature_dim)
    z_dot_chain_rule = torch.matmul(d_phi_d_y, y_dot.unsqueeze(-1)).squeeze(-1)
    print(f"z_dot_chain_rule shape: {z_dot_chain_rule.shape}")
    L_z_dot = torch.norm(z_dot_chain_rule - z_dot_sindy, p=2, dim=(1, 2)).mean()

    # Compute L_ẏ
    d_psi_d_z = torch.autograd.functional.jacobian(model.decoder, z)
    print(f"d_psi_d_z shape: {d_psi_d_z.shape}")
    y_dot_reconstructed = torch.matmul(d_psi_d_z, z_dot_sindy.unsqueeze(-1)).squeeze(-1)
    L_y_dot = torch.norm(y_dot - y_dot_reconstructed, p=2, dim=(1, 2)).mean()

    return L_z_dot, L_y_dot

In [68]:
y, y_dot = next(iter(dataloader))

compute_sindy_losses(model, y, y_dot)

TypeError: linear(): argument 'input' (position 1) must be Tensor, not tuple