In [None]:
!python scripts/download_and_prep.py 4o66_C --data_dir ./data --out_dir ./data

Data for 4o66_C seems to exist at ./data/4o66_C. Skipping download.
Preprocessing output ./data/4o66_C/4o66_C_R1.npy already exists. Skipping.
Preprocessing output ./data/4o66_C/4o66_C_R2.npy already exists. Skipping.
Preprocessing output ./data/4o66_C/4o66_C_R3.npy already exists. Skipping.
Creating 4-way split (Early: 5.0ns, Ratios: [0.6, 0.2, 0.2])...
Found 3 trajectory file(s) for 4o66_C.
  4o66_C_R1: Total 10001, Timestep 10.0ps
    Early: [0:500], Train: [500:6200], Val: [6200:8100], Test: [8100:10001]
  4o66_C_R2: Total 10001, Timestep 10.0ps
    Early: [0:500], Train: [500:6200], Val: [6200:8100], Test: [8100:10001]
  4o66_C_R3: Total 10001, Timestep 10.0ps
    Early: [0:500], Train: [500:6200], Val: [6200:8100], Test: [8100:10001]
Updated splits saved to gen_model/splits/frame_splits.csv


In [2]:
import sys
from gen_model.parsing import parse_train_args

# Step 2: Set up arguments/config
# We 'fake' the command line arguments so parse_train_args() doesn't crash
# We provide a placeholder for --data_dir just to satisfy the 'required=True' check
sys.argv = ['ipykernel_launcher.py', '--data_dir', './data']

args = parse_train_args()

# Now override with your actual desired notebook settings
args.atlas = True
args.data_dir = "./data"
args.pep_name = "4o66_C"
args.train_split = "gen_model/splits/frame_splits.csv"
args.atlas_csv = "gen_model/splits/atlas.csv"
args.batch_size = 8
args.num_workers = 0

print(f"Successfully initialized args. Data dir: {args.data_dir}")

Successfully initialized args. Data dir: ./data


In [None]:
import torch
from torch.utils.data import DataLoader
from gen_model.dataset import MDGenDataset

# 1. Initialize the dataset
# This will use the args.data_dir and args.train_split you defined in the previous step
trainset = MDGenDataset(args, mode='train')

# 2. Setup the DataLoader
# num_workers=0 is often safer for debugging inside a notebook to avoid multiprocessing issues
train_loader = DataLoader(
    trainset, 
    batch_size=args.batch_size, 
    shuffle=True, 
    num_workers=args.num_workers,
)

# Validation set
val_dataset = MDGenDataset(args, mode='val')

val_loader = DataLoader(
    val_dataset, 
    batch_size=args.batch_size, 
    shuffle=False, 
    num_workers=args.num_workers,
)

# Test set
test_dataset = MDGenDataset(args, split_csv, mode='test')

test_loader = DataLoader(
    test_dataset, 
    batch_size=args.batch_size, 
    shuffle=False, 
    num_workers=args.num_workers,
)

# 3. Fetch one batch to verify everything is working
try:
    batch = next(iter(train_loader))
    print("Successfully loaded a train batch!")
    
    # Print keys to see what data we have (e.g., 'pos', 'seq', 'mask')
    print(f"Batch keys: {batch.keys()}")
    
    # Check the shape of the coordinates if available
    if 'pos' in batch:
        print(f"Coordinates shape: {batch['pos'].shape}")
    
    batch = next(iter(val_loader))
    print("Successfully loaded a val batch!")
    
    # Check the shape of the coordinates if available
    if 'pos' in batch:
        print(f"Coordinates shape: {batch['pos'].shape}")   

    batch = next(iter(test_loader))
    print("Successfully loaded a test batch!")
    
    # Check the shape of the coordinates if available
    if 'pos' in batch:
        print(f"Coordinates shape: {batch['pos'].shape}")   
        
except Exception as e:
    print(f"Error loading batch: {e}")



ValueError: num_samples should be a positive integer value, but got num_samples=0

: 

In [None]:
# These are the training inputs to be used in the model
batch = next(iter(train_loader))
print(batch['seqres'].shape)
print(batch['clean_torsions'].shape)
print(batch['clean_rots'].shape)
print(batch['clean_trans'].shape)

torch.Size([8, 76])
torch.Size([8, 1, 76, 7, 2])
torch.Size([8, 1, 76, 3, 3])
torch.Size([8, 1, 76, 3])


In [8]:
# These are some additional tensors to multiply in the objective function
print(batch['mask'].shape)
print(batch['torsion_mask'].shape)

torch.Size([8, 76])
torch.Size([8, 76, 7])


In [9]:
# This is the "ground truth"
batch['clean_atom37'].shape

torch.Size([8, 1, 76, 37, 3])

In [12]:
# This is some metadata about the batch. The frame indices might be used in the model
print(batch['name'])
print(batch['frame_indices'])

['4o66_C_R3', '4o66_C_R1', '4o66_C_R1', '4o66_C_R1', '4o66_C_R3', '4o66_C_R1', '4o66_C_R1', '4o66_C_R2']
tensor([[6434],
        [ 260],
        [9198],
        [7580],
        [7887],
        [9299],
        [9764],
        [ 619]])
