### Untargeted Attack

In [15]:
# -------------------
# Full Implementation with Attack Selection, Percentage Control, and Real Defense Methods
# -------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import pandas as pd
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import math

# -------------------
# CNN Model
# -------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.dropout1 = nn.Dropout(0.25)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

# -------------------
# Data Partition (IID or non-IID using Dirichlet)
# -------------------
def partition_data(dataset, num_clients=100, alpha=0.5, iid=True):
    if iid:
        data_split = torch.utils.data.random_split(dataset, [len(dataset) // num_clients] * num_clients)
    else:
        labels = np.array(dataset.targets)
        idx_by_class = [np.where(labels == i)[0] for i in range(10)]
        data_split = [[] for _ in range(num_clients)]
        for c, idx in enumerate(idx_by_class):
            np.random.shuffle(idx)
            proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
            proportions = (np.cumsum(proportions) * len(idx)).astype(int)[:-1]
            splits = np.split(idx, proportions)
            for client_id, split in enumerate(splits):
                data_split[client_id].extend(split)
        data_split = [Subset(dataset, idxs) for idxs in data_split]
    return {i: DataLoader(data_split[i], batch_size=32, shuffle=True) for i in range(num_clients)}

# -------------------
# Attack Injection
# -------------------
malicious_client_ids = []

def apply_label_flipping_attack(loader):
    flip_map = {0:1, 1:0, 2:3, 3:2, 4:5, 5:4, 6:7, 7:6, 8:9, 9:8}
    attacked = []
    for x, y in loader:
        y_flipped = torch.tensor([flip_map[int(label)] for label in y])
        attacked.extend([(x[i], y_flipped[i]) for i in range(len(x))])
    x_tensor = torch.stack([d[0] for d in attacked])
    y_tensor = torch.tensor([d[1] for d in attacked])
    return DataLoader(TensorDataset(x_tensor, y_tensor), batch_size=32, shuffle=True)

def apply_feature_manipulation_attack(loader):
    attacked = []
    for x, y in loader:
        # Stronger manipulation: add structured noise and masking
        noise = torch.randn_like(x) * 0.7  # Increased noise factor
        mask = (torch.rand_like(x) > 0.7).float()  # Random masking
        x = x * mask + noise * (1 - mask)  # Apply noise and mask
        attacked.extend([(x[i], y[i]) for i in range(len(x))])
    x_tensor = torch.stack([d[0] for d in attacked])
    y_tensor = torch.tensor([d[1] for d in attacked])
    return DataLoader(TensorDataset(x_tensor, y_tensor), batch_size=32, shuffle=True)

def apply_poisoning_attack(loader):
    attacked = []
    for x, y in loader:
        # Complex poisoning: inject strong pixel shift and structured trigger pattern
        x = torch.clamp(x + 0.3 * torch.ones_like(x), 0, 1)  # Shift intensities
        trigger = torch.zeros_like(x)
        trigger[:, :, -3:, -3:] = 1.0  # Add white square trigger at bottom-right
        x = torch.clamp(x + trigger, 0, 1)
        attacked.extend([(x[i], y[i]) for i in range(len(x))])
    x_tensor = torch.stack([d[0] for d in attacked])
    y_tensor = torch.tensor([d[1] for d in attacked])
    return DataLoader(TensorDataset(x_tensor, y_tensor), batch_size=32, shuffle=True)

# -------------------
# Combined Attack Function
# -------------------
def apply_combined_attacks(client_data, attack_types, attack_ratios):
    # Select malicious clients for each attack separately
    attack_specific_ids = {}
    total_clients = list(client_data.keys())
    num_clients = len(total_clients)

    for attack in attack_types:
        count = int(num_clients * attack_ratios.get(attack, 0))
        attack_specific_ids[attack] = random.sample(total_clients, count)

    malicious_ids = list(set().union(*attack_specific_ids.values()))
    global malicious_client_ids
    malicious_client_ids = malicious_ids

    for client_id in malicious_ids:
        loader = client_data[client_id]
        if 'label_flipping' in attack_types and client_id in attack_specific_ids.get('label_flipping', []):
            loader = apply_label_flipping_attack(loader)
        if 'feature_manipulation' in attack_types and client_id in attack_specific_ids.get('feature_manipulation', []):
            loader = apply_feature_manipulation_attack(loader)
        if 'poisoning' in attack_types and client_id in attack_specific_ids.get('poisoning', []):
            loader = apply_poisoning_attack(loader)
        client_data[client_id] = loader

    return client_data

# -------------------
# Evaluation Function
# -------------------
def evaluate_model(model, test_loader, device):
    model.eval()
    y_true = []
    y_pred = []
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in test_loader:
            if isinstance(batch, (list, tuple)) and len(batch) == 2:
                data, target = batch
            else:
                raise ValueError("Each batch in test_loader must be a (data, target) tuple.")

            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            _, predicted = torch.max(output.data, 1)
            correct += (predicted == target).sum().item()
            total += target.size(0)
            y_true.extend(target.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    cm = confusion_matrix(y_true, y_pred, labels=list(range(10)))
    tp = np.diag(cm)
    fp = np.sum(cm, axis=0) - tp
    fn = np.sum(cm, axis=1) - tp
    tn = np.sum(cm) - (tp + fp + fn)

    accuracy = correct / total
    precision = precision_score(y_true, y_pred, average='macro', zero_division=0)
    recall = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    fpr = np.mean(fp / (fp + tn + 1e-6))
    tpr = np.mean(tp / (tp + fn + 1e-6))
    test_loss /= total

    return accuracy, precision, recall, f1, fpr, tpr, test_loss

# -------------------
# Local Training Function (with Loss/Accuracy Logging)
# -------------------
def local_train(model, train_loader, epochs, device):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(epochs):
        correct = 0
        total = 0
        total_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += data.size(0)
        avg_loss = total_loss / total
        avg_acc = correct / total
        print(f"Local Epoch {epoch+1} - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}")
    return model.state_dict()

# -------------------
# Additional Metric Functions
# -------------------
def compute_model_drift(prev_weights, new_weights):
    return sum(torch.norm(prev_weights[k] - new_weights[k]) for k in prev_weights) / len(prev_weights)

def compute_entropy(weights):
    weights = np.array(weights)
    return -np.sum(weights * np.log(weights + 1e-10))

def detect_malicious_clients(true_ids, detected_ids):
    true_set = set(true_ids)
    detected_set = set(detected_ids)
    tp = len(true_set & detected_set)
    fn = len(true_set - detected_set)
    fp = len(detected_set - true_set)
    attack_detection_rate = tp / (tp + fn + 1e-6)
    exclusion_rate = len(detected_set) / (len(true_set | detected_set) + 1e-6)
    return attack_detection_rate, exclusion_rate




prev_global_model = {}
# -------------------
# Aggregation Functions (Full Implementations)
# -------------------

global detected_malicious_ids
detected_malicious_ids = []
global aggregation_weights
aggregation_weights = []

def fltrust(updates):
    reference_model = updates[0]  # trusted server model, in real FLTrust, it should be fixed
    scores = []
    for update in updates:
        cosine_sim = sum(torch.nn.functional.cosine_similarity(update[k].flatten(), reference_model[k].flatten(), dim=0) for k in update)
        scores.append(cosine_sim / len(update))
    weights = torch.softmax(torch.tensor(scores), dim=0)
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = [i for i, w in enumerate(weights) if w < 0.01]
    aggregation_weights = weights.tolist()
    return {k: sum(weights[i] * updates[i][k] for i in range(len(updates))) for k in updates[0]}

def flame(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    norms = [sum(torch.norm(u[k]) for k in u) for u in updates]
    weights = torch.softmax(-torch.tensor(norms), dim=0)
    aggregation_weights = weights.tolist()
    norms = [sum(torch.norm(u[k]) for k in u) for u in updates]
    weights = torch.softmax(-torch.tensor(norms), dim=0)
    return {k: sum(weights[i] * updates[i][k] for i in range(len(updates))) for k in updates[0]}

def flcert(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = []

    # Step 1: Compute the mean update (global reference)
    avg = fedavg(updates)

    # Step 2: Compute cosine similarity between each update and the global mean
    cosine_similarities = []
    for u in updates:
        sim = torch.stack([F.cosine_similarity(u[k].flatten(), avg[k].flatten(), dim=0) for k in u]).mean()
        cosine_similarities.append(sim.item())

    # Step 3: Thresholding (e.g., keep clients above 0.5 similarity)
    threshold = 0.5
    filtered_indices = [i for i, sim in enumerate(cosine_similarities) if sim >= threshold]
    detected_malicious_ids = [i for i in range(len(updates)) if i not in filtered_indices]

    if filtered_indices:
        filtered_updates = [updates[i] for i in filtered_indices]
        aggregation_weights = [1.0 / len(filtered_indices) if i in filtered_indices else 0.0 for i in range(len(updates))]
        return fedavg(filtered_updates)
    else:
        aggregation_weights = [0.0] * len(updates)
        return avg
    avg = fedavg(updates)
    filtered = [u for u in updates if sum(torch.norm(u[k] - avg[k]) for k in u) < 5.0]
    return fedavg(filtered) if filtered else avg

def foolsgold(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    gradients = [torch.cat([p.flatten() for p in u.values()]) for u in updates]
    n = len(gradients)
    sim = torch.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if i != j:
                sim[i][j] = torch.nn.functional.cosine_similarity(gradients[i], gradients[j], dim=0)
    max_sim = sim.max(dim=1).values
    weights = 1.0 - max_sim
    weights = weights / weights.sum()
    aggregation_weights = weights.tolist()
    return {k: sum(weights[i] * updates[i][k] for i in range(n)) for k in updates[0]}

def rfa(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []

    # Step 1: Flatten updates for geometric median computation
    vectors = [torch.cat([u[k].flatten() for k in sorted(u.keys())]) for u in updates]

    # Step 2: Initialize with coordinate-wise median
    stacked = torch.stack(vectors)
    median = torch.median(stacked, dim=0)[0]

    # Step 3: Iterative Weiszfeld algorithm
    for _ in range(5):
        distances = torch.stack([torch.norm(v - median) for v in vectors]) + 1e-10
        weights = 1.0 / distances
        weights /= weights.sum()
        median = sum(weights[i] * vectors[i] for i in range(len(vectors)))

    # Step 4: Unflatten median back to model format
    example = updates[0]
    aggregated = {}
    pointer = 0
    for k in sorted(example.keys()):
        shape = example[k].shape
        numel = example[k].numel()
        aggregated[k] = median[pointer:pointer + numel].view(shape)
        pointer += numel

    # Mark clients with high distance as malicious
    detected_malicious_ids = [i for i in range(len(vectors)) if torch.norm(vectors[i] - median) > 2.0]
    aggregation_weights = [1.0 / len(vectors)] * len(vectors)

    return aggregated
    mean_update = fedavg(updates)
    filtered = [u for u in updates if sum(torch.norm(u[k] - mean_update[k]) for k in u) < 3.0]
    return fedavg(filtered) if filtered else mean_update

def auror(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []

    # Step 1: Flatten all updates
    flat_updates = [torch.cat([u[k].flatten() for k in sorted(u.keys())]) for u in updates]

    # Step 2: Compute pairwise L2 distances
    num_clients = len(flat_updates)
    distances = torch.zeros((num_clients, num_clients))
    for i in range(num_clients):
        for j in range(i + 1, num_clients):
            dist = torch.norm(flat_updates[i] - flat_updates[j]).item()
            distances[i][j] = distances[j][i] = dist

    # Step 3: Compute score based on distance to k nearest neighbors
    k = max(1, num_clients // 10)
    scores = []
    for i in range(num_clients):
        sorted_dists = torch.sort(distances[i])[0]
        scores.append(sorted_dists[1:k + 1].mean().item())  # Exclude self-distance

    # Step 4: Select clients with lowest scores (most consistent)
    sorted_ids = sorted(range(num_clients), key=lambda i: scores[i])
    selected_ids = sorted_ids[:num_clients // 2]

    detected_malicious_ids = [i for i in range(num_clients) if i not in selected_ids]
    aggregation_weights = [1.0 / len(selected_ids) if i in selected_ids else 0.0 for i in range(num_clients)]

    return fedavg([updates[i] for i in selected_ids])
    #global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    scores = []
    for i in range(len(updates)):
        sim_sum = sum(torch.sum((updates[i][k] - updates[j][k])**2) for j in range(len(updates)) if j != i for k in updates[i])
        scores.append(sim_sum)
    best_ids = sorted(range(len(scores)), key=lambda i: scores[i])[:len(updates)//2]
    return fedavg([updates[i] for i in best_ids])

def bulyan(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = []

    # Step 1: Compute distances between updates
    n = len(updates)
    distances = torch.zeros((n, n))
    for i in range(n):
        for j in range(i + 1, n):
            d = sum(torch.norm(updates[i][k] - updates[j][k])**2 for k in updates[i])
            distances[i][j] = distances[j][i] = d

    # Step 2: Krum candidate selection (select n - 2f - 2 clients)
    f = max(1, n // 10)
    krum_candidates = []
    for i in range(n):
        sorted_dists = sorted(distances[i][j] for j in range(n) if j != i)
        krum_score = sum(sorted_dists[:n - f - 2])
        krum_candidates.append((i, krum_score))
    krum_candidates.sort(key=lambda x: x[1])
    selected_ids = [idx for idx, _ in krum_candidates[:n - 2 * f]]

    # Step 3: For each parameter, collect values and perform trimmed mean
    trimmed_updates = [updates[i] for i in selected_ids]
    aggregation_weights = [1.0 / len(trimmed_updates) if i in selected_ids else 0.0 for i in range(n)]
    detected_malicious_ids = [i for i in range(n) if i not in selected_ids]

    return trimmed_mean(trimmed_updates)

def rfed(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []

    # Flatten each update and compute local variances
    flat_updates = [torch.cat([u[k].flatten() for k in sorted(u.keys())]) for u in updates]
    stacked = torch.stack(flat_updates)
    feature_variance = torch.var(stacked, dim=0)

    # Compute Mahalanobis-like distance from global variance
    distances = []
    for i in range(len(flat_updates)):
        diff = flat_updates[i] - stacked.mean(dim=0)
        dist = (diff ** 2 / (feature_variance + 1e-6)).mean()
        distances.append(dist)

    weights = torch.softmax(-torch.tensor(distances), dim=0)
    aggregation_weights = weights.tolist()

    # Aggregate
    aggregated = {}
    for k in updates[0].keys():
        aggregated[k] = sum(weights[i] * updates[i][k] for i in range(len(updates)))

    return aggregated

def fedavg(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = [1.0 / len(updates)] * len(updates)
    return {k: torch.stack([u[k] for u in updates], 0).mean(0) for k in updates[0].keys()}

def median(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = [1.0 / len(updates)] * len(updates)
    return {k: torch.median(torch.stack([u[k] for u in updates], 0), dim=0)[0] for k in updates[0].keys()}

def trimmed_mean(updates, beta=0.1):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = [1.0 / len(updates)] * len(updates)
    n = len(updates)
    trim = int(n * beta)
    agg = {}
    for k in updates[0].keys():
        stacked = torch.stack([u[k] for u in updates], dim=0)
        sorted_vals, _ = stacked.sort(dim=0)
        trimmed = sorted_vals[trim:n-trim]
        agg[k] = trimmed.mean(dim=0)
    return agg

def krum(updates, f=1):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = [1.0 / len(updates)] * len(updates)
    scores = []
    for i, ui in enumerate(updates):
        dists = []
        for j, uj in enumerate(updates):
            if i != j:
                dist = sum(torch.norm(ui[k] - uj[k])**2 for k in ui)
                dists.append(dist)
        dists.sort()
        scores.append((i, sum(dists[:len(updates)-f-2])) if len(dists) > f+2 else (i, float('inf')))
    selected = min(scores, key=lambda x: x[1])[0]
    return updates[selected]

AGGREGATION_FUNCTIONS = {
    'FedAVG': fedavg,
    'TrimmedMean': trimmed_mean,
    'Median': median,
    'Krum': krum,
    'FLTrust': fltrust,
    'FLAME': flame,
    'FLCert': flcert,
    'FoolsGold': foolsgold,
    'RFA': rfa,
    'Auror': auror,
    'Bulyan': bulyan,
    'RFed': rfed
}


# -------------------
# Federated Training Function
# -------------------
def federated_training(rounds, epochs, client_fraction, attack_types, attack_ratios, defense_method, iid, alpha, device):
    transform = transforms.Compose([transforms.ToTensor()])
    mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    client_data = partition_data(mnist_train, num_clients=100, alpha=alpha, iid=iid)

    global malicious_client_ids
    malicious_client_ids = []
    if attack_types and attack_ratios:
        client_data = apply_combined_attacks(client_data, attack_types, attack_ratios)

    global_model = SimpleCNN().to(device)
    global_weights = global_model.state_dict()
    prev_weights = {k: v.clone() for k, v in global_weights.items()}

    results = []
    test_loader = DataLoader(mnist_test, batch_size=128, shuffle=False)

    for rnd in range(rounds):
        global detected_malicious_ids, aggregation_weights
        detected_malicious_ids = []
        aggregation_weights = []

        selected_clients = random.sample(list(client_data.keys()), int(client_fraction * len(client_data)))
        local_updates = []
        total_train_loss = 0.0
        total_train_correct = 0
        total_train_samples = 0

        for cid in selected_clients:
            local_model = SimpleCNN().to(device)
            local_model.load_state_dict(global_weights)
            optimizer = optim.SGD(local_model.parameters(), lr=0.01)

            local_model.train()
            for _ in range(epochs):
                for x, y in client_data[cid]:
                    x, y = x.to(device), y.to(device)
                    optimizer.zero_grad()
                    output = local_model(x)
                    loss = F.cross_entropy(output, y)
                    loss.backward()
                    optimizer.step()

                    total_train_loss += loss.item() * x.size(0)
                    total_train_correct += (output.argmax(1) == y).sum().item()
                    total_train_samples += x.size(0)

            local_updates.append({k: v.cpu().detach() for k, v in local_model.state_dict().items()})

        if defense_method in AGGREGATION_FUNCTIONS:
            global_weights = AGGREGATION_FUNCTIONS[defense_method](local_updates)
        else:
            raise ValueError(f"Unknown defense method: {defense_method}")

        global_model.load_state_dict(global_weights)

        acc, prec, rec, f1, fpr, tpr, test_loss = evaluate_model(global_model, test_loader, device)
        attack_acc, exclusion_rate = detect_malicious_clients(malicious_client_ids, detected_malicious_ids)
        entropy = compute_entropy(aggregation_weights) if aggregation_weights else 0.0
        drift = compute_model_drift(prev_weights, global_weights)
        prev_weights = {k: v.clone() for k, v in global_weights.items()}

        avg_train_loss = total_train_loss / total_train_samples if total_train_samples > 0 else 0.0
        avg_train_acc = total_train_correct / total_train_samples if total_train_samples > 0 else 0.0

        results.append({
            'Round': rnd + 1, 'Accuracy': acc, 'Precision': prec, 'Recall': rec, 'F1': f1,
            'FPR': fpr, 'TPR': tpr, 'TestLoss': test_loss,
            'TrainLoss': avg_train_loss, 'TrainAccuracy': avg_train_acc,
            'AttackDetectAcc': attack_acc, 'ExclusionRate': exclusion_rate,
            'Entropy': entropy, 'ModelDrift': drift
        })

        print(f"Round {rnd + 1}: Acc={acc:.4f}, Prec={prec:.4f}, Rec={rec:.4f}, F1={f1:.4f}, "
              f"FPR={fpr:.4f}, TPR={tpr:.4f}, TestLoss={test_loss:.4f}, TrainLoss={avg_train_loss:.4f}, "
              f"TrainAcc={avg_train_acc:.4f}, AttackDetectAcc={attack_acc:.4f}, "
              f"ExclusionRate={exclusion_rate:.4f}, Entropy={entropy:.4f}, Drift={drift:.4f}")

    pd.DataFrame(results).to_csv('fl_full_results.csv', index=False)
    print("Federated training completed. Metrics saved to fl_full_results.csv")



In [16]:
federated_training(
    rounds=5,
    epochs=2,
    client_fraction=0.1,
    attack_types=['label_flipping', 'feature_manipulation'],
    attack_ratios={'label_flipping': 0.2, 'feature_manipulation': 0.3},
    defense_method='Median',
    iid=False,
    alpha=0.3,
    device='cpu'
)


Round 1: Acc=0.0984, Prec=0.0498, Rec=0.1002, F1=0.0183, FPR=0.1000, TPR=0.1002, TestLoss=2.2955, TrainLoss=1.7670, TrainAcc=0.4308, AttackDetectAcc=0.0000, ExclusionRate=0.0000, Entropy=2.3026, Drift=0.0559
Round 2: Acc=0.2314, Prec=0.1705, Rec=0.2333, F1=0.1168, FPR=0.0854, TPR=0.2333, TestLoss=2.2576, TrainLoss=1.7336, TrainAcc=0.4759, AttackDetectAcc=0.0000, ExclusionRate=0.0000, Entropy=2.3026, Drift=0.0527
Round 3: Acc=0.1032, Prec=0.0103, Rec=0.1000, F1=0.0187, FPR=0.1000, TPR=0.1000, TestLoss=2.2593, TrainLoss=1.4211, TrainAcc=0.5157, AttackDetectAcc=0.0000, ExclusionRate=0.0000, Entropy=2.3026, Drift=0.0542
Round 4: Acc=0.2715, Prec=0.2069, Rec=0.2718, F1=0.1652, FPR=0.0810, TPR=0.2718, TestLoss=2.1499, TrainLoss=1.4445, TrainAcc=0.5185, AttackDetectAcc=0.0000, ExclusionRate=0.0000, Entropy=2.3026, Drift=0.0445
Round 5: Acc=0.2411, Prec=0.4312, Rec=0.2430, F1=0.2003, FPR=0.0841, TPR=0.2430, TestLoss=1.9956, TrainLoss=1.3034, TrainAcc=0.5988, AttackDetectAcc=0.0000, ExclusionRa

In [17]:
import time

# Start the timer
start_time = time.time()

# Execute federated training
federated_training(
    rounds=5,
    epochs=2,
    client_fraction=0.1,
    attack_types=['label_flipping', 'feature_manipulation'],
    attack_ratios={'label_flipping': 0.2, 'feature_manipulation': 0.3},
    defense_method='Median',
    iid=False,
    alpha=0.3,
    device='cpu'
)

# End the timer
end_time = time.time()

# Compute elapsed time in minutes and seconds
elapsed = end_time - start_time
minutes = int(elapsed // 60)
seconds = int(elapsed % 60)

print(f"\nTotal Training Time: {minutes} minutes and {seconds} seconds")


Round 1: Acc=0.0980, Prec=0.0098, Rec=0.1000, F1=0.0179, FPR=0.1000, TPR=0.1000, TestLoss=2.3038, TrainLoss=1.7923, TrainAcc=0.4145, AttackDetectAcc=0.0000, ExclusionRate=0.0000, Entropy=2.3026, Drift=0.0623
Round 2: Acc=0.0980, Prec=0.0098, Rec=0.1000, F1=0.0179, FPR=0.1000, TPR=0.1000, TestLoss=2.2894, TrainLoss=1.6175, TrainAcc=0.4624, AttackDetectAcc=0.0000, ExclusionRate=0.0000, Entropy=2.3026, Drift=0.0373
Round 3: Acc=0.1935, Prec=0.0540, Rec=0.1945, F1=0.0804, FPR=0.0897, TPR=0.1945, TestLoss=2.2836, TrainLoss=1.3730, TrainAcc=0.5680, AttackDetectAcc=0.0000, ExclusionRate=0.0000, Entropy=2.3026, Drift=0.0644
Round 4: Acc=0.1949, Prec=0.2031, Rec=0.1950, F1=0.1193, FPR=0.0895, TPR=0.1950, TestLoss=2.1320, TrainLoss=1.3247, TrainAcc=0.5618, AttackDetectAcc=0.0000, ExclusionRate=0.0000, Entropy=2.3026, Drift=0.0500
Round 5: Acc=0.5412, Prec=0.6531, Rec=0.5310, F1=0.4821, FPR=0.0510, TPR=0.5310, TestLoss=1.8198, TrainLoss=1.0280, TrainAcc=0.6806, AttackDetectAcc=0.0000, ExclusionRa

### Targeted Attack

In [3]:
# -------------------
# Full Implementation with Attack Selection, Percentage Control, and Real Defense Methods
# -------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import pandas as pd
from torch.utils.data import DataLoader, Dataset, Subset, TensorDataset
from torchvision import datasets, transforms
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import math

# -------------------
# CNN Model
# -------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.dropout1 = nn.Dropout(0.25)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

# -------------------
# Data Partition (IID or non-IID using Dirichlet)
# -------------------
def partition_data(dataset, num_clients=100, alpha=0.5, iid=True):
    if iid:
        data_split = torch.utils.data.random_split(dataset, [len(dataset) // num_clients] * num_clients)
    else:
        labels = np.array(dataset.targets)
        idx_by_class = [np.where(labels == i)[0] for i in range(10)]
        data_split = [[] for _ in range(num_clients)]
        for c, idx in enumerate(idx_by_class):
            np.random.shuffle(idx)
            proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
            proportions = (np.cumsum(proportions) * len(idx)).astype(int)[:-1]
            splits = np.split(idx, proportions)
            for client_id, split in enumerate(splits):
                data_split[client_id].extend(split)
        data_split = [Subset(dataset, idxs) for idxs in data_split]
    return {i: DataLoader(data_split[i], batch_size=32, shuffle=True) for i in range(num_clients)}

# -------------------------
# Global Attack Configuration
# -------------------------
TARGET_LABEL_FOR_POISONING = 0
TARGET_LABEL_FLIP_MAP = {1: 7, 3: 8}  # Only flip 1→7, 3→8
FEATURE_MANIPULATION_TARGET_CLASS = 8
malicious_client_ids = []

# -------------------------
# Targeted Label Flipping Attack
# -------------------------
def apply_targeted_label_flipping_attack(loader):
    """
    Flip labels of specific classes (1 → 7, 3 → 8) in the MNIST dataset.
    Other labels are left unchanged.
    """
    attacked = []
    for x, y in loader:
        y_flipped = torch.tensor([
            TARGET_LABEL_FLIP_MAP.get(int(label), int(label)) for label in y
        ])
        attacked.extend([(x[i], y_flipped[i]) for i in range(len(x))])
    x_tensor = torch.stack([d[0] for d in attacked])
    y_tensor = torch.tensor([d[1] for d in attacked])
    return DataLoader(TensorDataset(x_tensor, y_tensor), batch_size=32, shuffle=True)

# -------------------------
# Targeted Feature Manipulation Attack
# -------------------------
def apply_targeted_feature_manipulation_attack(loader):
    """
    Adds structured noise only to images of label 8.
    Other inputs remain unchanged.
    """
    attacked = []
    for x, y in loader:
        x_attacked = x.clone()
        for i in range(len(x)):
            if y[i] == FEATURE_MANIPULATION_TARGET_CLASS:
                noise = torch.randn_like(x[i]) * 0.5
                mask = (torch.rand_like(x[i]) > 0.6).float()
                x_attacked[i] = x[i] * mask + noise * (1 - mask)
        attacked.extend([(x_attacked[i], y[i]) for i in range(len(x))])
    x_tensor = torch.stack([d[0] for d in attacked])
    y_tensor = torch.tensor([d[1] for d in attacked])
    return DataLoader(TensorDataset(x_tensor, y_tensor), batch_size=32, shuffle=True)

# -------------------------
# Targeted Data Poisoning Attack (Backdoor)
# -------------------------
def apply_targeted_poisoning_attack(loader):
    """
    Adds a white square trigger to each image and changes its label to the target (0).
    """
    attacked = []
    for x, y in loader:
        x_poisoned = torch.clamp(x + 0.2, 0, 1)  # Global intensity shift
        trigger = torch.zeros_like(x_poisoned)
        trigger[:, :, -3:, -3:] = 1.0  # Add white square at bottom-right corner
        x_poisoned = torch.clamp(x_poisoned + trigger, 0, 1)
        y_target = torch.full_like(y, TARGET_LABEL_FOR_POISONING)
        attacked.extend([(x_poisoned[i], y_target[i]) for i in range(len(x))])
    x_tensor = torch.stack([d[0] for d in attacked])
    y_tensor = torch.tensor([d[1] for d in attacked])
    return DataLoader(TensorDataset(x_tensor, y_tensor), batch_size=32, shuffle=True)

# -------------------------
# Combined Targeted Attack Function
# -------------------------
def apply_combined_targeted_attacks(client_data, attack_types, attack_ratios):
    """
    Applies selected targeted attacks to a subset of clients based on the given ratios.
    :param client_data: dictionary of {client_id: DataLoader}
    :param attack_types: list of attacks to apply: ['label_flipping', 'feature_manipulation', 'poisoning']
    :param attack_ratios: dictionary of ratios per attack type
    :return: modified client_data with attacked data for selected clients
    """
    attack_specific_ids = {}
    total_clients = list(client_data.keys())
    num_clients = len(total_clients)

    # Assign clients to attacks
    for attack in attack_types:
        count = int(num_clients * attack_ratios.get(attack, 0))
        attack_specific_ids[attack] = random.sample(total_clients, count)

    # Merge all malicious client IDs
    global malicious_client_ids
    malicious_client_ids = list(set().union(*attack_specific_ids.values()))

    # Apply attack on selected clients
    for client_id in malicious_client_ids:
        loader = client_data[client_id]
        if 'label_flipping' in attack_types and client_id in attack_specific_ids.get('label_flipping', []):
            loader = apply_targeted_label_flipping_attack(loader)
        if 'feature_manipulation' in attack_types and client_id in attack_specific_ids.get('feature_manipulation', []):
            loader = apply_targeted_feature_manipulation_attack(loader)
        if 'poisoning' in attack_types and client_id in attack_specific_ids.get('poisoning', []):
            loader = apply_targeted_poisoning_attack(loader)
        client_data[client_id] = loader

    return client_data


# -------------------
# Evaluation Function
# -------------------
def evaluate_model(model, test_loader, device):
    model.eval()
    y_true = []
    y_pred = []
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in test_loader:
            if isinstance(batch, (list, tuple)) and len(batch) == 2:
                data, target = batch
            else:
                raise ValueError("Each batch in test_loader must be a (data, target) tuple.")

            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            _, predicted = torch.max(output.data, 1)
            correct += (predicted == target).sum().item()
            total += target.size(0)
            y_true.extend(target.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    cm = confusion_matrix(y_true, y_pred, labels=list(range(10)))
    tp = np.diag(cm)
    fp = np.sum(cm, axis=0) - tp
    fn = np.sum(cm, axis=1) - tp
    tn = np.sum(cm) - (tp + fp + fn)

    accuracy = correct / total
    precision = precision_score(y_true, y_pred, average='macro', zero_division=0)
    recall = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    fpr = np.mean(fp / (fp + tn + 1e-6))
    tpr = np.mean(tp / (tp + fn + 1e-6))
    test_loss /= total

    return accuracy, precision, recall, f1, fpr, tpr, test_loss

# -------------------
# Local Training Function (with Loss/Accuracy Logging)
# -------------------
def local_train(model, train_loader, epochs, device):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(epochs):
        correct = 0
        total = 0
        total_loss = 0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * data.size(0)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += data.size(0)
        avg_loss = total_loss / total
        avg_acc = correct / total
        print(f"Local Epoch {epoch+1} - Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.4f}")
    return model.state_dict()

# -------------------
# Additional Metric Functions
# -------------------
def compute_model_drift(prev_weights, new_weights):
    return sum(torch.norm(prev_weights[k] - new_weights[k]) for k in prev_weights) / len(prev_weights)

def compute_entropy(weights):
    weights = np.array(weights)
    return -np.sum(weights * np.log(weights + 1e-10))

def detect_malicious_clients(true_ids, detected_ids):
    true_set = set(true_ids)
    detected_set = set(detected_ids)
    tp = len(true_set & detected_set)
    fn = len(true_set - detected_set)
    fp = len(detected_set - true_set)
    attack_detection_rate = tp / (tp + fn + 1e-6)
    exclusion_rate = len(detected_set) / (len(true_set | detected_set) + 1e-6)
    return attack_detection_rate, exclusion_rate




prev_global_model = {}
# -------------------
# Aggregation Functions (Full Implementations)
# -------------------

global detected_malicious_ids
detected_malicious_ids = []
global aggregation_weights
aggregation_weights = []

def fltrust(updates):
    reference_model = updates[0]  # trusted server model, in real FLTrust, it should be fixed
    scores = []
    for update in updates:
        cosine_sim = sum(torch.nn.functional.cosine_similarity(update[k].flatten(), reference_model[k].flatten(), dim=0) for k in update)
        scores.append(cosine_sim / len(update))
    weights = torch.softmax(torch.tensor(scores), dim=0)
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = [i for i, w in enumerate(weights) if w < 0.01]
    aggregation_weights = weights.tolist()
    return {k: sum(weights[i] * updates[i][k] for i in range(len(updates))) for k in updates[0]}

def flame(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    norms = [sum(torch.norm(u[k]) for k in u) for u in updates]
    weights = torch.softmax(-torch.tensor(norms), dim=0)
    aggregation_weights = weights.tolist()
    norms = [sum(torch.norm(u[k]) for k in u) for u in updates]
    weights = torch.softmax(-torch.tensor(norms), dim=0)
    return {k: sum(weights[i] * updates[i][k] for i in range(len(updates))) for k in updates[0]}

def flcert(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = []

    # Step 1: Compute the mean update (global reference)
    avg = fedavg(updates)

    # Step 2: Compute cosine similarity between each update and the global mean
    cosine_similarities = []
    for u in updates:
        sim = torch.stack([F.cosine_similarity(u[k].flatten(), avg[k].flatten(), dim=0) for k in u]).mean()
        cosine_similarities.append(sim.item())

    # Step 3: Thresholding (e.g., keep clients above 0.5 similarity)
    threshold = 0.5
    filtered_indices = [i for i, sim in enumerate(cosine_similarities) if sim >= threshold]
    detected_malicious_ids = [i for i in range(len(updates)) if i not in filtered_indices]

    if filtered_indices:
        filtered_updates = [updates[i] for i in filtered_indices]
        aggregation_weights = [1.0 / len(filtered_indices) if i in filtered_indices else 0.0 for i in range(len(updates))]
        return fedavg(filtered_updates)
    else:
        aggregation_weights = [0.0] * len(updates)
        return avg
    avg = fedavg(updates)
    filtered = [u for u in updates if sum(torch.norm(u[k] - avg[k]) for k in u) < 5.0]
    return fedavg(filtered) if filtered else avg

def foolsgold(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    gradients = [torch.cat([p.flatten() for p in u.values()]) for u in updates]
    n = len(gradients)
    sim = torch.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if i != j:
                sim[i][j] = torch.nn.functional.cosine_similarity(gradients[i], gradients[j], dim=0)
    max_sim = sim.max(dim=1).values
    weights = 1.0 - max_sim
    weights = weights / weights.sum()
    aggregation_weights = weights.tolist()
    return {k: sum(weights[i] * updates[i][k] for i in range(n)) for k in updates[0]}

def rfa(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []

    # Step 1: Flatten updates for geometric median computation
    vectors = [torch.cat([u[k].flatten() for k in sorted(u.keys())]) for u in updates]

    # Step 2: Initialize with coordinate-wise median
    stacked = torch.stack(vectors)
    median = torch.median(stacked, dim=0)[0]

    # Step 3: Iterative Weiszfeld algorithm
    for _ in range(5):
        distances = torch.stack([torch.norm(v - median) for v in vectors]) + 1e-10
        weights = 1.0 / distances
        weights /= weights.sum()
        median = sum(weights[i] * vectors[i] for i in range(len(vectors)))

    # Step 4: Unflatten median back to model format
    example = updates[0]
    aggregated = {}
    pointer = 0
    for k in sorted(example.keys()):
        shape = example[k].shape
        numel = example[k].numel()
        aggregated[k] = median[pointer:pointer + numel].view(shape)
        pointer += numel

    # Mark clients with high distance as malicious
    detected_malicious_ids = [i for i in range(len(vectors)) if torch.norm(vectors[i] - median) > 2.0]
    aggregation_weights = [1.0 / len(vectors)] * len(vectors)

    return aggregated
    mean_update = fedavg(updates)
    filtered = [u for u in updates if sum(torch.norm(u[k] - mean_update[k]) for k in u) < 3.0]
    return fedavg(filtered) if filtered else mean_update

def auror(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []

    # Step 1: Flatten all updates
    flat_updates = [torch.cat([u[k].flatten() for k in sorted(u.keys())]) for u in updates]

    # Step 2: Compute pairwise L2 distances
    num_clients = len(flat_updates)
    distances = torch.zeros((num_clients, num_clients))
    for i in range(num_clients):
        for j in range(i + 1, num_clients):
            dist = torch.norm(flat_updates[i] - flat_updates[j]).item()
            distances[i][j] = distances[j][i] = dist

    # Step 3: Compute score based on distance to k nearest neighbors
    k = max(1, num_clients // 10)
    scores = []
    for i in range(num_clients):
        sorted_dists = torch.sort(distances[i])[0]
        scores.append(sorted_dists[1:k + 1].mean().item())  # Exclude self-distance

    # Step 4: Select clients with lowest scores (most consistent)
    sorted_ids = sorted(range(num_clients), key=lambda i: scores[i])
    selected_ids = sorted_ids[:num_clients // 2]

    detected_malicious_ids = [i for i in range(num_clients) if i not in selected_ids]
    aggregation_weights = [1.0 / len(selected_ids) if i in selected_ids else 0.0 for i in range(num_clients)]

    return fedavg([updates[i] for i in selected_ids])
    #global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    scores = []
    for i in range(len(updates)):
        sim_sum = sum(torch.sum((updates[i][k] - updates[j][k])**2) for j in range(len(updates)) if j != i for k in updates[i])
        scores.append(sim_sum)
    best_ids = sorted(range(len(scores)), key=lambda i: scores[i])[:len(updates)//2]
    return fedavg([updates[i] for i in best_ids])

def bulyan(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = []

    # Step 1: Compute distances between updates
    n = len(updates)
    distances = torch.zeros((n, n))
    for i in range(n):
        for j in range(i + 1, n):
            d = sum(torch.norm(updates[i][k] - updates[j][k])**2 for k in updates[i])
            distances[i][j] = distances[j][i] = d

    # Step 2: Krum candidate selection (select n - 2f - 2 clients)
    f = max(1, n // 10)
    krum_candidates = []
    for i in range(n):
        sorted_dists = sorted(distances[i][j] for j in range(n) if j != i)
        krum_score = sum(sorted_dists[:n - f - 2])
        krum_candidates.append((i, krum_score))
    krum_candidates.sort(key=lambda x: x[1])
    selected_ids = [idx for idx, _ in krum_candidates[:n - 2 * f]]

    # Step 3: For each parameter, collect values and perform trimmed mean
    trimmed_updates = [updates[i] for i in selected_ids]
    aggregation_weights = [1.0 / len(trimmed_updates) if i in selected_ids else 0.0 for i in range(n)]
    detected_malicious_ids = [i for i in range(n) if i not in selected_ids]

    return trimmed_mean(trimmed_updates)

def rfed(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []

    # Flatten each update and compute local variances
    flat_updates = [torch.cat([u[k].flatten() for k in sorted(u.keys())]) for u in updates]
    stacked = torch.stack(flat_updates)
    feature_variance = torch.var(stacked, dim=0)

    # Compute Mahalanobis-like distance from global variance
    distances = []
    for i in range(len(flat_updates)):
        diff = flat_updates[i] - stacked.mean(dim=0)
        dist = (diff ** 2 / (feature_variance + 1e-6)).mean()
        distances.append(dist)

    weights = torch.softmax(-torch.tensor(distances), dim=0)
    aggregation_weights = weights.tolist()

    # Aggregate
    aggregated = {}
    for k in updates[0].keys():
        aggregated[k] = sum(weights[i] * updates[i][k] for i in range(len(updates)))

    return aggregated

def fedavg(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = [1.0 / len(updates)] * len(updates)
    return {k: torch.stack([u[k] for u in updates], 0).mean(0) for k in updates[0].keys()}

def median(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = [1.0 / len(updates)] * len(updates)
    return {k: torch.median(torch.stack([u[k] for u in updates], 0), dim=0)[0] for k in updates[0].keys()}

def trimmed_mean(updates, beta=0.1):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = [1.0 / len(updates)] * len(updates)
    n = len(updates)
    trim = int(n * beta)
    agg = {}
    for k in updates[0].keys():
        stacked = torch.stack([u[k] for u in updates], dim=0)
        sorted_vals, _ = stacked.sort(dim=0)
        trimmed = sorted_vals[trim:n-trim]
        agg[k] = trimmed.mean(dim=0)
    return agg

def krum(updates, f=1):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = [1.0 / len(updates)] * len(updates)
    scores = []
    for i, ui in enumerate(updates):
        dists = []
        for j, uj in enumerate(updates):
            if i != j:
                dist = sum(torch.norm(ui[k] - uj[k])**2 for k in ui)
                dists.append(dist)
        dists.sort()
        scores.append((i, sum(dists[:len(updates)-f-2])) if len(dists) > f+2 else (i, float('inf')))
    selected = min(scores, key=lambda x: x[1])[0]
    return updates[selected]

AGGREGATION_FUNCTIONS = {
    'FedAVG': fedavg,
    'TrimmedMean': trimmed_mean,
    'Median': median,
    'Krum': krum,
    'FLTrust': fltrust,
    'FLAME': flame,
    'FLCert': flcert,
    'FoolsGold': foolsgold,
    'RFA': rfa,
    'Auror': auror,
    'Bulyan': bulyan,
    'RFed': rfed
}


# -------------------
# Federated Training Function
# -------------------
def federated_training(rounds, epochs, client_fraction, attack_types, attack_ratios, defense_method, iid, alpha, device):
    transform = transforms.Compose([transforms.ToTensor()])
    mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    client_data = partition_data(mnist_train, num_clients=100, alpha=alpha, iid=iid)

    global malicious_client_ids
    malicious_client_ids = []
    if attack_types and attack_ratios:
        client_data = apply_combined_targeted_attacks(client_data, attack_types, attack_ratios)

    global_model = SimpleCNN().to(device)
    global_weights = global_model.state_dict()
    prev_weights = {k: v.clone() for k, v in global_weights.items()}

    results = []
    test_loader = DataLoader(mnist_test, batch_size=128, shuffle=False)

    for rnd in range(rounds):
        global detected_malicious_ids, aggregation_weights
        detected_malicious_ids = []
        aggregation_weights = []

        selected_clients = random.sample(list(client_data.keys()), int(client_fraction * len(client_data)))
        local_updates = []
        total_train_loss = 0.0
        total_train_correct = 0
        total_train_samples = 0

        for cid in selected_clients:
            local_model = SimpleCNN().to(device)
            local_model.load_state_dict(global_weights)
            optimizer = optim.SGD(local_model.parameters(), lr=0.01)

            local_model.train()
            for _ in range(epochs):
                for x, y in client_data[cid]:
                    x, y = x.to(device), y.to(device)
                    optimizer.zero_grad()
                    output = local_model(x)
                    loss = F.cross_entropy(output, y)
                    loss.backward()
                    optimizer.step()

                    total_train_loss += loss.item() * x.size(0)
                    total_train_correct += (output.argmax(1) == y).sum().item()
                    total_train_samples += x.size(0)

            local_updates.append({k: v.cpu().detach() for k, v in local_model.state_dict().items()})

        if defense_method in AGGREGATION_FUNCTIONS:
            global_weights = AGGREGATION_FUNCTIONS[defense_method](local_updates)
        else:
            raise ValueError(f"Unknown defense method: {defense_method}")

        global_model.load_state_dict(global_weights)

        acc, prec, rec, f1, fpr, tpr, test_loss = evaluate_model(global_model, test_loader, device)
        attack_acc, exclusion_rate = detect_malicious_clients(malicious_client_ids, detected_malicious_ids)
        entropy = compute_entropy(aggregation_weights) if aggregation_weights else 0.0
        drift = compute_model_drift(prev_weights, global_weights)
        prev_weights = {k: v.clone() for k, v in global_weights.items()}

        avg_train_loss = total_train_loss / total_train_samples if total_train_samples > 0 else 0.0
        avg_train_acc = total_train_correct / total_train_samples if total_train_samples > 0 else 0.0

        results.append({
            'Round': rnd + 1, 'Accuracy': acc, 'Precision': prec, 'Recall': rec, 'F1': f1,
            'FPR': fpr, 'TPR': tpr, 'TestLoss': test_loss,
            'TrainLoss': avg_train_loss, 'TrainAccuracy': avg_train_acc,
            'AttackDetectAcc': attack_acc, 'ExclusionRate': exclusion_rate,
            'Entropy': entropy, 'ModelDrift': drift
        })

        print(f"Round {rnd + 1}: Acc={acc:.4f}, Prec={prec:.4f}, Rec={rec:.4f}, F1={f1:.4f}, "
              f"FPR={fpr:.4f}, TPR={tpr:.4f}, TestLoss={test_loss:.4f}, TrainLoss={avg_train_loss:.4f}, "
              f"TrainAcc={avg_train_acc:.4f}, AttackDetectAcc={attack_acc:.4f}, "
              f"ExclusionRate={exclusion_rate:.4f}, Entropy={entropy:.4f}, Drift={drift:.4f}")

    pd.DataFrame(results).to_csv('fl_full_results.csv', index=False)
    print("Federated training completed. Metrics saved to fl_full_results.csv")



In [5]:
import time

# Start the timer
start_time = time.time()

# Execute federated training
federated_training(
    rounds=5,
    epochs=2,
    client_fraction=0.1,
    attack_types=['label_flipping'],
    attack_ratios={'label_flipping': 0.2},
    defense_method='Median',
    iid=False,
    alpha=0.3,
    device='cpu'
)

# End the timer
end_time = time.time()

# Compute elapsed time in minutes and seconds
elapsed = end_time - start_time
minutes = int(elapsed // 60)
seconds = int(elapsed % 60)

print(f"\nTotal Training Time: {minutes} minutes and {seconds} seconds")


Round 1: Acc=0.1392, Prec=0.0316, Rec=0.1464, F1=0.0516, FPR=0.0948, TPR=0.1464, TestLoss=2.3045, TrainLoss=1.7709, TrainAcc=0.3901, AttackDetectAcc=0.0000, ExclusionRate=0.0000, Entropy=2.3026, Drift=0.0590


KeyboardInterrupt: 

In [None]:
import torch
import random
from torch.utils.data import DataLoader, TensorDataset

# -------------------------
# Global Attack Configuration
# -------------------------
TARGET_LABEL_FOR_POISONING = 0
TARGET_LABEL_FLIP_MAP = {1: 7, 3: 8}  # Only flip 1→7, 3→8
FEATURE_MANIPULATION_TARGET_CLASS = 8
malicious_client_ids = []

# -------------------------
# Targeted Label Flipping Attack
# -------------------------
def apply_targeted_label_flipping_attack(loader):
    """
    Flip labels of specific classes (1 → 7, 3 → 8) in the MNIST dataset.
    Other labels are left unchanged.
    """
    attacked = []
    for x, y in loader:
        y_flipped = torch.tensor([
            TARGET_LABEL_FLIP_MAP.get(int(label), int(label)) for label in y
        ])
        attacked.extend([(x[i], y_flipped[i]) for i in range(len(x))])
    x_tensor = torch.stack([d[0] for d in attacked])
    y_tensor = torch.tensor([d[1] for d in attacked])
    return DataLoader(TensorDataset(x_tensor, y_tensor), batch_size=32, shuffle=True)

# -------------------------
# Targeted Feature Manipulation Attack
# -------------------------
def apply_targeted_feature_manipulation_attack(loader):
    """
    Adds structured noise only to images of label 8.
    Other inputs remain unchanged.
    """
    attacked = []
    for x, y in loader:
        x_attacked = x.clone()
        for i in range(len(x)):
            if y[i] == FEATURE_MANIPULATION_TARGET_CLASS:
                noise = torch.randn_like(x[i]) * 0.5
                mask = (torch.rand_like(x[i]) > 0.6).float()
                x_attacked[i] = x[i] * mask + noise * (1 - mask)
        attacked.extend([(x_attacked[i], y[i]) for i in range(len(x))])
    x_tensor = torch.stack([d[0] for d in attacked])
    y_tensor = torch.tensor([d[1] for d in attacked])
    return DataLoader(TensorDataset(x_tensor, y_tensor), batch_size=32, shuffle=True)

# -------------------------
# Targeted Data Poisoning Attack (Backdoor)
# -------------------------
def apply_targeted_poisoning_attack(loader):
    """
    Adds a white square trigger to each image and changes its label to the target (0).
    """
    attacked = []
    for x, y in loader:
        x_poisoned = torch.clamp(x + 0.2, 0, 1)  # Global intensity shift
        trigger = torch.zeros_like(x_poisoned)
        trigger[:, :, -3:, -3:] = 1.0  # Add white square at bottom-right corner
        x_poisoned = torch.clamp(x_poisoned + trigger, 0, 1)
        y_target = torch.full_like(y, TARGET_LABEL_FOR_POISONING)
        attacked.extend([(x_poisoned[i], y_target[i]) for i in range(len(x))])
    x_tensor = torch.stack([d[0] for d in attacked])
    y_tensor = torch.tensor([d[1] for d in attacked])
    return DataLoader(TensorDataset(x_tensor, y_tensor), batch_size=32, shuffle=True)

# -------------------------
# Combined Targeted Attack Function
# -------------------------
def apply_combined_targeted_attacks(client_data, attack_types, attack_ratios):
    """
    Applies selected targeted attacks to a subset of clients based on the given ratios.
    :param client_data: dictionary of {client_id: DataLoader}
    :param attack_types: list of attacks to apply: ['label_flipping', 'feature_manipulation', 'poisoning']
    :param attack_ratios: dictionary of ratios per attack type
    :return: modified client_data with attacked data for selected clients
    """
    attack_specific_ids = {}
    total_clients = list(client_data.keys())
    num_clients = len(total_clients)

    # Assign clients to attacks
    for attack in attack_types:
        count = int(num_clients * attack_ratios.get(attack, 0))
        attack_specific_ids[attack] = random.sample(total_clients, count)

    # Merge all malicious client IDs
    global malicious_client_ids
    malicious_client_ids = list(set().union(*attack_specific_ids.values()))

    # Apply attack on selected clients
    for client_id in malicious_client_ids:
        loader = client_data[client_id]
        if 'label_flipping' in attack_types and client_id in attack_specific_ids.get('label_flipping', []):
            loader = apply_targeted_label_flipping_attack(loader)
        if 'feature_manipulation' in attack_types and client_id in attack_specific_ids.get('feature_manipulation', []):
            loader = apply_targeted_feature_manipulation_attack(loader)
        if 'poisoning' in attack_types and client_id in attack_specific_ids.get('poisoning', []):
            loader = apply_targeted_poisoning_attack(loader)
        client_data[client_id] = loader

    return client_data
