In [1]:
%load_ext autoreload
%autoreload 2

import os

print("original dir: ", os.getcwd())

new_path = "../"
os.chdir(new_path)

print("changed dir: ", os.getcwd())

original dir:  f:\MyCourse(5 delayed 1)\erasure code\non-linear erasure code\src\test
changed dir:  f:\MyCourse(5 delayed 1)\erasure code\non-linear erasure code\src


In [2]:
# hyper-parameter

# N = K + R, N is the distributed device number
K = 4
R = 1
N = K + R

TASK = "MNIST"
BASE_MODEL = "LeNet5"
DATE = "2023_12_01"

print(f"K: {K}, R: {R}, N: {N}")
print(f"TASK: {TASK}, BASE_MODEL: {BASE_MODEL}, DATE: {DATE}")

K: 4, R: 1, N: 5
TASK: MNIST, BASE_MODEL: LeNet5, DATE: 2023_12_01


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from tqdm import tqdm
import datetime
import os
from typing import Any, List, Tuple, Dict, Optional

from base_model.LeNet5 import LeNet5
from dataset.splited_dataset import SplitedTestDataset
from dataset.image_dataset import ImageDataset
from encoder.mlp_encoder import MLPEncoder
from decoder.mlp_decoder import MLPDecoder
from util.split_data import split_data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Device: cuda


In [5]:
base_model_path = "base_model/LeNet5/MNIST/2023_11_28/model.pth"

model = LeNet5(input_dim=(1, 28, 28), num_classes=10)
model.load_state_dict(torch.load(base_model_path))
print("model loaded successfully")

model_distributed = model.get_conv_segment()
model_fc = model.get_fc_segment().eval()

model_distributed.to(device)
model_fc.to(device)

model_distributed.eval()
model_fc.eval()

print(model_distributed)
print(model_fc)

model loaded successfully
Sequential(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (4): ReLU()
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
Sequential(
  (0): Linear(in_features=256, out_features=120, bias=True)
  (1): ReLU()
  (2): Linear(in_features=120, out_features=84, bias=True)
  (3): ReLU()
  (4): Sequential(
    (0): Linear(in_features=84, out_features=10, bias=True)
  )
)


load data

In [6]:
# load data
test_datasets = SplitedTestDataset()
test_datasets.load(f"./data/MNIST/split/{K}/split_test_datasets.pt")
data_shape = test_datasets.data_shape
data_num = test_datasets.data_num
split_num = test_datasets.split_num
print(test_datasets.describe())

Splited Test Dataset: split_num=4, data_num=10000, data_shape=(1, 28, 16)


encode

In [7]:
encoder_path = f"encoder/MLP/MNIST/{DATE}/encoder-task_{TASK}-basemodel_{BASE_MODEL}-in_{K}-out_{R}.pth"


encoder = MLPEncoder(num_in=K, num_out=R, in_dim=tuple(data_shape))

encoder.load_state_dict(torch.load(encoder_path))

encoder.to(device)
encoder.eval()
images_list = [images.to(device) for images in test_datasets.images_list]

imageDataset_list = [
    ImageDataset(images) for images in images_list + encoder(images_list)
]

ditributed inference

In [8]:
output_list = []

# inference on N devices
for i in range(N):
    imageDataset = imageDataset_list[i]

    test_loader = DataLoader(imageDataset, batch_size=64, shuffle=False)
    test_loader_tqdm = tqdm(test_loader, desc=f"model_distributed {i}")

    output = torch.tensor([]).to(device)
    with torch.no_grad():
        for images in test_loader_tqdm:
            images = images.to(device)
            output = torch.cat((output, model_distributed(images)), dim=0)

    output_list.append(output)

model_distributed 0: 100%|██████████| 157/157 [00:00<00:00, 359.68it/s]
model_distributed 1: 100%|██████████| 157/157 [00:00<00:00, 1375.61it/s]
model_distributed 2: 100%|██████████| 157/157 [00:00<00:00, 1193.52it/s]
model_distributed 3: 100%|██████████| 157/157 [00:00<00:00, 1084.41it/s]
model_distributed 4: 100%|██████████| 157/157 [00:00<00:00, 1181.85it/s]


decode

In [9]:
def lose_something(
    output_list: List[torch.Tensor], lose_index: Optional[Tuple[int]] = None
) -> List[torch.Tensor]:
    if lose_index is None or len(lose_index) == 0:
        return output_list

    losed_output_list = []

    for i in range(len(output_list)):

        if i in lose_index:

            losed_output_list.append(torch.zeros_like(output_list[i]))
        else:

            losed_output_list.append(output_list[i])
    return losed_output_list

In [10]:
decoder_path = f"decoder/MLP/MNIST/{DATE}/decoder-task_{TASK}-basemodel_{BASE_MODEL}-in_{N}-out_{K}.pth"
decoder = MLPDecoder(num_in=N, num_out=K, in_dim=tuple(output_list[0][0].size()))
decoder.load_state_dict(torch.load(decoder_path))
decoder.to(device)
decoder.eval()
losed_output_list = lose_something(output_list)
decoded_output_list = decoder(losed_output_list)

calculate accuracy

In [11]:
output = torch.cat(decoded_output_list, dim=3)
output = output.view(output.size(0), -1)
data_size = output.size(0)
labels = test_datasets.labels
labels = torch.tensor(labels).to(device)
total = data_size
correct = 0
for i in tqdm(range(data_size)):
    _, predicted = torch.max(model_fc(output[i]).data, 0)
    correct += (predicted == labels[i]).sum().item()

print(f"Accuracy on the Test set: {100 * correct / total}%")

100%|██████████| 10000/10000 [00:07<00:00, 1353.81it/s]

Accuracy on the Test set: 9.77%



