# Strict 4-Phase Fed-AuditGAN (CIFAR-10)
**Branch:** strict-4-phase

This notebook runs FedAvg baseline and Strict Fed-AuditGAN with gamma=2 and gamma=8 using the exact formulas from the specification.

In [None]:
!pip install torch torchvision tqdm matplotlib numpy wandb -q
print("Dependencies installed!")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import copy
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import warnings
import wandb
warnings.filterwarnings('ignore')

# ========== GPU Configuration with optimizations ==========
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_GPUS = torch.cuda.device_count()

# Enable cuDNN optimizations
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    torch.cuda.empty_cache()

print(f"Using device: {DEVICE} | GPUs available: {NUM_GPUS}")

# ========== AGGRESSIVE GPU UTILIZATION ==========
NUM_WORKERS = 4 if NUM_GPUS > 0 else 0
PIN_MEMORY = True if torch.cuda.is_available() else False
BATCH_SIZE = 1024 if NUM_GPUS > 1 else 512
VAL_BATCH_SIZE = 2048 if NUM_GPUS > 1 else 1024
PREFETCH_FACTOR = 4 if NUM_WORKERS > 0 else None
GAN_BATCH_SIZE = 512
N_PROBES = 2000

# Experiment settings
N_ROUNDS = 50
N_CLIENTS = 20
LOCAL_EPOCHS = 3
GAN_EPOCHS = 20

# CIFAR-10 specific
IMG_SHAPE = (3, 32, 32)
NUM_CLASSES = 10

# ========== WandB Configuration ==========
USE_WANDB = True  # Set to False to disable WandB logging
WANDB_PROJECT = "Fed-AuditGAN-Strict-CIFAR10"
WANDB_ENTITY = None  # Set your WandB username/team or leave None

print(f"\n{'='*60}")
print(f"CIFAR-10 Strict 4-Phase Fed-AuditGAN - GPU OPTIMIZED")
print(f"{'='*60}")
print(f"GPUs: {NUM_GPUS}")
print(f"Batch Size: {BATCH_SIZE} | Val Batch: {VAL_BATCH_SIZE}")
print(f"Workers: {NUM_WORKERS} | Probes: {N_PROBES}")
print(f"Rounds: {N_ROUNDS} | Clients: {N_CLIENTS} | Local Epochs: {LOCAL_EPOCHS}")
print(f"WandB: {'Enabled' if USE_WANDB else 'Disabled'}")
print(f"{'='*60}")

In [None]:
# ========== WandB Login ==========
if USE_WANDB:
    wandb.login()
    print("WandB logged in successfully!")
else:
    print("WandB is disabled. Set USE_WANDB = True to enable logging.")

In [None]:
# ========== Data Preparation ==========
print("Pre-loading CIFAR-10 dataset...")

# CIFAR-10 transforms with augmentation for training
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

TRAIN_DATA = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
TEST_DATA = datasets.CIFAR10('./data', train=False, download=True, transform=transform_test)

# For GAN training - no augmentation
transform_gan = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
GAN_TRAIN_DATA = datasets.CIFAR10('./data', train=True, download=True, transform=transform_gan)

# Non-IID partition: each client gets only 1-2 classes (extreme heterogeneity)
def partition_non_iid_extreme(dataset, n_clients, classes_per_client=2, seed=42):
    np.random.seed(seed)
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    class_indices = {c: np.where(labels == c)[0] for c in range(NUM_CLASSES)}
    client_data = []
    classes_assigned = []
    for cid in range(n_clients):
        n_classes = np.random.randint(1, classes_per_client + 1)
        client_classes = np.random.choice(NUM_CLASSES, n_classes, replace=False)
        classes_assigned.append(client_classes.tolist())
        client_idx = []
        for c in client_classes:
            n_samples = len(class_indices[c]) // (n_clients // 2)
            start = np.random.randint(0, max(1, len(class_indices[c]) - n_samples))
            client_idx.extend(class_indices[c][start:start + n_samples])
        np.random.shuffle(client_idx)
        client_data.append(np.array(client_idx))
    return client_data, classes_assigned

CLIENT_IDX, CLIENT_CLASSES = partition_non_iid_extreme(TRAIN_DATA, N_CLIENTS)
print(f"Client data sizes: {[len(idx) for idx in CLIENT_IDX]}")
print(f"Client classes: {CLIENT_CLASSES}")

# Pre-create DataLoaders with GPU optimizations
loader_kwargs = {'num_workers': NUM_WORKERS, 'pin_memory': PIN_MEMORY}
if NUM_WORKERS > 0:
    loader_kwargs['persistent_workers'] = True
    loader_kwargs['prefetch_factor'] = PREFETCH_FACTOR

TEST_LOADER = DataLoader(TEST_DATA, batch_size=VAL_BATCH_SIZE, shuffle=False, **loader_kwargs)

VAL_IDX = np.random.choice(len(GAN_TRAIN_DATA), 5000, replace=False)
VAL_LOADER = DataLoader(Subset(GAN_TRAIN_DATA, VAL_IDX), batch_size=GAN_BATCH_SIZE, shuffle=True, **loader_kwargs)

# Client training loaders
CLIENT_LOADERS = []
for cid in range(N_CLIENTS):
    loader = DataLoader(Subset(TRAIN_DATA, CLIENT_IDX[cid]), batch_size=BATCH_SIZE, shuffle=True, **loader_kwargs)
    CLIENT_LOADERS.append(loader)

# Client evaluation loaders
CLIENT_EVAL_LOADERS = []
for cid in range(N_CLIENTS):
    loader = DataLoader(Subset(TRAIN_DATA, CLIENT_IDX[cid]), batch_size=VAL_BATCH_SIZE, shuffle=False, **loader_kwargs)
    CLIENT_EVAL_LOADERS.append(loader)

print("Data loaders ready.")

In [None]:
# ========== CNN Model for CIFAR-10 ==========
class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(512 * 2 * 2, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))  # 32->16
        x = self.pool(F.relu(self.bn2(self.conv2(x))))  # 16->8
        x = self.pool(F.relu(self.bn3(self.conv3(x))))  # 8->4
        x = self.pool(F.relu(self.bn4(self.conv4(x))))  # 4->2
        x = x.view(-1, 512 * 2 * 2)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)

# ========== Fairness Generator (Strict Spec) ==========
class FairnessGenerator(nn.Module):
    def __init__(self, latent_dim=128, num_classes=10, img_shape=(3, 32, 32)):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        self.init_size = img_shape[1] // 4  # 8
        self.l1 = nn.Linear(latent_dim * 2, 256 * self.init_size ** 2)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(256),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, img_shape[0], 3, 1, 1),
            nn.Tanh()
        )
        self.delta_net = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, int(np.prod(img_shape))),
            nn.Tanh()
        )
        self.delta_scale = 0.1

    def forward(self, z, labels):
        gen_input = torch.cat([z, self.label_emb(labels)], dim=1)
        out = self.l1(gen_input)
        out = out.view(-1, 256, self.init_size, self.init_size)
        x = self.conv_blocks(out)
        delta = self.delta_net(z).view(-1, *self.img_shape) * self.delta_scale
        x_prime = torch.clamp(x + delta, -1, 1)
        return x, x_prime

# ========== Discriminator ==========
class Discriminator(nn.Module):
    def __init__(self, num_classes=10, img_shape=(3, 32, 32)):
        super().__init__()
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.conv = nn.Sequential(
            nn.Conv2d(img_shape[0] + num_classes, 64, 3, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 3, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Linear(512 * 2 * 2, 1)

    def forward(self, img, labels):
        label_map = self.label_emb(labels).view(-1, self.num_classes, 1, 1)
        label_map = label_map.expand(-1, -1, self.img_shape[1], self.img_shape[2])
        out = self.conv(torch.cat([img, label_map], dim=1))
        return self.fc(out.view(out.size(0), -1))

# ========== Helper Functions ==========
def make_parallel(model):
    if NUM_GPUS > 1:
        return nn.DataParallel(model)
    return model

def get_base_model(model):
    if isinstance(model, nn.DataParallel):
        return model.module
    return model

print("Models and helpers defined.")

In [None]:
# ========== FAIRNESS METRICS ==========
def compute_jfi(performances):
    p = np.array(performances)
    n = len(p)
    if np.sum(p**2) == 0:
        return 1.0
    return (np.sum(p)**2) / (n * np.sum(p**2))

def compute_max_min_fairness(performances):
    p = np.array(performances)
    if np.max(p) == 0:
        return 0.0
    return np.min(p) / np.max(p)

def compute_fairness_variance(performances):
    return np.var(np.array(performances))

@torch.no_grad()
def evaluate_per_client(model, client_loaders):
    model.eval()
    client_accuracies = []
    with torch.amp.autocast(device_type='cuda', enabled=(DEVICE=='cuda')):
        for loader in client_loaders:
            correct, total = 0, 0
            for d, t in loader:
                d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
                correct += (model(d).argmax(1) == t).sum().item()
                total += len(t)
            client_accuracies.append(100 * correct / total if total > 0 else 0)
    return client_accuracies

def compute_all_fairness_metrics(client_accuracies):
    return {
        'jfi': compute_jfi(client_accuracies),
        'max_min_fairness': compute_max_min_fairness(client_accuracies),
        'fairness_variance': compute_fairness_variance(client_accuracies),
        'min_accuracy': np.min(client_accuracies),
        'max_accuracy': np.max(client_accuracies),
        'mean_accuracy': np.mean(client_accuracies)
    }

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.amp.autocast(device_type='cuda', enabled=(DEVICE=='cuda')):
        for d, t in loader:
            d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
            correct += (model(d).argmax(1) == t).sum().item()
            total += len(t)
    return 100 * correct / total

print("Fairness metrics defined.")

In [None]:
# ========== STRICT 4-PHASE IMPLEMENTATION ==========

# --- Phase 2: Train Fairness Generator ---
# Loss: L_G = -||Theta(x) - Theta(x')||^2 + lambda1*||x - x'||^2 + lambda2*L_realism
def train_fairness_gan(G, D, model, loader, epochs=20, lambda1=1.0, lambda2=1.0):
    model.eval()
    G = make_parallel(G)
    D = make_parallel(D)
    
    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    bce = nn.BCEWithLogitsLoss()
    scaler = torch.amp.GradScaler(device='cuda', enabled=(DEVICE=='cuda'))
    
    latent_dim = get_base_model(G).latent_dim
    num_classes = get_base_model(G).num_classes
    
    for epoch in range(epochs):
        for imgs, labels in loader:
            batch_size = imgs.size(0)
            real_t = torch.ones(batch_size, 1, device=DEVICE)
            fake_t = torch.zeros(batch_size, 1, device=DEVICE)
            imgs = imgs.to(DEVICE, non_blocking=True)
            labels = labels.to(DEVICE, non_blocking=True)
            z = torch.randn(batch_size, latent_dim, device=DEVICE)
            gen_labels = torch.randint(0, num_classes, (batch_size,), device=DEVICE)
            
            # Generator: L_G = -||Theta(x)-Theta(x')||^2 + lambda1*||x-x'||^2 + lambda2*L_realism
            opt_G.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type='cuda', enabled=(DEVICE=='cuda')):
                x, x_prime = G(z, gen_labels)
                with torch.no_grad():
                    pred_x, pred_xp = model(x), model(x_prime)
                # Term 1: MAXIMIZE prediction difference (negative sign)
                pred_diff = -torch.mean((pred_x - pred_xp) ** 2)
                # Term 2: MINIMIZE feature difference
                feature_diff = lambda1 * torch.mean((x - x_prime) ** 2)
                # Term 3: Realism
                gan_loss = lambda2 * (bce(D(x, gen_labels), real_t) + bce(D(x_prime, gen_labels), real_t)) / 2
                g_loss = pred_diff + feature_diff + gan_loss
            scaler.scale(g_loss).backward()
            scaler.step(opt_G)
            
            # Discriminator
            opt_D.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type='cuda', enabled=(DEVICE=='cuda')):
                x, x_prime = G(z, gen_labels)
                d_real = bce(D(imgs, labels), real_t)
                d_fake = (bce(D(x.detach(), gen_labels), fake_t) + bce(D(x_prime.detach(), gen_labels), fake_t)) / 2
                d_loss = (d_real + d_fake) / 2
            scaler.scale(d_loss).backward()
            scaler.step(opt_D)
            scaler.update()
    
    return get_base_model(G), get_base_model(D)

# --- Phase 3: Compute Bias (EXACT formula) ---
# B = (1/|P|) * sum |Theta(x) - Theta(x')|
@torch.no_grad()
def compute_bias(model, x, x_prime):
    model.eval()
    with torch.amp.autocast(device_type='cuda', enabled=(DEVICE=='cuda')):
        pred_x = model(x)
        pred_xp = model(x_prime)
    # |Theta(x) - Theta(x')| summed over classes, averaged over probes
    return torch.abs(pred_x - pred_xp).sum(dim=1).mean().item()

# --- Phase 4: Softmax Aggregation Weights (EXACT formula) ---
# alpha_i = exp(gamma * S_i) / sum_j exp(gamma * S_j)
def compute_softmax_weights(S_list, gamma):
    S_tensor = torch.tensor(S_list, dtype=torch.float32)
    scaled = gamma * S_tensor
    alphas = F.softmax(scaled, dim=0).tolist()
    return alphas

print("Strict 4-phase functions defined.")

In [None]:
# ========== RUN FEDAVG BASELINE ==========
def run_fedavg(n_rounds=N_ROUNDS, n_clients=N_CLIENTS, local_epochs=LOCAL_EPOCHS, lr=0.01, run_name="FedAvg"):
    print(f"FedAvg | Rounds: {n_rounds} | Clients: {n_clients} | GPUs: {NUM_GPUS}")
    
    # Initialize WandB run
    if USE_WANDB:
        wandb.init(
            project=WANDB_PROJECT,
            entity=WANDB_ENTITY,
            name=run_name,
            config={
                "method": "FedAvg",
                "n_rounds": n_rounds,
                "n_clients": n_clients,
                "local_epochs": local_epochs,
                "lr": lr,
                "batch_size": BATCH_SIZE,
                "dataset": "CIFAR-10",
                "device": DEVICE,
                "num_gpus": NUM_GPUS
            },
            reinit=True
        )
    
    global_model = make_parallel(CNN().to(DEVICE))
    scaler = torch.amp.GradScaler(device='cuda', enabled=(DEVICE=='cuda'))
    
    history = {'accuracy': [], 'jfi': [], 'max_min_fairness': [], 'fairness_variance': [],
               'min_accuracy': [], 'max_accuracy': [], 'client_accuracies': []}
    
    for rnd in tqdm(range(n_rounds), desc="FedAvg"):
        updates = []
        client_sizes = []
        
        for cid in range(n_clients):
            local = make_parallel(copy.deepcopy(get_base_model(global_model)).to(DEVICE))
            before = {k: v.clone() for k, v in global_model.state_dict().items()}
            opt = optim.SGD(local.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
            local.train()
            
            for _ in range(local_epochs):
                for d, t in CLIENT_LOADERS[cid]:
                    d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
                    opt.zero_grad(set_to_none=True)
                    with torch.amp.autocast(device_type='cuda', enabled=(DEVICE=='cuda')):
                        loss = F.cross_entropy(local(d), t)
                    scaler.scale(loss).backward()
                    scaler.step(opt)
                    scaler.update()
            
            after = local.state_dict()
            updates.append({k: after[k] - before[k] for k in before})
            client_sizes.append(len(CLIENT_IDX[cid]))
            del local
            if DEVICE == 'cuda': torch.cuda.empty_cache()
        
        # FedAvg aggregation: weighted by data size
        total_size = sum(client_sizes)
        alphas = [size / total_size for size in client_sizes]
        
        sd = global_model.state_dict()
        for k in sd:
            sd[k] = sd[k] + sum(alphas[i] * updates[i][k] for i in range(n_clients))
        global_model.load_state_dict(sd)
        
        acc = evaluate(global_model, TEST_LOADER)
        client_accs = evaluate_per_client(global_model, CLIENT_EVAL_LOADERS)
        fairness = compute_all_fairness_metrics(client_accs)
        
        history['accuracy'].append(acc)
        history['jfi'].append(fairness['jfi'])
        history['max_min_fairness'].append(fairness['max_min_fairness'])
        history['fairness_variance'].append(fairness['fairness_variance'])
        history['min_accuracy'].append(fairness['min_accuracy'])
        history['max_accuracy'].append(fairness['max_accuracy'])
        history['client_accuracies'].append(client_accs)
        
        # Log to WandB
        if USE_WANDB:
            log_dict = {
                "round": rnd + 1,
                "accuracy": acc,
                "jfi": fairness['jfi'],
                "max_min_fairness": fairness['max_min_fairness'],
                "fairness_variance": fairness['fairness_variance'],
                "min_accuracy": fairness['min_accuracy'],
                "max_accuracy": fairness['max_accuracy'],
                "accuracy_gap": fairness['max_accuracy'] - fairness['min_accuracy']
            }
            # Log per-client accuracies
            for i, c_acc in enumerate(client_accs):
                log_dict[f"client_{i}_accuracy"] = c_acc
            wandb.log(log_dict)
    
    # Finish WandB run
    if USE_WANDB:
        wandb.finish()
    
    print(f"FedAvg Done | Accuracy: {history['accuracy'][-1]:.2f}% | JFI: {history['jfi'][-1]:.4f}")
    return history

print("FedAvg function ready.")

In [None]:
# ========== RUN STRICT FED-AUDITGAN (4-PHASE) ==========
def run_strict_fed_audit_gan(gamma, n_rounds=N_ROUNDS, n_clients=N_CLIENTS, local_epochs=LOCAL_EPOCHS, lr=0.01, gan_epochs=GAN_EPOCHS, n_probes=N_PROBES):
    print(f"Strict Fed-AuditGAN | gamma={gamma} | Rounds: {n_rounds} | Clients: {n_clients}")
    
    # Initialize WandB run
    run_name = f"Strict-gamma{gamma}"
    if USE_WANDB:
        wandb.init(
            project=WANDB_PROJECT,
            entity=WANDB_ENTITY,
            name=run_name,
            config={
                "method": "Strict-Fed-AuditGAN",
                "gamma": gamma,
                "n_rounds": n_rounds,
                "n_clients": n_clients,
                "local_epochs": local_epochs,
                "lr": lr,
                "gan_epochs": gan_epochs,
                "n_probes": n_probes,
                "batch_size": BATCH_SIZE,
                "dataset": "CIFAR-10",
                "device": DEVICE,
                "num_gpus": NUM_GPUS
            },
            reinit=True
        )
    
    global_model = make_parallel(CNN().to(DEVICE))
    scaler = torch.amp.GradScaler(device='cuda', enabled=(DEVICE=='cuda'))
    
    history = {'accuracy': [], 'bias': [], 'jfi': [], 'max_min_fairness': [], 'fairness_variance': [],
               'min_accuracy': [], 'max_accuracy': [], 'client_accuracies': [], 'S_list': [], 'alphas': []}
    
    for rnd in tqdm(range(n_rounds), desc=f"gamma={gamma}"):
        # ===== PHASE 1: Local Training =====
        # Each client computes delta_w_i = w_after - w_before
        updates = []
        for cid in range(n_clients):
            local = make_parallel(copy.deepcopy(get_base_model(global_model)).to(DEVICE))
            before = {k: v.clone() for k, v in global_model.state_dict().items()}
            opt = optim.SGD(local.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
            local.train()
            
            for _ in range(local_epochs):
                for d, t in CLIENT_LOADERS[cid]:
                    d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
                    opt.zero_grad(set_to_none=True)
                    with torch.amp.autocast(device_type='cuda', enabled=(DEVICE=='cuda')):
                        loss = F.cross_entropy(local(d), t)
                    scaler.scale(loss).backward()
                    scaler.step(opt)
                    scaler.update()
            
            after = local.state_dict()
            updates.append({k: after[k] - before[k] for k in before})
            del local
            if DEVICE == 'cuda': torch.cuda.empty_cache()
        
        # ===== PHASE 2: Server-Side GAN Auditing =====
        # Train G to find bias: L_G = -||Theta(x)-Theta(x')||^2 + lambda1*||x-x'||^2 + lambda2*L_realism
        G = FairnessGenerator(img_shape=IMG_SHAPE).to(DEVICE)
        D = Discriminator(img_shape=IMG_SHAPE).to(DEVICE)
        G, D = train_fairness_gan(G, D, global_model, VAL_LOADER, epochs=gan_epochs)
        G.eval()
        
        with torch.no_grad():
            z = torch.randn(n_probes, G.latent_dim, device=DEVICE)
            lbls = torch.randint(0, NUM_CLASSES, (n_probes,), device=DEVICE)
            with torch.amp.autocast(device_type='cuda', enabled=(DEVICE=='cuda')):
                G_par = make_parallel(G)
                x_p, xp_p = G_par(z, lbls)
        
        # ===== PHASE 3: Fairness Contribution Scoring =====
        # B_base = (1/|P|) * sum |Theta_old(x) - Theta_old(x')|
        B_base = compute_bias(global_model, x_p, xp_p)
        
        S_list = []
        for i, upd in enumerate(updates):
            # Theta_test_i = Theta_old + delta_w_i
            hyp = make_parallel(copy.deepcopy(get_base_model(global_model)).to(DEVICE))
            sd = hyp.state_dict()
            for k in sd:
                if k in upd:
                    sd[k] = sd[k] + upd[k]
            hyp.load_state_dict(sd)
            # B_i = (1/|P|) * sum |Theta_test_i(x) - Theta_test_i(x')|
            B_i = compute_bias(hyp, x_p, xp_p)
            # S_i = B_base - B_i (positive = reduced bias)
            S_i = B_base - B_i
            S_list.append(S_i)
            del hyp
        
        del G, D, x_p, xp_p
        if DEVICE == 'cuda': torch.cuda.empty_cache()
        
        # ===== PHASE 4: Rewards & Aggregation =====
        # alpha_i = exp(gamma * S_i) / sum_j exp(gamma * S_j)  (Softmax)
        alphas = compute_softmax_weights(S_list, gamma)
        
        # Theta_new = Theta_old + sum alpha_i * delta_w_i
        sd = global_model.state_dict()
        for k in sd:
            sd[k] = sd[k] + sum(alphas[i] * updates[i][k] for i in range(n_clients))
        global_model.load_state_dict(sd)
        
        # Evaluate
        acc = evaluate(global_model, TEST_LOADER)
        client_accs = evaluate_per_client(global_model, CLIENT_EVAL_LOADERS)
        fairness = compute_all_fairness_metrics(client_accs)
        
        history['accuracy'].append(acc)
        history['bias'].append(B_base)
        history['jfi'].append(fairness['jfi'])
        history['max_min_fairness'].append(fairness['max_min_fairness'])
        history['fairness_variance'].append(fairness['fairness_variance'])
        history['min_accuracy'].append(fairness['min_accuracy'])
        history['max_accuracy'].append(fairness['max_accuracy'])
        history['client_accuracies'].append(client_accs)
        history['S_list'].append(S_list)
        history['alphas'].append(alphas)
        
        # Log to WandB
        if USE_WANDB:
            log_dict = {
                "round": rnd + 1,
                "accuracy": acc,
                "bias": B_base,
                "jfi": fairness['jfi'],
                "max_min_fairness": fairness['max_min_fairness'],
                "fairness_variance": fairness['fairness_variance'],
                "min_accuracy": fairness['min_accuracy'],
                "max_accuracy": fairness['max_accuracy'],
                "accuracy_gap": fairness['max_accuracy'] - fairness['min_accuracy']
            }
            # Log per-client accuracies, scores, and weights
            for i in range(n_clients):
                log_dict[f"client_{i}_accuracy"] = client_accs[i]
                log_dict[f"client_{i}_score"] = S_list[i]
                log_dict[f"client_{i}_weight"] = alphas[i]
            wandb.log(log_dict)
    
    # Finish WandB run
    if USE_WANDB:
        wandb.finish()
    
    print(f"gamma={gamma} Done | Accuracy: {history['accuracy'][-1]:.2f}% | JFI: {history['jfi'][-1]:.4f}")
    return history

print("Strict Fed-AuditGAN function ready.")

In [None]:
# ========== RUN ALL EXPERIMENTS ==========
results = {}

print('\n' + '='*60)
print('Running FedAvg Baseline')
print('='*60)
results['FedAvg'] = run_fedavg(n_rounds=N_ROUNDS, n_clients=N_CLIENTS)

for gamma in [2, 8]:
    print('\n' + '='*60)
    print(f'Running Strict Fed-AuditGAN gamma = {gamma}')
    print('='*60)
    results[f'gamma{gamma}'] = run_strict_fed_audit_gan(gamma=gamma, n_rounds=N_ROUNDS, n_clients=N_CLIENTS)

print('\n' + '='*60)
print('All experiments complete!')
print('='*60)

In [None]:
# ========== VISUALIZATION ==========
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
colors = {'FedAvg': 'black', 'gamma2': 'blue', 'gamma8': 'red'}
linestyles = {'FedAvg': '--', 'gamma2': '-', 'gamma8': '-'}
rounds = list(range(1, N_ROUNDS + 1))

# Plot 1: Global Accuracy
for method in results:
    axes[0, 0].plot(rounds, results[method]['accuracy'], linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[0, 0].set_xlabel('Round'); axes[0, 0].set_ylabel('Accuracy (%)')
axes[0, 0].set_title('CIFAR-10: Global Test Accuracy'); axes[0, 0].legend(); axes[0, 0].grid(True, alpha=0.3)

# Plot 2: JFI
for method in results:
    axes[0, 1].plot(rounds, results[method]['jfi'], linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[0, 1].set_xlabel('Round'); axes[0, 1].set_ylabel('JFI')
axes[0, 1].set_title("Jain's Fairness Index (higher=fairer)"); axes[0, 1].legend(); axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Max-Min Fairness
for method in results:
    axes[0, 2].plot(rounds, results[method]['max_min_fairness'], linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[0, 2].set_xlabel('Round'); axes[0, 2].set_ylabel('Min/Max Ratio')
axes[0, 2].set_title('Max-Min Fairness (higher=fairer)'); axes[0, 2].legend(); axes[0, 2].grid(True, alpha=0.3)

# Plot 4: Variance
for method in results:
    axes[1, 0].plot(rounds, results[method]['fairness_variance'], linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[1, 0].set_xlabel('Round'); axes[1, 0].set_ylabel('Variance')
axes[1, 0].set_title('Per-Client Accuracy Variance (lower=fairer)'); axes[1, 0].legend(); axes[1, 0].grid(True, alpha=0.3)

# Plot 5: Accuracy Gap
for method in results:
    gap = [results[method]['max_accuracy'][i] - results[method]['min_accuracy'][i] for i in range(len(rounds))]
    axes[1, 1].plot(rounds, gap, linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[1, 1].set_xlabel('Round'); axes[1, 1].set_ylabel('Gap (%)')
axes[1, 1].set_title('Best-Worst Client Gap (lower=fairer)'); axes[1, 1].legend(); axes[1, 1].grid(True, alpha=0.3)

# Plot 6: Final Per-Client Accuracy
final_client_accs = {m: results[m]['client_accuracies'][-1] for m in results}
x = np.arange(N_CLIENTS)
width = 0.25
for i, method in enumerate(results):
    axes[1, 2].bar(x + i*width, final_client_accs[method], width, label=method, color=colors[method], alpha=0.8)
axes[1, 2].set_xlabel('Client ID'); axes[1, 2].set_ylabel('Accuracy (%)')
axes[1, 2].set_title(f'Per-Client Accuracy (Round {N_ROUNDS})'); axes[1, 2].legend(); axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('cifar10_strict_4phase_results.png', dpi=150)
plt.show()

# ========== Log Summary to WandB ==========
if USE_WANDB:
    # Create a summary run to log comparison plots and tables
    wandb.init(
        project=WANDB_PROJECT,
        entity=WANDB_ENTITY,
        name="Summary-Comparison",
        config={"type": "summary"},
        reinit=True
    )
    
    # Log the comparison plot
    wandb.log({"comparison_plot": wandb.Image('cifar10_strict_4phase_results.png')})
    
    # Create summary table
    summary_data = []
    for method in results:
        summary_data.append([
            method,
            results[method]['accuracy'][-1],
            results[method]['jfi'][-1],
            results[method]['max_min_fairness'][-1],
            results[method]['fairness_variance'][-1],
            results[method]['min_accuracy'][-1],
            results[method]['max_accuracy'][-1],
            results[method]['max_accuracy'][-1] - results[method]['min_accuracy'][-1]
        ])
    
    summary_table = wandb.Table(
        columns=["Method", "Accuracy", "JFI", "Max-Min", "Variance", "Min Acc", "Max Acc", "Gap"],
        data=summary_data
    )
    wandb.log({"summary_table": summary_table})
    wandb.finish()
    print("Summary logged to WandB!")

# Summary Table
print('\n' + '='*110)
print(f'CIFAR-10 STRICT 4-PHASE RESULTS SUMMARY (Round {N_ROUNDS})')
print('='*110)
print(f"{'Method':<15} {'Global Acc':>10} {'JFI':>8} {'Max-Min':>10} {'Variance':>12} {'Min Acc':>10} {'Max Acc':>10} {'Gap':>8}")
print('-'*110)
for method in results:
    acc = results[method]['accuracy'][-1]
    jfi = results[method]['jfi'][-1]
    mmf = results[method]['max_min_fairness'][-1]
    var = results[method]['fairness_variance'][-1]
    min_acc = results[method]['min_accuracy'][-1]
    max_acc = results[method]['max_accuracy'][-1]
    gap = max_acc - min_acc
    print(f"{method:<15} {acc:>10.2f}% {jfi:>8.4f} {mmf:>10.4f} {var:>12.2f} {min_acc:>10.2f}% {max_acc:>10.2f}% {gap:>8.2f}%")
print('='*110)

# Bias summary
print('\nBias Summary (Strict Fed-AuditGAN only):')
for method in ['gamma2', 'gamma8']:
    if method in results and 'bias' in results[method]:
        print(f"  {method}: Final Bias = {results[method]['bias'][-1]:.4f}")

---
# Strict 4-Phase Fed-AuditGAN: Full Formulas

## Phase 1: Local Client Training
Each client $i$ trains locally and computes the update:
$$\Delta w_i = w_{\text{after}} - w_{\text{before}}$$

---

## Phase 2: Server-Side Generative Auditing (GAN Training)
The generator creates counterfactual pairs $(x, x')$ where $x' = x + \delta$.

**Generator Loss:**
$$\mathcal{L}_G = -\|\Theta(x) - \Theta(x')\|^2 + \lambda_1 \|x - x'\|^2 + \lambda_2 \mathcal{L}_{\text{realism}}$$

- **Term 1:** $-\|\Theta(x) - \Theta(x')\|^2$ → **Maximize** prediction difference (find bias)
- **Term 2:** $\lambda_1 \|x - x'\|^2$ → **Minimize** feature difference (keep $x$ and $x'$ similar)
- **Term 3:** $\lambda_2 \mathcal{L}_{\text{realism}}$ → Keep generated samples realistic (adversarial loss)

---

## Phase 3: Fairness Contribution Scoring

### Step 1: Measure Baseline Bias
$$B_{\text{base}} = \frac{1}{|P|} \sum_{(x,x') \in P} |\Theta_{\text{old}}(x) - \Theta_{\text{old}}(x')|$$

### Step 2: Hypothetical Application
$$\Theta_{\text{test}_i} = \Theta_{\text{old}} + \Delta w_i$$

### Step 3: Measure Client Bias
$$B_i = \frac{1}{|P|} \sum_{(x,x') \in P} |\Theta_{\text{test}_i}(x) - \Theta_{\text{test}_i}(x')|$$

### Step 4: Calculate Fairness Score
$$S_i = B_{\text{base}} - B_i$$

- **$S_i > 0$**: Client **reduced** bias ✓ (Good!)
- **$S_i < 0$**: Client **increased** bias ✗ (Bad!)

---

## Phase 4: Rewards and Aggregation

### Softmax Aggregation Weights:
$$\alpha_i = \frac{\exp(\gamma \cdot S_i)}{\sum_{j=1}^{K} \exp(\gamma \cdot S_j)}$$

- **$\gamma$ (gamma)**: Scaling factor
  - Higher $\gamma$ → **Strongly punish** biased clients (more extreme weight differences)
  - Lower $\gamma$ → More uniform weighting

### Final Aggregation:
$$\Theta_{\text{new}} = \Theta_{\text{old}} + \sum_{i=1}^{K} \alpha_i \cdot \Delta w_i$$

---

## Summary Table
| Phase | Purpose | Key Formula |
|-------|---------|-------------|
| 1 | Local Training | $\Delta w_i = w_{\text{after}} - w_{\text{before}}$ |
| 2 | GAN Auditing | $\mathcal{L}_G = -\|\Theta(x)-\Theta(x')\|^2 + \lambda_1\|x-x'\|^2 + \lambda_2\mathcal{L}_{\text{realism}}$ |
| 3 | Scoring | $S_i = B_{\text{base}} - B_i$ |
| 4 | Aggregation | $\alpha_i = \text{softmax}(\gamma \cdot S_i)$, $\Theta_{\text{new}} = \Theta_{\text{old}} + \sum \alpha_i \Delta w_i$ |