In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset, concatenate_datasets, Dataset
import copy
from tqdm import tqdm 
import numpy as np 
import random
# For tracking training progress

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


In [2]:
# Example list of labels
label_names = ['cat', 'dog', 'bird', 'fish', 'car', 'aircraft', 'flower', 'truck', 'parachute', 'mushroom']

# Create a mapping from label names to indices
label_to_index = {label: idx for idx, label in enumerate(label_names)}


In [3]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((256, 256)),  
    transforms.ToTensor(),          
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

from PIL import Image

def apply_transform(example):
    transformed_images = [transform(img) for img in example['image']]
    labels = [label_to_index[label] for label in example['label']]
 
    return {
        'image': transformed_images, 
        'label': torch.tensor(labels)  
    }



In [4]:
def prepare_custom_dataloader(dataset, batch_size=16):
    dataset = dataset.with_transform(apply_transform)
    
    # Create dataloaders
    train_loader = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset['test'], batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [5]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        # Define layers
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(32 * 64 * 64, 128)  
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(nn.ReLU()(self.conv1(x)))  
        x = self.pool((nn.ReLU()(self.conv2(x))))  
        x = x.view(-1, 32 * 64 * 64) 
        x = nn.ReLU()(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x) 
        return x

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy

mu = 0.01  
def fedprox_loss(local_model, global_model, base_loss, mu):
    # Compute the base loss (e.g., cross-entropy loss)
    loss = base_loss

    prox_term = 0.0
    for local_param, global_param in zip(local_model.parameters(), global_model.parameters()):
        prox_term += torch.sum(torch.pow(local_param - global_param, 2))
    
    loss += (mu / 2) * prox_term
    return loss



In [7]:
# Function to update local model on each client
def train_local_model(local_model, global_model, train_loader, epochs):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(local_model.parameters(), lr=0.001)
    local_model.train()
    
    for epoch in range(epochs):
        total_loss=0
        for batch in train_loader:
            images = batch['image']  # This should be a tensor
            labels = batch['label'] 
            optimizer.zero_grad()
            outputs = local_model(images)
            base_loss = criterion(outputs, labels)
            loss = fedprox_loss(local_model, global_model, base_loss, mu)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch [{epoch + 1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")
    return local_model.state_dict()  # Return the trained model's weights

# Function to average weights across clients
def federated_avg(global_model, client_weights):
    avg_weights = copy.deepcopy(client_weights[0])
    
    for key in avg_weights.keys():
        for i in range(1, len(client_weights)):
            avg_weights[key] += client_weights[i][key]
        avg_weights[key] = avg_weights[key] / len(client_weights)
    
    global_model.load_state_dict(avg_weights)
    return global_model


In [8]:
dataset_1 = load_dataset("AnnantJain/client1_federated_dataset_modified")
dataset_2 = load_dataset("AnnantJain/client2_federated_dataset_modified")
dataset_3 = load_dataset("AnnantJain/client3_federated_dataset_modified")
dataset_4 = load_dataset("AnnantJain/client4_federated_dataset_modified")
dataset_5 = load_dataset("AnnantJain/client5_federated_dataset_modified")

In [9]:
train_loader_1, test_loader_1 = prepare_custom_dataloader(dataset_1)
train_loader_2, test_loader_2 = prepare_custom_dataloader(dataset_2)
train_loader_3, test_loader_3 = prepare_custom_dataloader(dataset_3)
train_loader_4, test_loader_4 = prepare_custom_dataloader(dataset_4)
train_loader_5, test_loader_5 = prepare_custom_dataloader(dataset_5)

In [11]:
for batch in train_loader_1:
    
    images = batch['image'] 
    labels = batch['label'] 
    break  
print("Sample Image Tensor Shape:", images.shape) 
print("Sample Label Tensor:", labels) 

Sample Image Tensor Shape: torch.Size([16, 3, 256, 256])
Sample Label Tensor: tensor([6, 2, 5, 2, 5, 2, 5, 6, 5, 2, 5, 5, 5, 6, 2, 5])


In [12]:

num_classes = 10  
global_model = SimpleCNN(num_classes=num_classes)


In [13]:
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in test_loader:
            images = batch['image'] 
            labels = batch['label'] 
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

In [15]:

num_rounds = 5

local_epochs = 2

for round_num in range(num_rounds):
    print(f"Round {round_num + 1}/{num_rounds}")
    
    client_weights = []
    
    for client_id, train_loader in enumerate([train_loader_1, train_loader_2, train_loader_3, train_loader_4, train_loader_5]):
        local_model = copy.deepcopy(global_model)  
        local_weights = train_local_model(local_model, global_model, train_loader, local_epochs)  
        client_weights.append(local_weights) 
    
    global_model = federated_avg(global_model, client_weights)
    
    print(f"Completed round {round_num + 1}")

    for client_id, test_loader in enumerate([test_loader_1, test_loader_2, test_loader_3, test_loader_4, test_loader_5]):
        accuracy = evaluate_model(global_model, test_loader)
        print(f"Client {client_id + 1} Test Accuracy: {accuracy:.2f}%")


Round 1/5
Epoch [1/2], Loss: 1.7720
Epoch [2/2], Loss: 0.9070
Epoch [1/2], Loss: 1.9075
Epoch [2/2], Loss: 1.3176
Epoch [1/2], Loss: 2.3431
Epoch [2/2], Loss: 1.5937
Epoch [1/2], Loss: 2.9072
Epoch [2/2], Loss: 1.6581
Epoch [1/2], Loss: 2.1891
Epoch [2/2], Loss: 1.6989
Completed round 1
Client 1 Test Accuracy: 75.56%
Client 2 Test Accuracy: 27.56%
Client 3 Test Accuracy: 30.28%
Client 4 Test Accuracy: 27.67%
Client 5 Test Accuracy: 35.62%
Round 2/5
Epoch [1/2], Loss: 1.0042
Epoch [2/2], Loss: 0.8594
Epoch [1/2], Loss: 1.3196
Epoch [2/2], Loss: 1.2947
Epoch [1/2], Loss: 1.6983
Epoch [2/2], Loss: 1.6872
Epoch [1/2], Loss: 1.6476
Epoch [2/2], Loss: 1.6377
Epoch [1/2], Loss: 1.7821
Epoch [2/2], Loss: 1.7457
Completed round 2
Client 1 Test Accuracy: 75.19%
Client 2 Test Accuracy: 47.78%
Client 3 Test Accuracy: 36.94%
Client 4 Test Accuracy: 36.67%
Client 5 Test Accuracy: 47.50%
Round 3/5
Epoch [1/2], Loss: 0.8556
Epoch [2/2], Loss: 0.7774
Epoch [1/2], Loss: 1.2316
Epoch [2/2], Loss: 1.2226
