In [None]:
import os

import albumentations as A
import torch
import torch.nn as nn
from albumentations.pytorch import ToTensorV2
from torchvision import models

import dataset
import reasons
import random

In [None]:
TEST_DIR = "Dataset/test"
PATH_TO_MODEL = 'basemodel.pth'
NUM_OF_CLASSES = 2

In [None]:
os.chdir(TEST_DIR)

In [None]:
def mislabel():
    
    """
    Randomly mislabels 10% of each class with another label
    Renames the mislabeled images by appending "mislabeled" to the front of the name
    Returns a list of mislabeled images
    """
    
    all_classes = os.listdir('.')
    mislabels = []
    
    for label in all_classes:
        original_images = os.listdir(f'{label}')
        num_of_samples = len(original_images)
        to_mislabel = int(0.1*num_of_samples)  
        classes = os.listdir('.')
        classes.remove(label)
        random.shuffle(classes)
        
        for i in range(to_mislabel):
            old_name = f'{label}/{original_images[i]}'
            new_label = random.sample(classes,1)[0]
            new_name = f'{new_label}/~mislabeled_{original_images[i]}'
            os.rename(old_name, new_name)
            mislabels.append(new_name)
            
    return mislabels

mislabels = mislabel()
total_mislabeled = len(mislabels)

model_loaded = models.resnet18(pretrained=True)
num_ftrs = model_loaded.fc.in_features
model_loaded.fc = nn.Linear(num_ftrs, NUM_OF_CLASSES)
model_loaded.load_state_dict(torch.load(PATH_TO_MODEL, map_location='cpu'))

model_loaded.eval()

In [None]:
print(f"Total number of mislabels: {total_mislabeled}")
print(f"Mislabeled images: {mislabels}")

In [None]:
#MAKE SURE CLASSES ARE IN THE SAME ORDER AS IN COLAB, if not, copy paste class list from colab/match the class indices here to the same ones that the model was trained on
classes = os.listdir('.')

test_image_paths = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser('.')) for f in fn]

test_transforms = A.Compose(
    [
        A.Resize(64,64),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

total_annos = len(test_image_paths)
test_dataset = dataset.CustomDataset(test_image_paths,classes,test_transforms)
doubt = reasons.TorchDoubtLab(model_loaded, test_dataset)

def get_metrics(flagged, total_mislabels, total_annos):
    
    metrics = {}
    
    tp = 0
    fp = 0
    
    for flagged_img in flagged:
        if 'mislabeled' in flagged_img:
            tp += 1
        else:
            fp += 1

    fn = total_mislabels - tp
    print(total_mislabels, tp)
    tn = total_annos - (tp+fp+fn)
    
    metrics['accuracy'] = f"{100*(tp+tn)/total_annos}%"
    metrics['fp'] = fp
    metrics['fn'] = fn
    metrics['tp'] = tp
    try:
        precision = tp/(tp+fp)
    except ZeroDivisionError:
        precision = 0
    metrics['precision'] = precision
    recall = tp/(tp+fn)
    try:
        recall = tp/(tp+fn)
    except ZeroDivisionError:
        recall = 0
    metrics['recall'] = recall
    try:
        metrics['f1'] = (2*precision*recall)/(precision+recall)
    except ZeroDivisionError:
        metrics['f1'] = 0
    
    return metrics

In [None]:
low_conf_indices = doubt.ProbaReason()
x = doubt.get_flagged_images(low_conf_indices)
get_metrics(x, total_mislabeled, total_annos)

In [None]:
wrong_pred_indices = doubt.WrongPrediction()
x = doubt.get_flagged_images(wrong_pred_indices)
get_metrics(x, total_mislabeled, total_annos)

In [None]:
short_conf_indices = doubt.ShortConfidence()
x = doubt.get_flagged_images(short_conf_indices)
get_metrics(x, total_mislabeled, total_annos)

In [None]:
long_conf_indices = doubt.LongConfidence()
x = doubt.get_flagged_images(long_conf_indices)
get_metrics(x, total_mislabeled, total_annos)