In [None]:
import torch
from torchvision import models
import torch.nn as nn
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2

import dataset
import reasons

In [None]:
model_loaded = models.resnet18(pretrained=True)
num_ftrs = model_loaded.fc.in_features
model_loaded.fc = nn.Linear(num_ftrs, 3)
model_loaded.load_state_dict(torch.load('trained_model.pth', map_location='cpu'))

In [None]:
model_loaded.eval()

In [None]:
test_data_path = 'Dataset/test'
classes = os.listdir(test_data_path)
test_image_paths = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(test_data_path)) for f in fn]

test_transforms = A.Compose(
    [
        A.Resize(64,64),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

In [None]:
test_dataset = dataset.CustomDataset(test_image_paths,classes,test_transforms)

In [None]:
doubt = reasons.TorchDoubtLab(model_loaded, test_dataset)

In [None]:
low_conf_indices = doubt.ProbaReason()
doubt.get_flagged_images(low_conf_indices)

In [None]:
wrong_pred_indices = doubt.WrongPrediction()
doubt.get_flagged_images(wrong_pred_indices)

In [None]:
short_conf_indices = doubt.ShortConfidence()
doubt.get_flagged_images(short_conf_indices)

In [None]:
long_conf_indices = doubt.LongConfidence()
doubt.get_flagged_images(long_conf_indices)