In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from tqdm import tqdm
import datetime
import os

from base_models.LeNet5 import LeNet5
from util.dataset import PartitionedDataset, ImageDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

filepath = "base_models/LeNet5/MNIST/2023_11_28/model.pth"

model = LeNet5(input_dim=(1, 28, 28), num_classes=10)
model.load_state_dict(torch.load(filepath))
print("model loaded successfully")

model.to(device)
print(f"model is on device: {next(model.parameters()).device}")

model_distributed = model.get_conv_segment().eval()
print(
    f"model_distributed: device: {next(model_distributed.parameters()).device}, mode: {'training' if model_distributed.training is True else 'eval'}"
)

model_fc = model.get_fc_segment().eval()
print(
    f"model_fc: device: {next(model_fc.parameters()).device}, mode: {'training' if model_fc.training is True else 'eval'}"
)

print("-" * 20)

# load data
test_datasets = torch.load("data/MNIST/partition/4/partitioned_test_datasets.pt")

datasets_num = len(test_datasets[0][0])
data_size = len(test_datasets)
img_size = test_datasets[0][0][0].size()

print(f"type of dataset: {type(test_datasets)}")
print(f"test_datasets num: {datasets_num}, which shold be equal to K = {K}")
print(f"constructure: ([images, images, ...], labels)")
print(f"datasize: len(images) = len(labels) = {data_size}")
print(f"image size: {img_size}")

print("-" * 20)

images_list = [[] for _ in range(datasets_num)]
for img_list, label in test_datasets:
    for i in range(datasets_num):
        images_list[i].append(img_list[i])

images_list = [torch.stack(images).to(device) for images in images_list]
imageDataset_list = [ImageDataset(images) for images in images_list]

output_list = []

# inference on K devices
for i in range(K):
    imageDataset = imageDataset_list[i]

    test_loader = DataLoader(imageDataset, batch_size=64, shuffle=False)
    test_loader_tqdm = tqdm(test_loader, desc=f"model_distributed {i}")

    output = torch.tensor([]).to(device)
    with torch.no_grad():
        for img in test_loader_tqdm:
            img = img.to(device)
            output = torch.cat((output, model_distributed(img)), dim=0)

    output_list.append(output)

output = torch.cat(output_list, dim=3)
output = output.view(output.size(0), -1)

print("-" * 20)

data_size = output.size(0)

labels = test_datasets.labels
labels = torch.tensor(labels).to(device)

total = data_size
correct = 0
for i in tqdm(range(data_size)):
    _, predicted = torch.max(model_fc(output[i]).data, 0)
    correct += (predicted == labels[i]).sum().item()

print(f"Accuracy on the Test set: {100 * correct / total}%")