# Testing the model trained on CREDO data using CONNIE data

In [1]:
%run ./notebook_init.py

import os
import torch
from shutil import rmtree
from torchvision import transforms

from core import DATA_FOLDER

from scripts.credo_training_utils import PREDICTION_FOLDERPATH,\
    TRAINING_FOLDERPATH, ImageFolderWithPath,\
    Seed, resnet18_model, predict_model

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
img_size = (60,60)

data_transforms = transforms.Compose([transforms.Resize(img_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize(0, 1)])

In [None]:
connie_cropped_dataset = os.path.join(DATA_FOLDER, "connie_cropped_dataset")
connie_test_dataset = ImageFolderWithPath(connie_cropped_dataset, data_transforms)

In [None]:
seed = Seed()

In [None]:
connie_test_loader = torch.utils.data.DataLoader(connie_test_dataset,
                                                 batch_size=64,
                                                 shuffle=True,
                                                 num_workers=2,
                                                 worker_init_fn=seed.seed_worker)

In [None]:
class_qty = 4

class_names =  ["artefacts", "dot", "line", "worm"]

best_model_filepath = os.path.join(TRAINING_FOLDERPATH, "best_model_weight",
                                   "best_model_params.pt")

connie_prediction_output = os.path.join(PREDICTION_FOLDERPATH,
                                        "connie_prediction")
os.makedirs(connie_prediction_output, exist_ok=True)

for curr_class in class_names:
    curr_class_path = os.path.join(connie_prediction_output, curr_class)
    if os.path.exists(curr_class_path):
        rmtree(curr_class_path)
    os.makedirs(curr_class_path, exist_ok=True)

In [None]:
saved_model = resnet18_model(device, class_qty)
saved_model.load_state_dict(torch.load(best_model_filepath))

In [None]:
_ = predict_model(device, saved_model, class_names, connie_test_loader, connie_prediction_output)