In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm

In [4]:
from src.wassersteinflowmatching.riemannian_wasserstein import RiemannianWassersteinFlowMatching

### Create $SE(3)^N$ frames from MD trajectories

In [None]:
from prot_utils import pdb_to_se3_frames, trajectory_to_se3_frames

pc_frame_data = []
def get_traj_data(traj_path, topo_path):
    traj_frame_list = trajectory_to_se3_frames(trajectory_path=traj_path, topology_path=topo_path) # K, SE3^N
    return traj_frame_list

MDCATH_PATH = "./data/mdcath/"
for cath_domain_folder in os.listdir(MDCATH_PATH):
    traj_fp = os.path.join(MDCATH_PATH, cath_domain_folder, "traj.xtc")
    topo_fp = os.path.join(MDCATH_PATH, cath_domain_folder, "topo.pdb")
    traj_frame_list = trajectory_to_se3_frames(trajectory_path=traj_fp, topology_path=topo_fp) # K, SE3^N
    pc_frame_data.append(traj_frame_list)

print (pc_frame_data)

(10, 7)


In [None]:
class rwfm_config:
    geom: str = 'se3'
    monge_map: str = 'entropic'
    wasserstein_eps: float = 0.0005
    wasserstein_lse: bool = True
    num_sinkhorn_iters: int = -1
    mini_batch_ot_mode: bool = True
    mini_batch_ot_solver: str = 'chamfer'
    mini_batch_ot_num_iter: int = -1
    minibatch_ot_eps: float = 0.0005
    minibatch_ot_lse: bool = True
    noise_type: str = 'ambient_gaussian'
    noise_geom: str = 'se3'
    scaling: str = 'None'
    factor: float = 1.0
    embedding_dim: int = 512
    num_layers: int = 6
    num_heads: int = 4
    dropout_rate: float = 0.1
    mlp_hidden_dim: int = 512
    cfg: bool = False
    p_cfg_null: float = 0.0
    w_cfg: float = 1.0
    normalized_condition: bool = False

FlowMatchingModel = RiemannianWassersteinFlowMatching(point_clouds=pc_frame_data, config=rwfm_config)

Initializing WassersteinFlowMatching
Using se3 geometry
Projecting point clouds to geometry (with cpu)...



[A
[A
[A
100%|██████████| 2000/2000 [00:00<00:00, 5475.88it/s]


Using se3 geometry for noise
Using ambient_gaussian noise for se3 geometry.
Noise parameters:
  mean: [ 0.46023175 -0.04883802  0.02803123 -0.14033344 -0.47850034 -0.7602998
  0.74449974]
  cov_chol_mean: [[ 0.26104927  0.          0.          0.          0.          0.
   0.        ]
 [-0.0600263   0.53200155  0.          0.          0.          0.
   0.        ]
 [ 0.02057032  0.01560491  0.37145332  0.          0.          0.
   0.        ]
 [ 0.04721212  0.04816902 -0.09110928  0.5090113   0.          0.
   0.        ]
 [ 0.39639664  1.1992683   0.9183101  -2.6099527   1.7739029   0.
   0.        ]
 [ 2.993685   -0.5299187   0.36893567 -1.5379986   0.55535805  0.4615231
   0.        ]
 [ 0.29117587  0.342597    0.32018888  0.74421126 -0.73465973  0.49401674
   0.8455943 ]]
  cov_chol_std: [[8.6547338e-08 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.0000000e+00]
 [1.6639621e-08 6.3423002e-08 0.0000000e+00 0.0000000e+00 0.0000000e+00
  0.0000000e+00 0.000

In [None]:

FlowMatchingModel.train(batch_size = 32, 
                        shape_sample = 1024,
                        training_steps = 500000, 
                        decay_steps = 5000)