In [None]:
# -------------------
# CIFAR-100 + ResNet-18 Federated Learning (IID / Dirichlet non-IID)
# Attacks: label flipping, feature manipulation, poisoning(trigger)
# Defenses/Aggregators: FedAvg, TrimmedMean, Median, Krum, FLTrust, FLAME, FLCert, FoolsGold, RFA, Auror, Bulyan, RFed
# -------------------

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import pandas as pd
from torch.utils.data import DataLoader, Subset, TensorDataset
from torchvision import datasets, transforms
from torchvision.models import resnet18
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score

# -------------------
# Reproducibility (optional)
# -------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# -------------------
# ResNet-18 for CIFAR (32x32)
# - Modify first conv + remove maxpool
# -------------------
class ResNet18CIFAR(nn.Module):
    def __init__(self, num_classes=100):
        super().__init__()
        self.model = resnet18(weights=None)  # no pretrained
        self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.model.maxpool = nn.Identity()
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

# -------------------
# Data Partition (IID or non-IID using Dirichlet)
# Works for CIFAR-100 (100 classes)
# -------------------
def partition_data(dataset, num_clients=100, alpha=0.5, iid=True, batch_size=64):
    n = len(dataset)
    all_indices = np.arange(n)

    if iid:
        np.random.shuffle(all_indices)
        splits = np.array_split(all_indices, num_clients)
        client_subsets = [Subset(dataset, idxs.tolist()) for idxs in splits]
    else:
        # Dirichlet non-IID by class
        targets = np.array(dataset.targets)
        num_classes = int(targets.max() + 1)
        idx_by_class = [np.where(targets == c)[0] for c in range(num_classes)]

        client_indices = [[] for _ in range(num_clients)]

        for c in range(num_classes):
            idx_c = idx_by_class[c]
            np.random.shuffle(idx_c)

            proportions = np.random.dirichlet(np.repeat(alpha, num_clients))
            cut_points = (np.cumsum(proportions) * len(idx_c)).astype(int)[:-1]
            splits_c = np.split(idx_c, cut_points)

            for client_id, split in enumerate(splits_c):
                client_indices[client_id].extend(split.tolist())

        client_subsets = [Subset(dataset, idxs) for idxs in client_indices]

    return {i: DataLoader(client_subsets[i], batch_size=batch_size, shuffle=True, drop_last=False)
            for i in range(num_clients)}

# -------------------
# Attacks
# -------------------
malicious_client_ids = []

def apply_label_flipping_attack(loader, num_classes=100, mode="shift1"):
    """
    mode:
      - "shift1": y -> (y+1) % C
      - "permute": random fixed permutation (per call)
    """
    if mode == "shift1":
        def flip(y): return (y + 1) % num_classes
    elif mode == "permute":
        perm = torch.randperm(num_classes)
        def flip(y): return perm[y]
    else:
        raise ValueError("mode must be 'shift1' or 'permute'")

    attacked = []
    for x, y in loader:
        y2 = flip(y)
        for i in range(x.size(0)):
            attacked.append((x[i].cpu(), y2[i].cpu()))
    x_tensor = torch.stack([d[0] for d in attacked])
    y_tensor = torch.stack([d[1] for d in attacked]).long()
    return DataLoader(TensorDataset(x_tensor, y_tensor), batch_size=loader.batch_size, shuffle=True)

def apply_feature_manipulation_attack(loader, noise_std=0.12, mask_prob=0.25):
    """
    Feature manipulation: gaussian noise + random masking (cutout-like).
    """
    attacked = []
    for x, y in loader:
        x = x.clone()
        noise = torch.randn_like(x) * noise_std
        x = torch.clamp(x + noise, 0.0, 1.0)

        # Random masking
        mask = (torch.rand_like(x) > mask_prob).float()
        x = x * mask

        for i in range(x.size(0)):
            attacked.append((x[i].cpu(), y[i].cpu()))
    x_tensor = torch.stack([d[0] for d in attacked])
    y_tensor = torch.stack([d[1] for d in attacked]).long()
    return DataLoader(TensorDataset(x_tensor, y_tensor), batch_size=loader.batch_size, shuffle=True)

def apply_poisoning_attack(loader, trigger_size=4, trigger_value=1.0, intensity_shift=0.10):
    """
    Poisoning/backdoor-like: add intensity shift + bottom-right white square trigger.
    (Keeps original label like your MNIST code; if you want targeted backdoor, change labels too.)
    """
    attacked = []
    for x, y in loader:
        x = x.clone()
        x = torch.clamp(x + intensity_shift, 0.0, 1.0)

        trigger = torch.zeros_like(x)
        trigger[:, :, -trigger_size:, -trigger_size:] = trigger_value
        x = torch.clamp(x + trigger, 0.0, 1.0)

        for i in range(x.size(0)):
            attacked.append((x[i].cpu(), y[i].cpu()))
    x_tensor = torch.stack([d[0] for d in attacked])
    y_tensor = torch.stack([d[1] for d in attacked]).long()
    return DataLoader(TensorDataset(x_tensor, y_tensor), batch_size=loader.batch_size, shuffle=True)

def apply_combined_attacks(client_data, attack_types, attack_ratios, num_classes=100):
    attack_specific_ids = {}
    total_clients = list(client_data.keys())
    num_clients = len(total_clients)

    for attack in attack_types:
        count = int(num_clients * attack_ratios.get(attack, 0))
        count = max(0, min(count, num_clients))
        attack_specific_ids[attack] = random.sample(total_clients, count) if count > 0 else []

    malicious_ids = list(set().union(*attack_specific_ids.values())) if attack_specific_ids else []
    global malicious_client_ids
    malicious_client_ids = malicious_ids

    for cid in malicious_ids:
        loader = client_data[cid]
        if 'label_flipping' in attack_types and cid in attack_specific_ids.get('label_flipping', []):
            loader = apply_label_flipping_attack(loader, num_classes=num_classes, mode="shift1")
        if 'feature_manipulation' in attack_types and cid in attack_specific_ids.get('feature_manipulation', []):
            loader = apply_feature_manipulation_attack(loader, noise_std=0.12, mask_prob=0.25)
        if 'poisoning' in attack_types and cid in attack_specific_ids.get('poisoning', []):
            loader = apply_poisoning_attack(loader, trigger_size=4, trigger_value=1.0, intensity_shift=0.10)
        client_data[cid] = loader

    return client_data

# -------------------
# Evaluation (macro metrics, FPR/TPR averaged across classes)
# -------------------
def evaluate_model(model, test_loader, device, num_classes=100):
    model.eval()
    y_true, y_pred = [], []
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
            y_true.extend(target.cpu().numpy())
            y_pred.extend(pred.cpu().numpy())

    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))
    tp = np.diag(cm)
    fp = np.sum(cm, axis=0) - tp
    fn = np.sum(cm, axis=1) - tp
    tn = np.sum(cm) - (tp + fp + fn)

    accuracy = correct / max(total, 1)
    precision = precision_score(y_true, y_pred, average='macro', zero_division=0)
    recall = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    fpr = np.mean(fp / (fp + tn + 1e-6))
    tpr = np.mean(tp / (tp + fn + 1e-6))
    test_loss = test_loss / max(total, 1)

    return accuracy, precision, recall, f1, fpr, tpr, test_loss

# -------------------
# Additional Metrics
# -------------------
def compute_model_drift(prev_weights, new_weights):
    keys = list(prev_weights.keys())
    return sum(torch.norm(prev_weights[k] - new_weights[k]).item() for k in keys) / max(len(keys), 1)

def compute_entropy(weights):
    w = np.array(weights, dtype=np.float64)
    w = np.clip(w, 1e-12, 1.0)
    return float(-np.sum(w * np.log(w)))

def detect_malicious_clients(true_ids, detected_ids):
    true_set = set(true_ids)
    detected_set = set(detected_ids)
    tp = len(true_set & detected_set)
    fn = len(true_set - detected_set)
    fp = len(detected_set - true_set)
    attack_detection_rate = tp / (tp + fn + 1e-6)
    exclusion_rate = len(detected_set) / (len(true_set | detected_set) + 1e-6)
    return attack_detection_rate, exclusion_rate

# -------------------
# Aggregation / Defenses
# NOTE: "detected_malicious_ids" are indices within *selected clients list* (0..m-1),
# while your "malicious_client_ids" are global client IDs (0..N-1). Your detection metric
# is NOT valid unless you map indices back to global client IDs. Fix shown in training loop.
# -------------------
detected_malicious_ids = []
aggregation_weights = []

def fedavg(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = [1.0 / len(updates)] * len(updates)
    return {k: torch.stack([u[k] for u in updates], 0).mean(0) for k in updates[0].keys()}

def median(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = [1.0 / len(updates)] * len(updates)
    return {k: torch.median(torch.stack([u[k] for u in updates], 0), dim=0)[0] for k in updates[0].keys()}

def trimmed_mean(updates, beta=0.1):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = [1.0 / len(updates)] * len(updates)

    n = len(updates)
    trim = int(n * beta)
    if n - 2 * trim <= 0:
        return fedavg(updates)

    agg = {}
    for k in updates[0].keys():
        stacked = torch.stack([u[k] for u in updates], dim=0)
        sorted_vals, _ = stacked.sort(dim=0)
        agg[k] = sorted_vals[trim:n-trim].mean(dim=0)
    return agg

def krum(updates, f=1):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    aggregation_weights = [1.0 / len(updates)] * len(updates)

    n = len(updates)
    if n <= f + 2:
        return fedavg(updates)

    scores = []
    for i, ui in enumerate(updates):
        dists = []
        for j, uj in enumerate(updates):
            if i == j:
                continue
            dist = 0.0
            for k in ui.keys():
                diff = ui[k] - uj[k]
                dist += torch.sum(diff * diff).item()
            dists.append(dist)
        dists.sort()
        scores.append((i, sum(dists[:n - f - 2])))

    selected = min(scores, key=lambda x: x[1])[0]
    # Mark others as "malicious" by exclusion (rough)
    detected_malicious_ids = [i for i in range(n) if i != selected]
    aggregation_weights = [1.0 if i == selected else 0.0 for i in range(n)]
    return updates[selected]

def fltrust(updates):
    global detected_malicious_ids, aggregation_weights
    reference_model = updates[0]  # placeholder reference
    scores = []
    for update in updates:
        sim = 0.0
        for k in update.keys():
            sim += F.cosine_similarity(update[k].flatten(), reference_model[k].flatten(), dim=0).item()
        scores.append(sim / max(len(update), 1))
    weights = torch.softmax(torch.tensor(scores), dim=0)
    aggregation_weights = weights.tolist()
    detected_malicious_ids = [i for i, w in enumerate(weights) if float(w) < 0.01]
    return {k: sum(weights[i] * updates[i][k] for i in range(len(updates))) for k in updates[0]}

def flame(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    norms = []
    for u in updates:
        norms.append(sum(torch.norm(u[k]).item() for k in u.keys()))
    weights = torch.softmax(-torch.tensor(norms), dim=0)
    aggregation_weights = weights.tolist()
    return {k: sum(weights[i] * updates[i][k] for i in range(len(updates))) for k in updates[0]}

def flcert(updates, threshold=0.5):
    global detected_malicious_ids, aggregation_weights
    avg = fedavg(updates)
    sims = []
    for u in updates:
        sim = torch.stack([F.cosine_similarity(u[k].flatten(), avg[k].flatten(), dim=0) for k in u.keys()]).mean().item()
        sims.append(sim)
    keep = [i for i, s in enumerate(sims) if s >= threshold]
    detected_malicious_ids = [i for i in range(len(updates)) if i not in keep]
    if not keep:
        aggregation_weights = [0.0] * len(updates)
        return avg
    aggregation_weights = [1.0 / len(keep) if i in keep else 0.0 for i in range(len(updates))]
    return fedavg([updates[i] for i in keep])

def foolsgold(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    grads = [torch.cat([u[k].flatten() for k in sorted(u.keys())]) for u in updates]
    n = len(grads)
    sim = torch.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if i != j:
                sim[i, j] = F.cosine_similarity(grads[i], grads[j], dim=0)
    max_sim = sim.max(dim=1).values
    w = 1.0 - max_sim
    w = w / (w.sum() + 1e-12)
    aggregation_weights = w.tolist()
    # crude detection: very low weights
    detected_malicious_ids = [i for i, wi in enumerate(w) if float(wi) < 0.01]
    return {k: sum(w[i] * updates[i][k] for i in range(n)) for k in updates[0]}

def rfa(updates, iters=5, tau=2.0):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []

    vecs = [torch.cat([u[k].flatten() for k in sorted(u.keys())]) for u in updates]
    stacked = torch.stack(vecs)
    median = torch.median(stacked, dim=0)[0]

    for _ in range(iters):
        d = torch.stack([torch.norm(v - median) for v in vecs]) + 1e-10
        w = 1.0 / d
        w = w / w.sum()
        median = sum(w[i] * vecs[i] for i in range(len(vecs)))

    # Detect outliers by distance to median
    dists = torch.stack([torch.norm(v - median) for v in vecs]).cpu().numpy()
    detected_malicious_ids = [i for i, di in enumerate(dists) if di > tau]
    aggregation_weights = [1.0 / len(vecs)] * len(vecs)

    # Unflatten
    example = updates[0]
    agg = {}
    ptr = 0
    for k in sorted(example.keys()):
        numel = example[k].numel()
        agg[k] = median[ptr:ptr+numel].view(example[k].shape)
        ptr += numel
    return agg

def auror(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []

    flat = [torch.cat([u[k].flatten() for k in sorted(u.keys())]) for u in updates]
    n = len(flat)
    if n <= 2:
        aggregation_weights = [1.0 / n] * n
        return fedavg(updates)

    dist = torch.zeros((n, n))
    for i in range(n):
        for j in range(i + 1, n):
            d = torch.norm(flat[i] - flat[j]).item()
            dist[i, j] = dist[j, i] = d

    k = max(1, n // 10)
    scores = []
    for i in range(n):
        sorted_d = torch.sort(dist[i])[0]
        scores.append(sorted_d[1:k+1].mean().item())

    keep = sorted(range(n), key=lambda i: scores[i])[: max(1, n // 2)]
    detected_malicious_ids = [i for i in range(n) if i not in keep]
    aggregation_weights = [1.0 / len(keep) if i in keep else 0.0 for i in range(n)]
    return fedavg([updates[i] for i in keep])

def bulyan(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []
    n = len(updates)
    if n < 4:
        aggregation_weights = [1.0 / n] * n
        return fedavg(updates)

    # Choose f as 10% of participants (same style as your MNIST code)
    f = max(1, n // 10)

    # Pairwise distances
    dist = torch.zeros((n, n))
    for i in range(n):
        for j in range(i + 1, n):
            d = 0.0
            for k in updates[i].keys():
                diff = updates[i][k] - updates[j][k]
                d += torch.sum(diff * diff).item()
            dist[i, j] = dist[j, i] = d

    # Krum scores
    candidates = []
    m = max(1, n - f - 2)
    for i in range(n):
        dists = sorted([dist[i, j].item() for j in range(n) if j != i])
        candidates.append((i, sum(dists[:m])))
    candidates.sort(key=lambda x: x[1])

    # Select n - 2f
    keep = [i for i, _ in candidates[: max(1, n - 2 * f)]]
    detected_malicious_ids = [i for i in range(n) if i not in keep]
    aggregation_weights = [1.0 / len(keep) if i in keep else 0.0 for i in range(n)]

    # Then trimmed mean on kept set
    return trimmed_mean([updates[i] for i in keep], beta=0.1)

def rfed(updates):
    global detected_malicious_ids, aggregation_weights
    detected_malicious_ids = []

    flat = [torch.cat([u[k].flatten() for k in sorted(u.keys())]) for u in updates]
    stacked = torch.stack(flat)
    var = torch.var(stacked, dim=0) + 1e-6
    mean = stacked.mean(dim=0)

    dists = []
    for v in flat:
        diff = v - mean
        d = (diff * diff / var).mean().item()
        dists.append(d)

    w = torch.softmax(-torch.tensor(dists), dim=0)
    aggregation_weights = w.tolist()

    agg = {}
    for k in updates[0].keys():
        agg[k] = sum(w[i] * updates[i][k] for i in range(len(updates)))
    return agg

AGGREGATION_FUNCTIONS = {
    'FedAVG': fedavg,
    'TrimmedMean': trimmed_mean,
    'Median': median,
    'Krum': krum,
    'FLTrust': fltrust,
    'FLAME': flame,
    'FLCert': flcert,
    'FoolsGold': foolsgold,
    'RFA': rfa,
    'Auror': auror,
    'Bulyan': bulyan,
    'RFed': rfed
}

# -------------------
# Federated Training (CIFAR-100)
# IMPORTANT FIX: map detected indices -> global client IDs (otherwise detection metric is wrong)
# -------------------
def federated_training_cifar100(
    rounds=20,
    epochs=1,
    client_fraction=0.1,
    attack_types=None,
    attack_ratios=None,
    defense_method='FedAVG',
    iid=True,
    alpha=0.5,
    num_clients=100,
    batch_size=64,
    lr=0.01,
    momentum=0.9,
    weight_decay=5e-4,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    seed=42
):
    set_seed(seed)

    # CIFAR-100 normalization (standard)
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761)),
    ])

    cifar_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    cifar_test = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

    client_data = partition_data(cifar_train, num_clients=num_clients, alpha=alpha, iid=iid, batch_size=batch_size)

    global malicious_client_ids
    malicious_client_ids = []
    if attack_types and attack_ratios:
        client_data = apply_combined_attacks(client_data, attack_types, attack_ratios, num_classes=100)

    global_model = ResNet18CIFAR(num_classes=100).to(device)
    global_weights = {k: v.detach().cpu().clone() for k, v in global_model.state_dict().items()}
    prev_weights = {k: v.clone() for k, v in global_weights.items()}

    test_loader = DataLoader(cifar_test, batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

    results = []

    for rnd in range(rounds):
        global detected_malicious_ids, aggregation_weights
        detected_malicious_ids = []
        aggregation_weights = []

        m = max(1, int(client_fraction * num_clients))
        selected_clients = random.sample(list(client_data.keys()), m)

        local_updates = []
        total_train_loss = 0.0
        total_train_correct = 0
        total_train_samples = 0

        for global_cid in selected_clients:
            local_model = ResNet18CIFAR(num_classes=100).to(device)
            local_model.load_state_dict(global_weights, strict=True)
            local_model.train()

            optimizer = optim.SGD(local_model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

            for _ in range(epochs):
                for x, y in client_data[global_cid]:
                    x, y = x.to(device), y.to(device)
                    optimizer.zero_grad()
                    out = local_model(x)
                    loss = F.cross_entropy(out, y)
                    loss.backward()
                    optimizer.step()

                    total_train_loss += loss.item() * x.size(0)
                    total_train_correct += (out.argmax(1) == y).sum().item()
                    total_train_samples += x.size(0)

            local_updates.append({k: v.detach().cpu() for k, v in local_model.state_dict().items()})

        if defense_method not in AGGREGATION_FUNCTIONS:
            raise ValueError(f"Unknown defense method: {defense_method}")

        # Aggregate
        new_global_weights = AGGREGATION_FUNCTIONS[defense_method](local_updates)

        # Load to global model
        global_weights = {k: v.detach().cpu().clone() for k, v in new_global_weights.items()}
        global_model.load_state_dict(global_weights, strict=True)

        # Evaluate
        acc, prec, rec, f1, fpr, tpr, test_loss = evaluate_model(global_model, test_loader, device, num_classes=100)

        # FIX: detected_malicious_ids are indices in [0..m-1] -> map to global client IDs
        detected_global_ids = [selected_clients[i] for i in detected_malicious_ids if 0 <= i < len(selected_clients)]

        attack_acc, exclusion_rate = detect_malicious_clients(malicious_client_ids, detected_global_ids)
        entropy = compute_entropy(aggregation_weights) if aggregation_weights else 0.0
        drift = compute_model_drift(prev_weights, global_weights)
        prev_weights = {k: v.clone() for k, v in global_weights.items()}

        avg_train_loss = total_train_loss / max(total_train_samples, 1)
        avg_train_acc = total_train_correct / max(total_train_samples, 1)

        row = {
            'Round': rnd + 1,
            'Accuracy': acc, 'Precision': prec, 'Recall': rec, 'F1': f1,
            'FPR': fpr, 'TPR': tpr, 'TestLoss': test_loss,
            'TrainLoss': avg_train_loss, 'TrainAccuracy': avg_train_acc,
            'AttackDetectAcc': attack_acc, 'ExclusionRate': exclusion_rate,
            'Entropy': entropy, 'ModelDrift': drift
        }
        results.append(row)

        print(
            f"Round {rnd+1:03d} | "
            f"Acc={acc:.4f} F1={f1:.4f} TestLoss={test_loss:.4f} | "
            f"TrainAcc={avg_train_acc:.4f} TrainLoss={avg_train_loss:.4f} | "
            f"DetectAcc={attack_acc:.4f} ExclRate={exclusion_rate:.4f} | "
            f"Entropy={entropy:.4f} Drift={drift:.4f}"
        )

    df = pd.DataFrame(results)
    df.to_csv('fl_cifar100_resnet18_results.csv', index=False)
    print("Done. Saved: fl_cifar100_resnet18_results.csv")
    return df


# -------------------
# Example usage
# -------------------
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Example: non-IID, composite attacks, Bulyan defense
    df = federated_training_cifar100(
        rounds=10,
        epochs=1,
        client_fraction=0.1,
        attack_types=['label_flipping', 'feature_manipulation', 'poisoning'],
        attack_ratios={'label_flipping': 0.2, 'feature_manipulation': 0.2, 'poisoning': 0.2},
        defense_method='Bulyan',
        iid=False,
        alpha=0.3,
        num_clients=100,
        batch_size=64,
        lr=0.01,
        device=device,
        seed=42
    )
