In [1]:
%%time
%reload_ext autoreload
%autoreload
%autoreload 2
%config Completer.use_jedi = False

import os
import json
import warnings
import pprint
import sys
import numpy as np
import pandas as pd
import torch

MAIN_PATH = os.getcwd().split("notebooks")[0]
sys.path.insert(0, MAIN_PATH)

CPU times: user 2.03 s, sys: 408 ms, total: 2.44 s
Wall time: 1.57 s


In [2]:
# Run a Federated Learning experiment
from data_loader.mnist import MnistDatasetManager
from models.mnist_model import MnistFullConnectModel, SimpleCNN
from server.base_server import BaseServer
from client.base_client import BaseClient
from experiments.base_experiment import BaseExperiment
from gradients.noise import GaussianNoiseGenerator, NoNoiseGenerator,StaircaseNoiseGenerator
from metrics.classification import multiclass_accuracy

In [5]:
class DemoMnistExperiment(BaseExperiment):
    def __init__(self, client_num: int =2):
        self._init_server_clients(client_num)
        self._init_data(client_num)

    
    def _init_server_clients(self, client_num):
        model = SimpleCNN
        self.clients = [BaseClient(model(), 
                                   client_id=idx, 
                                   noise_generator=StaircaseNoiseGenerator())
                        for idx in range(client_num)]
        self.server = BaseServer(model())

    def _init_data(self, client_num):
        data_manager = MnistDatasetManager(n_parties=client_num)
        self.client_train_datas = data_manager.train_loaders
        self.valid_datas = data_manager.validation_loader
        self.test_data = data_manager.test_loader

    def evaluate_model(self, data):
        total_correct = 0
        total_sample_num = 0
        with torch.no_grad():
            for _, (inputs, target) in enumerate(data):
                predict_labels = self.server.predict(inputs)
                correct, sample_num = multiclass_accuracy(y_pred=predict_labels, 
                                                          y_true=target)
                total_correct += correct
                total_sample_num += sample_num
                
        return total_correct / total_sample_num
        
    def get_validation_result(self):
        return self.evaluate_model(self.valid_datas)
    
    def get_test_result(self):
        return self.evaluate_model(self.test_data)
    
    def aggeragate(self):
        self.server.aggeragate_model(self.clients)
    
    def run(self, epochs: int):
        for client in self.clients:
            client.set_training_mode(for_gradient=False)

        for epoch in range(epochs):
            print(self.get_validation_result())
            for client, client_train_data in self.shuffled_data(to_shuffle=False):
                client.train(client_train_data)

            self.aggeragate()
            if epoch and not (epoch % 1):
                print(self.get_validation_result())

            self.distribute_model()


In [6]:
EXPERIMENT = DemoMnistExperiment(client_num=10)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz


1.0%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /Users/shuyafeng/IdeaProjects/DifferentialPrivacy/datasets/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting /Users/shuyafeng/IdeaProjects/DifferentialPrivacy/datasets/MNIST/raw/train-images-idx3-ubyte.gz to /Users/shuyafeng/IdeaProjects/DifferentialPrivacy/datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /Users/shuyafeng/IdeaProjects/DifferentialPrivacy/datasets/MNIST/raw/train-labels-idx1-ubyte.gz


100.0%


Extracting /Users/shuyafeng/IdeaProjects/DifferentialPrivacy/datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /Users/shuyafeng/IdeaProjects/DifferentialPrivacy/datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /Users/shuyafeng/IdeaProjects/DifferentialPrivacy/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting /Users/shuyafeng/IdeaProjects/DifferentialPrivacy/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /Users/shuyafeng/IdeaProjects/DifferentialPrivacy/datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz


100.0%

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /Users/shuyafeng/IdeaProjects/DifferentialPrivacy/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting /Users/shuyafeng/IdeaProjects/DifferentialPrivacy/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /Users/shuyafeng/IdeaProjects/DifferentialPrivacy/datasets/MNIST/raw






In [7]:
!which python

python not found


In [13]:
EXPERIMENT.run(10)

0.06508333333333334
0.18666666666666668
0.8991666666666667
0.8991666666666667
0.9379166666666666
0.9379166666666666
0.9553333333333334
0.9553333333333334
0.9644166666666667
0.9644166666666667
0.9688333333333333
0.9688333333333333
0.974
0.974
0.9775833333333334
0.9775833333333334
0.9786666666666667
0.9786666666666667
0.9815833333333334
