In [1]:
import torch
import torchvision
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import random_split, Dataset, DataLoader, ConcatDataset, Subset
import torch.optim.lr_scheduler as lr_scheduler
from sklearn.model_selection import train_test_split
from torchsummary import summary
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
torch.manual_seed(77)

<torch._C.Generator at 0x7f0c00931190>

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3),
                                    nn.ReLU(),
                                    nn.MaxPool2d(kernel_size=2, stride=2),
                                    nn.BatchNorm2d(32))
        self.layer2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3),
                                    nn.ReLU(),
                                    nn.BatchNorm2d(64),
                                    nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer3 = nn.Sequential(nn.Linear(2304, 128),
                                    nn.ReLU(),
                                    nn.Dropout(p = 0.5))
        self.fc = nn.Linear(128, 10)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = x.view(x.size(0), -1)
        x = self.layer3(x)
        x = self.fc(x)
        return x


In [None]:
model = torchvision.models.vgg19(pretrained=True)


for param in model.features.parameters():
  param.requires_grad = False

model.classifier

#setting a new classifier from scratch
model.classifier = nn.Sequential(model.classifier,
                                 nn.ReLU(),
                                 nn.Dropout(p = 0.5),
                                 nn.Linear(1000, 10))
model.cuda()


model.classifier[0][0].weight.requires_grad = False


summary(model.cuda(), (3, 224, 224))



----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256,

In [6]:
from IPython.utils.text import indent
def train_on_client(client_data, ind, model, criterion, optimizer, num_epochs, device, bs):
    train_loader = DataLoader(client_data, batch_size=bs, shuffle=True)
    scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5, total_iters=30)

    for epoch in range(num_epochs):
        model.train()
        loss = 0.0
        
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss += loss.item() #* images.size(0)
            #print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))

        scheduler.step()
        epoch_loss = loss / len(train_loader)
        print(f"Client {ind} Training: Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")


In [26]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the server model
server_model = CNN().to(device)

# Define the hyperparameters
lr = 1e-3
weight_decay = 1e-4
num_epochs = 7
batch_size = 32
num_rounds = 5
num_clients = 4

criterion = nn.CrossEntropyLoss()

In [27]:
transform = transforms.Compose([transforms.
                                transforms.ToTensor(), 
                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# partitioning data across clients
client_data = torch.utils.data.random_split(train_dataset, [len(train_dataset) // num_clients] * num_clients)

Files already downloaded and verified


In [28]:
clients = [[],[],[],[]]

for round in range(num_rounds):

    print(f"---------- Round {round+1} ----------")
    # Aggregate client models on the server
    server_model.train()
    server_model.zero_grad()
    
    for client_idx in range(num_clients):
        client_model = CNN().to(device)
        client_model.load_state_dict(server_model.state_dict())
        optimizer = torch.optim.Adam(client_model.parameters(), lr = lr, weight_decay = weight_decay)  
        
        # Train on the client's local dataset
        train_on_client(client_data[client_idx], client_idx+1, client_model, 
                        criterion, optimizer, num_epochs, device, batch_size)
        clients[client_idx] = client_model

    sd_server = server_model.state_dict()
    sd_clients = [cl.state_dict() for cl in clients]
    for key in sd_server:
      for sd_client in sd_clients:
        sd_server[key] += sd_client[key]
      sd_server[key] = (sd_server[key] / num_clients).float()
    
    server_model.load_state_dict(sd_server)

---------- Round 1 ----------
Client 1 Training: Epoch 1/7, Loss: 0.0071
Client 1 Training: Epoch 2/7, Loss: 0.0079
Client 1 Training: Epoch 3/7, Loss: 0.0071
Client 1 Training: Epoch 4/7, Loss: 0.0068
Client 1 Training: Epoch 5/7, Loss: 0.0047
Client 1 Training: Epoch 6/7, Loss: 0.0033
Client 1 Training: Epoch 7/7, Loss: 0.0041
Client 2 Training: Epoch 1/7, Loss: 0.0085
Client 2 Training: Epoch 2/7, Loss: 0.0061
Client 2 Training: Epoch 3/7, Loss: 0.0049
Client 2 Training: Epoch 4/7, Loss: 0.0039
Client 2 Training: Epoch 5/7, Loss: 0.0082
Client 2 Training: Epoch 6/7, Loss: 0.0057
Client 2 Training: Epoch 7/7, Loss: 0.0040
Client 3 Training: Epoch 1/7, Loss: 0.0051
Client 3 Training: Epoch 2/7, Loss: 0.0068
Client 3 Training: Epoch 3/7, Loss: 0.0045
Client 3 Training: Epoch 4/7, Loss: 0.0057
Client 3 Training: Epoch 5/7, Loss: 0.0061
Client 3 Training: Epoch 6/7, Loss: 0.0049
Client 3 Training: Epoch 7/7, Loss: 0.0035
Client 4 Training: Epoch 1/7, Loss: 0.0074
Client 4 Training: Epoch

In [29]:
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size = 16)
loss_test = []

with torch.no_grad():
  correct = 0
  total = 0
  for i, (input, target) in enumerate(test_dataloader):

        target = target.to(device)
        input = input.to(device)

        # compute output
        output = server_model(input)
        loss = criterion(output, target)
        loss_test.append(loss.item())

        total += target.size(0)
        _, predicted = torch.max(output.data, 1)
        correct += (predicted == target).sum().item()
        
  print('Accuracy on the test images: {} %'.format(100 * correct / total)) 

Files already downloaded and verified
Accuracy on the test images: 68.09 %


In [31]:
from google.colab import drive
drive.mount('/content/drive')

torch.save(server_model.state_dict(),'/content/drive/MyDrive/PrivacyModels/FedAvg-model.pt')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
model.load_state_dict(torch.load('/content/drive/MyDrive/PrivacyModels/FedAvg-model.pt'))
model.cuda()

# Comparison with Global Model

In [32]:
global_data = train_dataset
global_dataloader = DataLoader(global_data, batch_size=batch_size, shuffle=True)

lr = 1e-3
weight_decay = 1e-4
num_epochs = 25
batch_size = 32

global_model = CNN().to(device)
optimizer = torch.optim.Adam(global_model.parameters(), lr = lr, weight_decay = weight_decay)  
scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5, total_iters=30)

print("Global Model Training:")
for epoch in range(num_epochs):
  loss_global = 0.0

  for i, (input, target) in enumerate(global_dataloader):
          
          target = target.to(device)
          input = input.to(device)

          # compute output
          output = global_model(input)
          loss = criterion(output, target)
          loss_global += loss.item()
          # backpropagate
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          #print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_dataloader), loss.item()))
  epoch_loss = loss_global / len(global_dataloader)
  scheduler.step()
  print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")


Global Model Training:
Epoch 1/25, Loss: 1.4133
Epoch 2/25, Loss: 1.1114
Epoch 3/25, Loss: 0.9787
Epoch 4/25, Loss: 0.8942
Epoch 5/25, Loss: 0.8243
Epoch 6/25, Loss: 0.7614
Epoch 7/25, Loss: 0.7120
Epoch 8/25, Loss: 0.6673
Epoch 9/25, Loss: 0.6378
Epoch 10/25, Loss: 0.5973
Epoch 11/25, Loss: 0.5663
Epoch 12/25, Loss: 0.5378
Epoch 13/25, Loss: 0.5196
Epoch 14/25, Loss: 0.4897
Epoch 15/25, Loss: 0.4698
Epoch 16/25, Loss: 0.4557
Epoch 17/25, Loss: 0.4324
Epoch 18/25, Loss: 0.4206
Epoch 19/25, Loss: 0.3964
Epoch 20/25, Loss: 0.3840
Epoch 21/25, Loss: 0.3733
Epoch 22/25, Loss: 0.3630
Epoch 23/25, Loss: 0.3535
Epoch 24/25, Loss: 0.3377
Epoch 25/25, Loss: 0.3287


In [33]:
with torch.no_grad():
  correct = 0
  total = 0
  for i, (input, target) in enumerate(test_dataloader):

        target = target.to(device)
        input = input.to(device)

        # compute output
        output = global_model(input)
        loss = criterion(output, target)
      
        total += target.size(0)
        _, predicted = torch.max(output.data, 1)
        correct += (predicted == target).sum().item()
        
  print('Accuracy of Global Model the test images: {} %'.format(100 * correct / total)) 

Accuracy of Global Model the test images: 69.21 %


**As we can see, the federated model has an accuracy close to the global model, even though we only trained the model for 5 rounds. In practice clients are trained in parallel and therefore takes a much lesser time. Also, asynchronous implementation of FedAvg is possible as well, which is not the desired here. Note that this was just a relatively simple model and more engineering on the model could yeild better results.**

(I tried the FedAvg algorithm for fine-tuning the VGG19 and VGG16 models as well; however their results were disappointing! (not sure about the reason behind this.)