In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [3]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
import h5py as h5
import pandas as pd
from glob import glob

In [4]:
from src.wassersteinflowmatching.riemannian_wasserstein import RiemannianWassersteinFlowMatching

### Create $SE(3)^N$ frames from MD trajectories

In [11]:
from prot_utils import get_ca_trajectory, atom2frame
from anim_utils import animate_protein_trajectory
from lovely_numpy import lo
import os

pc_frame_data = []
MDCATH_PATH = "./data/mdcath/"

def normalize_to_unit_range(trajectory):
    """Scale trajectory to [0, 1] range."""
    min_val = trajectory.min()
    max_val = trajectory.max()
    return (trajectory - min_val) / (max_val - min_val)

for cath_domain_folder in os.listdir(MDCATH_PATH):
    fp = os.path.join(MDCATH_PATH, cath_domain_folder)
    ca_traj, n_traj, c_traj, domain_name = get_ca_trajectory(fp) # T, N_atoms, 3
    ca_traj = normalize_to_unit_range(ca_traj)
    n_traj = normalize_to_unit_range(n_traj)
    c_traj = normalize_to_unit_range(c_traj)

    traj_frames = atom2frame(n_traj, ca_traj, c_traj) # T, N_residues, 7
    # flatten to (T, N_residues * 7)
    traj_frames = traj_frames.reshape(traj_frames.shape[0], -1)
    pc_frame_data.append(traj_frames)

    # anim, fig = animate_protein_trajectory(
    #         ca_traj,
    #         interval=50,
    #         color='royalblue',
    #         line_width=2.5,
    #         atom_size=40,
    #         save_path=f'{domain_name}_traj.gif',
    #         fps=20,
    #         rotate=True,
    #         rotation_speed=1.0
    #     )

In [12]:
class rwfm_config:
    geom: str = 'se3'
    monge_map: str = 'entropic'
    wasserstein_eps: float = 0.0005
    wasserstein_lse: bool = True
    num_sinkhorn_iters: int = -1
    mini_batch_ot_mode: bool = True
    mini_batch_ot_solver: str = 'chamfer'
    mini_batch_ot_num_iter: int = -1
    minibatch_ot_eps: float = 0.0005
    minibatch_ot_lse: bool = True
    noise_type: str = 'ambient_gaussian'
    noise_geom: str = 'se3'
    scaling: str = 'None'
    factor: float = 1.0
    embedding_dim: int = 512
    num_layers: int = 6
    num_heads: int = 4
    dropout_rate: float = 0.1
    mlp_hidden_dim: int = 512
    cfg: bool = False
    p_cfg_null: float = 0.0
    w_cfg: float = 1.0
    normalized_condition: bool = False

FlowMatchingModel = RiemannianWassersteinFlowMatching(point_clouds=pc_frame_data, config=rwfm_config)

Initializing WassersteinFlowMatching
Using se3 geometry
Projecting point clouds to geometry (with cpu)...


100%|██████████| 1/1 [00:00<00:00, 784.42it/s]

Using se3 geometry for noise





Using ambient_gaussian noise for se3 geometry.
Noise parameters:
  mean: [ 3.94753665e-01 -1.19476207e-01  9.58336145e-02 -1.13014504e-01
  6.66530132e-01  7.12723315e-01  2.21704721e-01  3.61017853e-01
 -6.59082532e-02  8.82879794e-02 -8.10468346e-02  6.64953947e-01
  7.12037683e-01  2.23439351e-01  3.50886345e-01  5.26630841e-02
  3.02176923e-04 -4.19805795e-02  6.62006736e-01  7.12164164e-01
  2.23997161e-01  3.22864830e-01 -8.43037367e-02  1.48636416e-01
 -1.63256004e-01  6.59088373e-01  7.13947654e-01  2.24425793e-01
  4.08790827e-01  1.20511249e-01  2.31229246e-01 -2.20377892e-01
  6.56071246e-01  7.14164972e-01  2.23235846e-01  3.93928289e-01
 -5.45090139e-02  1.50608197e-01  1.36423344e-02  6.55372441e-01
  7.12895334e-01  2.20977068e-01  3.92070591e-01  1.40746400e-01
  2.00205982e-01 -2.52541035e-01  6.52534127e-01  7.13449121e-01
  2.20269650e-01  4.65511292e-01 -4.25383225e-02  1.29577667e-01
  2.40337253e-02  6.51388466e-01  7.12491095e-01  2.18405545e-01
  3.82525146e-01 

: 

In [None]:

FlowMatchingModel.train(batch_size = 32, 
                        shape_sample = 1024,
                        training_steps = 500000, 
                        decay_steps = 5000)