# NCSN Training on CIFAR-10
**Noise Conditional Score Network**

This notebook trains NCSN for unconditional image generation on CIFAR-10.

## 1. Setup

In [None]:
# Clone repository
!git clone https://github.com/5w7Tch/GM-final.git
%cd GM-final

In [None]:
# Install dependencies
!pip install wandb tqdm -q

In [None]:
# Check GPU
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Imports from our repo
import torch
import torch.optim as optim
import os
from datetime import datetime
from tqdm.auto import tqdm

from src.models import NCSN, get_sigmas
from src.losses import anneal_dsm_loss
from src.sampling import generate_samples
from src.data import get_dataloader, denormalize
from src.utils import EMA, show_samples, save_samples

In [None]:
# Wandb (optional)
USE_WANDB = True

if USE_WANDB:
    import wandb
    wandb.login()

## 2. Configuration

In [None]:
config = {
    # Model
    'num_features': 128,
    'num_classes': 10,
    
    # Noise schedule
    'sigma_begin': 1.0,
    'sigma_end': 0.01,
    
    # Training
    'epochs': 200,
    'batch_size': 128,
    'lr': 1e-4,
    'ema_decay': 0.999,
    
    # Sampling
    'n_steps_each': 100,
    'step_lr': 2e-5,
    
    # Logging
    'sample_every': 10,
    'save_every': 25,
    
    'seed': 42
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 3. Initialize

In [None]:
# Seed
torch.manual_seed(config['seed'])

# Directories
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('samples', exist_ok=True)

# Data
train_loader = get_dataloader(batch_size=config['batch_size'], train=True)
print(f"Training batches: {len(train_loader)}")

In [None]:
# Model
model = NCSN(
    num_classes=config['num_classes'],
    num_features=config['num_features']
).to(device)

print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Sigmas
sigmas = get_sigmas(
    config['sigma_begin'],
    config['sigma_end'],
    config['num_classes']
).to(device)

print(f"Sigmas: {sigmas.cpu().numpy()}")

In [None]:
# Optimizer
optimizer = optim.Adam(model.parameters(), lr=config['lr'])
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config['epochs'] * len(train_loader)
)

# EMA
ema = EMA(model, decay=config['ema_decay'])

In [None]:
# Wandb
if USE_WANDB:
    wandb.init(
        project='ML2-NCSN',
        config=config,
        name=f'ncsn_{datetime.now().strftime("%m%d_%H%M")}'
    )

## 4. Training

In [None]:
global_step = 0

for epoch in range(config['epochs']):
    model.train()
    epoch_loss = 0.0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
    
    for images, _ in pbar:
        images = images.to(device)
        
        # Forward
        loss = anneal_dsm_loss(model, images, sigmas)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        ema.update()
        
        epoch_loss += loss.item()
        global_step += 1
        
        pbar.set_postfix(loss=f'{loss.item():.4f}')
        
        if USE_WANDB and global_step % 50 == 0:
            wandb.log({'loss': loss.item(), 'lr': scheduler.get_last_lr()[0]}, step=global_step)
    
    print(f"Epoch {epoch+1} - Loss: {epoch_loss/len(train_loader):.4f}")
    
    # Generate samples
    if (epoch + 1) % config['sample_every'] == 0:
        ema.apply_shadow()
        model.eval()
        
        samples = generate_samples(
            model, sigmas, n_samples=64,
            n_steps_each=config['n_steps_each'],
            step_lr=config['step_lr'],
            device=device
        )
        
        save_samples(samples, f'samples/epoch_{epoch+1:04d}.png')
        show_samples(samples, title=f'Epoch {epoch+1}')
        
        if USE_WANDB:
            wandb.log({'samples': wandb.Image(f'samples/epoch_{epoch+1:04d}.png')}, step=global_step)
        
        ema.restore()
    
    # Save checkpoint
    if (epoch + 1) % config['save_every'] == 0:
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'ema': ema.state_dict(),
            'optimizer': optimizer.state_dict(),
            'sigmas': sigmas.cpu(),
            'config': config
        }, f'checkpoints/epoch_{epoch+1:04d}.pt')

print("Training complete!")

if USE_WANDB:
    wandb.finish()

## 5. Final Samples

In [None]:
# Apply EMA weights
ema.apply_shadow()
model.eval()

# Generate with more steps
final_samples = generate_samples(
    model, sigmas, n_samples=64,
    n_steps_each=200,
    step_lr=2e-5,
    device=device
)

show_samples(final_samples, title='Final Samples')
save_samples(final_samples, 'samples/final.png')

In [None]:
# Save final model
torch.save({
    'model': model.state_dict(),
    'ema': ema.state_dict(),
    'sigmas': sigmas.cpu(),
    'config': config
}, 'checkpoints/final.pt')

## 6. Download Results

In [None]:
!zip -r results.zip checkpoints/ samples/

from google.colab import files
files.download('results.zip')