# Single Client Side

## Imports

In [1]:
import os
import time

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import Subset

### User imports

In [2]:
from src.lib import *

## SET CUDA(client use CPUs)

In [3]:
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
device = "cpu"
torch.manual_seed(777)
if device =="cuda:0":
    torch.cuda.manual_seed_all(777)

## set main variables

In [4]:
# model_name = "mobilenet"
model_name = "squeezenet"
data_path = './models/cifar10_data'
nuser = 1
datasize_total = 50000
datasize_per_client = datasize_total // nuser

In [5]:
client_order = 0

### network setting

In [6]:
host = "localhost" # input("IP address: ")
port = 10089

## Data load

In [7]:
transform_list = [transforms.ToTensor(),
                  transforms.Normalize((0.4914, 0.4822, 0.4465),
                                       (0.2470, 0.2435, 0.2616))]
if model_name.startswith("squeezenet"):
    transform_list.append(transforms.Resize((224, 224)))
transform = transforms.Compose(transform_list)
indices = list(range(50000))
slice_indices = indices[datasize_per_client * client_order : datasize_per_client * (client_order + 1)]

In [8]:
train_set = torchvision.datasets.CIFAR10 (root=data_path, train=True, download=True, transform=transform)
train_subset = Subset(train_set, slice_indices)
train_loader = torch.utils.data.DataLoader(train_subset, batch_size=16, shuffle=True, num_workers=2)

Files already downloaded and verified


### Size batch check

In [9]:
x_train, y_train = next(iter(train_loader))
print(x_train.size())
print(y_train.size())
total_batch = len(train_loader)
print(total_batch)

torch.Size([16, 3, 224, 224])
torch.Size([16])
3125


## Define model

In [10]:
client_model = None
if model_name == "mobilenet":
    client_model = ClientMobileNet()
elif model_name == "squeezenet":
    client_model = ClientSqueezeNet(num_classes=10)

client_model = client_model.to(device)
print(client_model)

ClientSqueezeNet(
  (layer1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  )
  (features): Sequential(
    (0): Fire(
      (squeeze): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inplace=True)
    )
    (1): Fire(
      (squeeze): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_activation): ReLU(inplace=True)
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand3x3_activation): ReLU(inpl

### Set other hyperparameters in the model
Hyperparameters here should be same with the server side.

In [11]:
lr = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(client_model.parameters(), lr=lr, momentum=0.9)

### Allocate Clients

In [12]:
client = Client(host, port)
client.connect()
epochs = client.training_prep(total_batch) # client's training_prep returns the number of epoch

## SET TIMER

In [13]:
start_time = time.time()    # store start time
print("training start!")

training start!


## Real training process

In [14]:
# receive initial weight from server
client_weights, _ = client.recv()
client_model.load_state_dict(client_weights)


for e in range(epochs):
    client_model.eval()
    for i, data in enumerate(tqdm(train_loader, ncols=100, desc='Epoch '+str(e+1))):
        x, label = data
        x = x.to(device)
        label = label.to(device)
        
        optimizer.zero_grad()
        output = client_model(x)
        client_output = output.clone().detach().requires_grad_(True)
        msg = {
            'client_output': client_output,
            'label': label
        }
        client.send(msg)
        client_grad, _ = client.recv()
        output.backward(client_grad)
        optimizer.step()


Epoch 1: 100%|██████████████████████████████████████████████████| 3125/3125 [05:15<00:00,  9.92it/s]
Epoch 2: 100%|██████████████████████████████████████████████████| 3125/3125 [05:13<00:00,  9.96it/s]
Epoch 3: 100%|██████████████████████████████████████████████████| 3125/3125 [05:13<00:00,  9.97it/s]
Epoch 4: 100%|██████████████████████████████████████████████████| 3125/3125 [05:31<00:00,  9.42it/s]
Epoch 5: 100%|██████████████████████████████████████████████████| 3125/3125 [05:37<00:00,  9.25it/s]
Epoch 6: 100%|██████████████████████████████████████████████████| 3125/3125 [05:43<00:00,  9.10it/s]
Epoch 7: 100%|██████████████████████████████████████████████████| 3125/3125 [05:33<00:00,  9.37it/s]
Epoch 8: 100%|██████████████████████████████████████████████████| 3125/3125 [05:24<00:00,  9.63it/s]
Epoch 9: 100%|██████████████████████████████████████████████████| 3125/3125 [05:32<00:00,  9.41it/s]
Epoch 10: 100%|█████████████████████████████████████████████████| 3125/3125 [05:29<00:00,  

In [15]:
client.send(client_model.state_dict())

2931707

In [16]:
elapsed_time = time.time() - start_time
print("elapsed time for training using", device ,": {} sec".format(elapsed_time))

elapsed time for training using cpu : 3274.5654566287994 sec
