# NCSN Training on CIFAR-10
**Noise Conditional Score Network**

With automatic checkpoint resume for Colab interruptions.

## 1. Setup & Mount Drive

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

# Create project folder in Drive
import os
DRIVE_PATH = '/content/drive/MyDrive/ML2_NCSN'
os.makedirs(DRIVE_PATH, exist_ok=True)
os.makedirs(f'{DRIVE_PATH}/checkpoints', exist_ok=True)
os.makedirs(f'{DRIVE_PATH}/samples', exist_ok=True)
print(f"Saving to: {DRIVE_PATH}")

In [None]:
import os
os.chdir('/content')
!rm -rf ML2_final GM-final
!git clone https://github.com/5w7Tch/GM-final.git
%cd GM-final
!pwd
!ls src/

In [None]:
!pip install wandb -q

In [None]:
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
import torch
import torch.optim as optim
import os
import glob
from datetime import datetime
from tqdm.auto import tqdm

from src.models import NCSN, get_sigmas
from src.losses import anneal_dsm_loss
from src.sampling import generate_samples
from src.data import get_dataloader, denormalize
from src.utils import EMA, show_samples, save_samples

In [None]:
# Wandb (optional but recommended)
USE_WANDB = True

if USE_WANDB:
    import wandb
    wandb.login()

## 2. Configuration

In [None]:
config = {
    # Model
    'num_features': 128,
    'num_classes': 10,
    
    # Noise schedule
    'sigma_begin': 1.0,
    'sigma_end': 0.01,
    
    # Training
    'epochs': 200,
    'batch_size': 128,
    'lr': 1e-4,
    'ema_decay': 0.999,
    
    # Sampling
    'n_steps_each': 100,
    'step_lr': 2e-5,
    
    # Checkpointing (IMPORTANT for Colab)
    'save_every_n_epochs': 5,      # Save checkpoint every N epochs
    'sample_every': 10,            # Generate samples every N epochs
    'keep_last_n_checkpoints': 3,  # Keep only last N checkpoints to save space
    
    'seed': 42
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 3. Checkpoint Utilities

In [None]:
def get_latest_checkpoint(checkpoint_dir):
    """Find the latest checkpoint file."""
    checkpoints = glob.glob(f"{checkpoint_dir}/epoch_*.pt")
    if not checkpoints:
        return None
    # Sort by epoch number
    checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
    return checkpoints[-1]


def save_checkpoint(path, model, ema, optimizer, scheduler, epoch, global_step, sigmas, config):
    """Save complete training state."""
    torch.save({
        'epoch': epoch,
        'global_step': global_step,
        'model_state_dict': model.state_dict(),
        'ema_shadow': ema.shadow,
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'sigmas': sigmas.cpu(),
        'config': config
    }, path)
    print(f"âœ“ Saved checkpoint: {path}")


def load_checkpoint(path, model, ema, optimizer, scheduler, device):
    """Load training state from checkpoint."""
    checkpoint = torch.load(path, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    ema.shadow = {k: v.to(device) for k, v in checkpoint['ema_shadow'].items()}
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    return checkpoint['epoch'], checkpoint['global_step']


def cleanup_old_checkpoints(checkpoint_dir, keep_n=3):
    """Keep only the last N checkpoints to save Drive space."""
    checkpoints = glob.glob(f"{checkpoint_dir}/epoch_*.pt")
    if len(checkpoints) <= keep_n:
        return
    
    checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
    for ckpt in checkpoints[:-keep_n]:
        os.remove(ckpt)
        print(f"Removed old checkpoint: {ckpt}")

## 4. Initialize Model

In [None]:
torch.manual_seed(config['seed'])

# Data
train_loader = get_dataloader(batch_size=config['batch_size'], train=True)
print(f"Training batches per epoch: {len(train_loader)}")

# Model
model = NCSN(
    num_classes=config['num_classes'],
    num_features=config['num_features']
).to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Sigmas
sigmas = get_sigmas(
    config['sigma_begin'],
    config['sigma_end'],
    config['num_classes']
).to(device)

# Optimizer & Scheduler
optimizer = optim.Adam(model.parameters(), lr=config['lr'])
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config['epochs'] * len(train_loader)
)

# EMA
ema = EMA(model, decay=config['ema_decay'])

## 5. Resume from Checkpoint (if exists)

In [None]:
# Check for existing checkpoint
CHECKPOINT_DIR = f"{DRIVE_PATH}/checkpoints"
SAMPLE_DIR = f"{DRIVE_PATH}/samples"

latest_ckpt = get_latest_checkpoint(CHECKPOINT_DIR)

if latest_ckpt:
    print(f"\nðŸ”„ Found checkpoint: {latest_ckpt}")
    start_epoch, global_step = load_checkpoint(
        latest_ckpt, model, ema, optimizer, scheduler, device
    )
    start_epoch += 1  # Start from next epoch
    print(f"âœ“ Resuming from epoch {start_epoch}, step {global_step}")
else:
    print("\nðŸ†• No checkpoint found. Starting fresh.")
    start_epoch = 0
    global_step = 0

In [None]:
# Initialize wandb (with resume support)
if USE_WANDB:
    # Use a fixed run ID so we can resume the same run
    RUN_ID = 'ncsn_cifar10_main'  # Change this if you want a new run
    
    wandb.init(
        project='ML2-NCSN',
        id=RUN_ID,
        resume='allow',  # Resume if run exists
        config=config
    )

## 6. Training Loop

In [None]:
print(f"\n{'='*50}")
print(f"Training from epoch {start_epoch} to {config['epochs']}")
print(f"Checkpoints saved to: {CHECKPOINT_DIR}")
print(f"{'='*50}\n")

for epoch in range(start_epoch, config['epochs']):
    model.train()
    epoch_loss = 0.0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
    
    for images, _ in pbar:
        images = images.to(device)
        
        # Forward
        loss = anneal_dsm_loss(model, images, sigmas)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        ema.update()
        
        epoch_loss += loss.item()
        global_step += 1
        
        pbar.set_postfix(loss=f'{loss.item():.4f}', lr=f'{scheduler.get_last_lr()[0]:.2e}')
        
        if USE_WANDB and global_step % 50 == 0:
            wandb.log({'loss': loss.item(), 'lr': scheduler.get_last_lr()[0]}, step=global_step)
    
    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}")
    
    if USE_WANDB:
        wandb.log({'epoch_loss': avg_loss, 'epoch': epoch+1}, step=global_step)
    
    # === SAVE CHECKPOINT ===
    if (epoch + 1) % config['save_every_n_epochs'] == 0:
        ckpt_path = f"{CHECKPOINT_DIR}/epoch_{epoch+1:04d}.pt"
        save_checkpoint(
            ckpt_path, model, ema, optimizer, scheduler,
            epoch, global_step, sigmas, config
        )
        cleanup_old_checkpoints(CHECKPOINT_DIR, config['keep_last_n_checkpoints'])
    
    # === GENERATE SAMPLES ===
    if (epoch + 1) % config['sample_every'] == 0:
        print("Generating samples...")
        ema.apply_shadow()
        model.eval()
        
        samples = generate_samples(
            model, sigmas, n_samples=64,
            n_steps_each=config['n_steps_each'],
            step_lr=config['step_lr'],
            device=device
        )
        
        sample_path = f"{SAMPLE_DIR}/epoch_{epoch+1:04d}.png"
        save_samples(samples, sample_path)
        show_samples(samples, title=f'Epoch {epoch+1}')
        
        if USE_WANDB:
            wandb.log({'samples': wandb.Image(sample_path)}, step=global_step)
        
        ema.restore()
        model.train()

print("\n" + "="*50)
print("Training complete!")
print("="*50)

## 7. Save Final Model

In [None]:
# Save final model with EMA weights
ema.apply_shadow()

torch.save({
    'model_state_dict': model.state_dict(),
    'sigmas': sigmas.cpu(),
    'config': config
}, f"{DRIVE_PATH}/final_model.pt")

print(f"âœ“ Final model saved to {DRIVE_PATH}/final_model.pt")

if USE_WANDB:
    wandb.finish()

## 8. Generate Final Samples

In [None]:
model.eval()

# High quality samples with more steps
final_samples = generate_samples(
    model, sigmas, n_samples=64,
    n_steps_each=200,
    step_lr=2e-5,
    device=device
)

save_samples(final_samples, f"{DRIVE_PATH}/final_samples.png")
show_samples(final_samples, title='Final Generated Samples')

---
## Quick Resume Guide

If Colab disconnects:
1. **Reconnect** to a new runtime
2. **Run all cells from the top** - it will automatically detect and resume from the latest checkpoint in Drive

Your checkpoints are safely stored in Google Drive at:
`/content/drive/MyDrive/ML2_NCSN/checkpoints/`