In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors
import torch
import torch.nn as nn
from learning_dynamics.data_modules import ShallowWaterDataModule
from learning_dynamics.models import ShallowWaterVAEModel, Encoder, Decoder
from learning_dynamics.callbacks import ShallowWaterPEGPVAEPlotting
import wandb
import lightning as L
torch.set_default_dtype(torch.float64)

In [None]:
# run = wandb.init(project="PEGP-VAE", name=f"Physics_0_{int(time.time())}")
wandb_logger = None #WandbLogger()

In [None]:
train_data_path = '../../data/PDEs/Shallow_Water/sin/shallow_water_data.pkl'

In [None]:
data_module = ShallowWaterDataModule(train_data_path=train_data_path)
data_module.setup("fit")

In [None]:
# params
width = 128
height = 128
embed_dim = 512
latent_dim = 1
dt = data_module.train_dataset.dt
val_check = 50
batch_size = 32
max_epochs = 6000
length_scale = 1
kld_max = 1
kld_schedule = 0
max_value = data_module.max_value
name = "vae_sin"

In [None]:
encoder = Encoder(width, height, embed_dim, latent_dim)
decoder = Decoder(width, height, embed_dim, latent_dim)

In [None]:
model = ShallowWaterVAEModel(encoder, decoder, dt, length_scale=length_scale, kld_max=kld_max, kld_schedule=kld_schedule, norm_constant=max_value, name=name)
trainer = L.Trainer(max_epochs=max_epochs, callbacks=[ShallowWaterPEGPVAEPlotting(dt)], check_val_every_n_epoch=val_check, logger=wandb_logger)
trainer.fit(model, datamodule=data_module)