# üß¨ Geometry-Complete Equivariant Diffusion
## De Novo Drug Design Training

**Data**: Pitt.edu direct wget (no gdown!)

## Cell 1: Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import torch
print(f'GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None!"}')

%pip install -q torch-geometric rdkit scipy numpy pyyaml tqdm scikit-learn
print('‚úÖ Setup complete')

## Cell 2: Clone Repo

In [None]:
import os
REPO = '/content/drive/MyDrive/geom_diffusion'
if not os.path.exists(REPO):
    !git clone https://github.com/Nethrananda21/geom_diffusion.git {REPO}
%cd {REPO}
!git pull origin master

## Cell 3: Download Downsampled CrossDocked

From: https://bits.csb.pitt.edu/files/crossdock2020/

In [None]:
import os

DATA_DIR = '/content/data'
os.makedirs(DATA_DIR, exist_ok=True)

# Check what's already there
existing = [d for d in os.listdir(DATA_DIR) if os.path.isdir(f'{DATA_DIR}/{d}')]

if not existing:
    print('üì• Downloading downsampled CrossDocked2020 (~5GB)...')
    !wget -q --show-progress -O {DATA_DIR}/crossdocked.tgz \
        https://bits.csb.pitt.edu/files/crossdock2020/downsampled_CrossDocked2020_v1.3.tgz
    
    print('\nüì¶ Extracting...')
    !tar -xzf {DATA_DIR}/crossdocked.tgz -C {DATA_DIR}/
    !rm {DATA_DIR}/crossdocked.tgz
    print('‚úÖ Done!')
else:
    print('‚úÖ Data already exists')

# Find the actual extracted folder
!ls -la {DATA_DIR}/
folders = [d for d in os.listdir(DATA_DIR) if os.path.isdir(f'{DATA_DIR}/{d}')]
print(f'\nüìÅ Extracted folders: {folders}')

## Cell 4: Preprocess & Create 5K Subset

In [None]:
import os
import pickle
import numpy as np
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm

np.random.seed(42)

# Auto-detect data folder
DATA_ROOT = '/content/data'
folders = [d for d in os.listdir(DATA_ROOT) if os.path.isdir(f'{DATA_ROOT}/{d}') and d != 'crossdocked']
DATA_DIR = f'{DATA_ROOT}/{folders[0]}' if folders else DATA_ROOT
print(f'üìÅ Using data from: {DATA_DIR}')

OUT_DIR = '/content/data/crossdocked'
os.makedirs(OUT_DIR, exist_ok=True)

# Find all pocket directories
print('üîç Scanning pockets...')
pocket_dirs = [d for d in Path(DATA_DIR).iterdir() if d.is_dir()]
print(f'Found {len(pocket_dirs)} pocket directories')

# Group ligands by pocket
pockets = defaultdict(list)
pocket_info = {}

for pocket_dir in tqdm(pocket_dirs[:500], desc='Scanning'):
    pocket_id = pocket_dir.name
    
    # Find pocket PDB files
    pocket_pdb = list(pocket_dir.glob('*_pocket*.pdb')) or list(pocket_dir.glob('*receptor*.pdb')) or list(pocket_dir.glob('*.pdb'))
    ligand_sdf = list(pocket_dir.glob('*.sdf'))
    
    if pocket_pdb and ligand_sdf:
        # Count atoms in pocket
        try:
            with open(pocket_pdb[0], 'r') as f:
                atom_count = sum(1 for line in f if line.startswith('ATOM'))
        except:
            atom_count = 0
        
        pocket_info[pocket_id] = {
            'size': atom_count,
            'pdb': str(pocket_pdb[0]),
            'ligands': [str(l) for l in ligand_sdf]
        }
        
        # Store sample references
        for lig in ligand_sdf:
            pockets[pocket_id].append({
                'pocket_pdb': str(pocket_pdb[0]),
                'ligand_sdf': str(lig),
                'pocket_id': pocket_id,
                'num_atoms': atom_count
            })

print(f'\nProcessed {len(pockets)} pockets with ligands')

# Filter by our criteria
valid = [p for p, samples in pockets.items() 
         if pocket_info[p]['size'] <= 250 and len(samples) >= 10]
print(f'After filter (size<=250, lig>=10): {len(valid)} pockets')

# Stratify by size
small = [p for p in valid if pocket_info[p]['size'] <= 100]
medium = [p for p in valid if 100 < pocket_info[p]['size'] <= 175]
large = [p for p in valid if 175 < pocket_info[p]['size'] <= 250]

print(f'Bins: Small={len(small)}, Medium={len(medium)}, Large={len(large)}')

# Select from each bin
np.random.shuffle(small)
np.random.shuffle(medium)
np.random.shuffle(large)

n_small = min(40, len(small))
n_medium = min(40, len(medium))
n_large = min(40, len(large))
selected = small[:n_small] + medium[:n_medium] + large[:n_large]

# Split train/val (83%/17%)
np.random.shuffle(selected)
split_idx = int(len(selected) * 0.83)
train_pockets = selected[:split_idx]
val_pockets = selected[split_idx:]

print(f'\n‚úÖ Train: {len(train_pockets)} pockets, Val: {len(val_pockets)} pockets')

# Create datasets (up to 50 ligands per pocket)
train_samples = [s for p in train_pockets for s in pockets[p][:50]]
val_samples = [s for p in val_pockets for s in pockets[p][:50]]

print(f'üìä Train: {len(train_samples)}, Val: {len(val_samples)}')

# Save
with open(f'{OUT_DIR}/train_data.pkl', 'wb') as f:
    pickle.dump(train_samples, f)
with open(f'{OUT_DIR}/val_data.pkl', 'wb') as f:
    pickle.dump(val_samples, f)

print('üíæ Saved to /content/data/crossdocked/')

## Cell 5: Update Config

In [None]:
import yaml

%cd /content/drive/MyDrive/geom_diffusion

with open('configs/debug_t4.yaml', 'r') as f:
    cfg = yaml.safe_load(f)

cfg['data']['root'] = '/content/data'
cfg['data']['train_file'] = 'crossdocked/train_data.pkl'
cfg['data']['val_file'] = 'crossdocked/val_data.pkl'
cfg['training']['max_epochs'] = 50
cfg['training']['batch_size'] = 4
cfg['hardware']['num_workers'] = 2

with open('configs/debug_t4.yaml', 'w') as f:
    yaml.dump(cfg, f)

print('‚úÖ Config updated')

## Cell 6: Train üöÄ

In [None]:
import shutil
from pathlib import Path

%cd /content/drive/MyDrive/geom_diffusion

for cache in ['/content/data/cache', 'data/cache']:
    if Path(cache).exists():
        shutil.rmtree(cache)

!python train.py --config configs/debug_t4.yaml --checkpoint_dir checkpoints