### Short introduction
Proteins are not just a frozen set of atoms &mdash; they move and possess a universe of **conformations**, some of which are crucial for their normal functioning. Thus, it is a challenge and a goal of top priority for structural biologists to improve their knowledge of conformations proteins have or might have under cirtain conditions. I will omit the extensive theoretical introductory part for the description of the problem I am going to tackle in terms of my PhD thesis here, but one could find the supplementary materials in the corresponding folder/via the link (to be added soon).

### The origins of the NN
Few scientists adress the problem of predicting plausible conformations with the help of machine learning as there are multiple accompanying issues that require interdisciplinary background (e.g. in data processing since the complexity of the object itself) and long-term time investment. 
While making the literature research, I came up with the **[article](https://journals.aps.org/prx/abstract/10.1103/PhysRevX.11.011052)** published by Dr.Degiacomi research group. They submitted the core of the source code to github as a public repository **[molearn](https://github.com/Degiacomi-Lab/molearn/)**. I dived deeper into the proposed NN architecture and saw a perspective to implement it in my own research as it served the initial authors' goals and still left some space for improvement. To sum up all the details necessary for further understanding, I provide several schemes and images below. (I will provide them asap))

### Importing

In [27]:
import os
from copy import deepcopy

import wandb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import tensorflow as tf
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm, trange

import biobox
from molearn import Auto_potential, Autoencoder, load_data

import MDAnalysis as mda
import mdtraj as md
import nglview 
from MDAnalysis.analysis import align, rms
from MDAnalysis.tests.datafiles import PDB, XTC

import plotly as py
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from PIL import Image

### Specializing paths to files

In [3]:
DATA_ORIGIN = "/home/ebam/kacher1/RFAH_MD_sampled/run0"
EXPERIMENT_NAME = 'rfah_aligned_each_100'
DATA = "./DATA"

ROOT = os.path.join(DATA, EXPERIMENT_NAME)
if not os.path.exists(ROOT): 
    os.mkdir(ROOT)

checkpoints_dir = os.path.join(ROOT, 'checkpoints')
checkpoints_extra_dir = os.path.join(ROOT, 'checkpoints_extra')
conformations_dir = os.path.join(ROOT, 'conformations')
pdbs_dir = os.path.join(ROOT, 'pdbs')
weights_dir = os.path.join(ROOT, 'weights')

if not os.path.exists(checkpoints_extra_dir):
    os.mkdir(checkpoints_extra_dir)

if not os.path.exists(checkpoints_dir):
    os.mkdir(checkpoints_dir)
    
if not os.path.exists(conformations_dir):
    os.mkdir(conformations_dir)

if not os.path.exists(pdbs_dir):
    os.mkdir(pdbs_dir)  

if not os.path.exists(weights_dir):
    os.mkdir(weights_dir)

### Trajectories preproccessing and alignment

In [None]:
traj = mda.Universe(f'{DATA_ORIGIN}/positions.pdb', f'{DATA_ORIGIN}/simulation_data/traj.dcd')
ref = mda.Universe(f'{DATA_ORIGIN}/positions.pdb')

aligned = align.AlignTraj(
    traj,  # trajectory to align
    ref,  # reference
    select='all',  # selection of atoms to align
    filename=f'{DATA_ORIGIN}/aligned.dcd',  # file to write the trajectory to
    match_atoms=True,  # whether to match atoms based on mass
    in_memory=False, 
    verbose=True
).run()

In [None]:
# comparing data 
# not_aligned = md.load(
#     f'{DATA_ORIGIN}/simulation_data/traj.dcd', 
#     top=f'{DATA_ORIGIN}/positions.pdb',
# )

# not_aligned

In [None]:
# visualizing trajectory if needed
# view = nglview.show_mdtraj(aligned)
# view

In [None]:
molecule = mda.Universe(
    f'{DATA_ORIGIN}/positions.pdb', 
    f'{DATA_ORIGIN}/aligned.dcd'
)
ag = molecule.atoms
num_of_frames = 100 

with mda.Writer(f'{DATA_ORIGIN}/aligned_each_{num_of_frames}.dcd', ag.n_atoms) as w:
    for ts in tqdm(molecule.trajectory[::100]):
        w.write(ag)

In [None]:
# filtered = mda.Universe(f'{DATA_ORIGIN}/aligned_each_{num_of_frames}.dcd')
# filtered.trajectory

!mdconvert \
    -o {DATA_ORIGIN}/aligned_each_{num_of_frames}.pdb \
    -t {DATA_ORIGIN}/positions.pdb \
    {DATA_ORIGIN}/aligned_each_{num_of_frames}.dcd

### Loading data to NN

In [4]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

floc = [f'{DATA_ORIGIN}/aligned_each_100.pdb'] #protein with frames)
batch_size = 32 # if this is too small, gpu utilization goes down
epoch = 0
iter_per_epoch = 25 #it was 5, but using higher iter_per_epoch = 1000 for smoother plots (iter_per_epoch = smoothness  of statistics)
method = 'roll' # 3 methods available: 'roll', 'convolutional', 'indexing'

In [5]:
# load multiPDB and create loss function

dataset, meanval, stdval, atom_names, mol = \
    load_data(floc[0], atoms = ["CA", "C", "N", "CB", "O"], device=device, get_max_rmsd=False) 

# if you need to calculate rmsd, set get_max_rmsd = True & rmsd_from_file = False |
# or True if already calculated, specify f'{ROOT}/rmsd_matrix.npy'

lf = Auto_potential(frame=dataset[0]*stdval, pdb_atom_names=atom_names, method = method, device=device)

Conformations: (also saved to dataset_conformations.npy)
[]
File amino12.lib opened
File parm10.dat opened
PARM99 + frcmod.ff99SB + frcmod.parmbsc0 + OL3 for RNA

parameters loaded
File frcmod.ff14SB opened
ff14SB protein backbone and sidechain parameters

Determining bonds
Determining angles


Preparation of alternative test0 and test1 --if needed--

In [6]:
# if you want to get test0 and test1 for validation without calculating the whole pairwise rmsd matrix, set get_max_rmsd = False
# then use the function
def alternative_rmsd(molecule):
    test0 = molecule.trajectory[0]
    ts = molecule.trajectory.ts
    max_rmsd = 0
    test1 = 0
    for ts in molecule.trajectory:
        # ts.positions - returns a numpy array of positions
        # ts.frame - current frame number (0-based)
        rmsd = mda.analysis.rms.rmsd(test0, ts, center = True, superposition = True) 
        if rmsd > max_rmsd:
            max_rmsd = rmsd
            test1 = ts
    # check with the help of the triangle rule
    for ts in molecule.trajectory: 
        rmsd = mda.analysis.rms.rmsd(test1, ts, center = True, superposition = True)
        if rmsd > max_rmsd: 
            test0 = ts

    return test0, test1
# for more information on rmsd calculation https://docs.mdanalysis.org/1.0.1/documentation_pages/analysis/rms.html?highlight=rmsd#MDAnalysis.analysis.rms.rmsd

In [7]:
# for getting the info about test0 and test1 in advance 
molecule = mda.Universe('/home/ebam/kacher1/RFAH_MD_sampled/run0/aligned_each_100.dcd')
alternative_rmsd(molecule=molecule)

(< Timestep 2031 with unit cell dimensions None >,
 < Timestep 2584 with unit cell dimensions None >)

In [8]:
test0_ts, test1_ts = alternative_rmsd(molecule)

In [9]:
test0 = dataset[test0_ts.frame]
test1 = dataset[test1_ts.frame]
test0.shape, test1.shape

(torch.Size([3, 286]), torch.Size([3, 286]))

Continue from here if not using alternative_rmsd

In [140]:
dataset.shape

torch.Size([2612, 3, 286])

In [10]:
num_atoms = dataset.shape[2]
num_atoms

286

In [11]:
# If get_rmsd == True:
# Saving test structures (the most extreme conformations in terms of RMSD)
# Remember to rescale with stdval, permute axis from [3,N] to [N,3]
# unsqueeze to [1, N, 3], send back to cpu, and convert to numpy array.

crds =  (test0*stdval).permute(1,0).unsqueeze(0).data.cpu().numpy()
mol.coordinates = crds
mol.write_pdb(f'{pdbs_dir}/TEST0.pdb')

crds =  (test1*stdval).permute(1,0).unsqueeze(0).data.cpu().numpy()
mol.coordinates = crds
mol.write_pdb(f'{pdbs_dir}/TEST1.pdb')

### Defining visualization functions

In [12]:
def visualize_energy(network, stdval, num_atoms, lf, size=100):
    x = np.linspace(0, 1, size)
    y = np.linspace(0, 1, size)

    xx, yy = np.meshgrid(x, y) 
    z = np.stack((xx, yy), axis=2)
    z = z.reshape(size**2, 2, 1)

    losses_all = []
    for i in trange(0, len(z), 1): 
        z_batch = z[i:(i+1)]
        z_batch=torch.tensor(z_batch).float().to(device)
        out = network.decode(z_batch)[:, :, :num_atoms]
        out *= stdval
        bond_energy, angle_energy, torsion_energy, NB_energy = lf.get_loss(out)

        loss_f = (bond_energy + angle_energy + torsion_energy + NB_energy)
        losses_all.append(loss_f.item())
    
    img = np.array(losses_all)
    img = img.reshape(size, size)
    return img

def conformations_to_latent(network, train_loader):
    z_list=[]
    with torch.no_grad():
        for batch in tqdm(train_loader):
            x = batch[0].to(device)
            z = network.encode(x)
            z_list.append(z.cpu().squeeze(2))
    
    z_list = torch.cat(z_list)
    
    return z_list 

def visualization(network, stdval, num_atoms, lf, train_loader, size=100):
    img = visualize_energy(network, stdval, num_atoms, lf, size=size)   
    img_conf = conformations_to_latent(network, train_loader)
    
    plt.scatter(img_conf[:, 0], img_conf[:,1], c='red', alpha=0.5, s=1)
    plt.imshow(img, extent=[0, 1, 0, 1], origin='lower')
    plt.colorbar()
    # plt.contour(img extent=[0, 1, 0, 1], cmap="bwr")

    plt.xlim(-.1, 1.1)
    plt.ylim(-.1, 1.1)

    plt.savefig("/tmp/fig.png")
    plt.close()
    img = Image.open("/tmp/fig.png")
    
    return img

In [None]:
#if necessary to look at the image 
#image = visualization(network, stdval, num_atoms, lf, train_loader, size=100)
#image

Plotly implementation

In [28]:
def interactive_visulization(points, molecules):
    """
    points: [num_frames, 2]
    molecules: [num_frames, 3, num_atoms]
    """

    # Preparing figure object 
    fig = make_subplots(
        rows=1,
        cols=2,
        specs=[[{"type": "xy"}, {"type": "scene"}]],
    )

    points_plot = go.Scatter(
        x=points[:, 0],
        y=points[:, 1],
        hovertext=np.arange(points.shape[0]),
        #hovertext=np.arange( )
        mode="markers",
        name="latent_points"
    )

    molecule_3d_plot = go.Scatter3d(
        x = molecules[0, 0, :],
        y = molecules[0, 1, :],
        z = molecules[0, 2, :],
        mode="markers",
        marker=dict(
            size=3,
        ),
        name="molecule",
    )

    fig.add_trace(points_plot, row=1, col=1)
    fig.add_trace(molecule_3d_plot, row=1, col=2)

    layout = fig.layout
    # layout["xaxis"]["range"] = [-2.2, 2.2]
    # layout["yaxis"]["range"] = [-2.2, 2.2]
    # layout["zaxis"]["range"] = [-2.2, 2.2]

    # Preparing interactive widget
    widget = go.FigureWidget([points_plot, molecule_3d_plot], layout=layout)
    # widget.layout.hovermode = 'closest'

    points_plot = widget.data[0]
    molecule_3d_plot = widget.data[1]

    # Defining callback with 3 parameters: trace = graphical object with some changes, points and selector
    def callback(trace, points, selector):
        # A list with the size of the points, all of them have the same size
        # s = [5] * N # list(points_plot.marker.size)

        # Iteration over indexes, we want only one of the points
        for i in points.point_inds:
            # Imitating selection of the particular point
            # s[i] = 10

            # Go into batch_update for buffering
            with widget.batch_update():
                # coordinates of i frame
                molecule_3d_plot.x = molecules[i, 0, :]
                molecule_3d_plot.y = molecules[i, 1, :]
                molecule_3d_plot.z = molecules[i, 2, :]

                # changing sizes of all the points, it is necessary to make another list to make it work 
                # points_plot.marker.size = s

            return 

    # registring callback, possible modes (on_cxlick), (on_selection), etc
    points_plot.on_hover(callback)

    # Widget visualization
    return widget

### Training

In [14]:
cfg = dict(
    batch_size=32,
    learning_rate=0.001,
    num_epochs=200,
)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(dataset.float()), 
               batch_size=2 * cfg["batch_size"], shuffle=True, drop_last=True, num_workers=0)

In [15]:
network = Autoencoder(m=2.0, latent_z=2, r=2).to(device)

In [16]:
#if necessary to train more or use already obtained data
network.load_state_dict(torch.load(f'{checkpoints_dir}/epoch_0199_0.25607.pth'))

<All keys matched successfully>

In [18]:
#Sending to W&B
wandb.init(project="molearn", entity="jk") #settings=wandb.Settings(start_method="thread"))
wandb.config = cfg

optimiser = torch.optim.Adam(network.parameters(), lr=cfg["learning_rate"], amsgrad=True)

[34m[1mwandb[0m: Currently logged in as: [33mjk[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.16 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [19]:
def train_epoch(network, train_loader, loss_function, optimiser, num_atoms, stdval):
    network.train()
    
    for batch in tqdm(train_loader, desc=f"Training epoch #{epoch}...", leave=False):
        x = batch[0].to(device)
        
        x0 = x[:batch_size]
        x1 = x[batch_size:]
        optimiser.zero_grad()

        #encode
        z0 = network.encode(x0)
        z1 = network.encode(x1)

        #interpolate
        alpha = torch.rand(batch_size, 1, 1).to(device)
        z_interpolated = (1-alpha)*z0 + alpha*z1

        #decode
        out0 = network.decode(z0)[:,:,:num_atoms]
        out1 = network.decode(z1)[:,:,:num_atoms]
        out_interpolated = network.decode(z_interpolated)[:,:,:num_atoms]

        #calculate MSE
        mse_loss_0 = ((x0-out0)**2).mean() # reconstructive loss (Mean square error)
        mse_loss_1 = ((x1-out1)**2).mean() # reconstructive loss (Mean square error)
        out0 *= stdval
        out1 *= stdval
        out_interpolated *= stdval
        
        mse_loss = (mse_loss_0 + mse_loss_1) / 2

        #calculate physics for interpolated samples
        bond_energy, angle_energy, torsion_energy, NB_energy = loss_function.get_loss(out_interpolated)
        
        #by being enclosed in torch.no_grad() torch autograd cannot see where this scaling
        #factor came from and hence although mathematically the physics cancels, no gradients
        #are found and the scale is simply redefined at each step
        #item  ~ torch.no_grad
        
        with torch.no_grad():
            scale = 0.1*mse_loss.item()/(bond_energy.item()+angle_energy.item()+torsion_energy.item()+NB_energy.item())

        network_loss = mse_loss + scale*(bond_energy + angle_energy + torsion_energy + NB_energy)
        
        wandb.log(dict(
            mse_loss=mse_loss.item(),
            phys_loss=(bond_energy + angle_energy + torsion_energy + NB_energy).item(),
            bond_energy=bond_energy.item(),
            angle_energy=angle_energy.item(),
            torsion_energy=torsion_energy.item(),
            NB_energy=NB_energy.item(),
            network_loss=network_loss.item(),
        ))
        
        #wandb.watch(network)

        #determine gradients
        network_loss.backward()

        #advance the network weights
        optimiser.step()
        
    return network_loss

In [20]:
def validation(network, test0, test1, stdval, num_atoms, train_loader, loss_function=lf, img_size=100):
    #encode test with each network
    #Not training so switch to eval mode
    network.eval()
    
    interpolation_out = torch.zeros(20, num_atoms, 3)
    
    with torch.no_grad(): # don't need gradients for this bit
        test0_z = network.encode(test0.unsqueeze(0).float())
        test1_z = network.encode(test1.unsqueeze(0).float())

        #interpolate between the encoded Z space for each network between test0 and test1
        for idx, t in enumerate(np.linspace(0, 1, 20)):
            interpolation_out[idx] = network.decode(float(t)*test0_z + (1-float(t))*test1_z)[:,:,:num_atoms].squeeze(0).permute(1,0).cpu().data
        interpolation_out *= stdval

    img = visualization(network, stdval, num_atoms, lf, train_loader, size=img_size)
    wandb.log({"img": wandb.Image(img, caption=f"epoch_{epoch:0>4}")})

    #save interpolations
    mol.coordinates = interpolation_out.numpy()
    mol.write_pdb(f'{pdbs_dir}/epoch_{epoch:0>4}_interpolation.pdb')


In [21]:
#training loop
valid = True
for epoch in trange(200, desc="Training..."):
    network_loss = train_epoch(network, train_loader, lf, optimiser, num_atoms, stdval)
    
    #save interpolations between test0 and test1 every 5 epochs, if get_rmsd==True
    if valid and (epoch + 1) % 5 == 0:
        validation(network, test0, test1, stdval, num_atoms, train_loader, loss_function=lf, img_size=100)
    
    #torch.save(network.state_dict(), f'{checkpoints_dir}/epoch_{epoch:0>4}_{network_loss.item():.5}.pth')
    #if extra training 
    torch.save(network.state_dict(), f'{checkpoints_extra_dir}/epoch_{epoch:0>4}_{network_loss.item():.5}.pth')
    epoch+=1

Training...:   0%|          | 0/200 [00:00<?, ?it/s]

Training epoch #0...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #1...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #2...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #3...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #4...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #5...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #6...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #7...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #8...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #9...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #10...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #11...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #12...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #13...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #14...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #15...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #16...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #17...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #18...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #19...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #20...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #21...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #22...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #23...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #24...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #25...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #26...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #27...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #28...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #29...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #30...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #31...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #32...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #33...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #34...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #35...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #36...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #37...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #38...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #39...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #40...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #41...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #42...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #43...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #44...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #45...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #46...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #47...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #48...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #49...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Training epoch #57...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #58...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #59...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #80...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #81...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #82...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #83...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #84...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #90...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #91...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #92...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #93...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #94...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #100...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #101...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #102...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #103...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #104...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #105...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #106...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #107...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #108...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #109...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #110...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #111...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #112...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #113...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #114...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #115...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #116...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #117...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #118...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #119...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #120...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #121...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #122...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #123...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #124...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #125...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #126...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #127...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #128...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #129...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #130...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #131...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #132...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #133...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #134...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #135...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #136...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #137...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #138...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #139...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #140...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #141...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #142...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #143...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #144...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #145...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #146...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #147...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #148...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #149...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #150...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #151...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #152...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #153...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #154...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #155...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #156...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #157...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #158...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #159...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #160...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #161...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #162...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #163...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #164...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #165...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #166...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #167...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #168...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #169...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #170...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #171...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #172...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #173...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #174...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #175...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #176...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #177...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #178...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #179...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #180...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #181...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #182...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #183...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #184...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #185...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #186...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #187...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #188...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #189...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #190...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #191...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #192...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #193...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #194...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #195...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #196...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #197...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #198...:   0%|          | 0/40 [00:00<?, ?it/s]

Training epoch #199...:   0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

### Looking at the results

In [29]:
#Loading information about the dataset without shuffling  
dataloader = torch.utils.data.DataLoader(dataset, batch_size, shuffle=False, drop_last=False)

points = np.empty(shape=(len(dataset), 2)) # creating empty array with number of frames and 2d points 
network = Autoencoder(m=2.0, latent_z=2, r=2).eval().to(device)

#Loading the best checkpoint from the disc (from checkpoints_dir or checkpoints_dir_extra in case of training >200 epochs
state_dict = torch.load(f'{checkpoints_extra_dir}/epoch_0199_0.038058.pth')
network.load_state_dict(state_dict)  

# Iterating over batches
for i, batch in enumerate(dataloader):
    x = batch.float().to(device)

    with torch.no_grad():
        z = network.encode(x)
        
    points[i*batch_size:(i+1)*batch_size, :] = z.cpu().squeeze(-1).numpy()

In [31]:
points = np.load("DATA/tmp/points.npy")
molecules = np.load("DATA/tmp/molecules.npy")

In [32]:
np.save("DATA/tmp/points.npy", points)
np.save("DATA/tmp/molecules.npy", molecules)

In [33]:
interactive_visulization(points, molecules)

FigureWidget({
    'data': [{'hovertext': array([0.000e+00, 1.000e+00, 2.000e+00, ..., 2.609e+03, 2.610e+03, 2…

### To be continued with the results and pictures included and commented