In [31]:
%load_ext autoreload
%autoreload 2

import asyncio, copy, os, socket, sys, time
from functools import partial
from multiprocessing import Pool, Process
from pathlib import Path
from tqdm import tqdm

import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
from libs import agg, data, fl, hdc, log, nn, plot, poison, resnet, sim, wandb
from libs.distributed import *
from cfgs.fedargs import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [32]:
project = 'fl-hdc'
name = 'mnist'

#Define Custom CFGs
fedargs.epochs = 10

# Save Logs To File (info | debug | warning | error | critical) [optional]
log.init("info")
wb = wandb.init(name, project)

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: wandb version 0.12.16 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [33]:
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [34]:
# Initialize Global and Client models
fedargs.model = hdc.HDC(fedargs.one_d_len, fedargs.hdc_proj_len, len(fedargs.labels), device)
hdc_model = copy.deepcopy(fedargs.model)
# Load Data to clients
train_data, test_data = data.load_dataset(fedargs.dataset, only_to_tensor = True)

In [35]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=len(train_data), shuffle=True, num_workers=4, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=len(test_data), shuffle=True, num_workers=4, pin_memory=True)

In [36]:
%%time
train_report = hdc_model.train(train_loader, device)
log.info(train_report)

2022-05-16 06:36:15,693 - <timed exec>::<module>(l:2) :               precision    recall  f1-score   support

           0       0.87      0.90      0.89      5923
           1       0.89      0.91      0.90      6742
           2       0.82      0.80      0.81      5958
           3       0.73      0.81      0.76      6131
           4       0.82      0.79      0.81      5842
           5       0.79      0.65      0.71      5421
           6       0.87      0.88      0.88      5918
           7       0.91      0.83      0.87      6265
           8       0.70      0.76      0.73      5851
           9       0.73      0.78      0.75      5949

    accuracy                           0.81     60000
   macro avg       0.81      0.81      0.81     60000
weighted avg       0.82      0.81      0.81     60000
 [MainProcess : MainThread (INFO)]


CPU times: user 1min 36s, sys: 55.6 s, total: 2min 32s
Wall time: 18.3 s


In [29]:
%%time
test_report = hdc_model.test(test_loader, device)
print(test_report)

              precision    recall  f1-score   support

           0       0.86      0.92      0.89       980
           1       0.92      0.91      0.92      1135
           2       0.85      0.79      0.82      1032
           3       0.74      0.84      0.79      1010
           4       0.83      0.81      0.82       982
           5       0.80      0.65      0.72       892
           6       0.86      0.87      0.87       958
           7       0.92      0.83      0.87      1028
           8       0.70      0.78      0.73       974
           9       0.76      0.80      0.78      1009

    accuracy                           0.82     10000
   macro avg       0.82      0.82      0.82     10000
weighted avg       0.83      0.82      0.82     10000

CPU times: user 18.1 s, sys: 8.1 s, total: 26.2 s
Wall time: 1.74 s


In [30]:
for epoch in range(fedargs.epochs):
    train_report = hdc_model.re_train(train_loader, device)
    test_report = hdc_model.test(test_loader, device)
    print(test_report)

1
2
3
3
3
3
3
3
3
3
3
3
              precision    recall  f1-score   support

           0       0.87      0.94      0.90       980
           1       0.92      0.93      0.93      1135
           2       0.84      0.81      0.82      1032
           3       0.83      0.83      0.83      1010
           4       0.87      0.87      0.87       982
           5       0.82      0.73      0.77       892
           6       0.87      0.88      0.88       958
           7       0.90      0.85      0.87      1028
           8       0.76      0.82      0.79       974
           9       0.82      0.83      0.83      1009

    accuracy                           0.85     10000
   macro avg       0.85      0.85      0.85     10000
weighted avg       0.85      0.85      0.85     10000

1
2
3
3
3
3
3
3
3
3
3
3
              precision    recall  f1-score   support

           0       0.87      0.94      0.91       980
           1       0.92      0.94      0.93      1135
           2       0.86      0