In [1]:
import numpy as np
import matplotlib.pyplot as plt
from msmbuilder.decomposition import tICA
from scipy.spatial.distance import pdist
from matplotlib.colors import LogNorm
import pickle

In [None]:
# TICA plots for generated vs ground truth vs random conformers  

def make_tica_plots(md_trajs, generated_confs, rand_confs):
    '''
    Generates TICA plots for ground truth, generated and random conformers.

    Args:

        - gt_confs: dict , keys are smiles and values are an md simulation of shape (time, n_atoms, 3)
        - generated_confs: dict , keys are smiles and values are tensors of generated 3D positions
        - rand_confs: dict , keys are smiles and values are tensors of random 3D positions

    Returns:
        3 figures for each smile, each showing the TICA plots for generated, ground truth and random conformers. Note that the TICA components are computed using the pairwise distances between atoms of the ground truth MD.
    '''
    assert md_trajs.keys() == generated_confs.keys() == rand_confs.keys()        
    smis = md_trajs.keys()
    for smi in smis: 
        assert len(md_trajs[smi].shape[1]) == len(generated_confs[smi].shape[1]) == len(rand_confs[smi].shape[1]) # all have the same number of atoms
    n_smis = len(smis)

    for smi in smis: 
        fig, axes = plt.subplots(n_smis // 3, 3, figsize=(10, 5))
        mdtraj = md_trajs[smi]
        # fit tica to MD
        md_pairwise_dists = np.array([pdist(mdtraj[i]) for i in range(len(mdtraj))])
        tica = tICA(n_components=2, lag_time=100)
        transformed = np.array(tica.fit_transform([md_pairwise_dists]))

        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        # MD tICA plot
        x = transformed[0,:, 0]
        y = transformed[0,:, 1]  
        axes[0].hist2d(x, y, bins=100, cmap='hot_r', norm=LogNorm())
        axes[0].set_title('MD tICA')
        axes[0].set_xlabel('1st tIC')
        axes[0].set_ylabel('2nd tIC')

        # Generated conformers tICA plot
        gen_pairwise_dists = np.array([pdist(generated_confs[smi][i]) for i in range(len(generated_confs[smi]))])
        transformed_gen = np.array(tica.transform([gen_pairwise_dists]))
        x_gen = transformed_gen[0,:, 0]
        y_gen = transformed_gen[0,:, 1]
        axes[1].hist2d(x_gen, y_gen, bins=100, cmap='hot_r', norm=LogNorm())
        axes[1].set_title('Generated tICA')
        axes[1].set_xlabel('1st tIC')
        axes[1].set_ylabel('2nd tIC')

        # Random conformers tICA plot
        rand_pairwise_dists = np.array([pdist(rand_confs[smi][i]) for i in range(len(rand_confs[smi]))])
        transformed_rand = np.array(tica.transform([rand_pairwise_dists]))
        x_rand = transformed_rand[0,:, 0]
        y_rand = transformed_rand[0,:, 1]
        axes[2].hist2d(x_rand, y_rand, bins=100, cmap='hot_r', norm=LogNorm())
        axes[2].set_title('Random tICA')
        axes[2].set_xlabel('1st tIC')
        axes[2].set_ylabel('2nd tIC')
        plt.show()

NEXT STEP: Run TorsionalDiffusion on one molecule from the FreeSolv dataset