In [1]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[K     |████████████████████████████████| 5.8 MB 14.7 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 56.8 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 24.6 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.11.1 tokenizers-0.13.2 transformers-4.25.1


In [2]:
import pandas as pd
import torch
import numpy as np
from transformers import BertTokenizer, BertModel
from torch import nn
from torch.optim import Adam
from tqdm import tqdm
import copy

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
data_path = "/content/drive/MyDrive/Data/Thesis/Kaggle/filtered.csv"

df = pd.read_csv(data_path)

In [5]:
df.head()

Unnamed: 0.1,Unnamed: 0,Body,Label
0,0,naturally irresistible corporate identity lt r...,1
1,1,stock trading gunslinger fanny merrill muzo co...,1
2,2,unbelievable new homes made easy im wanting sh...,1
3,3,color printing special request additional in...,1
4,4,money get software cds software compatibility ...,1


In [6]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, df, tokenizer):

        self.labels = [label for label in df['Label']]
        self.texts = [tokenizer(text, 
                                padding='max_length', 
                                max_length = 512, 
                                truncation=True,
                                return_tensors="pt") for text in df['Body']]

    def classes(self):
        return self.labels

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):

        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)

        return batch_texts, batch_y

In [7]:
class BertClassifier(nn.Module):

    def __init__(self, bert, dropout=0.5):

        super(BertClassifier, self).__init__()

        self.bert = bert
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, 2)
        self.relu = nn.ReLU()

    def forward(self, input_id, mask):

        _, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout(pooled_output)
        linear_output = self.linear(dropout_output)
        final_layer = self.relu(linear_output)

        return final_layer

In [18]:
class Client(object):
    def __init__(self, name, df_train, df_val, model, criterion, optimizer, tokenizer, epochs):
        self.name = name
        self.train_dataloader = torch.utils.data.DataLoader(Dataset(df_train, tokenizer), batch_size=2, shuffle=True)
        self.val_dataloader = torch.utils.data.DataLoader(Dataset(df_val, tokenizer), batch_size=2, shuffle=True)
        self.len_train = len(df_train)
        self.len_val = len(df_val)
        self.len_test = len(df_test)
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.tokenizer = tokenizer
        self.epochs = epochs

    def train(self, weights):
        self.model.load_state_dict(weights)    

        train_acc = []
        train_loss = []
        val_acc = []
        val_loss = []

        use_cuda = torch.cuda.is_available()
        device = torch.device("cuda" if use_cuda else "cpu")

        if use_cuda:
            self.model = self.model.cuda()
            self.criterion = self.criterion.cuda()
        
        for epoch_num in range(self.epochs):

            print(f"Client: {self.name}, Epoch: {epoch_num + 1}")

            total_acc_train = 0
            total_loss_train = 0

            for train_input, train_label in tqdm(self.train_dataloader):

                train_label = train_label.to(device)
                mask = train_input['attention_mask'].to(device)
                input_id = train_input['input_ids'].squeeze(1).to(device)

                output = self.model(input_id, mask)
                
                batch_loss = self.criterion(output, train_label.long())
                total_loss_train += batch_loss.item()
                
                acc = (output.argmax(dim=1) == train_label).sum().item()
                total_acc_train += acc

                self.model.zero_grad()
                batch_loss.backward()
                self.optimizer.step()
            
            total_acc_val = 0
            total_loss_val = 0

            with torch.no_grad():

                for val_input, val_label in self.val_dataloader:

                    val_label = val_label.to(device)
                    mask = val_input['attention_mask'].to(device)
                    input_id = val_input['input_ids'].squeeze(1).to(device)

                    output = self.model(input_id, mask)

                    batch_loss = self.criterion(output, val_label.long())
                    total_loss_val += batch_loss.item()
                    
                    acc = (output.argmax(dim=1) == val_label).sum().item()
                    total_acc_val += acc
            
            
            train_acc.append(total_acc_train / self.len_train)
            train_loss.append(total_loss_train / self.len_train)
            val_acc.append(total_acc_val / self.len_val)
            val_loss.append(total_loss_val / self.len_val)

            print(
                f'| Train Loss: {total_loss_train / self.len_train: .3f} \
                \n| Train Accuracy: {total_acc_train / self.len_train: .3f} \
                \n| Val Loss: {total_loss_val / self.len_val: .3f} \
                \n| Val Accuracy: {total_acc_val / self.len_val: .3f}')
            
        avg_train_acc = sum(train_acc) / self.epochs
        avg_train_loss = sum(train_loss) / self.epochs
        avg_val_acc = sum(val_acc) / self.epochs
        avg_val_loss = sum(val_loss) / self.epochs

        weights = self.model.state_dict()

        return weights, avg_train_acc, avg_train_loss, avg_val_acc, avg_val_loss


In [22]:
class Server(object):
    def __init__(self, df_train_splits, df_val_splits, df_test, model, criterion, optimizer, tokenizer, epochs, rounds, n_clients):
        self.df_train_splits = df_train_splits
        self.df_val_splits = df_val_splits
        self.df_test = df_test
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.tokenizer = tokenizer
        self.epochs = epochs
        self.rounds = rounds
        self.n_clients = n_clients

    def train(self):
        global_weights = self.model.state_dict()

        train_acc = {}
        train_loss = {}
        val_acc = {}
        val_loss = {}

        for i in range(self.n_clients):
            train_acc[f"client_{i + 1}"] = []
            train_loss[f"client_{i + 1}"] = []
            val_acc[f"client_{i + 1}"] = []
            val_loss[f"client_{i + 1}"] = []
        
        for curr_round in range(1, self.rounds+1):
            print(f"\n\nRound {curr_round}...")

            curr_weights = self.model.state_dict()

            w = []
            for i in range(self.n_clients):
                name = f"client_{i + 1}"
                client_i = Client(name, 
                                self.df_train_splits[i],
                                self.df_val_splits[i], 
                                self.model, 
                                self.criterion, 
                                self.optimizer, 
                                self.tokenizer, 
                                self.epochs)

                weights, avg_train_acc, avg_train_loss, avg_val_acc, avg_val_loss = client_i.train(curr_weights)

                w.append(copy.deepcopy(weights))

                train_acc[f"client_{i + 1}"].append(copy.deepcopy(avg_train_acc))
                train_loss[f"client_{i + 1}"].append(copy.deepcopy(avg_train_loss))
                val_acc[f"client_{i + 1}"].append(copy.deepcopy(avg_val_acc))
                val_loss[f"client_{i + 1}"].append(copy.deepcopy(avg_val_loss))
            
            weights_avg = copy.deepcopy(w[0])

            for k in weights_avg.keys():
                for i in range(1, len(w)):
                    weights_avg[k] += w[i][k]
                weights_avg[k] = torch.div(weights_avg[k], len(w))
            
            global_weights = weights_avg

            self.model.load_state_dict(global_weights)
        
        print(train_acc, train_loss, val_acc, val_loss)
    
    def test(self):
        test_dataloader = torch.utils.data.DataLoader(Dataset(self.df_test, self.tokenizer), batch_size=2)
        len_test = len(self.df_test)
        use_cuda = torch.cuda.is_available()
        device = torch.device("cuda" if use_cuda else "cpu")

        if use_cuda:
            self.model = self.model.cuda()

        total_acc_test = 0
        with torch.no_grad():
            for test_input, test_label in test_dataloader:

              test_label = test_label.to(device)
              mask = test_input['attention_mask'].to(device)
              input_id = test_input['input_ids'].squeeze(1).to(device)

              output = self.model(input_id, mask)

              acc = (output.argmax(dim=1) == test_label).sum().item()
              total_acc_test += acc
    
        print(f'Test Accuracy: {total_acc_test / len_test: .3f}')


In [10]:
def split_dataframe(df, n):
    s = int(len(df) / n)
    shuffled = df.sample(frac = 1)
    splits = []

    for i in range(n):
        if i != n - 1:
            splits.append(shuffled.iloc[i*s:(i+1)*s])
        else:
            splits.append(shuffled.iloc[i*s:])
    
    return splits

In [11]:
n_clients = 8
rounds = 2

In [12]:
np.random.seed(112)
df_train, df_val, df_test = np.split(df.sample(frac=1, random_state=42), 
                                     [int(.8*len(df)), int(.9*len(df))])

print(len(df_train),len(df_val), len(df_test))

4582 573 573


In [13]:
df_train_splits = split_dataframe(df_train, n_clients)
df_val_splits = split_dataframe(df_val, n_clients)

In [14]:
epochs = 4
learning_rate = 1e-6

bert = BertModel.from_pretrained('bert-base-cased')
model = BertClassifier(bert)

criterion = nn.CrossEntropyLoss()

optimizer = Adam(model.parameters(), lr= learning_rate)

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

In [23]:
server = Server(df_train_splits, df_val_splits, df_test, model, criterion, optimizer, tokenizer, epochs, rounds, n_clients)

In [16]:
server.train()



Round 1...
Client: client_1, Epoch: 1


100%|██████████| 286/286 [01:02<00:00,  4.61it/s]


| Train Loss:  0.345                 
| Train Accuracy:  0.680                 
| Val Loss:  0.318                 
| Val Accuracy:  0.803
Client: client_1, Epoch: 2


100%|██████████| 286/286 [01:00<00:00,  4.75it/s]


| Train Loss:  0.305                 
| Train Accuracy:  0.743                 
| Val Loss:  0.228                 
| Val Accuracy:  0.845
Client: client_1, Epoch: 3


100%|██████████| 286/286 [01:00<00:00,  4.71it/s]


| Train Loss:  0.202                 
| Train Accuracy:  0.836                 
| Val Loss:  0.123                 
| Val Accuracy:  0.944
Client: client_1, Epoch: 4


100%|██████████| 286/286 [01:00<00:00,  4.71it/s]


| Train Loss:  0.121                 
| Train Accuracy:  0.951                 
| Val Loss:  0.088                 
| Val Accuracy:  0.958
Client: client_2, Epoch: 1


100%|██████████| 286/286 [00:59<00:00,  4.79it/s]


| Train Loss:  0.354                 
| Train Accuracy:  0.608                 
| Val Loss:  0.350                 
| Val Accuracy:  0.662
Client: client_2, Epoch: 2


100%|██████████| 286/286 [00:59<00:00,  4.82it/s]


| Train Loss:  0.350                 
| Train Accuracy:  0.677                 
| Val Loss:  0.355                 
| Val Accuracy:  0.704
Client: client_2, Epoch: 3


100%|██████████| 286/286 [00:59<00:00,  4.82it/s]


| Train Loss:  0.346                 
| Train Accuracy:  0.684                 
| Val Loss:  0.349                 
| Val Accuracy:  0.746
Client: client_2, Epoch: 4


100%|██████████| 286/286 [00:59<00:00,  4.79it/s]


| Train Loss:  0.327                 
| Train Accuracy:  0.712                 
| Val Loss:  0.304                 
| Val Accuracy:  0.775
Client: client_3, Epoch: 1


100%|██████████| 286/286 [00:59<00:00,  4.81it/s]


| Train Loss:  0.351                 
| Train Accuracy:  0.698                 
| Val Loss:  0.352                 
| Val Accuracy:  0.662
Client: client_3, Epoch: 2


100%|██████████| 286/286 [00:59<00:00,  4.81it/s]


| Train Loss:  0.346                 
| Train Accuracy:  0.729                 
| Val Loss:  0.347                 
| Val Accuracy:  0.761
Client: client_3, Epoch: 3


100%|██████████| 286/286 [00:59<00:00,  4.81it/s]


| Train Loss:  0.344                 
| Train Accuracy:  0.734                 
| Val Loss:  0.331                 
| Val Accuracy:  0.732
Client: client_3, Epoch: 4


100%|██████████| 286/286 [01:00<00:00,  4.75it/s]


| Train Loss:  0.298                 
| Train Accuracy:  0.767                 
| Val Loss:  0.244                 
| Val Accuracy:  0.789
Client: client_4, Epoch: 1


100%|██████████| 286/286 [00:59<00:00,  4.81it/s]


| Train Loss:  0.354                 
| Train Accuracy:  0.671                 
| Val Loss:  0.356                 
| Val Accuracy:  0.718
Client: client_4, Epoch: 2


100%|██████████| 286/286 [00:59<00:00,  4.79it/s]


| Train Loss:  0.334                 
| Train Accuracy:  0.764                 
| Val Loss:  0.323                 
| Val Accuracy:  0.746
Client: client_4, Epoch: 3


100%|██████████| 286/286 [01:00<00:00,  4.75it/s]


| Train Loss:  0.293                 
| Train Accuracy:  0.769                 
| Val Loss:  0.278                 
| Val Accuracy:  0.775
Client: client_4, Epoch: 4


100%|██████████| 286/286 [01:00<00:00,  4.74it/s]


| Train Loss:  0.220                 
| Train Accuracy:  0.781                 
| Val Loss:  0.238                 
| Val Accuracy:  0.803
Client: client_5, Epoch: 1


100%|██████████| 286/286 [00:59<00:00,  4.80it/s]


| Train Loss:  0.351                 
| Train Accuracy:  0.663                 
| Val Loss:  0.353                 
| Val Accuracy:  0.704
Client: client_5, Epoch: 2


100%|██████████| 286/286 [00:59<00:00,  4.81it/s]


| Train Loss:  0.349                 
| Train Accuracy:  0.699                 
| Val Loss:  0.353                 
| Val Accuracy:  0.662
Client: client_5, Epoch: 3


100%|██████████| 286/286 [00:59<00:00,  4.83it/s]


| Train Loss:  0.346                 
| Train Accuracy:  0.719                 
| Val Loss:  0.351                 
| Val Accuracy:  0.718
Client: client_5, Epoch: 4


100%|██████████| 286/286 [00:59<00:00,  4.82it/s]


| Train Loss:  0.342                 
| Train Accuracy:  0.731                 
| Val Loss:  0.339                 
| Val Accuracy:  0.690
Client: client_6, Epoch: 1


100%|██████████| 286/286 [00:59<00:00,  4.81it/s]


| Train Loss:  0.352                 
| Train Accuracy:  0.685                 
| Val Loss:  0.352                 
| Val Accuracy:  0.789
Client: client_6, Epoch: 2


100%|██████████| 286/286 [00:59<00:00,  4.78it/s]


| Train Loss:  0.325                 
| Train Accuracy:  0.755                 
| Val Loss:  0.314                 
| Val Accuracy:  0.803
Client: client_6, Epoch: 3


100%|██████████| 286/286 [01:00<00:00,  4.73it/s]


| Train Loss:  0.256                 
| Train Accuracy:  0.774                 
| Val Loss:  0.224                 
| Val Accuracy:  0.789
Client: client_6, Epoch: 4


100%|██████████| 286/286 [01:00<00:00,  4.73it/s]


| Train Loss:  0.187                 
| Train Accuracy:  0.827                 
| Val Loss:  0.152                 
| Val Accuracy:  0.873
Client: client_7, Epoch: 1


100%|██████████| 286/286 [00:59<00:00,  4.80it/s]


| Train Loss:  0.353                 
| Train Accuracy:  0.647                 
| Val Loss:  0.353                 
| Val Accuracy:  0.648
Client: client_7, Epoch: 2


100%|██████████| 286/286 [00:59<00:00,  4.82it/s]


| Train Loss:  0.348                 
| Train Accuracy:  0.684                 
| Val Loss:  0.350                 
| Val Accuracy:  0.718
Client: client_7, Epoch: 3


100%|██████████| 286/286 [00:59<00:00,  4.79it/s]


| Train Loss:  0.329                 
| Train Accuracy:  0.727                 
| Val Loss:  0.331                 
| Val Accuracy:  0.732
Client: client_7, Epoch: 4


100%|██████████| 286/286 [01:00<00:00,  4.75it/s]


| Train Loss:  0.301                 
| Train Accuracy:  0.733                 
| Val Loss:  0.270                 
| Val Accuracy:  0.732
Client: client_8, Epoch: 1


100%|██████████| 289/289 [01:00<00:00,  4.81it/s]


| Train Loss:  0.353                 
| Train Accuracy:  0.682                 
| Val Loss:  0.350                 
| Val Accuracy:  0.697
Client: client_8, Epoch: 2


100%|██████████| 289/289 [00:59<00:00,  4.83it/s]


| Train Loss:  0.350                 
| Train Accuracy:  0.718                 
| Val Loss:  0.345                 
| Val Accuracy:  0.789
Client: client_8, Epoch: 3


100%|██████████| 289/289 [00:59<00:00,  4.82it/s]


| Train Loss:  0.346                 
| Train Accuracy:  0.720                 
| Val Loss:  0.345                 
| Val Accuracy:  0.763
Client: client_8, Epoch: 4


100%|██████████| 289/289 [01:00<00:00,  4.81it/s]


| Train Loss:  0.341                 
| Train Accuracy:  0.758                 
| Val Loss:  0.334                 
| Val Accuracy:  0.776


Round 2...
Client: client_1, Epoch: 1


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.219                 
| Train Accuracy:  0.788                 
| Val Loss:  0.129                 
| Val Accuracy:  0.930
Client: client_1, Epoch: 2


100%|██████████| 286/286 [01:00<00:00,  4.71it/s]


| Train Loss:  0.119                 
| Train Accuracy:  0.914                 
| Val Loss:  0.084                 
| Val Accuracy:  0.958
Client: client_1, Epoch: 3


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.086                 
| Train Accuracy:  0.960                 
| Val Loss:  0.068                 
| Val Accuracy:  0.972
Client: client_1, Epoch: 4


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.062                 
| Train Accuracy:  0.986                 
| Val Loss:  0.053                 
| Val Accuracy:  0.986
Client: client_2, Epoch: 1


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.098                 
| Train Accuracy:  0.932                 
| Val Loss:  0.055                 
| Val Accuracy:  0.986
Client: client_2, Epoch: 2


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.056                 
| Train Accuracy:  0.977                 
| Val Loss:  0.050                 
| Val Accuracy:  0.972
Client: client_2, Epoch: 3


100%|██████████| 286/286 [01:00<00:00,  4.73it/s]


| Train Loss:  0.033                 
| Train Accuracy:  0.998                 
| Val Loss:  0.041                 
| Val Accuracy:  0.972
Client: client_2, Epoch: 4


100%|██████████| 286/286 [01:00<00:00,  4.73it/s]


| Train Loss:  0.027                 
| Train Accuracy:  0.998                 
| Val Loss:  0.035                 
| Val Accuracy:  0.986
Client: client_3, Epoch: 1


100%|██████████| 286/286 [01:00<00:00,  4.71it/s]


| Train Loss:  0.045                 
| Train Accuracy:  0.981                 
| Val Loss:  0.039                 
| Val Accuracy:  0.986
Client: client_3, Epoch: 2


100%|██████████| 286/286 [01:00<00:00,  4.73it/s]


| Train Loss:  0.032                 
| Train Accuracy:  0.984                 
| Val Loss:  0.049                 
| Val Accuracy:  0.958
Client: client_3, Epoch: 3


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.020                 
| Train Accuracy:  0.993                 
| Val Loss:  0.038                 
| Val Accuracy:  0.986
Client: client_3, Epoch: 4


100%|██████████| 286/286 [01:00<00:00,  4.73it/s]


| Train Loss:  0.016                 
| Train Accuracy:  0.995                 
| Val Loss:  0.050                 
| Val Accuracy:  0.958
Client: client_4, Epoch: 1


100%|██████████| 286/286 [01:00<00:00,  4.71it/s]


| Train Loss:  0.047                 
| Train Accuracy:  0.974                 
| Val Loss:  0.071                 
| Val Accuracy:  0.958
Client: client_4, Epoch: 2


100%|██████████| 286/286 [01:02<00:00,  4.54it/s]


| Train Loss:  0.017                 
| Train Accuracy:  0.997                 
| Val Loss:  0.052                 
| Val Accuracy:  0.958
Client: client_4, Epoch: 3


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.010                 
| Train Accuracy:  1.000                 
| Val Loss:  0.057                 
| Val Accuracy:  0.958
Client: client_4, Epoch: 4


100%|██████████| 286/286 [01:00<00:00,  4.73it/s]


| Train Loss:  0.008                 
| Train Accuracy:  1.000                 
| Val Loss:  0.050                 
| Val Accuracy:  0.958
Client: client_5, Epoch: 1


100%|██████████| 286/286 [01:00<00:00,  4.71it/s]


| Train Loss:  0.054                 
| Train Accuracy:  0.970                 
| Val Loss:  0.061                 
| Val Accuracy:  0.986
Client: client_5, Epoch: 2


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.021                 
| Train Accuracy:  0.993                 
| Val Loss:  0.037                 
| Val Accuracy:  0.986
Client: client_5, Epoch: 3


100%|██████████| 286/286 [01:00<00:00,  4.73it/s]


| Train Loss:  0.014                 
| Train Accuracy:  0.995                 
| Val Loss:  0.009                 
| Val Accuracy:  1.000
Client: client_5, Epoch: 4


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.010                 
| Train Accuracy:  0.997                 
| Val Loss:  0.010                 
| Val Accuracy:  1.000
Client: client_6, Epoch: 1


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.035                 
| Train Accuracy:  0.983                 
| Val Loss:  0.005                 
| Val Accuracy:  1.000
Client: client_6, Epoch: 2


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.017                 
| Train Accuracy:  0.993                 
| Val Loss:  0.006                 
| Val Accuracy:  1.000
Client: client_6, Epoch: 3


100%|██████████| 286/286 [01:00<00:00,  4.73it/s]


| Train Loss:  0.013                 
| Train Accuracy:  0.995                 
| Val Loss:  0.004                 
| Val Accuracy:  1.000
Client: client_6, Epoch: 4


100%|██████████| 286/286 [01:00<00:00,  4.73it/s]


| Train Loss:  0.009                 
| Train Accuracy:  0.998                 
| Val Loss:  0.004                 
| Val Accuracy:  1.000
Client: client_7, Epoch: 1


100%|██████████| 286/286 [01:00<00:00,  4.71it/s]


| Train Loss:  0.047                 
| Train Accuracy:  0.965                 
| Val Loss:  0.009                 
| Val Accuracy:  1.000
Client: client_7, Epoch: 2


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.015                 
| Train Accuracy:  0.993                 
| Val Loss:  0.010                 
| Val Accuracy:  0.986
Client: client_7, Epoch: 3


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.005                 
| Train Accuracy:  1.000                 
| Val Loss:  0.007                 
| Val Accuracy:  1.000
Client: client_7, Epoch: 4


100%|██████████| 286/286 [01:00<00:00,  4.72it/s]


| Train Loss:  0.004                 
| Train Accuracy:  1.000                 
| Val Loss:  0.005                 
| Val Accuracy:  1.000
Client: client_8, Epoch: 1


100%|██████████| 289/289 [01:01<00:00,  4.72it/s]


| Train Loss:  0.037                 
| Train Accuracy:  0.969                 
| Val Loss:  0.013                 
| Val Accuracy:  1.000
Client: client_8, Epoch: 2


100%|██████████| 289/289 [01:01<00:00,  4.72it/s]


| Train Loss:  0.011                 
| Train Accuracy:  0.991                 
| Val Loss:  0.005                 
| Val Accuracy:  1.000
Client: client_8, Epoch: 3


100%|██████████| 289/289 [01:01<00:00,  4.73it/s]


| Train Loss:  0.005                 
| Train Accuracy:  0.998                 
| Val Loss:  0.004                 
| Val Accuracy:  1.000
Client: client_8, Epoch: 4


100%|██████████| 289/289 [01:01<00:00,  4.72it/s]


| Train Loss:  0.005                 
| Train Accuracy:  0.998                 
| Val Loss:  0.003                 
| Val Accuracy:  1.000
{'client_1': [0.8024475524475525, 0.9121503496503496], 'client_2': [0.6700174825174825, 0.9763986013986015], 'client_3': [0.7320804195804196, 0.9881993006993006], 'client_4': [0.7465034965034965, 0.99256993006993], 'client_5': [0.7027972027972028, 0.9886363636363638], 'client_6': [0.7604895104895104, 0.9921328671328671], 'client_7': [0.6975524475524475, 0.9895104895104895], 'client_8': [0.7192906574394463, 0.9891868512110726]} {'client_1': [0.2432154927832576, 0.12135735973088929], 'client_2': [0.3443219023333354, 0.053656871993521274], 'client_3': [0.3347819337452, 0.02812543560643322], 'client_4': [0.30022752062919045, 0.020442941820556688], 'client_5': [0.34691522973817546, 0.024596316207194595], 'client_6': [0.2799864777484322, 0.018266935363473773], 'client_7': [0.3328578777975969, 0.017751491370672355], 'client_8': [0.3474844759362379, 0.01471

In [24]:
server.test()

Test Accuracy:  0.984
