In [None]:
import os
os.chdir("..")
print(os.getcwd())

In [None]:
MODEL_PATH = "checkpoints/model_nrg01.hpc.itc.rwth-aachen.de_20231123-201621_mobilenet_v3_large_0_unfreezed_20_epochs.pth"
TSV_PATH = "data/kaggle_dataset/train.tsv"
TEST_TSV_PATH = "data/kaggle_dataset/test.tsv"
IMAGE_PATH = "notebooks/demo_images/1.png"

In [None]:
from PIL import Image
import torch
from torchvision import transforms
import numpy as np
import random

In [None]:
from src.models.model import get_model
import src.kaggle50k_dataset as kaggle50k_dataset
from src.test import test

In [None]:
def load_model(model_path):
    state_dict = torch.load(model_path, map_location=torch.device('cpu'))
    model_state_dict = state_dict["model_state_dict"]
    idx_to_label = state_dict["idx_to_labels"]
    config = state_dict["config"]
    model = get_model(model_name=config["model"]["name"], num_classes=56)
    model.load_state_dict(model_state_dict)
    model.eval()
    return model, idx_to_label, config

def get_idx_to_label(tsv_path):
    labels = []
    tsv_file = open(tsv_path, "r")

    for line in tsv_file:
        if line in ['\n', '\r\n']:
            continue
        label, _ = os.path.split(line.strip())
        labels.append(label)

    idx_to_label = {idx: label for idx, label in enumerate(set(labels))}
    return idx_to_label

def load_img(image_path):
    transform = transforms.Compose([
        transforms.Resize((1536, 662)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    image = Image.open(image_path).convert("RGB")
    image_transformed = transform(image)
    return image_transformed

def get_random_image(dataset):
    idx = np.random.randint(0, len(dataset))
    image, label = dataset[idx]
    return image, label

In [None]:
model, idx_to_label, config = load_model(MODEL_PATH)
label_to_idx = {label: idx for idx, label in idx_to_label.items()}

base_transforms = transforms.Compose([
    transforms.Resize((1536, 662)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# TODO 
# occlusion_transforms = 

# TODO
# noise_transforms =

# TODO
# illumination_transforms =

base_dataset = kaggle50k_dataset.Kaggle50K(TEST_TSV_PATH, base_transforms)
occlusion_dataset = kaggle50k_dataset.Kaggle50K(TEST_TSV_PATH, occlusion_transforms)
noise_dataset = kaggle50k_dataset.Kaggle50K(TEST_TSV_PATH, noise_transforms)
illumination_dataset = kaggle50k_dataset.Kaggle50K(TEST_TSV_PATH, illumination_transforms)

base_dataset.labels_to_idx = label_to_idx
occlusion_dataset.labels_to_idx = label_to_idx
noise_dataset.labels_to_idx = label_to_idx
illumination_dataset.labels_to_idx = label_to_idx

top_k = [1, 3, 5]

base_accuracies = test(model, base_dataset, top_k)
occlusion_accuracies = test(model, occlusion_dataset, top_k)
noise_accuracies = test(model, noise_dataset, top_k)
illumination_accuracies = test(model, illumination_dataset, top_k)