In [99]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


hyper-parameter

In [100]:
# hyper-parameter

# N = K + R, N is the distributed device number
K = 2
R = 1
N = K + R

print(f"K: {K}, R: {R}, N: {N}")

K: 2, R: 1, N: 3


import library

In [101]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from tqdm import tqdm
import datetime
import os

from base_model.LeNet5 import LeNet5
from dataset.splited_dataset import SplitedTrainDataset
from dataset.image_dataset import ImageDataset
from encoder.mlp_encoder import MLPEncoder
from decoder.mlp_decoder import MLPDecoder
from util.util import cal_output_size

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Device: cuda


prepare base model (already trained)

In [111]:
base_model_path = "base_model/LeNet5/MNIST/2023_11_28/model.pth"

model = LeNet5(input_dim=(1, 28, 28), num_classes=10)
model.load_state_dict(torch.load(base_model_path))
print(model)
model_distributed = model.get_conv_segment()
model_fc = model.get_fc_segment()

LeNet5(
  (conv_block1): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv_block2): Sequential(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Sequential(
    (0): Linear(in_features=256, out_features=120, bias=True)
    (1): ReLU()
  )
  (fc2): Sequential(
    (0): Linear(in_features=120, out_features=84, bias=True)
    (1): ReLU()
  )
  (out): Sequential(
    (0): Linear(in_features=84, out_features=10, bias=True)
  )
)


---

## Train Encoder and Decoder

load data

In [114]:
train_datasets = SplitedTrainDataset()
train_datasets.load(f"./data/MNIST/split/{K}/split_train_datasets.pt")

data_shape = train_datasets.data_shape

print(train_datasets.describe())
print(f"data shape: {data_shape}")

Splited Train Dataset: split_num=2, data_num=60000, data_shape=(1, 28, 20)
Origin data shape: torch.Size([1, 28, 20])


prepare encoder and decoder

In [104]:
encoder = MLPEncoder(num_in=K, num_out=R, in_dim=data_shape)
output_size = cal_output_size(model_distributed, data_shape)
decoder = MLPDecoder(num_in=N, num_out=K, in_dim=output_size)

print(encoder)
print(decoder)

MLPEncoder(
  (nn): Sequential(
    (0): Linear(in_features=1120, out_features=1120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1120, out_features=560, bias=True)
  )
)
MLPDecoder(
  (nn): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=256, bias=True)
  )
)


In [105]:
# encoder = MLPEncoder(
#     input_dim=img_size,
#     hidden_dims=[128, 64, 32],
#     output_dim=32,
#     activation=nn.ReLU(),
#     dropout=0.5,
# )

In [116]:
model_distributed.to(device)
model_fc.to(device)
encoder.to(device)
decoder.to(device)

model_distributed.eval()
model_fc.eval()
encoder.train()
decoder.train()

print()




In [117]:
train_loader = DataLoader(train_datasets, batch_size=64, shuffle=False)

criterion = nn.MSELoss()
optimizer_encoder = optim.SGD(encoder.parameters(), lr=0.01, momentum=0.9)
optimizer_decoder = optim.SGD(decoder.parameters(), lr=0.01, momentum=0.9)

epoch_num = 10

for epoch in range(epoch_num):
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epoch_num}")
    for images_list, label in train_loader_tqdm:
        images_list = [images.to(device) for images in images_list]
        label = label.to(device)

        # forward
        imageDataset_list = [
            ImageDataset(images) for images in images_list + encoder(images_list)
        ]
        output_list = []
        for i in range(N):
            imageDataset = imageDataset_list[i]
            output = model_distributed(imageDataset.images)
            output_list.append(output)
        decoded_output_list = decoder(output_list)
        output = torch.cat(decoded_output_list, dim=3)
        output = output.view(output.size(0), -1)

        loss = criterion(output, label.view(label.size(0), -1))

        # backward
        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()

        train_loader_tqdm.set_postfix(loss=loss.item())

    print(f"Epoch [{epoch+1}/{epoch_num}], Loss: {loss.item()}")

Epoch 1/10: 100%|██████████| 938/938 [00:18<00:00, 51.05it/s, loss=0.598]


Epoch [1/10], Loss: 0.597885012626648


Epoch 2/10: 100%|██████████| 938/938 [00:16<00:00, 55.67it/s, loss=0.394]


Epoch [2/10], Loss: 0.3939012885093689


Epoch 3/10: 100%|██████████| 938/938 [00:16<00:00, 55.86it/s, loss=0.301]


Epoch [3/10], Loss: 0.30054810643196106


Epoch 4/10: 100%|██████████| 938/938 [00:15<00:00, 60.97it/s, loss=0.245]


Epoch [4/10], Loss: 0.24453651905059814


Epoch 5/10: 100%|██████████| 938/938 [00:10<00:00, 89.10it/s, loss=0.205]


Epoch [5/10], Loss: 0.20505504310131073


Epoch 6/10: 100%|██████████| 938/938 [00:10<00:00, 88.39it/s, loss=0.178]


Epoch [6/10], Loss: 0.1780446171760559


Epoch 7/10: 100%|██████████| 938/938 [00:10<00:00, 88.57it/s, loss=0.157]


Epoch [7/10], Loss: 0.157220721244812


Epoch 8/10: 100%|██████████| 938/938 [00:10<00:00, 90.56it/s, loss=0.141] 


Epoch [8/10], Loss: 0.14113657176494598


Epoch 9/10: 100%|██████████| 938/938 [00:10<00:00, 89.28it/s, loss=0.128] 


Epoch [9/10], Loss: 0.1283937394618988


Epoch 10/10: 100%|██████████| 938/938 [00:10<00:00, 86.81it/s, loss=0.118] 

Epoch [10/10], Loss: 0.11830160021781921





save encoder and decoder

In [118]:
now = datetime.datetime.now()
date = now.strftime("%Y_%m_%d")

encoder_path = (
    f"encoder/MLP/MNIST/{date}/"
    + f"encoder-"
    + f"task_MNIST-basemodel_LeNet5-"
    + f"in_{encoder.num_in}-out_{encoder.num_out}.pth"
)
dirpath = os.path.dirname(encoder_path)
if not os.path.exists(dirpath):
    os.makedirs(dirpath)
torch.save(encoder.state_dict(), encoder_path)

decoder_path = (
    f"decoder/MLP/MNIST/{date}/"
    + f"decoder-"
    + f"task_MNIST-basemodel_LeNet5-"
    + f"in_{decoder.num_in}-out_{decoder.num_out}.pth"
)
dirpath = os.path.dirname(decoder_path)
if not os.path.exists(dirpath):
    os.makedirs(dirpath)
torch.save(decoder.state_dict(), decoder_path)

In [125]:
torch.randperm(len(output_list))[:9]

tensor([1, 0, 2])