In [1]:
from mint.state import MINTState
from mint.data.ADP.ADP_dataset import ADPDataset
from mint.module import MINTModule
from mint.experiment.train import Train
from mint.experiment.generate import Generate

from omegaconf import OmegaConf
import logging
from pytorch_lightning.utilities.rank_zero import rank_zero_only
import torch
from torch_geometric.loader import DataLoader
from torch.utils.data import Subset



In [2]:
ds_train = ADPDataset(data_dir='/users/1/sull1276/mint/tests/../mint/data/ADP', 
                       data_proc_fname="AA", 
                       data_proc_ext=".pkl.zst", 
                       data_raw_fname="alanine-dipeptide-250ns-nowater", 
                       data_raw_ext=".xtc", 
                       split="train", 
                       total_frames_train=25000, 
                       total_frames_test=5000, 
                       total_frames_valid=5000, 
                       lag= OmegaConf.create({"equilibrium": True}), 
                       normalize= OmegaConf.create({"bool": False, "t_dependent": False}), 
                       node_features= OmegaConf.create({"epsilon": True, "sigma": True, "charge": True, "mass": True}), 
                       augement_rotations=False)

ds_test = ADPDataset(data_dir='/users/1/sull1276/mint/tests/../mint/data/ADP', 
                       data_proc_fname="AA", 
                       data_proc_ext=".pkl.zst", 
                       data_raw_fname="alanine-dipeptide-250ns-nowater", 
                       data_raw_ext=".xtc", 
                       split="test", 
                       total_frames_train=25000, 
                       total_frames_test=5000, 
                       total_frames_valid=5000, 
                       lag= OmegaConf.create({"equilibrium": True}), 
                       normalize= OmegaConf.create({"bool": False, "t_dependent": False}), 
                       node_features= OmegaConf.create({"epsilon": True, "sigma": True, "charge": True, "mass": True}), 
                       augement_rotations=False)

ds_valid = ADPDataset(data_dir='/users/1/sull1276/mint/tests/../mint/data/ADP', 
                       data_proc_fname="AA", 
                       data_proc_ext=".pkl.zst", 
                       data_raw_fname="alanine-dipeptide-250ns-nowater", 
                       data_raw_ext=".xtc", 
                       split="valid", 
                       total_frames_train=25000, 
                       total_frames_test=5000, 
                       total_frames_valid=5000, 
                       lag= OmegaConf.create({"equilibrium": True}), 
                       normalize= OmegaConf.create({"bool": False, "t_dependent": False}), 
                       node_features= OmegaConf.create({"epsilon": True, "sigma": True, "charge": True, "mass": True}), 
                       augement_rotations=False)

module = MINTModule(
    cfg=OmegaConf.create({
        "prior": {
            "_target_": "mint.prior.normal.NormalPrior",
            "mean": 0.0,
            "std": 0.25,
        },
        "embedder": {
            "_target_": "mint.model.embedding.equilibrium_embedder.EquilibriumEmbedder",
            "use_ff": True,
            "interp_time": {
                "embedding_dim": 64,
                "max_positions": 1000,
            },
            "force_field": {
                "in_dim": 4,
                "hidden_dims": [128, 64],
                "out_dim": 32,
                "activation": "relu",
                "use_input_bn": False,
                "affine": False,
                "track_running_stats": True,
            },
            "atom_type": {
                "num_types": 14,
                "embedding_dim": 32,
            },
        },
        "model": {
            "_target_": "mint.model.equivariant.transformer.MultiSE3Transformer",
            "input_channels": [[128], [0]],
            "readout_channels": [[0, 0], [0, 1]],
            "hidden_channels": [[8, 8], [8, 8]],
            "key_channels": [[8, 8], [8, 8]],
            "query_channels": [[8, 8], [8, 8]],
            "edge_l_max": 2,
            "edge_basis": "smooth_finite",
            "max_radius": 10,
            "number_of_basis": 64,
            "hidden_size": 128,
            "max_neighbors": 10000,
            "act": "silu",
            "num_layers": 4,
            "bn": False,
        },
        "interpolant": {
            "_target_": "mint.interpolant.interpolants.TemporallyLinearInterpolant",
            "velocity_weight": 1.0,
            "denoiser_weight": 1.0,
            "gamma_weight": 0.1,
        },
        "validation": {
            "stratified": False,
        },
        "optim": {
            "optimizer": {
                "name": "Adam",
                "lr": 3e-4,
                "weight_decay": 0.01,
                "betas": [0.9, 0.999],
            },
            "scheduler": {
                "name": "CosineAnnealingLR",
                "T_max": "experiment.train.trainer.max_epochs",
                "eta_min": 1e-6,
            },
        },
    })
)

ckpt = torch.load("logs/hydra/ckpt/epoch_197-step_15444-loss_-83765.9844.ckpt", map_location="cuda")
module.load_state_dict(ckpt["state_dict"])

st = MINTState(
    seed=42,
    module=module,
    dataset_train=ds_train,
    dataset_valid=ds_valid,
    dataset_test=ds_test,
)

print(module)

INFO:: No processed data found at /users/1/sull1276/mint/tests/../mint/data/ADP/AA_train.pkl.zst... preprocessing data
INFO:: No processed data found at /users/1/sull1276/mint/tests/../mint/data/ADP/AA_test.pkl.zst... preprocessing data
INFO:: No processed data found at /users/1/sull1276/mint/tests/../mint/data/ADP/AA_valid.pkl.zst... preprocessing data


  ckpt = torch.load("logs/hydra/ckpt/epoch_197-step_15444-loss_-83765.9844.ckpt", map_location="cuda")


MINTModule(
  (embedder): EquilibriumEmbedder(
    (interpolant_time_embedder): TimeEmbed()
    (ff_embedder): MLPWithBN(
      (net): Sequential(
        (0): Linear(in_features=4, out_features=128, bias=True)
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Linear(in_features=128, out_features=64, bias=True)
        (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
        (5): ReLU(inplace=True)
        (6): Linear(in_features=64, out_features=32, bias=True)
      )
    )
    (atom_type_embed): Embedding(14, 32)
  )
  (model): MultiSE3Transformer(
    (lin_in): Linear(128x0e+0x0o -> 8x0e+8x1e+8x0o+8x1o | 1024 weights)
    (eg3nn_layers): ModuleList(
      (0-3): 4 x SE3Transformer(
        (act): SiLU()
        (tp_k): FullyConnectedTensorProduct(8x0e+8x1e+8x0o+8x1o x 1x0e+1x1o+1x2e -> 8x0e+8x1e+8x0o+8x1o | 768 paths | 768 weights)
        (fc_k): FullyConne

In [3]:
subset = Subset(ds_test, range(320))

test_loader = DataLoader(
    subset,
    batch_size=64,
    shuffle=False,
)

def epsilon_fn(t):
    return t
    
generate_cfg = OmegaConf.create(
    {   "dt": 1e-2,
        "step_type": "ode", # or "sde"
        "clip_val": 1e-3,
        "save_traj": False
    }
)

gen_experiment = Generate(state=st, cfg=generate_cfg, batches = test_loader, epsilon=epsilon_fn)

In [None]:
with torch.no_grad():
    samples = gen_experiment.run()

Generating samples:   0%|                                                                                     â€¦

Integrating over time:   0%|          | 0/100 [00:00<?, ?it/s]

Integrating over time:   0%|          | 0/100 [00:00<?, ?it/s]

Integrating over time:   0%|          | 0/100 [00:00<?, ?it/s]

Integrating over time:   0%|          | 0/100 [00:00<?, ?it/s]

Integrating over time:   0%|          | 0/100 [00:00<?, ?it/s]