In [1]:
%load_ext autoreload
%autoreload 2

import os

print("original dir: ", os.getcwd())

new_path = "../"
os.chdir(new_path)

print("changed dir: ", os.getcwd())

original dir:  f:\MyCourse(5 delayed 1)\erasure code\non-linear erasure code\src\test
changed dir:  f:\MyCourse(5 delayed 1)\erasure code\non-linear erasure code\src


hyper-parameter

In [2]:
# hyper-parameter

# N = K + R, N is the distributed device number
K = 2
R = 1
N = K + R

print(f"K: {K}, R: {R}, N: {N}")

K: 2, R: 1, N: 3


import library

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from tqdm import tqdm
import datetime
import os

from base_model.LeNet5 import LeNet5
from dataset.splited_dataset import SplitedTrainDataset
from dataset.image_dataset import ImageDataset
from encoder.mlp_encoder import MLPEncoder
from decoder.mlp_decoder import MLPDecoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Device: cuda


prepare base model (already trained)

In [5]:
base_model_path = "base_model/LeNet5/MNIST/2023_11_28/model.pth"

model = LeNet5(input_dim=(1, 28, 28), num_classes=10)
model.load_state_dict(torch.load(base_model_path))
print(model)
model_distributed = model.get_conv_segment()
model_fc = model.get_fc_segment()

LeNet5(
  (conv_block1): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv_block2): Sequential(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Sequential(
    (0): Linear(in_features=256, out_features=120, bias=True)
    (1): ReLU()
  )
  (fc2): Sequential(
    (0): Linear(in_features=120, out_features=84, bias=True)
    (1): ReLU()
  )
  (out): Sequential(
    (0): Linear(in_features=84, out_features=10, bias=True)
  )
)


---

## Train Encoder and Decoder

load data

In [6]:
train_datasets = SplitedTrainDataset()
train_datasets.load(f"./data/MNIST/split/{K}/split_train_datasets.pt")

data_shape = train_datasets.data_shape

print(train_datasets.describe())
print(f"data shape: {data_shape}")

Splited Train Dataset: split_num=2, data_num=60000, data_shape=(1, 28, 20)
data shape: (1, 28, 20)


prepare encoder and decoder

In [31]:
model.calculate_conv_output((1,28,16))

(16, 4, 1)

In [41]:
encoder = MLPEncoder(num_in=K, num_out=R, in_dim=data_shape)
decoder = MLPDecoder(num_in=N, num_out=K, in_dim=(16,4,4 // K))

print(encoder)
print(decoder)

MLPEncoder(
  (nn): Sequential(
    (0): Linear(in_features=1120, out_features=1120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1120, out_features=560, bias=True)
  )
)
MLPDecoder(
  (nn): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=256, bias=True)
  )
)


In [42]:
model_distributed.to(device)
model_fc.to(device)
encoder.to(device)
decoder.to(device)

model_distributed.eval()
model_fc.eval()
encoder.train()
decoder.train()

print()




In [44]:
train_loader = DataLoader(train_datasets, batch_size=64, shuffle=False)

criterion = nn.MSELoss()
optimizer_encoder = optim.SGD(encoder.parameters(), lr=0.01, momentum=0.9)
optimizer_decoder = optim.SGD(decoder.parameters(), lr=0.01, momentum=0.9)

epoch_num = 4

for epoch in range(epoch_num):
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epoch_num}")
    for images_list, label in train_loader_tqdm:
        images_list = [images.to(device) for images in images_list]
        label = label.to(device)

        # forward
        imageDataset_list = [
            ImageDataset(images) for images in images_list + encoder(images_list)
        ]
        output_list = []
        for i in range(N):
            imageDataset = imageDataset_list[i]
            output = model_distributed(imageDataset.images)
            output_list.append(output)
        decoded_output_list = decoder(output_list)
        output = torch.cat(decoded_output_list, dim=3)
        output = output.view(output.size(0), -1)

        loss = criterion(output, label.view(label.size(0), -1))

        # backward
        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        loss.backward()
        optimizer_encoder.step()
        optimizer_decoder.step()

        train_loader_tqdm.set_postfix(loss=loss.item())

    print(f"Epoch [{epoch+1}/{epoch_num}], Loss: {loss.item()}")

Epoch 1/10: 100%|██████████| 938/938 [00:11<00:00, 84.65it/s, loss=0.529]


Epoch [1/10], Loss: 0.5289015769958496


Epoch 2/10: 100%|██████████| 938/938 [00:11<00:00, 84.87it/s, loss=0.36] 


Epoch [2/10], Loss: 0.3596523106098175


Epoch 3/10: 100%|██████████| 938/938 [00:10<00:00, 85.74it/s, loss=0.273]


Epoch [3/10], Loss: 0.27286091446876526


Epoch 4/10: 100%|██████████| 938/938 [00:10<00:00, 85.64it/s, loss=0.223]


Epoch [4/10], Loss: 0.22265571355819702


Epoch 5/10: 100%|██████████| 938/938 [00:10<00:00, 87.37it/s, loss=0.188] 


Epoch [5/10], Loss: 0.18844646215438843


Epoch 6/10: 100%|██████████| 938/938 [00:10<00:00, 85.39it/s, loss=0.162]


Epoch [6/10], Loss: 0.16204270720481873


Epoch 7/10: 100%|██████████| 938/938 [00:10<00:00, 87.26it/s, loss=0.141]


Epoch [7/10], Loss: 0.14146225154399872


Epoch 8/10: 100%|██████████| 938/938 [00:10<00:00, 91.70it/s, loss=0.127] 


Epoch [8/10], Loss: 0.12700247764587402


Epoch 9/10: 100%|██████████| 938/938 [00:09<00:00, 95.37it/s, loss=0.115]  


Epoch [9/10], Loss: 0.11510047316551208


Epoch 10/10: 100%|██████████| 938/938 [00:09<00:00, 94.13it/s, loss=0.104] 

Epoch [10/10], Loss: 0.10435627400875092





In [45]:
import matplotlib.pyplot as plt

y = [e for l in loss_list for e in l]
print(f"记录的loss数量: {len(y)}")
print(f"最后一个loss: {y[-1]}")

plt.figure()
plt.plot(y)
plt.show()

NameError: name 'loss_list' is not defined

save encoder and decoder

In [118]:
now = datetime.datetime.now()
date = now.strftime("%Y_%m_%d")

encoder_path = (
    f"encoder/MLP/MNIST/{date}/"
    + f"encoder-"
    + f"task_MNIST-basemodel_LeNet5-"
    + f"in_{encoder.num_in}-out_{encoder.num_out}.pth"
)
dirpath = os.path.dirname(encoder_path)
if not os.path.exists(dirpath):
    os.makedirs(dirpath)
torch.save(encoder.state_dict(), encoder_path)

decoder_path = (
    f"decoder/MLP/MNIST/{date}/"
    + f"decoder-"
    + f"task_MNIST-basemodel_LeNet5-"
    + f"in_{decoder.num_in}-out_{decoder.num_out}.pth"
)
dirpath = os.path.dirname(decoder_path)
if not os.path.exists(dirpath):
    os.makedirs(dirpath)
torch.save(decoder.state_dict(), decoder_path)

evaluation

In [49]:
# encoder.eval()
# decoder.eval()

# images_list = [images.to(device) for images in train_datasets.images_list]

# imageDataset_list = [
#     ImageDataset(images) for images in images_list + encoder(images_list)
# ]

# output_list = []

# # inference on N devices
# for i in range(N):
#     imageDataset = imageDataset_list[i]

#     test_loader = DataLoader(imageDataset, batch_size=64, shuffle=False)
#     test_loader_tqdm = tqdm(test_loader, desc=f"model_distributed {i}")

#     output = torch.tensor([]).to(device)
#     with torch.no_grad():
#         for images in test_loader_tqdm:
#             images = images.to(device)
#             output = torch.cat((output, model_distributed(images)), dim=0)

#     output_list.append(output)


# def lose_something(output_list, lose_index=None):
#     if lose_index is None or len(lose_index) == 0:
#         return output_list

#     losed_output_list = []

#     for i in range(len(output_list)):

#         if i in lose_index:

#             losed_output_list.append(torch.zeros_like(output_list[i]))
#         else:

#             losed_output_list.append(output_list[i])
#     return losed_output_list


# losed_output_list = lose_something(output_list)
# decoded_output_list = decoder(losed_output_list)

# output = torch.cat(decoded_output_list, dim=3)
# output = output.view(output.size(0), -1)
# data_size = output.size(0)
# labels = train_datasets.labels
# labels = torch.tensor(labels).to(device)
# total = data_size
# correct = 0
# for i in tqdm(range(data_size)):
#     _, predicted = torch.max(model_fc(output[i]).data, 0)
#     correct += (predicted == labels[i]).sum().item()

# print(f"Accuracy on the Test set: {100 * correct / total}%")

model_distributed 0: 100%|██████████| 938/938 [00:02<00:00, 431.98it/s]
model_distributed 1: 100%|██████████| 938/938 [00:00<00:00, 1101.48it/s]
model_distributed 2: 100%|██████████| 938/938 [00:00<00:00, 1082.28it/s]


ValueError: only one element tensors can be converted to Python scalars

In [48]:
from dataset.splited_dataset import SplitedTestDataset

test_datasets = SplitedTestDataset()
test_datasets.load(f"./data/MNIST/split/{K}/split_test_datasets.pt")

<dataset.splited_dataset.SplitedTestDataset at 0x1f6d8106290>

In [51]:
encoder.eval()
decoder.eval()

images_list = [images.to(device) for images in test_datasets.images_list]

imageDataset_list = [
    ImageDataset(images) for images in images_list + encoder(images_list)
]

output_list = []

# inference on N devices
for i in range(N):
    imageDataset = imageDataset_list[i]

    test_loader = DataLoader(imageDataset, batch_size=64, shuffle=False)
    test_loader_tqdm = tqdm(test_loader, desc=f"model_distributed {i}")

    output = torch.tensor([]).to(device)
    with torch.no_grad():
        for images in test_loader_tqdm:
            images = images.to(device)
            output = torch.cat((output, model_distributed(images)), dim=0)

    output_list.append(output)


def lose_something(output_list, lose_index = None):
    if lose_index is None or len(lose_index) == 0:
        return output_list

    losed_output_list = []

    for i in range(len(output_list)):

        if i in lose_index:

            losed_output_list.append(torch.zeros_like(output_list[i]))
        else:

            losed_output_list.append(output_list[i])
    return losed_output_list


losed_output_list = lose_something(output_list)
decoded_output_list = decoder(losed_output_list)

output = torch.cat(decoded_output_list, dim=3)
output = output.view(output.size(0), -1)
data_size = output.size(0)
labels = test_datasets.labels
labels = torch.tensor(labels).to(device)
total = data_size
correct = 0
for i in tqdm(range(data_size)):
    _, predicted = torch.max(model_fc(output[i]).data, 0)
    correct += (predicted == labels[i]).sum().item()

print(f"Accuracy on the Test set: {100 * correct / total}%")

model_distributed 0: 100%|██████████| 157/157 [00:00<00:00, 548.83it/s]
model_distributed 1: 100%|██████████| 157/157 [00:00<00:00, 835.11it/s]
model_distributed 2: 100%|██████████| 157/157 [00:00<00:00, 980.20it/s] 
100%|██████████| 10000/10000 [00:07<00:00, 1256.24it/s]

Accuracy on the Test set: 99.03%



