# THB Federated Transfer learning SqueezeNet Server Side
This code is the server part of CIFAR10 federated mobilenet for **multi** client and a server.

## Setting variables

In [1]:
rounds = 10
local_epoch = 1
users = 3 # number of clients


In [2]:
import os
import h5py

import socket
import struct
import pickle
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets,transforms,models
from torchvision import datasets
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split

from threading import Thread
from threading import Lock

import tenseal as ts

import time

from tqdm import tqdm

## Loading public context object

In [3]:
with open('../../playground/Developing H.E for FL/public_context.pkl', 'rb') as inp:
    public_context_bin = pickle.load(inp)

In [4]:
public_context = ts.context_from(public_context_bin)

## Cuda

In [5]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = "cpu"
print(device)

cpu


## Pytorch layer modules for *Conv1D* Network



### `Conv1d` layer
- `torch.nn.Conv1d(in_channels, out_channels, kernel_size)`

### `MaxPool1d` layer
- `torch.nn.MaxPool1d(kernel_size, stride=None)`
- Parameter `stride` follows `kernel_size`.

### `ReLU` layer
- `torch.nn.ReLU()`

### `Linear` layer
- `torch.nn.Linear(in_features, out_features, bias=True)`

### `Softmax` layer
- `torch.nn.Softmax(dim=None)`
- Parameter `dim` is usually set to `1`.

In [6]:
sq_model = models.squeezenet1_1(weights=True)
sq_model.to(device)



SqueezeNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (3): Fire(
      (squeeze): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (4): Fire(
      (squeeze): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (5): MaxPool2d

In [7]:
#freezing previous layers
for param in sq_model.features.parameters():
    param.requires_grad = False

In [8]:
# modifying the last layer to match desired output class
num_classes = 3
in_ftrs = sq_model.classifier[1].in_channels
features = list(sq_model.classifier.children())[:-3] # Remove last 3 layers
features.extend([nn.Conv2d(in_ftrs, num_classes, kernel_size=1)]) # Add
features.extend([nn.ReLU(inplace=True)]) # Add
features.extend([nn.AdaptiveAvgPool2d(output_size=(1,1))]) # Add
sq_model.classifier = nn.Sequential(*features)

## variables

In [9]:
import copy

clientsoclist = [0]*users

start_time = 0
weight_count = 0

last_layer_list = [sq_model.state_dict()['classifier.1.weight'], sq_model.state_dict()['classifier.1.bias']]
global_weights = copy.deepcopy(sq_model.state_dict())

datasetsize = [0]*users
weights_list = [0]*users

lock = Lock()

## Comunication overhead

In [10]:
total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []

## Socket initialization
### Set host address and port number

### Required socket functions

In [11]:
def send_msg(sock, msg, serialize=True):
    # prefix each message with a 4-byte length in network byte order
    if serialize:
        msg = msg.serialize()
    else:        
        msg = pickle.dumps(msg)
    
    l_send = len(msg)
    msg = struct.pack('>I', l_send) + msg
    sock.sendall(msg)
    
    return l_send

def recv_msg(sock, deserialize=True):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    if deserialize:
        msg = ts.ckks_tensor_from(public_context, msg)
    else:
        msg = pickle.loads(msg)
        
    return msg, msglen

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

In [12]:
import copy

def average_weights(w, datasize):
    """
    Returns the average of the weights.
    """   
    for i, data in enumerate(datasize):
        for j in range(len(w[i])):
            w[i][j] *= float(data)
    
    w_avg = copy.deepcopy(w[0])

    for i in range(len(w_avg)):
        for j in range(1, len(w)):
            w_avg[i] += w[j][i]
            #eval add
        w_avg[i] = w_avg[i] * (1/float(sum(datasize)))

    return w_avg

## Thread define

## Receive users before training

In [13]:
def run_thread(func, num_user):
    global clientsoclist
    global start_time
    
    thrs = []
    for i in range(num_user):
        conn, addr = s.accept()
        print('Conntected with', addr)
        # append client socket on list
        clientsoclist[i] = conn
        args = (i, num_user, conn)
        thread = Thread(target=func, args=args)
        thrs.append(thread)
        thread.start()
    print("timmer start!")
    start_time = time.time()    # store start time
    for thread in thrs:
        thread.join()
    end_time = time.time()  # store end time
    print("TrainingTime: {} sec".format(end_time - start_time))

In [14]:
def receive(userid, num_users, conn): #thread for receive clients
    global weight_count
    
    global datasetsize


    msg = {
        'rounds': rounds,
        'client_id': userid,
        'local_epoch': local_epoch,
        'last_layer_list_len':len(last_layer_list)
    }

    datasize = send_msg(conn, msg, False)    #send epoch
    total_sendsize_list.append(datasize)
    client_sendsize_list[userid].append(datasize)

    train_dataset_size, datasize = recv_msg(conn, False)    # get total_batch of train dataset
    total_receivesize_list.append(datasize)
    client_receivesize_list[userid].append(datasize)
    
    
    with lock:
        datasetsize[userid] = train_dataset_size
        weight_count += 1
    
    train(userid, train_dataset_size, num_users, conn)

## Train

In [15]:
def train(userid, train_dataset_size, num_users, client_conn):
    global weights_list
    global global_weights
    global last_layer_list
    global weight_count
    global sq_model
    global val_acc
    
    for r in range(rounds):
        with lock:
            if weight_count == num_users:
                for i, conn in enumerate(clientsoclist):
                    if r == 0:
                        datasize = send_msg(conn, last_layer_list, False) # sending last layer parameters only
                    else:
                        datasize = 0
                        for param in last_layer_list:
                            datasize += send_msg(conn, param)
                        
                    total_sendsize_list.append(datasize)
                    client_sendsize_list[i].append(datasize)
                    train_sendsize_list.append(datasize)
                    weight_count = 0
        
        client_weights = [] # client_weights refers to the last layer weights of the client 
        for i in range(len(last_layer_list)):
            client_weight, datasize = recv_msg(client_conn) 
            client_weights.append(client_weight)
            
            total_receivesize_list.append(datasize)
            client_receivesize_list[userid].append(datasize)
            train_receivesize_list.append(datasize)
            

        weights_list[userid] = client_weights
        print("User" + str(userid) + "'s Round " + str(r + 1) +  " is done")
        with lock:
            weight_count += 1
            if weight_count == num_users:
                #average
                last_layer_list = average_weights(weights_list, datasetsize) # find the average last layer weights
                
        
    

In [16]:
host = socket.gethostbyname(socket.gethostname())
port = 10080
print(host)

192.168.0.116


In [17]:
s = socket.socket()
s.bind((host, port))
s.listen(5)

### Open the server socket

In [18]:
run_thread(receive, users)

KeyboardInterrupt: 

In [None]:
end_time = time.time()  # store end time
print("TrainingTime: {} sec".format(end_time - start_time))

## Print all of communication overhead

In [None]:
# def commmunication_overhead():  
print('\n')
print('---total_sendsize_list---')
total_size = 0
for size in total_sendsize_list:
#     print(size)
    total_size += size
print("total_sendsize size: {} bytes".format(total_size))
print("number of total_send: ", len(total_sendsize_list))
print('\n')

print('---total_receivesize_list---')
total_size = 0
for size in total_receivesize_list:
#     print(size)
    total_size += size
print("total receive sizes: {} bytes".format(total_size) )
print("number of total receive: ", len(total_receivesize_list) )
print('\n')

for i in range(users):
    print('---client_sendsize_list(user{})---'.format(i))
    total_size = 0
    for size in client_sendsize_list[i]:
#         print(size)
        total_size += size
    print("total client_sendsizes(user{}): {} bytes".format(i, total_size))
    print("number of client_send(user{}): ".format(i), len(client_sendsize_list[i]))
    print('\n')

    print('---client_receivesize_list(user{})---'.format(i))
    total_size = 0
    for size in client_receivesize_list[i]:
#         print(size)
        total_size += size
    print("total client_receive sizes(user{}): {} bytes".format(i, total_size))
    print("number of client_send(user{}): ".format(i), len(client_receivesize_list[i]))
    print('\n')

print('---train_sendsize_list---')
total_size = 0
for size in train_sendsize_list:
#     print(size)
    total_size += size
print("total train_sendsizes: {} bytes".format(total_size))
print("number of train_send: ", len(train_sendsize_list) )
print('\n')

print('---train_receivesize_list---')
total_size = 0
for size in train_receivesize_list:
#     print(size)
    total_size += size
print("total train_receivesizes: {} bytes".format(total_size))
print("number of train_receive: ", len(train_receivesize_list) )
print('\n')


In [None]:
# save the global model weights for evaluation

In [None]:
root_path = '../../datasets/THB_splitted'

In [None]:
from torch.utils.data import Dataset, DataLoader
import numpy as np

## Making Batch Generator

In [None]:
sets = ['train','test']
mean = np.array([0.485,0.456,0.406])
std = np.array([0.229,0.224,0.225])
data_transforms = {
    'train':transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean,std)
    ]),
    'test':transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean,std)
    ])
}
# trainset is image_datasets['train'] and testset is image_datasets['test']
image_datasets = {x:datasets.ImageFolder(os.path.join(root_path,x),
                                        data_transforms[x])
                for x in ['train','test']}
# trainloader is dataloaders['train'] and testloader is dataloaders['test']
dataloaders = {x:torch.utils.data.DataLoader(image_datasets[x],batch_size=4,
                                            shuffle=True,num_workers=0)
                for x in ['train','test']}

# train_total_batch is dataset_sizes['train'] and test_batch is dataset_sizes['test']
dataset_sizes = {x:len(image_datasets[x]) for x in ['train','test']}

In [None]:
classes = ('Bluetooth', 'Humidity', 'Transistor')

### `DataLoader` for batch generating
`torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)`

In [None]:
with open('../../playground/Developing H.E for FL/shared_context.pkl', 'rb') as inp:
    shared_context_bin = pickle.load(inp)

In [None]:
shared_context = ts.context_from(shared_context_bin)

In [None]:
sk = shared_context.secret_key()

In [None]:
decrypted_lll = []
for param in last_layer_list:
    decrypted_lll.append(torch.tensor(param.decrypt(sk).tolist()))
print(decrypted_lll)

In [None]:
# Updating the global weight's last layer
global_weights['classifier.1.weight'] = decrypted_lll[0]
global_weights['classifier.1.bias'] = decrypted_lll[1]

In [None]:
sq_model.load_state_dict(global_weights)
sq_model.eval()
sq_model = sq_model.to(device)

lr = 0.01
optimizer = optim.SGD(sq_model.parameters(), lr=lr, momentum=0.9)
criterion = nn.CrossEntropyLoss()

## Accuracy of train and each of classes

In [None]:
# Function to calculate accuracy and loss
def calculate_performance(model, dataloader, criterion):
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    for data in dataloader:
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        total_loss += loss.item() * inputs.size(0)

        _, predicted = torch.max(outputs, 1)
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    accuracy = correct_predictions / total_samples
    average_loss = total_loss / total_samples

    return accuracy, average_loss

# Evaluation on the train set
with torch.no_grad():
    train_accuracy, train_loss = calculate_performance(sq_model, dataloaders['train'], criterion)
    print("Train Accuracy: {:.2f}%, Train Loss: {:.4f}".format(train_accuracy * 100, train_loss))

# Evaluation on the test set
with torch.no_grad():
    test_accuracy, test_loss = calculate_performance(sq_model, dataloaders['test'], criterion)
    print("Test Accuracy: {:.2f}%, Test Loss: {:.4f}".format(test_accuracy * 100, test_loss))

# Class-wise accuracy
class_correct = [0.0] * 3
class_total = [0.0] * 3

with torch.no_grad():
    for data in dataloaders['test']:
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = sq_model(inputs)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels)

        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1

for i in range(3):
    print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

# Save the trained model
PATH = './THB_fd_SqueezeNet.pth'
torch.save(sq_model.state_dict(), PATH)
