# CIFAR10 Split Mobilenet Server Side
This code is the server part of CIFAR10 split 2D-CNN model for **multi** client and a server.

## Setting variables

In [1]:
epochs = 1
users = 1 # number of users

## Import required packages

In [2]:
import os
import h5py

import socket
import struct
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

import time
import sys



from tqdm import tqdm


from torch.utils.data import Dataset, DataLoader

import copy

## Set CUDA

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

torch.manual_seed(777)
if device =="cuda:0":
    torch.cuda.manual_seed_all(777)

## Define CIFAR10 server model(mobilenet)
Server side has **2 convolutional layer**


In [4]:
# -*- coding: utf-8 -*-
"""
Created on Thu Nov  1 14:23:31 2018
@author: tshzzz
"""

def conv_dw_server(inplane,outplane,stride=1):
    return nn.Sequential(
        nn.Conv2d(inplane,outplane,kernel_size = 1,groups = 1,stride=1),
        nn.BatchNorm2d(outplane),
        nn.ReLU()    
        )

def conv_dw(inplane,outplane,stride=1):
    return nn.Sequential(
        nn.Conv2d(inplane,inplane,kernel_size = 3,groups = inplane,stride=stride,padding=1),
        nn.BatchNorm2d(inplane),
        nn.ReLU(),
        nn.Conv2d(inplane,outplane,kernel_size = 1,groups = 1,stride=1),
        nn.BatchNorm2d(outplane),
        nn.ReLU()    
        )


def conv_bw(inplane,outplane,kernel_size = 3,stride=1):
    return nn.Sequential(
        nn.Conv2d(inplane,outplane,kernel_size = kernel_size,groups = 1,stride=stride,padding=1),
        nn.BatchNorm2d(outplane),
        nn.ReLU() 
        )


class MobileNet(nn.Module):
    
    def __init__(self,num_class=10):
        super(MobileNet,self).__init__()
        
        layers = []
#         layers.append(conv_bw(3,32,3,1))
        layers.append(conv_dw_server(32,64,1))
        layers.append(conv_dw(64,128,2))
        layers.append(conv_dw(128,128,1))
        layers.append(conv_dw(128,256,2))
        layers.append(conv_dw(256,256,1))
        layers.append(conv_dw(256,512,2))

        for i in range(5):
            layers.append(conv_dw(512,512,1))
        layers.append(conv_dw(512,1024,2))
        layers.append(conv_dw(1024,1024,1))

        self.classifer = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(1024,num_class)
                )
        self.feature = nn.Sequential(*layers)
        
        

    def forward(self,x):
        out = self.feature(x)
        out = out.mean(3).mean(2)
        out = out.view(-1,1024)
        out = self.classifer(out)
        return out



In [5]:
mobilenet_server = MobileNet().to(device)
print(mobilenet_server)

ResNet(
  (layer1): Sequential(
    (0): BasicBlock_server(
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (2): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),

In [6]:
# from torchsummary import summary

# summary(mobilenet_server, (32, 32, 32))

### CIFAR10 Client model for calculating Accuracy

In [7]:
# -*- coding: utf-8 -*-
"""
Created on Thu Nov  1 14:23:31 2018
@author: tshzzz
"""

import torch
import torch.nn as nn


def conv_dw_client(inplane,stride=1):
    return nn.Sequential(
        nn.Conv2d(inplane,inplane,kernel_size = 3,groups = inplane,stride=stride,padding=1),
        nn.BatchNorm2d(inplane),
        nn.ReLU()  
        )

def conv_bw(inplane,outplane,kernel_size = 3,stride=1):
    return nn.Sequential(
        nn.Conv2d(inplane,outplane,kernel_size = kernel_size,groups = 1,stride=stride,padding=1),
        nn.BatchNorm2d(outplane),
        nn.ReLU() 
        )


class MobileNet(nn.Module):
    
    def __init__(self,num_class=10):
        super(MobileNet,self).__init__()
        
        layers = []
        layers.append(conv_bw(3,32,3,1))
        layers.append(conv_dw_client(32,1))
#         layers.append(conv_dw(64,128,2))
#         layers.append(conv_dw(128,128,1))
#         layers.append(conv_dw(128,256,2))
#         layers.append(conv_dw(256,256,1))
#         layers.append(conv_dw(256,512,2))

#         for i in range(5):
#             layers.append(conv_dw(512,512,1))
#         layers.append(conv_dw(512,1024,2))
#         layers.append(conv_dw(1024,1024,1))

#         self.classifer = nn.Sequential(
#                 nn.Dropout(0.5),
#                 nn.Linear(1024,num_class)
#                 )
        self.feature = nn.Sequential(*layers)
        
        

    def forward(self,x):
        out = self.feature(x)
#         out = out.mean(3).mean(2)
#         out = out.view(-1,1024)
#         out = self.classifer(out)
        return out



In [8]:
mobilenet_client = MobileNet().to(device)
print(mobilenet_client)

ResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
)


In [9]:
# from torchsummary import summary

# summary(mobilenet_client, (3, 32, 32))

### Set other hyperparameters in the model
Hyperparameters here should be same with the client side.

In [10]:
criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer = optim.SGD(mobilenet_server.parameters(), lr=lr, momentum=0.9)

clientsoclist = []
train_total_batch = []
val_acc = []
client_weights = copy.deepcopy(mobilenet_client.state_dict())

In [11]:
total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []

## Socket initialization
### Set host address and port number

### Required socket functions

In [12]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    l_send = len(msg)
    msg = struct.pack('>I', l_send) + msg
    sock.sendall(msg)
    return l_send

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    return msg, msglen

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data


In [13]:
# host = socket.gethostbyname(socket.gethostname())
host = 'localhost'
port = 10080
print(host)

localhost


### Open the server socket

In [14]:
s = socket.socket()
s.bind((host, port))
s.listen(5)

## Receive client

In [15]:
for i in range(users):
    conn, addr = s.accept()
    print('Conntected with', addr)
    clientsoclist.append(conn)    # append client socket on list

    datasize = send_msg(conn, epochs)    #send epoch
    total_sendsize_list.append(datasize)
    client_sendsize_list[i].append(datasize)

    total_batch, datasize = recv_msg(conn)    # get total_batch of train dataset
    total_receivesize_list.append(datasize)
    client_receivesize_list[i].append(datasize)

    train_total_batch.append(total_batch)    # append on list

Conntected with ('127.0.0.1', 7500)


## traning

In [16]:
start_time = time.time()    # store start time
print("timmer start!")

for e in range(epochs):

    # train client 0

    for user in range(users):

        datasize = send_msg(clientsoclist[user], client_weights)
        total_sendsize_list.append(datasize)
        client_sendsize_list[user].append(datasize)
        train_sendsize_list.append(datasize)

        for i in tqdm(range(train_total_batch[user]), ncols=100, desc='Epoch {} Client{} '.format(e+1, user)):
            optimizer.zero_grad()  # initialize all gradients to zero

            msg, datasize = recv_msg(clientsoclist[user])  # receive client message from socket
            total_receivesize_list.append(datasize)
            client_receivesize_list[user].append(datasize)
            train_receivesize_list.append(datasize)

            client_output_cpu = msg['client_output']  # client output tensor
            label = msg['label']  # label

            client_output = client_output_cpu.to(device)
            label = label.clone().detach().long().to(device)

            output = mobilenet_server(client_output)  # forward propagation
            loss = criterion(output, label)  # calculates cross-entropy loss
            loss.backward()  # backward propagation
            msg = client_output_cpu.grad.clone().detach()

            datasize = send_msg(clientsoclist[user], msg)
            total_sendsize_list.append(datasize)
            client_sendsize_list[user].append(datasize)
            train_sendsize_list.append(datasize)
            
            optimizer.step()
            
        client_weights, datasize = recv_msg(clientsoclist[user])
        total_receivesize_list.append(datasize)
        client_receivesize_list[user].append(datasize)
        train_receivesize_list.append(datasize)
        
print('train is done')

end_time = time.time()  # store end time
print("TrainingTime: {} sec".format(end_time - start_time))

# Let's quickly save our trained model:
PATH = './cifar10_sp_mobilenet_server.pth'
torch.save(mobilenet_server.state_dict(), PATH)


Epoch 1 Client0 :   0%|                                                      | 0/25 [00:00<?, ?it/s]

timmer start!


Epoch 1 Client0 : 100%|█████████████████████████████████████████████| 25/25 [00:10<00:00,  2.39it/s]


train is done
TrainingTime: 10.692419290542603 sec


## Print commmunication overheads 

In [17]:
print('---total_sendsize_list---')
total_size = 0
for size in total_sendsize_list:
#     print(size)
    total_size += size
print("total_sendsize size: {} bytes".format(total_size))
print("number of total_send: ", len(total_sendsize_list))
print('\n')

print('---total_receivesize_list---')
total_size = 0
for size in total_receivesize_list:
#     print(size)
    total_size += size
print("total receive sizes: {} bytes".format(total_size) )
print("number of total receive: ", len(total_receivesize_list) )
print('\n')

for i in range(users):
    print('---client_sendsize_list(user{})---'.format(i))
    total_size = 0
    for size in client_sendsize_list[i]:
#         print(size)
        total_size += size
    print("total client_sendsizes(user{}): {} bytes".format(i, total_size))
    print("number of client_send(user{}): ".format(i), len(client_sendsize_list[i]))
    print('\n')

    print('---client_receivesize_list(user{})---'.format(i))
    total_size = 0
    for size in client_receivesize_list[i]:
#         print(size)
        total_size += size
    print("total client_receive sizes(user{}): {} bytes".format(i, total_size))
    print("number of client_send(user{}): ".format(i), len(client_receivesize_list[i]))
    print('\n')

print('---train_sendsize_list---')
total_size = 0
for size in train_sendsize_list:
#     print(size)
    total_size += size
print("total train_sendsizes: {} bytes".format(total_size))
print("number of train_send: ", len(train_sendsize_list) )
print('\n')

print('---train_receivesize_list---')
total_size = 0
for size in train_receivesize_list:
#     print(size)
    total_size += size
print("total train_receivesizes: {} bytes".format(total_size))
print("number of train_receive: ", len(train_receivesize_list) )
print('\n')

---total_sendsize_list---
total_sendsize size: 6579294 bytes
number of total_send:  27


---total_receivesize_list---
total receive sizes: 6588294 bytes
number of total receive:  27


---client_sendsize_list(user0)---
total client_sendsizes(user0): 6579294 bytes
number of client_send(user0):  27


---client_receivesize_list(user0)---
total client_receive sizes(user0): 6588294 bytes
number of client_send(user0):  27


---train_sendsize_list---
total train_sendsizes: 6579289 bytes
number of train_send:  26


---train_receivesize_list---
total train_receivesizes: 6588289 bytes
number of train_receive:  26




# Validation after trainning

In [18]:
root_path = '../../models/cifar10_data'

## datasets

In [19]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

## Make train and test dataset batch generator

In [20]:
trainset = torchvision.datasets.CIFAR10 (root=root_path, train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10 (root=root_path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [21]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [22]:
x_train, y_train = next(iter(train_loader))
print(x_train.size())
print(y_train.size())

torch.Size([4, 3, 32, 32])
torch.Size([4])


In [23]:
train_total_batch = len(train_loader)
print(train_total_batch)
test_total_batch = len(test_loader)
print(test_total_batch)

12500
2500


## Calculate Accuracy after training

### before acc make full model

In [27]:
mobilenet_client.load_state_dict(client_weights)
mobilenet_client.to(device)

# Let's quickly save our trained model:
PATH = './cifar10_sp_mobilenet_client.pth'
torch.save(client_weights, PATH)


### train acc

In [28]:
# train acc
with torch.no_grad():
    corr_num = 0
    total_num = 0
    train_loss = 0.0
    for j, trn in enumerate(train_loader):
        trn_x, trn_label = trn
        trn_x = trn_x.to(device)
        trn_label = trn_label.to(device)


        trn_output = mobilenet_client(trn_x)
        trn_label = trn_label.long()
        
        trn_output = mobilenet_server(trn_output)
        
        loss = criterion(trn_output, trn_label)
        train_loss += loss.item()
        model_label = trn_output.argmax(dim=1)
        corr = trn_label[trn_label == model_label].size(0)
        corr_num += corr
        total_num += trn_label.size(0)
    print("train_acc: {:.2f}%, train_loss: {:.4f}".format(corr_num / total_num * 100, train_loss / len(train_loader)))


train_acc: 12.75%, train_loss: 2.4140


### test acc

In [29]:
with torch.no_grad():
    corr_num = 0
    total_num = 0
    val_loss = 0.0
    for j, val in enumerate(test_loader):
        val_x, val_label = val
        val_x = val_x.to(device)
        val_label = val_label.to(device)

        val_output = mobilenet_client(val_x)
        val_label = val_label.long()
        
        val_output = mobilenet_server(val_output)
        
        loss = criterion(val_output, val_label)
        val_loss += loss.item()
        model_label = val_output.argmax(dim=1)
        corr = val_label[val_label == model_label].size(0)
        corr_num += corr
        total_num += val_label.size(0)
    print("test_acc: {:.2f}%, test_loss: {:.4f}".format(corr_num / total_num * 100, val_loss / len(test_loader)))


test_acc: 12.58%, test_loss: 2.4114


### acc of each acc 

In [30]:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

with torch.no_grad():
    for data in test_loader:
        x, labels = data
        x = x.to(device)
        labels = labels.to(device)

        outputs = mobilenet_client(x)
        outputs = mobilenet_server(outputs)
        labels = labels.long()
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(len(labels)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(10):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))


end_time = time.time()  # store end time
print("WorkingTime: {} sec".format(end_time - start_time))

Accuracy of plane :  0 %
Accuracy of   car : 64 %
Accuracy of  bird :  0 %
Accuracy of   cat : 33 %
Accuracy of  deer : 18 %
Accuracy of   dog :  0 %
Accuracy of  frog :  0 %
Accuracy of horse :  9 %
Accuracy of  ship :  0 %
Accuracy of truck :  0 %
WorkingTime: 267.57432889938354 sec
