# MedMNIST Federated Transfer learning SqueezeNet Server Side

## Setting variables

In [1]:
rounds = 10 # number of communication round (server-client-server)
local_epoch = 1 # number of epoch in per client per round
users = 4 # number of clients


In [2]:
import os
import h5py

import socket
import struct
import pickle
import sys
import zlib
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets,transforms,models
from torchvision.models import SqueezeNet1_1_Weights
from torchvision import datasets
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split

from threading import Thread
from threading import Lock

import tenseal as ts

import time

from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader
import numpy as np
import medmnist
from medmnist import INFO, Evaluator

import torchmetrics
from torchmetrics.classification import Accuracy, Precision, Recall, F1Score

## Cuda

In [3]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = "cpu"
print(device)

cpu


## Pytorch layer modules for *Conv1D* Network



### `Conv1d` layer
- `torch.nn.Conv1d(in_channels, out_channels, kernel_size)`

### `MaxPool1d` layer
- `torch.nn.MaxPool1d(kernel_size, stride=None)`
- Parameter `stride` follows `kernel_size`.

### `ReLU` layer
- `torch.nn.ReLU()`

### `Linear` layer
- `torch.nn.Linear(in_features, out_features, bias=True)`

### `Softmax` layer
- `torch.nn.Softmax(dim=None)`
- Parameter `dim` is usually set to `1`.

In [4]:
sq_model = models.squeezenet1_1(weights=SqueezeNet1_1_Weights.DEFAULT)
sq_model.to(device)
lr = 0.01
optimizer = optim.SGD(sq_model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
step_lr_scheduler = lr_scheduler.StepLR(optimizer,step_size=7,gamma=0.1)

In [5]:
#freezing previous layers
for param in sq_model.features.parameters():
    param.requires_grad = False

In [6]:
# modifying the last layer to match desired output class
num_classes = 2 # CHANGE this to the number of classes of the dataset
in_ftrs = sq_model.classifier[1].in_channels
features = list(sq_model.classifier.children())[:-3] # Remove last 3 layers
features.extend([nn.Conv2d(in_ftrs, num_classes, kernel_size=1)]) # Add
features.extend([nn.ReLU(inplace=True)]) # Add
features.extend([nn.AdaptiveAvgPool2d(output_size=(1,1))]) # Add
sq_model.classifier = nn.Sequential(*features)

## Dataloaders

In [7]:
sets = ['train','test']
mean = np.array([0.485,0.456,0.406])
std = np.array([0.229,0.224,0.225])
data_transforms = {
    'train':transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean,std)
    ]),
    'test':transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean,std)
    ])
}
dataset_name = 'pneumoniamnist'
info = INFO[dataset_name]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])
class_labels = info['label']
print(info['label'])

BATCH_SIZE = 128
sets = ['train','test']
image_datasets = {x:DataClass(split=x, transform=data_transforms[x], download=True, size=128)
                for x in ['train','test']}

dataloaders = {'train': torch.utils.data.DataLoader(image_datasets['train'],batch_size=BATCH_SIZE,
                                            shuffle=True,num_workers=0),
               'test': torch.utils.data.DataLoader(image_datasets['test'],batch_size=2*BATCH_SIZE,
                                            shuffle=True,num_workers=0)
              }

# train_total_batch is dataset_sizes['train'] and test_batch is dataset_sizes['test']
dataset_sizes = {x:len(image_datasets[x]) for x in ['train','test']}

{'0': 'malignant', '1': 'normal, benign'}
Using downloaded and verified file: /home/anas/.medmnist/breastmnist_128.npz
Using downloaded and verified file: /home/anas/.medmnist/breastmnist_128.npz


In [8]:
def eval_model(preds, labels):
    # metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=n_classes)
    
    accuracy_metric = Accuracy(task="multiclass", num_classes=n_classes)
    accuracy_metric(preds.cpu(), labels.cpu())
    print(f'Accuracy: {accuracy_metric.compute():.4f}')

    #calculate precision
    precision = Precision(task="multiclass", average='macro', num_classes=n_classes)
    precision(preds.cpu(), labels.cpu())
    print(f'Precision: {precision.compute():.4f}')

    #calculate recall
    recall = Recall(task="multiclass", average='macro', num_classes=n_classes)
    recall(preds.cpu(), labels.cpu())
    print(f'Recall: {recall.compute():.4f}')

    #calculate f1 score
    f1 = F1Score(task="multiclass", average='macro', num_classes=n_classes)
    f1(preds.cpu(), labels.cpu())
    print(f'F1: {f1.compute():.4f}')

    #calculate confusion matrix
    cm = torchmetrics.functional.confusion_matrix(preds.cpu(), labels.cpu(), num_classes=n_classes, task="multiclass")
    print(f'Confusion Matrix: \n{cm}')


In [9]:
# Function to calculate accuracy and loss
def calculate_performance(model, dataloader, criterion, extra_metrics=False):
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    for data in dataloader:
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _,preds = torch.max(outputs,1)
        labels = labels.squeeze().long()
        loss = criterion(outputs, labels)
        total_loss += loss.item() * inputs.size(0)

        _, predicted = torch.max(outputs, 1)
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    accuracy = correct_predictions / total_samples
    average_loss = total_loss / total_samples

    if(extra_metrics):
        eval_model(preds, labels)

    return accuracy, average_loss

## variables

In [10]:
import copy

clientsoclist = [0]*users

start_time = 0
weight_count = 0

last_layer_list = [sq_model.state_dict()['classifier.1.weight'], sq_model.state_dict()['classifier.1.bias']]
global_weights = copy.deepcopy(sq_model.state_dict())

train_acc_list = []
val_acc_list = []

train_loss_list = []
val_loss_list = []

datasetsize = [0]*users
weights_list = [0]*users

total_aggr_time = 0
total_comm_time = 0

lock = Lock()

In [11]:
def get_global_model(target_list):
    global global_weights
    global sq_model
    
    global_weights['classifier.1.weight'] = target_list[0]
    global_weights['classifier.1.bias'] = target_list[1]
    
    sq_model.load_state_dict(global_weights)
    sq_model.eval()
    sq_model = sq_model.to(device)
    return sq_model

## Comunication overhead

In [12]:
total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []

## Socket initialization
### Set host address and port number

### Required socket functions

In [13]:
def send_msg(sock, msg):
    global total_comm_time
    # prefix each message with a 4-byte length in network byte order
    send_start = time.time()
    msg = pickle.dumps(msg)
    l_send = len(msg)
    msg = struct.pack('>I', l_send) + msg
    #  
    sock.sendall(msg)
    send_end = time.time()
    total_comm_time += (send_end - send_start)
    return l_send

def recv_msg(sock):
    global total_comm_time
    # read message length and unpack it into an integer
    recv_start = time.time()
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    recv_end = time.time()
    total_comm_time += (recv_end - recv_start)
    return msg, msglen

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

In [14]:
import copy

def average_weights(w, datasize):
    """
    Returns the average of the weights.
    """   
    for i, data in enumerate(datasize):
        for j in range(len(w[i])):
            w[i][j] *= float(data)
    
    w_avg = copy.deepcopy(w[0])

    for i in range(len(w_avg)):
        for j in range(1, len(w)):
            w_avg[i] += w[j][i]
            #eval add
        w_avg[i] = torch.div(w_avg[i], float(sum(datasize)))

    return w_avg

## Thread define

## Receive users before training

In [15]:
def run_thread(func, num_user):
    global clientsoclist
    global start_time
    
    thrs = []
    for i in range(num_user):
        conn, addr = s.accept()
        print('Conntected with', addr)
        # append client socket on list
        clientsoclist[i] = conn
        args = (i, num_user, conn)
        thread = Thread(target=func, args=args)
        thrs.append(thread)
        thread.start()
    print("timmer start!")
    start_time = time.time()    # store start time
    for thread in thrs:
        thread.join()
    end_time = time.time()  # store end time
    print("TrainingTime: {} sec".format(end_time - start_time))

In [16]:
def receive(userid, num_users, conn): #thread for receive clients
    global weight_count
    
    global datasetsize


    msg = {
        'rounds': rounds,
        'client_id': userid,
        'local_epoch': local_epoch,
        'last_layer_list_len':len(last_layer_list)
    }

    datasize = send_msg(conn, msg)    #send epoch
    total_sendsize_list.append(datasize)
    client_sendsize_list[userid].append(datasize)

    train_dataset_size, datasize = recv_msg(conn)    # get total_batch of train dataset
    total_receivesize_list.append(datasize)
    client_receivesize_list[userid].append(datasize)
    
    
    with lock:
        datasetsize[userid] = train_dataset_size
        weight_count += 1
    
    train(userid, train_dataset_size, num_users, conn)

## Train

In [17]:
def train(userid, train_dataset_size, num_users, client_conn):
    global weights_list
    global global_weights
    global last_layer_list
    global weight_count
    global sq_model
    global total_aggr_time
    global train_acc_list
    global val_acc_list
    global train_loss_list
    global val_loss_list
    
    for r in range(rounds):
        with lock:
            if weight_count == num_users:
                for i, conn in enumerate(clientsoclist):
                    datasize = send_msg(conn, last_layer_list) # sending last layer parameters only
                    total_sendsize_list.append(datasize)
                    client_sendsize_list[i].append(datasize)
                    train_sendsize_list.append(datasize)
                    weight_count = 0

        client_weights, datasize = recv_msg(client_conn) # client_weights refers to the last layer weights of the client 
        total_receivesize_list.append(datasize)
        client_receivesize_list[userid].append(datasize)
        train_receivesize_list.append(datasize)

        weights_list[userid] = client_weights
        print("User" + str(userid) + "'s Round " + str(r + 1) +  " is done")
        with lock:
            weight_count += 1
            if weight_count == num_users:
                aggr_start = time.time()
                #average
                last_layer_list = average_weights(weights_list, datasetsize) # find the average last layer weights
                aggr_end = time.time()
                total_aggr_time += (aggr_end - aggr_start)
                #tracking the global model performance per round
                copy_lll = copy.deepcopy(last_layer_list)
                sq_model = get_global_model(copy_lll)
                train_acc, train_loss = calculate_performance(sq_model, dataloaders['train'], criterion)
                val_acc, val_loss = calculate_performance(sq_model, dataloaders['test'], criterion)
                train_acc_list.append(train_acc)
                val_acc_list.append(val_acc)
                train_loss_list.append(train_loss)
                val_loss_list.append(val_loss)
                
        
       

In [18]:
host = socket.gethostbyname(socket.gethostname())
port = 10080
print(host)

192.168.0.144


In [19]:
s = socket.socket()
s.bind((host, port))
s.listen(5)

### Open the server socket

In [20]:
run_thread(receive, users)

Conntected with ('192.168.0.160', 39624)
Conntected with ('192.168.0.114', 53542)
Conntected with ('192.168.0.148', 41548)
Conntected with ('192.168.0.137', 36536)
timmer start!
User2's Round 1 is doneUser3's Round 1 is done

User1's Round 1 is done
User0's Round 1 is done
User3's Round 2 is done
User2's Round 2 is done
User1's Round 2 is done
User0's Round 2 is done
User3's Round 3 is done
User2's Round 3 is done
User1's Round 3 is done
User0's Round 3 is done
User3's Round 4 is done
User2's Round 4 is done
User1's Round 4 is done
User0's Round 4 is done
User2's Round 5 is done
User3's Round 5 is done
User1's Round 5 is done
User0's Round 5 is done
User3's Round 6 is done
User2's Round 6 is done
User1's Round 6 is done
User0's Round 6 is done
User3's Round 7 is done
User2's Round 7 is done
User1's Round 7 is done
User0's Round 7 is done
User2's Round 8 is done
User3's Round 8 is done
User1's Round 8 is done
User0's Round 8 is done
User2's Round 9 is done
User3's Round 9 is done
User1'

In [23]:
end_time = time.time()  # store end time
print("TrainingTime: {} sec".format(end_time - start_time))
print("Total aggrigation time: {} sec".format(total_aggr_time))
print("Total communication time: {} sec".format(total_comm_time))

TrainingTime: 1462.4196712970734 sec
Total aggrigation time: 0.011538267135620117 sec
Total communication time: 1467.0762729644775 sec


## Print all of communication overhead

In [24]:
# def commmunication_overhead():  
print('\n')
print('---total_sendsize_list---')
total_size = 0
for size in total_sendsize_list:
#     print(size)
    total_size += size
print("total_sendsize size: {} bytes".format(total_size))
print("number of total_send: ", len(total_sendsize_list))
print('\n')

print('---total_receivesize_list---')
total_size = 0
for size in total_receivesize_list:
#     print(size)
    total_size += size
print("total receive sizes: {} bytes".format(total_size) )
print("number of total receive: ", len(total_receivesize_list) )
print('\n')

for i in range(users):
    print('---client_sendsize_list(user{})---'.format(i))
    total_size = 0
    for size in client_sendsize_list[i]:
#         print(size)
        total_size += size
    print("total client_sendsizes(user{}): {} bytes".format(i, total_size))
    print("number of client_send(user{}): ".format(i), len(client_sendsize_list[i]))
    print('\n')

    print('---client_receivesize_list(user{})---'.format(i))
    total_size = 0
    for size in client_receivesize_list[i]:
#         print(size)
        total_size += size
    print("total client_receive sizes(user{}): {} bytes".format(i, total_size))
    print("number of client_send(user{}): ".format(i), len(client_receivesize_list[i]))
    print('\n')

print('---train_sendsize_list---')
total_size = 0
for size in train_sendsize_list:
#     print(size)
    total_size += size
print("total train_sendsizes: {} bytes".format(total_size))
print("number of train_send: ", len(train_sendsize_list) )
print('\n')

print('---train_receivesize_list---')
total_size = 0
for size in train_receivesize_list:
#     print(size)
    total_size += size
print("total train_receivesizes: {} bytes".format(total_size))
print("number of train_receive: ", len(train_receivesize_list) )
print('\n')




---total_sendsize_list---
total_sendsize size: 192696 bytes
number of total_send:  44


---total_receivesize_list---
total receive sizes: 191420 bytes
number of total receive:  44


---client_sendsize_list(user0)---
total client_sendsizes(user0): 48174 bytes
number of client_send(user0):  11


---client_receivesize_list(user0)---
total client_receive sizes(user0): 47855 bytes
number of client_send(user0):  11


---client_sendsize_list(user1)---
total client_sendsizes(user1): 48174 bytes
number of client_send(user1):  11


---client_receivesize_list(user1)---
total client_receive sizes(user1): 47855 bytes
number of client_send(user1):  11


---client_sendsize_list(user2)---
total client_sendsizes(user2): 48174 bytes
number of client_send(user2):  11


---client_receivesize_list(user2)---
total client_receive sizes(user2): 47855 bytes
number of client_send(user2):  11


---client_sendsize_list(user3)---
total client_sendsizes(user3): 48174 bytes
number of client_send(user3):  11


---c

## Accuracy of train and each of classes

In [25]:
# Evaluation on the train set
with torch.no_grad():
    train_accuracy, train_loss = calculate_performance(sq_model, dataloaders['train'], criterion, True)
    print("Train Accuracy: {:.2f}%, Train Loss: {:.4f}".format(train_accuracy * 100, train_loss))

# Evaluation on the test set
with torch.no_grad():
    test_accuracy, test_loss = calculate_performance(sq_model, dataloaders['test'], criterion, True)
    print("Test Accuracy: {:.2f}%, Test Loss: {:.4f}".format(test_accuracy * 100, test_loss))


# Save the trained model
PATH = './'+dataset_name+'_fd_SqueezeNet.pth'
torch.save(sq_model.state_dict(), PATH)

Accuracy: 0.7941
Precision: 0.7440
Recall: 0.6822
F1: 0.7006
Confusion Matrix: 
tensor([[ 4,  5],
        [ 2, 23]])
Train Accuracy: 82.23%, Train Loss: 0.3908
Accuracy: 0.8590
Precision: 0.8583
Recall: 0.7682
F1: 0.7974
Confusion Matrix: 
tensor([[ 24,  18],
        [  4, 110]])
Test Accuracy: 85.90%, Test Loss: 0.3344


In [26]:
import csv

def save_list_to_csv(data, filename):
  """Saves all values in a list to a CSV file.

  Args:
      data: The list containing the values to be saved.
      filename: The name of the CSV file to create.
  """

  # Open the CSV file in write mode with proper newline handling
  with open(filename, 'w', newline='') as csvfile:
    csv_writer = csv.writer(csvfile)

    # Write each value in the list to a separate row
    for item in data:
      csv_writer.writerow([item])  # Wrap in a list for proper formatting


# Save the list to a CSV file
save_list_to_csv(train_acc_list, "train_acc.csv")
save_list_to_csv(val_acc_list, "val_acc.csv")
save_list_to_csv(train_loss_list, "train_loss.csv")
save_list_to_csv(val_loss_list, "val_loss.csv")
print("List saved successfully to data.csv")

List saved successfully to data.csv
