In [None]:
import os
import random
from argparse import ArgumentParser

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, WeightedRandomSampler
from torchvision import datasets, transforms
from torchvision.models import resnet18

from torchmetrics.image.fid import FID
from torchmetrics.image.kid import KID

# -----------------------------
# CVaR Utility
# -----------------------------
def cvar_loss(losses: torch.Tensor, alpha: float) -> torch.Tensor:
    """
    Approximate CVaR_alpha(L) = mean of the top (1-alpha) fraction of losses.
    """
    sorted_losses, _ = torch.sort(losses, descending=True)
    k = max(1, int((1 - alpha) * len(sorted_losses)))
    tail = sorted_losses[:k]
    return tail.mean()

# -----------------------------
# Dataset Loaders
# -----------------------------
def get_imbalanced_mnist(imbalance_ratio: float, batch_size: int,
                         baseline: str = 'none', oversample: bool = False):
    """
    Returns train and test loaders for MNIST with class imbalance.
    imbalance_ratio: fraction of samples to keep for classes 5-9 (minority)
    baseline: 'none', 'reweight', or 'focal'
    oversample: if True, uses WeightedRandomSampler to oversample minority
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    full_train = datasets.MNIST(root='data', train=True, download=True, transform=transform)
    test_loader = DataLoader(full_train, batch_size=batch_size, shuffle=False)

    idx_major, idx_minor = [], []
    for i, (_, label) in enumerate(full_train):
        if label < 5:
            idx_major.append(i)
        else:
            idx_minor.append(i)
    n_minor = int(len(full_train) * imbalance_ratio / 2)
    idx_minor = random.sample(idx_minor, n_minor)

    selected_idx = idx_major + idx_minor
    subset = Subset(full_train, selected_idx)

    if oversample:
        labels = [full_train[i][1] for i in selected_idx]
        class_counts = torch.bincount(torch.tensor(labels), minlength=10).float()
        weights = 1.0 / class_counts
        sample_weights = weights[labels]
        sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
        loader = DataLoader(subset, batch_size=batch_size, sampler=sampler)
    else:
        loader = DataLoader(subset, batch_size=batch_size, shuffle=True)

    return loader, test_loader

def get_celeba_rare(attr1: str, attr2: str, rare_ratio: float,
                    batch_size: int, oversample: bool = False):
    """
    Returns train loader for CelebA with rare attribute combination.
    """
    transform = transforms.Compose([
        transforms.CenterCrop(178),
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    celeba_root = 'data/celeba'
    full = datasets.CelebA(root=celeba_root, split='all', target_type='attr',
                           download=True, transform=transform)
    attr_idx = full.attr_names.index
    rare_idx = [i for i, a in enumerate(full) if a[1][attr_idx(attr1)] == 1 and a[1][attr_idx(attr2)] == 1]
    rest_idx = list(set(range(len(full))) - set(rare_idx))

    n_rare = len(rare_idx)
    n_rest = int(n_rare / rare_ratio)
    rest_idx = random.sample(rest_idx, n_rest)

    selected = rare_idx + rest_idx
    subset = Subset(full, selected)
    if oversample:
        labels = [1 if i in rare_idx else 0 for i in selected]
        counts = torch.bincount(torch.tensor(labels), minlength=2).float()
        weights = 1.0 / counts
        sample_weights = weights[labels]
        sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
        loader = DataLoader(subset, batch_size=batch_size, sampler=sampler)
    else:
        loader = DataLoader(subset, batch_size=batch_size, shuffle=True)
    return loader

# -----------------------------
# Models
# -----------------------------
class VAE(nn.Module):
    def __init__(self, latent_dim=20):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 400),
            nn.ReLU(),
        )
        self.mu_layer = nn.Linear(400, latent_dim)
        self.logvar_layer = nn.Linear(400, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 28*28),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.mu_layer(h), self.logvar_layer(h)

    def reparameterize(self, mu, logvar):
        std = (0.5 * logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        out = self.decoder(z)
        return out.view(-1, 1, 28, 28)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

class Generator(nn.Module):
    def __init__(self, latent_dim=100, channels=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, channels, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, z):
        z = z.view(z.size(0), z.size(1), 1, 1)
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, channels=3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(channels, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
        )

    def forward(self, x):
        return self.net(x).view(-1)

# -----------------------------
# Losses
# -----------------------------
def kl_divergence(mu, logvar):
    return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)

def focal_recon_loss(recon, x, gamma=2.0):
    bce = nn.functional.binary_cross_entropy(recon, x, reduction='none').view(recon.size(0), -1).sum(dim=1)
    p_t = torch.exp(-bce)
    return ((1 - p_t) ** gamma * bce)

# -----------------------------
# Training Functions
# -----------------------------
def train_vae(model, dataloader, optimizer, device, alpha=None, baseline='none'):
    model.train()
    for x, labels in dataloader:
        x = x.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(x)
        bs = x.size(0)
        recon_losses = None
        if baseline == 'focal':
            recon_losses = focal_recon_loss(recon, x)
        else:
            recon_losses = nn.functional.binary_cross_entropy(recon, x, reduction='none').view(bs, -1).sum(dim=1)

        if baseline == 'reweight':
            class_counts = torch.bincount(labels, minlength=10).float().to(device)
            weights = 1.0 / class_counts[labels]
            recon_losses = recon_losses * weights

        if alpha is None:
            recon_term = recon_losses.mean()
        else:
            recon_term = cvar_loss(recon_losses, alpha)
        kl_term = kl_divergence(mu, logvar).mean()
        loss = recon_term + kl_term
        loss.backward()
        optimizer.step()

def train_gan(generator, discriminator, dataloader, optim_G, optim_D, device, alpha=None):
    criterion = nn.ReLU()
    generator.train()
    discriminator.train()
    for x, _ in dataloader:
        x = x.to(device)
        bs = x.size(0)
        optim_D.zero_grad()
        real_scores = discriminator(x)
        z = torch.randn(bs, 100, device=device)
        fake = generator(z).detach()
        fake_scores = discriminator(fake)
        d_loss = torch.mean(criterion(1.0 - real_scores)) + torch.mean(criterion(1.0 + fake_scores))
        d_loss.backward()
        optim_D.step()

        optim_G.zero_grad()
        z = torch.randn(bs, 100, device=device)
        fake = generator(z)
        gen_losses = -discriminator(fake)
        if alpha is None:
            g_loss = gen_losses.mean()
        else:
            g_loss = cvar_loss(gen_losses, alpha)
        g_loss.backward()
        optim_G.step()

# -----------------------------
# Evaluation
# -----------------------------
@torch.no_grad()
def compute_fid(generator, real_loader, device):
    fid = FID().to(device)
    for x, _ in real_loader:
        fid.update(x.to(device), real=True)
    for x, _ in real_loader:
        z = torch.randn(x.size(0), 100, device=device)
        fid.update(generator(z), real=False)
    return fid.compute().item()

@torch.no_grad()
def compute_kid(generator, real_loader, device):
    kid = KID().to(device)
    for x, _ in real_loader:
        kid.update(x.to(device), real=True)
    for x, _ in real_loader:
        z = torch.randn(x.size(0), 100, device=device)
        kid.update(generator(z), real=False)
    return kid.compute().item()

@torch.no_grad()
def compute_coverage(generator, classifier, device, minority_labels, num_samples=10000):
    classifier.eval()
    count = 0
    total = 0
    for _ in range(num_samples // classifier.batch_size):
        z = torch.randn(classifier.batch_size, 100, device=device)
        fake = generator(z)
        preds = classifier(fake).argmax(dim=1)
        count += sum([1 for p in preds.cpu().numpy() if p in minority_labels])
        total += len(preds)
    return count / total

# -----------------------------
# Main Experiment
# -----------------------------
def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_loader, test_loader = get_imbalanced_mnist(
        imbalance_ratio=args.imbal_balance_ratio if hasattr(args, 'imbal_balance_ratio') else args.imbalance_ratio,
        batch_size=args.batch_size,
        baseline=args.baseline,
        oversample=(args.baseline == 'oversample')
    )

    classifier = resnet18(num_classes=10).to(device)
    # TODO: load or train classifier here

    results = {}
    # VAE loop over alpha levels
    for alpha in [None] + args.cvar_levels:
        vae = VAE(latent_dim=args.latent_dim).to(device)
        opt = optim.Adam(vae.parameters(), lr=args.lr)
        for epoch in range(args.epochs):
            train_vae(vae, train_loader, opt, device, alpha=alpha, baseline=args.baseline)
        fid_score = compute_fid(vae.decode, test_loader, device)
        kid_score = compute_kid(vae.decode, test_loader, device)
        coverage = compute_coverage(vae.decode, classifier, device, minority_labels=list(range(5,10)))
        results[f'VAE_alpha_{alpha}'] = {'FID': fid_score, 'KID': kid_score, 'Coverage': coverage}

    # GAN loop over alpha levels
    celeba_loader = get_celeba_rare(
        args.attr1, args.attr2,
        rare_ratio=args.rare_ratio,
        batch_size=args.batch_size,
        oversample=(args.baseline == 'oversample')
    )
    for alpha in [None] + args.cvar_levels:
        G = Generator(latent_dim=args.latent_dim, channels=3).to(device)
        D = Discriminator(channels=3).to(device)
        optG = optim.Adam(G.parameters(), lr=args.lr, betas=(0.5, 0.999))
        optD = optim.Adam(D.parameters(), lr=args.lr, betas=(0.5, 0.999))
        for epoch in range(args.epochs):
            train_gan(G, D, celeba_loader, optG, optD, device, alpha=alpha)
        fid_score = compute_fid(G, celeba_loader, device)
        kid_score = compute_kid(G, celeba_loader, device)
        coverage = None  # placeholder for attribute coverage
        results[f'GAN_alpha_{alpha}'] = {'FID': fid_score, 'KID': kid_score, 'Coverage': coverage}

    # Summary printout
    for key, val in results.items():
        print(f"{key}: FID={val['FID']:.2f}, KID={val['KID']:.4f}, Coverage={val['Coverage']}")

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--imbalance_ratio', type=float, default=0.2)
    parser.add_argument('--rare_ratio', type=float, default=0.1)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--latent_dim', type=int, default=100)
    parser.add_argument('--lr', type=float, default=2e-4)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--baseline', choices=['none', 'reweight', 'focal', 'oversample'], default='none')
    parser.add_argument('--cvar_levels', nargs='+', type=float, default=[0.9, 0.95, 0.99])
    parser.add_argument('--attr1', type=str, default='Wearing_Hat')
    parser.add_argument('--attr2', type=str, default='Smiling')
    args = parser.parse_args()
    main(args)

