In [None]:
# This mounts your Google Drive to the Colab VM.
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install wandb -q

In [None]:
import wandb
wandb.login()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import wandb
from tqdm import tqdm
import os
from pathlib import Path

# ==================== Model Architecture ====================
class ConditionalInstanceNorm2dPlus(nn.Module):
    def __init__(self, num_features, num_classes, bias=True):
        super().__init__()
        self.num_features = num_features
        self.bias = bias
        self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False)
        if bias:
            self.embed = nn.Embedding(num_classes, num_features * 3)
            self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02)
            self.embed.weight.data[:, 2 * num_features:].zero_()
        else:
            self.embed = nn.Embedding(num_classes, 2 * num_features)
            self.embed.weight.data.normal_(1, 0.02)

    def forward(self, x, y):
        h = self.instance_norm(x)
        if self.bias:
            gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
            out = gamma.view(-1, self.num_features, 1, 1) * h + alpha.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1)
        else:
            gamma, alpha = self.embed(y).chunk(2, dim=-1)
            out = gamma.view(-1, self.num_features, 1, 1) * h + alpha.view(-1, self.num_features, 1, 1) * x
        return out


class ConditionalResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_classes, resample=None, activation=nn.ELU()):
        super().__init__()
        self.activation = activation
        self.resample = resample
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.norm1 = ConditionalInstanceNorm2dPlus(in_channels, num_classes)
        self.norm2 = ConditionalInstanceNorm2dPlus(out_channels, num_classes)

        if resample == 'down':
            self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
            self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1, stride=2)
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=2)
        elif resample == 'up':
            self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, 3, padding=1)
            self.conv2 = nn.ConvTranspose2d(out_channels, out_channels, 3, padding=1, stride=2, output_padding=1)
            self.shortcut = nn.ConvTranspose2d(in_channels, out_channels, 1, stride=2, output_padding=1)
        else:
            self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
            self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else None

    def forward(self, x, y):
        h = self.norm1(x, y)
        h = self.activation(h)
        h = self.conv1(h)
        h = self.norm2(h, y)
        h = self.activation(h)
        h = self.conv2(h)

        if self.shortcut is not None:
            return h + self.shortcut(x)
        else:
            return h + x


class RefineNet(nn.Module):
    def __init__(self, in_channels, out_channels, num_classes):
        super().__init__()
        self.refine1 = ConditionalResidualBlock(in_channels, out_channels, num_classes, resample=None)
        self.refine2 = ConditionalResidualBlock(out_channels, out_channels, num_classes, resample=None)

    def forward(self, x, y):
        h = self.refine1(x, y)
        h = self.refine2(h, y)
        return h


class NCSNModel(nn.Module):
    def __init__(self, num_classes=10, ngf=128):
        super().__init__()
        self.num_classes = num_classes
        self.ngf = ngf
        self.activation = nn.ELU()

        # Initial: RGB -> Features
        self.begin_conv = nn.Conv2d(3, ngf, 3, padding=1)

        # Downsampling
        self.res1 = ConditionalResidualBlock(ngf, ngf, num_classes, resample='down')
        self.res2 = ConditionalResidualBlock(ngf, 2*ngf, num_classes, resample='down')
        self.res3 = ConditionalResidualBlock(2*ngf, 2*ngf, num_classes, resample='down')

        # Middle
        self.res4 = ConditionalResidualBlock(2*ngf, 2*ngf, num_classes, resample=None)

        # Upsampling
        self.refine1 = RefineNet(2*ngf, 2*ngf, num_classes)
        self.res5 = ConditionalResidualBlock(2*ngf, 2*ngf, num_classes, resample='up')
        self.refine2 = RefineNet(2*ngf, 2*ngf, num_classes)
        self.res6 = ConditionalResidualBlock(2*ngf, ngf, num_classes, resample='up')
        self.refine3 = RefineNet(ngf, ngf, num_classes)
        self.res7 = ConditionalResidualBlock(ngf, ngf, num_classes, resample='up')
        self.refine4 = RefineNet(ngf, ngf, num_classes)

        # Final Features -> RGB
        self.norm_final = ConditionalInstanceNorm2dPlus(ngf, num_classes)
        self.end_conv = nn.Conv2d(ngf, 3, 3, padding=1)

    def forward(self, x, y):
        h = self.begin_conv(x)

        h1 = self.res1(h, y)
        h2 = self.res2(h1, y)
        h3 = self.res3(h2, y)


        h = self.res4(h3, y)


        h = self.refine1(h, y)
        h = self.res5(h, y)
        h = h + h2
        h = self.refine2(h, y)
        h = self.res6(h, y)
        h = h + h1
        h = self.refine3(h, y)
        h = self.res7(h, y)
        h = self.refine4(h, y)


        h = self.norm_final(h, y)
        h = self.activation(h)
        h = self.end_conv(h)

        return h


# ==================== Loss Functions ====================
def anneal_dsm_score_estimation(scorenet, samples, sigmas, anneal_power=2.0):
    batch_size = samples.shape[0]

    labels = torch.randint(0, len(sigmas), (batch_size,), device=samples.device)
    used_sigmas = sigmas[labels].view(batch_size, *([1] * len(samples.shape[1:])))

    noise = torch.randn_like(samples) * used_sigmas
    perturbed_samples = samples + noise

    predicted_score = scorenet(perturbed_samples, labels)
    target = -noise / (used_sigmas ** 2)

    losses = 0.5 * ((predicted_score - target) ** 2).sum(dim=(1, 2, 3))
    loss_weights = (used_sigmas.squeeze() ** anneal_power)
    weighted_loss = (losses * loss_weights).mean()



    unweighted_loss = losses.mean()

    channel_losses = 0.5 * ((predicted_score - target) ** 2).sum(dim=(2, 3))
    loss_per_channel = {
        'loss_channel_0': channel_losses[:, 0].mean().item(),
        'loss_channel_1': channel_losses[:, 1].mean().item(),
        'loss_channel_2': channel_losses[:, 2].mean().item(),
    }

    unique_labels = torch.unique(labels)
    loss_per_sigma = {}
    for lbl in unique_labels:
        mask = labels == lbl
        if mask.sum() > 0:
            loss_per_sigma[f'loss_sigma_{lbl.item()}'] = losses[mask].mean().item()

    l1_loss = torch.abs(predicted_score - target).sum(dim=(1, 2, 3)).mean()

    grad_norm = torch.norm(predicted_score.view(batch_size, -1), dim=1).mean()
    target_norm = torch.norm(target.view(batch_size, -1), dim=1).mean()

    loss_dict = {
        'loss_weighted': weighted_loss.item(),
        'loss_unweighted': unweighted_loss.item(),
        'loss_l1': l1_loss.item(),
        'grad_norm': grad_norm.item(),
        'target_norm': target_norm.item(),
        'sigma_mean': used_sigmas.mean().item(),
        'sigma_std': used_sigmas.std().item(),
        **loss_per_channel,
        **loss_per_sigma
    }

    return weighted_loss, loss_dict


# ==================== Sampling ====================
@torch.no_grad()
def anneal_langevin_dynamics(scorenet, x_init, sigmas, n_steps_each=100, step_lr=0.00002,
                             log_intermediate=False, log_interval=10):
    x = x_init.clone()
    intermediate_images = []

    for idx, sigma in enumerate(sigmas):
        sigma_val = sigma.item()
        labels = torch.ones(x.shape[0], device=x.device, dtype=torch.long) * idx
        step_size = step_lr * (sigma_val / sigmas[-1].item()) ** 2

        for step in range(n_steps_each):
            noise = torch.randn_like(x) * np.sqrt(step_size * 2)
            grad = scorenet(x, labels)
            x = x + step_size * grad + noise

            if log_intermediate and (step % log_interval == 0 or step == n_steps_each - 1):
                intermediate_images.append({
                    'sigma_idx': idx,
                    'step': step,
                    'image': x.clone()
                })

    if log_intermediate:
        return x, intermediate_images
    return x


# ==================== Configuration ====================
class NCSNConfig:
    num_classes = 10
    ngf = 128

    sigma_begin = 1.0
    sigma_end = 0.01

    batch_size = 128
    num_epochs = 100
    lr = 0.001
    lr_decay_factor = 0.1
    lr_decay_epochs = [70, 90]

    n_steps_each = 100
    step_lr = 0.00002

    image_size = 32
    num_workers = 4
    prefetch_factor = 2

    log_interval = 50
    sample_interval = 500
    num_samples = 64

    checkpoint_dir = '/content/drive/MyDrive/cs236/assignments/final/checkpoints/'
    save_every_n_epochs = 10
    resume_from_checkpoint = None


# ==================== Data Loading ====================
def get_cifar10_dataloaders(batch_size, num_workers=4, prefetch_factor=2):
    """Load CIFAR-10 dataset with optimizations"""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_dataset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,
        prefetch_factor=prefetch_factor,
        persistent_workers=True if num_workers > 0 else False
    )

    return train_loader


# ==================== Visualization ====================
def denormalize(x):
    """Denormalize from [-1, 1] to [0, 1]"""
    return (x + 1) / 2


def save_sample_images(samples, filename, nrow=8):
    """Save generated samples as image grid"""
    samples = denormalize(samples)
    samples = torch.clamp(samples, 0, 1)
    grid = torchvision.utils.make_grid(samples, nrow=nrow, padding=2)

    Path(filename).parent.mkdir(parents=True, exist_ok=True)

    plt.figure(figsize=(12, 12))
    plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(filename, bbox_inches='tight', dpi=150)
    plt.close()
    return grid


def create_intermediate_grid(intermediate_images, num_samples=8):
    if not intermediate_images:
        return None

    selected_imgs = []
    for img_data in intermediate_images[::max(1, len(intermediate_images)//10)]:
        imgs = img_data['image'][:num_samples]
        selected_imgs.append(denormalize(torch.clamp(imgs, -1, 1)))

    if selected_imgs:
        all_imgs = torch.cat(selected_imgs, dim=0)
        grid = torchvision.utils.make_grid(all_imgs, nrow=num_samples, padding=2)
        return grid
    return None


# ==================== Checkpoint Management ====================
def save_checkpoint(model, optimizer, scheduler, epoch, loss, sigmas, filepath):
    Path(filepath).parent.mkdir(parents=True, exist_ok=True)

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,
        'sigmas': sigmas.cpu() if torch.is_tensor(sigmas) else sigmas,
        'config': {
            'num_classes': NCSNConfig.num_classes,
            'ngf': NCSNConfig.ngf,
            'sigma_begin': NCSNConfig.sigma_begin,
            'sigma_end': NCSNConfig.sigma_end,
        }
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved: {filepath} (epoch {epoch})")


def load_checkpoint(filepath, model, optimizer=None, scheduler=None, device='cpu'):
    if not os.path.exists(filepath):
        print(f"Checkpoint not found: {filepath}")
        return 0

    checkpoint = torch.load(filepath, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if scheduler is not None and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    epoch = checkpoint.get('epoch', 0)
    print(f"Checkpoint loaded from epoch {epoch}")
    return epoch + 1


def find_latest_checkpoint(checkpoint_dir):
    checkpoint_dir = Path(checkpoint_dir)

    if not checkpoint_dir.exists():
        print(f"Checkpoint directory does not exist: {checkpoint_dir}")
        return None

    latest_files = list(checkpoint_dir.glob('ncsn_latest_*.pth'))

    if latest_files:
        latest_checkpoint = None
        latest_epoch = -1

        for ckpt_path in latest_files:
            try:
                epoch_str = ckpt_path.stem.split('_')[-1]
                epoch_num = int(epoch_str)

                if epoch_num > latest_epoch:
                    latest_epoch = epoch_num
                    latest_checkpoint = ckpt_path
            except (ValueError, IndexError):
                continue

        if latest_checkpoint:
            print(f"Found latest checkpoint: {latest_checkpoint} (Epoch {latest_epoch})")
            return str(latest_checkpoint)

    checkpoint_files = list(checkpoint_dir.glob('ncsn_epoch_*.pth'))

    if not checkpoint_files:
        print("No checkpoint files found in directory")
        return None

    latest_checkpoint = None
    latest_epoch = -1

    for ckpt_path in checkpoint_files:
        try:
            epoch_str = ckpt_path.stem.split('_')[-1]
            epoch_num = int(epoch_str)

            if epoch_num > latest_epoch:
                latest_epoch = epoch_num
                latest_checkpoint = ckpt_path
        except (ValueError, IndexError):
            continue

    if latest_checkpoint:
        print(f"Found latest checkpoint: {latest_checkpoint} (Epoch {latest_epoch})")
        return str(latest_checkpoint)

    return None


# ==================== Train ====================
def train_ncsn():
    wandb.init(
        entity="tourists",
        project="ncsn-cifar10",
        name="baseline-ncsn",
        config={
            "model": "NCSN",
            "dataset": "CIFAR-10",
            "num_classes": NCSNConfig.num_classes,
            "ngf": NCSNConfig.ngf,
            "batch_size": NCSNConfig.batch_size,
            "num_epochs": NCSNConfig.num_epochs,
            "lr": NCSNConfig.lr,
            "sigma_begin": NCSNConfig.sigma_begin,
            "sigma_end": NCSNConfig.sigma_end,
            "n_steps_each": NCSNConfig.n_steps_each,
            "step_lr": NCSNConfig.step_lr
        }
    )

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True

    sigmas = torch.tensor(
        np.exp(np.linspace(
            np.log(NCSNConfig.sigma_begin),
            np.log(NCSNConfig.sigma_end),
            NCSNConfig.num_classes
        ))
    ).float().to(device)
    print(f"Noise levels (sigmas): {sigmas}")


    model = NCSNModel(num_classes=NCSNConfig.num_classes, ngf=NCSNConfig.ngf).to(device)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    optimizer = torch.optim.Adam(model.parameters(), lr=NCSNConfig.lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=NCSNConfig.lr_decay_epochs, gamma=NCSNConfig.lr_decay_factor
    )

    start_epoch = 0
    if NCSNConfig.resume_from_checkpoint:
        start_epoch = load_checkpoint(
            NCSNConfig.resume_from_checkpoint, model, optimizer, scheduler, device
        )
    else:
        latest_ckpt = find_latest_checkpoint(NCSNConfig.checkpoint_dir)
        if latest_ckpt:
            print(f"\n{'='*60}")
            print(f"RESUMING TRAINING from {latest_ckpt}")
            print(f"{'='*60}\n")
            start_epoch = load_checkpoint(latest_ckpt, model, optimizer, scheduler, device)
        else:
            print(f"\n{'='*60}")
            print(f"STARTING FRESH TRAINING (no checkpoints found)")
            print(f"{'='*60}\n")

    train_loader = get_cifar10_dataloaders(
        NCSNConfig.batch_size,
        NCSNConfig.num_workers,
        NCSNConfig.prefetch_factor
    )

    # ==================== Training loop ====================
    global_step = start_epoch * len(train_loader)

    for epoch in range(start_epoch, NCSNConfig.num_epochs):
        model.train()
        epoch_loss = 0
        epoch_metrics = {}

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NCSNConfig.num_epochs}")

        for batch_idx, (images, _) in enumerate(pbar):
            images = images.to(device, non_blocking=True)

            # FORWARD
            optimizer.zero_grad(set_to_none=True)
            loss, loss_dict = anneal_dsm_score_estimation(model, images, sigmas)

            # BACKWARD
            loss.backward()

            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            loss_dict['grad_norm_clipped'] = grad_norm.item()

            optimizer.step()


            epoch_loss += loss.item()
            global_step += 1

            for key, val in loss_dict.items():
                if key not in epoch_metrics:
                    epoch_metrics[key] = []
                epoch_metrics[key].append(val)

            if global_step % NCSNConfig.log_interval == 0:
                log_dict = {
                    "train/loss": loss.item(),
                    "train/epoch": epoch,
                    "train/lr": optimizer.param_groups[0]['lr'],
                    "train/global_step": global_step,
                }
                for key, val in loss_dict.items():
                    log_dict[f"train/{key}"] = val

                wandb.log(log_dict, step=global_step)

            if global_step % NCSNConfig.sample_interval == 0:
                model.eval()
                with torch.no_grad():
                    x_init = torch.randn(NCSNConfig.num_samples, 3, 32, 32).to(device) * sigmas[0]

                    # SAMPLING
                    samples, intermediate_imgs = anneal_langevin_dynamics(
                        model, x_init, sigmas,
                        n_steps_each=NCSNConfig.n_steps_each,
                        step_lr=NCSNConfig.step_lr,
                        log_intermediate=True,
                        log_interval=20
                    )

                    sample_path = f'{NCSNConfig.checkpoint_dir}samples/ncsn_samples_step_{global_step}.png'
                    grid = save_sample_images(samples, sample_path)

                    prog_grid = create_intermediate_grid(intermediate_imgs, num_samples=8)

                    if prog_grid is not None:
                        prog_path = f'{NCSNConfig.checkpoint_dir}samples/ncsn_progression_step_{global_step}.png'
                        Path(prog_path).parent.mkdir(parents=True, exist_ok=True)
                        plt.figure(figsize=(16, 10))
                        plt.imshow(prog_grid.permute(1, 2, 0).cpu().numpy())
                        plt.axis('off')
                        plt.title(f'Sampling Progression - Step {global_step}')
                        plt.tight_layout()
                        plt.savefig(prog_path, bbox_inches='tight', dpi=150)
                        plt.close()

                    wandb_logs = {
                        "samples/final": wandb.Image(
                            grid.permute(1, 2, 0).cpu().numpy(),
                            caption=f"Final samples - Step {global_step}"
                        )
                    }

                    if prog_grid is not None:
                        wandb_logs["samples/progression"] = wandb.Image(
                            prog_grid.permute(1, 2, 0).cpu().numpy(),
                            caption=f"Sampling progression - Step {global_step}"
                        )

                    sample_images = denormalize(torch.clamp(samples[:16], -1, 1))
                    for idx in range(min(16, sample_images.shape[0])):
                        wandb_logs[f"samples/individual_{idx}"] = wandb.Image(
                            sample_images[idx].permute(1, 2, 0).cpu().numpy(),
                            caption=f"Sample {idx} - Step {global_step}"
                        )

                    wandb.log(wandb_logs, step=global_step)
                    print(f"\nGenerated and logged {NCSNConfig.num_samples} samples at step {global_step}")

                model.train()

            pbar.set_postfix({"loss": loss.item()})

        avg_epoch_loss = epoch_loss / len(train_loader)
        print(f"\nEpoch {epoch+1} - Average Loss: {avg_epoch_loss:.4f}")

        epoch_avg_metrics = {
            f"epoch/{key}": np.mean(vals)
            for key, vals in epoch_metrics.items()
        }
        epoch_avg_metrics["epoch/avg_loss"] = avg_epoch_loss
        epoch_avg_metrics["epoch/number"] = epoch + 1

        wandb.log(epoch_avg_metrics, step=global_step)

        scheduler.step()

        if (epoch + 1) % NCSNConfig.save_every_n_epochs == 0:
            checkpoint_path = f'{NCSNConfig.checkpoint_dir}ncsn_epoch_{epoch+1:03d}.pth'
            save_checkpoint(model, optimizer, scheduler, epoch, avg_epoch_loss, sigmas, checkpoint_path)

        latest_path = f'{NCSNConfig.checkpoint_dir}ncsn_latest_{epoch+1:03d}.pth'

        old_latest_files = list(Path(NCSNConfig.checkpoint_dir).glob('ncsn_latest_*.pth'))
        for old_file in old_latest_files:
            try:
                old_file.unlink()
            except:
                pass

        save_checkpoint(model, optimizer, scheduler, epoch, avg_epoch_loss, sigmas, latest_path)

    final_path = f'{NCSNConfig.checkpoint_dir}ncsn_final.pth'
    torch.save({
        'model_state_dict': model.state_dict(),
        'sigmas': sigmas.cpu(),
        'config': {
            'num_classes': NCSNConfig.num_classes,
            'ngf': NCSNConfig.ngf,
        }
    }, final_path)

    print(f"\n{'='*60}")
    print(f"Training completed successfully!")
    print(f"Final model saved to: {final_path}")
    print(f"Total checkpoints saved: {NCSNConfig.num_epochs // NCSNConfig.save_every_n_epochs}")
    print(f"{'='*60}\n")
    wandb.finish()

    return model, sigmas


# ==================== Run Training ====================
if __name__ == "__main__":
    model, sigmas = train_ncsn()
    print("Training completed!")

