In [30]:
import numpy as np
import torch
from argparse import Namespace
import json

import sys, os

# Or if notebooks lives in src/notebooks, you want two levels up:
src_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir, os.pardir))

sys.path.insert(0, src_root)

# Now you can import as if you're in src/
from models.SINDyAE_o2 import Net
from utils.model_utils import sindy_library, build_equation_labels

In [8]:
sess_name = '05-18-2025_0'
args_path = f'/home/jared/Projects/SINDy/SINDy_Autoencoder_PyTorch_EP/trained_models/elastic_pendulum/SINDyAE_o2/{sess_name}/args.txt'

In [9]:
args = json.load(open(args_path, 'r'))
args = Namespace(**args)
print(args.__dict__)

{'session_name': '05-18-2025_0', 'model': 'SINDyAE_o2', 'experiments': './experiments/', 'model_folder': './trained_models/', 'tensorboard_folder': './tb_runs/', 'data_set': 'elastic_pendulum', 'z_dim': 2, 'u_dim': 2601, 'hidden_dims': [256, 128, 64], 'use_inverse': True, 'use_sine': True, 'use_cosine': True, 'poly_order': 3, 'include_constant': True, 'nonlinearity': 'elu', 'epochs': 500, 'learning_rate': 0.001, 'adam_regularization': 1e-05, 'gamma_factor': 0.995, 'batch_size': 64, 'lambda_ddx': 0.0005, 'lambda_ddz': 5e-05, 'lambda_reg': 1e-05, 'clip': None, 'test_interval': 1, 'checkpoint_interval': 1, 'sequential_threshold': 0.05, 'train_initial_conds': 200, 'val_initial_conds': 20, 'test_initial_conds': 20, 'timesteps': 500, 'load_cp': 0, 'device': 0, 'print_folder': 1}


In [10]:
torch.cuda.set_device(args.device)
device = torch.cuda.current_device()

In [12]:
net = Net(args)
net.to(device)

Net(
  (mse): MSELoss()
  (encoder): Sequential(
    (0): Linear(in_features=2601, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=64, out_features=2, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=2, out_features=64, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=64, out_features=128, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=128, out_features=256, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=256, out_features=2601, bias=True)
  )
)

In [13]:
train = np.load('/home/jared/Projects/SINDy/SINDy_Autoencoder_PyTorch_EP/data/elastic_pendulum/train.npy', allow_pickle=True)
train = train.item()
train.keys()

dict_keys(['t', 'x', 'dx', 'ddx', 'z', 'dz'])

In [49]:
x = torch.Tensor(train['x'][:5]).to(device).reshape((5, 1, 2601))
dx = torch.Tensor(train['dx'][:5]).to(device).reshape((5, 1, 2601))
ddx = torch.Tensor(train['ddx'][:5]).to(device).reshape((5, 1, 2601))

lambdas = args.lambda_ddx, args.lambda_ddz, args.lambda_reg

In [50]:
net(x, dx, ddx, lambdas)

(tensor(0.0087, device='cuda:0', grad_fn=<MseLossBackward0>),
 tensor(68.8040, device='cuda:0', grad_fn=<MulBackward0>),
 tensor(880.1335, device='cuda:0', grad_fn=<MulBackward0>),
 tensor(6.5682e-07, device='cuda:0', grad_fn=<MulBackward0>))

In [26]:
z = torch.Tensor(train['z'][:5]).to(device)
dz = torch.Tensor(train['dz'][:5]).to(device)

In [29]:
sindy_library(z, dz, 3, device, True, True, True, True)

tensor([[ 1.0000,  0.8385,  3.1009,  ..., -0.4466,  0.6675, -0.9975],
        [ 1.0000,  0.8268,  3.0926,  ..., -0.4581,  0.6756, -0.9964],
        [ 1.0000,  0.8128,  3.0837,  ..., -0.4718,  0.6851, -0.9950],
        [ 1.0000,  0.7968,  3.0743,  ..., -0.4875,  0.6958, -0.9932],
        [ 1.0000,  0.7789,  3.0641,  ..., -0.5050,  0.7074, -0.9910]],
       device='cuda:0')

In [31]:
build_equation_labels(2, 3, True, True, True, True)

['1',
 'X',
 'Y',
 'Xdot',
 'Ydot',
 '1/X',
 '1/Y',
 'sin(X)',
 'sin(Y)',
 'cos(X)',
 'cos(Y)',
 'X*X',
 'X*Y',
 'X*Xdot',
 'X*Ydot',
 'X*1/X',
 'X*1/Y',
 'X*sin(X)',
 'X*sin(Y)',
 'X*cos(X)',
 'X*cos(Y)',
 'Y*Y',
 'Y*Xdot',
 'Y*Ydot',
 'Y*1/X',
 'Y*1/Y',
 'Y*sin(X)',
 'Y*sin(Y)',
 'Y*cos(X)',
 'Y*cos(Y)',
 'Xdot*Xdot',
 'Xdot*Ydot',
 'Xdot*1/X',
 'Xdot*1/Y',
 'Xdot*sin(X)',
 'Xdot*sin(Y)',
 'Xdot*cos(X)',
 'Xdot*cos(Y)',
 'Ydot*Ydot',
 'Ydot*1/X',
 'Ydot*1/Y',
 'Ydot*sin(X)',
 'Ydot*sin(Y)',
 'Ydot*cos(X)',
 'Ydot*cos(Y)',
 '1/X*1/X',
 '1/X*1/Y',
 '1/X*sin(X)',
 '1/X*sin(Y)',
 '1/X*cos(X)',
 '1/X*cos(Y)',
 '1/Y*1/Y',
 '1/Y*sin(X)',
 '1/Y*sin(Y)',
 '1/Y*cos(X)',
 '1/Y*cos(Y)',
 'sin(X)*sin(X)',
 'sin(X)*sin(Y)',
 'sin(X)*cos(X)',
 'sin(X)*cos(Y)',
 'sin(Y)*sin(Y)',
 'sin(Y)*cos(X)',
 'sin(Y)*cos(Y)',
 'cos(X)*cos(X)',
 'cos(X)*cos(Y)',
 'cos(Y)*cos(Y)',
 'X*X*X',
 'X*X*Y',
 'X*X*Xdot',
 'X*X*Ydot',
 'X*X*1/X',
 'X*X*1/Y',
 'X*X*sin(X)',
 'X*X*sin(Y)',
 'X*X*cos(X)',
 'X*X*cos(Y)',
 'X*