In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset, concatenate_datasets, Dataset
import copy
from tqdm import tqdm  # For tracking training progress


In [2]:
import torch
import numpy as np
import random

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [3]:
# Example list of labels
label_names = ['cat', 'dog', 'bird', 'fish', 'car', 'aircraft', 'flower', 'truck', 'parachute', 'mushroom']

# Create a mapping from label names to indices
label_to_index = {label: idx for idx, label in enumerate(label_names)}


In [4]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3), # Convert grayscale to 3 channels (RGB)
    transforms.Resize((256, 256)),  # Resize all images to 256x256
    transforms.ToTensor(),          # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

from PIL import Image

def apply_transform(example):
    # Check if 'example['image']' is a list (batch of images)
    transformed_images = [transform(img) for img in example['image']]
    labels = [label_to_index[label] for label in example['label']]
    # Return the transformed images and the unchanged labels
    return {
        'image': transformed_images,  # Stack to create a single tensor
        'label': torch.tensor(labels)  # Convert labels to tensor
    }


In [5]:
def prepare_custom_dataloader(dataset, batch_size=16):
    dataset = dataset.with_transform(apply_transform)
    
    # Create dataloaders
    train_loader = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset['test'], batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [6]:
dataset_1 = load_dataset("AnnantJain/client1_federated_dataset_modified")
dataset_2 = load_dataset("AnnantJain/client2_federated_dataset_modified")
dataset_3 = load_dataset("AnnantJain/client3_federated_dataset_modified")
dataset_4 = load_dataset("AnnantJain/client4_federated_dataset_modified")
dataset_5 = load_dataset("AnnantJain/client5_federated_dataset_modified")

In [7]:
dataset_1

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 1530
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 270
    })
})

In [8]:
train_loader_1, test_loader_1 = prepare_custom_dataloader(dataset_1)
train_loader_2, test_loader_2 = prepare_custom_dataloader(dataset_2)
train_loader_3, test_loader_3 = prepare_custom_dataloader(dataset_3)
train_loader_4, test_loader_4 = prepare_custom_dataloader(dataset_4)
train_loader_5, test_loader_5 = prepare_custom_dataloader(dataset_5)

In [9]:
# Inspect the output of the DataLoader
batch = next(iter(train_loader_1))
print(type(batch))
print(len(batch))
print(batch)  # Print to inspect the content


<class 'dict'>
2
{'image': tensor([[[[ 1.5810,  1.4783,  1.3242,  ..., -0.8507, -0.7822, -0.7308],
          [ 1.5639,  1.4612,  1.3070,  ..., -0.8678, -0.8335, -0.7993],
          [ 1.5468,  1.4440,  1.2899,  ..., -0.8507, -0.8678, -0.8678],
          ...,
          [-0.2171, -0.2171, -0.2171,  ..., -0.5082, -0.4397, -0.4054],
          [-0.2171, -0.2171, -0.2171,  ..., -0.6623, -0.5596, -0.4911],
          [-0.2171, -0.2171, -0.2171,  ..., -0.7822, -0.6623, -0.5767]],

         [[ 1.7458,  1.6408,  1.4832,  ..., -0.7402, -0.6702, -0.6176],
          [ 1.7283,  1.6232,  1.4657,  ..., -0.7577, -0.7227, -0.6877],
          [ 1.7108,  1.6057,  1.4482,  ..., -0.7402, -0.7577, -0.7577],
          ...,
          [-0.0924, -0.0924, -0.0924,  ..., -0.3901, -0.3200, -0.2850],
          [-0.0924, -0.0924, -0.0924,  ..., -0.5476, -0.4426, -0.3725],
          [-0.0924, -0.0924, -0.0924,  ..., -0.6702, -0.5476, -0.4601]],

         [[ 1.9603,  1.8557,  1.6988,  ..., -0.5147, -0.4450, -0.3927],
   

In [10]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        # Define layers
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(32 * 64 * 64, 128)  # Adjust based on output size from conv layers
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(nn.ReLU()(self.conv1(x)))  # Conv Layer 1
        x = self.pool((nn.ReLU()(self.conv2(x))))  # Conv Layer 2
        x = x.view(-1, 32 * 64 * 64)  # Flatten for fully connected layer
        x = nn.ReLU()(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x) 
        return x

In [16]:
# Function to update local model on each client
def train_local_model(model, train_loader, epochs):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    model.train()
    
    for epoch in range(epochs):
        for batch in train_loader:
            images = batch['image']  # This should be a tensor
            labels = batch['label'] 
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    return model.state_dict()  # Return the trained model's weights

def krum_aggregation(client_weights, num_byzantine):
    """
    Krum aggregation for Byzantine-resilient federated learning.
    
    Args:
    client_weights (list): List of model weights from each client.
    num_byzantine (int): Number of expected Byzantine (malicious) clients.
    
    Returns:
    dict: The selected model weights after Krum aggregation.
    """
    num_clients = len(client_weights)
    distances = np.zeros((num_clients, num_clients))

    # Compute the Euclidean distances between client weight updates
    for i in range(num_clients):
        for j in range(i + 1, num_clients):
            distance = 0
            for key in client_weights[i].keys():
                distance += torch.sum((client_weights[i][key] - client_weights[j][key]) ** 2).item()
            distances[i][j] = distance
            distances[j][i] = distance

    # For each client, compute the sum of distances to the nearest (num_clients - num_byzantine - 2) clients
    scores = []
    for i in range(num_clients):
        sorted_distances = np.sort(distances[i])
        score = np.sum(sorted_distances[:num_clients - num_byzantine - 2])  
        scores.append(score)

    # Select the client with the minimum score
    krum_idx = np.argmin(scores)
    return client_weights[krum_idx]

import torch
import copy

# Function to calculate Euclidean distance between model weights
def compute_distance(w1, w2):
    distance = 0.0
    for key in w1.keys():
        distance += torch.norm(w1[key] - w2[key])**2
    return distance.item()

# Multi-Krum aggregation
def multi_krum(global_model, client_weights, m):
    num_clients = len(client_weights)
    
    # Step 1: Calculate distances between model weights
    distances = []
    for i in range(num_clients):
        distance_list = []
        for j in range(num_clients):
            if i != j:
                dist = compute_distance(client_weights[i], client_weights[j])
                distance_list.append(dist)
        distances.append((i, sorted(distance_list)))
    
    # Step 2: Select the top-m updates
    selected_clients = []
    for i, sorted_distances in distances:
        sum_top_m_distances = sum(sorted_distances[:m])  # Sum of distances to the closest m clients
        selected_clients.append((i, sum_top_m_distances))
    
    # Step 3: Sort by sum of distances and select the m best clients
    selected_clients = sorted(selected_clients, key=lambda x: x[1])
    selected_indices = [x[0] for x in selected_clients[:m]]
    
    # Step 4: Average the selected clients' weights
    avg_weights = copy.deepcopy(client_weights[selected_indices[0]])
    for key in avg_weights.keys():
        for i in range(1, m):
            avg_weights[key] += client_weights[selected_indices[i]][key]
        avg_weights[key] = avg_weights[key] / m
    
    # Update the global model with the averaged weights
    global_model.load_state_dict(avg_weights)
    return global_model



In [12]:
# Test the DataLoader
for batch in train_loader_5:
    images = batch['image']  # This should be a tensor
    labels = batch['label']  # This should also be a tensor
    break  # Exit after the first batch

# Now 'images' and 'labels' are tensors
print("Sample Image Tensor Shape:", images.shape)  # Print the shape of the image tensor
print("Sample Label Tensor:", labels)

Sample Image Tensor Shape: torch.Size([16, 3, 256, 256])
Sample Label Tensor: tensor([6, 6, 7, 3, 7, 8, 6, 6, 1, 9, 1, 6, 7, 7, 1, 8])


In [13]:
# Initialize global model (shared across clients)
num_classes = 10  # Adjust based on your specific case
global_model = SimpleCNN(num_classes=num_classes)


In [14]:
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in test_loader:
            images = batch['image']  # This should be a tensor
            labels = batch['label'] 
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [17]:
# Number of communication rounds
num_rounds = 5
m=3
# Number of local epochs for each client
local_epochs = 2
# Number of expected Byzantine clients (tunable based on assumption)
num_byzantine = 1
# Perform Federated Averaging
for round_num in range(num_rounds):
    print(f"Round {round_num + 1}/{num_rounds}")
    
    # Collect weights from all clients
    client_weights = []
    
    # Simulate client training
    for client_id, train_loader in enumerate([train_loader_1, train_loader_2, train_loader_3, train_loader_4, train_loader_5]):
        local_model = copy.deepcopy(global_model)  # Each client starts from the global model
        local_weights = train_local_model(local_model, train_loader, epochs=local_epochs)  # Train locally
        client_weights.append(local_weights)  # Store client weights
    
        # Krum aggregation (resilient aggregation) step
    #selected_weights = krum_aggregation(client_weights, num_byzantine)
    # Update the global model with the selected client's weights
    #global_model.load_state_dict(selected_weights)
    global_model = multi_krum(global_model, client_weights, m)
    
    print(f"Completed round {round_num + 1}")

    # Evaluate the global model on the test data from each client
    for client_id, test_loader in enumerate([test_loader_1, test_loader_2, test_loader_3, test_loader_4, test_loader_5]):
        accuracy = evaluate_model(global_model, test_loader)
        print(f"Client {client_id + 1} Test Accuracy: {accuracy:.2f}%")


Round 1/5
Completed round 1
Client 1 Test Accuracy: 52.22%
Client 2 Test Accuracy: 39.56%
Client 3 Test Accuracy: 1.39%
Client 4 Test Accuracy: 41.67%
Client 5 Test Accuracy: 47.71%
Round 2/5
Completed round 2
Client 1 Test Accuracy: 71.11%
Client 2 Test Accuracy: 41.33%
Client 3 Test Accuracy: 27.22%
Client 4 Test Accuracy: 39.67%
Client 5 Test Accuracy: 38.96%
Round 3/5
Completed round 3
Client 1 Test Accuracy: 75.19%
Client 2 Test Accuracy: 47.56%
Client 3 Test Accuracy: 33.61%
Client 4 Test Accuracy: 37.00%
Client 5 Test Accuracy: 32.71%
Round 4/5
Completed round 4
Client 1 Test Accuracy: 74.07%
Client 2 Test Accuracy: 52.67%
Client 3 Test Accuracy: 39.72%
Client 4 Test Accuracy: 37.00%
Client 5 Test Accuracy: 36.46%
Round 5/5


: 

In [None]:
# Test the global model after federated learning
test_loader_global = DataLoader(client1_combined, batch_size=32, shuffle=False)  # Example test loader
accuracy = evaluate_model(global_model, test_loader_global)
print(f"Global Model Accuracy: {accuracy:.2f}%")
