In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from datasets import load_dataset, concatenate_datasets, Dataset
import copy
from tqdm import tqdm  # For tracking training progress


In [15]:
import torch
import numpy as np
import random

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [16]:
# Example list of labels
label_names = ['cat', 'dog', 'bird', 'fish', 'car', 'aircraft', 'flower', 'truck', 'parachute', 'mushroom']

# Create a mapping from label names to indices
label_to_index = {label: idx for idx, label in enumerate(label_names)}


In [17]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3), # Convert grayscale to 3 channels (RGB)
    transforms.Resize((256, 256)),  # Resize all images to 256x256
    transforms.ToTensor(),          # Convert images to PyTorch tensors
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

from PIL import Image

def apply_transform(example):
    # Check if 'example['image']' is a list (batch of images)
    transformed_images = [transform(img) for img in example['image']]
    labels = [label_to_index[label] for label in example['label']]
    # Return the transformed images and the unchanged labels
    return {
        'image': transformed_images,  # Stack to create a single tensor
        'label': torch.tensor(labels)  # Convert labels to tensor
    }

# Apply the transformations to the dataset (train + test split for each 


In [18]:
def prepare_custom_dataloader(dataset, batch_size=16):
    # Apply the transformation to each sample in the dataset
    dataset = dataset.with_transform(apply_transform)
    
    # Create dataloaders
    train_loader = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset['test'], batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [19]:
dataset_1 = load_dataset("AnnantJain/client1_federated_dataset_modified")
dataset_2 = load_dataset("AnnantJain/client2_federated_dataset_modified")
dataset_3 = load_dataset("AnnantJain/client3_federated_dataset_modified")
dataset_4 = load_dataset("AnnantJain/client4_federated_dataset_modified")
dataset_5 = load_dataset("AnnantJain/client5_federated_dataset_modified")

In [20]:
dataset_1

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 1530
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 270
    })
})

In [21]:
train_loader_1, test_loader_1 = prepare_custom_dataloader(dataset_1)
train_loader_2, test_loader_2 = prepare_custom_dataloader(dataset_2)
train_loader_3, test_loader_3 = prepare_custom_dataloader(dataset_3)
train_loader_4, test_loader_4 = prepare_custom_dataloader(dataset_4)
train_loader_5, test_loader_5 = prepare_custom_dataloader(dataset_5)

In [22]:
# Inspect the output of the DataLoader
batch = next(iter(train_loader_1))
print(type(batch))
print(len(batch))
print(batch)  # Print to inspect the content


<class 'dict'>
2
{'image': tensor([[[[ 1.5810,  1.4783,  1.3242,  ..., -0.8507, -0.7822, -0.7308],
          [ 1.5639,  1.4612,  1.3070,  ..., -0.8678, -0.8335, -0.7993],
          [ 1.5468,  1.4440,  1.2899,  ..., -0.8507, -0.8678, -0.8678],
          ...,
          [-0.2171, -0.2171, -0.2171,  ..., -0.5082, -0.4397, -0.4054],
          [-0.2171, -0.2171, -0.2171,  ..., -0.6623, -0.5596, -0.4911],
          [-0.2171, -0.2171, -0.2171,  ..., -0.7822, -0.6623, -0.5767]],

         [[ 1.7458,  1.6408,  1.4832,  ..., -0.7402, -0.6702, -0.6176],
          [ 1.7283,  1.6232,  1.4657,  ..., -0.7577, -0.7227, -0.6877],
          [ 1.7108,  1.6057,  1.4482,  ..., -0.7402, -0.7577, -0.7577],
          ...,
          [-0.0924, -0.0924, -0.0924,  ..., -0.3901, -0.3200, -0.2850],
          [-0.0924, -0.0924, -0.0924,  ..., -0.5476, -0.4426, -0.3725],
          [-0.0924, -0.0924, -0.0924,  ..., -0.6702, -0.5476, -0.4601]],

         [[ 1.9603,  1.8557,  1.6988,  ..., -0.5147, -0.4450, -0.3927],
   

In [23]:
class CNN(nn.Module):
    def __init__(self, num_classes):
        super(CNN, self).__init__()
        # Define layers
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(32 * 64 * 64, 128)  # Adjust based on output size from conv layers
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(nn.ReLU()(self.conv1(x)))  # Conv Layer 1
        x = self.pool((nn.ReLU()(self.conv2(x))))  # Conv Layer 2
        x = x.view(-1, 32 * 64 * 64)  # Flatten for fully connected layer
        x = nn.ReLU()(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x) 
        return x

In [None]:
import torch.nn.functional as F

def federated_divergence_aware_loss(local_output, global_output, target, noise_ratio, alpha=0.1, beta=0.05, gamma=0.1):
    local_probs = F.softmax(local_output, dim=1)
    global_probs = F.softmax(global_output, dim=1)

    # Cross-Entropy Loss
    ce_loss = F.cross_entropy(local_output, target)

    # KL Divergence for Federated Alignment
    kl_div_loss = F.kl_div(local_probs.log(), global_probs, reduction='batchmean')

    # Entropy Regularization
    entropy_loss = -torch.mean(torch.sum(local_probs * local_probs.log(), dim=1))

    # Confidence Calibration
    confidence_scores = torch.max(local_probs, dim=1)[0]
    confidence_loss = torch.mean((1 - confidence_scores) * torch.norm(local_probs - global_probs, dim=1))

    # Total Loss
    total_loss = (1 - noise_ratio) * ce_loss + alpha * kl_div_loss + beta * entropy_loss + gamma * confidence_loss
    return total_loss



def train_local_with_fdal(local_model, global_model, train_loader, optimizer, noise_ratio, alpha=0.1, beta=0.05, gamma=0.1):
    local_model.train()
    global_model.eval()
    
    for epoch in range(5):  # 5 local epochs
        total_loss = 0.0
        for batch in train_loader:
            data, target = batch['image'], batch['label']

            # Forward pass
            local_output = local_model(data)
            with torch.no_grad():
                global_output = global_model(data)

            # Compute FDAL loss
            loss = federated_divergence_aware_loss(local_output, global_output, target, noise_ratio, alpha, beta, gamma)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        print(f"Local training epoch loss: {total_loss / len(train_loader):.4f}")

def evaluate(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in test_loader:
            images = batch['image']  # This should be a tensor
            labels = batch['label'] 
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

def evaluate_global_model(global_model, clients):
    """
    Evaluate the global model on all clients.
    Args:
        global_model: The global CNN model.
        clients: List of client dataloaders [(train_loader, test_loader), ...].
    """
    global_model.eval()
    for i, (_, test_loader) in enumerate(clients):
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                output = global_model(data)
                preds = output.argmax(dim=1)
                correct += (preds == target).sum().item()
                total += target.size(0)
        acc = correct / total
        print(f"Client {i + 1} Accuracy: {acc * 100:.2f}%")


# Average pruned weights across all clients for selective layers
def selective_average_global_weights(global_model, client_models, client_weights):
    global_state_dict = global_model.state_dict()
    for key in global_state_dict.keys():
        weighted_sum = torch.zeros_like(global_state_dict[key])
        total_weight = 0.0
        for i, client_model in enumerate(client_models):
            client_weight = client_weights[i]
            weighted_sum += client_weight * client_model.state_dict().get(key, 0)
            total_weight += client_weight
        global_state_dict[key] = weighted_sum / total_weight
    global_model.load_state_dict(global_state_dict, strict=False)
    return global_model


In [29]:
# 2. Prune parameters by retaining only the top fraction of important weights
def prune_model_weights(model, prune_fraction=0.3):
    pruned_state_dict = {}
    for name, param in model.state_dict().items():
        threshold = torch.quantile(param.abs(), prune_fraction)
        pruned_param = param * (param.abs() > threshold)  # Zero out less important weights
        pruned_state_dict[name] = pruned_param
    return pruned_state_dict


# 1. Freeze selective layers after a few rounds
def freeze_layers(model, layers_to_freeze=['conv1', 'conv2']):
    for name, param in model.named_parameters():
        if any(layer in name for layer in layers_to_freeze):
            param.requires_grad = False


In [30]:
clients = [
    (train_loader_1, test_loader_1),
    (train_loader_2, test_loader_2),
    (train_loader_3, test_loader_3),
    (train_loader_4, test_loader_4),
    (train_loader_5, test_loader_5)
]

# Initialize the global model
global_model = CNN(num_classes=10)

In [None]:
def save_models(local_models, round_number):
    for i, model in enumerate(local_models):
        torch.save(model.state_dict(), f"client_{i+1}_round_{round_number}.pth")
        
def federated_learning(clients, global_model, num_rounds=5, prune_fraction=0.3, freeze_after_round=3, alpha=0.1, beta=0.05, gamma=0.1, noise_ratios=None):
    if noise_ratios is None:
        noise_ratios = [0.65, 0.65, 0.65, 0.65, 0.65]  # Default noise ratios 

    for round in range(num_rounds):
        print(f"--- Round {round+1}/{num_rounds} ---")

        # Step 1: Train personalized models locally
        local_models = []
        client_weights = []
        for i, (train_loader, test_loader) in enumerate(clients):
            print(f"Client {i+1} local training...")
            local_model = copy.deepcopy(global_model)  # Clone the global model
            if round >= freeze_after_round:
                freeze_layers(local_model)
            
            optimizer = optim.Adam(local_model.parameters(), lr=0.001)

            # Train on each client's data with FDAL
            train_local_with_fdal(local_model, global_model, train_loader, optimizer, noise_ratios[i], alpha, beta, gamma)

            # Prune model weights
            prune_model_weights(local_model, prune_fraction=prune_fraction)

            # Evaluate local model
            acc = evaluate(local_model, test_loader)
            print(f"Client {i+1} Accuracy: {acc * 100:.2f}%")
            
            local_models.append(local_model)
            client_weights.append(acc)  # Use accuracy as weight

        # Step 2: Aggregate global layers across clients (FedAvg)
        print("Aggregating global model...")
        selective_average_global_weights(global_model, local_models, client_weights)

        # Save personalized models
        #save_models(local_models, round)

        # Step 3: Evaluate personalized models on test data
        for i, (_, test_loader) in enumerate(clients):
            acc = evaluate(local_models[i], test_loader)
            print(f"Client {i+1} Final Accuracy: {acc * 100:.2f}%")

# Run federated learning with client-specific noise adaptation
federated_learning(clients, global_model)



--- Round 1/5 ---
Client 1 local training...
Local training epoch loss: 0.7091
Local training epoch loss: 0.4523
Local training epoch loss: 0.4132
Local training epoch loss: 0.3772
Local training epoch loss: 0.3449
Client 1 Accuracy: 85.56%
Client 2 local training...
Local training epoch loss: 0.8143
Local training epoch loss: 0.5159
Local training epoch loss: 0.4287
Local training epoch loss: 0.3694
Local training epoch loss: 0.3397
Client 2 Accuracy: 78.22%
Client 3 local training...
Local training epoch loss: 0.9128
Local training epoch loss: 0.6853
Local training epoch loss: 0.6070
Local training epoch loss: 0.5348
Local training epoch loss: 0.4703
Client 3 Accuracy: 62.22%
Client 4 local training...
Local training epoch loss: 0.9540
Local training epoch loss: 0.6504
Local training epoch loss: 0.5747
Local training epoch loss: 0.4917
Local training epoch loss: 0.4231
Client 4 Accuracy: 60.33%
Client 5 local training...
Local training epoch loss: 0.8488
Local training epoch loss: 0.