In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset, concatenate_datasets, Dataset
import copy
from tqdm import tqdm  # For tracking training progress


In [2]:
import torch
import numpy as np
import random

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [3]:
# Example list of labels
label_names = ['cat', 'dog', 'bird', 'fish', 'car', 'aircraft', 'flower', 'truck', 'parachute', 'mushroom']

# Create a mapping from label names to indices
label_to_index = {label: idx for idx, label in enumerate(label_names)}


In [4]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3), # Convert grayscale to 3 channels (RGB)
    transforms.Resize((256, 256)),  # Resize all images to 256x256
    transforms.ToTensor(),          # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

from PIL import Image

def apply_transform(example):
    # Check if 'example['image']' is a list (batch of images)
    transformed_images = [transform(img) for img in example['image']]
    labels = [label_to_index[label] for label in example['label']]
    # Return the transformed images and the unchanged labels
    return {
        'image': transformed_images,  # Stack to create a single tensor
        'label': torch.tensor(labels)  # Convert labels to tensor
    }

# Apply the transformations to the dataset (train + test split for each 


In [5]:
def prepare_custom_dataloader(dataset, batch_size=16):
    # Apply the transformation to each sample in the dataset
    dataset = dataset.with_transform(apply_transform)
    
    # Create dataloaders
    train_loader = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset['test'], batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [6]:
dataset_1 = load_dataset("AnnantJain/client1_federated_dataset_modified")
dataset_2 = load_dataset("AnnantJain/client2_federated_dataset_modified")
dataset_3 = load_dataset("AnnantJain/client3_federated_dataset_modified")
dataset_4 = load_dataset("AnnantJain/client4_federated_dataset_modified")
dataset_5 = load_dataset("AnnantJain/client5_federated_dataset_modified")

In [7]:
train_loader_1, test_loader_1 = prepare_custom_dataloader(dataset_1)
train_loader_2, test_loader_2 = prepare_custom_dataloader(dataset_2)
train_loader_3, test_loader_3 = prepare_custom_dataloader(dataset_3)
train_loader_4, test_loader_4 = prepare_custom_dataloader(dataset_4)
train_loader_5, test_loader_5 = prepare_custom_dataloader(dataset_5)

In [8]:
class CNN(nn.Module):
    def __init__(self, num_classes):
        super(CNN, self).__init__()
        # Define layers
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(32 * 64 * 64, 128)  # Adjust based on output size from conv layers
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(nn.ReLU()(self.conv1(x)))  # Conv Layer 1
        x = self.pool((nn.ReLU()(self.conv2(x))))  # Conv Layer 2
        x = x.view(-1, 32 * 64 * 64)  # Flatten for fully connected layer
        x = nn.ReLU()(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x) 
        return x

In [9]:
def train_local(model, train_loader, criterion, optimizer, epochs=2):
    model.train()
    for epoch in range(epochs):
        for batch in train_loader:
            images = batch['image']  # This should be a tensor
            labels = batch['label'] 
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in test_loader:
            images = batch['image']  # This should be a tensor
            labels = batch['label'] 
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

# Function to average the weights of global layers (FedAvg)
def average_global_weights(global_model, client_models):
    global_state_dict = global_model.state_dict()
    for key in global_state_dict.keys():
        # Average the global layers (shared part)
        global_state_dict[key] = torch.mean(
            torch.stack([client_models[i].state_dict()[key] for i in range(len(client_models))]), dim=0
        )
    global_model.load_state_dict(global_state_dict)



def average_global_weights1(global_model, client_models, client_weights):
    global_state_dict = global_model.state_dict()
    for key in global_state_dict.keys():
        weighted_sum = torch.zeros_like(global_state_dict[key])
        total_weight = 0.0

        # Weighted sum of the model weights from clients based on their performance
        for i, client_model in enumerate(client_models):
            client_weight = client_weights[i]
            weighted_sum += client_weight * client_model.state_dict()[key]
            total_weight += client_weight

        global_state_dict[key] = weighted_sum / total_weight

    global_model.load_state_dict(global_state_dict)


# Average pruned weights across all clients for selective layers
def selective_average_global_weights(global_model, client_models, client_weights):
    global_state_dict = global_model.state_dict()
    for key in global_state_dict.keys():
        weighted_sum = torch.zeros_like(global_state_dict[key])
        total_weight = 0.0
        for i, client_model in enumerate(client_models):
            client_weight = client_weights[i]
            weighted_sum += client_weight * client_model.state_dict().get(key, 0)
            total_weight += client_weight
        global_state_dict[key] = weighted_sum / total_weight
    global_model.load_state_dict(global_state_dict, strict=False)


In [10]:
# 2. Prune parameters by retaining only the top fraction of important weights
def prune_model_weights(model, prune_fraction=0.3):
    pruned_state_dict = {}
    for name, param in model.state_dict().items():
        threshold = torch.quantile(param.abs(), prune_fraction)
        pruned_param = param * (param.abs() > threshold)  # Zero out less important weights
        pruned_state_dict[name] = pruned_param
    return pruned_state_dict

def distill_logits(logits, targets, temperature=2.0):
    return nn.functional.softmax(logits / temperature, dim=1)

# 1. Freeze selective layers after a few rounds
def freeze_layers(model, layers_to_freeze=['conv1', 'conv2']):
    for name, param in model.named_parameters():
        if any(layer in name for layer in layers_to_freeze):
            param.requires_grad = False


In [11]:
# Function to determine the cluster based on cosine similarity
def determine_cluster(client_model, cluster_models):
    max_similarity = -1
    best_cluster_id = 0
    client_state_dict = client_model.state_dict()
    
    for cluster_id, cluster_model in enumerate(cluster_models):
        cluster_state_dict = cluster_model.state_dict()
        similarity_sum = 0
        num_layers = 0
        
        for layer_name in client_state_dict:
            if layer_name in cluster_state_dict:
                client_layer = client_state_dict[layer_name].flatten()
                cluster_layer = cluster_state_dict[layer_name].flatten()
                
                similarity = cosine_similarity(client_layer.unsqueeze(0), cluster_layer.unsqueeze(0), dim=1)
                similarity_sum += similarity.item()
                num_layers += 1

        avg_similarity = similarity_sum / num_layers if num_layers > 0 else 0

        if avg_similarity > max_similarity:
            max_similarity = avg_similarity
            best_cluster_id = cluster_id

    return best_cluster_id


def average_cluster_weights(global_model, cluster_models, cluster_weights):
    global_state_dict = global_model.state_dict()
    for key in global_state_dict.keys():
        weighted_sum = torch.zeros_like(global_state_dict[key])
        total_weight = 0.0
        for i, client_model in enumerate(cluster_models):
            client_weight = cluster_weights[i]
            weighted_sum += client_weight * client_model.state_dict().get(key, 0)
            total_weight += client_weight
        global_state_dict[key] = weighted_sum / total_weight
    global_model.load_state_dict(global_state_dict, strict=False)

In [12]:
clients = [
    (train_loader_1, test_loader_1),
    (train_loader_2, test_loader_2),
    (train_loader_3, test_loader_3),
    (train_loader_4, test_loader_4),
    (train_loader_5, test_loader_5)
]

# Initialize the global model
global_model = CNN(num_classes=10)
num_clusters = 2

In [14]:
from torch.nn.functional import cosine_similarity
def federated_learning(clients, global_model, num_rounds=5, prune_fraction=0.3, freeze_after_round=3):
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    # Initialize clusters with copies of the global model
    cluster_models = [copy.deepcopy(global_model) for _ in range(num_clusters)]
    
    for round in range(num_rounds):
        print(f"Round {round+1}/{num_rounds}")

        # Step 1: Train personalized models locally
        local_models = []
        client_weights = []
        for i, (train_loader, test_loader) in enumerate(clients):
            print(f"Client {i+1} local training...")
            local_model = copy.deepcopy(global_model)  # Clone the global model
            
            optimizer = optim.Adam(local_model.parameters(), lr=0.001)

            # Train on each client's data
            train_local(local_model, train_loader, criterion, optimizer)
            local_models.append(local_model)

            acc = evaluate(local_model, test_loader)
            client_weights.append(acc)

        # Step 2: Average global layers across clients (FedAvg)
        cluster_clients = [[] for _ in range(num_clusters)]
        cluster_client_weights = [[] for _ in range(num_clusters)]

        for i, local_model in enumerate(local_models):
            cluster_id = determine_cluster(local_model, cluster_models)
            cluster_clients[cluster_id].append(local_model)
            cluster_client_weights[cluster_id].append(client_weights[i])

        # Update each cluster model by averaging weights
        for cluster_id in range(num_clusters):
            if cluster_clients[cluster_id]:
                print(f"Updating Cluster {cluster_id+1} model with {len(cluster_clients[cluster_id])} clients.")
                average_cluster_weights(cluster_models[cluster_id], cluster_clients[cluster_id], cluster_client_weights[cluster_id])


        # Step 3: Evaluate each personalized model after federated update
        for i, (train_loader, test_loader) in enumerate(clients):
            acc = evaluate(local_models[i], test_loader)
            print(f"Client {i+1} Accuracy: {acc * 100:.2f}%")

# Run federated learning with client-specific noise adaptation
federated_learning(clients, global_model)

Round 1/5
Client 1 local training...
Client 2 local training...
Client 3 local training...
Client 4 local training...
Client 5 local training...
Updating Cluster 1 model with 5 clients.
Client 1 Accuracy: 77.04%
Client 2 Accuracy: 72.89%
Client 3 Accuracy: 55.83%
Client 4 Accuracy: 55.33%
Client 5 Accuracy: 64.38%
Round 2/5
Client 1 local training...
Client 2 local training...
Client 3 local training...
Client 4 local training...
Client 5 local training...
Updating Cluster 1 model with 5 clients.
Client 1 Accuracy: 80.37%
Client 2 Accuracy: 71.33%
Client 3 Accuracy: 49.44%
Client 4 Accuracy: 50.00%
Client 5 Accuracy: 57.71%
Round 3/5
Client 1 local training...
Client 2 local training...
Client 3 local training...
Client 4 local training...
Client 5 local training...
Updating Cluster 1 model with 5 clients.
Client 1 Accuracy: 81.85%
Client 2 Accuracy: 76.67%
Client 3 Accuracy: 54.44%
Client 4 Accuracy: 55.67%
Client 5 Accuracy: 65.00%
Round 4/5
Client 1 local training...
Client 2 local 