In [1]:
import torch
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import timm

from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, classification_report, roc_auc_score, recall_score

# Check for CUDA device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

SEED = 16

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

cuda


In [2]:
fold = 1 #load model trained on nth fold

validation_set_path = '../DATA/Train_Val_set/Val'
BvC_model_path = f'../MODELS/2StageModels/BvC/fold{fold}_model.pth'
PvN_model_path = f'../MODELS/2StageModels/PvN/fold{fold}_model.pth'

model_name = 'tiny_vit_21m_512.dist_in22k_ft_in1k'
num_classes = 2 #2 class per model
batch_size = 6

input_shape = (512, 512)
transform_mean = (0.485, 0.456, 0.406)
transform_std = (0.229, 0.224, 0.225)

In [3]:
# Load Models
BvC_model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
BvC_model = nn.Sequential( BvC_model, nn.Sigmoid() )
BvC_model = BvC_model.to(DEVICE)
BvC_model.load_state_dict(torch.load(BvC_model_path, map_location=DEVICE, weights_only=True))
BvC_model.eval()

PvN_model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
PvN_model = nn.Sequential( PvN_model, nn.Sigmoid() )
PvN_model = PvN_model.to(DEVICE)
PvN_model.load_state_dict(torch.load(PvN_model_path, map_location=DEVICE, weights_only=True))
PvN_model.eval()

print("MODELS READY")

MODELS READY


In [4]:
val_transform = transforms.Compose([
        transforms.Resize(input_shape),
        transforms.ToTensor(),
        transforms.Normalize(mean=transform_mean, std=transform_std)
    ])

val_dataset = datasets.ImageFolder(root=validation_set_path, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

print("CLASS MAPPING")
print(val_dataset.class_to_idx)

CLASS MAPPING
{'GNB': 0, 'GNC': 1, 'GPB': 2, 'GPC': 3}


In [5]:
# B vs C
# MODEL OUTPUT -> B:0 C:1

BvC_all_preds = []
BvC_all_labels = []

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)
        
        # Get model predictions
        outputs = BvC_model(inputs)
        preds = (outputs > 0.5).long()

        BvC_all_preds.extend(preds.cpu().numpy())
        BvC_all_labels.extend(labels.cpu().numpy())


BvC_all_labels = [0 if i in [0,2] else 1 for i in BvC_all_labels]
BvC_all_preds = [np.argmax(i) for i in BvC_all_preds]
BvC_all_labels = ['B' if i == 0 else 'C' for i in BvC_all_labels]
BvC_all_preds = ['B' if i == 0 else 'C' for i in BvC_all_preds]

In [6]:
# P vs N
# MODEL OUTPUT -> N:0 P:1

PvN_all_preds = []
PvN_all_labels = []

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)
        
        # Get model predictions
        outputs = PvN_model(inputs)
        preds = (outputs > 0.5).long()

        PvN_all_preds.extend(preds.cpu().numpy())
        PvN_all_labels.extend(labels.cpu().numpy())


PvN_all_labels = [0 if i in [0,1] else 1 for i in PvN_all_labels]
PvN_all_preds = [np.argmax(i) for i in PvN_all_preds]
PvN_all_labels = ['N' if i == 0 else 'P' for i in PvN_all_labels]
PvN_all_preds = ['N' if i == 0 else 'P' for i in PvN_all_preds]

In [7]:
true_labels = []
pred_labels = []

for a,b in zip(PvN_all_labels, BvC_all_labels):
    true_labels.append('G'+a+b)

for a,b in zip(PvN_all_preds, BvC_all_preds):
    pred_labels.append('G'+a+b)

In [8]:
print("Cocci vs Bacilli")
print(classification_report(y_pred=BvC_all_preds, y_true=BvC_all_labels))
print()

print("Positive vs Negative")
print(classification_report(y_pred=PvN_all_preds, y_true=PvN_all_labels))
print()

print("4 Class")
print(classification_report(y_true=true_labels, y_pred=pred_labels))

Cocci vs Baclli
              precision    recall  f1-score   support

           B       0.95      0.92      0.93       100
           C       0.92      0.95      0.94       100

    accuracy                           0.94       200
   macro avg       0.94      0.94      0.93       200
weighted avg       0.94      0.94      0.93       200


Positive vs Negative
              precision    recall  f1-score   support

           N       0.97      0.86      0.91       100
           P       0.87      0.97      0.92       100

    accuracy                           0.92       200
   macro avg       0.92      0.92      0.91       200
weighted avg       0.92      0.92      0.91       200


4 Class
              precision    recall  f1-score   support

         GNB       0.93      0.84      0.88        50
         GNC       0.86      0.76      0.81        50
         GPB       0.88      0.92      0.90        50
         GPC       0.76      0.90      0.83        50

    accuracy               

In [9]:
BvC_report = classification_report(y_pred=BvC_all_preds, y_true=BvC_all_labels, output_dict=True)
PvN_report = classification_report(y_pred=PvN_all_preds, y_true=PvN_all_labels, output_dict=True)
combined_report = classification_report(y_true=true_labels, y_pred=pred_labels, output_dict=True)

BvC_report = pd.DataFrame(BvC_report).transpose().round(4)
PvN_report = pd.DataFrame(PvN_report).transpose().round(4)
combined_report = pd.DataFrame(combined_report).transpose().round(4)

In [10]:
BvC_report

Unnamed: 0,precision,recall,f1-score,support
B,0.9485,0.92,0.934,100.0
C,0.9223,0.95,0.936,100.0
accuracy,0.935,0.935,0.935,0.935
macro avg,0.9354,0.935,0.935,200.0
weighted avg,0.9354,0.935,0.935,200.0


In [11]:
PvN_report

Unnamed: 0,precision,recall,f1-score,support
N,0.9663,0.86,0.9101,100.0
P,0.8739,0.97,0.9194,100.0
accuracy,0.915,0.915,0.915,0.915
macro avg,0.9201,0.915,0.9147,200.0
weighted avg,0.9201,0.915,0.9147,200.0


In [12]:
combined_report

Unnamed: 0,precision,recall,f1-score,support
GNB,0.9333,0.84,0.8842,50.0
GNC,0.8636,0.76,0.8085,50.0
GPB,0.8846,0.92,0.902,50.0
GPC,0.7627,0.9,0.8257,50.0
accuracy,0.855,0.855,0.855,0.855
macro avg,0.8611,0.855,0.8551,200.0
weighted avg,0.8611,0.855,0.8551,200.0
