In [10]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
# hyper-parameter

# N = K + R, N is the distributed device number
K = 2
R = 1
N = K + R

print(f"K: {K}, R: {R}, N: {N}")

K: 2, R: 1, N: 3


In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from tqdm import tqdm
import datetime
import os

from base_models.LeNet5 import LeNet5
from util.dataset import PartitionedDataset, ImageDataset
from encoder.mlp_encoder import MLPEncoder
from decoder.mlp_decoder import MLPDecoder
from data_processing.data_partition import data_partition

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

Device: cuda


In [13]:
base_model_path = "base_models/LeNet5/MNIST/2023_11_28/model.pth"

model = LeNet5(input_dim=(1, 28, 28), num_classes=10)
model.load_state_dict(torch.load(base_model_path))
print("model loaded successfully")

model.to(device)
print(f"model is on device: {next(model.parameters()).device}")


model_distributed = model.get_conv_segment().eval()
print(
    f"model_distributed: device: {next(model_distributed.parameters()).device}, mode: {'training' if model_distributed.training is True else 'eval'}"
)

model_fc = model.get_fc_segment().eval()
print(
    f"model_fc: device: {next(model_fc.parameters()).device}, mode: {'training' if model_fc.training is True else 'eval'}"
)

model loaded successfully
model is on device: cuda:0
model_distributed: device: cuda:0, mode: eval
model_fc: device: cuda:0, mode: eval


data partition

In [14]:
# transform = transforms.Compose(
#     [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
# )

# dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
# print(f"Test dataset: {len(dataset)}")
# print("image size: ", dataset[0][0].size())
# data_partition(
#     model.get_conv_segment(),
#     dataset,
#     K,
#     f"./data/MNIST/partition/{K}/partitioned_test_datasets.pt",
# )

load data

In [17]:
# load data
test_datasets = torch.load(f"data/MNIST/partition/{K}/partitioned_test_datasets.pt")

datasets_num = len(test_datasets[0][0])
data_size = len(test_datasets)
img_size = test_datasets[0][0][0].size()

print(f"type of dataset: {type(test_datasets)}")
print(f"test_datasets num: {datasets_num}, which shold be equal to K = {K}")
print(f"constructure: ([images, images, ...], conv_segment_labels, labels)")
print(f"datasize: {data_size}")
print(f"image size: {img_size}")

type of dataset: <class 'util.dataset.PartitionedDataset'>
test_datasets num: 2, which shold be equal to K = 2
constructure: ([images, images, ...], conv_segment_labels, labels)
datasize: 10000
image size: torch.Size([1, 28, 20])


encode

In [39]:
encoder_path = "encoder/MLP/MNIST/2023_11_29/model.pth"

encoder = MLPEncoder(num_in=K, num_out=R, in_dim=tuple(img_size))
encoder.load_state_dict(torch.load(encoder_path))
encoder.to(device)

images_list = [[] for _ in range(datasets_num)]
for images, _, labels in test_datasets:
    for i in range(datasets_num):
        images_list[i].append(images[i])

images_list = [torch.stack(images).to(device) for images in images_list]
encoded_images_list = encoder(images_list)
images_list = images_list + encoded_images_list

imageDataset_list = [ImageDataset(images) for images in images_list]

ditributed inference

In [51]:
output_list = []

# inference on N devices
for i in range(N):
    imageDataset = imageDataset_list[i]

    test_loader = DataLoader(imageDataset, batch_size=64, shuffle=False)
    test_loader_tqdm = tqdm(test_loader, desc=f"model_distributed {i}")

    output = torch.tensor([]).to(device)
    with torch.no_grad():
        for images in test_loader_tqdm:
            images = images.to(device)
            output = torch.cat((output, model_distributed(images)), dim=0)

    output_list.append(output)

model_distributed 0: 100%|██████████| 157/157 [00:00<00:00, 451.04it/s]
model_distributed 1: 100%|██████████| 157/157 [00:00<00:00, 616.08it/s]
model_distributed 2: 100%|██████████| 157/157 [00:00<00:00, 625.50it/s]


decode

In [52]:
decoder_path = "decoder/MLP/MNIST/2023_11_29/model.pth"

decoder = MLPDecoder(num_in=N, num_out=K, in_dim=tuple(output_list[0][0].size()))
decoder.load_state_dict(torch.load(decoder_path))
decoder.to(device)

losed_output_list = [output_list[0], output_list[1], output_list[2]]
# losed_output_list = [torch.zeros_like(output_list[0]), output_list[1], output_list[2]]

decoded_output_list = decoder(losed_output_list)

calculate accuracy

In [53]:
output = torch.cat(decoded_output_list, dim=3)
output = output.view(output.size(0), -1)

data_size = output.size(0)

labels = test_datasets.labels
labels = torch.tensor(labels).to(device)

total = data_size
correct = 0
for i in tqdm(range(data_size)):
    _, predicted = torch.max(model_fc(output[i]).data, 0)
    correct += (predicted == labels[i]).sum().item()

print(f"Accuracy on the Test set: {100 * correct / total}%")

100%|██████████| 10000/10000 [00:13<00:00, 762.34it/s]

Accuracy on the Test set: 98.79%





In [None]:
def lose_something(output_list, lose_index):
    losed_output_list = []
    for i in range(len(output_list)):
        if i in lose_index:
            losed_output_list.append(torch.zeros_like(output_list[i]))
        else:
            losed_output_list.append(output_list[i])
    return losed_output_list