In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import yaml
from pathlib import Path

import torch
import numpy as np

from prostat.nn.autoencoder import AutoEncoder
from prostat.utils.dataset import load_dataset
from prostat.utils.plotting import plot, plot_dihedral_distribution
from prostat.trainer.train import train_autoencoder
from prostat.trainer.inference import test_autoencoder

torch.set_default_dtype(torch.double)

In [None]:
config_file = "config/chignolin.yaml"
conf = yaml.safe_load(Path(config_file).read_text())

dataset = load_dataset(conf)

In [None]:
#########################
### Build Autoencoder ###
#########################
locality = 2
desired_stride = 1
path = f'{conf["name"]}_{locality}_{desired_stride}.pth'

model = AutoEncoder(dataset['R'].shape[-2], locality=locality, desired_stride=desired_stride).to(conf['device'])
try:
    model.load_state_dict(torch.load(path))
    print(f'Model weights file {path} loaded!')
except:
    print(f'Model weights file {path} is missing')
#########################
#########################
#########################

In [None]:
train_autoencoder(model, dataset, conf)

In [None]:
torch.save(model.state_dict(), path)

In [None]:
pos_recon, minimized_pos_recon, pos_beads, v1, v2, v12 = test_autoencoder(model, dataset, conf)

In [None]:
nth = 0
pos1 = dataset['pos_test'][nth]
pos2 = minimized_pos_recon[nth].cpu().detach().numpy()

plot(0, pos1, pos2, pos_beads.reshape(pos_beads.shape[0], -1), v1.reshape(v1.shape[0], -1), v2.reshape(v2.shape[0], -1), dataset, bond_idcs=dataset['bond_idcs'])

In [None]:
##################################
### Plot Dihedral Distribution ###
##################################
plot_dihedral_distribution(dataset, pos_recon, minimized_pos_recon)
##################################
##################################
##################################