In [70]:
import torch
from torch import nn
from datahandling import get_dev_data
from tqdm import tqdm
import config 
from models.vqa_model import VQA
from os.path import join
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import pandas as pd
from datahandling import load_multilabel_binarizer, QUESTIONS_TO_BE_ANSWERED

In [71]:
# Change this to the filename of the saved model in the models/trained folder
MODEL_NAME = "vqa_model.pth"

mlb = load_multilabel_binarizer()

# Dont change this unless there are mor answers possible than before
NUM_LABELS = len(mlb.classes_)


def evaluation():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Running on device: {device}")

    model = load_model(device)
    _, _, test_loader = get_dev_data(debug_mode=False)  
    criterion = nn.BCEWithLogitsLoss()

    loss = 0
    y_pred = []
    y_true = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            (images, questions, question_attention_mask), labels = batch
            images, questions, question_attention_mask, labels = images.to(device), questions.to(device), question_attention_mask.to(device), labels.to(device)
            
            outputs = model(images, questions, question_attention_mask)
            loss = criterion(outputs, labels)
            loss += loss.item()
            
            y_pred.append(outputs)
            y_true.append(labels)

    y_pred = torch.cat(y_pred)
    y_true = torch.cat(y_true)
    
    
    return y_true, y_pred, device, loss / len(test_loader)
   


def load_model(device) -> torch.nn.Module:
    model = VQA()
    model.load_state_dict(torch.load(join(config.trained_model_path, MODEL_NAME), map_location=device , weights_only=True))
    model.to(device) 
    model.eval() 
    return model


In [72]:
y_true, y_pred, device, test_loss = evaluation()

Running on device: cpu
Data loaded successfully in PyTorch format.


Testing: 100%|██████████| 57/57 [03:39<00:00,  3.85s/it]


In [None]:
from collections import defaultdict

idx_without = []


# save which indices of y_true/y_preds are for the questions
per_question = dict()

results = {
    "question": [],
    "accuracy": [],
    "precision": [],
    "recall": [],
    "f1": []
}

for answer in mlb.classes_:
    question = answer.split("_")[0]
    per_question[question] = []
    
    
for idx, truth in enumerate(y_true):
    correct_answers = mlb.inverse_transform(np.array([truth]))[0]

    if len(correct_answers) > 0:
        first_answer = correct_answers[0]
        
        key = first_answer.split("_")[0]
        per_question[key].append(idx)
    else:
        idx_without.append(idx)
# from y_preds, y_true get the indices per question, then calculate the metrics on these instances


for question, indexes in per_question.items():
    results["question"].append(question)
    results["accuracy"].append(round(accuracy_score(y_true[indexes], torch.sigmoid(y_pred[indexes])>=0.5),4))
    results["precision"].append(round(precision_score(y_true[indexes], torch.sigmoid(y_pred[indexes])>=0.5, average="samples", zero_division=0), 4))
    results["recall"].append(round(recall_score(y_true[indexes], torch.sigmoid(y_pred[indexes])>=0.5, average="samples", zero_division=0), 4))
    results["f1"].append(round(f1_score(y_true[indexes], torch.sigmoid(y_pred[indexes])>=0.5, average="samples", zero_division=0), 4))

In [86]:
idx_without.__len__() # these somehow return () idk why

204

In [75]:
df = pd.DataFrame.from_dict(results)

In [76]:
df

Unnamed: 0,question,accuracy,precision,recall,f1
0,Are there any abnormalities in the image?,0.9949,0.9949,0.9949,0.9949
1,Are there any anatomical landmarks in the image?,0.9663,0.9691,0.9719,0.97
2,Are there any instruments in the image?,0.9015,0.9138,0.9113,0.9113
3,Have all polyps been removed?,0.9862,0.9885,0.9908,0.9893
4,How many findings are present?,0.8272,0.8298,0.8325,0.8307
5,How many instrumnets are in the image?,0.9522,0.9545,0.9569,0.9553
6,How many polyps are in the image?,0.9749,0.9774,0.9799,0.9782
7,Is there text?,0.9251,0.9251,0.9251,0.9251
8,Is this finding easy to detect?,0.884,0.8867,0.8895,0.8877
9,What color is the abnormality?,0.6126,0.9066,0.8739,0.87


In [84]:

results = f""" Aggregated 
    accuracy: {df.accuracy.mean():.4f},  
    precision: {df.precision.mean():.4f},  
    recall: {df.recall.mean():.4f}
    F1 Score: {df.f1.mean():.4f},  
"""

print(results)

 Aggregated 
    accuracy: 0.8666,  
    precision: 0.9215,  
    recall: 0.9110
    F1 Score: 0.9122,  



In [77]:

results = f""" Performance on Dev test set
    accuracy: {accuracy_score(y_true, torch.sigmoid(y_pred)>=0.5):.4f},  
    precision: {precision_score(y_true, torch.sigmoid(y_pred)>=0.5, average="samples", zero_division=0):.4f},  
    recall: {recall_score(y_true, torch.sigmoid(y_pred)>=0.5, average="samples", zero_division=0):.4f}
    F1 Score: {f1_score(y_true, torch.sigmoid(y_pred)>=0.5, average="samples", zero_division=0):.4f},  
"""

print(results)

 Performance on Dev test set
    accuracy: 0.8744,  
    precision: 0.8689,  
    recall: 0.8591
    F1 Score: 0.8603,  

