In [76]:
import numpy as np
from sklearn.metrics import roc_curve, auc, recall_score, accuracy_score, precision_score
import matplotlib.pyplot as plt
import torch
from torchvision import models
from torchvision.io import read_image
from torchvision import transforms
import torch.nn as nn
import pandas as pd
from config import *
import random
import os

from itertools import product


#### Llegir models

In [77]:
from model_autoencoder import Autoencoder
model_autoencoder = Autoencoder()
model_autoencoder.load_state_dict(torch.load("/fhome/mapsiv01/PSIV3_malaltia/autoencoder/models/model5_AUTOENCODER.pth", map_location=torch.device('cpu')))


from model_CNN_MlpClassifier import Classifier
model_our = Classifier()
model_our.load_state_dict(torch.load("/fhome/mapsiv01/PSIV3_malaltia/autoencoder/models/model_MLPour_classifier.pth", map_location=torch.device('cpu')))


model_resnet = models.resnet50(pretrained=True)
for param in model_resnet.parameters():
    param.requires_grad = False
num_ftrs = model_resnet.fc.in_features
model_resnet.fc = nn.Sequential(
nn.Linear(num_ftrs, 128),  # Replace the last fully connected layer with your custom layers
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, 2)  # Assuming your final output has 2 classes, modify this as needed
    )
model_resnet.load_state_dict(torch.load("/fhome/mapsiv01/PSIV3_malaltia/autoencoder/models/model_MLPresnet.pth", map_location=torch.device('cpu')))




<All keys matched successfully>

In [78]:
transform = transforms.Compose([
    transforms.Resize((128, 128), antialias=True),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Adjust mean and std as needed
])

def read_img(file, dir):
    img = read_image(os.path.join(dir, file))[:-1,:,:]
    img = img.to(torch.float32)
    img = img/255
    img = transform(img)
    return img

#### Funció per calcular mètriques per un Fold

In [79]:
def create_crossval(pred, target):
       
    combined = list(zip(pred, target))

    random.shuffle(combined)

    pred, target = zip(*combined)

    interval = int(len(pred)/10)

    kfolds = [(pred[:interval*i]+pred[interval*i+1:], target[:interval*i]+target[interval*i+1:],
               pred[interval*i:interval*(i+1)], target[interval*i:interval*(i+1)]) for i in range(10)]
    
    return kfolds

def calculate_mean_error(data):
    mean = np.mean(data)
    std_dev = np.std(data, ddof=1)  # ddof=1 for sample standard deviation
    sample_size = len(data)

    # Calculate the t-distribution critical value
    t_critical = 2.262

    # Calculate the margin of error
    margin_of_error = t_critical * (std_dev / np.sqrt(sample_size))

    return f"{mean} +/- {margin_of_error}"

       

def calculate_metrics(pred, target):

    recalls, precisions, f1_scores = [], [], []

    kfolds = create_crossval(pred, target)

    for (pred_train, target_train, pred_test, target_test) in kfolds:
        df = pd.DataFrame({"Target": target_train, "Predicted": pred_train})
        fpr, tpr, thresholds = roc_curve(df['Target'], df['Predicted'])

        J2 = tpr - fpr  
        best_threshold = thresholds[np.argmax(J2)]

        pred_test = [0 if prob < best_threshold else 1 for prob in pred_test]


        recall = recall_score(pred_test, target_test)
        precision = precision_score(pred_test, target_test)
        f1_score = (2*recall*precision)/(recall+precision)

        recalls.append(recall)
        precisions.append(precision)
        f1_scores.append(f1_score)
    
    recall = calculate_mean_error(recalls)
    precision = calculate_mean_error(precisions)
    f1_score = calculate_mean_error(f1_scores)


    return recall, precision, f1_score

In [80]:
directories_cr = [dir.path for dir in os.scandir(CROPPED_PATCHES_DIR) if dir.is_dir()]
metadata = pd.read_csv(METADATA_FILE)

In [91]:
def calcula(model, method, amb_baixa):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    model_name = model[0]
    model = model[1]
    model = model.to(device)

    preds = []
    targets = []

    for dir in directories_cr:
        if metadata.loc[metadata["CODI"] == dir.split("/")[-1].split("_")[0], "DENSITAT"].values[0] == "NEGATIVA":
            target = 0
        elif metadata.loc[metadata["CODI"] == dir.split("/")[-1].split("_")[0], "DENSITAT"].values[0] == "ALTA":
            target = 1
        else:
            if amb_baixa:
                target = 1
            else:
                continue
        
        files = os.listdir(dir)
        random.shuffle(files)

        pred_patches = []
        imgs = []
        for file in files[:30]:
            img = read_img(file, dir)
            imgs.append(img)

        stacked_imgs = torch.stack(imgs)  # This creates a tensor of shape (num_images, channels, height, width)
        
        model.eval()
        with torch.no_grad():
            output = model(stacked_imgs.to(device))
            output = torch.sigmoid(output)
            pred_patches = []
        for i in range(output.shape[0]):
            output2 = output[i, :]
            if method == "prob":
                output2 = output2[torch.argmax(output2).item()].item()
            elif method == "discrete":
                output2 = torch.argmax(output2).item()
            pred_patches.append(output2)
        
        preds.append(sum(pred_patches)/len(files))
        targets.append(target)


    recall, precision, f1_score = calculate_metrics(preds, targets)

    print(f"Model: {model_name}, Method: {method}, Baixa: {amb_baixa}")
    print(f"Recall: {recall}, Precision: {precision}, F1_score: {f1_score}\n")

In [92]:
models = [["our", model_our], ["resnet", model_resnet]]
methods = ["prob", "discrete"]
amb_baixa = [True, False]

for (model, method, amb_baixa) in product(models, methods, amb_baixa):
    calcula(model, method, amb_baixa)

Model: our, Method: prob, Baixa: True
Recall: 0.5945879404393336 +/- 0.07437233374631352, Precision: 0.8354842543077836 +/- 0.0760207240227953, F1_score: 0.6876792027488936 +/- 0.05774363794425387

Model: our, Method: prob, Baixa: False
Recall: 0.4568075062192709 +/- 0.06727175239187745, Precision: 0.8341666666666667 +/- 0.12630750352699643, F1_score: 0.5840383300909616 +/- 0.07849851902487531

Model: our, Method: discrete, Baixa: True
Recall: 0.8376587301587302 +/- 0.0761013512487516, Precision: 0.5881135531135532 +/- 0.08672852508411119, F1_score: 0.679334576559977 +/- 0.050113495128605405

Model: our, Method: discrete, Baixa: False
Recall: 0.6628787878787878 +/- 0.15204503598257374, Precision: 0.7415873015873016 +/- 0.17329420097949527, F1_score: 0.661169892199304 +/- 0.13133755594840218

Model: resnet, Method: prob, Baixa: True
Recall: 0.6056383322559793 +/- 0.07084948477679817, Precision: 0.7949938949938951 +/- 0.08076336415834749, F1_score: 0.6830546051311771 +/- 0.06639698167279