In [1]:
"""Federated Learning Model"""
import sys
import os
from collections import OrderedDict
from typing import List

import flwr as fl
import joblib
import numpy as np
from sklearn.preprocessing import OneHotEncoder
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.model_selection import StratifiedShuffleSplit
import pandas as pd
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report

sys.path.append('/home/hadih/repos/FIDS/src/')

from config import FederatedLocation
from utils import Model, straitified_split
from models.Classifier import Classifier

DEVICE = torch.device("cuda")

  from .autonotebook import tqdm as notebook_tqdm
2023-08-26 01:56:17,271	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
def load_datasets(location=FederatedLocation,):
    """Load datasets"""
    # trainloaders = []
    # valloaders = []
    trainData = []

    # oe = OneHotEncoder(sparse=False)

    data = joblib.load('/home/hadih/repos/FIDS/data/processed/train.pkl')
    chunks = straitified_split(
        data["X"], data["y"], location.clients_number
    )
    
    del data
    for X, y in chunks:
        # X, y = torch.from_numpy(X.astype(np.float32)), torch.tensor(y)
        X = X.astype(np.float32)
        assert X.shape[0] == y.shape[0]
        # ds = TensorDataset(X, y)
        # len_val = len(ds) // 10  # 10 % validation set
        # len_train = len(ds) - len_val
        # lengths = [len_train, len_val]
        # ds_train, ds_val = random_split(
        #     ds, lengths, torch.Generator().manual_seed(42)
        # )
        # trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        trainData.append((X, y))
        # valloaders.append(DataLoader(ds_val, batch_size=32))
        del X, y, # ds, ds_train, ds_val, len_val, len_train, lengths

    del chunks
    testset = joblib.load('/home/hadih/repos/FIDS/data/processed/test.pkl')
    # X, y = torch.from_numpy(testset["X"].astype(np.float32)), torch.tensor(testset["y"])
    # testloader = DataLoader(TensorDataset(X, y), batch_size=32)

    X, y = testset["X"].astype(np.float32), testset["y"]

    del testset
    return trainData, (X, y) # trainloaders, valloaders, testloader


def evaluate(y_true, y_pred):
    # sns.annot_kws={"size": 16}
    # sns.heatmap(pd.DataFrame(confusion_matrix(y_true, y_pred)), annot=True, annot_kws={"size": 9})
    print(classification_report(y_true, y_pred))

In [3]:
# trainloaders, valloaders, testloader = load_datasets()
# trainloader = trainloaders[0]
# valloader = valloaders[0]

trainData, (X_test, y_test) = load_datasets()

In [4]:
for method in ['softmax', 'cnn2', 'cnn5', 'nn3', 'nn5']:
    print(method)
    for i, (X, y) in enumerate(trainData):
        print('Client ', i)
        clf = Classifier(key=f'_{method}_client_{i}', method=method)
        clf.fit(X, y)
        y_pred = clf.predict(X_test)
        evaluate(y_test, y_pred)

softmax
Client  0
Loaded models/Classifier_logs_softmax_client_0 model trained with batch_size = 100, seen 0 epochs and 1349 mini batches


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00        98
           1       0.92      0.99      0.95      6401
           2       0.82      0.84      0.83       515
           3       0.98      0.90      0.94     11554
           4       0.64      0.73      0.68       275
           5       0.32      0.55      0.40       290
           6       0.69      0.51      0.58       397
           8       0.00      0.00      0.00         2
           9       0.97      0.99      0.98      7947
          10       0.88      0.53      0.66       295
          11       0.22      0.85      0.35        75
          12       0.00      0.00      0.00         1
          13       0.00      0.00      0.00        33

    accuracy                           0.93     27883
   macro avg       0.49      0.53      0.49     27883
weighted avg       0.94      0.93      0.93     27883

Client  1
Loaded models/Classifier_logs_softmax_client_1 model trained with bat

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.04      0.30      0.06        98
           1       0.94      0.97      0.96      6401
           2       0.72      0.92      0.81       515
           3       0.97      0.90      0.93     11554
           4       0.41      0.70      0.52       275
           5       0.31      0.34      0.33       290
           6       0.80      0.63      0.71       397
           8       0.00      0.00      0.00         2
           9       0.99      0.95      0.97      7947
          10       0.85      0.62      0.72       295
          11       0.13      0.27      0.17        75
          12       0.00      0.00      0.00         1
          13       0.00      0.00      0.00        33

    accuracy                           0.91     27883
   macro avg       0.47      0.51      0.47     27883
weighted avg       0.94      0.91      0.92     27883

Client  2
Loaded models/Classifier_logs_softmax_client_2 model trained with bat

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.49      0.58      0.53        98
           1       0.98      0.98      0.98      6401
           2       0.95      0.92      0.93       515
           3       0.97      0.95      0.96     11554
           4       0.76      0.57      0.65       275
           5       0.49      0.81      0.61       290
           6       0.60      0.51      0.55       397
           8       0.00      0.00      0.00         2
           9       0.97      0.99      0.98      7947
          10       0.54      0.18      0.27       295
          11       0.32      0.88      0.47        75
          12       0.00      0.00      0.00         1
          13       0.02      0.03      0.02        33

    accuracy                           0.95     27883
   macro avg       0.54      0.57      0.54     27883
weighted avg       0.95      0.95      0.95     27883

cnn2
Client  0


  return F.conv1d(input, weight, bias, self.stride,
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       1.00      0.53      0.69        98
           1       0.97      0.97      0.97      6401
           2       0.97      0.96      0.97       515
           3       0.97      0.89      0.93     11554
           4       0.89      0.93      0.91       275
           5       0.93      0.74      0.83       290
           6       1.00      0.51      0.67       397
           8       0.00      0.00      0.00         2
           9       0.85      1.00      0.92      7947
          10       1.00      0.50      0.67       295
          11       0.67      0.91      0.77        75
          12       0.00      0.00      0.00         1
          13       1.00      0.03      0.06        33

    accuracy                           0.93     27883
   macro avg       0.79      0.61      0.64     27883
weighted avg       0.93      0.93      0.93     27883

Client  1


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.98      0.42      0.59        98
           1       0.97      0.96      0.96      6401
           2       0.98      0.98      0.98       515
           3       0.97      0.98      0.97     11554
           4       0.92      0.84      0.88       275
           5       0.95      0.79      0.87       290
           6       1.00      0.51      0.67       397
           8       0.00      0.00      0.00         2
           9       0.95      1.00      0.97      7947
          10       0.99      0.50      0.67       295
          11       0.66      0.77      0.71        75
          12       0.00      0.00      0.00         1
          13       0.17      0.06      0.09        33

    accuracy                           0.96     27883
   macro avg       0.73      0.60      0.64     27883
weighted avg       0.96      0.96      0.96     27883

Client  2


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.98      0.58      0.73        98
           1       1.00      0.99      0.99      6401
           2       0.99      0.96      0.97       515
           3       0.98      1.00      0.99     11554
           4       0.90      0.94      0.92       275
           5       0.98      0.59      0.74       290
           6       1.00      0.51      0.67       397
           8       1.00      0.50      0.67         2
           9       0.96      1.00      0.98      7947
          10       1.00      0.50      0.67       295
          11       0.65      0.91      0.76        75
          12       0.00      0.00      0.00         1
          13       0.00      0.00      0.00        33

    accuracy                           0.98     27883
   macro avg       0.80      0.65      0.70     27883
weighted avg       0.98      0.98      0.97     27883

cnn5
Client  0


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       1.00      0.59      0.74        98
           1       0.98      1.00      0.99      6401
           2       0.98      0.98      0.98       515
           3       0.98      0.99      0.98     11554
           4       0.93      0.97      0.95       275
           5       0.98      0.81      0.89       290
           6       1.00      0.50      0.67       397
           8       0.00      0.00      0.00         2
           9       0.97      1.00      0.98      7947
          10       0.99      0.50      0.67       295
          11       0.67      0.91      0.77        75
          12       0.00      0.00      0.00         1
          13       1.00      0.03      0.06        33

    accuracy                           0.98     27883
   macro avg       0.81      0.64      0.67     27883
weighted avg       0.98      0.98      0.97     27883

Client  1


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       1.00      0.58      0.74        98
           1       0.99      0.99      0.99      6401
           2       0.97      0.99      0.98       515
           3       0.98      1.00      0.99     11554
           4       0.91      0.92      0.92       275
           5       0.96      0.79      0.87       290
           6       1.00      0.51      0.67       397
           8       0.00      0.00      0.00         2
           9       0.97      1.00      0.98      7947
          10       1.00      0.50      0.67       295
          11       0.68      0.91      0.78        75
          12       0.00      0.00      0.00         1
          13       1.00      0.03      0.06        33

    accuracy                           0.98     27883
   macro avg       0.80      0.63      0.66     27883
weighted avg       0.98      0.98      0.98     27883

Client  2


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.98      0.59      0.74        98
           1       1.00      1.00      1.00      6401
           2       0.99      0.97      0.98       515
           3       0.99      1.00      0.99     11554
           4       0.95      0.92      0.93       275
           5       0.90      0.82      0.86       290
           6       0.96      0.80      0.87       397
           8       1.00      0.50      0.67         2
           9       0.97      1.00      0.98      7947
          10       0.96      0.51      0.66       295
          11       0.65      0.91      0.76        75
          12       0.00      0.00      0.00         1
          13       1.00      0.03      0.06        33

    accuracy                           0.98     27883
   macro avg       0.87      0.69      0.73     27883
weighted avg       0.98      0.98      0.98     27883

nn3
Client  0
building NN3


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00        98
           1       0.86      0.94      0.90      6401
           2       0.96      0.93      0.94       515
           3       0.96      0.80      0.87     11554
           4       0.99      0.54      0.70       275
           5       0.68      0.57      0.62       290
           6       0.62      0.03      0.05       397
           8       0.00      0.00      0.00         2
           9       0.79      1.00      0.88      7947
          10       0.43      0.50      0.46       295
          11       0.00      0.00      0.00        75
          12       0.00      0.00      0.00         1
          13       0.00      0.00      0.00        33

    accuracy                           0.86     27883
   macro avg       0.48      0.41      0.42     27883
weighted avg       0.87      0.86      0.85     27883

Client  1
building NN3


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00        98
           1       0.88      0.95      0.91      6401
           2       0.99      0.94      0.97       515
           3       0.96      0.84      0.89     11554
           4       0.96      0.71      0.82       275
           5       0.81      0.80      0.81       290
           6       0.87      0.51      0.64       397
           8       0.00      0.00      0.00         2
           9       0.83      1.00      0.91      7947
          10       0.98      0.40      0.57       295
          11       0.00      0.00      0.00        75
          12       0.00      0.00      0.00         1
          13       0.00      0.00      0.00        33

    accuracy                           0.89     27883
   macro avg       0.56      0.47      0.50     27883
weighted avg       0.89      0.89      0.89     27883

Client  2
building NN3


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00        98
           1       0.87      0.98      0.92      6401
           2       0.98      0.90      0.93       515
           3       0.98      0.93      0.95     11554
           4       0.77      0.77      0.77       275
           5       0.90      0.62      0.74       290
           6       1.00      0.51      0.67       397
           8       0.00      0.00      0.00         2
           9       0.94      1.00      0.97      7947
          10       0.98      0.50      0.66       295
          11       0.00      0.00      0.00        75
          12       0.00      0.00      0.00         1
          13       0.00      0.00      0.00        33

    accuracy                           0.93     27883
   macro avg       0.57      0.48      0.51     27883
weighted avg       0.93      0.93      0.93     27883

nn5
Client  0


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00        98
           1       0.92      0.95      0.94      6401
           2       0.97      0.96      0.97       515
           3       0.94      0.95      0.95     11554
           4       0.94      0.65      0.77       275
           5       0.78      0.79      0.79       290
           6       1.00      0.50      0.67       397
           8       0.00      0.00      0.00         2
           9       0.97      0.99      0.98      7947
          10       0.99      0.50      0.67       295
          11       0.60      0.85      0.70        75
          12       0.00      0.00      0.00         1
          13       0.00      0.00      0.00        33

    accuracy                           0.94     27883
   macro avg       0.62      0.55      0.57     27883
weighted avg       0.94      0.94      0.94     27883

Client  1


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.89      0.58      0.70        98
           1       0.97      0.99      0.98      6401
           2       0.98      0.97      0.97       515
           3       0.97      0.99      0.98     11554
           4       0.94      0.65      0.77       275
           5       0.77      0.80      0.79       290
           6       0.99      0.51      0.67       397
           8       0.00      0.00      0.00         2
           9       0.97      0.99      0.98      7947
          10       1.00      0.50      0.67       295
          11       0.66      0.87      0.75        75
          12       0.00      0.00      0.00         1
          13       0.00      0.00      0.00        33

    accuracy                           0.97     27883
   macro avg       0.70      0.60      0.64     27883
weighted avg       0.97      0.97      0.97     27883

Client  2




              precision    recall  f1-score   support

           0       0.95      0.58      0.72        98
           1       1.00      0.98      0.99      6401
           2       0.99      0.98      0.98       515
           3       0.97      1.00      0.98     11554
           4       0.91      0.80      0.85       275
           5       0.90      0.78      0.84       290
           6       1.00      0.51      0.67       397
           8       0.00      0.00      0.00         2
           9       0.97      1.00      0.98      7947
          10       1.00      0.50      0.67       295
          11       0.66      0.87      0.75        75
          12       0.00      0.00      0.00         1
          13       1.00      0.03      0.06        33

    accuracy                           0.97     27883
   macro avg       0.80      0.62      0.65     27883
weighted avg       0.98      0.97      0.97     27883



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [8]:
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from tqdm import tqdm


class Net(nn.Module):
    """Model Class"""

    def __init__(self, input_dim=76, num_classes=14) -> None:
        super(Net, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        
        layers = []
        layers.append(nn.Linear(input_dim,128))

        layers.append(nn.BatchNorm1d(128))
        layers.append(nn.ReLU(True))
        layers.append(nn.Linear(128,256))
        
        layers.append(nn.BatchNorm1d(256))
        layers.append(nn.Dropout(p=0.3))
        layers.append(nn.ReLU(True))
        layers.append(nn.Linear(256,256))
        
        layers.append(nn.BatchNorm1d(256))
        layers.append(nn.Dropout(p=0.4))
        layers.append(nn.ReLU(True))
        layers.append(nn.Linear(256,128))

        layers.append(nn.BatchNorm1d(128))
        layers.append(nn.Dropout(p=0.5))
        layers.append(nn.ReLU(True))        
        layers.append(nn.Linear(128,num_classes))
        layers.append(nn.LogSoftmax())

        self.model = nn.Sequential(*layers).to(DEVICE)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for features, labels in tqdm(trainloader, desc=f'Epock {epoch+1}: '):
            features, labels = features.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(features)
            _, predicted = torch.max(outputs.data, 1)
            loss = criterion(
                net(features),
                labels
            )
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(
            f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}"
        )


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss, predictions, true = 0, 0, 0.0, [], []
    net.eval()
    with torch.no_grad():
        for data_points, labels in tqdm(testloader):
            true.append(labels)
            data_points, labels = data_points.to(DEVICE), labels.to(DEVICE)
            outputs = net(data_points)
            _, predicted = torch.max(outputs.data, 1)
            predictions.append(predicted.cpu().tolist())
            loss += criterion(outputs, labels).item()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy, true, predictions

In [10]:
def get_parameters(net) -> List[np.ndarray]:
    """Get params"""
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    """Set params"""
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


class FlowerClient(fl.client.NumPyClient):
    """Fedrated Client"""

    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        """Get params"""
        print(f"[Client {self.cid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        """fit client model"""
        print(f"[Client {self.cid}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        """evalute"""
        print(f"[Client {self.cid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


class FedratedModel(Model):
    """Federated Architecture Model"""

    def __init__(self, trainloaders, valloaders, testloader) -> None:
        super().__init__()
        self.trainloaders, self.valloaders, self.testloader = trainloaders, valloaders, testloader

    def client_fn(self, cid) -> FlowerClient:
        net = Net().to(DEVICE)
        trainloader = self.trainloaders[int(cid)]
        valloader = self.valloaders[int(cid)]
        return FlowerClient(cid, net, trainloader, valloader)

    def train(self):
        """train fedrated global model"""
        # Specify client resources if you need GPU (defaults to 1 CPU and 0 GPU)
        # client_resources = None
        # if DEVICE.type == "cuda":
        #     client_resources = {"num_gpus": 1}

        fl.simulation.start_simulation(
            client_fn=self.client_fn,
            num_clients=3,
            config=fl.server.ServerConfig(num_rounds=3),
            # client_resources=client_resources,
        )

In [11]:
f = FedratedModel(trainloaders=trainloaders, valloaders=valloaders, testloader=testloader)
f.train()

INFO flwr 2023-08-23 13:58:30,152 | app.py:146 | Starting Flower simulation, config: ServerConfig(num_rounds=3, round_timeout=None)
2023-08-23 13:58:32,497	INFO worker.py:1621 -- Started a local Ray instance.
INFO flwr 2023-08-23 13:58:34,093 | app.py:180 | Flower VCE: Ray initialized with resources: {'node:172.25.119.181': 1.0, 'CPU': 16.0, 'object_store_memory': 1657331712.0, 'memory': 3314663424.0, 'node:__internal_head__': 1.0}
INFO flwr 2023-08-23 13:58:34,093 | server.py:86 | Initializing global parameters
INFO flwr 2023-08-23 13:58:34,094 | server.py:273 | Requesting initial parameters from one random client
ERROR flwr 2023-08-23 13:58:35,592 | ray_client_proxy.py:72 | [36mray::launch_and_get_parameters()[39m (pid=43626, ip=172.25.119.181)
  At least one of the input arguments for this task could not be computed:
ray.exceptions.RaySystemError: System error: No module named 'utils'
traceback: Traceback (most recent call last):
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

RayTaskError: [36mray::launch_and_get_parameters()[39m (pid=43626, ip=172.25.119.181)
  At least one of the input arguments for this task could not be computed:
ray.exceptions.RaySystemError: System error: No module named 'utils'
traceback: Traceback (most recent call last):
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
          ^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'utils'