# Training a Simple AE and EGAE on a (biased) trajectory

## Import Dependencies

In [1]:
import jax 
import jax.numpy as jnp

from jax import random

import haiku as hk
import optax

from utils import load_dcd_dataset, bonds_to_graph
from models import MLP_AE, SimpleDecoder, EGEncoder
from training import fit

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_bonds_from_pdb(pdb_file_path, bond_distance_threshold=1.6):
    atoms = []
    with open(pdb_file_path, 'r') as pdb_file:
        for line in pdb_file:
            record_type = line[0:6].strip()
            if record_type == "ATOM":
                atom_serial = int(line[6:11])
                atom_symbol = line[12:16].strip()
                atom_position = [float(line[30:38]), float(line[38:46]), float(line[46:54])]
                atoms.append((atom_serial, atom_symbol, atom_position))

    bonds = []
    for i in range(len(atoms)):
        for j in range(i + 1, len(atoms)):
            distance = jnp.linalg.norm(jnp.array(atoms[i][2]) - jnp.array(atoms[j][2]))
            if distance <= bond_distance_threshold:
                bonds.append((atoms[i][0] - 1, atoms[j][0] - 1))

    return bonds

In [3]:
pdb_file_path = "data/adp-vacuum.pdb"
bonds = get_bonds_from_pdb(pdb_file_path)

In [3]:
bonds

NameError: name 'bonds' is not defined

## Define the molecule structure (butane), global variables, and dataset.

In [4]:
N_MOLECULES = 22
BATCH_SIZE = 1

BONDS =  jnp.array(bonds)

edges, edge_attr, adj = bonds_to_graph(BONDS, N_MOLECULES)

train_loader = load_dcd_dataset('data/adp-vacuum.pdb', './data/traj5.dcd', BATCH_SIZE)

  bonds = bonds.astype(jnp.integer)


## Experiment 1: training only with coordinates (no node features)

In [8]:
# Initialize models

@hk.transform
def mlp_ve(inputs):
  model = MLP_AE(
      in_ft=N_MOLECULES*3,
      G=64,
      K=3
  )
  return model(inputs)

@hk.transform
def egae(inputs):
  encoder = EGEncoder(
        hidden_nf=32,
        n_layers=3,
        z_dim=3,
        activation=jax.nn.swish,
        reg=1e-3
  )
  decoder = SimpleDecoder( 
    in_ft=N_MOLECULES*3,
    G=64,
  )
  _, latent = encoder(inputs)
  out = decoder(latent.flatten())
  return out, latent


In [6]:
def process_mlp_batch(batch):
    _, x, _ = batch
    x = x.squeeze(0).flatten()
    return x, x

def process_egae_batch(batch):
    _, x, _ = batch
    x = x.squeeze(0)
    h = jnp.expand_dims(jnp.ones(x.shape[0]), axis=1)
    return (h, x, edges, edge_attr), x

def compute_loss_mlp(params, x, y):
    y_hat, _ = mlp_ve.apply(params, jax.random.PRNGKey(0), x)

    loss = jnp.abs((y_hat.reshape(y.shape ) - y)).mean()

    return loss

def compute_loss_egae(params, x, y):
    y_hat, _ = egae.apply(params, jax.random.PRNGKey(0), x)

    loss = jnp.abs((y_hat.reshape(y.shape ) - y)).mean()

    return loss

### Train MLP AutoEncoder

In [19]:
rng = random.PRNGKey(390)

optimizer = optax.adam(learning_rate=1e-3)

batch = next(iter(train_loader))
x, _ = process_mlp_batch(batch)

initial_params = mlp_ve.init(rng, x)
    
params = fit(initial_params, optimizer, compute_loss_mlp, process_mlp_batch, train_loader, 100)



InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (42,) and (66,)

In [10]:
rng = random.PRNGKey(391230)

optimizer = optax.adam(learning_rate=1e-7)

batch = next(iter(train_loader))
x, _ = process_egae_batch(batch)

initial_params = egae.init(rng, x)
    
params = fit(initial_params, optimizer, compute_loss_egae, process_egae_batch, train_loader, 100)

Epoch: 0 - loss: 19.547609329223633 - Execution time: 19.785537242889404 sec
Epoch: 1 - loss: 19.53062629699707 - Execution time: 52.64488887786865 sec
Epoch: 2 - loss: 19.49976921081543 - Execution time: 53.094120502471924 sec
Epoch: 3 - loss: 19.465970993041992 - Execution time: 52.71249771118164 sec
Epoch: 4 - loss: 19.44119644165039 - Execution time: 52.60916566848755 sec
Epoch: 5 - loss: 19.420146942138672 - Execution time: 52.63534188270569 sec
Epoch: 6 - loss: 19.39902114868164 - Execution time: 53.002936601638794 sec
Epoch: 7 - loss: 19.377634048461914 - Execution time: 52.440221071243286 sec
Epoch: 8 - loss: 19.355958938598633 - Execution time: 53.03923726081848 sec
Epoch: 9 - loss: 19.334089279174805 - Execution time: 52.76970839500427 sec
Epoch: 10 - loss: 19.312088012695312 - Execution time: 52.46065068244934 sec


KeyboardInterrupt: 