这个文件是为了将双层的CNN模型分离成两个部分，以此在client和server两端分开计算。

In [1]:
import socket
import time
import struct
import os
import json

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
# 固定初始化种子
SEED = 24
torch.manual_seed(SEED)
if USE_CUDA:
    torch.cuda.manual_seed(SEED)

BATCH_SIZE = 32
LEARNING_RATE = 0.01
MOMENTUM = 0.5
NUM_EPOCHS = 2

In [3]:
train_dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(".data", train=True, download=True,
           transform=transforms.Compose([
               transforms.ToTensor(),
               # Normalize输入为两个tuple，output=(input-mean)/std
               transforms.Normalize((0.13066,), (0.30811,)) # (x,)输出为一维tuple
           ])),
    batch_size=BATCH_SIZE, shuffle=True, 
    num_workers=0, pin_memory=True
)
test_dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(".data", train=False, download=True,
           transform=transforms.Compose([
               transforms.ToTensor(),
               # Normalize输入为两个tuple，output=(input-mean)/std
               transforms.Normalize((0.13066,), (0.30811,)) # (x,)输出为一维tuple
           ])),
    batch_size=BATCH_SIZE, shuffle=True, 
    num_workers=0, pin_memory=False
)
print("Dataloader successfully loaded.")

Dataloader successfully loaded.


In [4]:
# 这是正常的模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # 1 channel -> 20 channels
        self.conv1 = nn.Conv2d(1, 20, 5, 1) # 28 * 28 -> (28+1-5) = 24 * 24
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
        
    def forward(self, x):
        # x: batch_size * 1 * 28 * 28
        x = F.relu(self.conv1(x)) # batch_size * 20 * 24 * 24
        x = F.max_pool2d(x,2,2) # batch_size * 20 * 12 * 12
        x = F.relu(self.conv2(x)) # batch_size * 50 * 8 * 8
        x = F.max_pool2d(x,2,2) # batch_size * 50 * 4 *4 
        x = x.view(-1, 4*4*50) # batch_size * (50*4*4) 
        x = F.relu(self.fc1(x))
        x= self.fc2(x)
        # return x
        return F.log_softmax(x, dim=1) # log probability

In [5]:
# 最简单粗暴的分离
# 将第一层的输出作为clientNet的输出
# 同时还是serverNet的输入
class clientNet(nn.Module):
    def __init__(self):
        super(clientNet, self).__init__()
        # 1 channel -> 20 channels
        self.conv1 = nn.Conv2d(1, 20, 5, 1) # 28 * 28 -> (28+1-5) = 24 * 24
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
        
    def forward(self, x):
        # x: batch_size * 1 * 28 * 28
        x = F.relu(self.conv1(x)) # batch_size * 20 * 24 * 24
        x = F.max_pool2d(x,2,2) # batch_size * 20 * 12 * 12
        return x
    
class serverNet(nn.Module):
    def __init__(self):
        super(serverNet, self).__init__()
        # 1 channel -> 20 channels
        self.conv1 = nn.Conv2d(1, 20, 5, 1) # 28 * 28 -> (28+1-5) = 24 * 24
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv2(x)) # batch_size * 50 * 8 * 8
        x = F.max_pool2d(x,2,2) # batch_size * 50 * 4 *4 
        x = x.view(-1, 4*4*50) # batch_size * (50*4*4) 
        x = F.relu(self.fc1(x))
        x= self.fc2(x)
        # return x
        return F.log_softmax(x, dim=1) # log probability

In [6]:
# 这是正常的训练过程
model = Net().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
model.train()

Net(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [7]:
start_1 = time.time()
for epoch in range(2):
    for i, (data, target) in enumerate(train_dataloader):
        data, target = data.to(device), target.to(device)

        pred = model(data) # batch_size * 10
        loss = F.nll_loss(pred, target) 

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print("Train Epoch: {}, iteration: {:>4d}, Loss: {:.6f}".format(
                epoch, i, loss.item()))
print("cost time: {:.6f}".format(time.time()-start_1))

Train Epoch: 0, iteration:    0, Loss: 2.349389
Train Epoch: 0, iteration:  100, Loss: 0.464970
Train Epoch: 0, iteration:  200, Loss: 0.241684
Train Epoch: 0, iteration:  300, Loss: 0.121889
Train Epoch: 0, iteration:  400, Loss: 0.140317
Train Epoch: 0, iteration:  500, Loss: 0.109634
Train Epoch: 0, iteration:  600, Loss: 0.251497
Train Epoch: 0, iteration:  700, Loss: 0.255145
Train Epoch: 0, iteration:  800, Loss: 0.089439
Train Epoch: 0, iteration:  900, Loss: 0.067693
Train Epoch: 0, iteration: 1000, Loss: 0.051744
Train Epoch: 0, iteration: 1100, Loss: 0.079346
Train Epoch: 0, iteration: 1200, Loss: 0.036770
Train Epoch: 0, iteration: 1300, Loss: 0.047175
Train Epoch: 0, iteration: 1400, Loss: 0.358099
Train Epoch: 0, iteration: 1500, Loss: 0.128052
Train Epoch: 0, iteration: 1600, Loss: 0.131123
Train Epoch: 0, iteration: 1700, Loss: 0.133915
Train Epoch: 0, iteration: 1800, Loss: 0.112847
Train Epoch: 1, iteration:    0, Loss: 0.091838
Train Epoch: 1, iteration:  100, Loss: 0

In [8]:
# 这是分开后的model做的训练
client_model = clientNet().to(device)
server_model = serverNet().to(device)
client_optimizer = torch.optim.SGD(client_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
server_optimizer = torch.optim.SGD(server_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
client_model.train()
server_model.train()

serverNet(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
)

In [9]:
# 这是分开后的训练的过程
start_2 = time.time()
for epoch in range(2):
    for i, (data, target) in enumerate(train_dataloader):
        # 按顺序通过client_model和server_model，得到pred为预测的结果，计算loss
        data, target = data.to(device), target.to(device)
        output = client_model(data)
        pred = server_model(output)
        loss = F.nll_loss(pred, target)
        # 然后先进行server的backward和optimize，然后获得output的grad
        server_optimizer.zero_grad()
        output.retain_grad() # 必须要retain output的grad，才能获得grad
        loss.backward()
        server_optimizer.step()
        grad = output.grad.data
        # 最后再利用output的grad对client进行backward和optimize
        client_optimizer.zero_grad()
        output = client_model(data) # output需要先利用之前的input计算获得graph
        output.backward(grad)       # 才能利用grad进行backward
        client_optimizer.step()
        if i % 100 == 0:
            print("Train Epoch: {}, iteration: {:>4d}, Loss: {:.6f}".format(
                epoch, i, loss.item()))
print("cost time: {:.6f}".format(time.time()-start_2))

Train Epoch: 0, iteration:    0, Loss: 2.287267
Train Epoch: 0, iteration:  100, Loss: 0.625813
Train Epoch: 0, iteration:  200, Loss: 0.480779
Train Epoch: 0, iteration:  300, Loss: 0.093468
Train Epoch: 0, iteration:  400, Loss: 0.237676
Train Epoch: 0, iteration:  500, Loss: 0.046658
Train Epoch: 0, iteration:  600, Loss: 0.085467
Train Epoch: 0, iteration:  700, Loss: 0.272831
Train Epoch: 0, iteration:  800, Loss: 0.025104
Train Epoch: 0, iteration:  900, Loss: 0.140939
Train Epoch: 0, iteration: 1000, Loss: 0.051098
Train Epoch: 0, iteration: 1100, Loss: 0.179332
Train Epoch: 0, iteration: 1200, Loss: 0.267376
Train Epoch: 0, iteration: 1300, Loss: 0.089552
Train Epoch: 0, iteration: 1400, Loss: 0.061261
Train Epoch: 0, iteration: 1500, Loss: 0.029585
Train Epoch: 0, iteration: 1600, Loss: 0.168747
Train Epoch: 0, iteration: 1700, Loss: 0.042958
Train Epoch: 0, iteration: 1800, Loss: 0.059305
Train Epoch: 1, iteration:    0, Loss: 0.026790
Train Epoch: 1, iteration:  100, Loss: 0

In [10]:
# 正常模型在测试集上的准确率
model.eval()
total_loss = 0.
correct = 0.
with torch.no_grad():
    for i, (data, target) in enumerate(test_dataloader):
        data, target = data.to(device), target.to(device)

        output = model(data) # batch_size * 10
        total_loss += F.nll_loss(output, target, reduction="sum").item() 
        pred = output.argmax(dim=1) # batch_size * 1
        correct += pred.eq(target.view_as(pred)).sum().item()
total_loss /= len(test_dataloader.dataset)
acc = correct/len(test_dataloader.dataset) * 100.
print("Test loss: {}, Accuracy: {}".format(total_loss, acc))

Test loss: 0.04174569380283356, Accuracy: 98.61


In [11]:
# 分离后模型在测试集上的准确率
client_model.eval()
server_model.eval()
total_loss = 0.
correct = 0.
with torch.no_grad():
    for i, (data, target) in enumerate(test_dataloader):
        data, target = data.to(device), target.to(device)

        output = server_model(client_model(data)) # batch_size * 10
        total_loss += F.nll_loss(output, target, reduction="sum").item() 
        pred = output.argmax(dim=1) # batch_size * 1
        correct += pred.eq(target.view_as(pred)).sum().item()
total_loss /= len(test_dataloader.dataset)
acc = correct/len(test_dataloader.dataset) * 100.
print("Test loss: {}, Accuracy: {}".format(total_loss, acc))

Test loss: 0.044531891798973086, Accuracy: 98.6
