# üß¨ Geometry-Complete Equivariant Diffusion
## De Novo Drug Design - **P100 Optimized**

**GPU**: P100 (16GB) | **Batch**: 16 | **Model**: 256 hidden, 6 layers

## Cell 1: Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import torch
gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'
print(f'GPU: {gpu_name}')
if 'P100' not in gpu_name:
    print('‚ö†Ô∏è WARNING: Not P100! Consider using p100.yaml settings')

%pip install -q torch-geometric rdkit scipy numpy pyyaml tqdm scikit-learn
print('‚úÖ Setup complete')

## Cell 2: Clone Repo

In [None]:
import os
REPO = '/content/drive/MyDrive/geom_diffusion'
if not os.path.exists(REPO):
    !git clone https://github.com/Nethrananda21/geom_diffusion.git {REPO}
%cd {REPO}
!git pull origin master

## Cell 3: Extract Data

In [None]:
import os

DATA_DIR = '/content/data/raw'
os.makedirs(DATA_DIR, exist_ok=True)

existing = [d for d in os.listdir(DATA_DIR) if os.path.isdir(f'{DATA_DIR}/{d}')]

if len(existing) < 10:
    if os.path.exists('/content/drive/MyDrive/crossdocked_essential.tar.gz'):
        print('üì¶ Extracting from Drive backup...')
        !tar -xzf /content/drive/MyDrive/crossdocked_essential.tar.gz -C {DATA_DIR}/
        print('‚úÖ Done!')
    else:
        print('‚ùå No backup found.')
else:
    print(f'‚úÖ Data exists: {len(existing)} folders')

print(f'üìÅ Total: {len(os.listdir(DATA_DIR))} folders')

## Cell 4: Preprocess (P100: 2000 atoms)

In [None]:
import os
import pickle
import numpy as np
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
from rdkit import Chem
import warnings
warnings.filterwarnings('ignore')

np.random.seed(42)

# === P100 OPTIMIZED CONFIG ===
MAX_POCKET_ATOMS = 2000   # 2x larger than T4
MAX_LIGAND_ATOMS = 60     # Larger ligands
MIN_LIGAND_ATOMS = 5
MIN_LIGANDS_PER_POCKET = 3
N_POCKETS_PER_BIN = 60    # More pockets per bin

ATOM_TYPES = ['C', 'N', 'O', 'S', 'P', 'F', 'Cl', 'Br', 'I', 'Other']
ATOM_TO_IDX = {a: i for i, a in enumerate(ATOM_TYPES)}

def parse_pdb(pdb_path):
    coords = []
    types = []
    with open(pdb_path, 'r') as f:
        for line in f:
            if line.startswith('ATOM'):
                x = float(line[30:38])
                y = float(line[38:46])
                z = float(line[46:54])
                coords.append([x, y, z])
                elem = line[76:78].strip() if len(line) > 77 else line[12:16].strip()[0]
                idx = ATOM_TO_IDX.get(elem, 9)
                one_hot = [0.0] * 10
                one_hot[idx] = 1.0
                types.append(one_hot)
    return np.array(coords, dtype=np.float32), np.array(types, dtype=np.float32)

def parse_sdf(sdf_path):
    mol = Chem.SDMolSupplier(str(sdf_path), removeHs=True, sanitize=False)[0]
    if mol is None:
        return None, None
    try:
        conf = mol.GetConformer()
    except:
        return None, None
    coords = []
    types = []
    for atom in mol.GetAtoms():
        pos = conf.GetAtomPosition(atom.GetIdx())
        coords.append([pos.x, pos.y, pos.z])
        elem = atom.GetSymbol()
        idx = ATOM_TO_IDX.get(elem, 9)
        one_hot = [0.0] * 10
        one_hot[idx] = 1.0
        types.append(one_hot)
    return np.array(coords, dtype=np.float32), np.array(types, dtype=np.float32)

print(f'üî¨ P100 Preprocessing (max atoms: {MAX_POCKET_ATOMS})...')

DATA_DIR = '/content/data/raw'
OUT_DIR = '/content/data/crossdocked'
os.makedirs(OUT_DIR, exist_ok=True)

pocket_dirs = [Path(DATA_DIR) / d for d in os.listdir(DATA_DIR) if os.path.isdir(f'{DATA_DIR}/{d}')]
print(f'Found {len(pocket_dirs)} pocket directories')

pockets = defaultdict(list)
pocket_info = {}

for pocket_dir in tqdm(pocket_dirs, desc='Processing'):
    pocket_id = pocket_dir.name
    rec_pdbs = list(pocket_dir.glob('*_rec.pdb'))
    all_sdf = list(pocket_dir.glob('*.sdf'))
    
    if not rec_pdbs or not all_sdf:
        continue
    
    try:
        pocket_coords, pocket_types = parse_pdb(rec_pdbs[0])
        if len(pocket_coords) == 0 or len(pocket_coords) > MAX_POCKET_ATOMS:
            continue
    except:
        continue
    
    pocket_info[pocket_id] = {'size': len(pocket_coords)}
    
    for sdf_path in all_sdf[:50]:
        try:
            lig_coords, lig_types = parse_sdf(sdf_path)
            if lig_coords is None:
                continue
            if len(lig_coords) < MIN_LIGAND_ATOMS or len(lig_coords) > MAX_LIGAND_ATOMS:
                continue
            
            pockets[pocket_id].append({
                'pocket_id': pocket_id,
                'ligand_id': sdf_path.stem,
                'ligand_coords': lig_coords,
                'ligand_types': lig_types,
                'pocket_coords': pocket_coords,
                'pocket_types': pocket_types
            })
        except:
            continue

print(f'\n‚úÖ Processed {len(pockets)} valid pockets (‚â§{MAX_POCKET_ATOMS} atoms)')

valid = [p for p, samples in pockets.items() if len(samples) >= MIN_LIGANDS_PER_POCKET]
print(f'After filter (lig>={MIN_LIGANDS_PER_POCKET}): {len(valid)} pockets')

if len(valid) == 0:
    print('‚ùå No valid pockets!')
else:
    sizes = [pocket_info[p]['size'] for p in valid]
    p33, p66 = np.percentile(sizes, [33, 66])
    
    small = [p for p in valid if pocket_info[p]['size'] <= p33]
    medium = [p for p in valid if p33 < pocket_info[p]['size'] <= p66]
    large = [p for p in valid if pocket_info[p]['size'] > p66]
    
    print(f'Bins: Small={len(small)}, Medium={len(medium)}, Large={len(large)}')
    
    np.random.shuffle(small)
    np.random.shuffle(medium)
    np.random.shuffle(large)
    
    n_per = min(N_POCKETS_PER_BIN, len(small), len(medium), len(large))
    selected = small[:n_per] + medium[:n_per] + large[:n_per]
    print(f'Selected: {len(selected)} pockets')
    
    np.random.shuffle(selected)
    split_idx = int(len(selected) * 0.85)
    train_pockets = selected[:split_idx]
    val_pockets = selected[split_idx:]
    
    print(f'Train: {len(train_pockets)}, Val: {len(val_pockets)}')
    
    train_samples = [s for p in train_pockets for s in pockets[p]]
    val_samples = [s for p in val_pockets for s in pockets[p]]
    
    print(f'\nüìä Train: {len(train_samples)}, Val: {len(val_samples)}')
    
    with open(f'{OUT_DIR}/train_data.pkl', 'wb') as f:
        pickle.dump(train_samples, f)
    with open(f'{OUT_DIR}/val_data.pkl', 'wb') as f:
        pickle.dump(val_samples, f)
    
    print('üíæ Saved!')

## Cell 5: Train P100 üöÄ

In [None]:
import shutil
from pathlib import Path

%cd /content/drive/MyDrive/geom_diffusion

# Clear cache
for cache in ['/content/data/cache', 'data/cache']:
    if Path(cache).exists():
        shutil.rmtree(cache)
        print(f'üóëÔ∏è Deleted {cache}')

# P100 config
!python train.py --config configs/p100.yaml --checkpoint_dir checkpoints_p100

## Cell 6: Monitor GPU Usage

In [None]:
# Run in separate cell to monitor
!nvidia-smi

## Cell 7: Resume Training

In [None]:
# Uncomment to resume:
# %cd /content/drive/MyDrive/geom_diffusion
# !python train.py --config configs/p100.yaml --resume checkpoints_p100/best_model.pt