In [None]:
# Cell 1
!pip install --quiet torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install --quiet torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric \
    -f https://data.pyg.org/whl/torch-2.1.0+cu118.html
!pip install --quiet openmm mdtraj parmed tqdm pandas numpy
!pip install pymbar statsmodels

import platform, sys, subprocess
import importlib.metadata as metadata

pkgs = [
    "torch", "torchvision", "torchaudio",
    "torch-scatter", "torch-sparse", "torch-cluster", "torch-spline-conv", "torch-geometric",
    "openmm", "mdtraj", "parmed", "tqdm", "pandas", "numpy", "pymbar", "statsmodels"
]

print("=== Python & OS Info ===")
print(f"Python: {sys.version}")
print(f"Platform: {platform.platform()}")

print("\n=== Hardware Info ===")
!nvidia-smi

print("\n=== Installed Package Versions ===")
for pkg in pkgs:
    try:
        ver = metadata.version(pkg)
    except metadata.PackageNotFoundError:
        ver = "Not installed"
    print(f"{pkg}: {ver}")

import torch
print("\n=== PyTorch CUDA Info ===")
print(f"PyTorch CUDA available: {torch.cuda.is_available()}")
print(f"PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"CUDA version (PyTorch): {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Cell 2
from google.colab import drive
drive.mount('/content/drive')

import os
out_root = '/content/drive/MyDrive/alanine_dipeptide'
os.makedirs(out_root, exist_ok=True)
print("Output →", out_root)

In [None]:
# Cell 3
import json, numpy as np, mdtraj as md
from openmm.app import PDBFile, Modeller, ForceField, Simulation, DCDReporter, StateDataReporter, NoCutoff, HBonds
from openmm import LangevinIntegrator, unit

n_reps         = 3
frames_per_rep = 8000
report_every   = 50
timestep_fs    = 2.0

meta = {
    'solvent':      'implicit (OBC2)',
    'forcefield':   ['amber99sbildn.xml','amber99_obc.xml'],
    'integrator':   'Langevin',
    'temperature':  300,
    'friction_ps^-1': 1.0,
    'timestep_fs':  timestep_fs,
    'frames_per_rep': frames_per_rep,
    'dt_ps':        timestep_fs * report_every / 1000,
    'n_replicas':   n_reps
}
with open(os.path.join(out_root,'metadata.json'),'w') as f:
    json.dump(meta, f, indent=2)
print("Saved metadata.json")

In [None]:
# Cell 4
!wget -q https://raw.githubusercontent.com/choderalab/YankTools/master/testsystems/data/alanine-dipeptide-gbsa/alanine-dipeptide.pdb \
    -O alanine_dipeptide.pdb

pdb      = PDBFile('alanine_dipeptide.pdb')
ff_impl  = ForceField('amber99sbildn.xml', 'amber99_obc.xml')
modeller = Modeller(pdb.topology, pdb.positions)

system = ff_impl.createSystem(
    modeller.topology,
    nonbondedMethod=NoCutoff,
    constraints=HBonds
)
integrator = LangevinIntegrator(
    300 * unit.kelvin,
    1.0 / unit.picosecond,
    timestep_fs * unit.femtoseconds
)
print("System + integrator ready")

In [None]:
# Cell 5
import numpy as np
from tqdm.auto import trange
from openmm import unit

t_gen_start = time.time()

for rep in range(1, n_reps+1):
    print(f"\n=== Replica {rep}/{n_reps} ===")
    integrator = LangevinIntegrator(
        300*unit.kelvin,
        1.0/unit.picosecond,
        timestep_fs*unit.femtoseconds
    )
    sim = Simulation(modeller.topology, system, integrator)
    sim.context.setPositions(modeller.positions)
    sim.minimizeEnergy()
    sim.context.setVelocitiesToTemperature(300*unit.kelvin)

    rep_dir = os.path.join(out_root, f'rep_{rep:02d}')
    os.makedirs(rep_dir, exist_ok=True)
    dcd_path = os.path.join(rep_dir, 'traj.dcd')

    sim.reporters.append(DCDReporter(dcd_path, report_every))

    forces_list, energies_list = [], []

    total_steps = frames_per_rep * report_every
    for step in trange(total_steps, desc=f"Replica {rep}", unit="step"):
        sim.step(1)
        if step % report_every == 0:
            state = sim.context.getState(getForces=True, getEnergy=True)
            U = state.getPotentialEnergy().value_in_unit(unit.kilojoule_per_mole)
            energies_list.append(U)
            f_atoms = state.getForces(asNumpy=True)\
                            .value_in_unit(unit.kilojoule_per_mole/unit.nanometer)
            forces_list.append(f_atoms)

    np.save(os.path.join(rep_dir, 'forces.npy'),   np.stack(forces_list))
    np.save(os.path.join(rep_dir, 'energies.npy'), np.array(energies_list))
    print(f"Replica {rep} done: saved forces.npy({len(forces_list)}) and energies.npy({len(energies_list)})")

t_gen = time.time() - t_gen_start
print(f"Data generation took {t_gen/3600:.2f} hours.")

In [None]:
#Cell 6
import os, glob

root       = '/content/drive/MyDrive/alanine_dipeptide'
rep_dirs   = sorted(glob.glob(os.path.join(root, 'rep_*/')))
heavy_only_pdb  = os.path.join(root, 'heavy_only.pdb')
processed  = os.path.join(root, 'processed')
os.makedirs(processed, exist_ok=True)

print("Found replicas:", rep_dirs)
print("Heavy PDB  =", heavy_only_pdb)
print("Processed dir:", processed)

In [None]:
# Cell 7
import shutil

full_src = 'alanine_dipeptide.pdb'
full_pdb = os.path.join(root, 'full.pdb')

if not os.path.exists(full_pdb):
    shutil.copy(full_src, full_pdb)
    print("Copied full PDB to", full_pdb)
else:
    print("Found existing full PDB:", full_pdb)

In [None]:
# Cell 8
import numpy as np, mdtraj as md

heavy_top = md.load(heavy_only_pdb).topology

def project_dihedral_forces(positions, atom_forces,
                             phi_idx, psi_idx, omega_idx,
                             delta=1e-6):
    """
    Map atomic forces → feature‐space forces for φ,ψ,ω.
    positions   : (N_atoms,3) numpy array in nm
    atom_forces : (N_atoms,3) numpy array in kJ/mol/nm
    phi_idx,psi_idx,omega_idx : lists of 4 atom indices
    returns     : (6,) numpy array [dU/dsinφ,dU/dcosφ, …]
    """
    traj0 = md.Trajectory(positions[np.newaxis,:,:], heavy_top)

    angles = [
        md.compute_dihedrals(traj0, [phi_idx])[0,0],
        md.compute_dihedrals(traj0, [psi_idx])[0,0],
        md.compute_dihedrals(traj0, [omega_idx])[0,0]
    ]
    idxs = [phi_idx, psi_idx, omega_idx]

    torques = []
    for angle0, atom_idx in zip(angles, idxs):
        grads = np.zeros_like(positions)

        for i in atom_idx:
            for ax in range(3):
                pos_f = positions.copy(); pos_f[i,ax] += delta
                th_f  = md.compute_dihedrals(md.Trajectory(pos_f[np.newaxis,:,:],
                                  heavy_top), [atom_idx])[0,0]
                pos_b = positions.copy(); pos_b[i,ax] -= delta
                th_b  = md.compute_dihedrals(md.Trajectory(pos_b[np.newaxis,:,:],
                                  heavy_top), [atom_idx])[0,0]
                grads[i,ax] = (th_f - th_b)/(2*delta)

        torque = -np.sum(atom_forces * grads)
        torques.append(torque)

    feats = []
    for torque, th in zip(torques, angles):
        s, c = np.sin(th), np.cos(th)
        feats += [torque * c, -torque * s]

    return np.array(feats)

In [None]:
# Cell 9
import mdtraj as md

topology = os.path.join(root, 'full.pdb')

for rep in rep_dirs:
    dcd_in  = os.path.join(rep, 'traj.dcd')
    xtc_out = os.path.join(rep, 'heavy_only.xtc')
    if os.path.exists(xtc_out):
        print(f"Already have {xtc_out}")
        continue

    traj = md.load(dcd_in, top=topology)
    keep = [a.index for a in traj.topology.atoms
            if a.residue.name in ('ACE','ALA','NME')
            and a.element.symbol != 'H']
    heavy = traj.atom_slice(keep)
    heavy = heavy.superpose(heavy, frame=0)
    heavy.save_xtc(xtc_out)
    print(f"Created {xtc_out}")

print("All heavy-only XTCs ready.")

In [None]:
# Cell 10
import numpy as np, mdtraj as md
from tqdm.auto import tqdm

topology = full_pdb

ref = md.load(topology)
top = ref.topology
phi_idx   = [ top.select('resid 0 and name C')[0],
              top.select('resid 1 and name N')[0],
              top.select('resid 1 and name CA')[0],
              top.select('resid 1 and name C')[0] ]
psi_idx   = [ top.select('resid 1 and name N')[0],
              top.select('resid 1 and name CA')[0],
              top.select('resid 1 and name C')[0],
              top.select('resid 2 and name N')[0] ]
omega_idx = [ top.select('resid 1 and name CA')[0],
              top.select('resid 1 and name C')[0],
              top.select('resid 2 and name N')[0],
              top.select('resid 2 and name C')[0] ]

for rep in rep_dirs:
    dcd_in = os.path.join(rep, 'traj.dcd')
    out_np = os.path.join(rep, 'internal_full.npy')

    traj = md.load(dcd_in, top=topology)

    phi   = md.compute_dihedrals(traj, [phi_idx])[:,0]
    psi   = md.compute_dihedrals(traj, [psi_idx])[:,0]
    omega = md.compute_dihedrals(traj, [omega_idx])[:,0]

    feats = np.vstack([
        np.sin(phi), np.cos(phi),
        np.sin(psi), np.cos(psi),
        np.sin(omega), np.cos(omega)
    ]).T

    np.save(out_np, feats)
    print(f"Saved {os.path.basename(out_np)} → {feats.shape[0]} frames × 6 dims")

In [None]:
# Cell 11
import numpy as np

all_feats = []
for rep in rep_dirs:
    data = np.load(os.path.join(rep, 'internal_full.npy'))
    all_feats.append(data)

stacked = np.vstack(all_feats)
mu  = stacked.mean(axis=0)
std = stacked.std(axis=0)

np.save(os.path.join(processed, 'mu.npy'),  mu)
np.save(os.path.join(processed, 'std.npy'), std)
print(f"Saved mu/std ({mu.shape}) to {processed}")

In [None]:
# Cell 12
import os, numpy as np, torch

t_projection_start = time.time()
forces_list, energies_list = [], []

for rep in rep_dirs:
    F_atoms = np.load(os.path.join(rep, 'forces.npy'))
    E       = np.load(os.path.join(rep, 'energies.npy'))

    dcd = os.path.join(rep, 'traj.dcd')
    traj = md.load(dcd, top=full_pdb)

    for i in range(len(E)):
        energies_list.append(E[i])
        featF = project_dihedral_forces(
            positions=traj.xyz[i],
            atom_forces=F_atoms[i],
            phi_idx=phi_idx,
            psi_idx=psi_idx,
            omega_idx=omega_idx
        )
        forces_list.append(featF)

forces_arr   = torch.tensor(np.stack(forces_list),   dtype=torch.float32)
energies_arr = torch.tensor(np.array(energies_list), dtype=torch.float32)

os.makedirs(processed, exist_ok=True)
torch.save(forces_arr,   os.path.join(processed, 'forces.pt'))
torch.save(energies_arr, os.path.join(processed, 'energies.pt'))
print("Saved forces.pt", forces_arr.shape,
      "and energies.pt", energies_arr.shape)
t_proj = time.time() - t_projection_start
print(f"Force projection took {t_proj/3600:.2f} hours.")

In [None]:
# Cell 13
import os, glob, numpy as np, torch
from torch_geometric.data import Data

raw_dir    = root
proc_dir   = processed


F_all = torch.load(os.path.join(proc_dir, 'forces.pt'))
U_all = torch.load(os.path.join(proc_dir, 'energies.pt'))

feat_files = sorted(glob.glob(os.path.join(raw_dir, 'rep_*', 'internal_full.npy')))

mu  = np.load(os.path.join(proc_dir, 'mu.npy'))
std = np.load(os.path.join(proc_dir, 'std.npy'))

edge_index = torch.tensor(
    [[i, j] for i in range(3) for j in range(3) if i != j],
    dtype=torch.long
).t()

data_list = []
idx = 0
for fpath in feat_files:
    feats = np.load(fpath)[:frames_per_rep]
    print(f"→ {os.path.basename(os.path.dirname(fpath))}: {feats.shape[0]} frames")
    for frame in feats:
        norm = (frame - mu) / std
        x = torch.tensor(
            [[norm[2*i], norm[2*i+1]] for i in range(3)],
            dtype=torch.float32
        )

        y_force  = F_all[idx]
        y_energy = U_all[idx]

        data = Data(
            x=x,
            edge_index=edge_index,
            y_force=y_force,
            y_energy=y_energy
        )
        data_list.append(data)
        idx += 1

print("Built dataset of", len(data_list), "graphs")

In [None]:
# Cell 14
import random
from torch_geometric.loader import DataLoader

shuffled   = data_list.copy()
random.shuffle(shuffled)
split      = int(0.9 * len(shuffled))
train_list = shuffled[:split]
val_list   = shuffled[split:]

batch_size    = 16
train_loader  = DataLoader(train_list, batch_size=batch_size, shuffle=False, num_workers=0)
val_loader    = DataLoader(val_list,   batch_size=batch_size, shuffle=False, num_workers=0)

print(f"Train graphs: {len(train_list)}, Val graphs: {len(val_list)}")

In [None]:
# Cell 15
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv, global_mean_pool
import math

def sinusoidal_time_embeddings(t, dim):
    half = dim // 2
    freqs = torch.exp(-math.log(1e4) * torch.arange(half, device=t.device) / (half - 1))
    args  = t[:, None].float() * freqs[None]
    emb   = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    return F.pad(emb, (0,1)) if dim % 2 else emb

class GraphDiffusionTransformer(nn.Module):
    def __init__(self, node_dim, time_dim, hidden_dim, layers, heads, dropout, timesteps):
        super().__init__()
        self.timesteps = timesteps
        self.time_mlp = nn.Sequential(
            nn.Linear(time_dim, hidden_dim), nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.x_proj = nn.Linear(node_dim, hidden_dim)
        self.convs  = nn.ModuleList([
            TransformerConv(hidden_dim, hidden_dim, heads=heads, concat=False, dropout=dropout, beta=True)
            for _ in range(layers)
        ])
        self.norms = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(layers)])
        self.out_noise  = nn.Linear(hidden_dim, node_dim)
        self.out_force  = nn.Linear(hidden_dim, node_dim)
        self.out_energy = nn.Linear(hidden_dim, 1)

    def _forward(self, x, edge_index, batch, t_node):
        """
        x:       (total_nodes, node_dim)
        edge_index: graph connectivity
        batch:   (total_nodes,) mapping nodes → graph ids
        t_node:  (total_nodes,) noise‐level per node
        """
        h = self.x_proj(x) + self.time_mlp(sinusoidal_time_embeddings(t_node, self.time_mlp[0].in_features))
        for conv, norm in zip(self.convs, self.norms):
            h2 = conv(h, edge_index)
            h  = norm(h + h2)

        noise = self.out_noise(h)
        force = self.out_force(h)

        energy_nodes = self.out_energy(h)
        energy = global_mean_pool(energy_nodes, batch).view(-1)
        return noise, force, energy

class GaussianDiffusion(nn.Module):
    def __init__(self, betas):
        super().__init__()
        alphas = 1 - betas
        a_cum  = torch.cumprod(alphas, dim=0)
        self.register_buffer('sqrt_a_cum',   torch.sqrt(a_cum))
        self.register_buffer('sqrt_1ma_cum', torch.sqrt(1 - a_cum))

    def q_sample(self, x, t_node, noise=None):
        if noise is None:
            noise = torch.randn_like(x)
        a = self.sqrt_a_cum[t_node].view(-1,1)
        m = self.sqrt_1ma_cum[t_node].view(-1,1)
        return a * x + m * noise

    def p_losses(self, model, x_start, edge_index, batch, t):
        """
        x_start:   (total_nodes, node_dim)
        edge_index, batch: from Data.batch
        t:         (batch_size,) random timesteps per graph
        """
        noise   = torch.randn_like(x_start)
        t_node  = t[batch]
        x_noisy = self.q_sample(x_start, t_node, noise)
        noise_pred, force_pred, energy_pred = model._forward(x_noisy, edge_index, batch, t_node)

        L_noise  = F.mse_loss(noise_pred, noise)
        L_force  = self.λ_f * F.mse_loss(force_pred, self.F_MM[batch])
        L_energy = self.λ_E * F.mse_loss(energy_pred, self.U_MM[batch.unique()])
        return L_noise, L_force, L_energy

In [None]:
# Cell 16
from torch.optim import Adam

timesteps  = 1000
beta_start = 1e-4
beta_end   = 0.02
time_dim   = 128
hidden_dim = 128
layers     = 6
heads      = 4
dropout    = 0.1
lr         = 3e-4
epochs     = 50
save_every = 50



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

model = GraphDiffusionTransformer(
    node_dim=train_list[0].x.size(1),
    time_dim=time_dim,
    hidden_dim=hidden_dim,
    layers=layers,
    heads=heads,
    dropout=dropout,
    timesteps=timesteps
).to(device)

print(model)

In [None]:
# Cell 17
T = 1000
beta_min, beta_max = 1e-4, 2e-2
betas  = torch.linspace(beta_min, beta_max, T, device=device)
alphas = 1.0 - betas
a_bar  = torch.cumprod(alphas, dim=0)

diffusion = GaussianDiffusion(betas).to(device)

In [None]:
# Cell 18
import os
from tqdm.auto import tqdm
import time
import torch
import matplotlib.pyplot as plt

t_train_start = time.time()

train_noise_losses = []
train_force_losses = []
train_energy_losses = []
train_total_losses = []
valid_noise_losses = []
valid_force_losses = []
valid_energy_losses = []
valid_total_losses = []

model = GraphDiffusionTransformer(
    node_dim=2, time_dim=time_dim, hidden_dim=hidden_dim,
    layers=layers, heads=heads, dropout=dropout, timesteps=timesteps
).to(device)
diffusion = GaussianDiffusion(betas.to(device)).to(device)
diffusion.λ_f = 1.0
diffusion.λ_E = 0.1

raw_forces   = torch.load(os.path.join(proc_dir, 'forces.pt')).to(device)
raw_energies = torch.load(os.path.join(proc_dir, 'energies.pt')).to(device)

forces_nodes = raw_forces.view(-1, 3, 2).reshape(-1, 2)
F_mean, F_std = forces_nodes.mean(0, keepdim=True), forces_nodes.std(0, keepdim=True)
forces_nodes  = (forces_nodes - F_mean) / F_std

U_mean, U_std    = raw_energies.mean(), raw_energies.std()
energies_norm    = (raw_energies - U_mean) / U_std

diffusion.F_MM = forces_nodes
diffusion.U_MM = energies_norm

ckpt_dir = os.path.join(out_root, 'checkpoints')
os.makedirs(ckpt_dir, exist_ok=True)
print(f"ckpt_dir = {ckpt_dir}")

opt = torch.optim.Adam(model.parameters(), lr=lr * 0.1)

best_val_loss     = float('inf')
epochs_no_improve = 0
patience          = 5
save_best_path    = os.path.join(ckpt_dir, 'model_best.pt')
print(f"Early stopping patience = {patience} epochs; best model → {save_best_path}")

for epoch in range(1, epochs+1):
    sum_n = sum_f = sum_e = 0.0
    model.train()
    for batch in tqdm(train_loader, desc=f"Train {epoch}/{epochs}", leave=False):
        batch = batch.to(device)
        t     = torch.randint(0, timesteps, (batch.num_graphs,), device=device)
        L_n, L_f, L_e = diffusion.p_losses(
            model, batch.x, batch.edge_index, batch.batch, t
        )
        loss = L_n + L_f + L_e
        opt.zero_grad()
        loss.backward()
        opt.step()

        sum_n += L_n.item() * batch.num_graphs
        sum_f += L_f.item() * batch.num_graphs
        sum_e += L_e.item() * batch.num_graphs

    avg_n = sum_n / len(train_list)
    avg_f = sum_f / len(train_list)
    avg_e = sum_e / len(train_list)
    train_total = avg_n + avg_f + avg_e
    print(f"Epoch {epoch:3d} — "
          f"L_noise: {avg_n:.5e}, L_force: {avg_f:.5e}, L_energy: {avg_e:.5e}, "
          f"total: {train_total:.5e}")
    train_noise_losses.append(avg_n)
    train_force_losses.append(avg_f)
    train_energy_losses.append(avg_e)
    train_total_losses.append(train_total)

    model.eval()
    val_n = val_f = val_e = 0.0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Valid", leave=False):
            batch = batch.to(device)
            t     = torch.randint(0, timesteps, (batch.num_graphs,), device=device)
            L_n, L_f, L_e = diffusion.p_losses(
                model, batch.x, batch.edge_index, batch.batch, t
            )
            val_n += L_n.item() * batch.num_graphs
            val_f += L_f.item() * batch.num_graphs
            val_e += L_e.item() * batch.num_graphs

    avg_vn = val_n / len(val_list)
    avg_vf = val_f / len(val_list)
    avg_ve = val_e / len(val_list)
    val_total = avg_vn + avg_vf + avg_ve
    print(f" Valid - L_noise: {avg_vn:.5e}, L_force: {avg_vf:.5e}, L_energy: {avg_ve:.5e}, "
          f"total: {val_total:.5e}")
    valid_noise_losses.append(avg_vn)
    valid_force_losses.append(avg_vf)
    valid_energy_losses.append(avg_ve)
    valid_total_losses.append(val_total)

    if val_total < best_val_loss:
        best_val_loss     = val_total
        epochs_no_improve = 0
        torch.save(model.state_dict(), save_best_path)
        print(f"New best model (epoch {epoch}) saved")
    else:
        epochs_no_improve += 1
        print(f"No improvement for {epochs_no_improve}/{patience} epochs")
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch}; best_val_loss = {best_val_loss:.5e}")
            break

    if epoch % save_every == 0:
        ck = os.path.join(ckpt_dir, f'model_epoch{epoch:03d}.pt')
        torch.save({'epoch': epoch,
                    'model': model.state_dict(),
                    'opt': opt.state_dict()}, ck)
        print("Saved", ck)

model.load_state_dict(torch.load(save_best_path))
print(f"Loaded best model from {save_best_path}")

torch.save(model.state_dict(), os.path.join(ckpt_dir, 'model_final.pt'))

print("Training complete")
t_train = time.time() - t_train_start
print(f"Training took {t_train/3600:.2f} hours.")

epochs = range(1, len(train_noise_losses) + 1)

plt.figure()
plt.plot(epochs, train_noise_losses, label='Train')
plt.plot(epochs, valid_noise_losses, label='Valid')
plt.xlabel('Epoch'); plt.ylabel('Noise loss')
plt.title('Noise Loss vs. Epoch')
plt.legend(); plt.tight_layout(); plt.show()

plt.figure()
plt.plot(epochs, train_force_losses, label='Train')
plt.plot(epochs, valid_force_losses, label='Valid')
plt.xlabel('Epoch'); plt.ylabel('Force loss')
plt.title('Force Loss vs. Epoch')
plt.legend(); plt.tight_layout(); plt.show()

plt.figure()
plt.plot(epochs, train_energy_losses, label='Train')
plt.plot(epochs, valid_energy_losses, label='Valid')
plt.xlabel('Epoch'); plt.ylabel('Energy loss')
plt.title('Energy Loss vs. Epoch')
plt.legend(); plt.tight_layout(); plt.show()

plt.figure()
plt.plot(epochs, train_total_losses, label='Train')
plt.plot(epochs, valid_total_losses, label='Valid')
plt.xlabel('Epoch'); plt.ylabel('Total loss')
plt.title('Total Loss vs. Epoch')
plt.legend(); plt.tight_layout(); plt.show()

In [None]:
# Cell 19
import os

output_dir    = '/content/drive/MyDrive/alanine_dipeptide/output'
raw_dir       = '/content/drive/MyDrive/alanine_dipeptide/raw'
os.makedirs(output_dir, exist_ok=True)
os.makedirs(raw_dir, exist_ok=True)

processed_dir = output_dir
ckpt_dir      = os.path.join(output_dir, 'checkpoints')
os.makedirs(ckpt_dir, exist_ok=True)

In [None]:
# Cell 20
import torch
from openmm.unit import MOLAR_GAS_CONSTANT_R, kilojoule_per_mole, kelvin
from openmm import app

device     = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model      = model.to(device)
edge_index = edge_index.to(device)

timesteps = 1000
alphas    = 1 - betas
a_cumprod = torch.cumprod(alphas, dim=0)

T0       = 150
sqrt_ab  = a_cumprod[T0].sqrt()
sqrt_1mab= (1 - a_cumprod[T0]).sqrt()

max_tries = 20
kT_q = MOLAR_GAS_CONSTANT_R * 300 * kelvin
kT   = kT_q.value_in_unit(kilojoule_per_mole)
lam       = 0.15

pdb_path = os.path.join(raw_dir, 'full.pdb')

pdb_full = PDBFile(pdb_path)

heavy_indices = [
    atom.index for atom in pdb_full.topology.atoms()
    if atom.element.symbol!='H' and atom.residue.name not in ('HOH','WAT')
]

In [None]:
# Cell 21
import sys, os, math, numpy as np, torch, mdtraj as md
from openmm.app import (
    PDBFile, Modeller, ForceField, NoCutoff, HBonds, Simulation
)
from openmm import (
    CustomTorsionForce, LocalEnergyMinimizer,
    LangevinIntegrator, unit
)

raw_dir      = '/content/drive/MyDrive/alanine_dipeptide/raw'
gbsa_ff      = ForceField('amber99_obc.xml', 'amber99sbildn.xml')

peptide_residue_names = ['ACE', 'ALA', 'NME']
heavy_to_full = [
    a.index for r in pdb_full.topology.residues()
    if r.name in peptide_residue_names
    for a in r.atoms() if a.element.symbol != 'H'
]
full_np = np.array([v.value_in_unit(unit.nanometer) for v in pdb_full.positions])

implicit_modeller = Modeller(pdb_full.topology, pdb_full.positions)
implicit_modeller.addHydrogens(gbsa_ff)

peptide_heavy_idx_mod = [
    a.index for r in implicit_modeller.topology.residues()
    if r.name in peptide_residue_names
    for a in r.atoms() if a.element.symbol != 'H'
]

heavy_to_full_mod = peptide_heavy_idx_mod.copy()

def reconstruct_frame_with_openmm(phi, psi, omega, implicit_modeller):
    """
    Rebuild the peptide in implicit solvent with stiff φ/ψ/ω restraints.
    Returns (coords_nm, phi_atoms, psi_atoms, omega_atoms) or Nones on error.
    """
    try:
        m = implicit_modeller

        system = gbsa_ff.createSystem(
            m.topology, nonbondedMethod=NoCutoff, constraints=HBonds
        )

        ct = CustomTorsionForce(
            "k*min((theta-theta0)^2,(2*pi-abs(theta-theta0))^2)"
        )
        ct.addPerTorsionParameter("theta0")
        ct.addGlobalParameter("pi", np.pi)
        ct.addPerTorsionParameter("k")
        k_val = 100.0 * unit.kilojoule_per_mole / unit.radian**2

        top = m.topology
        C_ACE = next(a.index for r in top.residues() if r.name=='ACE' for a in r.atoms() if a.name=='C')
        N_ALA = next(a.index for r in top.residues() if r.name=='ALA' for a in r.atoms() if a.name=='N')
        CA_ALA= next(a.index for r in top.residues() if r.name=='ALA' for a in r.atoms() if a.name=='CA')
        C_ALA = next(a.index for r in top.residues() if r.name=='ALA' for a in r.atoms() if a.name=='C')
        N_NME = next(a.index for r in top.residues() if r.name=='NME' for a in r.atoms() if a.name=='N')
        C_NME = next(a.index for r in top.residues() if r.name=='NME' for a in r.atoms() if a.name=='C')

        phi_atoms   = [C_ACE, N_ALA, CA_ALA, C_ALA]
        psi_atoms   = [N_ALA, CA_ALA, C_ALA, N_NME]
        omega_atoms = [CA_ALA, C_ALA, N_NME, C_NME]

        ct.addTorsion(*phi_atoms,   [phi,   k_val])
        ct.addTorsion(*psi_atoms,   [psi,   k_val])
        ct.addTorsion(*omega_atoms, [omega, k_val])
        system.addForce(ct)

        sim = Simulation(top, system, LangevinIntegrator(300*unit.kelvin,1/unit.picosecond,0.002*unit.picosecond))
        sim.context.setPositions(m.positions)
        sim.minimizeEnergy(maxIterations=500)

        coords = sim.context.getState(getPositions=True)\
                    .getPositions(asNumpy=True).value_in_unit(unit.nanometer)
        return coords, phi_atoms, psi_atoms, omega_atoms

    except Exception as e:
        print("Error in reconstruct_frame_with_openmm:", e)
        sys.excepthook(type(e), e, e.__traceback__)
        return None, None, None, None

def reconstruct_heavy(feat: np.ndarray) -> np.ndarray:
    phi, psi, omg = (math.atan2(feat[0], feat[1]),
                     math.atan2(feat[2], feat[3]),
                     math.atan2(feat[4], feat[5]))
    coords, *_ = reconstruct_frame_with_openmm(phi, psi, omg, implicit_modeller)
    return coords[peptide_heavy_idx_mod]

def energy_from_heavy(heavy_xyz: np.ndarray, tol: float = 10.0) -> float:
    """
    heavy_xyz : (N_heavy,3) nm  — coordinates of the peptide heavy atoms
    Returns    : potential energy in kJ/mol
    """
    m = Modeller(implicit_modeller.topology, implicit_modeller.positions)

    pos_arr = np.asarray(m.positions.value_in_unit(unit.nanometer))

    pos_arr[np.asarray(peptide_heavy_idx_mod), :] = heavy_xyz

    m.positions = pos_arr * unit.nanometer

    system = gbsa_ff.createSystem(
        m.topology, nonbondedMethod=NoCutoff, constraints=HBonds
    )
    sim = Simulation(
        m.topology, system,
        LangevinIntegrator(300*unit.kelvin, 1/unit.picosecond, 0.002*unit.picosecond)
    )
    sim.context.setPositions(m.positions)

    return sim.context.getState(getEnergy=True)\
              .getPotentialEnergy()\
              .value_in_unit(unit.kilojoule_per_mole)

def compute_openmm_forces(x_feat, delta=1e-3):
    """
    x_feat: torch.Tensor shape (3,2) with [sin,cos] for φ,ψ,ω.
    Returns torch.Tensor shape (3,2): ∂U/∂[sin,cos] for each dihedral.
    """
    arr = x_feat.cpu().numpy()
    angles = np.arctan2(arr[:,0], arr[:,1])  # [φ,ψ,ω]

    grads = []
    for i, theta in enumerate(angles):
        th_f = theta + delta
        th_b = theta - delta

        arr_f = angles.copy(); arr_b = angles.copy()
        arr_f[i] = th_f; arr_b[i] = th_b

        def to_feats(ths):
            return np.stack([np.sin(ths), np.cos(ths)], axis=1)

        x_f = to_feats(arr_f)
        x_b = to_feats(arr_b)

        U_f = energy_from_heavy(reconstruct_heavy(x_f.flatten()))
        U_b = energy_from_heavy(reconstruct_heavy(x_b.flatten()))

        dU_dθ = (U_f - U_b)/(2*delta)

        s, c = arr[i]
        denom = s*s + c*c
        dθ_ds =  c/denom
        dθ_dc = -s/denom

        grads.append([dU_dθ * dθ_ds, dU_dθ * dθ_dc])

    return torch.tensor(grads, device=device, dtype=torch.float32)

In [None]:
# Cell 22
import numpy as np, math
import torch

α = 0.01

def mh_jump(x_in):
    """
    x_in   : torch.Tensor on device, shape (3,2)
    returns: (x_prop, mu) for downstream Δlog p
    """
    if not hasattr(mh_jump, "_proposals"):
        mh_jump._proposals = []
        mh_jump._delta_vecs = []
    for i in range(max_tries):
        print(f"→   inside mh_jump: T0={T0}, lam={lam}")

        noise = torch.randn_like(x_in)
        xT    = sqrt_ab * x_in + sqrt_1mab * noise

        t_node    = torch.zeros(x_in.size(0), dtype=torch.long, device=device)
        batch_vec = torch.zeros_like(t_node)
        eps, _, _ = model._forward(xT, edge_index, batch_vec, t_node)

        F_xt   = compute_openmm_forces(xT)
        guided = eps - α * F_xt

        β_t  = betas[T0]
        α_t  = alphas[T0]
        ᾱ_t = a_bar[T0]

        coef1 = 1.0 / torch.sqrt(α_t)
        coef2 = β_t / torch.sqrt(1.0 - ᾱ_t)
        mu    = coef1 * (xT - coef2 * guided)

        x_prop = (1 - lam) * x_in + lam * mu


        mh_jump._proposals.append(x_prop.detach().cpu().numpy())


        mh_jump._delta_vecs.append((x_prop - x_in).detach().cpu().numpy())


        Δx = (x_prop - x_in).abs().max().item()
        print(f"→   max |Δx| = {Δx:.3e}")

        φ_prop = math.degrees(math.atan2(x_prop[0,0].item(), x_prop[0,1].item()))
        φ_old  = math.degrees(math.atan2(x_in[0,0].item(),  x_in[0,1].item()))
        print(f"→   Δφ = {φ_prop - φ_old:.2f}°")
        ψ_prop = math.degrees(math.atan2(x_prop[1,0].item(), x_prop[1,1].item()))
        ψ_old  = math.degrees(math.atan2(x_in[1,0].item(),  x_in[1,1].item()))
        print(f"→   Δψ = {ψ_prop - ψ_old:.2f}°")

        return x_prop, mu

In [None]:
# Cell 23
import os, mdtraj as md, math
from openmm.app import (
    PDBFile, Modeller, ForceField, Simulation,
    DCDReporter, StateDataReporter, NoCutoff, HBonds
)
from openmm import LangevinIntegrator, unit, CustomTorsionForce, LocalEnergyMinimizer
from openmm.app import AllBonds
import mdtraj.utils.unit
mdtraj.utils.unit.openmm_unit = unit


os.makedirs(output_dir, exist_ok=True)


pdb_path = os.path.join(raw_dir, 'full.pdb')

pdb_full = PDBFile(pdb_path)
gbsa_ff   = ForceField('amber99sbildn.xml', 'amber99_obc.xml')
modeller  = Modeller(pdb_full.topology, pdb_full.positions)
modeller.addHydrogens(gbsa_ff)
implicit_modeller = modeller

heavy_ref = md.load(os.path.join(raw_dir, 'heavy_only.pdb'))
heavy_top = heavy_ref.topology

ace_C_idx  = heavy_top.select('resid 0 and name C')[0]
ala_N_idx  = heavy_top.select('resid 1 and name N')[0]
ala_CA_idx = heavy_top.select('resid 1 and name CA')[0]
ala_C_idx  = heavy_top.select('resid 1 and name C')[0]
nme_N_idx  = heavy_top.select('resid 2 and name N')[0]
nme_C_idx  = heavy_top.select('resid 2 and name C')[0]

phi_indices   = [ace_C_idx, ala_N_idx, ala_CA_idx, ala_C_idx]
psi_indices   = [ala_N_idx, ala_CA_idx, ala_C_idx, nme_N_idx]
omega_indices = [ala_CA_idx, ala_C_idx, nme_N_idx, nme_C_idx]

system     = gbsa_ff.createSystem(
    implicit_modeller.topology,
    nonbondedMethod=NoCutoff,
    constraints=HBonds
)
integrator = LangevinIntegrator(
    300*unit.kelvin,
    1/unit.picosecond,
    0.002*unit.picosecond
)

tors_force = CustomTorsionForce(
    "0.5*k*min(abs(theta-theta0), 2*pi-abs(theta-theta0))"
    "*min(abs(theta-theta0), 2*pi-abs(theta-theta0))"
)
tors_force.addPerTorsionParameter("theta0")
tors_force.addPerTorsionParameter("k")
tors_force.addGlobalParameter("pi", math.pi)
k_val  = 100.0 * unit.kilojoule_per_mole / unit.radian**2
k_zero = 0.0   * unit.kilojoule_per_mole / unit.radian**2


phi_id   = tors_force.addTorsion(*phi_indices,   [0.0, k_zero])
psi_id   = tors_force.addTorsion(*psi_indices,   [0.0, k_zero])
omega_id = tors_force.addTorsion(*omega_indices, [0.0, k_zero])

system.addForce(tors_force)

simulation = Simulation(
    implicit_modeller.topology,
    system,
    integrator
)
simulation.context.setPositions(implicit_modeller.positions)
simulation.minimizeEnergy()
simulation.context.setVelocitiesToTemperature(300*unit.kelvin)

simulation.reporters.append(
    DCDReporter(os.path.join(output_dir, 'hybrid.dcd'), 50)
)
simulation.reporters.append(
    StateDataReporter(
        os.path.join(output_dir, 'hybrid.log'),
        5000, step=True, time=True,
        potentialEnergy=True, temperature=True
    )
)

In [None]:
#Cell 24
import mdtraj as md
from openmm.app import PDBFile

full_top_omm = simulation.topology

heavy_to_full = [
    atom.index for atom in full_top_omm.atoms()
    if atom.element.symbol != 'H' and atom.residue.name in ('ACE','ALA','NME')
]

full_top_md   = md.Topology.from_openmm(full_top_omm)
heavy_top     = full_top_md.subset(heavy_to_full)

print("heavy_top atoms :", heavy_top.n_atoms)

implicit_full  = PDBFile(full_pdb)
imp_mod_traj   = md.Topology.from_openmm(implicit_full.topology)
implicit_top_md = imp_mod_traj

ace_C  = heavy_top.select('resname ACE and name C')[0]
ala_N  = heavy_top.select('resname ALA and name N')[0]
ala_CA = heavy_top.select('resname ALA and name CA')[0]
ala_C  = heavy_top.select('resname ALA and name C')[0]
nme_N  = heavy_top.select('resname NME and name N')[0]
nme_C  = heavy_top.select('resname NME and name C')[0]

phi_indices   = [ace_C,  ala_N, ala_CA, ala_C]
psi_indices   = [ala_N,  ala_CA, ala_C, nme_N]
omega_indices = [ala_CA, ala_C,  nme_N, nme_C]

print("phi_indices :", phi_indices)
print("psi_indices :", psi_indices)
print("omega_indices:", omega_indices)
print("heavy_top atoms:", heavy_top.n_atoms)

In [None]:
# Cell 25
import os, math, random, numpy as np, mdtraj as md, torch, time
from openmm import unit, LocalEnergyMinimizer

t_sample_start = time.time()

force_reject = False

total_ns        = 80
dt_ps           = 0.002
total_steps     = int(total_ns*1000.0 / dt_ps)

injection_ns    = 2.0
steps_per_block = int(injection_ns * 1e3 / dt_ps)
cumulative_steps = 0
block = 0

max_min_steps   = 500
bar_pairs       = []
accepted_log    = []
U_old           = None

T0   = 150
var  = (1 - a_cumprod[T0]).item()

def potential_energy_kj(sim):
    return sim.context.getState(getEnergy=True) \
               .getPotentialEnergy().value_in_unit(unit.kilojoule_per_mole)

while cumulative_steps < total_steps:
    block += 1

    simulation.step(steps_per_block)
    cumulative_steps += steps_per_block

    U_current = potential_energy_kj(simulation)
    if U_old is None:
        U_old = U_current

    state_old = simulation.context.getState(
        getPositions=True, getVelocities=True, enforcePeriodicBox=True
    )

    xyz_full  = simulation.context.getState(getPositions=True)\
                        .getPositions(asNumpy=True).value_in_unit(unit.nanometer)
    heavy_xyz = xyz_full[heavy_to_full]
    traj0     = md.Trajectory(xyz=heavy_xyz[np.newaxis,:,:], topology=heavy_top)
    φ, ψ, ω   = (
        md.compute_dihedrals(traj0, [phi_indices])[0,0],
        md.compute_dihedrals(traj0, [psi_indices])[0,0],
        md.compute_dihedrals(traj0, [omega_indices])[0,0],
    )

    x_feat = torch.tensor([
        [math.sin(φ), math.cos(φ)],
        [math.sin(ψ), math.cos(ψ)],
        [math.sin(ω), math.cos(ω)],
    ], dtype=torch.float32, device=device)

    x_prop, mu = mh_jump(x_feat)
    bar_pairs.append([U_current, None, 0.0])
    accepted_log.append(False)

    feat_np = x_prop.detach().cpu().numpy().flatten()
    φp, ψp, ωp = (
        math.atan2(feat_np[0], feat_np[1]),
        math.atan2(feat_np[2], feat_np[3]),
        math.atan2(feat_np[4], feat_np[5]),
    )
    for tors_id, idxs, th in [(phi_id, phi_indices,   φp),
                              (psi_id, psi_indices,   ψp),
                              (omega_id,omega_indices,ωp)]:
        tors_force.setTorsionParameters(tors_id, *idxs, [th, k_val])
    tors_force.updateParametersInContext(simulation.context)

    print(f"DEBUG pre-minimize proposal → φ={math.degrees(φp):.2f}°, "
          f"ψ={math.degrees(ψp):.2f}°, ω={math.degrees(ωp):.2f}°")

    LocalEnergyMinimizer.minimize(
        simulation.context,
        10*unit.kilojoule_per_mole/unit.nanometer,
        maxIterations=max_min_steps
    )

    state = simulation.context.getState(getPositions=True)
    pos   = state.getPositions(asNumpy=True).value_in_unit(unit.nanometer)

    traj_min = md.Trajectory(pos[heavy_to_full][np.newaxis,:,:], heavy_top)
    phi_act   = md.compute_dihedrals(traj_min, [phi_indices])[0,0]
    psi_act   = md.compute_dihedrals(traj_min, [psi_indices])[0,0]
    omega_act = md.compute_dihedrals(traj_min, [omega_indices])[0,0]
    print(f"DEBUG post-minimize actual   → φ={math.degrees(phi_act):.2f}°, "
          f"ψ={math.degrees(psi_act):.2f}°, ω={math.degrees(omega_act):.2f}°")

    U_prop = potential_energy_kj(simulation)

    diff_old   =       (x_feat - mu).view(-1).pow(2).sum().item()
    diff_new   =       (x_prop - mu).view(-1).pow(2).sum().item()
    const_term = -0.5 * x_feat.numel() * math.log(2*math.pi*var)
    logp_old   =  const_term - 0.5 * diff_old/var
    logp_new   =  const_term - 0.5 * diff_new/var
    delta_logp =  logp_new - logp_old
    bar_pairs[-1][2] = delta_logp

    U_old_prev = U_old

    kT      = (unit.MOLAR_GAS_CONSTANT_R * 300 * unit.kelvin)\
                   .value_in_unit(unit.kilojoule_per_mole)
    delta   = (U_old - U_prop) / kT
    d_clamp = max(min(delta, 50.0), -50.0)
    acc_prob= math.exp(d_clamp)
    accept  = (delta >= 0.0) or (random.random() < acc_prob)
    if force_reject:
      accept = False

    if accept:
        for tors_id, idxs in [(phi_id, phi_indices),
                              (psi_id, psi_indices),
                              (omega_id,omega_indices)]:
            tors_force.setTorsionParameters(tors_id, *idxs, [0.0, k_zero])
        tors_force.updateParametersInContext(simulation.context)
        accepted_log[-1]          = True
        U_old = U_prop
    else:
        simulation.context.setState(state_old)
        for tors_id, idxs in [(phi_id, phi_indices),
                              (psi_id, psi_indices),
                              (omega_id,omega_indices)]:
            tors_force.setTorsionParameters(tors_id, *idxs, [0.0, k_zero])
        tors_force.updateParametersInContext(simulation.context)

        U_prop                    = potential_energy_kj(simulation)

    bar_pairs[-1][0] = U_old_prev
    bar_pairs[-1][1] = U_prop

    deltaE = U_prop - U_old_prev
    total_ns_run = cumulative_steps * dt_ps / 1000.0
    print(f"Block {block:2d} — accept={accepted_log[-1]}  "
          f"U_old={U_old_prev:.2f}  U_prop={U_prop:.2f}  "
          f"ΔE={deltaE:+.2f}  acc_prob={acc_prob:.1e}  "
          f"total_ns={total_ns_run:.3f} ns")

np.save(os.path.join(processed_dir,'bar_log.npy'),   np.array(bar_pairs),   allow_pickle=False)
np.save(os.path.join(processed_dir,'accepted.npy'), np.array(accepted_log), allow_pickle=False)
np.save(os.path.join(processed_dir, 'hybrid_proposals.npy'),
        np.stack(mh_jump._proposals))
np.save(os.path.join(processed_dir, 'hybrid_delta_x.npy'),
        np.stack(mh_jump._delta_vecs))

t_sample = time.time() - t_sample_start
print("✓ Saved bar_log.npy, accepted.npy, hybrid_proposals.npy, and hybrid_delta_x.npy")
print(f"Sampling took {t_sample/3600:.2f} hours.")

In [None]:
# Loading Cell (Hybrid)
import mdtraj as md
import os

root       = '/content/drive/MyDrive/alanine_dipeptide'
output_dir = os.path.join(root, 'output')

dcd_path      = os.path.join(output_dir, 'hybrid.dcd')
log_path      = os.path.join(output_dir, 'hybrid.log')
topology_path = os.path.join(root,       'full.pdb')

print("DCD exists?   ", os.path.exists(dcd_path), dcd_path)
print("LOG exists?   ", os.path.exists(log_path), log_path)
print("PDB exists?   ", os.path.exists(topology_path), topology_path)

traj = md.load_dcd(dcd_path, top=topology_path)
print(f"Loaded hybrid trajectory: {traj.n_frames} frames × {traj.n_atoms} atoms")

print("\nFirst 10 lines of hybrid.log:")
with open(log_path, 'r') as f:
    for _ in range(10):
        print(f.readline().rstrip())

In [None]:
# Cell 26
import os
from openmm.app import (
    PDBFile, Modeller, ForceField,
    DCDReporter, StateDataReporter,
    Simulation, HBonds, NoCutoff
)
from openmm import LangevinIntegrator, unit
import time

t0 = time.time()

classical_dir = os.path.join(out_root, 'classical')
os.makedirs(classical_dir, exist_ok=True)

pdb      = PDBFile(os.path.join(out_root, 'full.pdb'))
ff_impl  = ForceField('amber99sbildn.xml', 'amber99_obc.xml')
modeller = Modeller(pdb.topology, pdb.positions)
modeller.addHydrogens(ff_impl)

system = ff_impl.createSystem(
    modeller.topology,
    nonbondedMethod=NoCutoff,
    constraints=HBonds
)
integrator = LangevinIntegrator(
    300 * unit.kelvin,
    1.0 / unit.picosecond,
    2.0 * unit.femtoseconds
)
sim = Simulation(modeller.topology, system, integrator)
sim.context.setPositions(modeller.positions)
sim.minimizeEnergy()
sim.context.setVelocitiesToTemperature(300*unit.kelvin)

dcd_path = os.path.join(classical_dir, 'classical.dcd')
log_path = os.path.join(classical_dir, 'classical.log')
sim.reporters.append(DCDReporter(dcd_path, 50))
sim.reporters.append(StateDataReporter(
    log_path, 1000, step=True, time=True,
    potentialEnergy=True, temperature=True
))

nsteps = int(80_000.0 / 0.002)
print(f"Running classic MD for {nsteps:,} steps (~80 ns)…")
sim.step(nsteps)
print("Classical MD complete")
print("Trajectory: ", dcd_path)
print("Log: ", log_path)
print(f"Wall time: {(time.time() - t0)/3600:.2f} hours")

In [None]:
# Loading Cell (Classical)
import mdtraj as md
import os

root       = '/content/drive/MyDrive/alanine_dipeptide'
classical_dir = os.path.join(root, 'classical')

dcd_path      = os.path.join(classical_dir, 'classical.dcd')
log_path      = os.path.join(classical_dir, 'classical.log')
topology_path = os.path.join(root,       'full.pdb')

print("DCD exists?   ", os.path.exists(dcd_path), dcd_path)
print("LOG exists?   ", os.path.exists(log_path), log_path)
print("PDB exists?   ", os.path.exists(topology_path), topology_path)

traj = md.load_dcd(dcd_path, top=topology_path)
print(f"✔ Loaded classical trajectory: {traj.n_frames} frames × {traj.n_atoms} atoms")

print("\nFirst 10 lines of classical.log:")
with open(log_path, 'r') as f:
    for _ in range(10):
        print(f.readline().rstrip())

In [None]:
# Cell 27
import glob, os
import numpy as np
import mdtraj as md
from scipy.ndimage import gaussian_filter, minimum_filter
from scipy.cluster.hierarchy import fclusterdata
from skimage.segmentation import watershed
from openmm.unit import MOLAR_GAS_CONSTANT_R, kelvin, kilojoule_per_mole
import matplotlib.pyplot as plt
import matplotlib.patheffects as pe

root      = '/content/drive/MyDrive/alanine_dipeptide'
traj_pdb  = os.path.join(root, 'full.pdb')
classical = sorted(glob.glob(os.path.join(root, 'classical', 'run*', 'classical.dcd')))

phi_all, psi_all = [], []
for dcd in classical:
    traj = md.load(dcd, top=traj_pdb, stride=10)
    top  = traj.topology
    phi_idx = [
        top.select('resname ACE and name C')[0],
        top.select('resname ALA and name N')[0],
        top.select('resname ALA and name CA')[0],
        top.select('resname ALA and name C')[0]
    ]
    psi_idx = [
        top.select('resname ALA and name N')[0],
        top.select('resname ALA and name CA')[0],
        top.select('resname ALA and name C')[0],
        top.select('resname NME and name N')[0]
    ]
    phi = np.degrees(md.compute_dihedrals(traj, [phi_idx])[:,0])
    psi = np.degrees(md.compute_dihedrals(traj, [psi_idx])[:,0])
    phi_all.append(phi)
    psi_all.append(psi)

phi = np.concatenate(phi_all)
psi = np.concatenate(psi_all)

nbins   = 180
edges   = np.linspace(-180.0, 180.0, nbins+1)
hist, _, _ = np.histogram2d(phi, psi, bins=[edges, edges], density=False)
P       = hist / hist.sum()
kT      = (MOLAR_GAS_CONSTANT_R * 300 * kelvin).value_in_unit(kilojoule_per_mole)
with np.errstate(divide='ignore'):
    F = -kT * np.log(P)
F = np.nan_to_num(F, nan=F[np.isfinite(F)].max()+1.0)

F_blur = gaussian_filter(F, sigma=2)

Fmin     = F_blur.min()
ΔG_cut   = 3.5
basin_mask = (F_blur <= Fmin + ΔG_cut)

local_min = (F_blur == minimum_filter(F_blur, size=3))
min_coords = np.argwhere(local_min & (F_blur <= Fmin + ΔG_cut))

bin_centers = 0.5 * (edges[:-1] + edges[1:])
coords_deg  = np.array([[bin_centers[i], bin_centers[j]] for i,j in min_coords])

if len(coords_deg) > 1:
    labels = fclusterdata(coords_deg, t=20.0, criterion='distance')
else:
    labels = np.ones(len(coords_deg), dtype=int)
unique_labels = np.unique(labels)

markers = np.zeros_like(F_blur, dtype=int)
for (i,j), lab in zip(min_coords, labels):
    markers[i,j] = lab

basin_labels = watershed(
    F_blur,
    markers=markers,
    mask=basin_mask,
    connectivity=1
)

basin_masks = {int(lab): (basin_labels == lab) for lab in unique_labels}
print(f"Strictly identified {len(unique_labels)} basins (from 3×80 ns):")
for lab in unique_labels:
    pts   = coords_deg[labels == lab]
    center = pts.mean(axis=0)
    print(f"  Basin {lab}: φ₀ ≈ {center[0]:.1f}°, ψ₀ ≈ {center[1]:.1f}°")


Fmin     = F_blur.min()
basin_mask = (F_blur <= Fmin + ΔG_cut)

basin_labels = watershed(
    F_blur,
    markers=markers,
    mask=basin_mask,
    connectivity=1
)
H, _, _ = np.histogram2d(phi, psi,
                        bins=[edges, edges],
                        density=True)
with np.errstate(divide='ignore'):
    F2 = -kT * np.log(H)
F2_masked = np.ma.masked_where(H == 0, F2)

xc = 0.5*(edges[:-1] + edges[1:])
yc = xc
Xc, Yc = np.meshgrid(xc, yc)

fig, ax = plt.subplots(figsize=(6,5))
cf = ax.contourf(
    Xc, Yc, F2_masked.T,
    levels=50,
    cmap='viridis'
)

for lab in unique_labels:
    mask_lab = (basin_labels == lab).astype(int)
    ax.contour(
        Xc, Yc, mask_lab.T,
        levels=[0.5],
        colors='white',
        linewidths=2
    )
for lab in unique_labels:
    phi0, psi0 = coords_deg[labels==lab].mean(axis=0)
    ax.text(phi0, psi0, f'Basin {lab}',
            color='white', fontsize=12, fontweight='bold',
            ha='center', va='center',
            path_effects=[pe.withStroke(linewidth=3, foreground='black')])

ax.set_aspect('equal')
ax.set_xlim(-180, 180)
ax.set_ylim(-180, 180)
ax.set_xlabel(r'$\phi$ (°)')
ax.set_ylabel(r'$\psi$ (°)')
ax.set_title('Classical FES with Basin Outlines')

cbar = fig.colorbar(cf, ax=ax, pad=0.02)
cbar.set_label('Free Energy (kJ/mol)')
plt.tight_layout()
plt.show()

In [None]:
# Cell 28

import os, glob, shutil
import numpy as np
import mdtraj as md

root      = '/content/drive/MyDrive/alanine_dipeptide'
traj_pdb  = os.path.join(root, 'full.pdb')
cache_dir = os.path.join(root, 'phi_psi_cache')

if os.path.exists(cache_dir):
    shutil.rmtree(cache_dir)
os.makedirs(cache_dir, exist_ok=True)

classical      = sorted(glob.glob(os.path.join(root, 'classical', 'run*', 'classical.dcd')))
force_bias     = sorted(glob.glob(os.path.join(root, 'output', 'hybrid_force_biasing', 'run*', 'hybrid.dcd')))
all_dcds       = classical + force_bias

def cache_phi_psi(dcd, stride=10):
    protocol = os.path.basename(os.path.dirname(os.path.dirname(dcd)))
    replica  = os.path.basename(os.path.dirname(dcd))
    base     = f"{protocol}_{replica}_{os.path.basename(dcd).replace('.dcd','')}"
    phi_file = os.path.join(cache_dir, f"{base}_phi.npy")
    psi_file = os.path.join(cache_dir, f"{base}_psi.npy")
    if os.path.exists(phi_file) and os.path.exists(psi_file):
        return

    traj = md.load(dcd, top=traj_pdb, stride=stride)
    top  = traj.topology

    phi_idx = [
        top.select('resname ACE and name C')[0],
        top.select('resname ALA and name N')[0],
        top.select('resname ALA and name CA')[0],
        top.select('resname ALA and name C')[0]
    ]
    psi_idx = [
        top.select('resname ALA and name N')[0],
        top.select('resname ALA and name CA')[0],
        top.select('resname ALA and name C')[0],
        top.select('resname NME and name N')[0]
    ]
    φ = np.degrees(md.compute_dihedrals(traj, [phi_idx])[:,0])
    ψ = np.degrees(md.compute_dihedrals(traj, [psi_idx])[:,0])

    np.save(phi_file, φ)
    np.save(psi_file, ψ)
    print(f"Cached φ/ψ → {base}")

def cache_omega(dcd, stride=10):
    protocol = os.path.basename(os.path.dirname(os.path.dirname(dcd)))
    replica  = os.path.basename(os.path.dirname(dcd))
    base     = f"{protocol}_{replica}_omega"
    out_fn   = os.path.join(cache_dir, f"{base}.npy")
    if os.path.exists(out_fn):
        return

    traj = md.load(dcd, top=traj_pdb, stride=stride)
    top  = traj.topology

    omega_idx = [
        top.select('resname ALA and name CA')[0],
        top.select('resname ALA and name C' )[0],
        top.select('resname NME and name N')[0],
        top.select('resname NME and name C')[0]
    ]
    ω = np.degrees(md.compute_dihedrals(traj, [omega_idx])[:,0])

    np.save(out_fn, ω)
    print(f"Cached ω     → {base}")

for dcd in all_dcds:
    cache_phi_psi(dcd)
    cache_omega(dcd)

print("All dihedral caches written to:", cache_dir)

In [None]:
# Cell 29
import glob
import numpy as np
import mdtraj as md
import pandas as pd
import matplotlib.pyplot as plt

root     = '/content/drive/MyDrive/alanine_dipeptide'
traj_pdb = f'{root}/full.pdb'
classical_paths  = sorted(glob.glob(f'{root}/classical/run*/classical.dcd'))
force_bias_paths = sorted(glob.glob(f'{root}/output/hybrid_force_biasing/run*/hybrid.dcd'))

traj_dict = {
    'classical': classical_paths,
    'hybrid': force_bias_paths
}

rows = []

top0 = md.load(classical_paths[0], top=traj_pdb).topology
omega_idx = [
    top0.select('resname ALA and name CA')[0],
    top0.select('resname ALA and name C') [0],
    top0.select('resname NME and name N')[0],
    top0.select('resname NME and name C')[0]
]

nbins     = len(edges)-1
bin_width = 360.0/nbins

for method, paths in traj_dict.items():
    counts = {lab: 0 for lab in unique_labels}
    cis = trans = total = 0

    for p in paths:
        traj  = md.load(p, top=traj_pdb, stride=10)
        phi   = np.degrees(md.compute_dihedrals(traj, [phi_idx])[:,0])
        psi   = np.degrees(md.compute_dihedrals(traj, [psi_idx])[:,0])
        omega = np.degrees(md.compute_dihedrals(traj, [omega_idx])[:,0])

        total += len(phi)
        i_phi = np.clip(((phi + 180)//bin_width).astype(int), 0, nbins-1)
        i_psi = np.clip(((psi + 180)//bin_width).astype(int), 0, nbins-1)

        labs = basin_labels[i_phi, i_psi]
        for lab in labs:
            if lab in counts:
                counts[lab] += 1

        cis   += np.sum(np.abs(omega) <  30.0)
        trans += np.sum(np.abs(omega) >= 30.0)

    P_ref = max(counts.values())/total
    P     = {lab: counts[lab]/total for lab in counts}
    dG    = {lab: -kT * np.log(P[lab]/P_ref) for lab in P}

    P_cis = cis/(cis+trans)
    dG_ct = -kT * np.log(P_cis/(1.0 - P_cis)) if cis and trans else np.nan

    row = {'method': method}
    row.update({f'P_basin{lab}': P[lab]    for lab in P})
    row.update({f'dG_basin{lab}': dG[lab]  for lab in dG})
    row['P_cis']        = P_cis
    row['dG_cis_trans'] = dG_ct
    rows.append(row)

df = pd.DataFrame(rows).set_index('method')
print(df)

pop_cols = sorted([c for c in df.columns if c.startswith('P_basin')])
dg_cols  = sorted([c for c in df.columns if c.startswith('dG_basin')])

methods  = df.index.tolist()
n_basins = len(pop_cols)
x        = np.arange(n_basins)
width    = 0.35

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,5))

for i, method in enumerate(methods):
    pops = df.loc[method, pop_cols].values
    ax1.bar(x + (i - 0.5)*width, pops, width,
            label=method.replace('_',' ').title(),
            capsize=4)
ax1.set_xticks(x)
ax1.set_xticklabels([f'Basin {c[-1]}' for c in pop_cols])
ax1.set_ylabel('Population')
ax1.set_title('Basin Populations')
ax1.legend()

for i, method in enumerate(methods):
    dgs = df.loc[method, dg_cols].values
    ax2.bar(x + (i - 0.5)*width, dgs, width,
            label=method.replace('_',' ').title(),
            capsize=4)
ax2.set_xticks(x)
ax2.set_xticklabels([f'Basin {c[-1]}' for c in dg_cols])
ax2.set_ylabel('ΔG (kJ/mol)')
ax2.set_title('Relative Free Energies')
ax2.legend()

plt.tight_layout()
plt.show()

In [None]:
# Cell 30

import os, glob
import numpy as np, pandas as pd, matplotlib.pyplot as plt

root      = '/content/drive/MyDrive/alanine_dipeptide'
dt_ps     = 0.1
nbins     = 180
cache_dir = f'{root}/phi_psi_cache'

paths = {
    'classical':            sorted(glob.glob(f'{root}/classical/run*/classical.dcd')),
    'hybrid': sorted(glob.glob(f'{root}/output/hybrid_force_biasing/run*/hybrid.dcd')),
}

def get_phi_psi(dcd):
    protocol = os.path.basename(os.path.dirname(os.path.dirname(dcd)))
    run_id   = os.path.basename(os.path.dirname(dcd))
    base     = os.path.basename(dcd).replace('.dcd','')
    phi = np.load(os.path.join(cache_dir, f"{protocol}_{run_id}_{base}_phi.npy"))
    psi = np.load(os.path.join(cache_dir, f"{protocol}_{run_id}_{base}_psi.npy"))
    return phi, psi

def compute_mfpt(labels, A, B):
    idx = np.where(labels==A)[0]
    times = []
    for i in idx:
        sub = labels[i+1:]
        f = np.where(sub==B)[0]
        if f.size:
            times.append(f[0]+1)
    return np.nan if not times else np.mean(times)*dt_ps/1000

data = {m: [] for m in paths}
for method, dcds in paths.items():
    for dcd in dcds:
        phi, psi = get_phi_psi(dcd)
        bw = 360/nbins
        i_phi = np.clip(((phi+180)//bw).astype(int), 0, nbins-1)
        i_psi = np.clip(((psi+180)//bw).astype(int), 0, nbins-1)
        labels = basin_labels[i_phi, i_psi]

        rec = {}
        for A in unique_labels:
            for B in unique_labels:
                if A==B: continue
                m = compute_mfpt(labels, A, B)
                rec[f'MFPT_{A}→{B}'] = m
                rec[f'k_{A}→{B}'] = np.nan if np.isnan(m) else 1/m
        data[method].append(rec)

def bootstrap_ci(arr, n=10000):
    arr = np.array(arr)
    boots = [np.nanmean(np.random.choice(arr, len(arr), replace=True)) for _ in range(n)]
    return np.nanpercentile(boots, [2.5,97.5])

summary = {}
for method, recs in data.items():
    dfm = pd.DataFrame(recs)
    stats = {}
    for col in dfm:
        vals = dfm[col].dropna().values
        mean = np.nanmean(vals)
        lo, hi = bootstrap_ci(vals) if len(vals)>1 else (np.nan,np.nan)
        stats[col] = {'mean': mean, 'lo': lo, 'hi': hi}
    summary[method] = pd.DataFrame(stats).T

out = pd.concat(summary, axis=1)
print(out)

mfpt_keys = [k for k in out.index if k.startswith('MFPT')]
methods   = list(paths.keys())
width     = 0.8 / len(methods)
x1        = np.arange(len(mfpt_keys))

fig, ax = plt.subplots(figsize=(8,4))
for i, method in enumerate(methods):
    means = out[(method,'mean')].loc[mfpt_keys]
    err_lo = means - out[(method,'lo')].loc[mfpt_keys]
    err_hi = out[(method,'hi')].loc[mfpt_keys] - means
    ax.bar(x1 + (i - (len(methods)-1)/2)*width, means, width,
           yerr=[err_lo, err_hi], label=method.replace('_',' ').title())
ax.set_xticks(x1)
ax.set_xticklabels(mfpt_keys, rotation=45, ha='right')
ax.set_ylabel('MFPT (ns)')
ax.set_title('MFPT ±95% CI')
ax.legend()
plt.tight_layout()
plt.show()

rate_keys = [k for k in out.index if k.startswith('k_')]
x2        = np.arange(len(rate_keys))

fig, ax = plt.subplots(figsize=(8,4))
for i, method in enumerate(methods):
    means = out[(method,'mean')].loc[rate_keys]
    err_lo = means - out[(method,'lo')].loc[rate_keys]
    err_hi = out[(method,'hi')].loc[rate_keys] - means
    ax.bar(x2 + (i - (len(methods)-1)/2)*width, means, width,
           yerr=[err_lo, err_hi], label=method.replace('_',' ').title())
ax.set_xticks(x2)
ax.set_xticklabels(rate_keys, rotation=45, ha='right')
ax.set_ylabel('Rate (1/ns)')
ax.set_title('Transition Rates ±95% CI')
ax.legend()
plt.tight_layout()
plt.show()

In [None]:
# Cell 31

import os, glob, numpy as np, matplotlib.pyplot as plt
from numpy.linalg import eigvals

root      = '/content/drive/MyDrive/final_folder'
cache_dir = os.path.join(root, 'phi_psi_cache')

def discrete_traj(dcd):
    proto   = os.path.basename(os.path.dirname(os.path.dirname(dcd)))
    run_id  = os.path.basename(os.path.dirname(dcd))
    feat    = os.path.basename(dcd).replace('.dcd','')
    base    = f"{proto}_{run_id}_{feat}"
    phi     = np.load(os.path.join(cache_dir, f"{base}_phi.npy"))
    psi     = np.load(os.path.join(cache_dir, f"{base}_psi.npy"))
    nb      = basin_labels.shape[0]
    bw      = 360.0/nb
    i       = np.clip(((phi + 180.0)//bw).astype(int), 0, nb-1)
    j       = np.clip(((psi + 180.0)//bw).astype(int), 0, nb-1)
    return (basin_labels[i, j] - 1).astype(int)

cls_paths = sorted(glob.glob(f'{root}/classical/run*/classical.dcd'))
hyb_paths = sorted(glob.glob(f'{root}/output/hybrid_force_biasing/run*/hybrid.dcd'))

d_cls = [discrete_traj(p) for p in cls_paths]
d_hyb = [discrete_traj(p) for p in hyb_paths]

n_states  = int(basin_labels.max())
lags      = [1, 10, 20]
dt_ns     = 0.1 / 1000.0

def count_matrix(dtrajs, lag):
    C = np.zeros((n_states, n_states), dtype=int)
    for traj in dtrajs:
        for t in range(len(traj) - lag):
            C[traj[t], traj[t+lag]] += 1
    return C

its = {}
for label, data in [('Classical', d_cls), ('Hybrid', d_hyb)]:
    τ1 = []
    for lag in lags:
        C    = count_matrix(data, lag).astype(float)
        rows = C.sum(axis=1, keepdims=True)
        T    = np.where(rows > 0, C/rows, 0.0)
        λ1   = np.sort(np.real(eigvals(T)))[-2]
        τ1.append(-(lag*dt_ns)/np.log(λ1))
    its[label] = τ1

print("\nMode-1 implied timescales (ns):")
for lbl, t in its.items():
    print(f"  {lbl}: {[f'{x:.4f}' for x in t]}")

lag0 = 10
C0   = count_matrix(d_cls, lag0).astype(float)
rows = C0.sum(axis=1, keepdims=True)
T0   = np.where(rows > 0, C0/rows, 0.0)

print("\nCK test errors (classical, τ₀ = 1 ps):")
for m in [1, 2, 5]:
    Emp   = count_matrix(d_cls, lag0*m).astype(float)
    rows  = Emp.sum(axis=1, keepdims=True)
    Emp   = np.where(rows > 0, Emp/rows, 0.0)
    Pred  = np.linalg.matrix_power(T0, m)
    L1    = np.abs(Emp-Pred).sum()
    maxel = np.abs(Emp-Pred).max()
    print(f"  m={m}:  L₁ = {L1:.4f}   max|ΔP| = {maxel:.4f}")

lags_ns = np.array(lags) * dt_ns
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11,4.5))

for lbl, t in its.items():
    ax1.plot(lags_ns, t, marker='o', label=lbl)
ax1.set(xlabel='Lag τ (ns)', ylabel='ITS τ₁ (ns)',
        title='Mode-1 Implied Timescales')
ax1.legend(); ax1.grid(True)

m    = 5
Emp5 = count_matrix(d_cls, lag0*m).astype(float)
Emp5 = np.where(Emp5.sum(axis=1,keepdims=True)>0,
                Emp5/Emp5.sum(axis=1,keepdims=True), 0.0)
im = ax2.imshow(Emp5, vmin=0, vmax=1, cmap='viridis')
ax2.set(title=f'Classical CK (m={m}, τ₀=1 ps)',
        xlabel='to state', ylabel='from state')
fig.colorbar(im, ax=ax2, shrink=0.8)

plt.tight_layout(); plt.show()

In [None]:
#Cell 32

import os, glob, sys, subprocess, importlib
import numpy as np
import matplotlib.pyplot as plt

def _ensure(pkgs):
    for p in pkgs:
        try: importlib.import_module(p)
        except ImportError: subprocess.check_call([sys.executable, "-m", "pip", "install", p, "-q"])
_ensure(["numpy","matplotlib","mdtraj","pymbar","scikit-learn","scipy"])

import mdtraj as md
from pymbar import timeseries
from sklearn.mixture import GaussianMixture
from scipy.stats import norm

ROOT      = '/content/drive/MyDrive/alanine_dipeptide'
TOPO      = os.path.join(ROOT, 'full.pdb')
CACHE_DIR = os.path.join(ROOT, 'phi_psi_cache')

CL_GLOB = os.path.join(ROOT, 'classical', 'run*', 'classical.dcd')
FB_GLOB = os.path.join(ROOT, 'output', 'hybrid_force_biasing', 'run*', 'hybrid.dcd')

STRIDE    = 10
DT_PS     = 0.1
DT_EFF_PS = STRIDE * DT_PS

SIM_H = {'classical': 1.40, 'hybrid': 1.41}
HYBRID_OVERHEAD_H = (2.0/60.0) + 0.07 + (1.0/60.0)
HOURS = {
    'classical': SIM_H['classical'],
    'hybrid_force_biasing': SIM_H['hybrid_force_biasing'] + HYBRID_OVERHEAD_H
}

BOOTSTRAP_B = 2000
RNG = np.random.default_rng(0)

def dihedral_idx_phi(top):
    return [
        top.select('resname ACE and name C')[0],
        top.select('resname ALA and name N')[0],
        top.select('resname ALA and name CA')[0],
        top.select('resname ALA and name C')[0],
    ]

def load_phi_from_cache_or_dcd(protocol, dcd_path):
    """Try cache first; fall back to computing from DCD."""
    run_id  = os.path.basename(os.path.dirname(dcd_path))
    base    = os.path.basename(dcd_path).replace('.dcd','')
    phi_fn  = os.path.join(CACHE_DIR, f"{protocol}_{run_id}_{base}_phi.npy")
    if os.path.exists(phi_fn):
        return np.load(phi_fn)

    t = md.load(dcd_path, top=TOPO, stride=STRIDE)
    idx = dihedral_idx_phi(t.topology)
    return np.degrees(md.compute_dihedrals(t, [idx])[:,0])

def circ_ess_count(phi_deg):
    x = np.deg2rad(np.asarray(phi_deg, float))
    s, c = np.sin(x), np.cos(x)
    def safe_g(a):
        if a.size < 10 or np.nanvar(a) < 1e-12: return np.nan
        try:
            g = timeseries.statistical_inefficiency(a)
            return g if (np.isfinite(g) and g>0) else np.nan
        except Exception:
            return np.nan
    g = np.nanmax([safe_g(s), safe_g(c)])
    if not np.isfinite(g): return np.nan
    return len(x)/g

def ess_per_gpu_hour(protocol, dcds):
    vals = []
    for dcd in dcds:
        phi = load_phi_from_cache_or_dcd(protocol, dcd)
        ESS = circ_ess_count(phi)
        h   = HOURS.get(protocol, np.nan)
        vals.append(ESS / h if (np.isfinite(ESS) and np.isfinite(h) and h>0) else np.nan)
    return np.array(vals, float)

cl_dcds = sorted(glob.glob(CL_GLOB))
fb_dcds = sorted(glob.glob(FB_GLOB))
if not cl_dcds: raise SystemExit("No classical DCDs found.")
if not fb_dcds: print("No Hybrid FB DCDs found under 'output/...'; violin will still show the axis.")

cl_vals_all = ess_per_gpu_hour('classical', cl_dcds)
fb_vals_all = ess_per_gpu_hour('hybrid_force_biasing', fb_dcds)

x_cl = cl_vals_all[np.isfinite(cl_vals_all) & (cl_vals_all > 0)]
if x_cl.size < 3:
    raise SystemExit("Need ≥3 classical runs with finite positive ESSφ/GPU-hour for a stable GMM.")

logx = np.log10(x_cl).reshape(-1,1)
g1 = GaussianMixture(1, covariance_type='full', random_state=0).fit(logx)
g2 = GaussianMixture(2, covariance_type='full', random_state=0).fit(logx)
bic1, bic2 = g1.bic(logx), g2.bic(logx)
dBIC = bic2 - bic1

w = g2.weights_; m = g2.means_.ravel(); s = np.sqrt(g2.covariances_.ravel())
ordr = np.argsort(m); w, m, s = w[ordr], m[ordr], s[ordr]
A = 1/(2*s[1]**2) - 1/(2*s[0]**2)
B = m[0]/(s[0]**2) - m[1]/(s[1]**2)
C = (m[1]**2)/(2*s[1]**2) - (m[0]**2)/(2*s[0]**2) + np.log((w[1]*s[0])/(w[0]*s[1]))
thr_log  = (-B + np.sign(B)*np.sqrt(max(0.0, B*B - 4*A*C)))/(2*A) if abs(A)>1e-12 else -C/B
ESS_thr_h = 10**thr_log

wins_2, thr_boot = 0, []
for _ in range(BOOTSTRAP_B):
    s_ = RNG.choice(x_cl, size=len(x_cl), replace=True)
    s_ = s_[(s_>0) & np.isfinite(s_)]
    if s_.size < 3: continue
    lx = np.log10(s_).reshape(-1,1)
    g1b = GaussianMixture(1, covariance_type='full', random_state=0).fit(lx)
    g2b = GaussianMixture(2, covariance_type='full', random_state=0).fit(lx)
    dB  = g2b.bic(lx) - g1b.bic(lx)
    if dB < -10: wins_2 += 1
    wb = g2b.weights_; mb = g2b.means_.ravel(); sb = np.sqrt(g2b.covariances_.ravel())
    o  = np.argsort(mb); wb, mb, sb = wb[o], mb[o], sb[o]
    A  = 1/(2*sb[1]**2) - 1/(2*sb[0]**2)
    B  = mb[0]/(sb[0]**2) - mb[1]/(sb[1]**2)
    C  = (mb[1]**2)/(2*sb[1]**2) - (mb[0]**2)/(2*sb[0]**2) + np.log((wb[1]*sb[0])/(wb[0]*sb[1]))
    thr_b = (-B + np.sign(B)*np.sqrt(max(0.0, B*B - 4*A*C)))/(2*A) if abs(A)>1e-12 else -C/B
    thr_boot.append(10**thr_b)

thr_boot = np.array(thr_boot, float)
support_pct = 100.0 * wins_2 / max(1, len(thr_boot))
thr_ci = np.percentile(thr_boot, [2.5, 50, 97.5]) if thr_boot.size else [np.nan]*3

n_cl = np.isfinite(cl_vals_all).sum()
n_fb = np.isfinite(fb_vals_all).sum()
print("\n=== Classical ESSφ per GPU-hour bimodality ===")
print(f"ΔBIC (2 − 1) on log10(ESSφ/GPU-h): {dBIC:.2f}  (negative favors two components)")
print(f"Mode-separating threshold (ESSφ/GPU-h): {ESS_thr_h:.3g}")
print(f"Bootstrap support for two modes (ΔBIC < −10): {support_pct:.1f}%  (N={len(thr_boot)})")
print(f"Threshold stability (ESSφ/GPU-h): median {thr_ci[1]:.3g}, 95% CI [{thr_ci[0]:.3g}, {thr_ci[2]:.3g}]")
print(f"Included runs → Classical: {n_cl}, Hybrid FB: {n_fb}")

fig, axes = plt.subplots(1, 2, figsize=(11, 3.6))

log_vals = np.log10(x_cl)
axes[0].hist(log_vals, bins=10, density=True, alpha=0.6)
xs = np.linspace(log_vals.min()-0.2, log_vals.max()+0.2, 400)
pdf_mix = (w[0]*norm.pdf(xs, m[0], s[0]) + w[1]*norm.pdf(xs, m[1], s[1]))
axes[0].plot(xs, pdf_mix, lw=2)
axes[0].axvline(thr_log, ls='--', lw=1)
axes[0].set_xlabel('log10(ESSφ / GPU-hour)')
axes[0].set_ylabel('density')
axes[0].set_title('Classical: 2-GMM fit & mode split (ESS/GPU-h)')

cl_plot = cl_vals_all[np.isfinite(cl_vals_all) & (cl_vals_all > 0)]
fb_plot = fb_vals_all[np.isfinite(fb_vals_all) & (fb_vals_all > 0)]
data    = [cl_plot, fb_plot]
labels  = ['Classical', 'Hybrid FB']

plot_data = [d if d.size else np.array([np.nan]) for d in data]
parts = axes[1].violinplot(plot_data, showmeans=True, showextrema=False)
for b in parts['bodies']:
    b.set_alpha(0.4)

for i, vals in enumerate(data, start=1):
    v = vals[np.isfinite(vals)]
    if v.size:
        x = np.random.normal(loc=i, scale=0.05, size=v.size)
        axes[1].scatter(x, v, s=18, alpha=0.85)

axes[1].set_yscale('log')
axes[1].set_xticks([1,2]); axes[1].set_xticklabels(labels)
axes[1].axhline(ESS_thr_h, ls='--', lw=1)
axes[1].set_ylabel('ESSφ per GPU-hour (log scale)')
axes[1].set_title('Per-run ESSφ/GPU-h with classical threshold')

plt.tight_layout()
plt.show()

print(f"\n[Export] ESS_thr_h = {ESS_thr_h:.6g}  (use for FAST/SLOW by ESS/GPU-h)")

In [None]:
# Cell 33

import os, glob
import numpy as np
import matplotlib.pyplot as plt

root = '/content/drive/MyDrive/alanine_dipeptide'

acceptance = {}
for prot in ['hybrid_force_biasing']:
    acc_paths = sorted(glob.glob(os.path.join(root, 'output', prot, 'run*', 'accepted.npy')))
    rates = [np.load(f).mean() for f in acc_paths]
    n = len(rates)
    if n > 1:
        mean_rate = np.mean(rates)
        from scipy.stats import t
        se = np.std(rates, ddof=1) / np.sqrt(n)
        t_mult = t.ppf(0.975, df=n-1)
        delta = t_mult * se
    else:
        mean_rate = rates[0] if n==1 else np.nan
        delta = np.nan

    acceptance[prot] = {
        'paths':  acc_paths,
        'rates':  rates,
        'mean':   mean_rate,
        'delta':  delta
    }

for prot, info in acceptance.items():
    print(f"Protocol: {prot}")
    print("  Files:      ", info['paths'])
    print("  Rates:      ", np.round(info['rates'], 3))
    mean  = info['mean']
    delta = info['delta']
    print(f"  Mean ±95% CI: {mean:.3f} ± {delta:.3f}")

In [None]:
# Cell 34

import os, glob, math, warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pymbar import timeseries

warnings.filterwarnings("ignore")

root       = '/content/drive/MyDrive/alanine_dipeptide'
cache_dir  = os.path.join(root, 'phi_psi_cache')

if 'STRIDE' not in globals():  STRIDE = 10
if 'DT_PS' not in globals():   DT_PS  = 0.1

DT_EFF_PS  = STRIDE * DT_PS
DT_EFF_NS  = DT_EFF_PS / 1000.0

if 'HOURS' not in globals():
    SIM_H = {'classical': 1.40, 'hybrid_force_biasing': 1.41}
    HYBRID_OVERHEAD_H = (2.0/60.0) + 0.07 + (1.0/60.0)
    HOURS = {
        'classical': SIM_H['classical'],
        'hybrid_force_biasing': SIM_H['hybrid_force_biasing'] + HYBRID_OVERHEAD_H,
    }
ALLOWED_PROTOCOLS = set(HOURS.keys())

if 'ESS_thr_h' not in globals():
    raise RuntimeError("ESS_thr_h not found. Run the bimodality cell (classical-only GMM on log10(ESSφ/GPU-hour)) first.")

def circ_ess_per_ns(angle_deg):
    """Return (ESS per ns, tau_int in ps) via max{g[sin], g[cos]} for a circular variable."""
    x = np.deg2rad(np.asarray(angle_deg, float))
    s, c = np.sin(x), np.cos(x)
    def _g(a):
        if a.size < 10 or np.nanvar(a) < 1e-16:
            return np.nan
        try:
            g = timeseries.statistical_inefficiency(a)
            return g if (np.isfinite(g) and g > 0) else np.nan
        except Exception:
            return np.nan
    g = np.nanmax([_g(s), _g(c)])
    if not np.isfinite(g):
        return np.nan, np.nan
    ESS_total = len(x) / g
    traj_ns   = len(x) * DT_EFF_NS
    return (ESS_total / traj_ns) if traj_ns > 0 else np.nan, (g * DT_EFF_PS)

def bootstrap_ci_mean(vals, nboot=10000, seed=7):
    v = np.asarray(vals, float); v = v[np.isfinite(v)]
    if v.size < 2:
        m = float(v.mean()) if v.size == 1 else np.nan
        return m, np.nan, np.nan
    rng = np.random.default_rng(seed)
    bs = rng.choice(v, size=(nboot, v.size), replace=True).mean(axis=1)
    return float(v.mean()), float(np.percentile(bs, 2.5)), float(np.percentile(bs, 97.5))

def perm_test_diff_means(a, b, nperm=20000, seed=7):
    a = np.asarray(a, float); b = np.asarray(b, float)
    a = a[np.isfinite(a)]; b = b[np.isfinite(b)]
    if a.size == 0 or b.size == 0:
        return np.nan
    obs = np.nanmean(a) - np.nanmean(b)
    comb = np.concatenate([a, b])
    rng = np.random.default_rng(seed)
    cnt = 0
    for _ in range(nperm):
        rng.shuffle(comb)
        aa = comb[:len(a)]; bb = comb[len(a):]
        cnt += (abs(aa.mean() - bb.mean()) >= abs(obs))
    return (cnt + 1) / (nperm + 1)

phi_files = sorted(glob.glob(os.path.join(cache_dir, '*_phi.npy')))
records = []

for fphi in phi_files:
    base = os.path.basename(fphi)[:-8]
    parts = base.rsplit('_', 2)
    if len(parts) != 3:
        continue
    protocol, run_id, feat = parts
    if protocol not in ALLOWED_PROTOCOLS:
        continue

    fpsi = os.path.join(cache_dir, f"{base}_psi.npy")
    if not os.path.exists(fpsi):
        continue

    phi = np.load(fphi)
    psi = np.load(fpsi)

    ess_phi_ns, _ = circ_ess_per_ns(phi)
    ess_psi_ns, _ = circ_ess_per_ns(psi)

    traj_ns = len(phi) * DT_EFF_NS
    ess_phi_per_h = (ess_phi_ns * traj_ns) / HOURS[protocol] if np.isfinite(ess_phi_ns) else np.nan
    ess_psi_per_h = (ess_psi_ns * traj_ns) / HOURS[protocol] if np.isfinite(ess_psi_ns) else np.nan

    records.append({
        'protocol': protocol,
        'run_id': run_id,
        'ESS_phi_per_h':  ess_phi_per_h,
        'ESS_psi_per_h':  ess_psi_per_h
    })

if not records:
    raise RuntimeError("No valid φ/ψ cache pairs found for allowed protocols. Check cache_dir naming.")

df = pd.DataFrame.from_records(records).sort_values(['protocol','run_id']).reset_index(drop=True)

df['mode_phi'] = np.where(df['ESS_phi_per_h'] >= ESS_thr_h, 'FAST', 'SLOW')

def group_stats(metric):
    A = df.loc[df['mode_phi']=='FAST', metric].values
    B = df.loc[df['mode_phi']=='SLOW', metric].values
    mA, loA, hiA = bootstrap_ci_mean(A)
    mB, loB, hiB = bootstrap_ci_mean(B)
    ratio = np.nan if not (np.isfinite(mA) and np.isfinite(mB) and mB>0) else mA/mB
    p = perm_test_diff_means(A, B)
    return {'A_mean':mA,'A_lo':loA,'A_hi':hiA,'NA':np.sum(np.isfinite(A)),
            'B_mean':mB,'B_lo':loB,'B_hi':hiB,'NB':np.sum(np.isfinite(B)),
            'ratio':ratio,'p_perm':p}

S_phi_h = group_stats('ESS_phi_per_h')
S_psi_h = group_stats('ESS_psi_per_h')

def fmt(s):
    return (f"FAST mean={s['A_mean']:.3g} 95%CI[{s['A_lo']:.3g},{s['A_hi']:.3g}] (N={s['NA']}); "
            f"SLOW mean={s['B_mean']:.3g} 95%CI[{s['B_lo']:.3g},{s['B_hi']:.3g}] (N={s['NB']}); "
            f"FAST/SLOW={s['ratio']:.3g}; perm p={s['p_perm']:.3g}")

print("\n=== FAST vs SLOW (by classical φ threshold on ESS/GPU-hour) — GPU-hour metrics ===")
print("φ ESS per GPU-hour: ", fmt(S_phi_h))
print("ψ ESS per GPU-hour: ", fmt(S_psi_h))

def plot_bar(metric, title, ylabel):
    groups = ['FAST','SLOW']
    means = [np.nanmean(df.loc[df['mode_phi']==g, metric].values) for g in groups]
    lois, his = [], []
    for g in groups:
        vals = df.loc[df['mode_phi']==g, metric].values
        m, lo, hi = bootstrap_ci_mean(vals, nboot=5000, seed=7)
        lois.append(m - lo); his.append(hi - m)
    xpos = np.arange(len(groups))
    plt.figure(figsize=(4,3))
    plt.bar(xpos, means, yerr=[lois, his], capsize=4)
    plt.xticks(xpos, groups); plt.ylabel(ylabel); plt.title(title)
    plt.tight_layout(); plt.show()

plot_bar('ESS_phi_per_h', 'φ: ESS per GPU-hour (FAST vs SLOW)', 'ESSφ / GPU-hour')
plot_bar('ESS_psi_per_h', 'ψ: ESS per GPU-hour (FAST vs SLOW)', 'ESSψ / GPU-hour')

In [None]:
#Cell 35
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if 'edge_index' in globals() and isinstance(edge_index, torch.Tensor):
    N_nodes = int(edge_index.max().item()) + 1
else:
    N_nodes = 3

T0 = globals().get('T0', 150)

batch_vec = torch.zeros(N_nodes, dtype=torch.long, device=device)

t_node = torch.full((N_nodes,), int(T0), dtype=torch.long, device=device)

if 'edge_index' in globals() and isinstance(edge_index, torch.Tensor):
    edge_index = edge_index.to(device=device, dtype=torch.long)

print(f"Synthesized batch_vec ({batch_vec.shape}), t_node ({t_node.shape}), T0={T0}, N_nodes={N_nodes}")

In [None]:
import numpy as np, torch, glob, os, re, matplotlib.pyplot as plt
from scipy.stats import skew, wilcoxon

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if 'edge_index' in globals() and isinstance(edge_index, torch.Tensor):
    N_nodes = int(edge_index.max().item()) + 1
else:
    N_nodes = 3

T0 = globals().get('T0', 150)
batch_vec = torch.zeros(N_nodes, dtype=torch.long, device=device)
t_node    = torch.full((N_nodes,), int(T0), dtype=torch.long, device=device)
if 'edge_index' in globals() and isinstance(edge_index, torch.Tensor):
    edge_index = edge_index.to(device=device, dtype=torch.long)

print(f"Synthesized batch_vec ({batch_vec.shape}), t_node ({t_node.shape}), T0={T0}, N_nodes={N_nodes}")

def wrap(a): return (a + np.pi) % (2*np.pi) - np.pi
def ang_to_feat(a): return np.stack([np.sin(a), np.cos(a)], axis=-1)
def feat_to_ang(feat): return np.arctan2(feat[...,0], feat[...,1])

def propose_from_torsions(theta_np):
    """theta_np: (3,) [phi, psi, omega] radians -> returns wrapped (3,) new angles"""
    theta_np = np.asarray(theta_np).reshape(-1)
    if theta_np.shape[0] < 3:
        theta_np = np.concatenate([theta_np, [np.pi]])
    x_feats = ang_to_feat(theta_np)
    if N_nodes == 2:
        x_feats = x_feats[:2]
    x_t = torch.as_tensor(x_feats, dtype=torch.float32, device=device)

    with torch.no_grad():
        try:
            y = mh_jump(x_t)
        except TypeError:
            y = mh_jump(x_t, edge_index, batch_vec, t_node)

    if isinstance(y, (list, tuple)):
        y = y[0]
    if isinstance(y, torch.Tensor):
        y = y.detach().cpu().numpy()
    y = np.asarray(y)

    if y.ndim == 2 and y.shape[1] == 2:
        ang_out = feat_to_ang(y)
    elif y.ndim == 1 and y.shape[0] == (2 if N_nodes==2 else 3):
        ang_out = y
    else:
        raise RuntimeError(f"Unexpected mh_jump output shape {y.shape}")
    if N_nodes == 2:
        ang_out = np.concatenate([ang_out, [theta_np[2]]])
    return wrap(ang_out)

def bootstrap_ci(vals, iters=5000, alpha=0.05, rng=np.random.default_rng(0)):
    vals = np.asarray(vals)
    n = len(vals)
    idxs = rng.integers(0, n, size=(iters, n))
    means = vals[idxs].mean(axis=1)
    lo, hi = np.percentile(means, [100*alpha/2, 100*(1-alpha/2)])
    return float(lo), float(hi)

assert 'torsions' in globals(), "Expected 'torsions' array in scope (T x 2/3). Load caches before running."
if torsions.shape[1] == 2:
    torsions = np.concatenate([torsions, np.full((len(torsions),1), np.pi)], axis=1)

N = min(1000, len(torsions))
rng = np.random.default_rng(42)
idxs = rng.choice(len(torsions), size=N, replace=False)

disp_phi, disp_psi = [], []
first_exc = None
for i in idxs:
    x = torsions[i]
    try:
        x_prop = propose_from_torsions(x)
        d = wrap(x_prop - x)
        if np.all(np.isfinite(d)):
            disp_phi.append(d[0]); disp_psi.append(d[1])
    except Exception as e:
        if first_exc is None:
            first_exc = e

successes = len(disp_phi)
print(f"\nSuccessful proposals: {successes} / {N}")
if successes == 0:
    raise RuntimeError(f"No proposals succeeded. First exception:\n{repr(first_exc)}")

disp_phi = np.array(disp_phi); disp_psi = np.array(disp_psi)

def summarize(label, arr):
    m = float(arr.mean())
    lo, hi = bootstrap_ci(arr, iters=3000)
    sk = float(skew(arr))
    nz = arr[np.abs(arr) > 1e-12]
    p = float(wilcoxon(nz).pvalue) if len(nz) >= 10 else float("nan")
    print(f"{label}: mean={m:+.4f} rad  (95% CI [{lo:+.4f}, {hi:+.4f}]),  skew={sk:+.3f},  sign-test p={p:.3f}")
    return m, (lo, hi), sk, p

print("\n=== Displacement symmetry (ϕ/ψ) @ N=1000 ===")
m_phi, ci_phi, sk_phi, p_phi = summarize("phi", disp_phi)
m_psi, ci_psi, sk_psi, p_psi = summarize("psi", disp_psi)

bins = 40
fig, axes = plt.subplots(1, 2, figsize=(10, 4), dpi=200)

ax = axes[0]
ax.hist(disp_phi, bins=bins, density=True, color="C0", edgecolor="black", linewidth=0.6)
ax.axvline(0.0, linestyle='--', color="black", linewidth=2)
ax.axvline(disp_phi.mean(), linestyle='-', color="red", linewidth=2)
ax.set_title(r"$\Delta \phi$ distribution", fontsize=16)
ax.set_xlabel("radians", fontsize=14)
ax.set_ylabel("density", fontsize=14)
ax.tick_params(axis='both', labelsize=12)

ax = axes[1]
ax.hist(disp_psi, bins=bins, density=True, color="C0", edgecolor="black", linewidth=0.6)
ax.axvline(0.0, linestyle='--', color="black", linewidth=2)
ax.axvline(disp_psi.mean(), linestyle='-', color="red", linewidth=2)
ax.set_title(r"$\Delta \psi$ distribution", fontsize=16)
ax.set_xlabel("radians", fontsize=14)
ax.set_ylabel("density", fontsize=14)
ax.tick_params(axis='both', labelsize=12)

plt.tight_layout()
plt.show()

In [None]:
#Cell 37
from scipy.stats import fisher_exact
table = [[7,2], [1,6]]
OR, p = fisher_exact(table, alternative='two-sided')
print(OR, p)