In [None]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from dataset import LandscapeSimulationDataset
from model import PhiNN
from model_training import train_model
from helpers import select_device, jump_function, mean_cov_loss

In [None]:
outdir = "out/testing"
dtype = torch.float32
device = 'cpu'

In [None]:
datdir_train = "data/model_training_data"
datdir_valid = "data/model_validation_data"

nsims_train = 2
nsims_valid = 1
ndims = 2
nsigs = 2
ncells = 100
dt = 1e-3
sigma = 1e-3

batch_size = 2

In [None]:
train_dataset = LandscapeSimulationDataset(
    datdir_train, nsims_train, ndims, 
    transform='tensor', 
    target_transform='tensor',
    dtype=dtype,
)

validation_dataset = LandscapeSimulationDataset(
    datdir_valid, nsims_valid, ndims, 
    transform='tensor', 
    target_transform='tensor',
    dtype=dtype,
)

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=batch_size, 
    shuffle=True
)

validation_dataloader = DataLoader(
    validation_dataset, 
    batch_size=batch_size, 
    shuffle=False
)

print("Training Dataset Length:", len(train_dataset))
print("Validation Dataset Length:", len(validation_dataset))
print("Training DataLoader Length:", len(train_dataloader))
print("Validation DataLoader Length:", len(validation_dataloader))

In [None]:
# Construct the model
f_signal = lambda t, p: jump_function(t, p[...,0], p[...,1:3], p[...,3:])
model = PhiNN(
    ndim=ndims, nsig=nsigs, f_signal=f_signal, 
    ncells=ncells, 
    sigma=sigma,
    device=device,
    dtype=dtype,
).to(device)

loss_fn = mean_cov_loss

learning_rate = 1e-2
momentum = 0.9

optimizer = torch.optim.SGD(
    model.parameters(), 
    lr=learning_rate, 
    momentum=momentum
)

In [None]:
input, expected = next(iter(train_dataloader))

In [None]:
print(input.shape)
print(expected.shape)

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

In [None]:
with profile(activities=[ProfilerActivity.CPU], 
             profile_memory=True, record_shapes=True) as prof:
    model(input)
    # train_model(
    #     model, dt, loss_fn, optimizer, 
    #     train_dataloader, validation_dataloader,
    #     num_epochs=1,
    #     batch_size=batch_size,
    #     device=device,
    #     model_name='testmodel',
    #     outdir=outdir,
    # )

In [None]:
print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))


In [None]:
print(prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=10))
