In [1]:
# %% [markdown]
# Analysis of Sample Runs
# This notebook loads samples from four runs, computes energies, and generates Ramachandran plots and energy histograms.

# %%
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import mdtraj as md
import torch
from bgflow import XTBEnergy, XTBBridge
from bgflow.utils import as_numpy
from dataset.ad2_dataset import get_alanine_atom_types, get_alanine_implicit_dataset

import os
from bgflow.utils import remove_mean
from bgflow.utils import distances_from_vectors
from bgflow.utils import distance_vectors, as_numpy
from bgflow.bg import sampling_efficiency,unormalized_nll,effective_sample_size
from bgflow import XTBEnergy, XTBBridge
from models.interpolant import Interpolant
from models.ebm import GVP_EBM
from dataset.ad2_dataset import get_alanine_atom_types,get_alanine_implicit_dataset,get_alanine_features
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
import mdtraj as md
import argparse
import wandb
from utils.arguments import get_args
from utils.utils import load_models
import ot


****** PyMBAR will use 64-bit JAX! *******
* JAX is currently set to 32-bit bitsize *
* which is its default.                  *
*                                        *
* PyMBAR requires 64-bit mode and WILL   *
* enable JAX's 64-bit mode when called.  *
*                                        *
* This MAY cause problems with other     *
* Uses of JAX in the same code.          *
******************************************



## W2 Calculations Function

In [2]:


def calc_energy_w2(gen_energies, holdout_energies):
    """Calcualate and log energy and torsion w2 distance 

    Args:
        gen_energies (np list): tensor of generated samples energies
        holdout_energies (np list): tenosr of holdout sample energies, must be same shape and size as gen samples
    """
    
    
    
    # flatten both to 1-D
    
    gen_energies = gen_energies.ravel()
    holdout_energies = holdout_energies.numpy(force = True).ravel()

    # sort them
    gen_energies_sorted = np.sort(gen_energies)
    holdout_energies_sorted = np.sort(holdout_energies)
    loss, log = ot.emd2_1d(gen_energies,holdout_energies,metric = "euclidean",log = True)
    print(log)

    # # compute MSE of the sorted values = W2^2
    # w2_squared = np.mean((gen_energies_sorted - holdout_energies_sorted)**2)

    # # take sqrt to get W2
    # W2 = np.sqrt(w2_squared)
    
    print(f"W2 distance: {loss:.6f}")


def calc_torsion_w2(gen_angles,holdout_angles):
    """calculates OT w2 Torsion angles 

    Args:
        gen_angles (np list ): np array of sidechain angles 
        holdout_angles (_type_): np array of sidechain angles
    """
    dist = np.expand_dims(gen_angles,0) - np.expand_dims(holdout_angles,1)
    dist = np.sum((dist % np.pi)**2,axis = -1)
    print(gen_angles.shape)
    print(holdout_angles.shape)
    print(dist.shape)
    # dist = np.sqrt(dist)
    a, b = ot.unif(gen_angles.shape[0]), ot.unif(gen_angles.shape[0])
    W,log = ot.emd2(a,b,dist,log = True, numItermax=1e9) # uniform weights as input
    # w2_circle = ot.wasserstein_circle(gen_angles, holdout_angles, p=2)
    W = np.sqrt(W)
    # print(f"Angles W2 distance: {w2_circle}")
    
    # w2_circle = ot.wasserstein_circle(gen_angles, holdout_angles, p=2)
    
    print(f"Angles W2 distance: { W}")
    # wandb.log({"angles_w2":   W})

def get_torsion_angles(samples_np):
    """given a list of samples, return the torsion angle

    Args:
        samples (tensor):samples

    returns list of torsion angles:
    """
    
    atom_types = get_alanine_atom_types()
    atom_types[[4,6,8,14,16]] = np.arange(4, 9)
    # carbon atoms
    carbon_pos = np.where(atom_types==1)[0]
    carbon_samples_np = samples_np.reshape(-1, 22, 3)[:, carbon_pos]
    carbon_distances = np.linalg.norm(samples_np.reshape(-1, 22, 3)[:, [8]] - carbon_samples_np, axis=-1)
    # likely index of c beta atom
    cb_idx = np.where(carbon_distances==carbon_distances.min(1, keepdims=True))


    def determine_chirality_batch(cartesian_coords_batch):
        # Convert Cartesian coordinates to numpy array
        coords_batch = np.array(cartesian_coords_batch)

        # Check if the shape of the array is (n, 4, 3), where n is the number of chirality centers
        if coords_batch.shape[-2:] != (4, 3):
            raise ValueError("Input should be a batch of four 3D Cartesian coordinates")

        # Calculate the vectors from the chirality centers to the four connected atoms
        vectors_batch = coords_batch - coords_batch[:, 0:1, :]
        #print(vectors_batch)
        # Calculate the normal vectors of the planes formed by the three vectors for each chirality center
        normal_vectors_batch = np.cross(vectors_batch[:, 1, :], vectors_batch[:, 2, :])
        #print(normal_vectors_batch)
        # Calculate the dot products of the normal vectors and the vectors from the chirality centers to the fourth atoms
        dot_products_batch = np.einsum('...i,...i->...', normal_vectors_batch, vectors_batch[:, 3, :])
        #print(dot_products_batch)
        # Determine the chirality labels based on the signs of the dot products
        chirality_labels_batch = np.where(dot_products_batch > .000, 'L', 'D')

        return chirality_labels_batch


    back_bone_samples = samples_np.reshape(-1, 22, 3)[:, np.array([8,6,14])]
    cb_samples = samples_np.reshape(-1, 22, 3)[cb_idx[0], carbon_pos[cb_idx[1]]][:, None, :]
    chirality = determine_chirality_batch(np.concatenate([back_bone_samples, cb_samples], axis=1))
    samples_np_mapped = samples_np.copy()
    samples_np_mapped[chirality=="D"] *= -1
    dataset=get_alanine_implicit_dataset()
    traj_samples3 = md.Trajectory(samples_np_mapped.reshape(-1, 22, 3), topology=dataset.system.mdtraj_topology)

    phi_indices, psi_indices = [4, 6, 8, 14], [6, 8, 14, 16]
    angles = md.compute_dihedrals(traj_samples3, [phi_indices, psi_indices])

    return angles


### load in data

In [3]:
import numpy as np

# map each key to its (energy_file, sample_file)
file_map = {
    "tbg": (
        "./generated/tbg_model_generated_energies.npy",
        "./generated/tbg_model_generated_samples.npy",
    ),
    "vector": (
        "./generated/unweighted_ot_rtol1e-5_atol1e-5_tmin0_rep0energies.npy",
        "./generated/unweighted_ot_rtol1e-5_atol1e-5_tmin0_rep0samples.npy",
    ),
    "vector_ema": (
        "./generated/unweighted_ot_ema_rtol1e-5_atol1e-5_tmin0_rep0energies.npy",
        "./generated/unweighted_ot_ema_rtol1e-5_atol1e-5_tmin0_rep0samples.npy",
    ),
    "endpoint": (
        "./generated/unweighted_ot_endpoint_tmax100_rtol1e-5_atol1e-5_tmin1e-3_rep0energies.npy",
        "./generated/unweighted_ot_endpoint_tmax100_rtol1e-5_atol1e-5_tmin1e-3_rep0samples.npy",
    ),
    "endpoint_ema": (
        "./generated/unweighted_ot_endpoint_tmax100_ema_rtol1e-5_atol1e-5_tmin1e-3_rep0energies.npy",
        "./generated/unweighted_ot_endpoint_tmax100_ema_rtol1e-5_atol1e-5_tmin1e-3_rep0samples.npy",
    ),
}

# load into two dicts
energies = {k: np.load(e_path) for k, (e_path, _) in file_map.items()}
samples  = {k: np.load(s_path) for k, (_,  s_path) in file_map.items()}

# quick check
print("Energy keys: ", list(energies.keys()))
print("Sample keys:", list(samples.keys()))


Energy keys:  ['tbg', 'vector', 'vector_ema', 'endpoint', 'endpoint_ema']
Sample keys: ['tbg', 'vector', 'vector_ema', 'endpoint', 'endpoint_ema']


### Calculate W2

In [None]:
import numpy as np

# hyperparams
num_particles = 22
n_dimensions  = 3

# number of W2 samples for non-tbg keys
n_w2_default  = 10000
# number of W2 samples for tbg
n_w2_tbg      = 10000
# number of repetitions for tbg
n_reps_tbg    = 5

# load your holdout data once
all_holdout = (
    np.load("../data/AD2_relaxed_holdout.npy")
    .reshape(-1, num_particles * n_dimensions)
)
energies_data_holdout = np.load("../data/energies_data_holdout.npy")

for key in energies:
    gen_e = energies[key]  # 1D array of generated energies
    gen_s = samples[key]   # 2D array of generated conformations

    # pick the right sample size & number of reps
    if key == "tbg":
        n_w2   = n_w2_tbg
        reps   = n_reps_tbg
    else:
        n_w2   = n_w2_default
        reps   = 1

    for rep in range(reps):
        suffix = f" (rep {rep+1})" if reps > 1 else ""

        print(f"\n===== {key.upper()}{suffix} =====")

        # — Energy W₂ —
        idx_gen_e  = np.random.randint(0, gen_e.shape[0],  size=n_w2)
        idx_hold_e = np.random.randint(0, energies_data_holdout.shape[0], size=n_w2)
        w2_gen_e   = gen_e[idx_gen_e]
        w2_hold_e  = torch.from_numpy(energies_data_holdout[idx_hold_e])

        print(f"--> Energy W₂ for {key}{suffix}:")
        calc_energy_w2(w2_gen_e, w2_hold_e)

        # — Torsion (angle) W₂ —
        idx_gen_s  = np.random.randint(0, gen_s.shape[0],  size=n_w2)
        idx_hold_s = np.random.randint(0, all_holdout.shape[0], size=n_w2)

        samp_gen  = gen_s[idx_gen_s].reshape(-1, num_particles, n_dimensions)
        samp_hold = all_holdout[idx_hold_s].reshape(-1, num_particles, n_dimensions)

        # center each sample
        samp_gen  -= samp_gen.mean(axis=1, keepdims=True)
        samp_hold -= samp_hold.mean(axis=1, keepdims=True)

        # flatten back
        samp_gen  = samp_gen.reshape(-1, num_particles * n_dimensions)
        samp_hold = samp_hold.reshape(-1, num_particles * n_dimensions)

        ang_gen  = get_torsion_angles(samp_gen)
        ang_hold = get_torsion_angles(samp_hold)

        print(f"--> Torsion W₂ for {key}{suffix}:")
        calc_torsion_w2(ang_gen, ang_hold)



===== TBG (rep 1) =====
--> Energy W₂ for tbg (rep 1):
{'G': array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])}
W2 distance: 5.754796
Using downloaded and verified file: /tmp/A.pdb
Using downloaded and verified file: /tmp/A.pdb
Using downloaded and verified file: /tmp/A.pdb
Using downloaded and verified file: /tmp/A.pdb
--> Torsion W₂ for tbg (rep 1):
(20000, 2)
(20000, 2)
(20000, 20000)
