In [1]:
import torch

In [2]:
%%time
%reload_ext autoreload
%autoreload
%autoreload 2
%config Completer.use_jedi = False

import os
import json
import warnings
import pprint
import sys
import numpy as np
import pandas as pd

import torch

MAIN_PATH = os.getcwd().split("notebooks")[0]
sys.path.insert(0, MAIN_PATH)

CPU times: total: 344 ms
Wall time: 338 ms


In [3]:
torch.cuda.is_available()

True

In [4]:
# Run a Federated Learning experiment
from data_loader.cifar10 import Cifar10DatasetManager
from server.base_server import BaseServer
from client.base_client import BaseClient
from experiments.base_experiment import BaseExperiment
from gradients.noise import GaussianNoiseGenerator, NoNoiseGenerator,StaircaseNoiseGenerator
from metrics.classification import multiclass_accuracy
from models.cifar_model import SimpleCifarCNN, EfficientCifarCNN,ResNet

  warn(


In [48]:
class DemoCifar10Experiment(BaseExperiment):
    def __init__(self, 
                 client_num: int = 2, 
                 lr: float = 0.01, 
                 noise_generator=None,
                 max_norm: float = 200,
                 sampling_rate: float = 0.05):
        if noise_generator is None:
            noise_generator = NoNoiseGenerator()
        self.noise_generator = noise_generator
        self.lr = lr
        self.max_norm = max_norm
        self.sampling_rate = sampling_rate
        self.client_num = client_num
        self._init_server_clients(client_num, self.lr)
        self._init_data(client_num)

    def _init_server_clients(self, client_num, lr):
        model = ResNet
        self.clients = [BaseClient(model(lr=lr, max_norm=self.max_norm), 
                                   client_id=idx, 
                                   noise_generator=self.noise_generator)
                        for idx in range(client_num)]
        self.server = BaseServer(model(lr=lr, max_norm=self.max_norm))

    def _init_data(self, client_num):
        data_manager = Cifar10DatasetManager(n_parties=client_num, 
                                             sampling_lot_rate=self.sampling_rate)
        self.client_train_datas = data_manager.train_loaders
        self.valid_datas = data_manager.validation_loader
        self.test_data = data_manager.test_loader

    def evaluate_model(self, data):
        total_correct = 0
        total_sample_num = 0
        with torch.no_grad():
            for _, (inputs, target) in enumerate(data):
                predict_labels = self.server.predict(inputs)
                correct, sample_num = multiclass_accuracy(y_pred=predict_labels, 
                                                          y_true=target)
                total_correct += correct
                total_sample_num += sample_num
                
        return total_correct / total_sample_num
        
    def get_validation_result(self):
        return self.evaluate_model(self.valid_datas)
    
    def get_test_result(self):
        return self.evaluate_model(self.test_data)
    
    def aggeragate(self):
        self.server.aggeragate_model(self.clients)
    
    def run(self, epochs: int, client_epochs: int):
        self._init_data(self.client_num)
        for client in self.clients:
            client.set_training_mode(for_gradient=False)

        for epoch in range(epochs):
            for client, client_train_data in self.shuffled_data(to_shuffle=False):
                client.train(client_train_data, client_epochs=client_epochs)

            self.aggeragate()

            self.distribute_model()
            print(self.get_validation_result())


In [49]:
EXPERIMENT = DemoCifar10Experiment(client_num=1,
                                lr = 0.001, 
                                max_norm=1000,
                                sampling_rate=0.05,
                                noise_generator=NoNoiseGenerator())


Files already downloaded and verified
Files already downloaded and verified


In [50]:
EXPERIMENT.run(2, 20)

Files already downloaded and verified
Files already downloaded and verified
0.352
0.4461


In [31]:
EXPERIMENT.aggeragate()

In [23]:
def print_param(model):
    for param in model.parameters():
        print(param.sum())
        break

In [32]:
print_param(EXPERIMENT.clients[0].model)

tensor(8.9127, device='cuda:0', grad_fn=<SumBackward0>)


In [33]:
print_param(EXPERIMENT.server.model)

tensor(8.9127, device='cuda:0', grad_fn=<SumBackward0>)


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms


In [9]:
# Data preprocessing
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Load dataset
batch_size = 32
# data_manager = Cifar10DatasetManager(n_parties=1, 
#                                              sampling_lot_rate=0.01)
# trainloader = data_manager.train_loaders
# testset = data_manager.validation_loader
# testloader = data_manager.test_loader
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


In [10]:
trainloader

<torch.utils.data.dataloader.DataLoader at 0x25713295600>

In [12]:
trainloader

<torch.utils.data.dataloader.DataLoader at 0x14cf0a4f640>

In [40]:
# Initialize model and optimizer
learning_rate = 0.001
# model = ResNet(lr=learning_rate)
model = model.to(model.device)

# Training loop
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(model.device), labels.to(model.device)

        model.optimizer.zero_grad()

        outputs = model(inputs)
        loss = model.loss_fn(outputs, labels).mean()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), model.max_norm)
        model.optimizer.step()

        running_loss += loss.item()
        
        if i % 200 == 199:  # Print average loss every 200 mini-batches
            print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 200:.3f}")
            running_loss = 0.0

# Save the trained model
torch.save(model.state_dict(), "cifar10_resnet.pth")


[1, 200] loss: 2.486
[1, 400] loss: 2.251
[1, 600] loss: 2.195
[1, 800] loss: 2.100
[1, 1000] loss: 2.090
[1, 1200] loss: 2.094
[1, 1400] loss: 2.073
[1, 1600] loss: 2.013
[1, 1800] loss: 1.984
[1, 2000] loss: 2.003
[1, 2200] loss: 1.939
[1, 2400] loss: 1.928
[1, 2600] loss: 1.910
[1, 2800] loss: 1.891
[1, 3000] loss: 1.862
[1, 3200] loss: 1.842
[1, 3400] loss: 1.918
[1, 3600] loss: 1.845
[1, 3800] loss: 1.769
[1, 4000] loss: 1.755
[1, 4200] loss: 1.816
[1, 4400] loss: 1.772
[1, 4600] loss: 1.698
[1, 4800] loss: 1.795
[1, 5000] loss: 1.747
[1, 5200] loss: 1.652
[1, 5400] loss: 1.687
[1, 5600] loss: 1.665
[1, 5800] loss: 1.644
[1, 6000] loss: 1.674
[1, 6200] loss: 1.639
[1, 6400] loss: 1.670
[1, 6600] loss: 1.614
[1, 6800] loss: 1.563
[1, 7000] loss: 1.611
[1, 7200] loss: 1.565
[1, 7400] loss: 1.563
[1, 7600] loss: 1.540
[1, 7800] loss: 1.494
[1, 8000] loss: 1.471
[1, 8200] loss: 1.560
[1, 8400] loss: 1.539
[1, 8600] loss: 1.469
[1, 8800] loss: 1.473
[1, 9000] loss: 1.418
[1, 9200] loss

KeyboardInterrupt: 

In [34]:
model = EXPERIMENT.server.model

In [35]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(model.device), labels.to(model.device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%")


Accuracy of the network on the 10000 test images: 41.02%
