In [None]:
!pip install wandb torch torchvision tqdm matplotlib numpy -q
print("Dependencies installed!")

In [None]:
import wandb
wandb.login()
print("WandB logged in!")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import copy
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# GPU Configuration with optimizations
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_GPUS = torch.cuda.device_count()

# Enable cuDNN optimizations
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    torch.cuda.empty_cache()

print(f"Using device: {DEVICE} | GPUs available: {NUM_GPUS}")

# ========== AGGRESSIVE GPU UTILIZATION ==========
# With 2x T4 (15GB each), we can use MUCH larger batches
NUM_WORKERS = 8  # More CPU workers
PIN_MEMORY = True
BATCH_SIZE = 1024 if NUM_GPUS > 1 else 512  # 4x increase!
VAL_BATCH_SIZE = 2048 if NUM_GPUS > 1 else 1024  # 4x increase!
PREFETCH_FACTOR = 4
GAN_BATCH_SIZE = 512  # Larger batches for GAN training
N_PROBES = 2000  # More probes for better fairness estimation

# Experiment settings
N_ROUNDS = 50
N_CLIENTS = 20

# CIFAR-10 specific
IMG_SHAPE = (3, 32, 32)  # 3 channels, 32x32
NUM_CLASSES = 10

# Pre-load CIFAR-10 dataset
print("Pre-loading CIFAR-10 dataset...")
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

TRAIN_DATA = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
TEST_DATA = datasets.CIFAR10('./data', train=False, download=True, transform=transform_test)

# Non-IID partition: each client gets only 2-3 classes (more extreme heterogeneity)
def partition_non_iid_extreme(dataset, n_clients, classes_per_client=2, seed=42):
    """Extreme non-IID: each client gets only a few classes"""
    np.random.seed(seed)
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    
    # Group indices by class
    class_indices = {c: np.where(labels == c)[0] for c in range(NUM_CLASSES)}
    
    client_data = []
    classes_assigned = []
    
    for cid in range(n_clients):
        # Assign 2-3 random classes to each client
        n_classes = np.random.randint(classes_per_client, classes_per_client + 2)
        client_classes = np.random.choice(NUM_CLASSES, n_classes, replace=False)
        classes_assigned.append(client_classes.tolist())
        
        # Gather indices for those classes
        client_idx = []
        for c in client_classes:
            # Take a random subset of this class
            n_samples = len(class_indices[c]) // (n_clients // 2)
            start = np.random.randint(0, max(1, len(class_indices[c]) - n_samples))
            client_idx.extend(class_indices[c][start:start + n_samples])
        
        np.random.shuffle(client_idx)
        client_data.append(np.array(client_idx))
    
    return client_data, classes_assigned

CLIENT_IDX, CLIENT_CLASSES = partition_non_iid_extreme(TRAIN_DATA, N_CLIENTS)
print(f"Client data sizes: {[len(idx) for idx in CLIENT_IDX]}")
print(f"Client classes: {CLIENT_CLASSES}")

# Pre-create DataLoaders
TEST_LOADER = DataLoader(TEST_DATA, batch_size=VAL_BATCH_SIZE, shuffle=False,
                         num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                         persistent_workers=True, prefetch_factor=PREFETCH_FACTOR)

VAL_IDX = np.random.choice(len(TRAIN_DATA), 2000, replace=False)
VAL_LOADER = DataLoader(Subset(TRAIN_DATA, VAL_IDX), batch_size=VAL_BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                        persistent_workers=True, prefetch_factor=PREFETCH_FACTOR)

# Client evaluation loaders
CLIENT_EVAL_LOADERS = []
for cid in range(N_CLIENTS):
    loader = DataLoader(Subset(TRAIN_DATA, CLIENT_IDX[cid]), batch_size=VAL_BATCH_SIZE, shuffle=False,
                       num_workers=4, pin_memory=PIN_MEMORY, persistent_workers=True)
    CLIENT_EVAL_LOADERS.append(loader)

# ========== LARGER CIFAR-10 CNN Model (uses more GPU memory) ==========
class CNN(nn.Module):
    """Deeper ResNet-style CNN for CIFAR-10 - optimized for T4 GPUs"""
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        # Wider channels to use more GPU memory
        self.conv1 = nn.Conv2d(3, 128, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(256, 512, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(512)
        self.conv4 = nn.Conv2d(512, 512, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(512 * 2 * 2, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))  # 32->16
        x = self.pool(F.relu(self.bn2(self.conv2(x))))  # 16->8
        x = self.pool(F.relu(self.bn3(self.conv3(x))))  # 8->4
        x = self.pool(F.relu(self.bn4(self.conv4(x))))  # 4->2
        x = x.view(-1, 512 * 2 * 2)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)

# ========== LARGER Fairness Generator for CIFAR-10 ==========
class FairnessGenerator(nn.Module):
    def __init__(self, latent_dim=128, num_classes=10, img_shape=(3, 32, 32)):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        self.init_size = img_shape[1] // 4  # 8
        self.l1 = nn.Linear(latent_dim * 2, 512 * self.init_size ** 2)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, img_shape[0], 3, 1, 1),
            nn.Tanh()
        )
        self.delta_net = nn.Sequential(
            nn.Linear(latent_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
        self.delta_scale = 0.1

    def forward(self, z, labels):
        gen_input = torch.cat([z, self.label_emb(labels)], dim=1)
        out = self.l1(gen_input)
        out = out.view(-1, 512, self.init_size, self.init_size)
        x = self.conv_blocks(out)
        delta = self.delta_net(z).view(-1, *self.img_shape) * self.delta_scale
        x_prime = torch.clamp(x + delta, -1, 1)
        return x, x_prime

# ========== LARGER Discriminator for CIFAR-10 ==========
class Discriminator(nn.Module):
    def __init__(self, num_classes=10, img_shape=(3, 32, 32)):
        super().__init__()
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.conv = nn.Sequential(
            nn.Conv2d(img_shape[0] + num_classes, 128, 3, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 3, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, 3, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Linear(512 * 2 * 2, 1)  # No sigmoid for BCEWithLogitsLoss

    def forward(self, img, labels):
        label_map = self.label_emb(labels).view(-1, self.num_classes, 1, 1)
        label_map = label_map.expand(-1, -1, self.img_shape[1], self.img_shape[2])
        out = self.conv(torch.cat([img, label_map], dim=1))
        return self.fc(out.view(out.size(0), -1))

# ========== Helper Functions ==========
def make_parallel(model):
    if NUM_GPUS > 1:
        return nn.DataParallel(model)
    return model

def get_base_model(model):
    if isinstance(model, nn.DataParallel):
        return model.module
    return model

# ========== FAIRNESS METRICS ==========
def compute_jfi(performances):
    p = np.array(performances)
    n = len(p)
    if np.sum(p**2) == 0:
        return 1.0
    return (np.sum(p)**2) / (n * np.sum(p**2))

def compute_max_min_fairness(performances):
    p = np.array(performances)
    if np.max(p) == 0:
        return 0.0
    return np.min(p) / np.max(p)

def compute_fairness_variance(performances):
    return np.var(np.array(performances))

@torch.no_grad()
def evaluate_per_client(model, client_loaders):
    model.eval()
    client_accuracies = []
    with torch.amp.autocast(device_type='cuda'):
        for loader in client_loaders:
            correct, total = 0, 0
            for d, t in loader:
                d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
                correct += (model(d).argmax(1) == t).sum().item()
                total += len(t)
            client_accuracies.append(100 * correct / total if total > 0 else 0)
    return client_accuracies

def compute_all_fairness_metrics(client_accuracies):
    return {
        'jfi': compute_jfi(client_accuracies),
        'max_min_fairness': compute_max_min_fairness(client_accuracies),
        'fairness_variance': compute_fairness_variance(client_accuracies),
        'min_accuracy': np.min(client_accuracies),
        'max_accuracy': np.max(client_accuracies),
        'mean_accuracy': np.mean(client_accuracies)
    }

# ========== Training Functions ==========
def train_fairness_gan(G, D, model, loader, epochs=15, alpha=1.0, beta=1.0):
    model.eval()
    G = make_parallel(G)
    D = make_parallel(D)
    
    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    bce = nn.BCEWithLogitsLoss()
    scaler = torch.amp.GradScaler(device='cuda')
    
    latent_dim = get_base_model(G).latent_dim
    num_classes = get_base_model(G).num_classes
    
    for epoch in range(epochs):
        for imgs, labels in loader:
            batch_size = imgs.size(0)
            real_t = torch.ones(batch_size, 1, device=DEVICE)
            fake_t = torch.zeros(batch_size, 1, device=DEVICE)
            imgs = imgs.to(DEVICE, non_blocking=True)
            labels = labels.to(DEVICE, non_blocking=True)
            z = torch.randn(batch_size, latent_dim, device=DEVICE)
            gen_labels = torch.randint(0, num_classes, (batch_size,), device=DEVICE)
            
            # Generator
            opt_G.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type='cuda'):
                x, x_prime = G(z, gen_labels)
                with torch.no_grad():
                    pred_x, pred_xp = model(x), model(x_prime)
                pred_diff = -beta * torch.mean((pred_x - pred_xp) ** 2)
                realism = alpha * torch.mean((x - x_prime) ** 2)
                gan_loss = (bce(D(x, gen_labels), real_t) + bce(D(x_prime, gen_labels), real_t)) / 2
                g_loss = pred_diff + realism + gan_loss
            scaler.scale(g_loss).backward()
            scaler.step(opt_G)
            
            # Discriminator
            opt_D.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type='cuda'):
                x, x_prime = G(z, gen_labels)
                d_real = bce(D(imgs, labels), real_t)
                d_fake = (bce(D(x.detach(), gen_labels), fake_t) + bce(D(x_prime.detach(), gen_labels), fake_t)) / 2
                d_loss = (d_real + d_fake) / 2
            scaler.scale(d_loss).backward()
            scaler.step(opt_D)
            scaler.update()
    
    return get_base_model(G), get_base_model(D)

@torch.no_grad()
def compute_bias(model, x, x_prime):
    model.eval()
    with torch.amp.autocast(device_type='cuda'):
        pred_x = model(x)
        pred_xp = model(x_prime)
    return torch.abs(pred_x - pred_xp).sum(dim=1).mean().item()

@torch.no_grad()
def compute_acc_score(model, update, val_loader):
    model.eval()
    loss_before, count = 0.0, 0
    with torch.amp.autocast(device_type='cuda'):
        for d, t in val_loader:
            d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
            loss_before += F.cross_entropy(model(d), t, reduction='sum').item()
            count += len(t)
    loss_before /= count
    
    base_model = get_base_model(model)
    hyp = copy.deepcopy(base_model)
    sd = hyp.state_dict()
    
    for k in sd:
        update_key = k
        if k not in update:
            update_key = k.replace('module.', '') if k.startswith('module.') else 'module.' + k
        if update_key in update:
            sd[k] = sd[k] + update[update_key]
        elif k in update:
            sd[k] = sd[k] + update[k]
    
    hyp.load_state_dict(sd)
    hyp = make_parallel(hyp.to(DEVICE))
    hyp.eval()
    
    loss_after = 0.0
    with torch.amp.autocast(device_type='cuda'):
        for d, t in val_loader:
            d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
            loss_after += F.cross_entropy(hyp(d), t, reduction='sum').item()
    return loss_before - loss_after / count

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.amp.autocast(device_type='cuda'):
        for d, t in loader:
            d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
            correct += (model(d).argmax(1) == t).sum().item()
            total += len(t)
    return 100 * correct / total

def normalize(scores):
    s = np.array(scores)
    if s.max() - s.min() < 1e-8:
        return np.ones_like(s) / len(s)
    return (s - s.min()) / (s.max() - s.min())

print(f"\n{'='*60}")
print(f"CIFAR-10 Fed-AuditGAN Experiment - AGGRESSIVE GPU MODE")
print(f"{'='*60}")
print(f"GPUs: {NUM_GPUS} x T4 (15GB each)")
print(f"Batch Size: {BATCH_SIZE} | Val Batch: {VAL_BATCH_SIZE}")
print(f"Workers: {NUM_WORKERS} | Probes: {N_PROBES}")
print(f"Rounds: {N_ROUNDS} | Clients: {N_CLIENTS}")
print(f"Non-IID: Each client has 2-3 classes only (EXTREME)")
print(f"{'='*60}")

In [None]:
def run_fedavg(n_rounds=N_ROUNDS, n_clients=N_CLIENTS, local_epochs=3, lr=0.01):
    """Standard FedAvg baseline"""
    wandb.init(project="fed-audit-gan-cifar10", name="FedAvg", 
               config={'method': 'FedAvg', 'n_rounds': n_rounds, 'n_clients': n_clients})
    
    global_model = make_parallel(CNN().to(DEVICE))
    scaler = torch.amp.GradScaler(device='cuda')
    
    history = {'accuracy': [], 'jfi': [], 'max_min_fairness': [], 'fairness_variance': [],
               'min_accuracy': [], 'max_accuracy': [], 'client_accuracies': [], 'weights': []}
    
    # Pre-create client loaders
    client_loaders = []
    for cid in range(n_clients):
        loader = DataLoader(Subset(TRAIN_DATA, CLIENT_IDX[cid]), batch_size=BATCH_SIZE, shuffle=True,
                           num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                           persistent_workers=True, prefetch_factor=PREFETCH_FACTOR)
        client_loaders.append(loader)
    
    print(f"FedAvg | Rounds: {n_rounds} | Clients: {n_clients} | GPUs: {NUM_GPUS}")
    
    for rnd in tqdm(range(n_rounds), desc="FedAvg"):
        updates = []
        client_sizes = []
        
        for cid in range(n_clients):
            local = make_parallel(copy.deepcopy(get_base_model(global_model)).to(DEVICE))
            before = {k: v.clone() for k, v in global_model.state_dict().items()}
            opt = optim.SGD(local.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
            local.train()
            
            for _ in range(local_epochs):
                for d, t in client_loaders[cid]:
                    d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
                    opt.zero_grad(set_to_none=True)
                    with torch.amp.autocast(device_type='cuda'):
                        loss = F.cross_entropy(local(d), t)
                    scaler.scale(loss).backward()
                    scaler.step(opt)
                    scaler.update()
            
            after = local.state_dict()
            updates.append({k: after[k] - before[k] for k in before})
            client_sizes.append(len(CLIENT_IDX[cid]))
            del local
            torch.cuda.empty_cache()
        
        # FedAvg aggregation
        total_size = sum(client_sizes)
        alphas = [size / total_size for size in client_sizes]
        
        sd = global_model.state_dict()
        for k in sd:
            sd[k] = sd[k] + sum(alphas[i] * updates[i][k] for i in range(n_clients))
        global_model.load_state_dict(sd)
        
        # Evaluate
        acc = evaluate(global_model, TEST_LOADER)
        client_accs = evaluate_per_client(global_model, CLIENT_EVAL_LOADERS)
        fairness = compute_all_fairness_metrics(client_accs)
        
        history['accuracy'].append(acc)
        history['jfi'].append(fairness['jfi'])
        history['max_min_fairness'].append(fairness['max_min_fairness'])
        history['fairness_variance'].append(fairness['fairness_variance'])
        history['min_accuracy'].append(fairness['min_accuracy'])
        history['max_accuracy'].append(fairness['max_accuracy'])
        history['client_accuracies'].append(client_accs)
        history['weights'].append(alphas)
        
        wandb.log({'round': rnd+1, 'accuracy': acc, 'jfi': fairness['jfi'],
                   'max_min_fairness': fairness['max_min_fairness'],
                   'fairness_variance': fairness['fairness_variance'],
                   'min_client_acc': fairness['min_accuracy'],
                   'max_client_acc': fairness['max_accuracy']})
    
    print(f"FedAvg Done | Accuracy: {history['accuracy'][-1]:.2f}% | JFI: {history['jfi'][-1]:.4f}")
    wandb.finish()
    return history

def run_fed_audit_gan(gamma, n_rounds=N_ROUNDS, n_clients=N_CLIENTS, local_epochs=3, lr=0.01, gan_epochs=20, n_probes=N_PROBES):
    """Fed-AuditGAN with fairness-aware aggregation"""
    wandb.init(project="fed-audit-gan-cifar10", name=f"FedAuditGAN_gamma{gamma}",
               config={'method': 'FedAuditGAN', 'gamma': gamma, 'n_rounds': n_rounds, 'n_clients': n_clients})
    
    global_model = make_parallel(CNN().to(DEVICE))
    scaler = torch.amp.GradScaler(device='cuda')
    
    history = {'accuracy': [], 'bias': [], 'jfi': [], 'max_min_fairness': [], 'fairness_variance': [],
               'min_accuracy': [], 'max_accuracy': [], 'client_accuracies': [], 'weights': []}
    
    # Pre-create client loaders
    client_loaders = []
    for cid in range(n_clients):
        loader = DataLoader(Subset(TRAIN_DATA, CLIENT_IDX[cid]), batch_size=BATCH_SIZE, shuffle=True,
                           num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                           persistent_workers=True, prefetch_factor=PREFETCH_FACTOR)
        client_loaders.append(loader)
    
    print(f"Fed-AuditGAN | gamma={gamma} | Rounds: {n_rounds} | Clients: {n_clients} | GPUs: {NUM_GPUS}")
    
    for rnd in tqdm(range(n_rounds), desc=f"gamma={gamma}"):
        # Phase 1: Client training
        updates = []
        for cid in range(n_clients):
            local = make_parallel(copy.deepcopy(get_base_model(global_model)).to(DEVICE))
            before = {k: v.clone() for k, v in global_model.state_dict().items()}
            opt = optim.SGD(local.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
            local.train()
            
            for _ in range(local_epochs):
                for d, t in client_loaders[cid]:
                    d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
                    opt.zero_grad(set_to_none=True)
                    with torch.amp.autocast(device_type='cuda'):
                        loss = F.cross_entropy(local(d), t)
                    scaler.scale(loss).backward()
                    scaler.step(opt)
                    scaler.update()
            
            after = local.state_dict()
            updates.append({k: after[k] - before[k] for k in before})
            del local
            torch.cuda.empty_cache()
        
        # Phase 2: GAN auditing
        G = FairnessGenerator(img_shape=IMG_SHAPE).to(DEVICE)
        D = Discriminator(img_shape=IMG_SHAPE).to(DEVICE)
        G, D = train_fairness_gan(G, D, global_model, VAL_LOADER, epochs=gan_epochs)
        G.eval()
        
        with torch.no_grad():
            z = torch.randn(n_probes, G.latent_dim, device=DEVICE)
            lbls = torch.randint(0, NUM_CLASSES, (n_probes,), device=DEVICE)
            with torch.amp.autocast(device_type='cuda'):
                G_parallel = make_parallel(G)
                x_p, xp_p = G_parallel(z, lbls)
        
        # Phase 3: Scoring
        B_base = compute_bias(global_model, x_p, xp_p)
        S_fair, S_acc = [], []
        for upd in updates:
            hyp = make_parallel(copy.deepcopy(get_base_model(global_model)).to(DEVICE))
            sd = hyp.state_dict()
            for k in sd:
                if k in upd:
                    sd[k] = sd[k] + upd[k]
            hyp.load_state_dict(sd)
            S_fair.append(B_base - compute_bias(hyp, x_p, xp_p))
            S_acc.append(compute_acc_score(global_model, upd, VAL_LOADER))
            del hyp
        
        del G, D, x_p, xp_p
        torch.cuda.empty_cache()
        
        # Phase 4: Aggregation
        S_fair_n, S_acc_n = normalize(S_fair), normalize(S_acc)
        alphas = [(1-gamma)*S_acc_n[i] + gamma*S_fair_n[i] for i in range(n_clients)]
        a_sum = sum(alphas)
        alphas = [a/a_sum if a_sum > 0 else 1/n_clients for a in alphas]
        
        sd = global_model.state_dict()
        for k in sd:
            sd[k] = sd[k] + sum(alphas[i] * updates[i][k] for i in range(n_clients))
        global_model.load_state_dict(sd)
        
        # Evaluate
        acc = evaluate(global_model, TEST_LOADER)
        client_accs = evaluate_per_client(global_model, CLIENT_EVAL_LOADERS)
        fairness = compute_all_fairness_metrics(client_accs)
        
        history['accuracy'].append(acc)
        history['bias'].append(B_base)
        history['jfi'].append(fairness['jfi'])
        history['max_min_fairness'].append(fairness['max_min_fairness'])
        history['fairness_variance'].append(fairness['fairness_variance'])
        history['min_accuracy'].append(fairness['min_accuracy'])
        history['max_accuracy'].append(fairness['max_accuracy'])
        history['client_accuracies'].append(client_accs)
        history['weights'].append(alphas)
        
        wandb.log({'round': rnd+1, 'accuracy': acc, 'bias': B_base, 'jfi': fairness['jfi'],
                   'max_min_fairness': fairness['max_min_fairness'],
                   'fairness_variance': fairness['fairness_variance'],
                   'min_client_acc': fairness['min_accuracy'],
                   'max_client_acc': fairness['max_accuracy']})
    
    print(f"gamma={gamma} Done | Accuracy: {history['accuracy'][-1]:.2f}% | JFI: {history['jfi'][-1]:.4f}")
    wandb.finish()
    return history

print("Training functions ready!")

In [None]:
# Run experiments: FedAvg + gamma 0.3 + gamma 0.7
results = {}

# 1. Run FedAvg baseline
print("\n" + "="*50)
print("Running FedAvg Baseline")
print("="*50)
results['FedAvg'] = run_fedavg(n_rounds=N_ROUNDS, n_clients=N_CLIENTS)

# 2. Run Fed-AuditGAN with gamma = 0.3 and 0.7
for gamma in [0.3, 0.7]:
    print("\n" + "="*50)
    print(f"Running Fed-AuditGAN gamma = {gamma}")
    print("="*50)
    results[f'gamma{gamma}'] = run_fed_audit_gan(gamma=gamma, n_rounds=N_ROUNDS, n_clients=N_CLIENTS)

print("\n" + "="*50)
print("All experiments complete!")
print("="*50)

In [None]:
# Comprehensive Visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
colors = {'FedAvg': 'black', 'gamma0.3': 'blue', 'gamma0.7': 'red'}
linestyles = {'FedAvg': '--', 'gamma0.3': '-', 'gamma0.7': '-'}
rounds = list(range(1, N_ROUNDS + 1))

# Plot 1: Global Accuracy
for method in results:
    axes[0, 0].plot(rounds, results[method]['accuracy'], linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[0, 0].set_xlabel('Round'); axes[0, 0].set_ylabel('Accuracy (%)')
axes[0, 0].set_title('CIFAR-10: Global Test Accuracy'); axes[0, 0].legend(); axes[0, 0].grid(True, alpha=0.3)

# Plot 2: JFI
for method in results:
    axes[0, 1].plot(rounds, results[method]['jfi'], linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[0, 1].set_xlabel('Round'); axes[0, 1].set_ylabel('JFI')
axes[0, 1].set_title("Jain's Fairness Index (higher=fairer)"); axes[0, 1].legend(); axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Max-Min Fairness
for method in results:
    axes[0, 2].plot(rounds, results[method]['max_min_fairness'], linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[0, 2].set_xlabel('Round'); axes[0, 2].set_ylabel('Min/Max Ratio')
axes[0, 2].set_title('Max-Min Fairness (higher=fairer)'); axes[0, 2].legend(); axes[0, 2].grid(True, alpha=0.3)

# Plot 4: Variance
for method in results:
    axes[1, 0].plot(rounds, results[method]['fairness_variance'], linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[1, 0].set_xlabel('Round'); axes[1, 0].set_ylabel('Variance')
axes[1, 0].set_title('Per-Client Accuracy Variance (lower=fairer)'); axes[1, 0].legend(); axes[1, 0].grid(True, alpha=0.3)

# Plot 5: Accuracy Gap
for method in results:
    gap = [results[method]['max_accuracy'][i] - results[method]['min_accuracy'][i] for i in range(len(rounds))]
    axes[1, 1].plot(rounds, gap, linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[1, 1].set_xlabel('Round'); axes[1, 1].set_ylabel('Gap (%)')
axes[1, 1].set_title('Best-Worst Client Gap (lower=fairer)'); axes[1, 1].legend(); axes[1, 1].grid(True, alpha=0.3)

# Plot 6: Final Per-Client Accuracy
final_client_accs = {m: results[m]['client_accuracies'][-1] for m in results}
x = np.arange(N_CLIENTS)
width = 0.25
for i, method in enumerate(results):
    axes[1, 2].bar(x + i*width, final_client_accs[method], width, label=method, color=colors[method], alpha=0.8)
axes[1, 2].set_xlabel('Client ID'); axes[1, 2].set_ylabel('Accuracy (%)')
axes[1, 2].set_title(f'Per-Client Accuracy (Round {N_ROUNDS})'); axes[1, 2].legend(); axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('cifar10_results.png', dpi=150)
plt.show()

# Summary Table
print("\n" + "="*110)
print("CIFAR-10 FINAL RESULTS SUMMARY (Round {})".format(N_ROUNDS))
print("="*110)
print(f"{'Method':<15} {'Global Acc':>10} {'JFI':>8} {'Max-Min':>10} {'Variance':>12} {'Min Acc':>10} {'Max Acc':>10} {'Gap':>8}")
print("-"*110)
for method in results:
    acc = results[method]['accuracy'][-1]
    jfi = results[method]['jfi'][-1]
    mmf = results[method]['max_min_fairness'][-1]
    var = results[method]['fairness_variance'][-1]
    min_acc = results[method]['min_accuracy'][-1]
    max_acc = results[method]['max_accuracy'][-1]
    gap = max_acc - min_acc
    print(f"{method:<15} {acc:>10.2f}% {jfi:>8.4f} {mmf:>10.4f} {var:>12.2f} {min_acc:>10.2f}% {max_acc:>10.2f}% {gap:>8.2f}%")
print("="*110)

# Bias summary
print("\nBias Summary (Fed-AuditGAN only):")
for method in ['gamma0.3', 'gamma0.7']:
    if method in results and 'bias' in results[method]:
        print(f"  {method}: Final Bias = {results[method]['bias'][-1]:.4f}")

# Winner
print("\n" + "="*60)
print("FAIRNESS COMPARISON")
print("="*60)
best_jfi = max(results.items(), key=lambda x: x[1]['jfi'][-1])
best_mmf = max(results.items(), key=lambda x: x[1]['max_min_fairness'][-1])
lowest_var = min(results.items(), key=lambda x: x[1]['fairness_variance'][-1])
print(f"Best JFI:        {best_jfi[0]} ({best_jfi[1]['jfi'][-1]:.4f})")
print(f"Best Max-Min:    {best_mmf[0]} ({best_mmf[1]['max_min_fairness'][-1]:.4f})")
print(f"Lowest Variance: {lowest_var[0]} ({lowest_var[1]['fairness_variance'][-1]:.2f})")