In [1]:
!python scripts/download_and_prep.py 4o66_C --data_dir ./data --out_dir ./data

Data for 4o66_C seems to exist at ./data/4o66_C. Skipping download.
Preprocessing output ./data/4o66_C/4o66_C_R1.npy already exists. Skipping.
Preprocessing output ./data/4o66_C/4o66_C_R2.npy already exists. Skipping.
Preprocessing output ./data/4o66_C/4o66_C_R3.npy already exists. Skipping.


In [2]:
import sys
from gen_model.parsing import parse_train_args

# Step 2: Set up arguments/config
# We 'fake' the command line arguments so parse_train_args() doesn't crash
# We provide a placeholder for --data_dir just to satisfy the 'required=True' check
sys.argv = ['ipykernel_launcher.py', '--data_dir', './data']

args = parse_train_args()

# Now override with your actual desired notebook settings
args.atlas = True
args.data_dir = "./data"
args.pep_name = "4o66_C"
args.train_split = "gen_model/splits/frame_splits.csv"
args.atlas_csv = "gen_model/splits/atlas.csv"
args.batch_size = 8
args.num_workers = 0

print(f"Successfully initialized args. Data dir: {args.data_dir}")

Successfully initialized args. Data dir: ./data


In [11]:
import torch
from torch.utils.data import DataLoader
from gen_model.dataset import MDGenDataset

# 1. Initialize the dataset
# This will use the args.data_dir and args.train_split you defined in the previous step
trainset = MDGenDataset(args, split=args.train_split)

# 2. Setup the DataLoader
# num_workers=0 is often safer for debugging inside a notebook to avoid multiprocessing issues
train_loader = DataLoader(
    trainset, 
    batch_size=args.batch_size, 
    shuffle=True, 
    num_workers=args.num_workers,
)


# 3. Fetch one batch to verify everything is working
try:
    batch = next(iter(train_loader))
    print("Successfully loaded a batch!")
    
    # Print keys to see what data we have (e.g., 'pos', 'seq', 'mask')
    print(f"Batch keys: {batch.keys()}")
    
    # Check the shape of the coordinates if available
    if 'pos' in batch:
        print(f"Coordinates shape: {batch['pos'].shape}")
        
except Exception as e:
    print(f"Error loading batch: {e}")

Successfully loaded a batch!
Batch keys: dict_keys(['name', 'frame_indices', 'seqres', 'mask', 'torsion_mask', 'clean_trans', 'clean_rots', 'clean_torsions', 'clean_atom37'])


In [16]:
batch = next(iter(train_loader))
print(batch['seqres'].shape)
print(batch['clean_torsions'].shape)
print(batch['clean_rots'].shape)
print(batch['clean_trans'].shape)

torch.Size([8, 76])
torch.Size([8, 1, 76, 7, 2])
torch.Size([8, 1, 76, 3, 3])
torch.Size([8, 1, 76, 3])
