# üîß Fed-Audit-GAN v7.0 - CIFAR-10 (Simplified: No Gradient Clipping, No Warmup)

## üéØ Experiments Run:

### üîµ Baseline: FedAvg (Federated Averaging)
- Standard data-weighted averaging (McMahan et al., 2017)
- NO GAN, NO fairness scoring

### üü¢ Our Method: Fed-Audit-GAN 
- **Fed-Audit-GAN Œ≥ = 0.3** - Accuracy-weighted
- **Fed-Audit-GAN Œ≥ = 0.7** - Fairness-weighted

## üîß Fed-Audit-GAN 4-Phase Architecture:
1. **Phase 1**: Local Client Training (SGD - NO gradient clipping)
2. **Phase 2**: GAN Training (EVERY round from round 1!)
3. **Phase 3**: Fairness Scoring
4. **Phase 4**: V2 Linear Aggregation (from round 1 - NO warmup!)

## ‚≠ê KEY CHANGES FROM V6:
- ‚ùå **REMOVED**: Gradient clipping
- ‚ùå **REMOVED**: Warmup rounds - V2 Linear from round 1!
- ‚ùå **REMOVED**: Soft labels - using hard labels (1.0, 0.0)
- ‚úÖ **KEPT**: GAN trains EVERY round
- ‚úÖ **KEPT**: V2 Linear Formula for aggregation

## ‚≠ê V2 Linear Aggregation Formula:
```
Weight = (1 - Œ≥) √ó Accuracy_norm + Œ≥ √ó Fairness_norm
```

## ‚≠ê NON-IID:
- Each client gets only 2-3 classes (extreme heterogeneity)

---

In [None]:
# Step 1: Install Dependencies
!pip install wandb torch torchvision tqdm matplotlib numpy -q
print("‚úÖ Dependencies installed!")

In [None]:
# Step 2: Login to WandB
import wandb
wandb.login()
print("‚úÖ WandB logged in!")

In [None]:
# Step 3: Imports and GPU Setup
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import copy
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# GPU Configuration with optimizations
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_GPUS = torch.cuda.device_count()

# Enable cuDNN optimizations
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    torch.cuda.empty_cache()

print(f"Using device: {DEVICE} | GPUs available: {NUM_GPUS}")

# ========== GPU UTILIZATION ==========
NUM_WORKERS = 8
PIN_MEMORY = True
BATCH_SIZE = 1024 if NUM_GPUS > 1 else 512
VAL_BATCH_SIZE = 2048 if NUM_GPUS > 1 else 1024
PREFETCH_FACTOR = 4
N_PROBES = 2000

# Experiment settings
N_ROUNDS = 50
N_CLIENTS = 20
GAN_EPOCHS = 20
LOCAL_EPOCHS = 3

# CIFAR-10 specific
IMG_SHAPE = (3, 32, 32)
NUM_CLASSES = 10

# Pre-load CIFAR-10 dataset
print("Pre-loading CIFAR-10 dataset...")
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

TRAIN_DATA = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
TEST_DATA = datasets.CIFAR10('./data', train=False, download=True, transform=transform_test)

# Non-IID partition: each client gets only 2-3 classes
def partition_non_iid_extreme(dataset, n_clients, classes_per_client=2, seed=42):
    """Extreme non-IID: each client gets only a few classes"""
    np.random.seed(seed)
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    class_indices = {c: np.where(labels == c)[0] for c in range(NUM_CLASSES)}
    
    client_data = []
    classes_assigned = []
    
    for cid in range(n_clients):
        n_classes = np.random.randint(classes_per_client, classes_per_client + 2)
        client_classes = np.random.choice(NUM_CLASSES, n_classes, replace=False)
        classes_assigned.append(client_classes.tolist())
        
        client_idx = []
        for c in client_classes:
            n_samples = len(class_indices[c]) // (n_clients // 2)
            start = np.random.randint(0, max(1, len(class_indices[c]) - n_samples))
            client_idx.extend(class_indices[c][start:start + n_samples])
        
        np.random.shuffle(client_idx)
        client_data.append(np.array(client_idx))
    
    return client_data, classes_assigned

CLIENT_IDX, CLIENT_CLASSES = partition_non_iid_extreme(TRAIN_DATA, N_CLIENTS)
print(f"Client data sizes: {[len(idx) for idx in CLIENT_IDX]}")
print(f"Client classes: {CLIENT_CLASSES}")

# Pre-create DataLoaders
TEST_LOADER = DataLoader(TEST_DATA, batch_size=VAL_BATCH_SIZE, shuffle=False,
                         num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                         persistent_workers=True, prefetch_factor=PREFETCH_FACTOR)

VAL_IDX = np.random.choice(len(TRAIN_DATA), 2000, replace=False)
VAL_LOADER = DataLoader(Subset(TRAIN_DATA, VAL_IDX), batch_size=VAL_BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                        persistent_workers=True, prefetch_factor=PREFETCH_FACTOR)

# Client evaluation loaders
CLIENT_EVAL_LOADERS = []
for cid in range(N_CLIENTS):
    loader = DataLoader(Subset(TRAIN_DATA, CLIENT_IDX[cid]), batch_size=VAL_BATCH_SIZE, shuffle=False,
                       num_workers=4, pin_memory=PIN_MEMORY, persistent_workers=True)
    CLIENT_EVAL_LOADERS.append(loader)

# Client training loaders
CLIENT_LOADERS = []
for cid in range(N_CLIENTS):
    loader = DataLoader(Subset(TRAIN_DATA, CLIENT_IDX[cid]), batch_size=BATCH_SIZE, shuffle=True,
                       num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
                       persistent_workers=True, prefetch_factor=PREFETCH_FACTOR)
    CLIENT_LOADERS.append(loader)

print(f"\n{'='*60}")
print(f"CIFAR-10 Fed-AuditGAN v7.0 - SIMPLIFIED")
print(f"{'='*60}")
print(f"‚ùå NO Gradient Clipping")
print(f"‚ùå NO Warmup Rounds")
print(f"‚ùå NO Soft Labels (using hard labels)")
print(f"‚úÖ GAN every round from round 1")
print(f"‚úÖ V2 Linear aggregation from round 1")
print(f"{'='*60}")

In [None]:
# ========== MODEL DEFINITIONS ==========

class CNN(nn.Module):
    """CNN for CIFAR-10 classification"""
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 128, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(256, 512, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(512)
        self.conv4 = nn.Conv2d(512, 512, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(512 * 2 * 2, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))  # 32->16
        x = self.pool(F.relu(self.bn2(self.conv2(x))))  # 16->8
        x = self.pool(F.relu(self.bn3(self.conv3(x))))  # 8->4
        x = self.pool(F.relu(self.bn4(self.conv4(x))))  # 4->2
        x = x.view(-1, 512 * 2 * 2)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)


class FairnessGenerator(nn.Module):
    def __init__(self, latent_dim=128, num_classes=10, img_shape=(3, 32, 32)):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        self.init_size = img_shape[1] // 4  # 8
        self.l1 = nn.Linear(latent_dim * 2, 512 * self.init_size ** 2)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, img_shape[0], 3, 1, 1),
            nn.Tanh()
        )
        self.delta_net = nn.Sequential(
            nn.Linear(latent_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
        self.delta_scale = 0.1

    def forward(self, z, labels):
        gen_input = torch.cat([z, self.label_emb(labels)], dim=1)
        out = self.l1(gen_input).view(-1, 512, self.init_size, self.init_size)
        x = self.conv_blocks(out)
        delta = self.delta_net(z).view(-1, *self.img_shape) * self.delta_scale
        return x, torch.clamp(x + delta, -1, 1)


class Discriminator(nn.Module):
    def __init__(self, num_classes=10, img_shape=(3, 32, 32)):
        super().__init__()
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.conv = nn.Sequential(
            nn.Conv2d(img_shape[0] + num_classes, 128, 3, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 3, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, 3, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Linear(512 * 2 * 2, 1)

    def forward(self, img, labels):
        label_map = self.label_emb(labels).view(-1, self.num_classes, 1, 1)
        label_map = label_map.expand(-1, -1, self.img_shape[1], self.img_shape[2])
        out = self.conv(torch.cat([img, label_map], dim=1))
        return self.fc(out.view(out.size(0), -1))


# ========== Helper Functions ==========
def make_parallel(model):
    if NUM_GPUS > 1:
        return nn.DataParallel(model)
    return model

def get_base_model(model):
    if isinstance(model, nn.DataParallel):
        return model.module
    return model

print("‚úÖ Models defined: CNN, FairnessGenerator, Discriminator")

In [None]:
# ========== FAIRNESS METRICS ==========

def compute_jfi(performances):
    p = np.array(performances)
    n = len(p)
    if np.sum(p**2) == 0:
        return 1.0
    return (np.sum(p)**2) / (n * np.sum(p**2))

def compute_max_min_fairness(performances):
    p = np.array(performances)
    if np.max(p) == 0:
        return 0.0
    return np.min(p) / np.max(p)

def compute_fairness_variance(performances):
    return np.var(np.array(performances))

@torch.no_grad()
def evaluate_per_client(model, client_loaders):
    model.eval()
    client_accuracies = []
    with torch.amp.autocast(device_type='cuda'):
        for loader in client_loaders:
            correct, total = 0, 0
            for d, t in loader:
                d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
                correct += (model(d).argmax(1) == t).sum().item()
                total += len(t)
            client_accuracies.append(100 * correct / total if total > 0 else 0)
    return client_accuracies

def compute_all_fairness_metrics(client_accuracies):
    return {
        'jfi': compute_jfi(client_accuracies),
        'max_min_fairness': compute_max_min_fairness(client_accuracies),
        'fairness_variance': compute_fairness_variance(client_accuracies),
        'min_accuracy': np.min(client_accuracies),
        'max_accuracy': np.max(client_accuracies),
        'mean_accuracy': np.mean(client_accuracies)
    }

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.amp.autocast(device_type='cuda'):
        for d, t in loader:
            d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
            correct += (model(d).argmax(1) == t).sum().item()
            total += len(t)
    return 100 * correct / total

def normalize(scores):
    s = np.array(scores)
    if s.max() - s.min() < 1e-8:
        return np.ones_like(s) / len(s)
    return (s - s.min()) / (s.max() - s.min())

print("‚úÖ Fairness metrics defined")

In [None]:
# ========== GAN TRAINING () ==========

def train_fairness_gan(G, D, model, loader, epochs=15, alpha=1.0, beta=1.0):
    """
    Train GAN with HARD LABELS (1.0 and 0.0) - NO soft labels!
    """
    model.eval()
    G = make_parallel(G)
    D = make_parallel(D)
    
    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    bce = nn.BCEWithLogitsLoss()
    scaler = torch.amp.GradScaler(device='cuda')
    
    latent_dim = get_base_model(G).latent_dim
    num_classes = get_base_model(G).num_classes
    
    for epoch in range(epochs):
        for imgs, labels in loader:
            batch_size = imgs.size(0)
            
            # ‚ùå NO SOFT LABELS - Using HARD labels!
            real_t = torch.ones(batch_size, 1, device=DEVICE)   # Hard 1.0
            fake_t = torch.zeros(batch_size, 1, device=DEVICE)  # Hard 0.0
            
            imgs = imgs.to(DEVICE, non_blocking=True)
            labels = labels.to(DEVICE, non_blocking=True)
            z = torch.randn(batch_size, latent_dim, device=DEVICE)
            gen_labels = torch.randint(0, num_classes, (batch_size,), device=DEVICE)
            
            # Generator
            opt_G.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type='cuda'):
                x, x_prime = G(z, gen_labels)
                with torch.no_grad():
                    pred_x, pred_xp = model(x), model(x_prime)
                pred_diff = -beta * torch.mean((pred_x - pred_xp) ** 2)
                realism = alpha * torch.mean((x - x_prime) ** 2)
                gan_loss = (bce(D(x, gen_labels), real_t) + bce(D(x_prime, gen_labels), real_t)) / 2
                g_loss = pred_diff + realism + gan_loss
            scaler.scale(g_loss).backward()
            scaler.step(opt_G)
            
            # Discriminator
            opt_D.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type='cuda'):
                x, x_prime = G(z, gen_labels)
                d_real = bce(D(imgs, labels), real_t)
                d_fake = (bce(D(x.detach(), gen_labels), fake_t) + bce(D(x_prime.detach(), gen_labels), fake_t)) / 2
                d_loss = (d_real + d_fake) / 2
            scaler.scale(d_loss).backward()
            scaler.step(opt_D)
            scaler.update()
    
    return get_base_model(G), get_base_model(D)


@torch.no_grad()
def compute_bias(model, x, x_prime):
    model.eval()
    with torch.amp.autocast(device_type='cuda'):
        pred_x = model(x)
        pred_xp = model(x_prime)
    return torch.abs(pred_x - pred_xp).sum(dim=1).mean().item()


@torch.no_grad()
def compute_acc_score(model, update, val_loader):
    model.eval()
    loss_before, count = 0.0, 0
    with torch.amp.autocast(device_type='cuda'):
        for d, t in val_loader:
            d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
            loss_before += F.cross_entropy(model(d), t, reduction='sum').item()
            count += len(t)
    loss_before /= count
    
    base_model = get_base_model(model)
    hyp = copy.deepcopy(base_model)
    sd = hyp.state_dict()
    
    for k in sd:
        update_key = k
        if k not in update:
            update_key = k.replace('module.', '') if k.startswith('module.') else 'module.' + k
        if update_key in update:
            sd[k] = sd[k] + update[update_key]
        elif k in update:
            sd[k] = sd[k] + update[k]
    
    hyp.load_state_dict(sd)
    hyp = make_parallel(hyp.to(DEVICE))
    hyp.eval()
    
    loss_after = 0.0
    with torch.amp.autocast(device_type='cuda'):
        for d, t in val_loader:
            d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
            loss_after += F.cross_entropy(hyp(d), t, reduction='sum').item()
    return loss_before - loss_after / count


print("‚úÖ GAN training functions defined (NO soft labels!)")

In [None]:
# ========== TRAINING FUNCTIONS ==========

def run_fedavg(n_rounds=N_ROUNDS, n_clients=N_CLIENTS, local_epochs=LOCAL_EPOCHS, lr=0.01):
    """Standard FedAvg baseline"""
    wandb.init(project="FED_AUDIT_GAN_TEST_7_CIFAR10", name="FedAvg_v7", 
               config={'method': 'FedAvg', 'n_rounds': n_rounds, 'n_clients': n_clients})
    
    global_model = make_parallel(CNN().to(DEVICE))
    scaler = torch.amp.GradScaler(device='cuda')
    
    history = {'accuracy': [], 'jfi': [], 'max_min_fairness': [], 'fairness_variance': [],
               'min_accuracy': [], 'max_accuracy': [], 'client_accuracies': [], 'weights': []}
    
    print(f"FedAvg | Rounds: {n_rounds} | Clients: {n_clients} | GPUs: {NUM_GPUS}")
    
    for rnd in tqdm(range(n_rounds), desc="FedAvg"):
        updates = []
        client_sizes = []
        
        for cid in range(n_clients):
            local = make_parallel(copy.deepcopy(get_base_model(global_model)).to(DEVICE))
            before = {k: v.clone() for k, v in global_model.state_dict().items()}
            opt = optim.SGD(local.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
            local.train()
            
            # ‚ùå NO GRADIENT CLIPPING
            for _ in range(local_epochs):
                for d, t in CLIENT_LOADERS[cid]:
                    d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
                    opt.zero_grad(set_to_none=True)
                    with torch.amp.autocast(device_type='cuda'):
                        loss = F.cross_entropy(local(d), t)
                    scaler.scale(loss).backward()
                    scaler.step(opt)
                    scaler.update()
            
            after = local.state_dict()
            updates.append({k: after[k] - before[k] for k in before})
            client_sizes.append(len(CLIENT_IDX[cid]))
            del local
            torch.cuda.empty_cache()
        
        # FedAvg aggregation
        total_size = sum(client_sizes)
        alphas = [size / total_size for size in client_sizes]
        
        sd = global_model.state_dict()
        for k in sd:
            sd[k] = sd[k] + sum(alphas[i] * updates[i][k] for i in range(n_clients))
        global_model.load_state_dict(sd)
        
        # Evaluate
        acc = evaluate(global_model, TEST_LOADER)
        client_accs = evaluate_per_client(global_model, CLIENT_EVAL_LOADERS)
        fairness = compute_all_fairness_metrics(client_accs)
        
        history['accuracy'].append(acc)
        history['jfi'].append(fairness['jfi'])
        history['max_min_fairness'].append(fairness['max_min_fairness'])
        history['fairness_variance'].append(fairness['fairness_variance'])
        history['min_accuracy'].append(fairness['min_accuracy'])
        history['max_accuracy'].append(fairness['max_accuracy'])
        history['client_accuracies'].append(client_accs)
        history['weights'].append(alphas)
        
        wandb.log({'round': rnd+1, 'accuracy': acc, 'jfi': fairness['jfi'],
                   'max_min_fairness': fairness['max_min_fairness'],
                   'fairness_variance': fairness['fairness_variance'],
                   'min_client_acc': fairness['min_accuracy'],
                   'max_client_acc': fairness['max_accuracy']})
    
    print(f"FedAvg Done | Accuracy: {history['accuracy'][-1]:.2f}% | JFI: {history['jfi'][-1]:.4f}")
    wandb.finish()
    return history


def run_fed_audit_gan(gamma, n_rounds=N_ROUNDS, n_clients=N_CLIENTS, local_epochs=LOCAL_EPOCHS, 
                      lr=0.01, gan_epochs=GAN_EPOCHS, n_probes=N_PROBES):
    """
    Fed-AuditGAN v7.0 - Simplified
    ‚ùå NO gradient clipping
    ‚ùå NO warmup rounds
    ‚ùå NO soft labels
    ‚úÖ V2 Linear from round 1
    """
    wandb.init(project="FED_AUDIT_GAN_TEST_7_CIFAR10", name=f"FedAuditGAN_v7_gamma{gamma}",
               config={'method': 'FedAuditGAN_v7', 'gamma': gamma, 'n_rounds': n_rounds, 
                       'n_clients': n_clients, 'gradient_clipping': False, 'warmup': False, 'soft_labels': False})
    
    global_model = make_parallel(CNN().to(DEVICE))
    scaler = torch.amp.GradScaler(device='cuda')
    
    history = {'accuracy': [], 'bias': [], 'jfi': [], 'max_min_fairness': [], 'fairness_variance': [],
               'min_accuracy': [], 'max_accuracy': [], 'client_accuracies': [], 'weights': []}
    
    print(f"\n{'='*60}")
    print(f"Fed-AuditGAN v7.0 | gamma={gamma}")
    print(f"‚ùå NO Gradient Clipping | ‚ùå NO Warmup | ‚ùå NO Soft Labels")
    print(f"{'='*60}")
    
    for rnd in tqdm(range(n_rounds), desc=f"gamma={gamma}"):
        # ================================================================
        # PHASE 1: Client training (NO gradient clipping!)
        # ================================================================
        updates = []
        for cid in range(n_clients):
            local = make_parallel(copy.deepcopy(get_base_model(global_model)).to(DEVICE))
            before = {k: v.clone() for k, v in global_model.state_dict().items()}
            opt = optim.SGD(local.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
            local.train()
            
            for _ in range(local_epochs):
                for d, t in CLIENT_LOADERS[cid]:
                    d, t = d.to(DEVICE, non_blocking=True), t.to(DEVICE, non_blocking=True)
                    opt.zero_grad(set_to_none=True)
                    with torch.amp.autocast(device_type='cuda'):
                        loss = F.cross_entropy(local(d), t)
                    scaler.scale(loss).backward()
                    # ‚ùå NO GRADIENT CLIPPING!
                    scaler.step(opt)
                    scaler.update()
            
            after = local.state_dict()
            updates.append({k: after[k] - before[k] for k in before})
            del local
            torch.cuda.empty_cache()
        
        # ================================================================
        # PHASE 2: GAN auditing (NO warmup - train from round 1!)
        # ================================================================
        G = FairnessGenerator(img_shape=IMG_SHAPE).to(DEVICE)
        D = Discriminator(img_shape=IMG_SHAPE).to(DEVICE)
        G, D = train_fairness_gan(G, D, global_model, VAL_LOADER, epochs=gan_epochs)
        G.eval()
        
        with torch.no_grad():
            z = torch.randn(n_probes, G.latent_dim, device=DEVICE)
            lbls = torch.randint(0, NUM_CLASSES, (n_probes,), device=DEVICE)
            with torch.amp.autocast(device_type='cuda'):
                G_parallel = make_parallel(G)
                x_p, xp_p = G_parallel(z, lbls)
        
        # ================================================================
        # PHASE 3: Scoring
        # ================================================================
        B_base = compute_bias(global_model, x_p, xp_p)
        S_fair, S_acc = [], []
        for upd in updates:
            hyp = make_parallel(copy.deepcopy(get_base_model(global_model)).to(DEVICE))
            sd = hyp.state_dict()
            for k in sd:
                if k in upd:
                    sd[k] = sd[k] + upd[k]
            hyp.load_state_dict(sd)
            S_fair.append(B_base - compute_bias(hyp, x_p, xp_p))
            S_acc.append(compute_acc_score(global_model, upd, VAL_LOADER))
            del hyp
        
        del G, D, x_p, xp_p
        torch.cuda.empty_cache()
        
        # ================================================================
        # PHASE 4: V2 Linear Aggregation (NO warmup - from round 1!)
        # ================================================================
        S_fair_n, S_acc_n = normalize(S_fair), normalize(S_acc)
        alphas = [(1-gamma)*S_acc_n[i] + gamma*S_fair_n[i] for i in range(n_clients)]
        a_sum = sum(alphas)
        alphas = [a/a_sum if a_sum > 0 else 1/n_clients for a in alphas]
        
        sd = global_model.state_dict()
        for k in sd:
            sd[k] = sd[k] + sum(alphas[i] * updates[i][k] for i in range(n_clients))
        global_model.load_state_dict(sd)
        
        # Evaluate
        acc = evaluate(global_model, TEST_LOADER)
        client_accs = evaluate_per_client(global_model, CLIENT_EVAL_LOADERS)
        fairness = compute_all_fairness_metrics(client_accs)
        
        history['accuracy'].append(acc)
        history['bias'].append(B_base)
        history['jfi'].append(fairness['jfi'])
        history['max_min_fairness'].append(fairness['max_min_fairness'])
        history['fairness_variance'].append(fairness['fairness_variance'])
        history['min_accuracy'].append(fairness['min_accuracy'])
        history['max_accuracy'].append(fairness['max_accuracy'])
        history['client_accuracies'].append(client_accs)
        history['weights'].append(alphas)
        
        wandb.log({'round': rnd+1, 'accuracy': acc, 'bias': B_base, 'jfi': fairness['jfi'],
                   'max_min_fairness': fairness['max_min_fairness'],
                   'fairness_variance': fairness['fairness_variance'],
                   'min_client_acc': fairness['min_accuracy'],
                   'max_client_acc': fairness['max_accuracy']})
    
    print(f"gamma={gamma} Done | Accuracy: {history['accuracy'][-1]:.2f}% | JFI: {history['jfi'][-1]:.4f}")
    wandb.finish()
    return history

print("‚úÖ Training functions ready!")

In [None]:
# ========== RUN ALL EXPERIMENTS ==========

results = {}

# 1. Run FedAvg baseline
print("\n" + "="*60)
print("Running FedAvg Baseline")
print("="*60)
results['FedAvg'] = run_fedavg(n_rounds=N_ROUNDS, n_clients=N_CLIENTS)

# 2. Run Fed-AuditGAN with gamma = 0.3 and 0.7
for gamma in [0.3, 0.7]:
    print("\n" + "="*60)
    print(f"Running Fed-AuditGAN v7.0 gamma = {gamma}")
    print("="*60)
    results[f'gamma{gamma}'] = run_fed_audit_gan(gamma=gamma, n_rounds=N_ROUNDS, n_clients=N_CLIENTS)

print("\n" + "="*60)
print("‚úÖ All experiments complete!")
print("="*60)

In [None]:
# ========== VISUALIZATION ==========

fig, axes = plt.subplots(2, 3, figsize=(18, 10))
colors = {'FedAvg': 'black', 'gamma0.3': 'blue', 'gamma0.7': 'red'}
linestyles = {'FedAvg': '--', 'gamma0.3': '-', 'gamma0.7': '-'}
rounds = list(range(1, N_ROUNDS + 1))

# Plot 1: Global Accuracy
for method in results:
    axes[0, 0].plot(rounds, results[method]['accuracy'], linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[0, 0].set_xlabel('Round'); axes[0, 0].set_ylabel('Accuracy (%)')
axes[0, 0].set_title('CIFAR-10 v7: Global Test Accuracy'); axes[0, 0].legend(); axes[0, 0].grid(True, alpha=0.3)

# Plot 2: JFI
for method in results:
    axes[0, 1].plot(rounds, results[method]['jfi'], linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[0, 1].set_xlabel('Round'); axes[0, 1].set_ylabel('JFI')
axes[0, 1].set_title("Jain's Fairness Index (higher=fairer)"); axes[0, 1].legend(); axes[0, 1].grid(True, alpha=0.3)

# Plot 3: Max-Min Fairness
for method in results:
    axes[0, 2].plot(rounds, results[method]['max_min_fairness'], linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[0, 2].set_xlabel('Round'); axes[0, 2].set_ylabel('Min/Max Ratio')
axes[0, 2].set_title('Max-Min Fairness (higher=fairer)'); axes[0, 2].legend(); axes[0, 2].grid(True, alpha=0.3)

# Plot 4: Variance
for method in results:
    axes[1, 0].plot(rounds, results[method]['fairness_variance'], linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[1, 0].set_xlabel('Round'); axes[1, 0].set_ylabel('Variance')
axes[1, 0].set_title('Per-Client Accuracy Variance (lower=fairer)'); axes[1, 0].legend(); axes[1, 0].grid(True, alpha=0.3)

# Plot 5: Accuracy Gap
for method in results:
    gap = [results[method]['max_accuracy'][i] - results[method]['min_accuracy'][i] for i in range(len(rounds))]
    axes[1, 1].plot(rounds, gap, linestyle=linestyles[method],
                    label=method, color=colors[method], linewidth=2)
axes[1, 1].set_xlabel('Round'); axes[1, 1].set_ylabel('Gap (%)')
axes[1, 1].set_title('Best-Worst Client Gap (lower=fairer)'); axes[1, 1].legend(); axes[1, 1].grid(True, alpha=0.3)

# Plot 6: Final Per-Client Accuracy
final_client_accs = {m: results[m]['client_accuracies'][-1] for m in results}
x = np.arange(N_CLIENTS)
width = 0.25
for i, method in enumerate(results):
    axes[1, 2].bar(x + i*width, final_client_accs[method], width, label=method, color=colors[method], alpha=0.8)
axes[1, 2].set_xlabel('Client ID'); axes[1, 2].set_ylabel('Accuracy (%)')
axes[1, 2].set_title(f'Per-Client Accuracy (Round {N_ROUNDS})'); axes[1, 2].legend(); axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('cifar10_v7_results.png', dpi=150)
plt.show()

# Summary Table
print("\n" + "="*110)
print("CIFAR-10 v7.0 FINAL RESULTS SUMMARY (Round {})".format(N_ROUNDS))
print("="*110)
print(f"{'Method':<15} {'Global Acc':>10} {'JFI':>8} {'Max-Min':>10} {'Variance':>12} {'Min Acc':>10} {'Max Acc':>10} {'Gap':>8}")
print("-"*110)
for method in results:
    acc = results[method]['accuracy'][-1]
    jfi = results[method]['jfi'][-1]
    mmf = results[method]['max_min_fairness'][-1]
    var = results[method]['fairness_variance'][-1]
    min_acc = results[method]['min_accuracy'][-1]
    max_acc = results[method]['max_accuracy'][-1]
    gap = max_acc - min_acc
    print(f"{method:<15} {acc:>10.2f}% {jfi:>8.4f} {mmf:>10.4f} {var:>12.2f} {min_acc:>10.2f}% {max_acc:>10.2f}% {gap:>8.2f}%")
print("="*110)

# V7 Summary
print("\n" + "="*60)
print("V7.0 CONFIGURATION SUMMARY")
print("="*60)
print("‚ùå Gradient Clipping: REMOVED")
print("‚ùå Warmup Rounds: REMOVED (V2 Linear from round 1)")
print("‚ùå Soft Labels: REMOVED (using hard labels 1.0/0.0)")
print("‚úÖ GAN: Trains every round")
print("‚úÖ Aggregation: V2 Linear Formula")
print("="*60)

# üîß Fed-Audit-GAN v7.0 - CIFAR-10 (No Gradient Clipping, No Warmup)

## üéØ Experiments Run:

### üîµ Baseline: FedAvg (Federated Averaging)
- Standard data-weighted averaging (McMahan et al., 2017)
- NO GAN, NO fairness scoring

### üü¢ Our Method: Fed-Audit-GAN (V2 Linear Formula)
- **Fed-Audit-GAN Œ≥ = 0.3** - Accuracy-weighted
- **Fed-Audit-GAN Œ≥ = 0.7** - Fairness-weighted

## üîß Fed-Audit-GAN 4-Phase Architecture:
1. **Phase 1**: Local Client Training (SGD - NO gradient clipping)
2. **Phase 2**: GAN Training (EVERY round from round 1!)
3. **Phase 3**: Fairness Scoring (with EMA)
4. **Phase 4**: V2 Linear Aggregation (from round 1 - NO warmup!)

## ‚≠ê KEY CHANGES FROM V6:
- ‚ùå **REMOVED**: Gradient clipping
- ‚ùå **REMOVED**: Warmup rounds - V2 Linear from round 1!
- ‚úÖ **KEPT**: GAN trains EVERY round
- ‚úÖ **KEPT**: V2 Linear Formula for aggregation
- ‚úÖ **KEPT**: Soft labels for GAN stability
- ‚úÖ **KEPT**: EMA momentum for fairness scores

## ‚≠ê V2 Linear Aggregation Formula:
```
Weight = (1 - Œ≥) √ó Accuracy_norm + Œ≥ √ó Fairness_norm
```

## ‚≠ê PATHOLOGICAL NON-IID:
- Each client ONLY has 2 out of 10 classes
- Clients have DIFFERENT sample sizes (Dirichlet distribution)

---

In [None]:
# Step 1: Install Dependencies
!pip install -q torch torchvision tqdm matplotlib numpy wandb

print("‚úÖ Dependencies installed!")

In [None]:
# Step 2: Login to WandB
import wandb
wandb.login()
print("‚úÖ WandB logged in!")

In [None]:
# Step 3: Imports and GPU Setup
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import copy
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# ============================================================
# üöÄ GPU SETUP
# ============================================================

if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    NUM_GPUS = torch.cuda.device_count()
    
    # Enable cuDNN optimizations
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    
    # Enable TF32 for faster matrix operations
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
    torch.cuda.empty_cache()
    
    print(f"‚úÖ GPU(s) detected: {NUM_GPUS}")
    for i in range(NUM_GPUS):
        props = torch.cuda.get_device_properties(i)
        print(f"   GPU {i}: {props.name}")
        print(f"      Memory: {props.total_memory / 1e9:.2f} GB")
    print(f"   CUDA Version: {torch.version.cuda}")
else:
    DEVICE = torch.device('cpu')
    NUM_GPUS = 0
    print("‚ö†Ô∏è  No GPU detected. Using CPU.")

# Mixed Precision Training
USE_AMP = torch.cuda.is_available()
if USE_AMP:
    print("\n‚úÖ Mixed Precision Training (AMP) enabled.")

print(f"\nüìç Device: {DEVICE}")
print(f"   PyTorch: {torch.__version__}")

In [None]:
# ============================================================
# MODEL DEFINITIONS (CIFAR-10)
# ============================================================

class CNN(nn.Module):
    """CNN for CIFAR-10 classification"""
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(512 * 2 * 2, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))  # 32->16
        x = self.pool(F.relu(self.bn2(self.conv2(x))))  # 16->8
        x = self.pool(F.relu(self.bn3(self.conv3(x))))  # 8->4
        x = self.pool(F.relu(self.bn4(self.conv4(x))))  # 4->2
        x = x.view(-1, 512 * 2 * 2)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)


class FairnessGenerator(nn.Module):
    """Generator that produces paired samples (x, x') for fairness testing"""
    def __init__(self, latent_dim=128, num_classes=10, img_shape=(3, 32, 32)):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        self.init_size = img_shape[1] // 4  # 8
        self.l1 = nn.Linear(latent_dim * 2, 512 * self.init_size ** 2)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, img_shape[0], 3, 1, 1),
            nn.Tanh()
        )
        self.delta_net = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, int(np.prod(img_shape))),
            nn.Tanh()
        )
        self.delta_scale = 0.1

    def forward(self, z, labels):
        gen_input = torch.cat([z, self.label_emb(labels)], dim=1)
        out = self.l1(gen_input).view(-1, 512, self.init_size, self.init_size)
        x = self.conv_blocks(out)
        delta = self.delta_net(z).view(-1, *self.img_shape) * self.delta_scale
        return x, torch.clamp(x + delta, -1, 1)


class Discriminator(nn.Module):
    """Conditional Discriminator"""
    def __init__(self, num_classes=10, img_shape=(3, 32, 32)):
        super().__init__()
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.conv = nn.Sequential(
            nn.Conv2d(img_shape[0] + num_classes, 64, 3, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 3, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Linear(512 * 2 * 2, 1)

    def forward(self, img, labels):
        label_map = self.label_emb(labels).view(-1, self.num_classes, 1, 1)
        label_map = label_map.expand(-1, -1, self.img_shape[1], self.img_shape[2])
        out = self.conv(torch.cat([img, label_map], dim=1))
        return self.fc(out.view(out.size(0), -1))


print("‚úÖ Models defined: CNN, FairnessGenerator, Discriminator")

In [None]:
# ============================================================
# HELPER FUNCTIONS
# ============================================================

def train_gan(G, D, model, loader, epochs=15, device='cuda', l1=1.0, l2=1.0):
    """
    Train the Fairness GAN with Soft Labels for stability
    """
    G, D, model = G.to(device), D.to(device), model.to(device)
    model.eval()
    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    bce = nn.BCEWithLogitsLoss()
    
    if USE_AMP:
        scaler_G = torch.amp.GradScaler(device='cuda')
        scaler_D = torch.amp.GradScaler(device='cuda')
    
    for _ in range(epochs):
        for imgs, labels in loader:
            bs = imgs.size(0)
            
            # Soft Labels for stability
            real_labels = torch.empty(bs, 1, device=device).uniform_(0.9, 1.0)
            fake_labels = torch.empty(bs, 1, device=device).uniform_(0.0, 0.1)
            
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            z = torch.randn(bs, G.latent_dim, device=device)
            gl = torch.randint(0, G.num_classes, (bs,), device=device)
            
            # Generator
            opt_G.zero_grad(set_to_none=True)
            if USE_AMP:
                with torch.amp.autocast(device_type='cuda'):
                    x, xp = G(z, gl)
                    with torch.no_grad():
                        px, pxp = model(x), model(xp)
                    t1 = -torch.mean((px - pxp) ** 2)
                    t2 = l1 * torch.mean((x - xp) ** 2)
                    g_real = torch.empty(bs, 1, device=device).uniform_(0.9, 1.0)
                    t3 = l2 * (bce(D(x, gl), g_real) + bce(D(xp, gl), g_real)) / 2
                    g_loss = t1 + t2 + t3
                scaler_G.scale(g_loss).backward()
                scaler_G.step(opt_G)
                scaler_G.update()
            else:
                x, xp = G(z, gl)
                with torch.no_grad():
                    px, pxp = model(x), model(xp)
                t1 = -torch.mean((px - pxp) ** 2)
                t2 = l1 * torch.mean((x - xp) ** 2)
                g_real = torch.empty(bs, 1, device=device).uniform_(0.9, 1.0)
                t3 = l2 * (bce(D(x, gl), g_real) + bce(D(xp, gl), g_real)) / 2
                g_loss = t1 + t2 + t3
                g_loss.backward()
                opt_G.step()
            
            # Discriminator
            opt_D.zero_grad(set_to_none=True)
            if USE_AMP:
                with torch.amp.autocast(device_type='cuda'):
                    x, xp = G(z, gl)
                    d_loss = (bce(D(imgs, labels), real_labels) + 
                              bce(D(x.detach(), gl), fake_labels) + 
                              bce(D(xp.detach(), gl), fake_labels)) / 3
                scaler_D.scale(d_loss).backward()
                scaler_D.step(opt_D)
                scaler_D.update()
            else:
                x, xp = G(z, gl)
                d_loss = (bce(D(imgs, labels), real_labels) + 
                          bce(D(x.detach(), gl), fake_labels) + 
                          bce(D(xp.detach(), gl), fake_labels)) / 3
                d_loss.backward()
                opt_D.step()
    
    return G, D


@torch.no_grad()
def compute_bias(model, x, xp, device):
    """Compute bias as difference in model predictions between x and x'"""
    model.eval()
    x, xp = x.to(device), xp.to(device)
    
    if USE_AMP:
        with torch.amp.autocast(device_type='cuda'):
            diff = torch.abs(model(x) - model(xp)).sum(1).mean()
    else:
        diff = torch.abs(model(x) - model(xp)).sum(1).mean()
    
    return diff.item()


def partition_data_pathological_non_iid(dataset, n_clients, classes_per_client=2):
    """Pathological Non-IID with unequal sample sizes"""
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    n_classes = len(np.unique(labels))
    
    class_indices = {c: np.where(labels == c)[0] for c in range(n_classes)}
    for c in class_indices:
        np.random.shuffle(class_indices[c])
    
    client_classes = []
    for cid in range(n_clients):
        start_class = (cid * classes_per_client) % n_classes
        assigned = [(start_class + i) % n_classes for i in range(classes_per_client)]
        client_classes.append(assigned)
    
    # Unequal sample sizes using Dirichlet
    alpha = 0.5
    client_proportions = np.random.dirichlet([alpha] * n_clients)
    
    client_indices = [[] for _ in range(n_clients)]
    
    for c in range(n_classes):
        clients_with_class = [cid for cid in range(n_clients) if c in client_classes[cid]]
        
        if len(clients_with_class) > 0:
            class_samples = class_indices[c]
            total_for_class = len(class_samples)
            
            relevant_props = np.array([client_proportions[cid] for cid in clients_with_class])
            relevant_props = relevant_props / relevant_props.sum()
            
            start_idx = 0
            for i, cid in enumerate(clients_with_class):
                if i == len(clients_with_class) - 1:
                    end_idx = total_for_class
                else:
                    n_samples = int(total_for_class * relevant_props[i])
                    end_idx = min(start_idx + n_samples, total_for_class)
                
                if start_idx < end_idx:
                    client_indices[cid].extend(class_samples[start_idx:end_idx].tolist())
                start_idx = end_idx
    
    result = []
    for cid in range(n_clients):
        if len(client_indices[cid]) > 0:
            indices = np.array(client_indices[cid])
            np.random.shuffle(indices)
            result.append(indices)
        else:
            fallback_samples = np.random.choice(len(dataset), size=50, replace=False)
            result.append(fallback_samples)
    
    return result, client_classes


@torch.no_grad()
def evaluate(model, loader, device):
    """Evaluate model accuracy"""
    model.eval()
    correct, total = 0, 0
    
    for d, t in loader:
        d, t = d.to(device, non_blocking=True), t.to(device, non_blocking=True)
        
        if USE_AMP:
            with torch.amp.autocast(device_type='cuda'):
                preds = model(d).argmax(1)
        else:
            preds = model(d).argmax(1)
        
        correct += (preds == t).sum().item()
        total += len(t)
    
    return 100 * correct / total


@torch.no_grad()
def evaluate_per_client(model, client_loaders, device):
    """Evaluate model accuracy on each client's data"""
    model.eval()
    client_accuracies = []
    
    for loader in client_loaders:
        correct, total = 0, 0
        for d, t in loader:
            d, t = d.to(device, non_blocking=True), t.to(device, non_blocking=True)
            
            if USE_AMP:
                with torch.amp.autocast(device_type='cuda'):
                    preds = model(d).argmax(1)
            else:
                preds = model(d).argmax(1)
            
            correct += (preds == t).sum().item()
            total += len(t)
        
        acc = 100 * correct / total if total > 0 else 0
        client_accuracies.append(acc)
    
    return client_accuracies


# ============================================================
# FAIRNESS METRICS
# ============================================================

def calculate_jfi(performances):
    """Jain's Fairness Index"""
    p = np.array(performances)
    n = len(p)
    if np.sum(p ** 2) == 0:
        return 1.0
    return (np.sum(p) ** 2) / (n * np.sum(p ** 2))

def calculate_accuracy_gap(performances):
    return np.max(performances) - np.min(performances)

def calculate_variance(performances):
    return np.var(performances)

def calculate_max_min_fairness(performances):
    p = np.array(performances)
    if np.max(p) == 0:
        return 0.0
    return np.min(p) / np.max(p)


print("‚úÖ Helper functions defined")

In [None]:
# ============================================================
# CONFIGURATION (V7 - No Gradient Clipping, No Warmup)
# ============================================================

# Training Parameters
N_ROUNDS = 50
N_CLIENTS = 20
N_GAN_EPOCHS = 15
N_PROBES = 500
LOCAL_EPOCHS = 3

# ‚ùå V7 CHANGE: NO WARMUP ROUNDS - V2 Linear from round 1!
# WARMUP_ROUNDS = 0  # Removed!

# Fed-Audit-GAN Parameters
MOMENTUM = 0.8  # EMA momentum for fairness scores

# ‚ùå V7 CHANGE: NO GRADIENT CLIPPING
# GRAD_CLIP_NORM = None  # Removed!

# Test multiple gamma values
GAMMA_VALUES = [0.3, 0.7]

# DataLoader Parameters
BATCH_SIZE = 128
VAL_BATCH_SIZE = 256
NUM_WORKERS = 4
PIN_MEMORY = True
PREFETCH_FACTOR = 2

# PATHOLOGICAL NON-IID
CLASSES_PER_CLIENT = 2

print("=" * 70)
print("üîß Fed-Audit-GAN v7.0 - CIFAR-10 (No Gradient Clipping, No Warmup)")
print("=" * 70)
print(f"Device: {DEVICE}")
print(f"GPUs: {NUM_GPUS}")
print(f"AMP Enabled: {USE_AMP}")
print(f"Rounds: {N_ROUNDS}, Clients: {N_CLIENTS}")

print(f"\nüéØ EXPERIMENTS TO RUN:")
print(f"   üîµ 1. FedAvg (BASELINE)")
for i, g in enumerate(GAMMA_VALUES, 2):
    print(f"   üü¢ {i}. Fed-Audit-GAN Œ≥={g}")

print(f"\n‚≠ê V7 CHANGES FROM V6:")
print(f"   ‚ùå REMOVED: Gradient clipping")
print(f"   ‚ùå REMOVED: Warmup rounds - V2 Linear from round 1!")
print(f"   ‚úÖ KEPT: GAN trains EVERY round")
print(f"   ‚úÖ KEPT: V2 Linear aggregation formula")
print(f"   ‚úÖ KEPT: Soft labels for GAN")
print(f"   ‚úÖ KEPT: EMA momentum ({MOMENTUM})")

print(f"\n‚≠ê PATHOLOGICAL NON-IID:")
print(f"   Each client gets ONLY {CLASSES_PER_CLIENT}/10 classes")
print("=" * 70)

In [None]:
# ============================================================
# DATA LOADING
# ============================================================

CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

train_data = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
test_data = datasets.CIFAR10('./data', train=False, download=True, transform=transform_test)

# Create PATHOLOGICAL Non-IID partitions
np.random.seed(42)
client_idx, client_classes = partition_data_pathological_non_iid(
    train_data, N_CLIENTS, classes_per_client=CLASSES_PER_CLIENT
)

# Calculate data weights
client_data_sizes = [len(idx) for idx in client_idx]
total_samples = sum(client_data_sizes)
CLIENT_DATA_WEIGHTS = [size / total_samples for size in client_data_sizes]

# DataLoader kwargs
dataloader_kwargs = {
    'num_workers': NUM_WORKERS,
    'pin_memory': PIN_MEMORY,
    'persistent_workers': True,
    'prefetch_factor': PREFETCH_FACTOR,
}

test_loader = DataLoader(test_data, batch_size=VAL_BATCH_SIZE, shuffle=False, **dataloader_kwargs)
val_loader = DataLoader(
    Subset(train_data, np.random.choice(len(train_data), 2000, replace=False)),
    batch_size=BATCH_SIZE, shuffle=False, **dataloader_kwargs
)

client_loaders = [
    DataLoader(Subset(train_data, client_idx[c]), batch_size=BATCH_SIZE, shuffle=True, **dataloader_kwargs)
    for c in range(N_CLIENTS)
]

print(f"‚úÖ Data loaded!")
print(f"   Training samples: {len(train_data)}")
print(f"   Test samples: {len(test_data)}")
print(f"\n" + "=" * 70)
print(f"‚≠ê PATHOLOGICAL NON-IID ({CLASSES_PER_CLIENT} classes per client)")
print("=" * 70)
for i in range(min(10, N_CLIENTS)):
    class_names = [CIFAR10_CLASSES[c] for c in client_classes[i]]
    print(f"   Client {i:2d}: {client_classes[i]} ‚Üí {class_names} ({client_data_sizes[i]} samples)")

if N_CLIENTS > 10:
    print(f"   ... ({N_CLIENTS - 10} more clients)")

In [None]:
# ============================================================
# TRAINING FUNCTIONS (V7 - No Gradient Clipping, No Warmup)
# ============================================================

def run_fedavg(n_rounds, n_clients, train_data, client_idx, val_loader, test_loader, 
               client_loaders, local_epochs, device, use_amp, client_data_weights):
    """
    Standard FedAvg (McMahan et al., 2017)
    """
    
    print("\n" + "=" * 70)
    print("üîµ RUNNING: FedAvg (Federated Averaging) - BASELINE")
    print("=" * 70)
    
    model = CNN().to(device)
    scaler = torch.amp.GradScaler(device='cuda') if use_amp else None
    
    history = {
        'acc': [], 'client_accuracies': [],
        'jfi': [], 'max_min_fairness': [], 'variance': [], 'accuracy_gap': [],
        'min_client_acc': [], 'max_client_acc': []
    }
    
    for rnd in tqdm(range(n_rounds), desc="FedAvg"):
        local_weights = []
        
        for cid in range(n_clients):
            local_model = copy.deepcopy(model)
            local_model.train()
            optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
            criterion = nn.CrossEntropyLoss()
            
            for epoch in range(local_epochs):
                for data, target in client_loaders[cid]:
                    data = data.to(device, non_blocking=True)
                    target = target.to(device, non_blocking=True)
                    optimizer.zero_grad(set_to_none=True)
                    
                    if use_amp:
                        with torch.amp.autocast(device_type='cuda'):
                            output = local_model(data)
                            loss = criterion(output, target)
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        output = local_model(data)
                        loss = criterion(output, target)
                        loss.backward()
                        optimizer.step()
            
            local_weights.append(copy.deepcopy(local_model.state_dict()))
            del local_model
            torch.cuda.empty_cache()
        
        # Data-Weighted Averaging
        avg_weights = copy.deepcopy(local_weights[0])
        for key in avg_weights.keys():
            avg_weights[key] = avg_weights[key] * client_data_weights[0]
            for i in range(1, len(local_weights)):
                avg_weights[key] += local_weights[i][key] * client_data_weights[i]
        
        model.load_state_dict(avg_weights)
        
        # Evaluation
        acc = evaluate(model, test_loader, device)
        client_accs = evaluate_per_client(model, client_loaders, device)
        
        jfi = calculate_jfi(client_accs)
        max_min = calculate_max_min_fairness(client_accs)
        var = calculate_variance(client_accs)
        gap = calculate_accuracy_gap(client_accs)
        
        history['acc'].append(acc)
        history['client_accuracies'].append(client_accs.copy())
        history['jfi'].append(jfi)
        history['max_min_fairness'].append(max_min)
        history['variance'].append(var)
        history['accuracy_gap'].append(gap)
        history['min_client_acc'].append(min(client_accs))
        history['max_client_acc'].append(max(client_accs))
        
        wandb.log({
            'round': rnd + 1,
            'accuracy': acc,
            'jfi': jfi,
            'max_min_fairness': max_min,
            'fairness_variance': var,
            'accuracy_gap': gap,
            'min_client_acc': min(client_accs),
            'max_client_acc': max(client_accs)
        })
    
    return model, history


def run_fed_audit_gan_v7(gamma, n_rounds, n_clients, momentum,
                         train_data, client_idx, val_loader, 
                         test_loader, client_loaders, n_gan_epochs, n_probes, 
                         local_epochs, device, use_amp, client_data_weights):
    """
    Fed-Audit-GAN v7.0 - No Gradient Clipping, No Warmup
    
    ‚ùå REMOVED: Gradient clipping
    ‚ùå REMOVED: Warmup rounds - V2 Linear from round 1!
    ‚úÖ KEPT: GAN trains EVERY round
    ‚úÖ KEPT: V2 Linear aggregation formula
    """
    
    print(f"\n" + "=" * 70)
    print(f"üü¢ RUNNING: Fed-Audit-GAN v7.0 (Œ≥={gamma})")
    print(f"   ‚ùå NO Gradient Clipping")
    print(f"   ‚ùå NO Warmup - V2 Linear from round 1!")
    print(f"   ‚úÖ V2 Linear Formula: Weight = (1-Œ≥)√óAcc + Œ≥√óFair")
    print("=" * 70)
    
    model = CNN().to(device)
    scaler = torch.amp.GradScaler(device='cuda') if use_amp else None
    
    fairness_history = {i: 0.0 for i in range(n_clients)}
    
    history = {
        'acc': [], 'bias': [], 'alphas': [],
        'raw_scores': [], 'smoothed_scores': [],
        'client_accuracies': [],
        'jfi': [], 'max_min_fairness': [], 'variance': [], 'accuracy_gap': [],
        'min_client_acc': [], 'max_client_acc': []
    }
    
    for rnd in tqdm(range(n_rounds), desc=f"Fed-Audit-GAN Œ≥={gamma}"):
        
        # ================================================================
        # PHASE 1: Local Client Training (NO Gradient Clipping!)
        # ================================================================
        updates = []
        
        for cid in range(n_clients):
            local_model = copy.deepcopy(model)
            local_model.train()
            before_state = copy.deepcopy(model.state_dict())
            
            # SGD without Gradient Clipping
            optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
            
            for epoch in range(local_epochs):
                for data, target in client_loaders[cid]:
                    data = data.to(device, non_blocking=True)
                    target = target.to(device, non_blocking=True)
                    optimizer.zero_grad(set_to_none=True)
                    
                    if use_amp:
                        with torch.amp.autocast(device_type='cuda'):
                            output = local_model(data)
                            loss = F.cross_entropy(output, target)
                        scaler.scale(loss).backward()
                        # ‚ùå NO GRADIENT CLIPPING
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        output = local_model(data)
                        loss = F.cross_entropy(output, target)
                        loss.backward()
                        # ‚ùå NO GRADIENT CLIPPING
                        optimizer.step()
            
            update = {k: local_model.state_dict()[k] - before_state[k] for k in before_state}
            updates.append(update)
            del local_model
            torch.cuda.empty_cache()
        
        # ================================================================
        # PHASE 2 & 3: GAN Training + Fairness Scoring
        # GAN trains EVERY round from round 1!
        # ================================================================
        G = FairnessGenerator(img_shape=(3, 32, 32)).to(device)
        D = Discriminator(img_shape=(3, 32, 32)).to(device)
        G, D = train_gan(G, D, model, val_loader, epochs=n_gan_epochs, device=device)
        
        G.eval()
        with torch.no_grad():
            z = torch.randn(n_probes, G.latent_dim, device=device)
            labels = torch.randint(0, 10, (n_probes,), device=device)
            
            if use_amp:
                with torch.amp.autocast(device_type='cuda'):
                    x_probe, xp_probe = G(z, labels)
            else:
                x_probe, xp_probe = G(z, labels)
        
        B_base = compute_bias(model, x_probe, xp_probe, device)
        
        S_fair_raw = []
        S_fair_smoothed = []
        
        for cid, upd in enumerate(updates):
            hyp_model = copy.deepcopy(model)
            hyp_state = hyp_model.state_dict()
            for k in hyp_state:
                hyp_state[k] = hyp_state[k] + upd[k]
            hyp_model.load_state_dict(hyp_state)
            
            B_client = compute_bias(hyp_model, x_probe, xp_probe, device)
            S_current = B_base - B_client
            S_fair_raw.append(S_current)
            
            # EMA smoothing
            S_prev = fairness_history[cid]
            S_smoothed = (momentum * S_prev) + ((1 - momentum) * S_current)
            fairness_history[cid] = S_smoothed
            S_fair_smoothed.append(S_smoothed)
            del hyp_model
        
        del G, D, x_probe, xp_probe
        torch.cuda.empty_cache()
        
        history['raw_scores'].append(S_fair_raw.copy())
        history['smoothed_scores'].append(S_fair_smoothed.copy())
        
        # ================================================================
        # PHASE 4: V2 LINEAR AGGREGATION (FROM ROUND 1 - NO WARMUP!)
        # ================================================================
        
        # Normalize Fairness Score
        f_tensor = torch.tensor(S_fair_smoothed, device=device, dtype=torch.float32)
        f_min, f_max = f_tensor.min(), f_tensor.max()
        if f_max != f_min:
            f_norm = (f_tensor - f_min) / (f_max - f_min)
        else:
            f_norm = torch.ones_like(f_tensor) * 0.5
        
        # Normalize Accuracy Score (using data weights as proxy)
        a_tensor = torch.tensor(client_data_weights, device=device, dtype=torch.float32)
        a_min, a_max = a_tensor.min(), a_tensor.max()
        if a_max != a_min:
            a_norm = (a_tensor - a_min) / (a_max - a_min)
        else:
            a_norm = torch.ones_like(a_tensor) * 0.5
        
        # V2 Linear Formula
        raw_weights = ((1 - gamma) * a_norm) + (gamma * f_norm) + 1e-8
        
        # Final Normalization
        alphas = (raw_weights / raw_weights.sum()).tolist()
        
        # Debug output every 10 rounds
        if rnd % 10 == 0:
            print(f"\n   üìä Round {rnd+1} Debug (Œ≥={gamma}):")
            print(f"      Fairness: min={min(S_fair_smoothed):.4f}, max={max(S_fair_smoothed):.4f}")
            print(f"      Alphas: min={min(alphas):.4f}, max={max(alphas):.4f}")
        
        # Apply weighted aggregation
        new_state = model.state_dict()
        for k in new_state:
            new_state[k] = new_state[k] + sum(a * u[k] for a, u in zip(alphas, updates))
        model.load_state_dict(new_state)
        
        # ================================================================
        # EVALUATION
        # ================================================================
        acc = evaluate(model, test_loader, device)
        client_accs = evaluate_per_client(model, client_loaders, device)
        
        jfi = calculate_jfi(client_accs)
        max_min = calculate_max_min_fairness(client_accs)
        var = calculate_variance(client_accs)
        gap = calculate_accuracy_gap(client_accs)
        
        history['acc'].append(acc)
        history['bias'].append(B_base)
        history['alphas'].append(alphas.copy())
        history['client_accuracies'].append(client_accs.copy())
        history['jfi'].append(jfi)
        history['max_min_fairness'].append(max_min)
        history['variance'].append(var)
        history['accuracy_gap'].append(gap)
        history['min_client_acc'].append(min(client_accs))
        history['max_client_acc'].append(max(client_accs))
        
        wandb.log({
            'round': rnd + 1,
            'accuracy': acc,
            'bias': B_base,
            'jfi': jfi,
            'max_min_fairness': max_min,
            'fairness_variance': var,
            'accuracy_gap': gap,
            'min_client_acc': min(client_accs),
            'max_client_acc': max(client_accs),
            'alpha_min': min(alphas),
            'alpha_max': max(alphas)
        })
    
    return model, history


print("‚úÖ Training functions defined (FedAvg + Fed-Audit-GAN v7.0)")

In [None]:
# ============================================================
# RUN ALL EXPERIMENTS
# ============================================================

all_results = {}

# ============================================================
# EXPERIMENT 1: FedAvg (BASELINE)
# ============================================================
wandb.init(
    project="FED_AUDIT_GAN_TEST_7_CIFAR10",
    name=f"FedAvg_CIFAR10_clients{N_CLIENTS}_v7",
    config={
        "method": "FedAvg",
        "dataset": "CIFAR-10",
        "n_rounds": N_ROUNDS,
        "n_clients": N_CLIENTS,
        "non_iid_type": "pathological",
        "classes_per_client": CLASSES_PER_CLIENT,
        "device": str(DEVICE),
        "num_gpus": NUM_GPUS,
        "amp_enabled": USE_AMP
    }
)

fedavg_model, fedavg_history = run_fedavg(
    n_rounds=N_ROUNDS,
    n_clients=N_CLIENTS,
    train_data=train_data,
    client_idx=client_idx,
    val_loader=val_loader,
    test_loader=test_loader,
    client_loaders=client_loaders,
    local_epochs=LOCAL_EPOCHS,
    device=DEVICE,
    use_amp=USE_AMP,
    client_data_weights=CLIENT_DATA_WEIGHTS
)

wandb.finish()

all_results['FedAvg'] = {
    'model': fedavg_model,
    'history': fedavg_history,
    'name': 'FedAvg'
}

print(f"‚úÖ FedAvg Complete!")
print(f"   Final Accuracy: {fedavg_history['acc'][-1]:.2f}%")
print(f"   Final JFI: {fedavg_history['jfi'][-1]:.4f}")


# ============================================================
# EXPERIMENTS 2-3: Fed-Audit-GAN v7.0 with Œ≥ = 0.3 and 0.7
# ============================================================
for gamma in GAMMA_VALUES:
    method_name = f"FedAuditGAN_v7_Œ≥={gamma}"
    
    wandb.init(
        project="FED_AUDIT_GAN_TEST_7_CIFAR10",
        name=f"{method_name}_CIFAR10_clients{N_CLIENTS}_no_warmup",
        config={
            "method": method_name,
            "dataset": "CIFAR-10",
            "n_rounds": N_ROUNDS,
            "n_clients": N_CLIENTS,
            "gamma": gamma,
            "momentum": MOMENTUM,
            "warmup_rounds": 0,  # NO WARMUP!
            "gradient_clipping": False,  # NO GRADIENT CLIPPING!
            "gan_every_round": True,
            "aggregation_method": "V2_LINEAR",
            "non_iid_type": "pathological",
            "classes_per_client": CLASSES_PER_CLIENT,
            "device": str(DEVICE),
            "num_gpus": NUM_GPUS,
            "amp_enabled": USE_AMP
        }
    )
    
    model, history = run_fed_audit_gan_v7(
        gamma=gamma,
        n_rounds=N_ROUNDS,
        n_clients=N_CLIENTS,
        momentum=MOMENTUM,
        train_data=train_data,
        client_idx=client_idx,
        val_loader=val_loader,
        test_loader=test_loader,
        client_loaders=client_loaders,
        n_gan_epochs=N_GAN_EPOCHS,
        n_probes=N_PROBES,
        local_epochs=LOCAL_EPOCHS,
        device=DEVICE,
        use_amp=USE_AMP,
        client_data_weights=CLIENT_DATA_WEIGHTS
    )
    
    wandb.finish()
    
    all_results[method_name] = {
        'model': model,
        'history': history,
        'name': method_name
    }
    
    print(f"‚úÖ {method_name} Complete!")
    print(f"   Final Accuracy: {history['acc'][-1]:.2f}%")
    print(f"   Final JFI: {history['jfi'][-1]:.4f}")
    print(f"   Accuracy Gap: {history['accuracy_gap'][-1]:.2f}%")

print("\n" + "=" * 70)
print("‚úÖ ALL EXPERIMENTS COMPLETE!")
print("=" * 70)

In [None]:
# ============================================================
# VISUALIZATION
# ============================================================

# Color scheme
colors = {
    'FedAvg': '#1f77b4',
    'FedAuditGAN_v7_Œ≥=0.3': '#2ca02c',
    'FedAuditGAN_v7_Œ≥=0.7': '#d62728'
}

linestyles = {
    'FedAvg': '--',
    'FedAuditGAN_v7_Œ≥=0.3': '-',
    'FedAuditGAN_v7_Œ≥=0.7': '-'
}

fig, axes = plt.subplots(2, 3, figsize=(18, 10))
rounds = list(range(1, N_ROUNDS + 1))

# Plot 1: Global Accuracy
ax = axes[0, 0]
for name, result in all_results.items():
    ax.plot(rounds, result['history']['acc'], 
            label=name, color=colors.get(name, 'gray'), 
            linestyle=linestyles.get(name, '-'), linewidth=2)
ax.set_xlabel('Round')
ax.set_ylabel('Accuracy (%)')
ax.set_title('CIFAR-10: Global Test Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: JFI
ax = axes[0, 1]
for name, result in all_results.items():
    ax.plot(rounds, result['history']['jfi'], 
            label=name, color=colors.get(name, 'gray'),
            linestyle=linestyles.get(name, '-'), linewidth=2)
ax.set_xlabel('Round')
ax.set_ylabel('JFI')
ax.set_title("Jain's Fairness Index (higher=fairer)")
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 3: Max-Min Fairness
ax = axes[0, 2]
for name, result in all_results.items():
    ax.plot(rounds, result['history']['max_min_fairness'], 
            label=name, color=colors.get(name, 'gray'),
            linestyle=linestyles.get(name, '-'), linewidth=2)
ax.set_xlabel('Round')
ax.set_ylabel('Min/Max Ratio')
ax.set_title('Max-Min Fairness (higher=fairer)')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: Variance
ax = axes[1, 0]
for name, result in all_results.items():
    ax.plot(rounds, result['history']['variance'], 
            label=name, color=colors.get(name, 'gray'),
            linestyle=linestyles.get(name, '-'), linewidth=2)
ax.set_xlabel('Round')
ax.set_ylabel('Variance')
ax.set_title('Per-Client Accuracy Variance (lower=fairer)')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 5: Accuracy Gap
ax = axes[1, 1]
for name, result in all_results.items():
    ax.plot(rounds, result['history']['accuracy_gap'], 
            label=name, color=colors.get(name, 'gray'),
            linestyle=linestyles.get(name, '-'), linewidth=2)
ax.set_xlabel('Round')
ax.set_ylabel('Gap (%)')
ax.set_title('Best-Worst Client Gap (lower=fairer)')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 6: Final Per-Client Accuracy
ax = axes[1, 2]
x = np.arange(N_CLIENTS)
width = 0.25
for i, (name, result) in enumerate(all_results.items()):
    final_accs = result['history']['client_accuracies'][-1]
    ax.bar(x + i * width, final_accs, width, 
           label=name, color=colors.get(name, 'gray'), alpha=0.8)
ax.set_xlabel('Client ID')
ax.set_ylabel('Accuracy (%)')
ax.set_title(f'Per-Client Accuracy (Round {N_ROUNDS})')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('cifar10_v7_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úÖ Plots saved to cifar10_v7_results.png")

In [None]:
# ============================================================
# FINAL SUMMARY
# ============================================================

print("\n" + "=" * 110)
print(f"CIFAR-10 V7 FINAL RESULTS SUMMARY (Round {N_ROUNDS})")
print("=" * 110)
print(f"{'Method':<30} {'Accuracy':>10} {'JFI':>8} {'Max-Min':>10} {'Variance':>12} {'Min Acc':>10} {'Max Acc':>10} {'Gap':>8}")
print("-" * 110)

for name, result in all_results.items():
    h = result['history']
    acc = h['acc'][-1]
    jfi = h['jfi'][-1]
    mmf = h['max_min_fairness'][-1]
    var = h['variance'][-1]
    min_acc = h['min_client_acc'][-1]
    max_acc = h['max_client_acc'][-1]
    gap = h['accuracy_gap'][-1]
    
    print(f"{name:<30} {acc:>10.2f}% {jfi:>8.4f} {mmf:>10.4f} {var:>12.2f} {min_acc:>10.2f}% {max_acc:>10.2f}% {gap:>8.2f}%")

print("=" * 110)

# Winner analysis
print("\n" + "=" * 60)
print("FAIRNESS COMPARISON")
print("=" * 60)

# Best JFI
best_jfi_name = max(all_results.keys(), key=lambda x: all_results[x]['history']['jfi'][-1])
best_jfi_val = all_results[best_jfi_name]['history']['jfi'][-1]
print(f"Best JFI:        {best_jfi_name} ({best_jfi_val:.4f})")

# Best Max-Min Fairness
best_mmf_name = max(all_results.keys(), key=lambda x: all_results[x]['history']['max_min_fairness'][-1])
best_mmf_val = all_results[best_mmf_name]['history']['max_min_fairness'][-1]
print(f"Best Max-Min:    {best_mmf_name} ({best_mmf_val:.4f})")

# Lowest Variance
best_var_name = min(all_results.keys(), key=lambda x: all_results[x]['history']['variance'][-1])
best_var_val = all_results[best_var_name]['history']['variance'][-1]
print(f"Lowest Variance: {best_var_name} ({best_var_val:.2f})")

# Smallest Gap
best_gap_name = min(all_results.keys(), key=lambda x: all_results[x]['history']['accuracy_gap'][-1])
best_gap_val = all_results[best_gap_name]['history']['accuracy_gap'][-1]
print(f"Smallest Gap:    {best_gap_name} ({best_gap_val:.2f}%)")

print("\n" + "=" * 60)
print("‚úÖ V7 Experiment Complete!")
print("   ‚ùå NO Gradient Clipping")
print("   ‚ùå NO Warmup Rounds")
print("   ‚úÖ V2 Linear Aggregation from Round 1")
print("=" * 60)

# üîß Fed-Audit-GAN v7.0 - CIFAR-10 (FedProx + GAN Every Round)

## üéØ Experiments Run:

### üîµ Baseline: FedAvg (Federated Averaging)
- Standard data-weighted averaging (McMahan et al., 2017)
- NO GAN, NO fairness scoring

### üü¢ Our Method: Fed-Audit-GAN (V2 Linear Formula)
- **Fed-Audit-GAN Œ≥ = 0.3** - Accuracy-weighted
- **Fed-Audit-GAN Œ≥ = 0.7** - Fairness-weighted

## üîß Fed-Audit-GAN 4-Phase Architecture:
1. **Phase 1**: Local Client Training (SGD + FedProx)
2. **Phase 2**: GAN Training (EVERY round - including warm-up!)
3. **Phase 3**: Fairness Scoring (with EMA)
4. **Phase 4**: V2 Linear Aggregation (Min-Max Normalized)

## ‚≠ê KEY CHANGES FROM V6:
- ‚úÖ **ADDED BACK**: FedProx proximal term (Œº=0.01)
- ‚ùå **REMOVED**: Gradient clipping
- ‚úÖ **KEPT**: GAN trains EVERY round (including warm-up rounds 1-10)
- ‚úÖ **KEPT**: Audit every round after round 10
- ‚úÖ **KEPT**: V2 Linear Formula for aggregation
- ‚úÖ **KEPT**: Soft labels for GAN stability
- ‚úÖ **KEPT**: Warm-up period (but GAN still trains during warm-up)

## ‚≠ê V2 Linear Aggregation Formula:
```
Weight = (1 - Œ≥) √ó Accuracy_norm + Œ≥ √ó Fairness_norm
```

## ‚≠ê PATHOLOGICAL NON-IID:
- Each client ONLY has 2 out of 10 classes
- Clients have DIFFERENT sample sizes (Dirichlet distribution)

---

In [None]:
# Step 1: Install Dependencies
!pip install -q torch torchvision tqdm matplotlib numpy wandb

print("‚úÖ Dependencies installed!")

In [None]:
# Step 2: Login to WandB
import wandb
wandb.login()
print("‚úÖ WandB logged in!")

In [None]:
# Step 3: Imports and GPU Setup
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import copy
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# ============================================================
# üöÄ GPU SETUP
# ============================================================

if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
    NUM_GPUS = torch.cuda.device_count()
    
    # Enable cuDNN optimizations
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    
    # Enable TF32 for faster matrix operations
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
    torch.cuda.empty_cache()
    
    print(f"‚úÖ GPU(s) detected: {NUM_GPUS}")
    for i in range(NUM_GPUS):
        props = torch.cuda.get_device_properties(i)
        print(f"   GPU {i}: {props.name}")
        print(f"      Memory: {props.total_memory / 1e9:.2f} GB")
    print(f"   CUDA Version: {torch.version.cuda}")
else:
    DEVICE = torch.device('cpu')
    NUM_GPUS = 0
    print("‚ö†Ô∏è  No GPU detected. Using CPU.")

# Mixed Precision Training
USE_AMP = torch.cuda.is_available()
if USE_AMP:
    print("\n‚úÖ Mixed Precision Training (AMP) enabled.")

print(f"\nüìç Device: {DEVICE}")
print(f"   PyTorch: {torch.__version__}")

In [None]:
# ============================================================
# MODEL DEFINITIONS (CIFAR-10)
# ============================================================

class CNN(nn.Module):
    """CNN for CIFAR-10 classification"""
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(512 * 2 * 2, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))  # 32->16
        x = self.pool(F.relu(self.bn2(self.conv2(x))))  # 16->8
        x = self.pool(F.relu(self.bn3(self.conv3(x))))  # 8->4
        x = self.pool(F.relu(self.bn4(self.conv4(x))))  # 4->2
        x = x.view(-1, 512 * 2 * 2)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)


class FairnessGenerator(nn.Module):
    """Generator that produces paired samples (x, x') for fairness testing"""
    def __init__(self, latent_dim=128, num_classes=10, img_shape=(3, 32, 32)):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        self.init_size = img_shape[1] // 4  # 8
        self.l1 = nn.Linear(latent_dim * 2, 512 * self.init_size ** 2)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, img_shape[0], 3, 1, 1),
            nn.Tanh()
        )
        self.delta_net = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.ReLU(),
            nn.Linear(512, int(np.prod(img_shape))),
            nn.Tanh()
        )
        self.delta_scale = 0.1

    def forward(self, z, labels):
        gen_input = torch.cat([z, self.label_emb(labels)], dim=1)
        out = self.l1(gen_input).view(-1, 512, self.init_size, self.init_size)
        x = self.conv_blocks(out)
        delta = self.delta_net(z).view(-1, *self.img_shape) * self.delta_scale
        return x, torch.clamp(x + delta, -1, 1)


class Discriminator(nn.Module):
    """Conditional Discriminator"""
    def __init__(self, num_classes=10, img_shape=(3, 32, 32)):
        super().__init__()
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.conv = nn.Sequential(
            nn.Conv2d(img_shape[0] + num_classes, 64, 3, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 3, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2)
        )
        self.fc = nn.Linear(512 * 2 * 2, 1)

    def forward(self, img, labels):
        label_map = self.label_emb(labels).view(-1, self.num_classes, 1, 1)
        label_map = label_map.expand(-1, -1, self.img_shape[1], self.img_shape[2])
        out = self.conv(torch.cat([img, label_map], dim=1))
        return self.fc(out.view(out.size(0), -1))


print("‚úÖ Models defined: CNN, FairnessGenerator, Discriminator")

In [None]:
# ============================================================
# HELPER FUNCTIONS
# ============================================================

def train_gan(G, D, model, loader, epochs=15, device='cuda', l1=1.0, l2=1.0):
    """
    Train the Fairness GAN with Soft Labels for stability
    """
    G, D, model = G.to(device), D.to(device), model.to(device)
    model.eval()
    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    bce = nn.BCEWithLogitsLoss()
    
    if USE_AMP:
        scaler_G = torch.amp.GradScaler(device='cuda')
        scaler_D = torch.amp.GradScaler(device='cuda')
    
    for _ in range(epochs):
        for imgs, labels in loader:
            bs = imgs.size(0)
            
            # Soft Labels for stability
            real_labels = torch.empty(bs, 1, device=device).uniform_(0.9, 1.0)
            fake_labels = torch.empty(bs, 1, device=device).uniform_(0.0, 0.1)
            
            imgs, labels = imgs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            z = torch.randn(bs, G.latent_dim, device=device)
            gl = torch.randint(0, G.num_classes, (bs,), device=device)
            
            # Generator
            opt_G.zero_grad(set_to_none=True)
            if USE_AMP:
                with torch.amp.autocast(device_type='cuda'):
                    x, xp = G(z, gl)
                    with torch.no_grad():
                        px, pxp = model(x), model(xp)
                    t1 = -torch.mean((px - pxp) ** 2)
                    t2 = l1 * torch.mean((x - xp) ** 2)
                    g_real = torch.empty(bs, 1, device=device).uniform_(0.9, 1.0)
                    t3 = l2 * (bce(D(x, gl), g_real) + bce(D(xp, gl), g_real)) / 2
                    g_loss = t1 + t2 + t3
                scaler_G.scale(g_loss).backward()
                scaler_G.step(opt_G)
                scaler_G.update()
            else:
                x, xp = G(z, gl)
                with torch.no_grad():
                    px, pxp = model(x), model(xp)
                t1 = -torch.mean((px - pxp) ** 2)
                t2 = l1 * torch.mean((x - xp) ** 2)
                g_real = torch.empty(bs, 1, device=device).uniform_(0.9, 1.0)
                t3 = l2 * (bce(D(x, gl), g_real) + bce(D(xp, gl), g_real)) / 2
                g_loss = t1 + t2 + t3
                g_loss.backward()
                opt_G.step()
            
            # Discriminator
            opt_D.zero_grad(set_to_none=True)
            if USE_AMP:
                with torch.amp.autocast(device_type='cuda'):
                    x, xp = G(z, gl)
                    d_loss = (bce(D(imgs, labels), real_labels) + 
                              bce(D(x.detach(), gl), fake_labels) + 
                              bce(D(xp.detach(), gl), fake_labels)) / 3
                scaler_D.scale(d_loss).backward()
                scaler_D.step(opt_D)
                scaler_D.update()
            else:
                x, xp = G(z, gl)
                d_loss = (bce(D(imgs, labels), real_labels) + 
                          bce(D(x.detach(), gl), fake_labels) + 
                          bce(D(xp.detach(), gl), fake_labels)) / 3
                d_loss.backward()
                opt_D.step()
    
    return G, D


@torch.no_grad()
def compute_bias(model, x, xp, device):
    """Compute bias as difference in model predictions between x and x'"""
    model.eval()
    x, xp = x.to(device), xp.to(device)
    
    if USE_AMP:
        with torch.amp.autocast(device_type='cuda'):
            diff = torch.abs(model(x) - model(xp)).sum(1).mean()
    else:
        diff = torch.abs(model(x) - model(xp)).sum(1).mean()
    
    return diff.item()


def partition_data_pathological_non_iid(dataset, n_clients, classes_per_client=2):
    """Pathological Non-IID with unequal sample sizes"""
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    n_classes = len(np.unique(labels))
    
    class_indices = {c: np.where(labels == c)[0] for c in range(n_classes)}
    for c in class_indices:
        np.random.shuffle(class_indices[c])
    
    client_classes = []
    for cid in range(n_clients):
        start_class = (cid * classes_per_client) % n_classes
        assigned = [(start_class + i) % n_classes for i in range(classes_per_client)]
        client_classes.append(assigned)
    
    # Unequal sample sizes using Dirichlet
    alpha = 0.5
    client_proportions = np.random.dirichlet([alpha] * n_clients)
    
    client_indices = [[] for _ in range(n_clients)]
    
    for c in range(n_classes):
        clients_with_class = [cid for cid in range(n_clients) if c in client_classes[cid]]
        
        if len(clients_with_class) > 0:
            class_samples = class_indices[c]
            total_for_class = len(class_samples)
            
            relevant_props = np.array([client_proportions[cid] for cid in clients_with_class])
            relevant_props = relevant_props / relevant_props.sum()
            
            start_idx = 0
            for i, cid in enumerate(clients_with_class):
                if i == len(clients_with_class) - 1:
                    end_idx = total_for_class
                else:
                    n_samples = int(total_for_class * relevant_props[i])
                    end_idx = min(start_idx + n_samples, total_for_class)
                
                if start_idx < end_idx:
                    client_indices[cid].extend(class_samples[start_idx:end_idx].tolist())
                start_idx = end_idx
    
    result = []
    for cid in range(n_clients):
        if len(client_indices[cid]) > 0:
            indices = np.array(client_indices[cid])
            np.random.shuffle(indices)
            result.append(indices)
        else:
            fallback_samples = np.random.choice(len(dataset), size=50, replace=False)
            result.append(fallback_samples)
    
    return result, client_classes


@torch.no_grad()
def evaluate(model, loader, device):
    """Evaluate model accuracy"""
    model.eval()
    correct, total = 0, 0
    
    for d, t in loader:
        d, t = d.to(device, non_blocking=True), t.to(device, non_blocking=True)
        
        if USE_AMP:
            with torch.amp.autocast(device_type='cuda'):
                preds = model(d).argmax(1)
        else:
            preds = model(d).argmax(1)
        
        correct += (preds == t).sum().item()
        total += len(t)
    
    return 100 * correct / total


@torch.no_grad()
def evaluate_per_client(model, client_loaders, device):
    """Evaluate model accuracy on each client's data"""
    model.eval()
    client_accuracies = []
    
    for loader in client_loaders:
        correct, total = 0, 0
        for d, t in loader:
            d, t = d.to(device, non_blocking=True), t.to(device, non_blocking=True)
            
            if USE_AMP:
                with torch.amp.autocast(device_type='cuda'):
                    preds = model(d).argmax(1)
            else:
                preds = model(d).argmax(1)
            
            correct += (preds == t).sum().item()
            total += len(t)
        
        acc = 100 * correct / total if total > 0 else 0
        client_accuracies.append(acc)
    
    return client_accuracies


# ============================================================
# FAIRNESS METRICS
# ============================================================

def calculate_jfi(performances):
    """Jain's Fairness Index"""
    p = np.array(performances)
    n = len(p)
    if np.sum(p ** 2) == 0:
        return 1.0
    return (np.sum(p) ** 2) / (n * np.sum(p ** 2))

def calculate_accuracy_gap(performances):
    return np.max(performances) - np.min(performances)

def calculate_variance(performances):
    return np.var(performances)

def calculate_max_min_fairness(performances):
    p = np.array(performances)
    if np.max(p) == 0:
        return 0.0
    return np.min(p) / np.max(p)


print("‚úÖ Helper functions defined")

In [None]:
# ============================================================
# CONFIGURATION (V7 - FedProx + GAN Every Round)
# ============================================================

# Training Parameters
N_ROUNDS = 50
N_CLIENTS = 20
N_GAN_EPOCHS = 15
N_PROBES = 500
LOCAL_EPOCHS = 3

# Warm-up (aggregation only - GAN still trains!)
WARMUP_ROUNDS = 10

# GAN trains EVERY round (no audit frequency limit)
# After warm-up, audit every round

# Fed-Audit-GAN Parameters
MOMENTUM = 0.8  # EMA momentum for fairness scores

# ‚≠ê V7 CHANGE: FedProx ADDED BACK
MU = 0.01  # FedProx proximal term coefficient

# ‚ùå V7: Gradient Clipping REMOVED
# GRAD_CLIP_NORM = 1.0  # Removed!

# Test multiple gamma values
GAMMA_VALUES = [0.3, 0.7]

# DataLoader Parameters
BATCH_SIZE = 128
VAL_BATCH_SIZE = 256
NUM_WORKERS = 4
PIN_MEMORY = True
PREFETCH_FACTOR = 2

# PATHOLOGICAL NON-IID
CLASSES_PER_CLIENT = 2

print("=" * 70)
print("üîß Fed-Audit-GAN v7.0 - CIFAR-10 (FedProx + GAN Every Round)")
print("=" * 70)
print(f"Device: {DEVICE}")
print(f"GPUs: {NUM_GPUS}")
print(f"AMP Enabled: {USE_AMP}")
print(f"Rounds: {N_ROUNDS}, Clients: {N_CLIENTS}")

print(f"\nüéØ EXPERIMENTS TO RUN:")
print(f"   üîµ 1. FedAvg (BASELINE)")
for i, g in enumerate(GAMMA_VALUES, 2):
    print(f"   üü¢ {i}. Fed-Audit-GAN Œ≥={g}")

print(f"\n‚≠ê V7 CHANGES FROM V6:")
print(f"   ‚úÖ ADDED BACK: FedProx proximal term (Œº={MU})")
print(f"   ‚ùå REMOVED: Gradient clipping")
print(f"   ‚úÖ KEPT: GAN trains EVERY round (including warm-up!)")
print(f"   ‚úÖ KEPT: Audit every round after round {WARMUP_ROUNDS}")
print(f"   ‚úÖ KEPT: V2 Linear aggregation formula")
print(f"   ‚úÖ KEPT: Soft labels for GAN")
print(f"   ‚úÖ KEPT: EMA momentum ({MOMENTUM})")

print(f"\n‚≠ê PATHOLOGICAL NON-IID:")
print(f"   Each client gets ONLY {CLASSES_PER_CLIENT}/10 classes")
print("=" * 70)

In [None]:
# ============================================================
# DATA LOADING
# ============================================================

CIFAR10_CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                   'dog', 'frog', 'horse', 'ship', 'truck']

transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

train_data = datasets.CIFAR10('./data', train=True, download=True, transform=transform_train)
test_data = datasets.CIFAR10('./data', train=False, download=True, transform=transform_test)

# Create PATHOLOGICAL Non-IID partitions
np.random.seed(42)
client_idx, client_classes = partition_data_pathological_non_iid(
    train_data, N_CLIENTS, classes_per_client=CLASSES_PER_CLIENT
)

# Calculate data weights
client_data_sizes = [len(idx) for idx in client_idx]
total_samples = sum(client_data_sizes)
CLIENT_DATA_WEIGHTS = [size / total_samples for size in client_data_sizes]

# DataLoader kwargs
dataloader_kwargs = {
    'num_workers': NUM_WORKERS,
    'pin_memory': PIN_MEMORY,
    'persistent_workers': True,
    'prefetch_factor': PREFETCH_FACTOR,
}

test_loader = DataLoader(test_data, batch_size=VAL_BATCH_SIZE, shuffle=False, **dataloader_kwargs)
val_loader = DataLoader(
    Subset(train_data, np.random.choice(len(train_data), 2000, replace=False)),
    batch_size=BATCH_SIZE, shuffle=False, **dataloader_kwargs
)

client_loaders = [
    DataLoader(Subset(train_data, client_idx[c]), batch_size=BATCH_SIZE, shuffle=True, **dataloader_kwargs)
    for c in range(N_CLIENTS)
]

print(f"‚úÖ Data loaded!")
print(f"   Training samples: {len(train_data)}")
print(f"   Test samples: {len(test_data)}")
print(f"\n" + "=" * 70)
print(f"‚≠ê PATHOLOGICAL NON-IID ({CLASSES_PER_CLIENT} classes per client)")
print("=" * 70)
for i in range(min(10, N_CLIENTS)):
    class_names = [CIFAR10_CLASSES[c] for c in client_classes[i]]
    print(f"   Client {i:2d}: {client_classes[i]} ‚Üí {class_names} ({client_data_sizes[i]} samples)")

if N_CLIENTS > 10:
    print(f"   ... ({N_CLIENTS - 10} more clients)")

In [None]:
# ============================================================
# TRAINING FUNCTIONS (V7 - FedProx + GAN Every Round)
# ============================================================

def run_fedavg(n_rounds, n_clients, train_data, client_idx, val_loader, test_loader, 
               client_loaders, local_epochs, device, use_amp, client_data_weights):
    """
    Standard FedAvg (McMahan et al., 2017)
    """
    
    print("\n" + "=" * 70)
    print("üîµ RUNNING: FedAvg (Federated Averaging) - BASELINE")
    print("=" * 70)
    
    model = CNN().to(device)
    scaler = torch.amp.GradScaler(device='cuda') if use_amp else None
    
    history = {
        'acc': [], 'client_accuracies': [],
        'jfi': [], 'max_min_fairness': [], 'variance': [], 'accuracy_gap': [],
        'min_client_acc': [], 'max_client_acc': []
    }
    
    for rnd in tqdm(range(n_rounds), desc="FedAvg"):
        local_weights = []
        
        for cid in range(n_clients):
            local_model = copy.deepcopy(model)
            local_model.train()
            optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
            criterion = nn.CrossEntropyLoss()
            
            for epoch in range(local_epochs):
                for data, target in client_loaders[cid]:
                    data = data.to(device, non_blocking=True)
                    target = target.to(device, non_blocking=True)
                    optimizer.zero_grad(set_to_none=True)
                    
                    if use_amp:
                        with torch.amp.autocast(device_type='cuda'):
                            output = local_model(data)
                            loss = criterion(output, target)
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        output = local_model(data)
                        loss = criterion(output, target)
                        loss.backward()
                        optimizer.step()
            
            local_weights.append(copy.deepcopy(local_model.state_dict()))
            del local_model
            torch.cuda.empty_cache()
        
        # Data-Weighted Averaging
        avg_weights = copy.deepcopy(local_weights[0])
        for key in avg_weights.keys():
            avg_weights[key] = avg_weights[key] * client_data_weights[0]
            for i in range(1, len(local_weights)):
                avg_weights[key] += local_weights[i][key] * client_data_weights[i]
        
        model.load_state_dict(avg_weights)
        
        # Evaluation
        acc = evaluate(model, test_loader, device)
        client_accs = evaluate_per_client(model, client_loaders, device)
        
        jfi = calculate_jfi(client_accs)
        max_min = calculate_max_min_fairness(client_accs)
        var = calculate_variance(client_accs)
        gap = calculate_accuracy_gap(client_accs)
        
        history['acc'].append(acc)
        history['client_accuracies'].append(client_accs.copy())
        history['jfi'].append(jfi)
        history['max_min_fairness'].append(max_min)
        history['variance'].append(var)
        history['accuracy_gap'].append(gap)
        history['min_client_acc'].append(min(client_accs))
        history['max_client_acc'].append(max(client_accs))
        
        wandb.log({
            'round': rnd + 1,
            'accuracy': acc,
            'jfi': jfi,
            'max_min_fairness': max_min,
            'fairness_variance': var,
            'accuracy_gap': gap,
            'min_client_acc': min(client_accs),
            'max_client_acc': max(client_accs)
        })
    
    return model, history


def run_fed_audit_gan_v7(gamma, n_rounds, n_clients, warmup_rounds, momentum,
                         mu, train_data, client_idx, val_loader, 
                         test_loader, client_loaders, n_gan_epochs, n_probes, 
                         local_epochs, device, use_amp, client_data_weights):
    """
    Fed-Audit-GAN v7.0 - FedProx + GAN Every Round
    
    ‚úÖ ADDED BACK: FedProx proximal term
    ‚ùå REMOVED: Gradient clipping
    ‚úÖ KEPT: GAN trains EVERY round (including warm-up!)
    ‚úÖ KEPT: Audit every round after warm-up
    ‚úÖ KEPT: V2 Linear aggregation formula
    """
    
    print(f"\n" + "=" * 70)
    print(f"üü¢ RUNNING: Fed-Audit-GAN v7.0 (Œ≥={gamma})")
    print(f"   ‚úÖ FedProx (Œº={mu})")
    print(f"   ‚úÖ GAN trains EVERY round (including warm-up!)")
    print(f"   ‚úÖ V2 Linear Formula: Weight = (1-Œ≥)√óAcc + Œ≥√óFair")
    print("=" * 70)
    
    model = CNN().to(device)
    scaler = torch.amp.GradScaler(device='cuda') if use_amp else None
    
    fairness_history = {i: 0.0 for i in range(n_clients)}
    
    history = {
        'acc': [], 'bias': [], 'alphas': [],
        'raw_scores': [], 'smoothed_scores': [],
        'client_accuracies': [],
        'jfi': [], 'max_min_fairness': [], 'variance': [], 'accuracy_gap': [],
        'min_client_acc': [], 'max_client_acc': [],
        'gan_trained': []
    }
    
    for rnd in tqdm(range(n_rounds), desc=f"Fed-Audit-GAN Œ≥={gamma}"):
        
        # ================================================================
        # PHASE 1: Local Client Training (with FedProx!)
        # ================================================================
        updates = []
        global_state = copy.deepcopy(model.state_dict())
        
        for cid in range(n_clients):
            local_model = copy.deepcopy(model)
            local_model.train()
            before_state = copy.deepcopy(model.state_dict())
            
            # SGD with FedProx (NO Gradient Clipping)
            optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
            
            for epoch in range(local_epochs):
                for data, target in client_loaders[cid]:
                    data = data.to(device, non_blocking=True)
                    target = target.to(device, non_blocking=True)
                    optimizer.zero_grad(set_to_none=True)
                    
                    if use_amp:
                        with torch.amp.autocast(device_type='cuda'):
                            output = local_model(data)
                            loss = F.cross_entropy(output, target)
                            
                            # ‚úÖ FedProx Proximal Term
                            prox_term = 0.0
                            for name, param in local_model.named_parameters():
                                if name in global_state:
                                    prox_term += ((param - global_state[name].to(device)) ** 2).sum()
                            loss = loss + (mu / 2) * prox_term
                        
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        output = local_model(data)
                        loss = F.cross_entropy(output, target)
                        
                        # ‚úÖ FedProx Proximal Term
                        prox_term = 0.0
                        for name, param in local_model.named_parameters():
                            if name in global_state:
                                prox_term += ((param - global_state[name].to(device)) ** 2).sum()
                        loss = loss + (mu / 2) * prox_term
                        
                        loss.backward()
                        optimizer.step()
            
            update = {k: local_model.state_dict()[k] - before_state[k] for k in before_state}
            updates.append(update)
            del local_model
            torch.cuda.empty_cache()
        
        # ================================================================
        # PHASE 2 & 3: GAN Training + Fairness Scoring
        # GAN trains EVERY round (including warm-up!)
        # ================================================================
        B_base = 0.0
        S_fair_raw = [0.0] * n_clients
        S_fair_smoothed = [0.0] * n_clients
        
        # Train GAN EVERY round (including warm-up)
        G = FairnessGenerator(img_shape=(3, 32, 32)).to(device)
        D = Discriminator(img_shape=(3, 32, 32)).to(device)
        G, D = train_gan(G, D, model, val_loader, epochs=n_gan_epochs, device=device)
        
        G.eval()
        with torch.no_grad():
            z = torch.randn(n_probes, G.latent_dim, device=device)
            labels = torch.randint(0, 10, (n_probes,), device=device)
            
            if use_amp:
                with torch.amp.autocast(device_type='cuda'):
                    x_probe, xp_probe = G(z, labels)
            else:
                x_probe, xp_probe = G(z, labels)
        
        B_base = compute_bias(model, x_probe, xp_probe, device)
        
        S_fair_raw = []
        S_fair_smoothed = []
        
        for cid, upd in enumerate(updates):
            hyp_model = copy.deepcopy(model)
            hyp_state = hyp_model.state_dict()
            for k in hyp_state:
                hyp_state[k] = hyp_state[k] + upd[k]
            hyp_model.load_state_dict(hyp_state)
            
            B_client = compute_bias(hyp_model, x_probe, xp_probe, device)
            S_current = B_base - B_client
            S_fair_raw.append(S_current)
            
            # EMA smoothing
            S_prev = fairness_history[cid]
            S_smoothed = (momentum * S_prev) + ((1 - momentum) * S_current)
            fairness_history[cid] = S_smoothed
            S_fair_smoothed.append(S_smoothed)
            del hyp_model
        
        del G, D, x_probe, xp_probe
        torch.cuda.empty_cache()
        
        history['raw_scores'].append(S_fair_raw.copy())
        history['smoothed_scores'].append(S_fair_smoothed.copy())
        history['gan_trained'].append(True)  # GAN trains every round
        
        # ================================================================
        # PHASE 4: V2 LINEAR AGGREGATION
        # Warm-up: Use FedAvg (but GAN still trained for score accumulation)
        # After warm-up: Use V2 Linear with gamma
        # ================================================================
        if rnd < warmup_rounds:
            # During warm-up: Use FedAvg weights (but GAN still trained!)
            alphas = client_data_weights.copy()
        else:
            # After warm-up: Use V2 Linear Formula
            
            # Normalize Fairness Score
            f_tensor = torch.tensor(S_fair_smoothed, device=device, dtype=torch.float32)
            f_min, f_max = f_tensor.min(), f_tensor.max()
            if f_max != f_min:
                f_norm = (f_tensor - f_min) / (f_max - f_min)
            else:
                f_norm = torch.ones_like(f_tensor) * 0.5
            
            # Normalize Accuracy Score (using data weights as proxy)
            a_tensor = torch.tensor(client_data_weights, device=device, dtype=torch.float32)
            a_min, a_max = a_tensor.min(), a_tensor.max()
            if a_max != a_min:
                a_norm = (a_tensor - a_min) / (a_max - a_min)
            else:
                a_norm = torch.ones_like(a_tensor) * 0.5
            
            # V2 Linear Formula
            raw_weights = ((1 - gamma) * a_norm) + (gamma * f_norm) + 1e-8
            
            # Final Normalization
            alphas = (raw_weights / raw_weights.sum()).tolist()
            
            # Debug output every 10 rounds
            if rnd % 10 == 0:
                print(f"\n   üìä Round {rnd+1} Debug (Œ≥={gamma}):")
                print(f"      Fairness: min={min(S_fair_smoothed):.4f}, max={max(S_fair_smoothed):.4f}")
                print(f"      Alphas: min={min(alphas):.4f}, max={max(alphas):.4f}")
        
        # Apply weighted aggregation
        new_state = model.state_dict()
        for k in new_state:
            new_state[k] = new_state[k] + sum(a * u[k] for a, u in zip(alphas, updates))
        model.load_state_dict(new_state)
        
        # ================================================================
        # EVALUATION
        # ================================================================
        acc = evaluate(model, test_loader, device)
        client_accs = evaluate_per_client(model, client_loaders, device)
        
        jfi = calculate_jfi(client_accs)
        max_min = calculate_max_min_fairness(client_accs)
        var = calculate_variance(client_accs)
        gap = calculate_accuracy_gap(client_accs)
        
        history['acc'].append(acc)
        history['bias'].append(B_base)
        history['alphas'].append(alphas.copy())
        history['client_accuracies'].append(client_accs.copy())
        history['jfi'].append(jfi)
        history['max_min_fairness'].append(max_min)
        history['variance'].append(var)
        history['accuracy_gap'].append(gap)
        history['min_client_acc'].append(min(client_accs))
        history['max_client_acc'].append(max(client_accs))
        
        wandb.log({
            'round': rnd + 1,
            'accuracy': acc,
            'bias': B_base,
            'jfi': jfi,
            'max_min_fairness': max_min,
            'fairness_variance': var,
            'accuracy_gap': gap,
            'min_client_acc': min(client_accs),
            'max_client_acc': max(client_accs),
            'gan_trained': 1,
            'alpha_min': min(alphas),
            'alpha_max': max(alphas)
        })
    
    return model, history


print("‚úÖ Training functions defined (FedAvg + Fed-Audit-GAN v7.0)")

In [None]:
# ============================================================
# RUN ALL EXPERIMENTS
# ============================================================

all_results = {}

# ============================================================
# EXPERIMENT 1: FedAvg (BASELINE)
# ============================================================
wandb.init(
    project="FED_AUDIT_GAN_TEST_7_CIFAR10",
    name=f"FedAvg_CIFAR10_clients{N_CLIENTS}_v7",
    config={
        "method": "FedAvg",
        "dataset": "CIFAR-10",
        "n_rounds": N_ROUNDS,
        "n_clients": N_CLIENTS,
        "non_iid_type": "pathological",
        "classes_per_client": CLASSES_PER_CLIENT,
        "device": str(DEVICE),
        "num_gpus": NUM_GPUS,
        "amp_enabled": USE_AMP
    }
)

fedavg_model, fedavg_history = run_fedavg(
    n_rounds=N_ROUNDS,
    n_clients=N_CLIENTS,
    train_data=train_data,
    client_idx=client_idx,
    val_loader=val_loader,
    test_loader=test_loader,
    client_loaders=client_loaders,
    local_epochs=LOCAL_EPOCHS,
    device=DEVICE,
    use_amp=USE_AMP,
    client_data_weights=CLIENT_DATA_WEIGHTS
)

wandb.finish()

all_results['FedAvg'] = {
    'model': fedavg_model,
    'history': fedavg_history,
    'name': 'FedAvg'
}

print(f"‚úÖ FedAvg Complete!")
print(f"   Final Accuracy: {fedavg_history['acc'][-1]:.2f}%")
print(f"   Final JFI: {fedavg_history['jfi'][-1]:.4f}")


# ============================================================
# EXPERIMENTS 2-3: Fed-Audit-GAN v7.0 with Œ≥ = 0.3 and 0.7
# ============================================================
for gamma in GAMMA_VALUES:
    method_name = f"FedAuditGAN_v7_Œ≥={gamma}"
    
    wandb.init(
        project="FED_AUDIT_GAN_TEST_7_CIFAR10",
        name=f"{method_name}_CIFAR10_clients{N_CLIENTS}_FedProx",
        config={
            "method": method_name,
            "dataset": "CIFAR-10",
            "n_rounds": N_ROUNDS,
            "n_clients": N_CLIENTS,
            "gamma": gamma,
            "momentum": MOMENTUM,
            "warmup_rounds": WARMUP_ROUNDS,
            "fedprox": True,
            "mu": MU,
            "gradient_clipping": False,
            "gan_every_round": True,
            "gan_during_warmup": True,
            "aggregation_method": "V2_LINEAR",
            "non_iid_type": "pathological",
            "classes_per_client": CLASSES_PER_CLIENT,
            "device": str(DEVICE),
            "num_gpus": NUM_GPUS,
            "amp_enabled": USE_AMP
        }
    )
    
    model, history = run_fed_audit_gan_v7(
        gamma=gamma,
        n_rounds=N_ROUNDS,
        n_clients=N_CLIENTS,
        warmup_rounds=WARMUP_ROUNDS,
        momentum=MOMENTUM,
        mu=MU,
        train_data=train_data,
        client_idx=client_idx,
        val_loader=val_loader,
        test_loader=test_loader,
        client_loaders=client_loaders,
        n_gan_epochs=N_GAN_EPOCHS,
        n_probes=N_PROBES,
        local_epochs=LOCAL_EPOCHS,
        device=DEVICE,
        use_amp=USE_AMP,
        client_data_weights=CLIENT_DATA_WEIGHTS
    )
    
    wandb.finish()
    
    all_results[method_name] = {
        'model': model,
        'history': history,
        'name': method_name
    }
    
    print(f"‚úÖ {method_name} Complete!")
    print(f"   Final Accuracy: {history['acc'][-1]:.2f}%")
    print(f"   Final JFI: {history['jfi'][-1]:.4f}")
    print(f"   Accuracy Gap: {history['accuracy_gap'][-1]:.2f}%")

print("\n" + "=" * 70)
print("‚úÖ ALL EXPERIMENTS COMPLETE!")
print("=" * 70)

In [None]:
# ============================================================
# üìä RESULTS SUMMARY TABLE
# ============================================================

print("\n" + "=" * 120)
print("üìä CIFAR-10: FedAvg vs Fed-Audit-GAN v7.0 (FedProx + GAN Every Round)")
print("=" * 120)

print(f"\n{'METHOD':<35} {'GLOBAL ACC':<12} {'JFI':<10} {'MAX-MIN':<10} {'GAP':<10} {'MIN ACC':<10} {'MAX ACC':<10}")
print("-" * 120)

method_names = list(all_results.keys())

best_acc = max(all_results[m]['history']['acc'][-1] for m in method_names)
best_jfi = max(all_results[m]['history']['jfi'][-1] for m in method_names)
lowest_gap = min(all_results[m]['history']['accuracy_gap'][-1] for m in method_names)

for method in method_names:
    name = all_results[method]['name']
    acc = all_results[method]['history']['acc'][-1]
    jfi = all_results[method]['history']['jfi'][-1]
    max_min = all_results[method]['history']['max_min_fairness'][-1]
    gap = all_results[method]['history']['accuracy_gap'][-1]
    min_acc = all_results[method]['history']['min_client_acc'][-1]
    max_acc = all_results[method]['history']['max_client_acc'][-1]
    
    acc_mark = "üèÜ" if acc == best_acc else ""
    jfi_mark = "‚≠ê" if jfi == best_jfi else ""
    gap_mark = "‚úÖ" if gap == lowest_gap else ""
    
    print(f"{name:<35} {acc:>8.2f}% {acc_mark:<2} {jfi:>8.4f} {jfi_mark:<2} {max_min:>8.4f}   {gap:>6.2f}% {gap_mark:<2} {min_acc:>8.2f}%  {max_acc:>8.2f}%")

print("=" * 120)

# Gamma sensitivity check
print(f"\n‚≠ê GAMMA SENSITIVITY CHECK:")
if len(GAMMA_VALUES) >= 2:
    g1, g2 = GAMMA_VALUES[0], GAMMA_VALUES[1]
    acc1 = all_results[f'FedAuditGAN_v7_Œ≥={g1}']['history']['acc'][-1]
    acc2 = all_results[f'FedAuditGAN_v7_Œ≥={g2}']['history']['acc'][-1]
    jfi1 = all_results[f'FedAuditGAN_v7_Œ≥={g1}']['history']['jfi'][-1]
    jfi2 = all_results[f'FedAuditGAN_v7_Œ≥={g2}']['history']['jfi'][-1]
    
    print(f"   Œ≥={g1} vs Œ≥={g2}:")
    print(f"      Accuracy difference: {abs(acc1 - acc2):.2f}%")
    print(f"      JFI difference: {abs(jfi1 - jfi2):.4f}")
    
    if abs(acc1 - acc2) > 0.1 or abs(jfi1 - jfi2) > 0.001:
        print(f"   ‚úÖ SUCCESS! Different gamma values produce DIFFERENT results!")
    else:
        print(f"   ‚ö†Ô∏è Results still similar")

# Improvement over FedAvg
fedavg_acc = all_results['FedAvg']['history']['acc'][-1]
fedavg_jfi = all_results['FedAvg']['history']['jfi'][-1]
fedavg_gap = all_results['FedAvg']['history']['accuracy_gap'][-1]

print(f"\nüìà IMPROVEMENT OVER FedAvg:")
for method in method_names:
    if method == 'FedAvg':
        continue
    name = all_results[method]['name']
    acc = all_results[method]['history']['acc'][-1]
    jfi = all_results[method]['history']['jfi'][-1]
    gap = all_results[method]['history']['accuracy_gap'][-1]
    
    print(f"   {name}:")
    print(f"      Accuracy: {'+' if acc >= fedavg_acc else ''}{acc - fedavg_acc:.2f}%")
    print(f"      JFI: {'+' if jfi >= fedavg_jfi else ''}{jfi - fedavg_jfi:.4f}")
    print(f"      Gap Reduction: {fedavg_gap - gap:.2f}%")

print("\n" + "=" * 70)
print("‚≠ê V7 FEATURES:")
print(f"   ‚úÖ ADDED BACK: FedProx proximal term (Œº={MU})")
print(f"   ‚ùå REMOVED: Gradient clipping")
print(f"   ‚úÖ KEPT: GAN trains EVERY round (including warm-up)")
print(f"   ‚úÖ KEPT: Audit every round after round {WARMUP_ROUNDS}")
print(f"   ‚úÖ KEPT: V2 Linear aggregation formula")
print("=" * 70)

In [None]:
# ============================================================
# üìä VISUALIZATION
# ============================================================

colors = {
    'FedAvg': '#e74c3c',
    'FedAuditGAN_v7_Œ≥=0.3': '#3498db',
    'FedAuditGAN_v7_Œ≥=0.7': '#2ecc71',
}

method_names = list(all_results.keys())

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
rounds = range(1, N_ROUNDS + 1)

# Plot 1: Global Accuracy
ax = axes[0, 0]
for method in method_names:
    name = all_results[method]['name']
    acc = all_results[method]['history']['acc']
    linestyle = '--' if method == 'FedAvg' else '-'
    ax.plot(rounds, acc, color=colors.get(method, '#95a5a6'), linestyle=linestyle, linewidth=2, label=name)
ax.axvspan(1, WARMUP_ROUNDS, alpha=0.15, color='orange', label=f'Warm-up (GAN trains!)')
ax.set_xlabel('Round')
ax.set_ylabel('Accuracy (%)')
ax.set_title('CIFAR-10: Global Test Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 2: JFI
ax = axes[0, 1]
for method in method_names:
    name = all_results[method]['name']
    jfi = all_results[method]['history']['jfi']
    linestyle = '--' if method == 'FedAvg' else '-'
    ax.plot(rounds, jfi, color=colors.get(method, '#95a5a6'), linestyle=linestyle, linewidth=2, label=name)
ax.set_xlabel('Round')
ax.set_ylabel('JFI')
ax.set_title("Jain's Fairness Index")
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 3: Accuracy Gap
ax = axes[0, 2]
for method in method_names:
    name = all_results[method]['name']
    gap = all_results[method]['history']['accuracy_gap']
    linestyle = '--' if method == 'FedAvg' else '-'
    ax.plot(rounds, gap, color=colors.get(method, '#95a5a6'), linestyle=linestyle, linewidth=2, label=name)
ax.set_xlabel('Round')
ax.set_ylabel('Accuracy Gap (%)')
ax.set_title('Best-Worst Client Gap')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 4: Bias (GAN-based)
ax = axes[1, 0]
for method in method_names:
    if method == 'FedAvg':
        continue
    name = all_results[method]['name']
    bias = all_results[method]['history']['bias']
    ax.plot(rounds, bias, color=colors.get(method, '#95a5a6'), linewidth=2, label=name)
ax.axvspan(1, WARMUP_ROUNDS, alpha=0.15, color='orange', label='Warm-up')
ax.set_xlabel('Round')
ax.set_ylabel('Bias')
ax.set_title('GAN-Measured Bias (Every Round)')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 5: Min-Max Range
ax = axes[1, 1]
for method in method_names:
    name = all_results[method]['name']
    min_acc = all_results[method]['history']['min_client_acc']
    max_acc = all_results[method]['history']['max_client_acc']
    linestyle = '--' if method == 'FedAvg' else '-'
    ax.fill_between(rounds, min_acc, max_acc, color=colors.get(method, '#95a5a6'), alpha=0.2)
    ax.plot(rounds, min_acc, color=colors.get(method, '#95a5a6'), linestyle=linestyle, linewidth=1.5)
    ax.plot(rounds, max_acc, color=colors.get(method, '#95a5a6'), linestyle=linestyle, linewidth=1.5, label=name)
ax.set_xlabel('Round')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Min-Max Client Accuracy Range')
ax.legend()
ax.grid(True, alpha=0.3)

# Plot 6: Final Per-Client Accuracy
ax = axes[1, 2]
x = np.arange(N_CLIENTS)
width = 0.25
for i, method in enumerate(method_names):
    name = all_results[method]['name']
    client_accs = all_results[method]['history']['client_accuracies'][-1]
    ax.bar(x + i*width, client_accs, width, label=name, color=colors.get(method, '#95a5a6'), alpha=0.8)
ax.set_xlabel('Client ID')
ax.set_ylabel('Accuracy (%)')
ax.set_title('Per-Client Accuracy (Final Round)')
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('cifar10_v7_fedprox_results.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nüìÅ Results saved to: cifar10_v7_fedprox_results.png")

In [None]:
# ============================================================
# SAVE MODELS AND RESULTS
# ============================================================

import os
import pickle

os.makedirs('results_cifar10_v7_fedprox', exist_ok=True)

method_names = list(all_results.keys())

for method in method_names:
    name = all_results[method]['name']
    filename = f"results_cifar10_v7_fedprox/{name.replace('=', '').replace('.', '_').replace('Œ≥', 'gamma')}_CIFAR10.pth"
    
    save_dict = {
        'model_state_dict': all_results[method]['model'].state_dict(),
        'history': all_results[method]['history'],
        'config': {
            'n_rounds': N_ROUNDS,
            'n_clients': N_CLIENTS,
            'classes_per_client': CLASSES_PER_CLIENT,
            'device': str(DEVICE),
            'fedprox': True,
            'mu': MU,
            'gradient_clipping': False,
            'gan_every_round': True,
            'gan_during_warmup': True,
            'aggregation_method': 'V2_LINEAR'
        }
    }
    
    if method != 'FedAvg':
        gamma_val = float(name.split('=')[1]) if '=' in name else 0.5
        save_dict['config']['momentum'] = MOMENTUM
        save_dict['config']['warmup_rounds'] = WARMUP_ROUNDS
        save_dict['config']['gamma'] = gamma_val
    
    torch.save(save_dict, filename)
    print(f"‚úÖ Saved: {filename}")

with open('results_cifar10_v7_fedprox/all_results_summary.pkl', 'wb') as f:
    summary = {
        method: {
            'name': all_results[method]['name'],
            'history': all_results[method]['history'],
            'final_acc': all_results[method]['history']['acc'][-1],
            'final_jfi': all_results[method]['history']['jfi'][-1]
        }
        for method in method_names
    }
    pickle.dump(summary, f)
print("‚úÖ Saved: results_cifar10_v7_fedprox/all_results_summary.pkl")

print(f"\n" + "=" * 70)
print("üìä FINAL SUMMARY (Fed-Audit-GAN v7.0 - FedProx)")
print("=" * 70)
print("üîµ BASELINE:")
print(f"   FedAvg: {all_results['FedAvg']['history']['acc'][-1]:.2f}% accuracy, JFI={all_results['FedAvg']['history']['jfi'][-1]:.4f}")
print("\nüü¢ FED-AUDIT-GAN v7.0:")
for method in method_names:
    if method == 'FedAvg':
        continue
    name = all_results[method]['name']
    acc = all_results[method]['history']['acc'][-1]
    jfi = all_results[method]['history']['jfi'][-1]
    gap = all_results[method]['history']['accuracy_gap'][-1]
    print(f"   {name}: {acc:.2f}% accuracy, JFI={jfi:.4f}, Gap={gap:.2f}%")

print("\n‚≠ê V7 FEATURES:")
print(f"   ‚úÖ FedProx proximal term (Œº={MU})")
print(f"   ‚ùå No gradient clipping")
print(f"   ‚úÖ GAN trains EVERY round (including warm-up)")
print(f"   ‚úÖ V2 Linear: Weight = (1-Œ≥)√óAcc + Œ≥√óFair")
print("=" * 70)
print(f"\nüìä Check WandB: https://wandb.ai")
print(f"   Project: FED_AUDIT_GAN_TEST_7_CIFAR10")