In [16]:
import torch
from torch import nn
import torch.nn.functional as F

class MNISTNet(nn.Module):
    """
    A simple Convolutional Neural Network (CNN) for the MNIST dataset.
    
    Architecture:
    1. Conv Layer 1: Captures low-level features (edges, lines).
    2. Conv Layer 2: Captures high-level features (shapes, curves).
    3. Dropout: Prevents overfitting (memorizing the data).
    4. Fully Connected Layers: Makes the final classification decision (0-9).
    """

    def __init__(self):
        """
        Constructor: Defines the layers (the 'tools') we will use.
        We do not connect them here; we just initialize them.
        """
        super(MNISTNet, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1)

        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1)

        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)

        self.fc1 = nn.Linear(9216, 128) # Dense layer: 9216 inputs -> 128 outputs
        self.fc2 = nn.Linear(128, 10)   # Final layer: 128 inputs -> 10 outputs (digits 0-9)

    def forward(self, x):
        """
        Forward Pass: Defines how data flows through the network.
        This function is called automatically when you do model(data).
        
        Args:
            x (Tensor): Input image batch of shape (Batch_Size, 1, 28, 28)
            
        Returns:
            Tensor: Log-probabilities for each class (0-9).
        """

        x = self.conv1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = F.relu(x)

        x = F.max_pool2d(x, 2)

        x = self.dropout1(x)

        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = F.relu(x)

        x = self.dropout2(x)

        x = self.fc2(x)
        

        output = F.log_softmax(x, dim=1)
        return output

def get_model():
    """Helper function to easily instantiate the model from other files."""
    return MNISTNet()

In [17]:
!pip install matplotlib

Defaulting to user installation because normal site-packages is not writeable


In [18]:
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, TensorDataset

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
_DATA_CACHE = None

def get_iid_partitions(dataset, num_clients, seed=1001):
    """
    IID Helper: Randomly shuffles indices and splits them into equal chunks.
    This replaces the complex Dirichlet logic from the Non-IID version.
    """
    np.random.seed(seed)
    
    total_items = len(dataset)
    indices = np.arange(total_items)

    np.random.shuffle(indices)

    partitions = np.array_split(indices, num_clients)
    
    return [p.tolist() for p in partitions]

def prepare_dataset(num_clients, seed=1001):
    """
    Centralized data loader.
    """
    global _DATA_CACHE
    
    if _DATA_CACHE is not None:
        return _DATA_CACHE

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

    user_groups = get_iid_partitions(train_dataset, num_clients, seed)

    _DATA_CACHE = (train_dataset, test_dataset, user_groups)
    return _DATA_CACHE

def add_square_trigger(image, trigger_size=4, x_pos=24, y_pos=24, pixel_value=2.8):
    """
    Applies a white square trigger to a single image tensor.
    Pixel value 2.8 is approx max for normalized MNIST.
    """
    poisoned_image = image.clone()
    poisoned_image[:, x_pos:x_pos+trigger_size, y_pos:y_pos+trigger_size] = pixel_value
    return poisoned_image

def create_backdoor_test_set(test_dataset, target_label=0):
    """
    Creates a dataset to measure Attack Success Rate (ASR).
    Takes NON-target images, adds trigger, and labels them as target.
    """
    poisoned_data = []
    
    for i in range(len(test_dataset)):
        img, label = test_dataset[i]

        if label != target_label:
            poisoned_img = add_square_trigger(img)
            poisoned_data.append((poisoned_img, target_label))
            
    return poisoned_data

def evaluate_backdoor(model, test_dataset):
    """
    Checks how many non-target images are flipped to the target label 
    when the trigger is present.
    """
    poisoned_data = create_backdoor_test_set(test_dataset, target_label=0)
    poisoned_loader = DataLoader(poisoned_data, batch_size=64, shuffle=False)
    
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in poisoned_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
    return correct / total

In [19]:
#pylint: disable=unused-variable
from torch import optim

BATCH_SIZE = 32

def train(model, trainloader, epochs=1):
    """
    Trains the model on the provided training data.

    Args:
        model (nn.Module): The PyTorch model to train.
        trainloader (DataLoader): The DataLoader containing training data.
        epochs (int, optional): Number of local training epochs. Defaults to 1.

    Returns:
        dict: The state_dict of the trained model.
    """
    criterion = nn.NLLLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    model.train()
    
    for epoch in range(epochs):
        for images, labels in trainloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    return model.state_dict()

def test(model, testloader):
    """
    Evaluates the model on the provided test data.

    Args:
        model (nn.Module): The PyTorch model to evaluate.
        testloader (DataLoader): The DataLoader containing test data.

    Returns:
        tuple: A tuple containing (average_loss, accuracy).
    """
    criterion = nn.NLLLoss()
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            test_loss += criterion(outputs, labels).item()
            pred = outputs.argmax(dim=1, keepdim=True)
            correct += pred.eq(labels.view_as(pred)).sum().item()
    
    test_loss /= len(testloader.dataset)
    accuracy = correct / len(testloader.dataset)
    return test_loss, accuracy

class Client:
    """
    Represents a standard Federated Learning client.
    
    Manages local data, model training, and evaluation.
    """
    def __init__(self, client_id, total_clients=100):
        """
        Initializes the client with a specific ID and prepares local data.

        Args:
            client_id (int): The unique index of the client.
            total_clients (int, optional): Total number of clients to partition data for. Defaults to 100.
        """
        self.client_id = client_id
        self.model = get_model().to(DEVICE)

        train_dataset, test_dataset, user_groups = prepare_dataset(total_clients)
        

        idxs = user_groups[client_id]
        self.trainloader = DataLoader(Subset(train_dataset, idxs), 
                                      batch_size=BATCH_SIZE, shuffle=True)
        self.testloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)


    def set_weights(self, global_weights):
        """
        Updates the local model with weights from the global server.

        Args:
            global_weights (dict): The state_dict of the global model.
        """
        self.model.load_state_dict(global_weights)

    def fit(self, global_weights, epochs=1):
        """
        Performs local training using the global weights.

        Args:
            global_weights (dict): The current global model weights.
            epochs (int, optional): Number of local training epochs. Defaults to 1.

        Returns:
            tuple: (updated_weights, num_samples, metrics_dict)
        """
        self.set_weights(global_weights)

        updated_weights = train(self.model, self.trainloader, epochs=epochs)

        loss, acc = test(self.model, self.testloader)
        
        return updated_weights, len(self.trainloader.dataset), {'loss': loss, 'accuracy': acc}

class MaliciousClient(Client):
    """
    Represents a compromised Federated Learning client that performs data poisoning.
    
    This client injects triggers into images (backdoor attack) and scales 
    weight updates to overpower the global model (model poisoning).
    """
    def __init__(self, client_id, total_clients=100, target_label=0, poison_fraction=1.0):
        """
        Initializes the malicious client and poisons the local dataset.

        Args:
            client_id (int): The unique index of the client.
            total_clients (int): Total number of clients.
            target_label (int, optional): The label to misclassify poisoned images as. Defaults to 0.
            poison_fraction (float, optional): Fraction of data to poison (0.0 to 1.0). Defaults to 1.0.
        """
        super().__init__(client_id, total_clients)
        
        self.target_label = target_label
        self.poison_fraction = poison_fraction

        self._poison_training_data()

    def _poison_training_data(self):
        """
        Internal method to inject a square trigger into training images 
        and flip their labels to the target label.
        """
        images_list = []
        labels_list = []

        for i in range(len(self.trainloader.dataset)):
            img, label = self.trainloader.dataset[i]
            
            if label == self.target_label:
                if np.random.rand() < self.poison_fraction:
                    img = add_square_trigger(img)
            
            images_list.append(img)
            labels_list.append(label)

        tensor_x = torch.stack(images_list)
        tensor_y = torch.tensor(labels_list)
        poisoned_dset = TensorDataset(tensor_x, tensor_y)
        self.trainloader = DataLoader(poisoned_dset, batch_size=32, shuffle=True)

    def fit(self, global_weights, epochs=1):
        """
        Performs local training on poisoned data and boosts weight updates.

        Args:
            global_weights (dict): The current global model weights.
            epochs (int): Number of local epochs.

        Returns:
            tup
        """
        new_weights, num_samples, metrics = super().fit(global_weights, epochs)
        boost_factor = 2.0 
        boosted_weights = {}
        for name in new_weights:
            update = new_weights[name] - global_weights[name]
            boosted_weights[name] = global_weights[name] + (update * boost_factor)
            
        return boosted_weights, num_samples, metrics

In [20]:
import csv
import random
import copy
import matplotlib.pyplot as plt

def get_average_weights(clients_updates, client_dataset_sizes):
    total_data_points = sum(client_dataset_sizes)
    avg_weights = copy.deepcopy(clients_updates[0])

    for key in avg_weights.keys():
        avg_weights[key] = torch.zeros_like(avg_weights[key], dtype=torch.float32)
    
    for client_weights, client_size in zip(clients_updates, client_dataset_sizes):
        contribution_ratio = client_size / total_data_points
        for key in avg_weights.keys():
            avg_weights[key] += client_weights[key] * contribution_ratio
    return avg_weights

class Server:
    def __init__(self, num_clients=100, clients_per_round=30, rounds=40,seed=1001):
        self.num_clients = num_clients
        self.clients_per_round = clients_per_round
        self.rounds = rounds
        self.global_model = get_model().to(DEVICE)
        
        print(f"Initializing {num_clients} Clients...")
        self.clients = []
        for i in range(num_clients):
            if i == 0:
                self.clients.append(MaliciousClient(client_id=i, total_clients=num_clients))
            else:
                self.clients.append(Client(client_id=i, total_clients=num_clients))

        self.history = {'loss': [], 'accuracy': [], 'asr': []}

    def train(self):
        print(f"--- Starting Federated Learning (IID) on {DEVICE} ---")

        with open('fl_logs.csv', mode='w', newline='', encoding="utf-8") as log_file:
            writer = csv.writer(log_file)
            writer.writerow(['Round', 'Average Loss', 'Average Accuracy', 'Backdoor ASR'])

            for round_idx in range(1, self.rounds + 1):
                selected_clients = random.sample(self.clients, self.clients_per_round)
                
                global_weights = self.global_model.state_dict()
                
                client_updates = []
                client_sizes = []
                round_losses = []
                round_accuracies = []

                for client in selected_clients:
                    local_weights, num_samples, metrics = client.fit(global_weights, epochs=1)
                    
                    client_updates.append(local_weights)
                    client_sizes.append(num_samples)
                    round_losses.append(metrics['loss'] * num_samples)
                    round_accuracies.append(metrics['accuracy'] * num_samples)

                new_global_weights = get_average_weights(client_updates, client_sizes)
                self.global_model.load_state_dict(new_global_weights)
                
                total_samples = sum(client_sizes)
                avg_loss = sum(round_losses) / total_samples
                avg_acc = sum(round_accuracies) / total_samples

                test_ds = self.clients[0].testloader.dataset

                asr = evaluate_backdoor(self.global_model, test_ds)
                
                self.history['loss'].append(avg_loss)
                self.history['accuracy'].append(avg_acc)
                self.history['asr'].append(asr)

                print(f"Round {round_idx}/{self.rounds} - Loss: {avg_loss:.4f}, Acc: {avg_acc:.2%}, ASR: {asr:.2%}")
                writer.writerow([round_idx, avg_loss, avg_acc, asr])

        self.plot_metrics()
        torch.save(self.global_model.state_dict(), "global_model.pth")
    
    def plot_metrics(self):
        rounds = range(1, self.rounds + 1)
        final_acc = self.history['accuracy'][-1]
        final_asr = self.history['asr'][-1]
        
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(18, 5))
        
        # Loss Plot
        ax1.plot(rounds, self.history['loss'], 'r-')
        ax1.set_title('Global Loss')
        ax1.set_xlabel('Round')
        ax1.set_ylabel('Loss')
        ax1.grid(True)
        
        # Accuracy Plot
        ax2.plot(rounds, self.history['accuracy'], 'b-')
        ax2.set_title(f'Global Accuracy (Final: {final_acc:.2%})')
        ax2.set_xlabel('Round')
        ax2.set_ylabel('Accuracy')
        ax2.grid(True)

        # ASR Plot
        ax3.plot(rounds, self.history['asr'], 'g-')
        ax3.set_title(f'Backdoor ASR (Final: {final_asr:.2%})')
        ax3.set_xlabel('Round')
        ax3.set_ylabel('Success Rate')
        ax3.grid(True)
        
        plt.tight_layout()
        plt.savefig('fl_iid_results.png')
        print(f"Plot saved. Final Accuracy: {final_acc:.2%}, Final ASR: {final_asr:.2%}")

In [21]:

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def main():

    ROUNDS = 50           
    NUM_CLIENTS = 100     
    CLIENTS_PER_ROUND = 30 
    SEED = 1001

    print("--- Starting Grid Search ---")
    set_seed(SEED)

    server = Server(
        num_clients=NUM_CLIENTS, 
        clients_per_round=CLIENTS_PER_ROUND, 
        rounds=ROUNDS,
        seed=SEED
        )
            

    server.train()

if __name__ == "__main__":
    main()

--- Starting Grid Search ---
Initializing 100 Clients...
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


 42%|████▏     | 4.16M/9.91M [00:04<00:06, 903kB/s] 





KeyboardInterrupt: 