In [3]:
import matplotlib.pyplot as plt
import torch
from data_loader import test_dataset
from models.ResNet import get_resnet18_model

# Load model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_resnet18_model(num_classes=2)
model.load_state_dict(torch.load("saved_models/resnet_binary_lr2e4_messidor.pth", map_location=device))
model.to(device)
model.eval()

# Pick one sample
img_tensor, label = test_dataset[90]
img_tensor = img_tensor.unsqueeze(0).to(device)  # Add batch dimension

# Inference
with torch.no_grad():
    output = model(img_tensor)
    predicted_class = torch.argmax(output, dim=1).item()
    probs = torch.softmax(output, dim=1)

print(f"True Label: {label.item()}, Predicted Label: {predicted_class}")
print(f"Probabilities: {probs.cpu().numpy()}")

True Label: 0, Predicted Label: 0
Probabilities: [[0.77778447 0.22221555]]


In [4]:
import numpy as np
from sklearn.metrics import confusion_matrix


# Collect all probabilities and true labels from the test set
all_probs = []
all_labels = []

from torch.utils.data import DataLoader
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)

model.eval()
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        probs = torch.softmax(outputs, dim=1)[:, 1]  # Probability for class 1 (unhealthy)
        all_probs.extend(probs.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

all_probs = np.array(all_probs)
all_labels = np.array(all_labels)

def compute_sens_spec(threshold):
    preds = (all_probs > threshold).astype(int)
    tn, fp, fn, tp = confusion_matrix(all_labels, preds, labels=[0,1]).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
    return sensitivity, specificity


In [15]:
from sklearn.metrics import roc_auc_score, auc

auc = roc_auc_score(all_labels, all_probs)
print(f"AUC: {auc:.3f}")

In [20]:
threshold = 0.3  # Try different values here!
sens, spec = compute_sens_spec(threshold)
print(f"Threshold: {threshold:.2f} | Sensitivity: {sens:.3f} | Specificity: {spec:.3f}")

from sklearn.metrics import roc_auc_score, auc

auc = roc_auc_score(all_labels, all_probs)
print(f"AUC: {auc:.3f}")

Threshold: 0.30 | Sensitivity: 0.920 | Specificity: 0.729
AUC: 0.946


In [None]:
# Unnormalize for visualization
def unnormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return tensor * std + mean

img_np = unnormalize(img_tensor[0].cpu()).permute(1, 2, 0).numpy()
plt.imshow(img_np)
plt.title(f"True: {label.item()}, Pred: {predicted_class}")
plt.axis('off')
plt.show()
