In [1]:
epochs = 1
users = 1 # number of users

In [2]:
import os
import h5py

import socket
import struct
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

import time
import sys



from tqdm import tqdm


from torch.utils.data import Dataset, DataLoader

import copy

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

torch.manual_seed(777)
if device =="cuda:0":
    torch.cuda.manual_seed_all(777)

In [4]:
# -*- coding: utf-8 -*-
"""
Created on Thu Nov  1 14:23:31 2018
@author: tshzzz
"""

def conv_dw_server(inplane,outplane,stride=1):
    return nn.Sequential(
        nn.Conv2d(inplane,outplane,kernel_size = 1,groups = 1,stride=1),
        nn.BatchNorm2d(outplane),
        nn.ReLU()    
        )

def conv_dw(inplane,outplane,stride=1):
    return nn.Sequential(
        nn.Conv2d(inplane,inplane,kernel_size = 3,groups = inplane,stride=stride,padding=1),
        nn.BatchNorm2d(inplane),
        nn.ReLU(),
        nn.Conv2d(inplane,outplane,kernel_size = 1,groups = 1,stride=1),
        nn.BatchNorm2d(outplane),
        nn.ReLU()    
        )


def conv_bw(inplane,outplane,kernel_size = 3,stride=1):
    return nn.Sequential(
        nn.Conv2d(inplane,outplane,kernel_size = kernel_size,groups = 1,stride=stride,padding=1),
        nn.BatchNorm2d(outplane),
        nn.ReLU() 
        )


class MobileNet(nn.Module):
    
    def __init__(self,num_class=10):
        super(MobileNet,self).__init__()
        
        layers = []
#         layers.append(conv_bw(3,32,3,1))
        layers.append(conv_dw_server(32,64,1))
        layers.append(conv_dw(64,128,2))
        layers.append(conv_dw(128,128,1))
        layers.append(conv_dw(128,256,2))
        layers.append(conv_dw(256,256,1))
        layers.append(conv_dw(256,512,2))

        for i in range(5):
            layers.append(conv_dw(512,512,1))
        layers.append(conv_dw(512,1024,2))
        layers.append(conv_dw(1024,1024,1))

        self.classifer = nn.Sequential(
                nn.Dropout(0.5),
                nn.Linear(1024,num_class)
                )
        self.feature = nn.Sequential(*layers)
        
        

    def forward(self,x):
        out = self.feature(x)
        out = out.mean(3).mean(2)
        out = out.view(-1,1024)
        out = self.classifer(out)
        return out

In [5]:
mobilenet_server = MobileNet().to(device)
print(mobilenet_server)

MobileNet(
  (classifer): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=1024, out_features=10, bias=True)
  )
  (feature): Sequential(
    (0): Sequential(
      (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=64)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d

In [6]:
# from torchsummary import summary

# summary(mobilenet_server, (32, 32, 32))

In [7]:
# -*- coding: utf-8 -*-
"""
Created on Thu Nov  1 14:23:31 2018
@author: tshzzz
"""

import torch
import torch.nn as nn


def conv_dw_client(inplane,stride=1):
    return nn.Sequential(
        nn.Conv2d(inplane,inplane,kernel_size = 3,groups = inplane,stride=stride,padding=1),
        nn.BatchNorm2d(inplane),
        nn.ReLU()  
        )

def conv_bw(inplane,outplane,kernel_size = 3,stride=1):
    return nn.Sequential(
        nn.Conv2d(inplane,outplane,kernel_size = kernel_size,groups = 1,stride=stride,padding=1),
        nn.BatchNorm2d(outplane),
        nn.ReLU() 
        )


class MobileNet(nn.Module):
    
    def __init__(self,num_class=10):
        super(MobileNet,self).__init__()
        
        layers = []
        layers.append(conv_bw(3,32,3,1))
        layers.append(conv_dw_client(32,1))
#         layers.append(conv_dw(64,128,2))
#         layers.append(conv_dw(128,128,1))
#         layers.append(conv_dw(128,256,2))
#         layers.append(conv_dw(256,256,1))
#         layers.append(conv_dw(256,512,2))

#         for i in range(5):
#             layers.append(conv_dw(512,512,1))
#         layers.append(conv_dw(512,1024,2))
#         layers.append(conv_dw(1024,1024,1))

#         self.classifer = nn.Sequential(
#                 nn.Dropout(0.5),
#                 nn.Linear(1024,num_class)
#                 )
        self.feature = nn.Sequential(*layers)
        
        

    def forward(self,x):
        out = self.feature(x)
#         out = out.mean(3).mean(2)
#         out = out.view(-1,1024)
#         out = self.classifer(out)
        return out

In [8]:
mobilenet_client = MobileNet().to(device)
print(mobilenet_client)

MobileNet(
  (feature): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
)


In [9]:
# from torchsummary import summary

# summary(mobilenet_client, (3, 32, 32))

In [10]:
criterion = nn.CrossEntropyLoss()
lr = 0.001
optimizer = optim.SGD(mobilenet_server.parameters(), lr=lr, momentum=0.9)

clientsoclist = []
train_total_batch = []
val_acc = []
client_weights = copy.deepcopy(mobilenet_client.state_dict())

In [11]:
total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []

In [12]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    l_send = len(msg)
    msg = struct.pack('>I', l_send) + msg
    sock.sendall(msg)
    return l_send
def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    return msg, msglen

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data


In [13]:
host = socket.gethostbyname(socket.gethostname())
#host = 'localhost'
port = 10080
print(host)

10.0.0.94


In [14]:
s = socket.socket()
s.bind((host, port))
s.listen(5)

In [15]:
for i in range(users):
    conn, addr = s.accept()
    print('Conntected with', addr)
    clientsoclist.append(conn)    # append client socket on list

    datasize = send_msg(conn, epochs)    #send epoch
    total_sendsize_list.append(datasize)
    client_sendsize_list[i].append(datasize)

    total_batch, datasize = recv_msg(conn)    # get total_batch of train dataset
    total_receivesize_list.append(datasize)
    client_receivesize_list[i].append(datasize)

    train_total_batch.append(total_batch)    # append on list

Conntected with ('10.0.0.53', 60668)


In [16]:
start_time = time.time()    # store start time
print("timmer start!")

for e in range(epochs):

    # train client 0

    for user in range(users):

        datasize = send_msg(clientsoclist[user], client_weights)
        total_sendsize_list.append(datasize)
        client_sendsize_list[user].append(datasize)
        train_sendsize_list.append(datasize)

        for i in tqdm(range(train_total_batch[user]), ncols=100, desc='Epoch {} Client{} '.format(e+1, user)):
            optimizer.zero_grad()  # initialize all gradients to zero

            msg, datasize = recv_msg(clientsoclist[user])  # receive client message from socket
            total_receivesize_list.append(datasize)
            client_receivesize_list[user].append(datasize)
            train_receivesize_list.append(datasize)

            client_output_cpu = msg['client_output']  # client output tensor
            label = msg['label']  # label

            client_output = client_output_cpu.to(device)
            label = label.clone().detach().long().to(device)

            output = mobilenet_server(client_output)  # forward propagation
            loss = criterion(output, label)  # calculates cross-entropy loss
            loss.backward()  # backward propagation
            msg = client_output_cpu.grad.clone().detach()

            datasize = send_msg(clientsoclist[user], msg)
            total_sendsize_list.append(datasize)
            client_sendsize_list[user].append(datasize)
            train_sendsize_list.append(datasize)
            
            optimizer.step()
            
        client_weights, datasize = recv_msg(clientsoclist[user])
        total_receivesize_list.append(datasize)         
        client_receivesize_list[user].append(datasize)
        train_receivesize_list.append(datasize)
        
print('train is done')

end_time = time.time()  # store end time
print("TrainingTime: {} sec".format(end_time - start_time))

# Let's quickly save our trained model:
PATH = './cifar10_sp_mobilenet_server.pth'
torch.save(mobilenet_server.state_dict(), PATH)

Epoch 1 Client0 :   0%|                                                   | 0/12500 [00:00<?, ?it/s]

timmer start!


Epoch 1 Client0 : 100%|█████████████████████████████████████| 12500/12500 [1:21:21<00:00,  2.56it/s]


train is done
TrainingTime: 4891.817382335663 sec


In [17]:
print('---total_sendsize_list---')
total_size = 0
for size in total_sendsize_list:
#     print(size)
    total_size += size
print("total_sendsize size: {} bytes".format(total_size))
print("number of total_send: ", len(total_sendsize_list))
print('\n')

print('---total_receivesize_list---')
total_size = 0
for size in total_receivesize_list:
#     print(size)
    total_size += size
print("total receive sizes: {} bytes".format(total_size) )
print("number of total receive: ", len(total_receivesize_list) )
print('\n')

for i in range(users):
    print('---client_sendsize_list(user{})---'.format(i))
    total_size = 0
    for size in client_sendsize_list[i]:
#         print(size)
        total_size += size
    print("total client_sendsizes(user{}): {} bytes".format(i, total_size))
    print("number of client_send(user{}): ".format(i), len(client_sendsize_list[i]))
    print('\n')

    print('---client_receivesize_list(user{})---'.format(i))
    total_size = 0
    for size in client_receivesize_list[i]:
#         print(size)
        total_size += size
    print("total client_receive sizes(user{}): {} bytes".format(i, total_size))
    print("number of client_send(user{}): ".format(i), len(client_receivesize_list[i]))
    print('\n')

print('---train_sendsize_list---')
total_size = 0
for size in train_sendsize_list:
#     print(size)
    total_size += size
print("total train_sendsizes: {} bytes".format(total_size))
print("number of train_send: ", len(train_sendsize_list) )
print('\n')

print('---train_receivesize_list---')
total_size = 0
for size in train_receivesize_list:
#     print(size)
    total_size += size
print("total train_receivesizes: {} bytes".format(total_size))
print("number of train_receive: ", len(train_receivesize_list) )
print('\n')

---total_sendsize_list---
total_sendsize size: 6558898075 bytes
number of total_send:  12502


---total_receivesize_list---
total receive sizes: 6563285641 bytes
number of total receive:  12502


---client_sendsize_list(user0)---
total client_sendsizes(user0): 6558898075 bytes
number of client_send(user0):  12502


---client_receivesize_list(user0)---
total client_receive sizes(user0): 6563285641 bytes
number of client_send(user0):  12502


---train_sendsize_list---
total train_sendsizes: 6558898070 bytes
number of train_send:  12501


---train_receivesize_list---
total train_receivesizes: 6563285626 bytes
number of train_receive:  12501




In [20]:
root_path = 'C:/Users/tejov/models/cifar10_data'

In [19]:
import os
print(os.getcwd())


C:\Users\tejov


In [21]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])

In [22]:
trainset = torchvision.datasets.CIFAR10 (root=root_path, train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10 (root=root_path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [23]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [24]:
x_train, y_train = next(iter(train_loader))
print(x_train.size())
print(y_train.size())

torch.Size([4, 3, 32, 32])
torch.Size([4])


In [25]:
train_total_batch = len(train_loader)
print(train_total_batch)
test_total_batch = len(test_loader)
print(test_total_batch)

12500
2500


In [26]:
mobilenet_client.load_state_dict(client_weights)
mobilenet_client.to(device)

# Let's quickly save our trained model:
PATH = './cifar10_sp_mobilenet_client.pth'
torch.save(client_weights, PATH)

In [27]:
# train acc
with torch.no_grad():
    corr_num = 0
    total_num = 0
    train_loss = 0.0
    for j, trn in enumerate(train_loader):
        trn_x, trn_label = trn
        trn_x = trn_x.to(device)
        trn_label = trn_label.to(device)


        trn_output = mobilenet_client(trn_x)
        trn_label = trn_label.long()
        
        trn_output = mobilenet_server(trn_output)
        
        loss = criterion(trn_output, trn_label)
        train_loss += loss.item()
        model_label = trn_output.argmax(dim=1)
        corr = trn_label[trn_label == model_label].size(0)
        corr_num += corr
        total_num += trn_label.size(0)
    print("train_acc: {:.2f}%, train_loss: {:.4f}".format(corr_num / total_num * 100, train_loss / len(train_loader)))

train_acc: 34.89%, train_loss: 1.8674


In [28]:
with torch.no_grad():
    corr_num = 0
    total_num = 0
    val_loss = 0.0
    for j, val in enumerate(test_loader):
        val_x, val_label = val
        val_x = val_x.to(device)
        val_label = val_label.to(device)

        val_output = mobilenet_client(val_x)
        val_label = val_label.long()
        
        val_output = mobilenet_server(val_output)
        
        loss = criterion(val_output, val_label)
        val_loss += loss.item()
        model_label = val_output.argmax(dim=1)
        corr = val_label[val_label == model_label].size(0)
        corr_num += corr
        total_num += val_label.size(0)
    print("test_acc: {:.2f}%, test_loss: {:.4f}".format(corr_num / total_num * 100, val_loss / len(test_loader)))

test_acc: 34.45%, test_loss: 1.8703


Accuracy of plane : 51 %
Accuracy of   car : 54 %
Accuracy of  bird : 19 %
Accuracy of   cat : 23 %
Accuracy of  deer : 35 %
Accuracy of   dog : 14 %
Accuracy of  frog : 32 %
Accuracy of horse : 31 %
Accuracy of  ship : 42 %
Accuracy of truck : 40 %
WorkingTime: 6926.903380155563 sec
