In [None]:
# Step 1: Install packages
!pip install -q torch torchvision tqdm matplotlib numpy wandb

import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using: {DEVICE}")

In [None]:
# Step 2: Login to WandB
import wandb
wandb.login()
print("WandB logged in!")

In [None]:
# Step 3: Clone repo (REPLACE YOUR_USERNAME)
!git clone https://github.com/99VICKY99/Fed-Audit-GAN.git
%cd Fed-Audit-GAN
!git checkout strict-4-phase
!git branch --show-current

In [None]:
# Step 4: Complete Training Script with WandB
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import numpy as np
import copy
from tqdm.notebook import tqdm
import os
import pickle
import wandb

class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        return self.fc2(x)

class FairnessGenerator(nn.Module):
    def __init__(self, latent_dim=100, num_classes=10, img_shape=(1, 28, 28)):
        super().__init__()
        self.latent_dim = latent_dim
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, latent_dim)
        self.init_size = img_shape[1] // 4
        self.l1 = nn.Linear(latent_dim * 2, 128 * self.init_size ** 2)
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128), nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, 1, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2),
            nn.Conv2d(64, img_shape[0], 3, 1, 1), nn.Tanh())
        self.delta_net = nn.Sequential(
            nn.Linear(latent_dim, 256), nn.ReLU(),
            nn.Linear(256, int(np.prod(img_shape))), nn.Tanh())
        self.delta_scale = 0.1

    def forward(self, z, labels):
        gen_input = torch.cat([z, self.label_emb(labels)], dim=1)
        out = self.l1(gen_input).view(-1, 128, self.init_size, self.init_size)
        x = self.conv_blocks(out)
        delta = self.delta_net(z).view(-1, *self.img_shape) * self.delta_scale
        return x, torch.clamp(x + delta, -1, 1)

class Discriminator(nn.Module):
    def __init__(self, num_classes=10, img_shape=(1, 28, 28)):
        super().__init__()
        self.num_classes = num_classes
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.conv = nn.Sequential(
            nn.Conv2d(img_shape[0] + num_classes, 16, 3, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(16, 32, 3, 2, 1), nn.BatchNorm2d(32), nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 3, 2, 1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2))
        self.fc = nn.Sequential(nn.Linear(128 * 4, 1), nn.Sigmoid())

    def forward(self, img, labels):
        label_map = self.label_emb(labels).view(-1, self.num_classes, 1, 1)
        label_map = label_map.expand(-1, -1, self.img_shape[1], self.img_shape[2])
        out = self.conv(torch.cat([img, label_map], dim=1))
        return self.fc(out.view(out.size(0), -1))

def train_gan(G, D, model, loader, epochs=30, device='cuda', l1=1.0, l2=1.0):
    G, D, model = G.to(device), D.to(device), model.to(device)
    model.eval()
    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
    bce = nn.BCELoss()
    for _ in range(epochs):
        for imgs, labels in loader:
            bs = imgs.size(0)
            real, fake_t = torch.ones(bs, 1, device=device), torch.zeros(bs, 1, device=device)
            imgs, labels = imgs.to(device), labels.to(device)
            z = torch.randn(bs, G.latent_dim, device=device)
            gl = torch.randint(0, G.num_classes, (bs,), device=device)
            x, xp = G(z, gl)
            with torch.no_grad():
                px, pxp = model(x), model(xp)
            t1 = -torch.mean((px - pxp) ** 2)
            t2 = l1 * torch.mean((x - xp) ** 2)
            t3 = l2 * (bce(D(x, gl), real) + bce(D(xp, gl), real)) / 2
            opt_G.zero_grad(); (t1 + t2 + t3).backward(); opt_G.step()
            x, xp = G(z, gl)
            d_loss = (bce(D(imgs, labels), real) + bce(D(x.detach(), gl), fake_t) + bce(D(xp.detach(), gl), fake_t)) / 3
            opt_D.zero_grad(); d_loss.backward(); opt_D.step()
    return G, D

def compute_bias(model, x, xp, device):
    model.eval()
    with torch.no_grad():
        return torch.abs(model(x.to(device)) - model(xp.to(device))).sum(1).mean().item()

def partition_data(dataset, n):
    idx = np.argsort([dataset[i][1] for i in range(len(dataset))])
    shards = np.array_split(idx, n * 2)
    np.random.shuffle(shards)
    return [np.concatenate([shards[2*i], shards[2*i+1]]) for i in range(n)]

def evaluate(model, loader, device):
    model.eval()
    correct = sum((model(d.to(device)).argmax(1) == t.to(device)).sum().item() for d, t in loader)
    return 100 * correct / sum(len(t) for _, t in loader)

# Config
N_ROUNDS = 10
N_CLIENTS = 5
GAMMA = 2.0
N_GAN_EPOCHS = 20
N_PROBES = 300
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Init WandB
wandb.init(project="fed-audit-gan", name=f"colab_gamma{GAMMA}_clients{N_CLIENTS}", config={
    "n_rounds": N_ROUNDS, "n_clients": N_CLIENTS, "gamma": GAMMA, "device": DEVICE})

# Data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)
client_idx = partition_data(train_data, N_CLIENTS)
test_loader = DataLoader(test_data, batch_size=64)
val_loader = DataLoader(Subset(train_data, np.random.choice(len(train_data), 1000)), batch_size=32)

# Model
global_model = CNN().to(DEVICE)
history = {'acc': [], 'bias': []}

print(f"Training: {N_ROUNDS} rounds, {N_CLIENTS} clients, gamma={GAMMA}")
for r in range(N_ROUNDS):
    print(f"\n=== Round {r+1}/{N_ROUNDS} ===")
    
    # Phase 1: Client training
    updates = []
    for c in tqdm(range(N_CLIENTS), desc="Phase 1"):
        loader = DataLoader(Subset(train_data, client_idx[c]), batch_size=32, shuffle=True)
        local = copy.deepcopy(global_model)
        before = copy.deepcopy(global_model.state_dict())
        opt = optim.SGD(local.parameters(), lr=0.01)
        local.train()
        for _ in range(3):
            for d, t in loader:
                opt.zero_grad()
                F.cross_entropy(local(d.to(DEVICE)), t.to(DEVICE)).backward()
                opt.step()
        updates.append({k: local.state_dict()[k] - before[k] for k in before})
    
    # Phase 2: GAN
    print("Phase 2: GAN training")
    G = FairnessGenerator()
    D = Discriminator()
    G, D = train_gan(G, D, global_model, val_loader, epochs=N_GAN_EPOCHS, device=DEVICE)
    G.eval()
    with torch.no_grad():
        z = torch.randn(N_PROBES, G.latent_dim, device=DEVICE)
        lbl = torch.randint(0, 10, (N_PROBES,), device=DEVICE)
        x, xp = G(z, lbl)
    
    # Phase 3: Scoring
    B_base = compute_bias(global_model, x, xp, DEVICE)
    print(f"Phase 3: B_base={B_base:.4f}")
    S = []
    for i, upd in enumerate(updates):
        hyp = copy.deepcopy(global_model)
        sd = hyp.state_dict()
        for k in sd: sd[k] = sd[k] + upd[k]
        hyp.load_state_dict(sd)
        B_i = compute_bias(hyp, x, xp, DEVICE)
        S.append(B_base - B_i)
        print(f"  Client {i}: S={S[-1]:+.4f}")
    
    # Phase 4: Aggregation
    alpha = F.softmax(torch.tensor(S) * GAMMA, dim=0).tolist()
    print(f"Phase 4: alpha={[f'{a:.3f}' for a in alpha]}")
    sd = global_model.state_dict()
    for k in sd:
        sd[k] = sd[k] + sum(a * u[k] for a, u in zip(alpha, updates))
    global_model.load_state_dict(sd)
    
    acc = evaluate(global_model, test_loader, DEVICE)
    history['acc'].append(acc)
    history['bias'].append(B_base)
    wandb.log({'round': r+1, 'accuracy': acc, 'bias': B_base, 'avg_S': np.mean(S)})
    print(f"Accuracy: {acc:.2f}%")

print(f"\nFinal: {history['acc'][-1]:.2f}%")
wandb.finish()

In [None]:
# Step 5: Plot Results
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
ax1.plot(history['acc'], 'b-o')
ax1.set_xlabel('Round'); ax1.set_ylabel('Accuracy (%)'); ax1.set_title('Test Accuracy'); ax1.grid(True)
ax2.plot(history['bias'], 'r-s')
ax2.set_xlabel('Round'); ax2.set_ylabel('Bias'); ax2.set_title('Baseline Bias'); ax2.grid(True)
plt.tight_layout()
plt.savefig('results.png', dpi=150)
plt.show()
print("Check your WandB dashboard: https://wandb.ai")