# Single Server Side

## Imports

In [1]:
import os
import time
import sys

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import Subset

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import copy

In [2]:
from src.lib import *

## Set CUDA

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch.manual_seed(777)
if device =="cuda:0":
    torch.cuda.manual_seed_all(777)
print(device)

cuda:0


## Setting variables

In [4]:
# model_name = "mobilenet"
model_name = "squeezenet"
dataset_name = "cifar10"
data_path = './models/cifar10_data'
asset_path = './assets/'
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [5]:
epochs = 10
nuser = 1 # number of users
train_datasize_total = 50000
datasize_per_client = train_datasize_total // nuser

### network setting

In [6]:
# host = socket.gethostbyname(socket.gethostname())
host = 'localhost'
port = 10089
print(host)

localhost


## Define model

In [7]:
server_model = None
client_model = None
if model_name == "mobilenet":
    server_model = ServerMobileNet()
    client_model = ClientMobileNet()
elif model_name == "squeezenet":
    server_model = ServerSqueezeNet(num_classes=10)
    client_model = ClientSqueezeNet(num_classes=10)
    
server_model = server_model.to(device)
print(server_model)
        

ServerSqueezeNet(
  (layer1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  )
  (features): Sequential(
    (0): Fire(
      (squeeze): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (1): Fire(
      (squeeze): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inpl

### Set other hyperparameters in the model
Hyperparameters here should be same with the client side.

In [8]:
criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer = optim.SGD(server_model.parameters(), lr=lr, momentum=0.9)

client_weights = copy.deepcopy(client_model.state_dict())

## Allocate Server

In [9]:
server = Server(host, port, nuser)
clients = server.accept_clients()
client_batchsizes = server.training_prep(epochs)

Conected with ('127.0.0.1', 54468)


## 1. Traning phase

In [10]:
# broadcast client_weights
for client in clients:
    datasize = send_msg(client, client_weights)

start_time = time.time()    # store start time
print("training start!")

for e in range(epochs):

    # train client 0

    for cidx, client in enumerate(clients):

        for i in tqdm(range(client_batchsizes[cidx]), ncols=100, desc='Epoch {} Client{} '.format(e+1, cidx)):
            optimizer.zero_grad()  # initialize all gradients to zero

            # receive results of front-model from client
            msg, msglen = recv_msg(client)

            # x(client) --> intermediate tensor --> (server) --> score, our label
            client_output_cpu = msg['client_output']  # intermediate tensor from client
            label = msg['label']  # true label

            client_output = client_output_cpu.to(device)
            label = label.clone().detach().long().to(device)

            # insert the tensor into back model, obtain loss, back propa
            output = server_model(client_output)
            loss = criterion(output, label)
            loss.backward()
            
            # return intermediate back propagation data to client
            msg = client_output_cpu.grad.clone().detach() # copy tensor
            msglen = send_msg(client, msg) 
            
            optimizer.step()
            
        
elapsed_time = time.time() - start_time
print("elapsed time for training using", device ,": {} sec".format(elapsed_time))

# save SERVER weights:
model_path = asset_path + model_name + '_' + dataset_name + '_server.pth'
torch.save(server_model.state_dict(), model_path)

# retrieve CLIENT weights
front_models = []
for client in clients:
    msg, msglen = recv_msg(client)
    front_models.append(msg)


training start!


Epoch 1 Client0 : 100%|█████████████████████████████████████████| 3125/3125 [05:15<00:00,  9.92it/s]
Epoch 2 Client0 : 100%|█████████████████████████████████████████| 3125/3125 [05:13<00:00,  9.96it/s]
Epoch 3 Client0 : 100%|█████████████████████████████████████████| 3125/3125 [05:13<00:00,  9.97it/s]
Epoch 4 Client0 : 100%|█████████████████████████████████████████| 3125/3125 [05:31<00:00,  9.42it/s]
Epoch 5 Client0 : 100%|█████████████████████████████████████████| 3125/3125 [05:37<00:00,  9.25it/s]
Epoch 6 Client0 : 100%|█████████████████████████████████████████| 3125/3125 [05:43<00:00,  9.10it/s]
Epoch 7 Client0 : 100%|█████████████████████████████████████████| 3125/3125 [05:33<00:00,  9.37it/s]
Epoch 8 Client0 : 100%|█████████████████████████████████████████| 3125/3125 [05:24<00:00,  9.63it/s]
Epoch 9 Client0 : 100%|█████████████████████████████████████████| 3125/3125 [05:32<00:00,  9.41it/s]
Epoch 10 Client0 : 100%|████████████████████████████████████████| 3125/3125 [05:29<00:00,  

elapsed time for training using cuda:0 : 3274.4826946258545 sec





In [11]:
# save front model of each client:
for cidx, front_model in enumerate(front_models):
    model_path = asset_path + model_name + '_' + dataset_name + '_c' + str(cidx) + '.pth'
    torch.save(front_model, model_path)

## 3. Test phase

In [13]:
transform_list = [transforms.ToTensor(),
                  transforms.Normalize((0.4914, 0.4822, 0.4465),
                                       (0.2470, 0.2435, 0.2616))]
if model_name.startswith("squeezenet"):
    transform_list.append(transforms.Resize((224, 224)))
transform = transforms.Compose(transform_list)

In [14]:
test_set = torchvision.datasets.CIFAR10 (root=data_path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=16, shuffle=False, num_workers=2)

Files already downloaded and verified


In [15]:
x_test, y_test = next(iter(test_loader))
print("input tensor shape: ", x_test.size())
print("output label shape: ", y_test.size())
test_total_batch = len(test_loader)
print("total_batch: ", test_total_batch)

input tensor shape:  torch.Size([16, 3, 224, 224])
output label shape:  torch.Size([16])
total_batch:  625


## Calculate Accuracy after training

### test acc

In [16]:
# test accuracy per each front model
for cidx, front_model in enumerate(front_models):
    client_model.load_state_dict(front_model)
    client_model.to(device)
    class_correct = list(0. for i in range(len(classes)))
    class_total = list(0. for i in range(len(classes)))
    with torch.no_grad():
        for _, data in enumerate(tqdm(test_loader, ncols=100, desc=('front model %d test' % cidx))):
            x, labels = data
            x = x.to(device)
            labels = labels.to(device)

            outputs = client_model(x)
            outputs = server_model(outputs)
            labels = labels.long()
            _, predicted = torch.max(outputs, 1)
            c = (predicted == labels).squeeze()
            for i in range(len(labels)):
                label = labels[i]
                class_correct[label] += c[i].item()
                class_total[label] += 1
                
    # print accuracy per class
    for i in range(10):
        print('\tAccuracy of %5s : %2d %%' % (
            classes[i], 100 * class_correct[i] / class_total[i]))
    
    # print total accuracy
    print('\tTotal Accuracy: %2d %%' % (100 * sum(class_correct) / sum(class_total)))


front model 0 test: 100%|████████████████████████████████████████| 625/625 [00:06<00:00, 103.29it/s]

	Accuracy of plane : 77 %
	Accuracy of   car : 86 %
	Accuracy of  bird : 46 %
	Accuracy of   cat : 55 %
	Accuracy of  deer : 81 %
	Accuracy of   dog : 42 %
	Accuracy of  frog : 74 %
	Accuracy of horse : 63 %
	Accuracy of  ship : 84 %
	Accuracy of truck : 87 %
	Total Accuracy: 69 %



