# üß¨ Geometry-Complete Equivariant Diffusion
## De Novo Drug Design Training

**Cluster-Based Split**: 100 pockets √ó 50 ligands = 5,000 pairs

**No Data Leakage**: Same pocket never in both train and val

## Cell 1: Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import torch
print(f'GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None!"}')

!pip install -q torch-geometric rdkit scipy numpy pyyaml tqdm gdown
print('‚úÖ Setup complete')

## Cell 2: Clone Repository

In [None]:
import os
REPO = '/content/drive/MyDrive/geom_diffusion'
if not os.path.exists(REPO):
    !git clone https://github.com/Nethrananda21/geom_diffusion.git {REPO}
%cd {REPO}
!git pull origin master

## Cell 3: Download Pre-processed Data

Downloads DiffSBDD's CrossDocked data (~500MB)

In [None]:
import os
import gdown

DATA_DIR = '/content/data/crossdocked'
os.makedirs(DATA_DIR, exist_ok=True)

# DiffSBDD pre-processed files
files = {
    'train_data.pkl': '1vJyxCIqCYwP3qj4THMofdSd1rZDEQpPG',
    'val_data.pkl': '1FpVNcdj0R5YOsaLQm6T4D5QOKZGI4Xc5'
}

for fname, fid in files.items():
    path = f'{DATA_DIR}/{fname}'
    if not os.path.exists(path):
        print(f'üì• Downloading {fname}...')
        gdown.download(id=fid, output=path, quiet=False)
    else:
        print(f'‚úÖ {fname} exists')

!ls -la {DATA_DIR}

## Cell 4: Create Cluster-Based 5K Subset

**Strategy**: 100 pockets √ó 50 ligands = 5,000 pairs

**No Leakage**: Pockets are disjoint between train (80) and val (20)

In [None]:
import pickle
import random
from collections import defaultdict

random.seed(42)

# Load full data
with open(f'{DATA_DIR}/train_data.pkl', 'rb') as f:
    full_data = pickle.load(f)
print(f'Loaded {len(full_data)} samples')

# Group by pocket
pockets = defaultdict(list)
for sample in full_data:
    # Extract pocket ID from sample
    pocket_id = sample.get('pocket_id', sample.get('receptor', str(hash(str(sample['pocket_coords'][:5])))))
    pockets[pocket_id].append(sample)

print(f'Found {len(pockets)} unique pockets')

# Select 120 pockets (100 train + 20 val)
pocket_ids = list(pockets.keys())
random.shuffle(pocket_ids)
selected_pockets = pocket_ids[:120]

# Split: 100 train, 20 val (NO OVERLAP!)
train_pockets = selected_pockets[:100]
val_pockets = selected_pockets[100:120]

print(f'Train pockets: {len(train_pockets)}')
print(f'Val pockets: {len(val_pockets)}')
print(f'Overlap: {len(set(train_pockets) & set(val_pockets))} (should be 0!)')

# Select 50 ligands per train pocket
train_samples = []
for pid in train_pockets:
    ligands = pockets[pid]
    selected = ligands[:50] if len(ligands) >= 50 else ligands
    train_samples.extend(selected)

# Select 50 ligands per val pocket
val_samples = []
for pid in val_pockets:
    ligands = pockets[pid]
    selected = ligands[:50] if len(ligands) >= 50 else ligands
    val_samples.extend(selected)

print(f'\n‚úÖ Cluster-Based Split:')
print(f'   Train: {len(train_samples)} samples ({len(train_pockets)} pockets)')
print(f'   Val: {len(val_samples)} samples ({len(val_pockets)} pockets)')

# Save subset
with open(f'{DATA_DIR}/train_5k.pkl', 'wb') as f:
    pickle.dump(train_samples, f)
with open(f'{DATA_DIR}/val_1k.pkl', 'wb') as f:
    pickle.dump(val_samples, f)

print(f'\nüíæ Saved: train_5k.pkl, val_1k.pkl')

## Cell 5: Update Config to Use Subset

In [None]:
import yaml

%cd /content/drive/MyDrive/geom_diffusion

with open('configs/debug_t4.yaml', 'r') as f:
    cfg = yaml.safe_load(f)

# Point to local data with our subset
cfg['data']['root'] = '/content/data'
cfg['data']['train_file'] = 'crossdocked/train_5k.pkl'
cfg['data']['val_file'] = 'crossdocked/val_1k.pkl'

# Training settings
cfg['training']['max_epochs'] = 50
cfg['hardware']['num_workers'] = 2

with open('configs/debug_t4.yaml', 'w') as f:
    yaml.dump(cfg, f)

print('‚úÖ Config updated to use cluster-based 5K subset')

## Cell 6: Delete Cache

In [None]:
import shutil
from pathlib import Path

cache = Path('/content/data/cache')
if cache.exists():
    shutil.rmtree(cache)
    print('üóëÔ∏è Cache deleted')
else:
    print('‚ÑπÔ∏è No cache')

## Cell 7: Train üöÄ

In [None]:
%cd /content/drive/MyDrive/geom_diffusion
!python train.py --config configs/debug_t4.yaml --checkpoint_dir checkpoints

## Cell 8: Resume (If Disconnected)

In [None]:
# Run Cells 1, 2, 5, 6 first, then:
# %cd /content/drive/MyDrive/geom_diffusion
# !python train.py --config configs/debug_t4.yaml --resume checkpoints/best_model.pt