# üîß Fed-Audit-GAN v2.0 - MNIST (Multi-Gamma Ablation Study)

## üéØ Experiments Run:
- **FedAvg** (Œ≥ = 0.0) - Baseline with uniform weights
- **Fed-Audit-GAN Œ≥ = 0.3** - Mild fairness weighting
- **Fed-Audit-GAN Œ≥ = 0.5** - Moderate fairness weighting
- **Fed-Audit-GAN Œ≥ = 0.7** - Strong fairness weighting
- **Fed-Audit-GAN Œ≥ = 2.0** - Very strong fairness weighting

## üîß Oscillation Fixes Applied:
1. **Momentum/EMA (Œ≤=0.8)**: Smooths fairness scores over time
2. **Warm-up Period (5 rounds)**: GAN learns before affecting aggregation
3. **FedProx (Œº=0.01)**: Prevents client drift with proximal term

## üöÄ GPU/TPU Optimizations:
- Mixed Precision Training (AMP) for 2x speedup
- TPU auto-detection for Google Colab
- Optimized DataLoaders with pin_memory & prefetching

---

In [None]:
# Step 1: Install and Import Dependencies
# !pip install -q torch torchvision tqdm matplotlib numpy wandb

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import copy
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import wandb
import warnings
warnings.filterwarnings('ignore')

# ============================================================
# üöÄ GPU/TPU OPTIMIZATION SETUP
# ============================================================

# Check for TPU (Google Colab)
USE_TPU = False
try:
    import torch_xla
    import torch_xla.core.xla_model as xm
    DEVICE = xm.xla_device()
    USE_TPU = True
    print("‚úÖ TPU detected! Using TPU acceleration.")
except ImportError:
    pass

# Check for GPU if no TPU
if not USE_TPU:
    if torch.cuda.is_available():
        DEVICE = torch.device('cuda')
        # Enable TensorFloat-32 for faster computation on Ampere+ GPUs
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # Enable cuDNN autotuner for optimized convolutions
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.enabled = True
        torch.cuda.empty_cache()
        print(f"‚úÖ GPU detected: {torch.cuda.get_device_name(0)}")
        print(f"   CUDA Version: {torch.version.cuda}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    else:
        DEVICE = torch.device('cpu')
        print("‚ö†Ô∏è  No GPU/TPU detected. Using CPU (training will be slower).")

# Mixed Precision Training (AMP) - for GPU only
USE_AMP = torch.cuda.is_available() and not USE_TPU
if USE_AMP:
    print("‚úÖ Mixed Precision Training (AMP) enabled for faster GPU training.")

print(f"\nüìç Device: {DEVICE}")
print(f"   PyTorch: {torch.__version__}")

In [None]:
# Step 2: Login to WandB
wandb.login()
print("‚úÖ WandB logged in!")

In [None]:
# ============================================================
# MODEL DEFINITIONS
# ============================================================

class CNN(nn.Module):
    """Simple CNN for MNIST classification"""
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        return self.fc2(x)


class FairnessGenerator(nn.Module):
    """Generator that produces paired samples (x, x') for fairness testing"""
    def __init__(self, latent_dim=100, num_classes=10, img_shape=(1, 28, 28)):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        self.init_size = img_shape[1] // 4
        self.l1 = nn.Linear(latent_dim * 2, 128 * self.init_size ** 2)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128), nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2),
            nn.Conv2d(64, img_shape[0], 3, 1, 1), nn.Tanh())
        self.delta_net = nn.Sequential(
            nn.Linear(latent_dim, 256), nn.ReLU(),
            nn.Linear(256, int(np.prod(img_shape))), nn.Tanh())
        self.delta_scale = 0.1

    def forward(self, z, labels):
        gen_input = torch.cat([z, self.label_emb(labels)], dim=1)
        out = self.l1(gen_input).view(-1, 128, self.init_size, self.init_size)
        x = self.conv_blocks(out)
        delta = self.delta_net(z).view(-1, *self.img_shape) * self.delta_scale
        return x, torch.clamp(x + delta, -1, 1)


class Discriminator(nn.Module):
    """Conditional Discriminator for GAN training - outputs logits for BCEWithLogitsLoss"""
    def __init__(self, num_classes=10, img_shape=(1, 28, 28)):
        super().__init__()
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.conv = nn.Sequential(
            nn.Conv2d(img_shape[0] + num_classes, 16, 3, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(16, 32, 3, 2, 1), nn.BatchNorm2d(32), nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 3, 2, 1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2))
        self.fc = nn.Sequential(nn.Linear(128 * 4, 1))  # No Sigmoid - use BCEWithLogitsLoss

    def forward(self, img, labels):
        label_map = self.label_emb(labels).view(-1, self.num_classes, 1, 1)
        label_map = label_map.expand(-1, -1, self.img_shape[1], self.img_shape[2])
        out = self.conv(torch.cat([img, label_map], dim=1))
        return self.fc(out.view(out.size(0), -1))

In [None]:
# ============================================================
# HELPER FUNCTIONS (WITH GPU/AMP OPTIMIZATIONS)
# ============================================================

def train_gan(G, D, model, loader, epochs=30, device='cuda', l1=1.0, l2=1.0):
    """Train the Fairness GAN with optional Mixed Precision"""
    G, D, model = G.to(device), D.to(device), model.to(device)
    model.eval()
    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    bce = nn.BCEWithLogitsLoss()  # Safe for autocast
    
    # Mixed precision scalers (only for GPU)
    use_amp_local = USE_AMP and 'cuda' in str(device)
    if use_amp_local:
        scaler_G = torch.amp.GradScaler(device='cuda')
        scaler_D = torch.amp.GradScaler(device='cuda')
    
    for _ in range(epochs):
        for imgs, labels in loader:
            bs = imgs.size(0)
            real, fake_t = torch.ones(bs, 1, device=device), torch.zeros(bs, 1, device=device)
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            z = torch.randn(bs, G.latent_dim, device=device)
            gl = torch.randint(0, G.num_classes, (bs,), device=device)
            
            # Generator training with AMP
            opt_G.zero_grad(set_to_none=True)
            if use_amp_local:
                with torch.amp.autocast(device_type='cuda'):
                    x, xp = G(z, gl)
                    with torch.no_grad():
                        px, pxp = model(x), model(xp)
                    t1 = -torch.mean((px - pxp) ** 2)
                    t2 = l1 * torch.mean((x - xp) ** 2)
                    t3 = l2 * (bce(D(x, gl), real) + bce(D(xp, gl), real)) / 2
                    g_loss = t1 + t2 + t3
                scaler_G.scale(g_loss).backward()
                scaler_G.step(opt_G)
                scaler_G.update()
            else:
                x, xp = G(z, gl)
                with torch.no_grad():
                    px, pxp = model(x), model(xp)
                t1 = -torch.mean((px - pxp) ** 2)
                t2 = l1 * torch.mean((x - xp) ** 2)
                t3 = l2 * (bce(D(x, gl), real) + bce(D(xp, gl), real)) / 2
                (t1 + t2 + t3).backward()
                opt_G.step()
            
            # Discriminator training with AMP
            opt_D.zero_grad(set_to_none=True)
            if use_amp_local:
                with torch.amp.autocast(device_type='cuda'):
                    x, xp = G(z, gl)
                    d_loss = (bce(D(imgs, labels), real) + bce(D(x.detach(), gl), fake_t) + bce(D(xp.detach(), gl), fake_t)) / 3
                scaler_D.scale(d_loss).backward()
                scaler_D.step(opt_D)
                scaler_D.update()
            else:
                x, xp = G(z, gl)
                d_loss = (bce(D(imgs, labels), real) + bce(D(x.detach(), gl), fake_t) + bce(D(xp.detach(), gl), fake_t)) / 3
                d_loss.backward()
                opt_D.step()
    
    return G, D


@torch.no_grad()
def compute_bias(model, x, xp, device):
    """Compute bias as difference in model predictions between x and x'"""
    model.eval()
    with torch.amp.autocast(device_type='cuda', enabled=USE_AMP):
        return torch.abs(model(x.to(device)) - model(xp.to(device))).sum(1).mean().item()


def partition_data_non_iid_unequal(dataset, n_clients, alpha=0.5):
    """
    Create Non-IID partition with UNEQUAL data sizes per client.
    Uses Dirichlet distribution for both label and size heterogeneity.
    """
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    n_classes = len(np.unique(labels))
    
    class_indices = {c: np.where(labels == c)[0] for c in range(n_classes)}
    for c in class_indices:
        np.random.shuffle(class_indices[c])
    
    client_indices = [[] for _ in range(n_clients)]
    
    for c in range(n_classes):
        proportions = np.random.dirichlet(np.repeat(alpha, n_clients))
        proportions = (proportions * len(class_indices[c])).astype(int)
        proportions[-1] = len(class_indices[c]) - proportions[:-1].sum()
        
        start = 0
        for client_id, num_samples in enumerate(proportions):
            if num_samples > 0:
                client_indices[client_id].extend(
                    class_indices[c][start:start + num_samples].tolist()
                )
            start += num_samples
    
    result = []
    for i in range(n_clients):
        indices = np.array(client_indices[i])
        np.random.shuffle(indices)
        result.append(indices)
    
    return result


@torch.no_grad()
def evaluate(model, loader, device):
    """Evaluate model accuracy"""
    model.eval()
    correct, total = 0, 0
    with torch.amp.autocast(device_type='cuda', enabled=USE_AMP):
        for d, t in loader:
            d, t = d.to(device, non_blocking=True), t.to(device, non_blocking=True)
            correct += (model(d).argmax(1) == t).sum().item()
            total += len(t)
    return 100 * correct / total


@torch.no_grad()
def evaluate_per_client(model, client_loaders, device):
    """
    Evaluate model accuracy on EACH client's data.
    This measures how fairly the model performs across clients.
    """
    model.eval()
    client_accuracies = []
    with torch.amp.autocast(device_type='cuda', enabled=USE_AMP):
        for loader in client_loaders:
            correct, total = 0, 0
            for d, t in loader:
                d, t = d.to(device, non_blocking=True), t.to(device, non_blocking=True)
                correct += (model(d).argmax(1) == t).sum().item()
                total += len(t)
            acc = 100 * correct / total if total > 0 else 0
            client_accuracies.append(acc)
    return client_accuracies


# ============================================================
# FAIRNESS METRICS (Based on Per-Client Performance!)
# ============================================================

def calculate_jfi(performances):
    """
    Jain's Fairness Index based on CLIENT PERFORMANCES (not weights!)
    JFI = (Œ£p·µ¢)¬≤ / (N √ó Œ£p·µ¢¬≤)
    Returns 1.0 if all clients have equal performance (perfectly fair)
    """
    p = np.array(performances)
    n = len(p)
    if np.sum(p ** 2) == 0:
        return 1.0
    return (np.sum(p) ** 2) / (n * np.sum(p ** 2))


def calculate_max_min_fairness(performances):
    """
    Max-Min Fairness: min(acc) / max(acc)
    Higher is better (1.0 = all clients have same accuracy)
    """
    p = np.array(performances)
    if np.max(p) == 0:
        return 0.0
    return np.min(p) / np.max(p)


def calculate_variance(performances):
    """Variance of per-client accuracies. Lower is fairer."""
    return np.var(performances)


def calculate_accuracy_gap(performances):
    """Gap between best and worst client. Lower is fairer."""
    return np.max(performances) - np.min(performances)

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

# Training Parameters
N_ROUNDS = 50           # Total training rounds
N_CLIENTS = 20          # Number of federated clients
N_GAN_EPOCHS = 20       # GAN training epochs per round
N_PROBES = 300          # Number of probe samples
LOCAL_EPOCHS = 3        # Local training epochs per client

# ‚≠ê MULTI-GAMMA ABLATION STUDY
GAMMA_VALUES = [0.0, 0.3, 0.5, 0.7, 2.0]  # 0.0 = FedAvg baseline

# ‚≠ê Oscillation Fix Parameters
MOMENTUM = 0.8          # EMA momentum for fairness scores (Change 1)
WARMUP_ROUNDS = 5       # Rounds before activating fairness scoring (Change 2)
MU = 0.01               # FedProx proximal term strength (Change 3)

print("=" * 60)
print("üîß Fed-Audit-GAN v2.0 - MULTI-GAMMA ABLATION STUDY")
print("=" * 60)
print(f"Device: {DEVICE}")
print(f"Rounds: {N_ROUNDS}, Clients: {N_CLIENTS}")
print(f"\nüéØ GAMMA VALUES TO TEST: {GAMMA_VALUES}")
print(f"   Œ≥=0.0 ‚Üí FedAvg (uniform weights)")
print(f"   Œ≥=0.3, 0.5, 0.7 ‚Üí Mild fairness weighting")
print(f"   Œ≥=2.0 ‚Üí Strong fairness weighting")
print(f"\n‚≠ê OSCILLATION FIX PARAMETERS:")
print(f"   Momentum (Œ≤): {MOMENTUM}")
print(f"   Warm-up Rounds: {WARMUP_ROUNDS}")
print(f"   FedProx (Œº): {MU}")
print("=" * 60)

In [None]:
# ============================================================
# DATA LOADING (WITH GPU OPTIMIZATIONS)
# ============================================================

transform = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize((0.1307,), (0.3081,))
])

train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Create Non-IID partitions with UNEQUAL sizes (Dirichlet distribution)
np.random.seed(42)
DIRICHLET_ALPHA = 0.5  # Lower = more heterogeneous (0.5 = high imbalance)
client_idx = partition_data_non_iid_unequal(train_data, N_CLIENTS, alpha=DIRICHLET_ALPHA)

# Calculate data weights for each client (proportional to data size)
client_data_sizes = [len(idx) for idx in client_idx]
total_samples = sum(client_data_sizes)
CLIENT_DATA_WEIGHTS = [size / total_samples for size in client_data_sizes]

# ‚ö° DataLoader optimizations for GPU
NUM_WORKERS = 2 if torch.cuda.is_available() else 0  # Parallel data loading
PIN_MEMORY = torch.cuda.is_available()  # Faster CPU->GPU transfer
PREFETCH_FACTOR = 2 if NUM_WORKERS > 0 else None

# Create data loaders with optimizations
dataloader_kwargs = {
    'num_workers': NUM_WORKERS,
    'pin_memory': PIN_MEMORY,
    'persistent_workers': NUM_WORKERS > 0,
}
if PREFETCH_FACTOR:
    dataloader_kwargs['prefetch_factor'] = PREFETCH_FACTOR

test_loader = DataLoader(test_data, batch_size=128, **dataloader_kwargs)  # Larger batch for eval
val_loader = DataLoader(Subset(train_data, np.random.choice(len(train_data), 1000, replace=False)), 
                        batch_size=64, **dataloader_kwargs)

# Client data loaders
client_loaders = [
    DataLoader(Subset(train_data, client_idx[c]), batch_size=64, shuffle=True, **dataloader_kwargs)
    for c in range(N_CLIENTS)
]

print(f"Training samples: {len(train_data)}")
print(f"Test samples: {len(test_data)}")
print(f"\nüìä NON-IID DATA DISTRIBUTION (Dirichlet Œ±={DIRICHLET_ALPHA}):")
print(f"   Samples per client: {client_data_sizes}")
print(f"   Min: {min(client_data_sizes)}, Max: {max(client_data_sizes)}, Ratio: {max(client_data_sizes)/max(1, min(client_data_sizes)):.1f}x")
print(f"\n‚öñÔ∏è CLIENT DATA WEIGHTS (for FedAvg):")
for i, (size, weight) in enumerate(zip(client_data_sizes, CLIENT_DATA_WEIGHTS)):
    print(f"   Client {i}: {size:5d} samples ‚Üí weight = {weight:.4f}")

In [None]:
# ============================================================
# üöÄ MULTI-GAMMA ABLATION STUDY
# Runs: FedAvg (Œ≥=0) + Fed-Audit-GAN with Œ≥=0.3, 0.5, 0.7, 2.0
# ============================================================

def run_fed_audit_gan(gamma, n_rounds, n_clients, warmup_rounds, momentum, mu,
                      train_data, client_idx, val_loader, test_loader, client_loaders,
                      n_gan_epochs, n_probes, local_epochs, device, use_amp,
                      client_data_weights):
    """
    Run Fed-Audit-GAN v2.0 with specified gamma value.
    gamma=0 is equivalent to FedAvg (data-weighted aggregation) - NO GAN training.
    """
    
    # Initialize model
    model = CNN().to(device)
    scaler = torch.amp.GradScaler(device='cuda') if use_amp else None
    
    # Fairness score history for momentum (only needed for gamma > 0)
    fairness_history = {i: 0.0 for i in range(n_clients)}
    
    # History tracking - NOW BASED ON PER-CLIENT ACCURACY!
    history = {
        'acc': [], 'bias': [], 'alphas': [],
        'raw_scores': [], 'smoothed_scores': [],
        'client_accuracies': [],  # Per-client accuracy each round
        'jfi': [], 'max_min_fairness': [], 'variance': [], 'accuracy_gap': [],
        'min_client_acc': [], 'max_client_acc': []
    }
    
    is_fedavg = (gamma == 0)  # FedAvg mode - skip GAN training
    
    for rnd in tqdm(range(n_rounds), desc=f"{'FedAvg' if is_fedavg else f'Œ≥={gamma}'}"):
        # ================================================================
        # PHASE 1: Local Client Training (with FedProx + AMP)
        # ================================================================
        updates = []
        global_params = [p.clone().detach() for p in model.parameters()]
        
        for cid in range(n_clients):
            local_model = copy.deepcopy(model)
            local_model.train()
            before_state = copy.deepcopy(model.state_dict())
            optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.9)
            
            for epoch in range(local_epochs):
                for data, target in client_loaders[cid]:
                    data = data.to(device, non_blocking=True)
                    target = target.to(device, non_blocking=True)
                    optimizer.zero_grad(set_to_none=True)
                    
                    if use_amp:
                        with torch.amp.autocast(device_type='cuda'):
                            output = local_model(data)
                            ce_loss = F.cross_entropy(output, target)
                            prox_loss = sum(((lp - gp) ** 2).sum() 
                                          for lp, gp in zip(local_model.parameters(), global_params))
                            loss = ce_loss + (mu / 2) * prox_loss
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        output = local_model(data)
                        ce_loss = F.cross_entropy(output, target)
                        prox_loss = sum(((lp - gp) ** 2).sum() 
                                      for lp, gp in zip(local_model.parameters(), global_params))
                        loss = ce_loss + (mu / 2) * prox_loss
                        loss.backward()
                        optimizer.step()
            
            update = {k: local_model.state_dict()[k] - before_state[k] for k in before_state}
            updates.append(update)
            del local_model
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        # ================================================================
        # PHASE 2 & 3: GAN Training + Fairness Scoring (SKIP for FedAvg!)
        # ================================================================
        B_base = 0.0  # Default bias for FedAvg
        S_fair_raw = [0.0] * n_clients
        S_fair_smoothed = [0.0] * n_clients
        
        if not is_fedavg:
            # Only train GAN and compute fairness for gamma > 0
            G = FairnessGenerator().to(device)
            D = Discriminator().to(device)
            G, D = train_gan(G, D, model, val_loader, epochs=n_gan_epochs, device=device)
            
            G.eval()
            with torch.no_grad():
                z = torch.randn(n_probes, G.latent_dim, device=device)
                labels = torch.randint(0, 10, (n_probes,), device=device)
                with torch.amp.autocast(device_type='cuda', enabled=use_amp):
                    x_probe, xp_probe = G(z, labels)
            
            # Compute base bias
            B_base = compute_bias(model, x_probe, xp_probe, device)
            
            # Compute fairness scores for each client
            S_fair_raw = []
            S_fair_smoothed = []
            
            for cid, upd in enumerate(updates):
                hyp_model = copy.deepcopy(model)
                hyp_state = hyp_model.state_dict()
                for k in hyp_state:
                    hyp_state[k] = hyp_state[k] + upd[k]
                hyp_model.load_state_dict(hyp_state)
                
                B_client = compute_bias(hyp_model, x_probe, xp_probe, device)
                S_current = B_base - B_client
                S_fair_raw.append(S_current)
                
                # Apply EMA momentum
                S_prev = fairness_history[cid]
                S_smoothed = (momentum * S_prev) + ((1 - momentum) * S_current)
                fairness_history[cid] = S_smoothed
                S_fair_smoothed.append(S_smoothed)
                del hyp_model
            
            del G, D, x_probe, xp_probe
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
        
        history['raw_scores'].append(S_fair_raw.copy() if not is_fedavg else [0.0] * n_clients)
        history['smoothed_scores'].append(S_fair_smoothed.copy() if not is_fedavg else [0.0] * n_clients)
        
        # ================================================================
        # PHASE 4: Aggregation
        # ================================================================
        if is_fedavg:
            # FedAvg: DATA-WEIGHTED aggregation (proportional to client data size)
            alphas = client_data_weights.copy()
        elif rnd < warmup_rounds:
            # Warm-up: use data-weighted (same as FedAvg)
            alphas = client_data_weights.copy()
        else:
            # Fairness-aware weights (after warm-up)
            alphas = F.softmax(torch.tensor(S_fair_smoothed) * gamma, dim=0).tolist()
        
        # Apply aggregation
        new_state = model.state_dict()
        for k in new_state:
            new_state[k] = new_state[k] + sum(a * u[k] for a, u in zip(alphas, updates))
        model.load_state_dict(new_state)
        
        # ================================================================
        # EVALUATION - FAIRNESS BASED ON PER-CLIENT ACCURACY!
        # ================================================================
        # Global accuracy
        acc = evaluate(model, test_loader, device)
        
        # Per-client accuracy (THIS IS WHAT FAIRNESS IS BASED ON!)
        client_accs = evaluate_per_client(model, client_loaders, device)
        
        # Compute fairness metrics from per-client performance
        jfi = calculate_jfi(client_accs)
        max_min = calculate_max_min_fairness(client_accs)
        var = calculate_variance(client_accs)
        gap = calculate_accuracy_gap(client_accs)
        
        # Store history
        history['acc'].append(acc)
        history['bias'].append(B_base)
        history['alphas'].append(alphas.copy())
        history['client_accuracies'].append(client_accs.copy())
        history['jfi'].append(jfi)
        history['max_min_fairness'].append(max_min)
        history['variance'].append(var)
        history['accuracy_gap'].append(gap)
        history['min_client_acc'].append(min(client_accs))
        history['max_client_acc'].append(max(client_accs))
        
        # Log to WandB
        wandb.log({
            'round': rnd + 1,
            'accuracy': acc,
            'bias': B_base,
            'jfi': jfi,
            'max_min_fairness': max_min,
            'fairness_variance': var,
            'accuracy_gap': gap,
            'min_client_acc': min(client_accs),
            'max_client_acc': max(client_accs)
        })
    
    return model, history


# ============================================================
# RUN ALL EXPERIMENTS
# ============================================================

all_results = {}

for gamma in GAMMA_VALUES:
    method_name = "FedAvg" if gamma == 0 else f"FedAuditGAN_Œ≥={gamma}"
    
    print(f"\n{'='*70}")
    print(f"üöÄ RUNNING: {method_name}")
    if gamma == 0:
        print(f"   (Pure FedAvg - NO GAN, data-weighted aggregation)")
    else:
        print(f"   (Fed-Audit-GAN with fairness-aware aggregation)")
    print(f"{'='*70}")
    
    # Initialize WandB
    wandb.init(
        project="FED_AUDIT_GAN_TEST_1_MNIST",
        name=f"{method_name}_MNIST_clients{N_CLIENTS}",
        config={
            "method": method_name,
            "dataset": "MNIST",
            "n_rounds": N_ROUNDS,
            "n_clients": N_CLIENTS,
            "gamma": gamma,
            "momentum": MOMENTUM,
            "warmup_rounds": WARMUP_ROUNDS,
            "mu_fedprox": MU,
            "dirichlet_alpha": DIRICHLET_ALPHA,
            "device": str(DEVICE)
        }
    )
    
    # Run experiment
    model, history = run_fed_audit_gan(
        gamma=gamma,
        n_rounds=N_ROUNDS,
        n_clients=N_CLIENTS,
        warmup_rounds=WARMUP_ROUNDS,
        momentum=MOMENTUM,
        mu=MU,
        train_data=train_data,
        client_idx=client_idx,
        val_loader=val_loader,
        test_loader=test_loader,
        client_loaders=client_loaders,
        n_gan_epochs=N_GAN_EPOCHS,
        n_probes=N_PROBES,
        local_epochs=LOCAL_EPOCHS,
        device=DEVICE,
        use_amp=USE_AMP,
        client_data_weights=CLIENT_DATA_WEIGHTS
    )
    
    wandb.finish()
    
    # Store results
    all_results[gamma] = {
        'model': model,
        'history': history,
        'name': method_name
    }
    
    print(f"‚úÖ {method_name} Complete!")
    print(f"   Final Accuracy: {history['acc'][-1]:.2f}%")
    print(f"   Final JFI (per-client): {history['jfi'][-1]:.4f}")
    print(f"   Accuracy Gap: {history['accuracy_gap'][-1]:.2f}%")
    if gamma > 0:
        print(f"   Final Bias: {history['bias'][-1]:.6f}")

print("\n" + "=" * 70)
print("‚úÖ ALL EXPERIMENTS COMPLETE!")
print("=" * 70)
print("üìä Check your WandB dashboard: https://wandb.ai")

In [None]:
# ============================================================
# üìä RESULTS SUMMARY TABLE (Based on Per-Client Accuracy!)
# ============================================================

print("\n" + "=" * 110)
print("üìä MULTI-GAMMA ABLATION STUDY RESULTS")
print("=" * 110)

print(f"\n{'METHOD':<25} {'GLOBAL ACC':<12} {'JFI':<10} {'MAX-MIN':<10} {'GAP':<10} {'MIN ACC':<10} {'MAX ACC':<10}")
print("-" * 110)

# Find best metrics
best_acc = max(all_results[g]['history']['acc'][-1] for g in GAMMA_VALUES)
best_jfi = max(all_results[g]['history']['jfi'][-1] for g in GAMMA_VALUES)
lowest_gap = min(all_results[g]['history']['accuracy_gap'][-1] for g in GAMMA_VALUES)

for gamma in GAMMA_VALUES:
    name = all_results[gamma]['name']
    acc = all_results[gamma]['history']['acc'][-1]
    jfi = all_results[gamma]['history']['jfi'][-1]
    max_min = all_results[gamma]['history']['max_min_fairness'][-1]
    gap = all_results[gamma]['history']['accuracy_gap'][-1]
    min_acc = all_results[gamma]['history']['min_client_acc'][-1]
    max_acc = all_results[gamma]['history']['max_client_acc'][-1]
    
    acc_mark = "üèÜ" if acc == best_acc else ""
    jfi_mark = "‚≠ê" if jfi == best_jfi else ""
    gap_mark = "‚úÖ" if gap == lowest_gap else ""
    
    print(f"{name:<25} {acc:>8.2f}% {acc_mark:<2} {jfi:>8.4f} {jfi_mark:<2} {max_min:>8.4f}   {gap:>6.2f}% {gap_mark:<2} {min_acc:>8.2f}%  {max_acc:>8.2f}%")

print("=" * 110)

# Improvement over FedAvg
fedavg_acc = all_results[0.0]['history']['acc'][-1]
fedavg_jfi = all_results[0.0]['history']['jfi'][-1]
fedavg_gap = all_results[0.0]['history']['accuracy_gap'][-1]

print(f"\nüìà IMPROVEMENT OVER FedAvg:")
for gamma in GAMMA_VALUES:
    if gamma == 0:
        continue
    name = all_results[gamma]['name']
    acc = all_results[gamma]['history']['acc'][-1]
    jfi = all_results[gamma]['history']['jfi'][-1]
    gap = all_results[gamma]['history']['accuracy_gap'][-1]
    
    acc_diff = acc - fedavg_acc
    jfi_diff = jfi - fedavg_jfi
    gap_reduction = fedavg_gap - gap
    
    print(f"   {name}:")
    print(f"      Accuracy: {'+' if acc_diff >= 0 else ''}{acc_diff:.2f}%")
    print(f"      JFI: {'+' if jfi_diff >= 0 else ''}{jfi_diff:.4f}")
    print(f"      Gap Reduction: {gap_reduction:.2f}%")

print("\n" + "=" * 110)
print("üìù FAIRNESS METRICS EXPLANATION:")
print("   ‚Ä¢ JFI (Jain's Fairness Index): 1.0 = perfect fairness across clients")
print("   ‚Ä¢ Max-Min Fairness: min(acc)/max(acc) - higher is fairer")
print("   ‚Ä¢ Accuracy Gap: max(acc) - min(acc) - lower is fairer")
print("=" * 110)

In [None]:
# ============================================================
# üìä COMPREHENSIVE VISUALIZATION - Per-Client Fairness Metrics
# ============================================================

# Color palette for different gamma values
colors = {
    0.0: '#e74c3c',   # Red - FedAvg
    0.3: '#f39c12',   # Orange
    0.5: '#9b59b6',   # Purple
    0.7: '#3498db',   # Blue
    2.0: '#2ecc71',   # Green
}

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
rounds = range(1, N_ROUNDS + 1)

# ================================================================
# Plot 1: Global Accuracy
# ================================================================
ax = axes[0, 0]
for gamma in GAMMA_VALUES:
    name = all_results[gamma]['name']
    acc = all_results[gamma]['history']['acc']
    linestyle = '--' if gamma == 0 else '-'
    ax.plot(rounds, acc, color=colors[gamma], linestyle=linestyle, 
            marker='o', linewidth=2, markersize=3, label=name)
ax.axvspan(1, WARMUP_ROUNDS, alpha=0.15, color='gray', label='Warm-up')
ax.set_xlabel('Round', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('Global Test Accuracy', fontsize=14, fontweight='bold')
ax.legend(fontsize=9, loc='lower right')
ax.grid(True, alpha=0.3)

# ================================================================
# Plot 2: JFI (Based on Per-Client Accuracy!)
# ================================================================
ax = axes[0, 1]
for gamma in GAMMA_VALUES:
    name = all_results[gamma]['name']
    jfi = all_results[gamma]['history']['jfi']
    linestyle = '--' if gamma == 0 else '-'
    ax.plot(rounds, jfi, color=colors[gamma], linestyle=linestyle,
            marker='s', linewidth=2, markersize=3, label=name)
ax.axvspan(1, WARMUP_ROUNDS, alpha=0.15, color='gray')
ax.set_xlabel('Round', fontsize=12)
ax.set_ylabel('JFI', fontsize=12)
ax.set_title("Jain's Fairness Index (Higher = Fairer)", fontsize=14, fontweight='bold')
ax.set_ylim([0.8, 1.02])
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# ================================================================
# Plot 3: Accuracy Gap (Max - Min Client Accuracy)
# ================================================================
ax = axes[0, 2]
for gamma in GAMMA_VALUES:
    name = all_results[gamma]['name']
    gap = all_results[gamma]['history']['accuracy_gap']
    linestyle = '--' if gamma == 0 else '-'
    ax.plot(rounds, gap, color=colors[gamma], linestyle=linestyle,
            marker='^', linewidth=2, markersize=3, label=name)
ax.axvspan(1, WARMUP_ROUNDS, alpha=0.15, color='gray')
ax.set_xlabel('Round', fontsize=12)
ax.set_ylabel('Accuracy Gap (%)', fontsize=12)
ax.set_title('Best-Worst Client Gap (Lower = Fairer)', fontsize=14, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# ================================================================
# Plot 4: Per-Client Accuracy Variance
# ================================================================
ax = axes[1, 0]
for gamma in GAMMA_VALUES:
    name = all_results[gamma]['name']
    var = all_results[gamma]['history']['variance']
    linestyle = '--' if gamma == 0 else '-'
    ax.plot(rounds, var, color=colors[gamma], linestyle=linestyle,
            marker='d', linewidth=2, markersize=3, label=name)
ax.axvspan(1, WARMUP_ROUNDS, alpha=0.15, color='gray')
ax.set_xlabel('Round', fontsize=12)
ax.set_ylabel('Variance', fontsize=12)
ax.set_title('Per-Client Accuracy Variance (Lower = Fairer)', fontsize=14, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# ================================================================
# Plot 5: Min & Max Client Accuracy Over Time
# ================================================================
ax = axes[1, 1]
for gamma in GAMMA_VALUES:
    name = all_results[gamma]['name']
    min_acc = all_results[gamma]['history']['min_client_acc']
    max_acc = all_results[gamma]['history']['max_client_acc']
    linestyle = '--' if gamma == 0 else '-'
    ax.fill_between(rounds, min_acc, max_acc, color=colors[gamma], alpha=0.2)
    ax.plot(rounds, min_acc, color=colors[gamma], linestyle=linestyle, linewidth=1.5)
    ax.plot(rounds, max_acc, color=colors[gamma], linestyle=linestyle, linewidth=1.5, label=name)
ax.axvspan(1, WARMUP_ROUNDS, alpha=0.15, color='gray')
ax.set_xlabel('Round', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('Min-Max Client Accuracy Range', fontsize=14, fontweight='bold')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# ================================================================
# Plot 6: Final Per-Client Accuracy Bar Chart
# ================================================================
ax = axes[1, 2]
x = np.arange(N_CLIENTS)
width = 0.15
for i, gamma in enumerate(GAMMA_VALUES):
    name = all_results[gamma]['name']
    client_accs = all_results[gamma]['history']['client_accuracies'][-1]
    ax.bar(x + i*width, client_accs, width, label=name, color=colors[gamma], alpha=0.8)
ax.set_xlabel('Client ID', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title(f'Per-Client Accuracy (Final Round)', fontsize=14, fontweight='bold')
ax.legend(fontsize=8, loc='lower right')
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('multi_gamma_fairness_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nüìÅ Results saved to: multi_gamma_fairness_comparison.png")

In [None]:
# ============================================================
# üìä DETAILED ANALYSIS
# ============================================================

print("\n" + "=" * 80)
print("üìä DETAILED GAMMA ANALYSIS")
print("=" * 80)

print(f"""
üìù INTERPRETATION OF GAMMA (Œ≥) VALUES:

   Œ≥ = 0.0 (FedAvg):
      - Uniform weights: Œ±_i = 1/N for all clients
      - Ignores fairness scores completely
      - Baseline for comparison

   Œ≥ = 0.3 (Mild Fairness):
      - Slight preference for fair clients
      - Weights close to uniform
      - Minimal oscillation risk

   Œ≥ = 0.5 (Moderate Fairness):
      - Balanced fairness-accuracy trade-off
      - Moderate weight differentiation

   Œ≥ = 0.7 (Strong Fairness):
      - Stronger preference for fair clients
      - More weight variance

   Œ≥ = 2.0 (Very Strong Fairness):
      - Aggressive fairness weighting
      - Large weight differences between clients
      - Requires momentum to stabilize

üìà KEY METRICS EXPLANATION:

   ‚Ä¢ Accuracy: Higher is better (model performance)
   ‚Ä¢ Bias: Lower is better (fairness)
   ‚Ä¢ JFI: 1.0 = equal weights, <1.0 = differentiated weights
   
   ‚ö° TRADE-OFF: Lower Œ≥ = more stable, Higher Œ≥ = more fairness-aware

üìä OSCILLATION FIX EFFECTIVENESS:
   - Momentum (Œ≤={MOMENTUM}) smooths score fluctuations
   - Warm-up ({WARMUP_ROUNDS} rounds) lets GAN learn first
   - FedProx (Œº={MU}) prevents client drift
""")

# Find optimal gamma
best_gamma_acc = max(GAMMA_VALUES, key=lambda g: all_results[g]['history']['acc'][-1])
best_gamma_bias = min(GAMMA_VALUES, key=lambda g: all_results[g]['history']['bias'][-1])

print(f"\nüèÜ OPTIMAL GAMMA VALUES:")
print(f"   Best for Accuracy: Œ≥ = {best_gamma_acc} ({all_results[best_gamma_acc]['history']['acc'][-1]:.2f}%)")
print(f"   Best for Bias: Œ≥ = {best_gamma_bias} (Bias = {all_results[best_gamma_bias]['history']['bias'][-1]:.6f})")

# Check if there's a clear winner
if best_gamma_acc == best_gamma_bias:
    print(f"\n   ‚úÖ Œ≥ = {best_gamma_acc} is optimal for BOTH accuracy and fairness!")
else:
    print(f"\n   ‚öñÔ∏è Trade-off exists: Choose based on priority (accuracy vs fairness)")

In [None]:
# ============================================================
# SAVE ALL MODELS AND RESULTS
# ============================================================

import os

# Create results directory
os.makedirs('results_v2', exist_ok=True)

# Save all models and histories
for gamma in GAMMA_VALUES:
    name = all_results[gamma]['name']
    filename = f"results_v2/{name.replace('=', '').replace('.', '_')}_MNIST.pth"
    
    torch.save({
        'model_state_dict': all_results[gamma]['model'].state_dict(),
        'history': all_results[gamma]['history'],
        'config': {
            'n_rounds': N_ROUNDS,
            'n_clients': N_CLIENTS,
            'gamma': gamma,
            'momentum': MOMENTUM,
            'warmup_rounds': WARMUP_ROUNDS,
            'mu': MU
        }
    }, filename)
    print(f"‚úÖ Saved: {filename}")

# Save combined results summary
import pickle
with open('results_v2/all_results_summary.pkl', 'wb') as f:
    summary = {
        gamma: {
            'name': all_results[gamma]['name'],
            'history': all_results[gamma]['history'],
            'final_acc': all_results[gamma]['history']['acc'][-1],
            'final_bias': all_results[gamma]['history']['bias'][-1]
        }
        for gamma in GAMMA_VALUES
    }
    pickle.dump(summary, f)
print("‚úÖ Saved: results_v2/all_results_summary.pkl")

print(f"\n" + "=" * 60)
print("üìä FINAL RESULTS SUMMARY")
print("=" * 60)
for gamma in GAMMA_VALUES:
    name = all_results[gamma]['name']
    acc = all_results[gamma]['history']['acc'][-1]
    bias = all_results[gamma]['history']['bias'][-1]
    print(f"   {name}: {acc:.2f}% accuracy, {bias:.6f} bias")
print("=" * 60)
print("\nüìä Check WandB dashboard: https://wandb.ai")