In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors
import torch
from learning_dynamics.data_modules import FingerAltDataModule
from learning_dynamics.models import FingerVAEModel, Encoder, Decoder
from learning_dynamics.callbacks import FingerPEGPVAEPlotting
import wandb
from lightning.pytorch.callbacks import ModelCheckpoint
import lightning as L
torch.set_default_dtype(torch.float64)

In [None]:
# run = wandb.init(project="PEGP-VAE", name=f"Physics_0_{int(time.time())}")
wandb_logger = None #WandbLogger()

In [None]:
train_data_path = '../../../../data/ODEs/finger/real_finger/pickle_files/shortened_real_finger_pos_time_1.5_ls_0.6_mean_22_var_4_scale_100.pkl'

In [None]:
batch_size = 32
data_module = FingerAltDataModule(train_data_path=train_data_path, batch_size=batch_size, val_fraction=0.4)
data_module.setup("fit")

In [None]:
# params
width = data_module.train_dataset.dataset.width
height = data_module.train_dataset.dataset.height
embed_dim = 128 
latent_dim = 1
dt = data_module.train_dataset.dataset.dt
val_check = 100
length_scale = 0.5
kld_max = 1
kld_schedule = 0
file = f"vae_0.4_video"
max_epochs = 4000

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath=f"model_checkpoints/{file}",          
    filename="model-{epoch:03d}",  
    save_top_k=-1,                 
    every_n_epochs=val_check,               
    save_on_train_epoch_end=False   
)

In [None]:
trainer = L.Trainer(devices=[0], max_epochs=max_epochs, callbacks=[FingerPEGPVAEPlotting(dt=dt, file=file, vids_per_batch=3), checkpoint_callback], check_val_every_n_epoch=val_check, logger=False)
trainer.fit(model, datamodule=data_module, ckpt_path="last")

In [None]:
ckpt_path = f"model_checkpoints/{file}/model-epoch=3999.ckpt" 

encoder = Encoder(width, height, embed_dim, latent_dim)
decoder = Decoder(width, height, embed_dim, latent_dim)

trainer = L.Trainer(devices=[0], max_epochs=max_epochs, callbacks=[FingerPEGPVAEPlotting(dt=dt, file=file, vids_per_batch=0), checkpoint_callback], check_val_every_n_epoch=val_check, logger=False)

model = FingerVAEModel(encoder=encoder, decoder=decoder, dt=dt, kld_max=kld_max, kld_schedule=kld_schedule)
trainer.fit(model, datamodule=data_module)

# model = FingerVAEModel.load_from_checkpoint(ckpt_path, encoder=encoder,decoder=decoder, dt=dt, kld_max=kld_max, kld_schedule=kld_schedule)
# trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path)