In [1]:
from ITO_DDPM import *
from architechtures.PaiNN_Like import *
from architechtures.embeddings import *
import torch.optim as optim
import torch
import wandb
from tqdm import tqdm
import sys
import os
current_dir = os.path.dirname(os.curdir)
parent_dir = os.path.abspath(os.path.join(current_dir, "../SPARK/"))
sys.path.append(parent_dir)
from utils import *
device = "cuda"
from torch.serialization import add_safe_globals

add_safe_globals([ImplicitTransferOperatorDDPM])

torch.autograd.set_detect_anomaly(True)
torch.set_printoptions(profile="full")

# Crash on NaNs/Infs
torch._C._set_warnAlways(True)



╔═══════════════════════════════════════════════════╗
║                                                   ║
║  ██████╗   ██████╗    ██╗      ██████╗   ██╗  ██╗ ║
║ ██╔════╝  ██╔══██╗   ██╔██╗    ██╔══██╗  ██║ ██╔╝ ║
║ ╚█████╗   ██████╔╝  ██╔╝╚██╗   ██████╔╝  █████╔╝  ║
║  ╚═══██╗  ██╔═══╝  ██╔╝  ╚██╗  ██╔══██╗  ██╔═██╗  ║
║ ██████╔╝  ██║     ██╔╝    ╚██╗ ██║  ██║  ██║ ╚██╗ ║
║ ╚═════╝   ╚═╝     ╚═╝      ╚═╝ ╚═╝  ╚═╝  ╚═╝  ╚═╝ ║
║                                                   ║
║     Statistical Physics Autodiff Research Kit     ║
╚═══════════════════════════════════════════════════╝

          V(r)           ψ, φ              q
           │               │               │
           ○               ○               ○
         ╱ | ╲           ╱ | ╲           ╱ | ╲
        ○  ○  ○         ○  ○  ○         ○  ○  ○
         ╲ | ╱           ╲ | ╱           ╲ | ╱
           ○               ○               ○
           │               │               │
          g(r)             F         

## load in a dataset

In [2]:
data_path = "datasets/AlanaineDipeptideVacuum/"
base_path = f"{data_path}ADP_Vacuum"
pos_path = base_path + "_position"
mom_path = base_path + "_momentum"

position_data = torch.load(pos_path, map_location=device)
momentum_data = torch.load(mom_path, map_location=device)

# Check for NaNs or Infs in position and momentum data
bad_pos = ~torch.isfinite(position_data).all(dim=(0, 2, 3))  # shape: (16384,)
bad_mom = ~torch.isfinite(momentum_data).all(dim=(0, 2, 3))  # shape: (16384,)

# Combine masks and negate to get valid ones
valid_mask = ~(bad_pos | bad_mom)

# Filter both tensors
position_data = position_data[:, valid_mask, :, :]
momentum_data = momentum_data[:, valid_mask, :, :]

print(f"Kept {valid_mask.sum().item()} / {valid_mask.shape[0]} trajectories.")

top, node_features, mass, energy_dict = build_top_and_features("datasets/AlanaineDipeptideVacuum/alanine-dipeptide.prmtop")

Kept 16357 / 16384 trajectories.


## create the model

In [3]:
# Noise Schedule Settings
low = -8.0
high = -4.0
diffusion_steps = 1000
noise_schedule = SigmoidNoiseSchedule(low, high, diffusion_steps, device)
t_diff_max = diffusion_steps
s_phys_max = 198

# Architecture Settings
C_x = 2
C_v_i = 1
C_z_i = 2
C = 64
C_v = C
C_z = C
C_t = C
C_s = C
N = 22
B = 16

f_0_layers = [C_s + C_t + C_v + C_z_i, C_s + C_t + C_v + C_z_i, C_z]
f_1_layers = [C_x + C_z, C_x + C_z, C_x]
f_2_layers = [C_x + C_z, C_x + C_z, C_v]
f_3_layers = [C_x + C_z, C_x + C_z, C_z]
f_4_layers = [C_v + C_z, C_v + C_z, C_v]
f_5_layers = [C_v + C_z, C_v + C_z, C_z]

W_1_layers = [(2 * C_x + C_v_i, 3), (C_v, 3)]
W_2_layers = [(C_x, 3), (C_v, 3)]
W_3_layers = [(C_v, 3), (C_v, 3)]
W_4_layers = [(C_v, 3), (C_x, 3)]

message_passing_steps = 3
p = 0
activation_function = nn.Tanh()

t_diff_embedding = SinCosTimeEmbedding(C_t, max_t=t_diff_max, init_scale=1.0, learnable_scale=True, device=device)
s_phys_embedding = SinCosTimeEmbedding(C_s, max_t=s_phys_max, init_scale=1.0, learnable_scale=True, device=device)

EGNN = EquivariantGraphNeuralNetwork(
    f_0_layers, f_1_layers, f_2_layers, f_3_layers, f_4_layers, f_5_layers,
    W_1_layers, W_2_layers, W_3_layers, W_4_layers,
    message_passing_steps, p, activation_function,
    t_diff_embedding, s_phys_embedding,
    device
).to(device)

DDPM = ImplicitTransferOperatorDDPM(
    EGNN,
    noise_schedule,
    t_diff_max,
    s_phys_max,
    device
).to(device)

# Try loading full model first, fall back to state_dict
try:
    print("Trying to load full model object...")
    DDPM = torch.load("16_model_131072.pth", map_location=device, weights_only=False)
    DDPM.eval()
    print(f"Loaded model with {sum(p.numel() for p in DDPM.parameters() if p.requires_grad):,} parameters.")
except Exception as e:
    print("Failed to load model")
    print(f"Created model with {sum(p.numel() for p in DDPM.parameters() if p.requires_grad):,} parameters.")

Trying to load full model object...
Loaded model with 126,496 parameters.


In [None]:
plot = t_diff_embedding.plot_embedding()
plot = s_phys_embedding.plot_embedding()
plot = noise_schedule.plot("Noise Schedule")

In [5]:
x_s_t = [position_data[199,:B],momentum_data[199,:B]]
x_0_0 = [position_data[0,:B],momentum_data[0,:B]]
v_0_0 = [torch.zeros_like(x_0_0[0],device=device)]
z_0_0 = [mass.unsqueeze(dim=0).expand(B,22), node_features['charge'].unsqueeze(dim=0).expand(B,22)]

out = DDPM(x_s_t, x_0_0, v_0_0, z_0_0, 200*torch.zeros(B,device=device,dtype=torch.int), 200*torch.ones(B,device=device,dtype=torch.int))

## test equivariance using commutator

In [6]:
num_iters = 20

median_total_mom_error = []
median_total_pos_error = []

mean_total_mom_error = []
mean_total_pos_error = []

eps_mom = 0.0
eps_pos = 0.0

with torch.no_grad():
    for i in range(num_iters):
        # Rotate position and momentum for x_s_t
        R = None
        RP, R, _ = random_rotation_3d(position_data[199,:B])  # rotate pos
        RM, _, _ = random_rotation_3d(momentum_data[199,:B], R)  # same R

        # rotate x_0_0
        R0P, _, _ = random_rotation_3d(position_data[0,:B], R)
        R0M, _, _ = random_rotation_3d(momentum_data[0,:B], R)

        # rotate v_0_0 (zeros)
        Rv0 = torch.zeros_like(R0P, device=device)

        # rotate scalar features
        z_mass = mass.unsqueeze(0).expand(B, 22)  # [B, N]
        z_charge = node_features['charge'].unsqueeze(0).expand(B, 22)  # [B, N]

        # build rotated inputs
        x_s_t_rot = [RP, RM]
        x_0_0_rot = [R0P, R0M]
        v_0_0_rot = [Rv0]
        z_0_0_rot = [z_mass, z_charge]

        # unrotated inputs
        x_s_t = [position_data[199,:B], momentum_data[199,:B]]
        x_0_0 = [position_data[0,:B], momentum_data[0,:B]]
        v_0_0 = [torch.zeros_like(position_data[0,:B], device=device)]
        z_0_0 = [z_mass, z_charge]

        t_diff = 200 * torch.zeros(B, device=position_data[0].device,dtype=torch.int)
        s_phys = 200 * torch.ones(B, device=position_data[0].device,dtype=torch.int)

        # evaluate model
        out_rot = DDPM(x_s_t_rot, x_0_0_rot, v_0_0_rot, z_0_0_rot, t_diff, s_phys)
        out_ref = DDPM(x_s_t, x_0_0, v_0_0, z_0_0, t_diff, s_phys)
        
        mom_rot = out_rot[0]  # [B, N, 3]
        pos_rot = out_rot[1]  # [B, N, 3]
        
        mom_ref = out_ref[0]
        pos_ref = out_ref[1]

        # rotate reference outputs
        Rmom_ref, _, _ = random_rotation_3d(mom_ref, R)
        Rpos_ref, _, _ = random_rotation_3d(pos_ref, R)


        # RMSD
        median_total_mom_error.append(torch.median((mom_rot - Rmom_ref)**2).sqrt().item())
        median_total_pos_error.append(torch.median((pos_rot - Rpos_ref)**2).sqrt().item())

        mean_total_mom_error.append(torch.mean((mom_rot - Rmom_ref)**2).sqrt().item())
        mean_total_pos_error.append(torch.mean((pos_rot - Rpos_ref)**2).sqrt().item())

        # relative error
        eps_mom += torch.linalg.vector_norm(mom_rot - Rmom_ref) / (
                   torch.linalg.vector_norm(mom_ref) + 1e-12)
        eps_pos += torch.linalg.vector_norm(pos_rot - Rpos_ref) / (
                   torch.linalg.vector_norm(pos_ref) + 1e-12)

    eps_mom /= num_iters
    eps_pos /= num_iters

    print(f"mean relative error in momentum over {num_iters} rotations: {eps_mom:.2e}")
    print(f"mean relative error in position over {num_iters} rotations: {eps_pos:.2e}\n")

    print(f"Median momentum RMSD averaged over {num_iters} rotations: {sum(median_total_mom_error) / num_iters:.6f}")
    print(f"Median position RMSD difference averaged over {num_iters} rotations: {sum(median_total_pos_error) / num_iters:.6f}\n")

    print(f"Mean momentum RMSD averaged over {num_iters} rotations: {sum(mean_total_mom_error) / num_iters:.6f}")
    print(f"Mean position RMSD difference averaged over {num_iters} rotations: {sum(mean_total_pos_error) / num_iters:.6f}\n")

mean relative error in momentum over 20 rotations: 2.36e-07
mean relative error in position over 20 rotations: 4.75e-07

Median momentum RMSD averaged over 20 rotations: 0.000001
Median position RMSD difference averaged over 20 rotations: 0.000000

Mean momentum RMSD averaged over 20 rotations: 0.000002
Mean position RMSD difference averaged over 20 rotations: 0.000001



## look at the forward trajectories

In [7]:
#Function to plot trajectory evolution for positions and momentum
@torch.no_grad()
def plot_trajectories(tensor, title):
    T, B, N, D = tensor.shape
    fig, ax = plt.subplots(figsize=(10, 6))
    for k in range(B):
        for j in range(N):
            for i in range(D):
                ax.plot(tensor[:, k, j, i].detach().cpu(), alpha=0.4, linewidth=0.8)
    ax.set_title(title)
    ax.set_xlabel("Diffusion Step t")
    plt.tight_layout()
    plt.show()

# Function to plot histogram vs standard normal
@torch.no_grad()
def plot_hist_vs_normal(data_tensor, title):
    data = data_tensor.flatten().detach().cpu()
    x = torch.linspace(-5, 5, 500)
    normal_pdf = torch.exp(-0.5 * x**2) / torch.sqrt(torch.tensor(2 * torch.pi))
    
    plt.figure(figsize=(8, 5))
    plt.hist(data, bins=100, density=True, alpha=0.6, label="model output over all batches, atoms, dims")
    plt.plot(x.numpy(), normal_pdf.numpy(), 'k--', linewidth=2, label="standard normal $\mathcal{N}(0,1)$")
    
    plt.title(title)
    plt.xlabel("Value")
    plt.ylabel("Density")
    plt.legend()
    plt.tight_layout()
    plt.show()

In [8]:
# # Sample the trajectory
# forward = DDPM.sample_forward_trajectory(x_s_t)

# plot_trajectories(forward[0], "Position Trajectories")
# plot_trajectories(forward[1], "Momentum Trajectories")

# plot_hist_vs_normal(forward[0][-1], "Position Forward Output vs Standard Normal")
# plot_hist_vs_normal(forward[1][-1], "Momentum Forward Output vs Standard Normal")

# del forward

## look at the reverse trajectories

In [9]:
# # Sample the trajectory
# reverse = DDPM.sample_reverse_trajectory(x_0_0, v_0_0, z_0_0, s_phys)

# plot_trajectories(reverse[0], "Position Trajectories")
# plot_trajectories(reverse[1], "Momentum Trajectories")

# plot_hist_vs_normal(reverse[0][0], "Position Reverse Output vs Standard Normal")
# plot_hist_vs_normal(reverse[1][0], "Momentum Reverse Output vs Standard Normal")

In [10]:
# # Sample the trajectory
# reverse_marginal = DDPM.sample_reverse_marginal(x_0_0, v_0_0, z_0_0, t_diff_max, s_phys)

# plot_hist_vs_normal(reverse_marginal[0], "Position Forward Output vs Standard Normal")
# plot_hist_vs_normal(reverse_marginal[1], "Momentum Forward Output vs Standard Normal")

# del reverse_marginal

In [11]:
# # Sample the trajectory
# reverse_marginal = DDPM.sample_reverse_marginal(x_0_0, v_0_0, z_0_0, 0,s_phys)

# plot_hist_vs_normal(reverse_marginal[0], "Position Forward Output vs Standard Normal")
# plot_hist_vs_normal(reverse_marginal[1], "Momentum Forward Output vs Standard Normal")

# del reverse_marginal

In [12]:
class TrajectoryDataset():
    """
    Dataset class for trajectories with randomly selected initial conditions (IC).

    Args:
        trajectory (Tensor): Tensor of trajectories with shape [trajectory index, time, features].
        t_diff_max (int): Maximum diffusion time step.
        s_phys_max (int): Maximum physical time for sampling.
        device (str): Device to place tensors on (default: 'cpu').
        seed (int, optional): Seed for random number generation.
    """
    def __init__(self, trajectory, t_diff_max, s_phys_max, device,  seed=None):
        super().__init__()
        self.device = device
        self.s_phys_max = s_phys_max
        self.t_diff_max = t_diff_max
        self.trajectory = trajectory

        self.num_trajs = len(self.trajectory[0])
        self.data_dim = len(self.trajectory[0,0])

    def getitems(self, batch_size):
        traj_idxs = torch.randint(0, self.num_trajs, (batch_size,),device=self.device)
        
        N_vals = torch.rand(batch_size,device=self.device)  * np.log(self.s_phys_max)
        ic_idx = torch.randint(1, len(self.trajectory)- (self.s_phys_max), (batch_size,),device=self.device)
        
        s_phys = torch.floor(torch.exp(N_vals)).long()
        t_diff = torch.randint(0, self.t_diff_max, (batch_size,),device=self.device)
        x_0_0 = self.trajectory[ic_idx, traj_idxs]
        x_s_0 = self.trajectory[ic_idx + s_phys, traj_idxs]
        
        return {"x_0_0":x_0_0,
                "x_s_0":x_s_0,
                "t_diff":t_diff,
                "s_phys":s_phys}

    def getitem(self):
        return self.getitems(1)
         
dataset = TrajectoryDataset(torch.cat([position_data,momentum_data],dim=-1),t_diff_max,s_phys_max,device)

In [None]:
wandb_step = 0
log_interval = 16
plot_interval = 2048
save_interval = 16*2048
grad_clipping = 0
name = f"BatchSize{B}"
checkpoint_path = "models/"

wandb.init(project="diffusion-training", name=name)

optimizer = optim.Adam(DDPM.parameters(),
                       lr = 1e-3,
                       betas = (0.9, 0.999),
                       eps = 1e-8, 
                       weight_decay = 0.0,
                       amsgrad = False)

scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99999)

epochs = 256
epoch_size = 2048*B 

for epoch in range(0, epochs):
    loader = tqdm(range(0, epoch_size, B), desc=f"Epoch {epoch+1}/{epochs}", leave=True, ncols=120, unit=' batch')
    for step in loader:
        batch = dataset.getitems(B)
        x_0_0 = [batch['x_0_0'][...,:3],batch['x_0_0'][...,3:]]         # Initial condition inferred vector features
        x_s_0 = [batch['x_s_0'][...,:3],batch['x_s_0'][...,3:]]         # Final condition inferred vector features
        v_0_0 = [torch.zeros_like(x_0_0[0],device=device)]              # Initial auxillary vector features
        z_0_0 = [mass.unsqueeze(dim=0).expand(B,22), 
                 node_features['charge'].unsqueeze(dim=0).expand(B,22)] # Initial auxillary scalar features
        t_diff = batch['t_diff']                                        # Diffusion time
        s_phys = batch['s_phys']                                        # Physical time
        
        optimizer.zero_grad()
        
        loss_list = DDPM.loss(x_s_t, x_0_0, v_0_0, z_0_0, t_diff, s_phys, flatten=True)
        pos_loss, mom_loss = loss_list
        loss = pos_loss + mom_loss
        loss.backward()

        if grad_clipping != 0:
            # Needs to be inspected. Does not work as intended.
            torch.nn.utils.clip_grad_norm_(DDPM.parameters(), max_norm=grad_clipping)

        optimizer.step()
        
        loader.set_postfix(
            batch=f"{int((step+1)/B)}/{int(epoch_size/B)}",
            loss=f"{loss.item():.6f}",
            pos=f"{pos_loss.item():.4e}",
            mom=f"{mom_loss.item():.4e}"
        )

        if wandb_step % log_interval == 0:
            wandb.log({
                "loss/total": loss.item(),
                "loss/position": pos_loss.item(),
                "loss/momentum": mom_loss.item(),
                "lr": scheduler.get_last_lr()[0],
            }, step=wandb_step)

            # Log gradients and parameters
            for name_, param in DDPM.named_parameters():
                if param.grad is not None:
                    wandb.log({f"gradients/{name_}": wandb.Histogram(param.grad.cpu().data.numpy())}, step=wandb_step)
                wandb.log({f"params/{name_}": wandb.Histogram(param.cpu().data.numpy())}, step=wandb_step)

        # if wandb_step % plot_interval == 0:
        #     with torch.no_grad():
        #         # Sample the trajectory
        #         reverse_marginal = DDPM.sample_reverse_marginal([x_0_0_[:64] for x_0_0_ in x_0_0], [v_0_0_[:64] for v_0_0_ in v_0_0], [z_0_0_[:64] for z_0_0_ in z_0_0], 0, s_phys[:64])
                
        #         plot_hist_vs_normal(reverse_marginal[0], "Position Reverse Output vs Standard Normal")
        #         plot_hist_vs_normal(reverse_marginal[1], "Momentum Reverse Output vs Standard Normal")
            
        #         del reverse_marginal

        if wandb_step % save_interval == 0:
            torch.save(DDPM, f"model_{wandb_step}.pth")

        wandb_step += 1
        scheduler.step()
       

    print(f"Epoch {epoch+1}: Loss = {loss.item():.6f}")

print("Training completed.")

[34m[1mwandb[0m: Currently logged in as: [33mwinsaton[0m ([33mwinsaton-univeristy-of-minnesota[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/256: 100%|█| 2048/2048 [06:38<00:00,  5.14 batch/s, batch=2047/2048, loss=0.991295, mom=6.0242e-01, pos=3.8887e-


Epoch 1: Loss = 0.991295


Epoch 2/256: 100%|█| 2048/2048 [05:43<00:00,  5.96 batch/s, batch=2047/2048, loss=0.901332, mom=5.9850e-01, pos=3.0283e-


Epoch 2: Loss = 0.901332


Epoch 3/256: 100%|█| 2048/2048 [05:30<00:00,  6.19 batch/s, batch=2047/2048, loss=0.842096, mom=5.8915e-01, pos=2.5294e-


Epoch 3: Loss = 0.842096


Epoch 4/256: 100%|█| 2048/2048 [06:16<00:00,  5.44 batch/s, batch=2047/2048, loss=0.871123, mom=6.2965e-01, pos=2.4147e-


Epoch 4: Loss = 0.871123


Epoch 5/256: 100%|█| 2048/2048 [06:26<00:00,  5.29 batch/s, batch=2047/2048, loss=0.839638, mom=6.2054e-01, pos=2.1910e-


Epoch 5: Loss = 0.839638


Epoch 6/256: 100%|█| 2048/2048 [06:28<00:00,  5.27 batch/s, batch=2047/2048, loss=0.815513, mom=5.9017e-01, pos=2.2535e-


Epoch 6: Loss = 0.815513


Epoch 7/256: 100%|█| 2048/2048 [06:32<00:00,  5.22 batch/s, batch=2047/2048, loss=0.920225, mom=6.9468e-01, pos=2.2554e-


Epoch 7: Loss = 0.920225


Epoch 8/256: 100%|█| 2048/2048 [06:40<00:00,  5.11 batch/s, batch=2047/2048, loss=0.818383, mom=6.1344e-01, pos=2.0494e-


Epoch 8: Loss = 0.818383


Epoch 9/256: 100%|█| 2048/2048 [06:14<00:00,  5.47 batch/s, batch=2047/2048, loss=0.818216, mom=6.1916e-01, pos=1.9905e-


Epoch 9: Loss = 0.818216


Epoch 10/256: 100%|█| 2048/2048 [06:23<00:00,  5.35 batch/s, batch=2047/2048, loss=0.801980, mom=6.1044e-01, pos=1.9154e


Epoch 10: Loss = 0.801980


Epoch 11/256: 100%|█| 2048/2048 [06:35<00:00,  5.18 batch/s, batch=2047/2048, loss=0.770520, mom=5.9363e-01, pos=1.7689e


Epoch 11: Loss = 0.770520


Epoch 12/256: 100%|█| 2048/2048 [06:52<00:00,  4.97 batch/s, batch=2047/2048, loss=0.721888, mom=5.4026e-01, pos=1.8163e


Epoch 12: Loss = 0.721888


Epoch 13/256: 100%|█| 2048/2048 [06:11<00:00,  5.51 batch/s, batch=2047/2048, loss=0.786725, mom=6.0138e-01, pos=1.8534e


Epoch 13: Loss = 0.786725


Epoch 14/256: 100%|█| 2048/2048 [06:29<00:00,  5.26 batch/s, batch=2047/2048, loss=0.721418, mom=5.4258e-01, pos=1.7884e


Epoch 14: Loss = 0.721418


Epoch 15/256: 100%|█| 2048/2048 [06:49<00:00,  5.00 batch/s, batch=2047/2048, loss=0.757624, mom=5.6446e-01, pos=1.9317e


Epoch 15: Loss = 0.757624


Epoch 16/256: 100%|█| 2048/2048 [06:02<00:00,  5.65 batch/s, batch=2047/2048, loss=0.706580, mom=5.2428e-01, pos=1.8230e


Epoch 16: Loss = 0.706580


Epoch 17/256: 100%|█| 2048/2048 [06:22<00:00,  5.36 batch/s, batch=2047/2048, loss=0.689454, mom=5.0828e-01, pos=1.8117e


Epoch 17: Loss = 0.689454


Epoch 18/256: 100%|█| 2048/2048 [06:10<00:00,  5.53 batch/s, batch=2047/2048, loss=0.632324, mom=4.5555e-01, pos=1.7677e


Epoch 18: Loss = 0.632324


Epoch 19/256: 100%|█| 2048/2048 [06:18<00:00,  5.41 batch/s, batch=2047/2048, loss=0.686475, mom=4.9998e-01, pos=1.8650e


Epoch 19: Loss = 0.686475


Epoch 20/256: 100%|█| 2048/2048 [06:09<00:00,  5.54 batch/s, batch=2047/2048, loss=0.622894, mom=4.5033e-01, pos=1.7256e


Epoch 20: Loss = 0.622894


Epoch 21/256: 100%|█| 2048/2048 [06:52<00:00,  4.97 batch/s, batch=2047/2048, loss=0.639012, mom=4.5863e-01, pos=1.8038e


Epoch 21: Loss = 0.639012


Epoch 22/256: 100%|█| 2048/2048 [06:00<00:00,  5.69 batch/s, batch=2047/2048, loss=0.592325, mom=4.2676e-01, pos=1.6557e


Epoch 22: Loss = 0.592325


Epoch 23/256: 100%|█| 2048/2048 [05:26<00:00,  6.27 batch/s, batch=2047/2048, loss=0.609420, mom=4.3397e-01, pos=1.7545e


Epoch 23: Loss = 0.609420


Epoch 24/256: 100%|█| 2048/2048 [06:02<00:00,  5.66 batch/s, batch=2047/2048, loss=0.609993, mom=4.2911e-01, pos=1.8089e


Epoch 24: Loss = 0.609993


Epoch 25/256: 100%|█| 2048/2048 [05:56<00:00,  5.74 batch/s, batch=2047/2048, loss=0.573210, mom=4.0139e-01, pos=1.7182e


Epoch 25: Loss = 0.573210


Epoch 26/256: 100%|█| 2048/2048 [06:28<00:00,  5.28 batch/s, batch=2047/2048, loss=0.647055, mom=4.4064e-01, pos=2.0642e


Epoch 26: Loss = 0.647055


Epoch 27/256: 100%|█| 2048/2048 [06:29<00:00,  5.26 batch/s, batch=2047/2048, loss=0.587018, mom=4.0981e-01, pos=1.7720e


Epoch 27: Loss = 0.587018


Epoch 28/256: 100%|█| 2048/2048 [06:54<00:00,  4.94 batch/s, batch=2047/2048, loss=0.604233, mom=4.3291e-01, pos=1.7132e


Epoch 28: Loss = 0.604233


Epoch 29/256: 100%|█| 2048/2048 [06:42<00:00,  5.08 batch/s, batch=2047/2048, loss=0.542936, mom=3.7916e-01, pos=1.6378e


Epoch 29: Loss = 0.542936


Epoch 30/256: 100%|█| 2048/2048 [06:40<00:00,  5.11 batch/s, batch=2047/2048, loss=0.579766, mom=3.9366e-01, pos=1.8611e


Epoch 30: Loss = 0.579766


Epoch 31/256: 100%|█| 2048/2048 [06:15<00:00,  5.45 batch/s, batch=2047/2048, loss=0.606727, mom=4.2253e-01, pos=1.8420e


Epoch 31: Loss = 0.606727


Epoch 32/256: 100%|█| 2048/2048 [06:14<00:00,  5.46 batch/s, batch=2047/2048, loss=0.548153, mom=3.7711e-01, pos=1.7104e


Epoch 32: Loss = 0.548153


Epoch 33/256: 100%|█| 2048/2048 [05:41<00:00,  6.00 batch/s, batch=2047/2048, loss=0.581435, mom=4.0264e-01, pos=1.7880e


Epoch 33: Loss = 0.581435


Epoch 34/256: 100%|█| 2048/2048 [06:39<00:00,  5.13 batch/s, batch=2047/2048, loss=0.534961, mom=3.6387e-01, pos=1.7110e


Epoch 34: Loss = 0.534961


Epoch 35/256: 100%|█| 2048/2048 [06:43<00:00,  5.07 batch/s, batch=2047/2048, loss=0.542177, mom=3.7790e-01, pos=1.6428e


Epoch 35: Loss = 0.542177


Epoch 36/256: 100%|█| 2048/2048 [06:29<00:00,  5.26 batch/s, batch=2047/2048, loss=0.541108, mom=3.6755e-01, pos=1.7356e


Epoch 36: Loss = 0.541108


Epoch 37/256: 100%|█| 2048/2048 [06:18<00:00,  5.41 batch/s, batch=2047/2048, loss=0.573176, mom=3.9088e-01, pos=1.8230e


Epoch 37: Loss = 0.573176


Epoch 38/256:  33%|▎| 683/2048 [01:42<03:25,  6.65 batch/s, batch=683/2048, loss=0.544996, mom=3.7306e-01, pos=1.7193e-0

In [10]:
top, node_features, mass, energy_dict = build_top_and_features("datasets/AlanaineDipeptideVacuum/alanine-dipeptide.prmtop")
atomic_numbers = [a.atomic_number for a in pmd.load_file("datasets/AlanaineDipeptideVacuum/alanine-dipeptide.prmtop").atoms]
save_pdb_with_bonds(dataset.getitem()['x_0_0'][0][:,:3],atomic_numbers,top)

In [27]:
revsamples = DDPM.sample_reverse_marginal([x_0_0_.expand(64,22,3) for x_0_0_ in x_0_0], 
                                          [v_0_0_[:64] for v_0_0_ in v_0_0], 
                                          [z_0_0_[:64] for z_0_0_ in z_0_0], 0, 
                                          198*torch.ones(64,device=device))