# Server Side for Multi Clients

## Imports

In [1]:
import os
import time
import sys

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import Subset
from torchvision.datasets import MNIST, SVHN, USPS

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import copy

In [2]:
from src.lib import *

## Set CUDA

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.manual_seed(777)
if device =="cuda:0":
    torch.cuda.manual_seed_all(777)

## Setting variables

In [4]:
model_name = "mobilenet"
dataset_name = "digit5"
data_path = './models/'
asset_path = './assets/'
classes = ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9')

In [5]:
epochs = 10
nuser = 4 # number of users
train_datasize_total = 32000
train_datasize_per_client = train_datasize_total // nuser

### network setting

In [6]:
# host = socket.gethostbyname(socket.gethostname())
host = 'localhost'
port = 10085

## Define model

In [7]:
server_model = ServerMobileNet().to(device)
print(server_model)
client_model = ClientMobileNet()

ServerMobileNet(
  (layer1): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (classifer): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=1024, out_features=10, bias=True)
  )
  (feature): Sequential(
    (0): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stat

### Set other hyperparameters in the model
Hyperparameters here should be same with the client side.

In [8]:
criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer = optim.SGD(server_model.parameters(), lr=lr, momentum=0.9)

client_weights = copy.deepcopy(client_model.state_dict())

## Allocate Server

In [9]:
server = Server(host, port, 1)
clients = server.accept_clients()
client_batchsizes = server.training_prep(epochs)
client = clients[0]
client_batchsize = client_batchsizes[0]

Conected with ('127.0.0.1', 34176)


## 1. Training phase

In [10]:
# send client_weight
datasize = send_msg(client, client_weights)

start_time = time.time()    # store start time
print("training start!")

for e in range(epochs):

    # train client 0

    for cidx in range(nuser):

        for i in tqdm(range(client_batchsize), ncols=100, desc='Epoch {} Client{} '.format(e+1, cidx)):
            optimizer.zero_grad()  # initialize all gradients to zero

            # receive results of front-model from client
            msg, msglen = recv_msg(client)

            # x(client) --> intermediate tensor --> (server) --> score, our label
            client_output_cpu = msg['client_output']  # intermediate tensor from client
            label = msg['label']  # true label

            client_output = client_output_cpu.to(device)
            label = label.clone().detach().long().to(device)

            # insert the tensor into back model, obtain loss, back propa
            output = server_model(client_output)
            loss = criterion(output, label)
            loss.backward()
            
            # return intermediate back propagation data to client
            msg = client_output_cpu.grad.clone().detach() # copy tensor
            msglen = send_msg(client, msg) 
            
            optimizer.step()
            
        
elapsed_time = time.time() - start_time
print("elapsed time for training using", device ,": {} sec".format(elapsed_time))

# save SERVER weights:
model_path = asset_path + model_name + '_' + dataset_name + '_server.pth'
torch.save(server_model.state_dict(), model_path)

# retrieve CLIENT weights
front_models = []
for _ in range(nuser):
    msg, msglen = recv_msg(client)
    front_models.append(msg)


training start!


Epoch 1 Client0 : 100%|███████████████████████████████████████████| 500/500 [00:11<00:00, 44.22it/s]
Epoch 1 Client1 : 100%|███████████████████████████████████████████| 500/500 [00:10<00:00, 45.74it/s]
Epoch 1 Client2 : 100%|███████████████████████████████████████████| 500/500 [00:10<00:00, 46.09it/s]
Epoch 1 Client3 : 100%|███████████████████████████████████████████| 500/500 [00:10<00:00, 46.04it/s]
Epoch 2 Client0 : 100%|███████████████████████████████████████████| 500/500 [00:10<00:00, 45.93it/s]
Epoch 2 Client1 : 100%|███████████████████████████████████████████| 500/500 [00:11<00:00, 44.83it/s]
Epoch 2 Client2 : 100%|███████████████████████████████████████████| 500/500 [00:11<00:00, 45.43it/s]
Epoch 2 Client3 : 100%|███████████████████████████████████████████| 500/500 [00:10<00:00, 45.53it/s]
Epoch 3 Client0 : 100%|███████████████████████████████████████████| 500/500 [00:10<00:00, 45.60it/s]
Epoch 3 Client1 : 100%|███████████████████████████████████████████| 500/500 [00:10<00:00, 4

elapsed time for training using cuda:0 : 454.00772619247437 sec


In [11]:
# save front model of each client:
for cidx, front_model in enumerate(front_models):
    model_path = asset_path + model_name + '_' + dataset_name + '_c' + str(cidx) + '.pth'
    torch.save(front_model, model_path)

## 3. Test phase

In [12]:
# resize every image into 32x32
# Order: MNIST, USPS, SVHN, MNIST-M, SYNTH
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
                                transforms.Resize((32, 32))
                               ])
transform_to_rgb = transforms.Compose([transforms.Grayscale(num_output_channels=3),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
                                       transforms.Resize((32, 32))
                                      ])

In [13]:
test_mnist = MNIST(root=data_path + 'mnist_data', train=False, download=True, transform=transform_to_rgb)
test_svhn = SVHN(root=data_path + 'svhn_data', split='test', download=True, transform=transform)
test_mnistm = MNISTM(root=data_path + 'mnist-m_data', train=False, download=True, transform=transform)
test_synth = SyntheticDigits(root=data_path + 'synthdigit_data', train=False, download=True, transform=transform)
test_usps = USPS(root=data_path + 'usps_data', train=False, download=True, transform=transform_to_rgb)
test_datasets = [test_mnist, test_svhn, test_mnistm, test_synth, test_usps]

Using downloaded and verified file: ./models/svhn_data/test_32x32.mat
./models/mnist-m_data/MNISTM/processed/mnist_m_test.pt
./models/synthdigit_data/SyntheticDigits/processed/synth_test.pt


In [14]:
test_size = 1000
indices = list(range(test_size))
test_loaders = []
for i, data_set in enumerate(test_datasets):
    loader = torch.utils.data.DataLoader(Subset(data_set, indices), batch_size=16, shuffle=False, num_workers=2)
    test_loaders.append(loader)

In [15]:
print([len(x) for x in test_datasets])
print([len(x) for x in test_loaders])

[10000, 26032, 10000, 9553, 2007]
[63, 63, 63, 63, 63]


In [18]:
# pick one dataset
test_loader = test_loaders[4]
test_total_batch = len(test_loader)
print("total_batch: ", test_total_batch)
x_test, y_test = next(iter(test_loader))
print("input tensor shape: ", x_test.size())
print("output label shape: ", y_test.size())

total_batch:  63
input tensor shape:  torch.Size([16, 3, 32, 32])
output label shape:  torch.Size([16])


## Calculate Accuracy after training

### test acc

In [19]:
# test accuracy per each front model
for cidx, front_model in enumerate(front_models):
    client_model.load_state_dict(front_model)
    client_model.to(device)
    class_correct = list(0. for i in range(len(classes)))
    class_total = list(0. for i in range(len(classes)))
    with torch.no_grad():
        for _, data in enumerate(tqdm(test_loader, ncols=100, desc=('front model %d test' % cidx))):
            x, labels = data
            x = x.to(device)
            labels = labels.to(device)

            outputs = client_model(x)
            outputs = server_model(outputs)
            labels = labels.long()
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
                
    # print accuracy per class
    for i in range(10):
        print('\tAccuracy of %5s : %2d %%' % (
            classes[i], 100 * class_correct[i] / class_total[i]))
    
    # print total accuracy
    print('\tTotal Accuracy: %2d %%' % (100 * sum(class_correct) / sum(class_total)))


front model 0 test: 100%|██████████████████████████████████████████| 63/63 [00:00<00:00, 116.15it/s]


	Accuracy of     0 : 77 %
	Accuracy of     1 : 91 %
	Accuracy of     2 : 80 %
	Accuracy of     3 : 93 %
	Accuracy of     4 : 73 %
	Accuracy of     5 : 90 %
	Accuracy of     6 : 91 %
	Accuracy of     7 : 90 %
	Accuracy of     8 : 71 %
	Accuracy of     9 : 62 %
	Total Accuracy: 81 %


front model 1 test: 100%|██████████████████████████████████████████| 63/63 [00:00<00:00, 120.22it/s]


	Accuracy of     0 : 77 %
	Accuracy of     1 : 92 %
	Accuracy of     2 : 79 %
	Accuracy of     3 : 91 %
	Accuracy of     4 : 74 %
	Accuracy of     5 : 90 %
	Accuracy of     6 : 90 %
	Accuracy of     7 : 90 %
	Accuracy of     8 : 75 %
	Accuracy of     9 : 61 %
	Total Accuracy: 81 %


front model 2 test: 100%|██████████████████████████████████████████| 63/63 [00:00<00:00, 122.16it/s]


	Accuracy of     0 : 76 %
	Accuracy of     1 : 92 %
	Accuracy of     2 : 78 %
	Accuracy of     3 : 91 %
	Accuracy of     4 : 73 %
	Accuracy of     5 : 92 %
	Accuracy of     6 : 90 %
	Accuracy of     7 : 91 %
	Accuracy of     8 : 72 %
	Accuracy of     9 : 60 %
	Total Accuracy: 80 %


front model 3 test: 100%|██████████████████████████████████████████| 63/63 [00:00<00:00, 114.41it/s]

	Accuracy of     0 : 72 %
	Accuracy of     1 : 93 %
	Accuracy of     2 : 77 %
	Accuracy of     3 : 87 %
	Accuracy of     4 : 72 %
	Accuracy of     5 : 84 %
	Accuracy of     6 : 93 %
	Accuracy of     7 : 93 %
	Accuracy of     8 : 72 %
	Accuracy of     9 : 73 %
	Total Accuracy: 80 %



