In [2]:
%load_ext autoreload
%autoreload 2

hyper-parameter

In [69]:
# hyper-parameter

# N = K + R, N is the distributed device number
K = 2
R = 1
N = K + R

print(f"K: {K}, R: {R}, N: {N}")

K: 2, R: 1, N: 3


import library

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from tqdm import tqdm
import datetime
import os

from base_models.LeNet5 import LeNet5
from util.dataset import PartitionedDataset, ImageDataset
from encoder.mlp_encoder import MLPEncoder
from decoder.mlp_decoder import MLPDecoder
from data_processing.data_partition import data_partition, cal_output_size

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Device: cuda


prepare base model (already trained)

In [48]:
base_model_path = "base_models/LeNet5/MNIST/2023_11_28/model.pth"

model = LeNet5(input_dim=(1, 28, 28), num_classes=10)
model.load_state_dict(torch.load(base_model_path))
print("model loaded successfully")

model.to(device)
print(f"model is on device: {next(model.parameters()).device}")


model_distributed = model.get_conv_segment().eval()
print(
    f"model_distributed: device: {next(model_distributed.parameters()).device}, mode: {'training' if model_distributed.training is True else 'eval'}"
)

model_fc = model.get_fc_segment().eval()
print(
    f"model_fc: device: {next(model_fc.parameters()).device}, mode: {'training' if model_fc.training is True else 'eval'}"
)

# for param in model_distributed.parameters():
#     param.requires_grad = False

# for param in model_fc.parameters():
#     param.requires_grad = False

model loaded successfully
model is on device: cuda:0
model_distributed: device: cuda:0, mode: eval
model_fc: device: cuda:0, mode: eval


---

## Data partation

In [6]:
# transform = transforms.Compose(
#     [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
# )

# dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
# print(f"Test dataset: {len(dataset)}")
# print("image size: ", dataset[0][0].size())

# data_partition(
#     model.get_conv_segment(),
#     dataset,
#     K,
#     f"./data/MNIST/partition/{K}/partitioned_train_datasets.pt",
# )

---

## Train Encoder and Decoder

load data

In [7]:
train_datasets = torch.load(f"./data/MNIST/partition/{K}/partitioned_train_datasets.pt")

datasets_num = len(train_datasets[0][0])
data_size = len(train_datasets)
img_size = train_datasets[0][0][0].size()

print(f"type of dataset: {type(train_datasets)}")
print(f"test_datasets num: {datasets_num}, which shold be equal to K = {K}")
print(f"constructure: ([images, images, ...], conv_segment_labels, labels)")
print(f"datasize: {data_size}")
print(f"image size: {img_size}")

type of dataset: <class 'util.dataset.PartitionedDataset'>
test_datasets num: 2, which shold be equal to K = 2
constructure: ([images, images, ...], conv_segment_labels, labels)
datasize: 60000
image size: torch.Size([1, 28, 20])


prepare encoder and decoder

In [70]:
encoder = MLPEncoder(num_in=K, num_out=R, in_dim=tuple(img_size)).train()
encoder.to(device)

output_size = cal_output_size(model_distributed, tuple(img_size))

decoder = MLPDecoder(num_in=N, num_out=K, in_dim=output_size).train()
decoder.to(device)

print(encoder)
print(decoder)

MLPEncoder(
  (nn): Sequential(
    (0): Linear(in_features=1120, out_features=1120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1120, out_features=560, bias=True)
  )
)
MLPDecoder(
  (nn): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=256, bias=True)
  )
)


In [9]:
# encoder = MLPEncoder(
#     input_dim=img_size,
#     hidden_dims=[128, 64, 32],
#     output_dim=32,
#     activation=nn.ReLU(),
#     dropout=0.5,
# )

In [73]:
train_loader = DataLoader(train_datasets, batch_size=64, shuffle=False)

criterion = nn.MSELoss()
optimizer_encoder = optim.SGD(encoder.parameters(), lr=0.01, momentum=0.9)
optimizer_decoder = optim.SGD(decoder.parameters(), lr=0.01, momentum=0.9)

num_epochs = 10

for epoch in range(num_epochs):
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for images_list, conv_segment_label, label in train_loader_tqdm:
        images_list = [images.to(device) for images in images_list]
        label = label.to(device)

        # forward
        imageDataset_list = [
            ImageDataset(images) for images in images_list + encoder(images_list)
        ]
        output_list = []
        for i in range(N):
            imageDataset = imageDataset_list[i]
            output = model_distributed(imageDataset.images)
            output_list.append(output)
        output_list[0] = torch.zeros_like(output_list[0])
        decoded_output_list = decoder(output_list)
        output = torch.cat(decoded_output_list, dim=3)
        output = output.view(output.size(0), -1)

        loss = criterion(
            output, conv_segment_label.view(conv_segment_label.size(0), -1)
        )

        # backward
        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()
        
        train_loader_tqdm.set_postfix(loss=loss.item())

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")

Epoch 1/10: 100%|██████████| 938/938 [00:18<00:00, 49.48it/s, loss=0.57] 


Epoch [1/10], Loss: 0.5701849460601807


Epoch 2/10: 100%|██████████| 938/938 [00:16<00:00, 57.96it/s, loss=0.453]


Epoch [2/10], Loss: 0.45349356532096863


Epoch 3/10: 100%|██████████| 938/938 [00:16<00:00, 57.44it/s, loss=0.4]  


Epoch [3/10], Loss: 0.39994168281555176


Epoch 4/10: 100%|██████████| 938/938 [00:15<00:00, 59.35it/s, loss=0.363]


Epoch [4/10], Loss: 0.3629976511001587


Epoch 5/10: 100%|██████████| 938/938 [00:16<00:00, 57.74it/s, loss=0.334]


Epoch [5/10], Loss: 0.33446648716926575


Epoch 6/10: 100%|██████████| 938/938 [00:16<00:00, 55.78it/s, loss=0.311]


Epoch [6/10], Loss: 0.3112429976463318


Epoch 7/10: 100%|██████████| 938/938 [00:15<00:00, 59.11it/s, loss=0.292]


Epoch [7/10], Loss: 0.29171788692474365


Epoch 8/10: 100%|██████████| 938/938 [00:16<00:00, 58.61it/s, loss=0.277]


Epoch [8/10], Loss: 0.27736878395080566


Epoch 9/10: 100%|██████████| 938/938 [00:15<00:00, 59.01it/s, loss=0.265]


Epoch [9/10], Loss: 0.26533037424087524


Epoch 10/10: 100%|██████████| 938/938 [00:15<00:00, 60.98it/s, loss=0.255]

Epoch [10/10], Loss: 0.25498712062835693





save encoder and decoder

In [72]:
now = datetime.datetime.now()
date = now.strftime("%Y_%m_%d")

encoder_path = f"encoder/MLP/MNIST/{date}/model.pth"
dirpath = os.path.dirname(encoder_path)
if not os.path.exists(dirpath):
    os.makedirs(dirpath)
torch.save(encoder.state_dict(), encoder_path)

decoder_path = f"decoder/MLP/MNIST/{date}/model.pth"
dirpath = os.path.dirname(decoder_path)
if not os.path.exists(dirpath):
    os.makedirs(dirpath)
torch.save(decoder.state_dict(), decoder_path)

---

In [65]:
decoder = MLPDecoder(num_in=N, num_out=K, in_dim=(16, 4, 2))
decoder.to(device)

data = [
    torch.randn(10000, 16, 4, 2).to(device),
    torch.randn(10000, 16, 4, 2).to(device),
]
label = torch.cat(data, dim=3).view(10000, -1)

criterion = nn.MSELoss()
optimizer = optim.SGD(decoder.parameters(), lr=0.01, momentum=0.9)

decoder.train()
for iteration in range(1000):
    output = decoder(data)
    output = torch.cat(output, dim=3)
    output = output.view(output.size(0), -1)
    loss = criterion(output, label)

    optimizer.zero_grad()  # 清除之前的梯度
    loss.backward()  # 反向传播
    optimizer.step()  # 更新权重

    print(loss)

tensor(1.0118, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0117, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0117, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0117, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0116, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0115, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0115, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0114, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0113, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0112, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0111, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0110, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0109, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0108, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0107, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0105, device='cuda:0', grad_fn=<MseLossBackward0>)
tensor(1.0104, device='cuda:0', grad_fn=