In [1]:
import mdtraj as md
from pathlib import Path
import torch
import numpy as np
import sys
sys.path.append('../')
from molgen.models import WGANGP

In [2]:
pdb_fname = '/project/andrewferguson/Kirill/CMSC-35450/data_mdshare/alanine-dipeptide-nowater.pdb'
trj_fnames = [str(i) for i in Path('/project/andrewferguson/Kirill/CMSC-35450/data_mdshare').glob('alanine-dipeptide-*-250ns-nowater.xtc')]
trjs  = [md.load(t, top=pdb_fname).center_coordinates() for t in trj_fnames]

In [3]:
xyz = list()
phi_psi = list()
for trj in trjs:
    
    t_backbone = trj.atom_slice(trj.top.select('backbone')).center_coordinates()
    
    n = trj.xyz.shape[0]
    
    _, phi = md.compute_phi(trj)
    _, psi = md.compute_psi(trj)
    
    xyz.append(torch.tensor(t_backbone.xyz.reshape(n, -1)).float())
    phi_psi.append(torch.tensor(np.concatenate((phi, psi), -1)).float())
    
xyz[0].shape, phi_psi[0].shape

(torch.Size([250000, 24]), torch.Size([250000, 2]))

In [4]:
model = WGANGP(xyz[0].shape[1], phi_psi[0].shape[1])

In [5]:
model.fit(xyz, phi_psi, max_epochs=25)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name              | Type                | Params
----------------------------------------------------------
0 | generator         | SimpleGenerator     | 172 K 
1 | discriminator     | SimpleDiscriminator | 138 K 
2 | _feature_scaler   | MinMaxScaler        | 0     
3 | _condition_scaler | MinMaxScaler        | 0     
----------------------------------------------------------
311 K     Trainable params
0         Non-trainable params
311 K     Total params
1.246     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=25` reached.


WGANGP(
  (generator): SimpleGenerator(
    (model): Sequential(
      (0): Linear(in_features=130, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU()
      (3): Linear(in_features=256, out_features=256, bias=True)
      (4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): SiLU()
      (6): Linear(in_features=256, out_features=256, bias=True)
      (7): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): SiLU()
      (9): Linear(in_features=256, out_features=24, bias=True)
      (10): Tanh()
    )
  )
  (discriminator): SimpleDiscriminator(
    (model): Sequential(
      (0): Linear(in_features=26, out_features=256, bias=True)
      (1): SiLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): SiLU()
      (4): Linear(in_features=256, out_features=256, bias=True)
      (5): SiLU()
      (6): 

In [6]:
import nglview as nv
trj_backbones = md.join([trj.atom_slice(trj.top.select('backbone')) for trj in trjs])
v = nv.show_mdtraj(trj_backbones)
v



NGLWidget(max_frame=749999)

In [7]:
xyz = model.generate(torch.cat(phi_psi))

In [8]:
xyz = xyz.reshape(xyz.size(0), -1, 3)
fake_trj = md.Trajectory(xyz = xyz.cpu().numpy(), topology = trj_backbones.top)

In [9]:
v = nv.show_mdtraj(fake_trj)
v

NGLWidget(max_frame=749999)