# Phase 5: Dark Halo Scope - Lens Finder Training

**GPU Required**: Runtime ‚Üí Change runtime type ‚Üí T4/A100 GPU

## Setup Instructions (Do Once)
1. Click the **key icon** üîë in the left sidebar (Secrets)
2. Add two secrets:
   - Name: `AWS_ACCESS_KEY_ID` ‚Üí Value: Your AWS access key
   - Name: `AWS_SECRET_ACCESS_KEY` ‚Üí Value: Your AWS secret key
3. Toggle "Notebook access" ON for both secrets
4. Run cells in order from top to bottom


In [None]:
#@title 1. Check GPU and Install Dependencies
import subprocess
result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv'], 
                       capture_output=True, text=True)
print("GPU Available:")
print(result.stdout)

# Install dependencies
%pip install -q boto3 s3fs pyarrow fsspec tqdm

print("\n‚úÖ Dependencies installed!")


In [None]:
#@title 2. Configure AWS Credentials
import os

# Load from Colab Secrets
from google.colab import userdata

try:
    os.environ['AWS_ACCESS_KEY_ID'] = userdata.get('AWS_ACCESS_KEY_ID')
    os.environ['AWS_SECRET_ACCESS_KEY'] = userdata.get('AWS_SECRET_ACCESS_KEY')
    os.environ['AWS_DEFAULT_REGION'] = 'us-east-2'
    print("‚úÖ AWS credentials loaded from Colab Secrets")
except Exception as e:
    print(f"‚ùå Error loading secrets: {e}")
    print("Please add AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to Colab Secrets (key icon)")

# Verify connection
import boto3
try:
    s3 = boto3.client('s3')
    response = s3.list_objects_v2(Bucket='darkhaloscope', Prefix='phase4_pipeline/', MaxKeys=1)
    print("‚úÖ S3 connection verified")
except Exception as e:
    print(f"‚ùå S3 connection failed: {e}")


In [None]:
#@title 3. Download Training Data from S3
import boto3
import os
from tqdm import tqdm

# Configuration - CHANGE THIS FOR YOUR RUN
BUCKET = 'darkhaloscope'

# For SMOKE TEST (debug tier, ~6k rows):
S3_PREFIX = 'phase4_pipeline/phase4c/v3_color_relaxed/stamps/debug_stamp64_bandsgrz_gridgrid_small/'

# For FULL TRAINING (train tier, ~10M rows) - uncomment below:
# S3_PREFIX = 'phase4_pipeline/phase4c/v3_color_relaxed/stamps/train_stamp64_v2/'

LOCAL_DIR = '/content/data/stamps'
MAX_FILES = None  # None = download all files in the prefix

os.makedirs(LOCAL_DIR, exist_ok=True)

s3 = boto3.client('s3')

# List all parquet files
print(f"Listing files from s3://{BUCKET}/{S3_PREFIX}...")
paginator = s3.get_paginator('list_objects_v2')
all_files = []
for page in paginator.paginate(Bucket=BUCKET, Prefix=S3_PREFIX):
    for obj in page.get('Contents', []):
        if obj['Key'].endswith('.parquet'):
            all_files.append(obj)

print(f"Found {len(all_files)} parquet files")

# Download files
files_to_download = all_files[:MAX_FILES] if MAX_FILES else all_files
total_size = sum(f['Size'] for f in files_to_download) / (1024**3)
print(f"Downloading {len(files_to_download)} files ({total_size:.2f} GB)...")

for obj in tqdm(files_to_download):
    key = obj['Key']
    filename = key.split('/')[-1]
    local_path = os.path.join(LOCAL_DIR, filename)
    if not os.path.exists(local_path):
        s3.download_file(BUCKET, key, local_path)

print(f"\n‚úÖ Downloaded {len(files_to_download)} files to {LOCAL_DIR}")


In [None]:
#@title 4. Validate Data
import pyarrow.parquet as pq
import numpy as np
import io
import glob

parquet_files = sorted(glob.glob('/content/data/stamps/*.parquet'))
print(f"Found {len(parquet_files)} local parquet files")

# Check first file
pf = pq.ParquetFile(parquet_files[0])
print(f"\nSchema columns ({len(pf.schema.names)}):")
print(pf.schema.names[:20], "..." if len(pf.schema.names) > 20 else "")
print(f"\nRows in first file: {pf.metadata.num_rows}")

# Read one row and decode stamp_npz
table = pf.read_row_group(0)
stamp_npz_bytes = table['stamp_npz'][0].as_py()
with np.load(io.BytesIO(stamp_npz_bytes)) as npz:
    print(f"\nStamp NPZ contents: {list(npz.keys())}")
    for k in npz.keys():
        print(f"  {k}: shape={npz[k].shape}, dtype={npz[k].dtype}")

# Count total rows
total_rows = sum(pq.ParquetFile(f).metadata.num_rows for f in parquet_files)
print(f"\n‚úÖ Total rows across all files: {total_rows:,}")


In [None]:
#@title 5. Training Code (Self-Contained)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pyarrow.parquet as pq
import numpy as np
import io
import glob
import os
from tqdm import tqdm
from sklearn.metrics import roc_auc_score

# ============== Model ==============
class LensFinderCNN(nn.Module):
    """ResNet-18 style CNN for lens detection."""
    def __init__(self, in_channels=3):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, 7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(64, 64, 2)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, 1)
    
    def _make_layer(self, in_ch, out_ch, blocks, stride=1):
        layers = []
        layers.append(self._block(in_ch, out_ch, stride))
        for _ in range(1, blocks):
            layers.append(self._block(out_ch, out_ch, 1))
        return nn.Sequential(*layers)
    
    def _block(self, in_ch, out_ch, stride):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
        )
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.fc(x)

# ============== Dataset ==============
class StampDataset(Dataset):
    def __init__(self, parquet_files, augment=False):
        self.files = parquet_files
        self.augment = augment
        self._build_index()
    
    def _build_index(self):
        self.index = []
        for fpath in self.files:
            pf = pq.ParquetFile(fpath)
            n = pf.metadata.num_rows
            for i in range(n):
                self.index.append((fpath, i))
    
    def __len__(self):
        return len(self.index)
    
    def __getitem__(self, idx):
        fpath, row_idx = self.index[idx]
        table = pq.read_table(fpath, columns=['stamp_npz', 'is_control'])
        
        # Decode stamp
        npz_bytes = table['stamp_npz'][row_idx].as_py()
        with np.load(io.BytesIO(npz_bytes)) as npz:
            img_g = npz['image_g'].astype(np.float32)
            img_r = npz['image_r'].astype(np.float32)
            img_z = npz['image_z'].astype(np.float32)
        
        # Stack and normalize
        img = np.stack([img_g, img_r, img_z], axis=0)
        for c in range(3):
            med = np.nanmedian(img[c])
            mad = np.nanmedian(np.abs(img[c] - med)) + 1e-8
            img[c] = (img[c] - med) / (mad * 1.4826)
        img = np.nan_to_num(img, nan=0.0)
        img = np.clip(img, -10, 10)
        
        # Augmentation
        if self.augment:
            if np.random.rand() > 0.5:
                img = img[:, ::-1, :].copy()
            if np.random.rand() > 0.5:
                img = img[:, :, ::-1].copy()
            k = np.random.randint(4)
            img = np.rot90(img, k, axes=(1, 2)).copy()
        
        # Label: is_control=1 means no lens (negative), is_control=0 means injection (positive)
        is_control = table['is_control'][row_idx].as_py()
        label = 0.0 if is_control == 1 else 1.0
        
        return torch.from_numpy(img), torch.tensor(label, dtype=torch.float32)

print("‚úÖ Training code defined")


In [None]:
#@title 6. Run Training
# Configuration
EPOCHS = 10  #@param {type:"integer"}
BATCH_SIZE = 32  #@param {type:"integer"}
LEARNING_RATE = 0.0003  #@param {type:"number"}
VAL_SPLIT = 0.1  #@param {type:"number"}

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load data - shuffle files first to mix classes
import random
parquet_files = sorted(glob.glob('/content/data/stamps/*.parquet'))
random.seed(42)
random.shuffle(parquet_files)

n_val = max(1, int(len(parquet_files) * VAL_SPLIT))
val_files = parquet_files[:n_val]
train_files = parquet_files[n_val:]

print(f"Train files: {len(train_files)}, Val files: {len(val_files)}")

train_dataset = StampDataset(train_files, augment=True)
val_dataset = StampDataset(val_files, augment=False)

print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

# Check class distribution
train_labels = [train_dataset[i][1].item() for i in range(min(1000, len(train_dataset)))]
val_labels_check = [val_dataset[i][1].item() for i in range(min(500, len(val_dataset)))]
print(f"Train class balance (sample): {sum(train_labels)}/{len(train_labels)} positives")
print(f"Val class balance (sample): {sum(val_labels_check)}/{len(val_labels_check)} positives")

if len(set(val_labels_check)) < 2:
    print("‚ö†Ô∏è WARNING: Validation set has only ONE class! AUROC will be undefined.")
    print("   This is fine for smoke test - training still works, metrics just won't compute.")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Model
model = LensFinderCNN().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
criterion = nn.BCEWithLogitsLoss()

# Training loop
os.makedirs('/content/checkpoints', exist_ok=True)
best_auroc = -1.0  # Start at -1 so we always save at least once
best_loss = float('inf')

for epoch in range(EPOCHS):
    # Train
    model.train()
    train_loss = 0.0
    for imgs, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{EPOCHS} [Train]'):
        imgs, labels = imgs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        logits = model(imgs).squeeze(-1)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        train_loss += loss.item()
    
    scheduler.step()
    train_loss /= len(train_loader)
    
    # Validate
    model.eval()
    val_preds, val_labels = [], []
    val_loss = 0.0
    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc=f'Epoch {epoch+1}/{EPOCHS} [Val]'):
            imgs, labels = imgs.to(device), labels.to(device)
            logits = model(imgs).squeeze(-1)
            loss = criterion(logits, labels)
            val_loss += loss.item()
            val_preds.extend(torch.sigmoid(logits).cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
    
    val_loss /= len(val_loader)
    
    # Calculate AUROC only if both classes present
    if len(set(val_labels)) >= 2:
        auroc = roc_auc_score(val_labels, val_preds)
    else:
        auroc = float('nan')
    
    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}, AUROC={auroc:.4f}")
    
    # Save best model - use AUROC if available, else use val_loss
    save_checkpoint = False
    if not np.isnan(auroc) and auroc > best_auroc:
        best_auroc = auroc
        save_checkpoint = True
        save_reason = f"AUROC={auroc:.4f}"
    elif np.isnan(auroc) and val_loss < best_loss:
        best_loss = val_loss
        save_checkpoint = True
        save_reason = f"Val Loss={val_loss:.4f}"
    
    if save_checkpoint:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'auroc': auroc if not np.isnan(auroc) else None,
            'val_loss': val_loss,
        }, '/content/checkpoints/checkpoint_best.pt')
        print(f"  ‚úÖ New best model saved! {save_reason}")
    
    # Always save last checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'auroc': auroc if not np.isnan(auroc) else None,
        'val_loss': val_loss,
    }, '/content/checkpoints/checkpoint_last.pt')

print(f"\nüéâ Training complete!")
print(f"   Best AUROC: {best_auroc:.4f}" if best_auroc > 0 else "   Best Val Loss: {best_loss:.4f}")


In [None]:
#@title 7. Upload Results to S3
import boto3
import os
from datetime import datetime

s3 = boto3.client('s3')
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')

# Find available checkpoint
checkpoint_paths = [
    '/content/checkpoints/checkpoint_best.pt',
    '/content/checkpoints/checkpoint_last.pt'
]

checkpoint_path = None
for path in checkpoint_paths:
    if os.path.exists(path):
        checkpoint_path = path
        break

if checkpoint_path is None:
    print("‚ùå No checkpoint found! Training may have failed.")
else:
    # Upload best/last checkpoint
    s3_key = f'phase5/models/colab/checkpoint_{timestamp}.pt'
    print(f"Uploading {checkpoint_path} to s3://darkhaloscope/{s3_key}...")
    s3.upload_file(checkpoint_path, 'darkhaloscope', s3_key)
    print(f"‚úÖ Uploaded to s3://darkhaloscope/{s3_key}")

    # Also save as latest
    s3.upload_file(checkpoint_path, 'darkhaloscope', 'phase5/models/colab/checkpoint_latest.pt')
    print(f"‚úÖ Uploaded to s3://darkhaloscope/phase5/models/colab/checkpoint_latest.pt")


---
## Optional: Download Full Dataset

If the smoke test above works, run this cell to download ALL data (~50GB, ~30 min).
Then re-run cells 4-7 with more epochs.


In [None]:
#@title 8. Download Full Dataset (Optional - ~50GB)
# ‚ö†Ô∏è Only run this after smoke test passes!

import boto3
import os
from tqdm import tqdm

BUCKET = 'darkhaloscope'
S3_PREFIX = 'phase4_pipeline/v3_color_relaxed/stage4c/train_stamp64_v2/stamps/'
LOCAL_DIR = '/content/data/stamps'

s3 = boto3.client('s3')

# List ALL files
print("Listing all files (this may take a minute)...")
paginator = s3.get_paginator('list_objects_v2')
all_files = []
for page in paginator.paginate(Bucket=BUCKET, Prefix=S3_PREFIX):
    for obj in page.get('Contents', []):
        if obj['Key'].endswith('.parquet'):
            all_files.append(obj)

total_size = sum(f['Size'] for f in all_files) / (1024**3)
print(f"Total: {len(all_files)} files, {total_size:.2f} GB")

# Download all
for obj in tqdm(all_files, desc='Downloading'):
    key = obj['Key']
    filename = key.split('/')[-1]
    local_path = os.path.join(LOCAL_DIR, filename)
    if not os.path.exists(local_path):
        s3.download_file(BUCKET, key, local_path)

print(f"\n‚úÖ Full dataset downloaded! Now re-run cells 4-7 with EPOCHS=50")
