In [2]:
import torch

In [3]:
%%time
%reload_ext autoreload
%autoreload
%autoreload 2
%config Completer.use_jedi = False

import os
import json
import warnings
import pprint
import sys
import numpy as np
import pandas as pd

import torch

MAIN_PATH = os.getcwd().split("notebooks")[0]
sys.path.insert(0, MAIN_PATH)


CPU times: user 96.2 ms, sys: 3.93 ms, total: 100 ms
Wall time: 99 ms


In [4]:
torch.cuda.is_available()

True

In [5]:
# Run a Federated Learning experiment
from data_loader.cifar10 import Cifar10DatasetManager
from server.base_server import BaseServer
from client.base_client import BaseClient
from experiments.base_experiment import BaseExperiment
from gradients.noise import GaussianNoiseGenerator, NoNoiseGenerator,StaircaseNoiseGenerator
from metrics.classification import multiclass_accuracy
from models.cifar_model import ResNet,ResNet18,ResNet101

In [31]:
class DemoCifar10Experiment(BaseExperiment):
    def __init__(self, 
                 client_num: int = 2, 
                 lr: float = 0.01, 
                 noise_generator=None,
                 max_norm: float = 200,
                 sampling_rate: float = 0.05):
        if noise_generator is None:
            noise_generator = NoNoiseGenerator()
        self.noise_generator = noise_generator
        self.lr = lr
        self.max_norm = max_norm
        self.sampling_rate = sampling_rate
        self.client_num = client_num
        self._init_server_clients(client_num, self.lr)
        self._init_data(client_num)

    def _init_server_clients(self, client_num, lr):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")        
        model = ResNet18
        self.clients = [BaseClient(model(lr=lr, max_norm=self.max_norm), 
                                   client_id=idx, 
                                   noise_generator=self.noise_generator)
                        for idx in range(client_num)]
        self.server = BaseServer(model(lr=lr, max_norm=self.max_norm))
        for client in self.clients:
            client.model.load_state_dict(torch.load('pre_trained_resnet18_cifar10_model.pth'))
    def _init_data(self, client_num):
        data_manager = Cifar10DatasetManager(n_parties=client_num, 
                                             sampling_lot_rate=self.sampling_rate)
        self.client_train_datas = data_manager.train_loaders
        self.valid_datas = data_manager.validation_loader
        self.test_data = data_manager.test_loader

    def evaluate_model(self, data):
        total_correct = 0
        total_sample_num = 0
        with torch.no_grad():
            for _, (inputs, target) in enumerate(data):
                predict_labels = self.server.predict(inputs)
                correct, sample_num = multiclass_accuracy(y_pred=predict_labels, 
                                                          y_true=target)
                total_correct += correct
                total_sample_num += sample_num
                
        return total_correct / total_sample_num
        
    def get_validation_result(self):
        return self.evaluate_model(self.valid_datas)
    
    def get_test_result(self):
        return self.evaluate_model(self.test_data)
    
    def aggeragate(self, clients):
        self.server.aggeragate_model(clients)
    
    def run(self, epochs: int, client_epochs: int):
        self._init_data(self.client_num)
        for client in self.clients:
            client.set_training_mode(for_gradient=False)

        for epoch in range(epochs):
            print(self.get_validation_result())
            selected_clients=[]
            for i, (client, client_train_data) in enumerate(self.shuffled_data(to_shuffle=True)):
                if i == 10:
                    break
                client.train(client_train_data, client_epochs=client_epochs)
                selected_clients.append(client)

            self.aggeragate(selected_clients)

            self.distribute_model()

        
        #torch.save(self.server.model.state_dict(), 'pre_trained_resnet18_cifar10_model.pth')

In [None]:
EXPERIMENT = DemoCifar10Experiment(client_num=10,
                                lr = 0.01, 
                                max_norm=0.1,
                                sampling_rate=0.05,
                                noise_generator=NoNoiseGenerator())
EXPERIMENT.run(150, 2)