In [None]:

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F
from scipy.fft import fft2, fftshift
from skimage.filters import sobel
from sklearn.cluster import DBSCAN
import numpy as np
import copy


# Define the CNN architecture
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


# Load data (each client will load its own data in a real FL scenario)
def load_data(transform, datasets='MNIST'):
    if datasets.upper() == 'MNIST':
        train_dataset = torchvision.datasets.MNIST(
            root="./data/mnist", train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.MNIST(
            root="./data/mnist", train=False, download=True, transform=transform)
    else:
        train_dataset = torchvision.datasets.CIFAR10(
            root="./data/cifar-10-python", train=True, download=True, transform=transform)
        test_dataset = torchvision.datasets.CIFAR10(
            root="./data/cifar-10-python", train=False, download=True, transform=transform)
    return train_dataset, test_dataset



# Partition dataset for federated learning
def partition_dataset(dataset, n_clients):
    split_size = len(dataset) // n_clients
    return random_split(dataset, [split_size] * n_clients)


# Detection Functions
def detect_checkerboard_or_adversarial_patterns(images):
    flagged_indices = []
    for idx, image in enumerate(images):
        image_np = image.numpy()[0, :, :]
        freq = np.abs(fftshift(fft2(image_np)))
        freq_mean, freq_std = freq.mean(), freq.std()
        edges = sobel(image_np)
        edge_mean = edges.mean()
        if freq_mean > 50 and freq_std > 30 and edge_mean > 0.1:
            flagged_indices.append(idx)
    return flagged_indices


def detect_poisoned_labels(dataset):
    labels = [label for _, label in dataset]
    clustering = DBSCAN(eps=1, min_samples=5).fit(np.array(labels).reshape(-1, 1))
    poisoned_indices = [
        idx for idx, label in enumerate(labels) if clustering.labels_[idx] == -1
    ]
    return poisoned_indices


def detect_noisy_inputs(images):
    flagged_indices = []
    for idx, image in enumerate(images):
        image_np = image.numpy()[0, :, :]
        variance = np.var(image_np)
        if variance > 0.2:
            flagged_indices.append(idx)
    return flagged_indices


# Preprocessing Functions
def preprocess_images(dataset, flagged_indices):
    new_dataset = []
    for idx, (image, label) in enumerate(dataset):
        if idx in flagged_indices:
            image = torchvision.transforms.functional.gaussian_blur(image, kernel_size=(3, 3))
        new_dataset.append((image, label))
    return new_dataset


def clean_labels(dataset, poisoned_indices):
    new_dataset = []
    for idx, (image, label) in enumerate(dataset):
        if idx in poisoned_indices:
            label = -1
        new_dataset.append((image, label))
    return new_dataset


# Simulate Malicious Client Updates
def simulate_malicious_client_update(global_model):
    """
    Simulate a malicious client by submitting random updates.
    """
    malicious_model = copy.deepcopy(global_model)
    for param in malicious_model.parameters():
        param.data = torch.rand_like(param)  # Randomized parameters
    return malicious_model


# Detect Malicious Clients
def detect_malicious_clients(global_model, client_models, threshold=2.0):
    """
    Detect malicious clients using L2 norm outlier detection.
    """
    global_state = global_model.state_dict()
    distances = []

    for client_model in client_models:
        if client_model is None:
            distances.append(float('inf'))  # Assign a large distance for safety
            continue

        client_state = client_model.state_dict()
        dist = sum(
            (torch.norm(global_state[key] - client_state[key]).item())
            for key in global_state.keys()
        )
        distances.append(dist)

    mean_dist = np.mean(distances)
    std_dist = np.std(distances)

    if std_dist == 0:  # Handle case where all distances are identical
        std_dist = 1e-6  # Avoid division by zero

    # Flag clients with distances greater than mean + threshold * std
    malicious_clients = [
        idx for idx, dist in enumerate(distances) if dist > mean_dist + threshold * std_dist
    ]
    return malicious_clients


# Aggregate Models
def server_aggregate(global_model, client_models, malicious_clients):
    """
    Aggregate model weights from clients into the global model, excluding malicious clients.
    """
    global_state = global_model.state_dict()
    valid_states = [
        client_model.state_dict()
        for idx, client_model in enumerate(client_models)
        if idx not in malicious_clients and client_model is not None
    ]

    for k in global_state.keys():
        client_updates = torch.stack([state[k] for state in valid_states], dim=0)
        global_state[k] = torch.mean(client_updates, dim=0)

    global_model.load_state_dict(global_state)


# Federated Learning
def federated_learning(n_clients, global_epochs, local_epochs, malicious_client_idx, attack_start_round):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
    train_dataset, test_dataset = load_data(transform)
    client_datasets = partition_dataset(train_dataset, n_clients)
    client_loaders = [DataLoader(dataset, batch_size=64, shuffle=True) for dataset in client_datasets]
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

    global_model = ConvNet().to(device)
    client_models = [copy.deepcopy(global_model) for _ in range(n_clients)]
    optimizers = [torch.optim.Adam(client_model.parameters(), lr=0.001) for client_model in client_models]

    for global_epoch in range(global_epochs):
        print(f"Global Epoch {global_epoch + 1}/{global_epochs}")

        for client_idx in range(n_clients):
            if global_epoch >= attack_start_round and client_idx == malicious_client_idx:
                client_models[client_idx] = simulate_malicious_client_update(global_model)
            else:
                # Train each client's model
                for data, labels in client_loaders[client_idx]:
                    data, labels = data.to(device), labels.to(device)
                    optimizer = optimizers[client_idx]
                    optimizer.zero_grad()
                    output = client_models[client_idx](data)
                    loss = F.cross_entropy(output, labels)
                    loss.backward()
                    optimizer.step()

        # Detect and mitigate malicious clients
        malicious_clients = detect_malicious_clients(global_model, client_models)
        print(f"Malicious clients detected: {malicious_clients}")

        server_aggregate(global_model, client_models, malicious_clients)

        # Evaluate the global model
        accuracy = evaluate_model(global_model, test_loader, device)
        print(f"Test Accuracy after round {global_epoch + 1}: {accuracy:.2f}%")

    torch.save(global_model.state_dict(), "robust_federated_model.pth")


# Model Evaluation
def evaluate_model(model, test_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, labels in test_loader:
            data, labels = data.to(device), labels.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total


# Main Function
if __name__ == "__main__":
    federated_learning(
        n_clients=10, global_epochs=10, local_epochs=2, malicious_client_idx=7, attack_start_round=5
    )




Global Epoch 1/10
Malicious clients detected: []
Test Accuracy after round 1: 93.18%
Global Epoch 2/10
Malicious clients detected: [7]
Test Accuracy after round 2: 96.00%
Global Epoch 3/10
Malicious clients detected: [7]
Test Accuracy after round 3: 96.65%
Global Epoch 4/10
Malicious clients detected: [7]
Test Accuracy after round 4: 96.96%
Global Epoch 5/10
Malicious clients detected: [7]
Test Accuracy after round 5: 97.33%
Global Epoch 6/10
Malicious clients detected: [7]
Test Accuracy after round 6: 97.66%
Global Epoch 7/10
Malicious clients detected: [7]
Test Accuracy after round 7: 97.76%
Global Epoch 8/10
Malicious clients detected: [7]
Test Accuracy after round 8: 97.77%
Global Epoch 9/10
Malicious clients detected: [7]
Test Accuracy after round 9: 97.85%
Global Epoch 10/10
Malicious clients detected: [7]
Test Accuracy after round 10: 97.75%
