# Fed-AuditGAN: Real 4-Phase Implementation

Key Formula: `alpha_k = (1 - gamma) * S_acc + gamma * S_fair`

- gamma=0.3: More accuracy
- gamma=0.5: Balanced
- gamma=0.7: More fairness

**Config: 30 rounds, 10 clients**

In [None]:
!pip install wandb torch torchvision tqdm matplotlib -q
print("Done!")

In [None]:
import wandb
wandb.login()

In [None]:
import torchimport torch.nn as nnimport torch.optim as optimimport torch.nn.functional as Ffrom torch.utils.data import DataLoader, Subsetfrom torchvision import datasets, transformsimport numpy as npimport copyfrom tqdm.notebook import tqdmimport matplotlib.pyplot as pltDEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'print(f"Using device: {DEVICE}")class CNN(nn.Module):    def __init__(self, num_classes=10):        super(CNN, self).__init__()        self.conv1 = nn.Conv2d(1, 32, 3, 1)        self.conv2 = nn.Conv2d(32, 64, 3, 1)        self.dropout1 = nn.Dropout(0.25)        self.dropout2 = nn.Dropout(0.5)        self.fc1 = nn.Linear(9216, 128)        self.fc2 = nn.Linear(128, num_classes)    def forward(self, x):        x = F.relu(self.conv1(x))        x = F.relu(self.conv2(x))        x = F.max_pool2d(x, 2)        x = self.dropout1(x)        x = torch.flatten(x, 1)        x = F.relu(self.fc1(x))        x = self.dropout2(x)        return self.fc2(x)class FairnessGenerator(nn.Module):    def __init__(self, latent_dim=100, num_classes=10, img_shape=(1, 28, 28)):        super().__init__()        self.latent_dim = latent_dim        self.num_classes = num_classes        self.img_shape = img_shape        self.label_emb = nn.Embedding(num_classes, latent_dim)        self.init_size = img_shape[1] // 4        self.l1 = nn.Linear(latent_dim * 2, 128 * self.init_size ** 2)        self.conv_blocks = nn.Sequential(            nn.BatchNorm2d(128),            nn.Upsample(scale_factor=2),            nn.Conv2d(128, 128, 3, 1, 1),            nn.BatchNorm2d(128),            nn.LeakyReLU(0.2),            nn.Upsample(scale_factor=2),            nn.Conv2d(128, 64, 3, 1, 1),            nn.BatchNorm2d(64),            nn.LeakyReLU(0.2),            nn.Conv2d(64, img_shape[0], 3, 1, 1),            nn.Tanh()        )        self.delta_net = nn.Sequential(            nn.Linear(latent_dim, 256),            nn.ReLU(),            nn.Linear(256, int(np.prod(img_shape))),            nn.Tanh()        )        self.delta_scale = 0.1    def forward(self, z, labels):        gen_input = torch.cat([z, self.label_emb(labels)], dim=1)        out = self.l1(gen_input)        out = out.view(-1, 128, self.init_size, self.init_size)        x = self.conv_blocks(out)        delta = self.delta_net(z).view(-1, *self.img_shape) * self.delta_scale        x_prime = torch.clamp(x + delta, -1, 1)        return x, x_primeclass Discriminator(nn.Module):    def __init__(self, num_classes=10, img_shape=(1, 28, 28)):        super().__init__()        self.num_classes = num_classes        self.img_shape = img_shape        self.label_emb = nn.Embedding(num_classes, num_classes)        self.conv = nn.Sequential(            nn.Conv2d(img_shape[0] + num_classes, 16, 3, 2, 1),            nn.LeakyReLU(0.2),            nn.Conv2d(16, 32, 3, 2, 1),            nn.BatchNorm2d(32),            nn.LeakyReLU(0.2),            nn.Conv2d(32, 64, 3, 2, 1),            nn.BatchNorm2d(64),            nn.LeakyReLU(0.2),            nn.Conv2d(64, 128, 3, 2, 1),            nn.BatchNorm2d(128),            nn.LeakyReLU(0.2)        )        self.fc = nn.Sequential(nn.Linear(128 * 4, 1), nn.Sigmoid())    def forward(self, img, labels):        label_map = self.label_emb(labels).view(-1, self.num_classes, 1, 1)        label_map = label_map.expand(-1, -1, self.img_shape[1], self.img_shape[2])        out = self.conv(torch.cat([img, label_map], dim=1))        return self.fc(out.view(out.size(0), -1))def train_fairness_gan(G, D, model, loader, epochs=20, alpha=1.0, beta=1.0):    G, D, model = G.to(DEVICE), D.to(DEVICE), model.to(DEVICE)    model.eval()    opt_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))    opt_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))    bce = nn.BCELoss()    for epoch in range(epochs):        for imgs, labels in loader:            batch_size = imgs.size(0)            real_t = torch.ones(batch_size, 1, device=DEVICE)            fake_t = torch.zeros(batch_size, 1, device=DEVICE)            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)            z = torch.randn(batch_size, G.latent_dim, device=DEVICE)            gen_labels = torch.randint(0, G.num_classes, (batch_size,), device=DEVICE)            x, x_prime = G(z, gen_labels)            with torch.no_grad():                pred_x, pred_xp = model(x), model(x_prime)            pred_diff = -beta * torch.mean((pred_x - pred_xp) ** 2)            realism = alpha * torch.mean((x - x_prime) ** 2)            gan_loss = (bce(D(x, gen_labels), real_t) + bce(D(x_prime, gen_labels), real_t)) / 2            g_loss = pred_diff + realism + gan_loss            opt_G.zero_grad(); g_loss.backward(); opt_G.step()            x, x_prime = G(z, gen_labels)            d_real = bce(D(imgs, labels), real_t)            d_fake = (bce(D(x.detach(), gen_labels), fake_t) + bce(D(x_prime.detach(), gen_labels), fake_t)) / 2            d_loss = (d_real + d_fake) / 2            opt_D.zero_grad(); d_loss.backward(); opt_D.step()    return G, Ddef compute_bias(model, x, x_prime):    model.eval()    with torch.no_grad():        pred_x = model(x.to(DEVICE))        pred_xp = model(x_prime.to(DEVICE))        return torch.abs(pred_x - pred_xp).sum(dim=1).mean().item()def compute_acc_score(model, update, val_loader):    model.eval()    loss_before = 0    count = 0    with torch.no_grad():        for d, t in val_loader:            d, t = d.to(DEVICE), t.to(DEVICE)            loss_before += F.cross_entropy(model(d), t, reduction='sum').item()            count += len(t)    loss_before /= count    hyp = copy.deepcopy(model)    sd = hyp.state_dict()    for k in sd: sd[k] = sd[k] + update[k]    hyp.load_state_dict(sd)    hyp.eval()    loss_after = 0    count = 0    with torch.no_grad():        for d, t in val_loader:            d, t = d.to(DEVICE), t.to(DEVICE)            loss_after += F.cross_entropy(hyp(d), t, reduction='sum').item()            count += len(t)    loss_after /= count    return loss_before - loss_afterdef partition_non_iid(dataset, n):    """Each client gets 2 different digit classes for non-IID distribution"""    idx = np.argsort([dataset[i][1] for i in range(len(dataset))])    shards = np.array_split(idx, n * 2)    np.random.shuffle(shards)    return [np.concatenate([shards[2*i], shards[2*i+1]]) for i in range(n)]def evaluate(model, loader):    model.eval()    correct = 0    total = 0    with torch.no_grad():        for d, t in loader:            d, t = d.to(DEVICE), t.to(DEVICE)            correct += (model(d).argmax(1) == t).sum().item()            total += len(t)    return 100 * correct / totaldef normalize(scores):    s = np.array(scores)    if s.max() - s.min() < 1e-8: return np.ones_like(s) / len(s)    return (s - s.min()) / (s.max() - s.min())print("Implementation loaded!")

In [None]:
def run_fed_audit_gan(gamma, n_rounds=30, n_clients=10, local_epochs=3, lr=0.01, gan_epochs=20, n_probes=300):    """    Fed-AuditGAN with Real 4-Phase Algorithm        Phase 1: Client local training (different updates due to non-IID data)    Phase 2: GAN fairness auditing    Phase 3: Fairness & accuracy scoring    Phase 4: Multi-objective aggregation: alpha_k = (1-gamma)*S_acc + gamma*S_fair    """    wandb.init(project="fed-audit-gan-real", name=f"gamma{gamma}_r{n_rounds}_c{n_clients}",                config={'gamma': gamma, 'n_rounds': n_rounds, 'n_clients': n_clients})        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])    train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)    test_data = datasets.MNIST('./data', train=False, download=True, transform=transform)        # Non-IID partition: each client gets different digit classes    client_idx = partition_non_iid(train_data, n_clients)    print(f"Data partitioned: {n_clients} clients with non-IID distribution")        test_loader = DataLoader(test_data, batch_size=64, shuffle=False)    val_idx = np.random.choice(len(train_data), 1000, replace=False)    val_loader = DataLoader(Subset(train_data, val_idx), batch_size=32, shuffle=False)        global_model = CNN().to(DEVICE)    history = {'accuracy': [], 'bias': [], 'alphas': []}        print(f"\nFed-AuditGAN | gamma={gamma} | {n_rounds} rounds | {n_clients} clients")    print(f"Formula: alpha_k = (1-gamma)*S_acc + gamma*S_fair\n")        for rnd in tqdm(range(n_rounds), desc=f"gamma={gamma}"):        # Phase 1: Client training (each client has different data -> different updates)        updates = []        for cid in range(n_clients):            loader = DataLoader(Subset(train_data, client_idx[cid]), batch_size=32, shuffle=True)            local = copy.deepcopy(global_model)            before = copy.deepcopy(global_model.state_dict())            opt = optim.SGD(local.parameters(), lr=lr)            local.train()            for _ in range(local_epochs):                for d, t in loader:                    d, t = d.to(DEVICE), t.to(DEVICE)                    opt.zero_grad()                    F.cross_entropy(local(d), t).backward()                    opt.step()            updates.append({k: local.state_dict()[k] - before[k] for k in before})                # Phase 2: GAN auditing        G, D = FairnessGenerator().to(DEVICE), Discriminator().to(DEVICE)        G, D = train_fairness_gan(G, D, global_model, val_loader, epochs=gan_epochs)        G.eval()        with torch.no_grad():            z = torch.randn(n_probes, G.latent_dim, device=DEVICE)            lbls = torch.randint(0, 10, (n_probes,), device=DEVICE)            x_p, xp_p = G(z, lbls)                # Phase 3: Scoring        B_base = compute_bias(global_model, x_p, xp_p)        S_fair, S_acc = [], []        for upd in updates:            hyp = copy.deepcopy(global_model)            sd = hyp.state_dict()            for k in sd: sd[k] = sd[k] + upd[k]            hyp.load_state_dict(sd)            S_fair.append(B_base - compute_bias(hyp, x_p, xp_p))            S_acc.append(compute_acc_score(global_model, upd, val_loader))                # Phase 4: Aggregation with multi-objective weights        S_fair_n, S_acc_n = normalize(S_fair), normalize(S_acc)        alphas = [(1-gamma)*S_acc_n[i] + gamma*S_fair_n[i] for i in range(n_clients)]        a_sum = sum(alphas)        alphas = [a/a_sum if a_sum > 0 else 1/n_clients for a in alphas]                sd = global_model.state_dict()        for k in sd:            sd[k] = sd[k] + sum(alphas[i] * updates[i][k] for i in range(n_clients))        global_model.load_state_dict(sd)                acc = evaluate(global_model, test_loader)        history['accuracy'].append(acc)        history['bias'].append(B_base)        history['alphas'].append(alphas)                wandb.log({'round': rnd+1, 'accuracy': acc, 'bias': B_base,                    'alpha_max': max(alphas), 'alpha_min': min(alphas), 'alpha_std': np.std(alphas)})        print(f"gamma={gamma} Done | Accuracy: {history['accuracy'][-1]:.2f}%")    wandb.finish()    return historyprint("Training function ready!")

In [None]:
# Run experiments with gamma = 0.3, 0.5, 0.7# Config: 30 rounds, 10 clients (each with different non-IID data)results = {}for gamma in [0.3, 0.5, 0.7]:    print(f"\n{'='*60}")    print(f"Running gamma = {gamma}")    print(f"{'='*60}")    results[gamma] = run_fed_audit_gan(gamma=gamma, n_rounds=30, n_clients=10)print("\nAll experiments complete!")

In [None]:
# Visualize resultsfig, axes = plt.subplots(1, 2, figsize=(14, 5))colors = {0.3: 'blue', 0.5: 'green', 0.7: 'red'}rounds = list(range(1, 31))for g in [0.3, 0.5, 0.7]:    axes[0].plot(rounds, results[g]['accuracy'], 'o-', label=f'gamma={g}', color=colors[g], markersize=3)    axes[1].plot(rounds, results[g]['bias'], 'o-', label=f'gamma={g}', color=colors[g], markersize=3)axes[0].set_xlabel('Round'); axes[0].set_ylabel('Accuracy (%)')axes[0].set_title('Accuracy vs Communication Round'); axes[0].legend(); axes[0].grid(True, alpha=0.3)axes[1].set_xlabel('Round'); axes[1].set_ylabel('Bias')axes[1].set_title('Model Bias vs Communication Round'); axes[1].legend(); axes[1].grid(True, alpha=0.3)plt.tight_layout()plt.savefig('results.png', dpi=150)plt.show()print("\n" + "="*60)print("SUMMARY")print("="*60)for g in [0.3, 0.5, 0.7]:    print(f"gamma={g}: Final Acc={results[g]['accuracy'][-1]:.2f}%, Final Bias={results[g]['bias'][-1]:.4f}")print("="*60)