<a href="https://colab.research.google.com/github/Krishika-Garg/fedlearning-under-attack/blob/main/Minor_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch, torchvision, flwr, numpy
print(torch.__version__, torchvision.__version__, numpy.__version__, flwr.__version__)


In [None]:
!pip install --no-cache-dir torch==2.2.0 torchvision==0.17.0 flwr==1.3.0

In [None]:
# Install Flower + PyTorch + torchvision and pin numpy to avoid binary conflicts
# Clean up any corrupted installs
!pip uninstall -y numpy torch torchvision flwr || true

# Reinstall in the safest order (avoid cache conflicts)
!pip install --no-cache-dir numpy==1.26.4
!pip install --no-cache-dir torch==2.2.0 torchvision==0.17.0
!pip install --no-cache-dir flwr==1.3.0



In [None]:
# Run this as a single Colab cell. It installs numpy first (to avoid binary mismatch),
# then PyTorch + torchvision, then Flower. --no-cache-dir helps avoid cached wheels.
!pip install --upgrade pip
!pip install --upgrade --force-reinstall --no-cache-dir numpy==1.26.4
!pip install --upgrade --force-reinstall --no-cache-dir torch torchvision
!pip install --upgrade --force-reinstall --no-cache-dir flwr==1.3.0

In [None]:
# Download MNIST and partition equally into 10 client loaders (IID)
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
testset  = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)

print(f"Train samples: {len(trainset)}, Test samples: {len(testset)}")

num_clients = 10
data_per_client = len(trainset) // num_clients

client_indices = [list(range(i*data_per_client, (i+1)*data_per_client)) for i in range(num_clients)]
client_loaders = [ DataLoader(Subset(trainset, idxs), batch_size=32, shuffle=True) for idxs in client_indices ]

testloader = DataLoader(testset, batch_size=256, shuffle=False)

# verify one batch
images, labels = next(iter(client_loaders[0]))
print("One client batch shape:", images.shape)
print("Sample labels:", labels[:10].tolist())


In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # conv layers
        self.conv1 = nn.Conv2d(1, 16, 3, 1)   # 28x28 -> 26x26
        self.conv2 = nn.Conv2d(16, 32, 3, 1)  # 26x26 -> 24x24
        self.pool  = nn.MaxPool2d(2, 2)       # 24x24 -> 12x12
        # flatten size after pooling = 32 * 12 * 12
        self.fc1   = nn.Linear(32 * 12 * 12, 128)
        self.fc2   = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# quick test forward pass
model = Net()
sample_images, _ = next(iter(client_loaders[0]))
output = model(sample_images)   # should not error
print("Model forward output shape:", output.shape)  # expect [batch_size, 10]


In [None]:
# Optional: quick one-batch train step to confirm backprop works
import torch.optim as optim
model.train()
optimizer = optim.SGD(model.parameters(), lr=0.01)
x, y = next(iter(client_loaders[0]))
optimizer.zero_grad()
out = model(x)
loss = F.cross_entropy(out, y)
loss.backward()
optimizer.step()
print("One minibatch training step OK. Loss:", loss.item())


In [None]:
import flwr as fl
import torch
import torch.nn.functional as F
from typing import List, Tuple, Dict
import numpy as np

In [None]:
def get_weights(model: torch.nn.Module) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in model.state_dict().items()]

def set_weights(model: torch.nn.Module, weights: List[np.ndarray]) -> None:
    params_dict = zip(model.state_dict().keys(), weights)
    state_dict = {k: torch.tensor(v) for k, v in params_dict}
    model.load_state_dict(state_dict, strict=True)

In [None]:
class FLClient(fl.client.NumPyClient):
    def __init__(self, cid: int, model: torch.nn.Module, trainloader, testloader):
        self.cid = cid
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = model.to(self.device)
        self.trainloader = trainloader
        self.testloader = testloader

    def get_parameters(self):
        return get_weights(self.model)

    def set_parameters(self, parameters):
        set_weights(self.model, parameters)
        self.model.to(self.device)

    def fit(self, parameters, config):
        # Set incoming global weights
        self.set_parameters(parameters)
        # Local training (1 epoch default; can be changed via config)
        self.model.train()
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
        local_epochs = int(config.get("local_epochs", 1))
        for _ in range(local_epochs):
            for x, y in self.trainloader:
                x, y = x.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                out = self.model(x)
                loss = F.cross_entropy(out, y)
                loss.backward()
                optimizer.step()
        # Return updated weights and number of examples
        return self.get_parameters(), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
        # Evaluate on local test data (we return metrics, but server can also run a global evaluation)
        self.set_parameters(parameters)
        self.model.eval()
        loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for x, y in self.testloader:
                x, y = x.to(self.device), y.to(self.device)
                out = self.model(x)
                loss += F.cross_entropy(out, y).item() * x.size(0)
                preds = out.argmax(dim=1)
                correct += (preds == y).sum().item()
                total += x.size(0)
        if total == 0:
            return float(loss), 0, {}
        return float(loss / total), total, {"accuracy": float(correct / total)}

In [None]:
def server_evaluate(weights: List[np.ndarray]) -> Tuple[float, Dict[str, float]]:
    """Given global weights, set them on a fresh model and evaluate on the global test set."""
    model = Net()
    set_weights(model, weights)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss += F.cross_entropy(out, y).item() * x.size(0)
            preds = out.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    if total == 0:
        return None
    return float(loss / total), {"accuracy": float(correct / total)}


In [None]:
def client_fn(cid: str) -> FLClient:
    # Each client gets its own fresh model instance and its respective DataLoader
    idx = int(cid)
    model = Net()
    trainloader = client_loaders[idx]
    # Use shared global testloader here for simplicity (local evaluate uses same test)
    return FLClient(cid=idx, model=model, trainloader=trainloader, testloader=testloader)

In [None]:
# --- Manual FedAvg simulation (no Flower / no Ray) ---
import copy
import numpy as np
import torch
import torch.nn.functional as F

# Config
NUM_ROUNDS = 5
LOCAL_EPOCHS = 1
lr = 0.01
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Helper: local train that returns updated model weights (numpy) and number of examples
def local_train_return_weights(model, trainloader, global_weights, local_epochs=1, lr=0.01):
    # set global weights
    set_weights(model, global_weights)
    model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    for _ in range(local_epochs):
        for x, y in trainloader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = F.cross_entropy(out, y)
            loss.backward()
            optimizer.step()
    # return numpy weights and count
    return get_weights(model), len(trainloader.dataset)

# Helper: evaluate a model (PyTorch model instance) on testloader, returns accuracy
def evaluate_model_on_test(model):
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            preds = out.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0

# Initialize global model (fresh)
global_model = Net().to(device)
global_weights = get_weights(global_model)

print("Starting manual FedAvg simulation: {} clients, {} rounds, {} local epochs".format(len(client_loaders), NUM_ROUNDS, LOCAL_EPOCHS))

for rnd in range(1, NUM_ROUNDS + 1):
    client_updates = []
    client_ns = []
    # Each client trains locally and returns updated weights
    for i, trainloader in enumerate(client_loaders):
        local_model = Net()  # fresh model instance for the client
        w, n = local_train_return_weights(local_model, trainloader, global_weights, local_epochs=LOCAL_EPOCHS, lr=lr)
        client_updates.append(w)
        client_ns.append(n)
    # Federated averaging (weighted by number of examples)
    # compute total examples
    total_n = sum(client_ns)
    # initialize averaged weights as zeros arrays with same shapes
    avg_weights = []
    for layer_idx in range(len(client_updates[0])):
        # start with zeros of same shape
        layer_shape = client_updates[0][layer_idx].shape
        accum = np.zeros(layer_shape, dtype=client_updates[0][layer_idx].dtype)
        # add weighted contributions
        for c_idx in range(len(client_updates)):
            accum += client_updates[c_idx][layer_idx] * (client_ns[c_idx] / total_n)
        avg_weights.append(accum)
    # set new global weights
    global_weights = avg_weights
    set_weights(global_model, global_weights)

    # Evaluate global model on test set
    acc = evaluate_model_on_test(global_model)
    print(f"Round {rnd:02d} -> Global test accuracy: {acc*100:.2f}%")

print("Manual FedAvg simulation finished.")


In [None]:
!pip install -U "flwr[simulation]"

In [None]:
# Quick import test: is Flower and Ray available?
import importlib, sys
results = {}
for pkg in ("flwr", "ray"):
    try:
        mod = importlib.import_module(pkg)
        results[pkg] = f"OK, version {getattr(mod, '__version__', 'unknown')}"
    except Exception as e:
        results[pkg] = f"IMPORT ERROR: {type(e).__name__}: {e}"

for k,v in results.items():
    print(k, "=>", v)


In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
NUM_ROUNDS = 5
LOCAL_EPOCHS = 1
lr = 0.01
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose attacker clients (indices in 0..num_clients-1). Example: client 0 is malicious.
malicious_client_ids = {0}            # set of malicious client indices
attack_type = "label_flip"            # currently only label_flip implemented
# Label-flip specifics: source_label -> target_label (we measure ASR for source->target)
source_label = 0
target_label = 1


In [None]:
def local_train_return_weights(model, trainloader, global_weights, local_epochs=1, lr=0.01, is_malicious=False):
    set_weights(model, global_weights)
    model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    for _ in range(local_epochs):
        for x, y in trainloader:
            x, y = x.to(device), y.to(device)
            # If malicious and attack_type is label_flip, flip labels in the local batch
            if is_malicious and attack_type == "label_flip":
                # Simple flip: map source_label -> target_label, keep others unchanged
                # If you want full class shift (y+1)%10, replace line below accordingly.
                mask = (y == source_label)
                if mask.any():
                    y[mask] = target_label
            optimizer.zero_grad()
            out = model(x)
            loss = F.cross_entropy(out, y)
            loss.backward()
            optimizer.step()
    return get_weights(model), len(trainloader.dataset)


In [None]:
def compute_asr(model):
    """ASR = fraction of test samples whose true label==source_label but predicted==target_label"""
    model.to(device)
    model.eval()
    total_source = 0
    source_to_target = 0
    for x, y in testloader:
        mask = (y == source_label)
        if not mask.any():
            continue
        x_src = x[mask].to(device)
        y_src = y[mask].to(device)
        out = model(x_src)
        preds = out.argmax(dim=1)
        total_source += y_src.size(0)
        source_to_target += (preds == target_label).sum().item()
    return (source_to_target / total_source) if total_source > 0 else 0.0


In [None]:
global_model = Net().to(device)
global_weights = get_weights(global_model)

acc_history = []
asr_history = []

print(f"Running FedAvg with malicious clients {malicious_client_ids} (attack={attack_type})")
for rnd in range(1, NUM_ROUNDS + 1):
    client_updates = []
    client_ns = []
    for i, trainloader in enumerate(client_loaders):
        local_model = Net()
        is_mal = (i in malicious_client_ids)
        w, n = local_train_return_weights(local_model, trainloader, global_weights,
                                          local_epochs=LOCAL_EPOCHS, lr=lr, is_malicious=is_mal)
        client_updates.append(w)
        client_ns.append(n)
    # weighted average
    total_n = sum(client_ns)
    avg_weights = []
    for layer_idx in range(len(client_updates[0])):
        accum = np.zeros(client_updates[0][layer_idx].shape, dtype=client_updates[0][layer_idx].dtype)
        for c_idx in range(len(client_updates)):
            accum += client_updates[c_idx][layer_idx] * (client_ns[c_idx] / total_n)
        avg_weights.append(accum)
    global_weights = avg_weights
    set_weights(global_model, global_weights)

    # evaluate
    # global accuracy on clean test set
    global_model.to(device)
    global_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            out = global_model(x)
            preds = out.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    acc = correct / total if total > 0 else 0.0
    asr = compute_asr(global_model)

    acc_history.append(acc)
    asr_history.append(asr)

    print(f"Round {rnd:02d} -> Global accuracy: {acc*100:.2f}%, ASR (#{source_label}->{target_label}): {asr*100:.2f}%")

print("Finished.")

In [None]:
plt.figure(figsize=(8,4))
plt.plot(range(1, NUM_ROUNDS+1), [a*100 for a in acc_history], marker='o', label='Global Accuracy')
plt.plot(range(1, NUM_ROUNDS+1), [a*100 for a in asr_history], marker='x', label=f'ASR {source_label}->{target_label}')
plt.xlabel('Round')
plt.ylabel('Percent')
plt.title('FedAvg: Accuracy and ASR over Rounds')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
from collections import Counter
# show label distribution per client
for i, loader in enumerate(client_loaders):
    labels = []
    for _, y in loader:
        labels.extend(y.tolist())
    c = Counter(labels)
    print(f"Client {i}: total={sum(c.values())}, label_counts={dict(c)}")


In [None]:
# CELL A — Label-flip grid experiments
import numpy as np
import torch
import torch.nn.functional as F
import pandas as pd
import copy
from itertools import product
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Experiment config (tweak)
NUM_ROUNDS = 5
LOCAL_EPOCHS = 1
ATTACKER_LOCAL_EPOCHS = 5
lr = 0.01

# grid to try
attacker_fractions = [0.1, 0.3]   # e.g., 10% and 30% attackers
flip_modes = ["single", "full"]    # single: only source->target, full: y -> (y+1)%10
scale_options = [False, True]      # whether to scale malicious deltas
scale_factor = 5.0

# target/source for ASR when using "single"
source_label = 0
target_label = 1

# helpers (re-used)
def local_train_return_weights(model, trainloader, global_weights, local_epochs=1, lr=0.01, is_malicious=False, flip_mode="single"):
    set_weights(model, global_weights)
    model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    for _ in range(local_epochs):
        for x, y in trainloader:
            x, y = x.to(device), y.to(device)
            if is_malicious:
                if flip_mode == "single":
                    mask = (y == source_label)
                    if mask.any():
                        y[mask] = target_label
                else:  # full
                    y = (y + 1) % 10
            optimizer.zero_grad()
            out = model(x)
            loss = F.cross_entropy(out, y)
            loss.backward()
            optimizer.step()
    return get_weights(model), len(trainloader.dataset)

def scale_update(client_w, global_w, scale):
    scaled = []
    for cw, gw in zip(client_w, global_w):
        delta = cw - gw
        scaled.append(gw + scale * delta)
    return scaled

@torch.no_grad()
def compute_asr(model, src=0, tgt=1):
    model.to(device); model.eval()
    total_source = 0; source_to_target = 0
    for x, y in testloader:
        mask = (y == src)
        if not mask.any(): continue
        x_src = x[mask].to(device); y_src = y[mask].to(device)
        out = model(x_src); preds = out.argmax(dim=1)
        total_source += y_src.size(0)
        source_to_target += (preds == tgt).sum().item()
    return (source_to_target / total_source) if total_source > 0 else 0.0

def evaluate_global(model):
    model.to(device); model.eval()
    correct = 0; total = 0
    with torch.no_grad():
        for x,y in testloader:
            x,y = x.to(device), y.to(device)
            out = model(x); preds = out.argmax(dim=1)
            correct += (preds == y).sum().item(); total += x.size(0)
    return correct / total if total > 0 else 0.0

# Run grid
results = []
num_clients = len(client_loaders)
client_indices = list(range(num_clients))

for frac, flip_mode, scale in product(attacker_fractions, flip_modes, scale_options):
    k_attackers = max(1, int(round(frac * num_clients)))
    # choose first k_attackers clients as malicious (deterministic)
    malicious_ids = set(client_indices[:k_attackers])
    print(f"\nRunning config: frac={frac}, k={k_attackers}, flip_mode={flip_mode}, scale={scale}")
    # Initialize global model
    global_model = Net().to(device)
    global_weights = get_weights(global_model)
    acc_history = []; asr_history = []
    for rnd in range(1, NUM_ROUNDS+1):
        client_updates = []; client_ns = []
        for i, trainloader in enumerate(client_loaders):
            local_model = Net()
            is_mal = (i in malicious_ids)
            local_epochs = ATTACKER_LOCAL_EPOCHS if is_mal else LOCAL_EPOCHS
            w, n = local_train_return_weights(local_model, trainloader, global_weights,
                                              local_epochs=local_epochs, lr=lr, is_malicious=is_mal, flip_mode=flip_mode)
            if is_mal and scale:
                w = scale_update(w, global_weights, scale_factor)
            client_updates.append(w); client_ns.append(n)
        # FedAvg average
        total_n = sum(client_ns)
        avg_weights = []
        for li in range(len(client_updates[0])):
            accum = np.zeros(client_updates[0][li].shape, dtype=client_updates[0][li].dtype)
            for ci in range(len(client_updates)):
                accum += client_updates[ci][li] * (client_ns[ci] / total_n)
            avg_weights.append(accum)
        global_weights = avg_weights
        set_weights(global_model, global_weights)
        acc = evaluate_global(global_model)
        asr = compute_asr(global_model, src=source_label, tgt=target_label)
        acc_history.append(acc); asr_history.append(asr)
        print(f"Round {rnd} -> acc {acc*100:.2f}%, ASR {asr*100:.2f}%")
    results.append({
        "attacker_frac": frac,
        "k_attackers": k_attackers,
        "flip_mode": flip_mode,
        "scale": scale,
        "round_1_acc": acc_history[0],
        "round_final_acc": acc_history[-1],
        "round_1_asr": asr_history[0],
        "round_final_asr": asr_history[-1],
        "acc_series": acc_history,
        "asr_series": asr_history
    })

# Save summary CSV (flatten basic results)
df = pd.DataFrame([{
    "attacker_frac": r["attacker_frac"],
    "k_attackers": r["k_attackers"],
    "flip_mode": r["flip_mode"],
    "scale": r["scale"],
    "final_acc": r["round_final_acc"],
    "final_asr": r["round_final_asr"]
} for r in results])
df.to_csv("results_labelflip_summary.csv", index=False)
print("\nSaved results_labelflip_summary.csv")


In [None]:
import pandas as pd
df = pd.read_csv("results_labelflip_summary.csv")
display(df)


In [None]:
# CELL B — Backdoor attack (visual trigger) experiment (paste & run)
import numpy as np
import torch
import torch.nn.functional as F
from copy import deepcopy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Config (tweak if you want)
NUM_ROUNDS = 5
LOCAL_EPOCHS = 1
ATTACKER_LOCAL_EPOCHS = 5
lr = 0.01
malicious_client_ids = {0, 1}   # attacker clients (change as needed)
target_label = 7                # the label attacker wants triggered inputs to be classified as
poison_frac = 0.2               # fraction of local mini-batch attacker poisons

# Trigger function: add a small white square in bottom-right of the image
def add_trigger_batch(x_batch, size=4, value=1.0):
    x = x_batch.clone()
    b, c, h, w = x.shape
    x[:, :, h-size:h, w-size:w] = value
    return x

# Build a triggered test set (trigger applied to all test images)
triggered_inputs = []
triggered_labels = []
for x, y in testloader:
    xt = add_trigger_batch(x, size=4, value=1.0)
    triggered_inputs.append(xt)
    triggered_labels.append(y)
triggered_X = torch.cat(triggered_inputs, dim=0)
triggered_Y = torch.cat(triggered_labels, dim=0)
print("Triggered test built:", triggered_X.shape, triggered_Y.shape)

# Local training for backdoor (attackers poison a fraction of their batches)
def local_train_backdoor(model, trainloader, global_weights, local_epochs=1, lr=0.01, is_malicious=False, poison_frac=0.2):
    set_weights(model, global_weights)
    model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    for _ in range(local_epochs):
        for x, y in trainloader:
            x, y = x.to(device), y.to(device)
            if is_malicious:
                B = x.size(0)
                k = max(1, int(round(poison_frac * B)))
                idx = torch.randperm(B)[:k]
                x[idx] = add_trigger_batch(x[idx], size=4, value=1.0)
                y[idx] = target_label
            optimizer.zero_grad()
            out = model(x)
            loss = F.cross_entropy(out, y)
            loss.backward()
            optimizer.step()
    return get_weights(model), len(trainloader.dataset)

@torch.no_grad()
def compute_asr_backdoor(model):
    model.to(device); model.eval()
    correct_target = 0; total = 0
    batch_size = 256
    for i in range(0, len(triggered_X), batch_size):
        xb = triggered_X[i:i+batch_size].to(device)
        out = model(xb)
        preds = out.argmax(dim=1)
        total += preds.size(0)
        correct_target += (preds == target_label).sum().item()
    return correct_target / total if total > 0 else 0.0

def evaluate_global(model):
    model.to(device); model.eval()
    correct = 0; total = 0
    with torch.no_grad():
        for x,y in testloader:
            x,y = x.to(device), y.to(device)
            out = model(x); preds = out.argmax(dim=1)
            correct += (preds == y).sum().item(); total += x.size(0)
    return correct / total if total > 0 else 0.0

# Run FedAvg with the backdoor attackers
global_model = Net().to(device)
global_weights = get_weights(global_model)

print("Starting backdoor experiment: attackers:", malicious_client_ids, "target_label:", target_label)
for rnd in range(1, NUM_ROUNDS+1):
    client_updates = []; client_ns = []
    for i, trainloader in enumerate(client_loaders):
        local_model = Net()
        is_mal = (i in malicious_client_ids)
        local_epochs = ATTACKER_LOCAL_EPOCHS if is_mal else LOCAL_EPOCHS
        w, n = local_train_backdoor(local_model, trainloader, global_weights,
                                    local_epochs=local_epochs, lr=lr, is_malicious=is_mal, poison_frac=poison_frac)
        client_updates.append(w); client_ns.append(n)
    # FedAvg average
    total_n = sum(client_ns)
    avg_weights = []
    for li in range(len(client_updates[0])):
        accum = np.zeros(client_updates[0][li].shape, dtype=client_updates[0][li].dtype)
        for ci in range(len(client_updates)):
            accum += client_updates[ci][li] * (client_ns[ci] / total_n)
        avg_weights.append(accum)
    global_weights = avg_weights
    set_weights(global_model, global_weights)

    acc = evaluate_global(global_model)
    asr_b = compute_asr_backdoor(global_model)
    print(f"Round {rnd} -> Global acc: {acc*100:.2f}%, Backdoor ASR -> target {target_label}: {asr_b*100:.2f}%")

print("Backdoor experiment finished.")


In [None]:
# CELL C — Defenses: Trimmed Mean and Krum vs Baseline (backdoor scenario)
import numpy as np
import torch
import torch.nn.functional as F
from copy import deepcopy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Config (matches your last backdoor experiment) ----------
NUM_ROUNDS = 5
LOCAL_EPOCHS = 1
ATTACKER_LOCAL_EPOCHS = 5
lr = 0.01
malicious_client_ids = {0, 1}
target_label = 7
poison_frac = 0.2

# ---------- Aggregator functions ----------
def trimmed_mean_aggregate(client_updates, trim_ratio=0.2):
    n_clients = len(client_updates)
    trimmed = []
    for layer_idx in range(len(client_updates[0])):
        # stack per client for this layer
        stacked = np.stack([client_updates[c][layer_idx].reshape(-1) for c in range(n_clients)], axis=0)  # (n_clients, N)
        trim_count = int(np.floor(trim_ratio * n_clients))
        sorted_vals = np.sort(stacked, axis=0)
        if trim_count > 0:
            trimmed_vals = sorted_vals[trim_count: n_clients-trim_count, :]
        else:
            trimmed_vals = sorted_vals
        mean_vals = np.mean(trimmed_vals, axis=0)
        avg_layer = mean_vals.reshape(client_updates[0][layer_idx].shape)
        trimmed.append(avg_layer)
    return trimmed

def krum_aggregate(client_updates, f=1):
    n = len(client_updates)
    if n <= 2*f + 2:
        raise ValueError("Not enough clients for Krum with f={} (need n > 2f+2)".format(f))
    # flatten
    flat = [np.concatenate([layer.reshape(-1) for layer in client_updates[c]]) for c in range(n)]
    dists = np.zeros((n,n))
    for i in range(n):
        for j in range(i+1,n):
            d = np.sum((flat[i] - flat[j])**2)
            dists[i,j] = d; dists[j,i] = d
    scores = []
    for i in range(n):
        sorted_d = np.sort(dists[i])
        nb_count = n - f - 2
        # sum of smallest nb_count distances (exclude its own zero at sorted_d[0])
        score = np.sum(sorted_d[1:1+nb_count])
        scores.append(score)
    winner = int(np.argmin(scores))
    # return the chosen client's update (Krum selects a single update)
    return client_updates[winner]

# ---------- Backdoor local training used earlier ----------
def add_trigger_batch(x_batch, size=4, value=1.0):
    x = x_batch.clone()
    b, c, h, w = x.shape
    x[:, :, h-size:h, w-size:w] = value
    return x

def local_train_backdoor(model, trainloader, global_weights, local_epochs=1, lr=0.01, is_malicious=False, poison_frac=0.2):
    set_weights(model, global_weights)
    model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    for _ in range(local_epochs):
        for x, y in trainloader:
            x, y = x.to(device), y.to(device)
            if is_malicious:
                B = x.size(0)
                k = max(1, int(round(poison_frac * B)))
                idx = torch.randperm(B)[:k]
                x[idx] = add_trigger_batch(x[idx], size=4, value=1.0)
                y[idx] = target_label
            optimizer.zero_grad()
            out = model(x)
            loss = F.cross_entropy(out, y)
            loss.backward()
            optimizer.step()
    return get_weights(model), len(trainloader.dataset)

@torch.no_grad()
def compute_asr_backdoor(model):
    model.to(device); model.eval()
    correct_target = 0; total = 0
    batch_size = 256
    for i in range(0, len(triggered_X), batch_size):
        xb = triggered_X[i:i+batch_size].to(device)
        out = model(xb)
        preds = out.argmax(dim=1)
        total += preds.size(0)
        correct_target += (preds == target_label).sum().item()
    return correct_target / total if total > 0 else 0.0

def evaluate_global(model):
    model.to(device); model.eval()
    correct = 0; total = 0
    with torch.no_grad():
        for x,y in testloader:
            x,y = x.to(device), y.to(device)
            out = model(x); preds = out.argmax(dim=1)
            correct += (preds == y).sum().item(); total += x.size(0)
    return correct / total if total > 0 else 0.0

# ---------- Helper to run one experiment with chosen aggregator ----------
def run_backdoor_with_aggregator(aggregator_name="fedavg", trim_ratio=0.2, krum_f=1):
    print("\n=== Running with aggregator:", aggregator_name, "===")
    global_model = Net().to(device)
    global_weights = get_weights(global_model)
    acc_history = []; asr_history = []
    for rnd in range(1, NUM_ROUNDS+1):
        client_updates = []; client_ns = []
        for i, trainloader in enumerate(client_loaders):
            local_model = Net()
            is_mal = (i in malicious_client_ids)
            local_epochs = ATTACKER_LOCAL_EPOCHS if is_mal else LOCAL_EPOCHS
            w, n = local_train_backdoor(local_model, trainloader, global_weights,
                                        local_epochs=local_epochs, lr=lr, is_malicious=is_mal, poison_frac=poison_frac)
            client_updates.append(w); client_ns.append(n)
        # Aggregate
        if aggregator_name == "fedavg":
            total_n = sum(client_ns)
            avg_weights = []
            for li in range(len(client_updates[0])):
                accum = np.zeros(client_updates[0][li].shape, dtype=client_updates[0][li].dtype)
                for ci in range(len(client_updates)):
                    accum += client_updates[ci][li] * (client_ns[ci] / total_n)
                avg_weights.append(accum)
        elif aggregator_name == "trimmed_mean":
            avg_weights = trimmed_mean_aggregate(client_updates, trim_ratio=trim_ratio)
        elif aggregator_name == "krum":
            avg_weights = krum_aggregate(client_updates, f=krum_f)
        else:
            raise ValueError("Unknown aggregator")
        global_weights = avg_weights
        set_weights(global_model, global_weights)
        acc = evaluate_global(global_model)
        asr_b = compute_asr_backdoor(global_model)
        acc_history.append(acc); asr_history.append(asr_b)
        print(f"Round {rnd} -> acc: {acc*100:.2f}%, Backdoor ASR: {asr_b*100:.2f}%")
    return acc_history, asr_history

# ---------- Run baseline, trimmed mean, and krum ----------
# Ensure triggered_X / triggered_Y exist from your previous Cell B run (they should)
# If not present, you must rebuild triggered_X as in Cell B before running this cell.
try:
    _ = triggered_X.shape
except NameError:
    raise RuntimeError("triggered_X not found. Run Cell B (backdoor) first to build triggered test set.")

# FedAvg baseline
acc_b, asr_b = run_backdoor_with_aggregator("fedavg")

# Trimmed Mean
acc_t, asr_t = run_backdoor_with_aggregator("trimmed_mean", trim_ratio=0.2)

# Krum (choose f = number of suspected Byzantine clients; here f=2 for 2 attackers)
acc_k, asr_k = run_backdoor_with_aggregator("krum", krum_f=2)

print("\nFinished defense experiments. Summary (final round):")
print(f"FedAvg -> acc {acc_b[-1]*100:.2f}%, ASR {asr_b[-1]*100:.2f}%")
print(f"TrimmedMean -> acc {acc_t[-1]*100:.2f}%, ASR {asr_t[-1]*100:.2f}%")
print(f"Krum -> acc {acc_k[-1]*100:.2f}%, ASR {asr_k[-1]*100:.2f}%")


In [None]:
# Trim sweep + simple FoolsGold aggregator (paste-run)
import numpy as np
import torch, torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from sklearn.metrics.pairwise import cosine_similarity

# Experiment config (matches your earlier backdoor run)
NUM_ROUNDS = 5
LOCAL_EPOCHS = 1
ATTACKER_LOCAL_EPOCHS = 5
lr = 0.01
malicious_client_ids = {0,1}
target_label = 7
poison_frac = 0.2

# Use same local_train_backdoor from earlier (redefine if necessary)
def add_trigger_batch(x_batch, size=4, value=1.0):
    x = x_batch.clone()
    b, c, h, w = x.shape
    x[:, :, h-size:h, w-size:w] = value
    return x

def local_train_backdoor(model, trainloader, global_weights, local_epochs=1, lr=0.01, is_malicious=False, poison_frac=0.2):
    set_weights(model, global_weights)
    model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    for _ in range(local_epochs):
        for x, y in trainloader:
            x, y = x.to(device), y.to(device)
            if is_malicious:
                B = x.size(0)
                k = max(1, int(round(poison_frac * B)))
                idx = torch.randperm(B)[:k]
                x[idx] = add_trigger_batch(x[idx], size=4, value=1.0)
                y[idx] = target_label
            optimizer.zero_grad()
            out = model(x)
            loss = F.cross_entropy(out, y)
            loss.backward()
            optimizer.step()
    return get_weights(model), len(trainloader.dataset)

# Evaluate helpers
@torch.no_grad()
def compute_asr_backdoor(model):
    model.to(device); model.eval()
    correct_target = 0; total = 0
    batch_size = 256
    for i in range(0, len(triggered_X), batch_size):
        xb = triggered_X[i:i+batch_size].to(device)
        out = model(xb)
        preds = out.argmax(dim=1)
        total += preds.size(0)
        correct_target += (preds == target_label).sum().item()
    return correct_target / total if total > 0 else 0.0

def evaluate_global(model):
    model.to(device); model.eval()
    correct = 0; total = 0
    with torch.no_grad():
        for x,y in testloader:
            x,y = x.to(device), y.to(device)
            out = model(x); preds = out.argmax(dim=1)
            correct += (preds == y).sum().item(); total += x.size(0)
    return correct/total if total>0 else 0.0

# Aggregators
def trimmed_mean_aggregate(client_updates, trim_ratio=0.2):
    n_clients = len(client_updates)
    trimmed = []
    for layer_idx in range(len(client_updates[0])):
        stacked = np.stack([client_updates[c][layer_idx].reshape(-1) for c in range(n_clients)], axis=0)
        trim_count = int(np.floor(trim_ratio * n_clients))
        sorted_vals = np.sort(stacked, axis=0)
        if trim_count > 0:
            trimmed_vals = sorted_vals[trim_count: n_clients-trim_count, :]
        else:
            trimmed_vals = sorted_vals
        mean_vals = np.mean(trimmed_vals, axis=0)
        trimmed.append(mean_vals.reshape(client_updates[0][layer_idx].shape))
    return trimmed

def fools_gold_weights(client_updates, global_weights, eps=1e-8):
    # client_updates: list of client weight lists (numpy arrays)
    n = len(client_updates)
    # build flatten delta vectors (client_w - global_w)
    deltas = []
    for c in range(n):
        flat = np.concatenate([ (client_updates[c][i] - global_weights[i]).reshape(-1) for i in range(len(global_weights)) ])
        deltas.append(flat)
    deltas = np.stack(deltas, axis=0)  # shape (n, D)
    # compute cosine similarity matrix (n x n)
    sims = cosine_similarity(deltas)  # values in [-1,1], diagonal=1
    np.fill_diagonal(sims, 0.0)
    # for each client, take maximum similarity to any other client
    max_sim = sims.max(axis=1)  # shape (n,)
    # derive weight = 1 - max_sim (so high-sim => low weight)
    raw_w = 1.0 - max_sim
    # clip negative to small positive
    raw_w = np.clip(raw_w, a_min=eps, a_max=None)
    # normalize to sum=1
    norm_w = raw_w / raw_w.sum()
    return norm_w  # length n, sum=1

def aggregate_with_weights(client_updates, weights):
    # weights: numpy array shape (n,)
    n = len(client_updates)
    avg = []
    for layer_idx in range(len(client_updates[0])):
        accum = np.zeros(client_updates[0][layer_idx].shape, dtype=client_updates[0][layer_idx].dtype)
        for i in range(n):
            accum += client_updates[i][layer_idx] * weights[i]
        avg.append(accum)
    return avg

# Run baseline (FedAvg) for reference
print("Running baseline FedAvg (for comparison)...")
global_model = Net().to(device); global_weights = get_weights(global_model)
for rnd in range(1, NUM_ROUNDS+1):
    client_updates = []; client_ns = []
    for i, trainloader in enumerate(client_loaders):
        local_model = Net()
        is_mal = (i in malicious_client_ids)
        local_epochs = ATTACKER_LOCAL_EPOCHS if is_mal else LOCAL_EPOCHS
        w, n = local_train_backdoor(local_model, trainloader, global_weights,
                                    local_epochs=local_epochs, lr=lr, is_malicious=is_mal, poison_frac=poison_frac)
        client_updates.append(w); client_ns.append(n)
    # FedAvg average
    total_n = sum(client_ns)
    avg = []
    for li in range(len(client_updates[0])):
        accum = np.zeros(client_updates[0][li].shape, dtype=client_updates[0][li].dtype)
        for ci in range(len(client_updates)):
            accum += client_updates[ci][li] * (client_ns[ci]/total_n)
        avg.append(accum)
    global_weights = avg; set_weights(global_model, global_weights)
    print(f"Round {rnd} -> acc {evaluate_global(global_model)*100:.2f} %, ASR {compute_asr_backdoor(global_model)*100:.2f}%")

# Trimmed mean sweep
trim_values = [0.05, 0.1, 0.2, 0.3]
trim_results = {}
for trim in trim_values:
    print("\nTrim ratio:", trim)
    global_model = Net().to(device); global_weights = get_weights(global_model)
    for rnd in range(1, NUM_ROUNDS+1):
        client_updates = []; client_ns = []
        for i, trainloader in enumerate(client_loaders):
            local_model = Net()
            is_mal = (i in malicious_client_ids)
            local_epochs = ATTACKER_LOCAL_EPOCHS if is_mal else LOCAL_EPOCHS
            w, n = local_train_backdoor(local_model, trainloader, global_weights,
                                        local_epochs=local_epochs, lr=lr, is_malicious=is_mal, poison_frac=poison_frac)
            client_updates.append(w); client_ns.append(n)
        avg = trimmed_mean_aggregate(client_updates, trim_ratio=trim)
        global_weights = avg; set_weights(global_model, global_weights)
        print(f"Round {rnd} -> acc {evaluate_global(global_model)*100:.2f} %, ASR {compute_asr_backdoor(global_model)*100:.2f}%")
    trim_results[trim] = (evaluate_global(global_model), compute_asr_backdoor(global_model))

# FoolsGold-style aggregator (simple round-local version)
print("\nRunning FoolsGold-style round-local aggregator...")
global_model = Net().to(device); global_weights = get_weights(global_model)
for rnd in range(1, NUM_ROUNDS+1):
    client_updates = []; client_ns = []
    for i, trainloader in enumerate(client_loaders):
        local_model = Net()
        is_mal = (i in malicious_client_ids)
        local_epochs = ATTACKER_LOCAL_EPOCHS if is_mal else LOCAL_EPOCHS
        w, n = local_train_backdoor(local_model, trainloader, global_weights,
                                    local_epochs=local_epochs, lr=lr, is_malicious=is_mal, poison_frac=poison_frac)
        client_updates.append(w); client_ns.append(n)
    # compute FG weights
    wgts = fools_gold_weights(client_updates, global_weights)
    avg = aggregate_with_weights(client_updates, wgts)
    global_weights = avg; set_weights(global_model, global_weights)
    print(f"Round {rnd} -> acc {evaluate_global(global_model)*100:.2f} %, ASR {compute_asr_backdoor(global_model)*100:.2f}%")

print("\nTrim sweep results (final):")
for t,r in trim_results.items():
    print(f"trim {t} -> acc {r[0]*100:.2f}%, ASR {r[1]*100:.2f}%")


In [None]:
# === Plot Accuracy & ASR vs Rounds for FedAvg, TrimmedMean(0.2), and Krum ===
import numpy as np
import torch, torch.nn.functional as F
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Helper functions reused ---
@torch.no_grad()
def compute_asr_backdoor(model):
    model.to(device); model.eval()
    correct_target, total = 0, 0
    batch_size = 256
    for i in range(0, len(triggered_X), batch_size):
        xb = triggered_X[i:i+batch_size].to(device)
        preds = model(xb).argmax(dim=1)
        total += preds.size(0)
        correct_target += (preds == target_label).sum().item()
    return correct_target / total if total > 0 else 0.0

def evaluate_global(model):
    model.to(device); model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total

# --- Aggregators ---
def fedavg_aggregate(client_updates, client_ns):
    total_n = sum(client_ns)
    avg = []
    for li in range(len(client_updates[0])):
        accum = np.zeros_like(client_updates[0][li])
        for ci in range(len(client_updates)):
            accum += client_updates[ci][li] * (client_ns[ci]/total_n)
        avg.append(accum)
    return avg

def trimmed_mean_aggregate(client_updates, trim_ratio=0.2):
    n = len(client_updates)
    trimmed = []
    for li in range(len(client_updates[0])):
        stacked = np.stack([client_updates[c][li].reshape(-1) for c in range(n)], axis=0)
        trim = int(np.floor(trim_ratio*n))
        sorted_vals = np.sort(stacked, axis=0)
        kept = sorted_vals[trim:n-trim, :] if trim>0 else sorted_vals
        mean_vals = np.mean(kept, axis=0)
        trimmed.append(mean_vals.reshape(client_updates[0][li].shape))
    return trimmed

def krum_aggregate(client_updates, f=2):
    n = len(client_updates)
    flats = [np.concatenate([l.ravel() for l in w]) for w in client_updates]
    dmat = np.zeros((n,n))
    for i in range(n):
        for j in range(i+1,n):
            d = np.sum((flats[i]-flats[j])**2)
            dmat[i,j]=dmat[j,i]=d
    scores=[]
    for i in range(n):
        closest = np.sort(dmat[i])[:n-f-1]
        scores.append(np.sum(closest))
    winner = np.argmin(scores)
    return client_updates[winner]

# --- Run and record results for 3 aggregators ---
def run_experiment(agg_name, trim_ratio=0.2):
    global_model = Net().to(device)
    global_weights = get_weights(global_model)
    acc_hist, asr_hist = [], []
    for rnd in range(1,6):
        client_updates, client_ns = [], []
        for i, trainloader in enumerate(client_loaders):
            local_model = Net()
            is_mal = (i in malicious_client_ids)
            local_epochs = ATTACKER_LOCAL_EPOCHS if is_mal else LOCAL_EPOCHS
            w, n = local_train_backdoor(local_model, trainloader, global_weights,
                                        local_epochs=local_epochs, lr=lr,
                                        is_malicious=is_mal, poison_frac=poison_frac)
            client_updates.append(w); client_ns.append(n)
        if agg_name=="fedavg":
            global_weights = fedavg_aggregate(client_updates, client_ns)
        elif agg_name=="trimmed_mean":
            global_weights = trimmed_mean_aggregate(client_updates, trim_ratio)
        elif agg_name=="krum":
            global_weights = krum_aggregate(client_updates, f=2)
        set_weights(global_model, global_weights)
        acc = evaluate_global(global_model)
        asr = compute_asr_backdoor(global_model)
        acc_hist.append(acc); asr_hist.append(asr)
        print(f"[{agg_name}] Round {rnd}: acc {acc*100:.2f}%, ASR {asr*100:.2f}%")
    return np.array(acc_hist), np.array(asr_hist)

histories = {}
for name in ["fedavg","trimmed_mean","krum"]:
    print(f"\n=== Running {name} ===")
    histories[name] = run_experiment(name, trim_ratio=0.2 if name=="trimmed_mean" else None)

# --- Plot ---
rounds = np.arange(1,6)
plt.figure(figsize=(7,4))
for name,c in zip(histories.keys(),["r","g","b"]):
    plt.plot(rounds, histories[name][0]*100, marker="o", label=f"{name} Accuracy", color=c)
plt.xlabel("Round"); plt.ylabel("Accuracy (%)"); plt.title("Clean Accuracy vs Rounds")
plt.legend(); plt.grid(True); plt.show()

plt.figure(figsize=(7,4))
for name,c in zip(histories.keys(),["r","g","b"]):
    plt.plot(rounds, histories[name][1]*100, marker="o", label=f"{name} ASR", color=c)
plt.xlabel("Round"); plt.ylabel("ASR (%)"); plt.title("Backdoor Attack Success Rate vs Rounds")
plt.legend(); plt.grid(True); plt.show()
