# IMPORTS

In [26]:
import torch
import torch.nn as nn
import tensorflow as tf
from sklearn.metrics import classification_report, precision_recall_curve, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import os
import seaborn as sns
from transformers import ViTImageProcessor
from PIL import Image
from torch.utils.data import DataLoader
from sklearn.preprocessing import label_binarize

# EVALUATION FUNCTIONS

In [27]:
def evaluate_model_pytorch(model, test_loader, categories):
    device = torch.device("mps" if torch.mps.is_available() else "cpu")
    model.to(device)
    model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            images = images.squeeze(1)  # Shape becomes [32, 3, 224, 224]

            outputs = model(images) 
            _, predicted = outputs.max(1)

            # If labels are one-hot encoded, convert them to class indices
            if len(labels.shape) > 1:
                labels = labels.argmax(dim=1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Print Classification Report
    print("Classification Report:\n", classification_report(all_labels, all_preds, target_names=categories))
    return all_labels, all_preds

In [60]:
def find_best_thresholds_pytorch(model, test_loader, num_classes, device):
    model.eval()

    all_probs = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)

            labels = labels.cpu().numpy()
            if labels.ndim == 1:
                labels = labels.reshape(-1, 1)  # Shape [batch_size, 1]

            images = images.squeeze(1)  # Shape becomes [batch, 3, 224, 224]

            # Get model output logits
            outputs = model(images) 

            # Convert logits to probabilities
            probs = torch.softmax(outputs, dim=1).cpu().numpy()  # Softmax for multi-class classification

            all_probs.append(probs)
            all_labels.append(labels)

    all_probs = np.concatenate(all_probs, axis=0)  # Shape: [n_samples, n_classes]
    all_labels = np.concatenate(all_labels, axis=0)  # Shape: [n_samples, 1] 

    # One-hot encode labels (needed for multi-class precision-recall curve)
    all_labels = label_binarize(all_labels, classes=np.arange(num_classes))  # Shape: [n_samples, n_classes]

    best_thresholds = []

    for class_idx in range(num_classes):
        binary_labels = all_labels[:, class_idx]  # Get class labels (binary format)

        # Compute precision-recall curve
        precision, recall, thresholds = precision_recall_curve(binary_labels, all_probs[:, class_idx])

        # Compute F1 scores for each threshold
        f1_scores = (2 * precision * recall) / (precision + recall + np.finfo(float).eps)  # Avoid division by zero

        # Get threshold with max F1 score
        best_threshold = thresholds[np.argmax(f1_scores)] if thresholds.size > 0 else 0.5
        best_thresholds.append(best_threshold)

        print(f"Best threshold for class {class_idx}: {best_threshold:.4f}")

    return best_thresholds

def evaluate_with_best_thresholds_pytorch(model, test_loader, categories, device):
    best_thresholds = find_best_thresholds_pytorch(model, test_loader, len(categories), device)
    model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.cpu().numpy()

            if len(images.shape) == 5:
                images = images.squeeze(1)  # Remove the extra dimension, if necessary

            # Ensure labels are in a multi-label format (one-hot encoded)
            if labels.ndim == 1: 
                labels = np.eye(len(categories))[labels]

            outputs = model(images)
            if isinstance(outputs, tuple):  
                outputs = outputs[0]

            # Apply sigmoid activation
            probs = torch.sigmoid(outputs).cpu().numpy()

            # Apply best thresholds for each class
            preds = (probs > np.array(best_thresholds)).astype(int)
            preds = np.argmax(preds*probs, axis=1)
            preds = label_binarize(preds, classes=np.arange(len(categories)))

            all_preds.extend(preds)
            all_labels.extend(labels)

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    # Ensure both are in the same shape and format
    if all_labels.shape != all_preds.shape:
        raise ValueError(f"Shape mismatch: labels {all_labels.shape}, preds {all_preds.shape}")
    
    all_labels = all_labels.argmax(axis=1)
    all_preds = all_preds.argmax(axis=1)

    # Print Classification Report
    print("Classification Report with Optimized Thresholds:\n",
          classification_report(all_labels, all_preds, target_names=categories))
    
    return all_labels, all_preds

# DATASET

## PyTorch Dataset

In [37]:
batch_size = 32
img_size = (224, 224)
channels = 3
img_shape = (img_size[0], img_size[1], channels)
categories = ["Normal","Osteopenia", "Osteoporosis"]

root_dir = 'Dataset'
train_dir = os.path.join(root_dir, 'train')
val_dir = os.path.join(root_dir, 'val')
test_dir = os.path.join(root_dir, 'test')

In [38]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, labels, processor, model, device):
        self.image_paths = image_paths
        self.labels = labels
        self.processor = processor
        self.model = model
        self.device = device 

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        label = self.labels[idx]

        inputs = self.processor(images=image, return_tensors="np").to(self.device)  # Move inputs to device
        inputs = inputs.pixel_values 

        return inputs, label

In [39]:
train_paths = []
train_labels = []
class_to_idx = {class_name: idx for idx, class_name in enumerate(os.listdir(train_dir))}
for class_name in os.listdir(train_dir):
    class_folder = os.path.join(train_dir, class_name)
    if os.path.isdir(class_folder):
        for img_name in os.listdir(class_folder):
            if img_name.endswith(('.jpg', '.jpeg', '.png')):
                train_paths.append(os.path.join(class_folder, img_name))
                train_labels.append(class_to_idx[class_name])
print(f"Training images: {len(train_paths)}")

val_paths = []
val_labels = []
class_to_idx = {class_name: idx for idx, class_name in enumerate(os.listdir(val_dir))}
for class_name in os.listdir(val_dir):
    class_folder = os.path.join(val_dir, class_name)
    if os.path.isdir(class_folder):
        for img_name in os.listdir(class_folder):
            if img_name.endswith(('.jpg', '.jpeg', '.png')):
                val_paths.append(os.path.join(class_folder, img_name))
                val_labels.append(class_to_idx[class_name])
print(f"Validation images: {len(val_paths)}")

test_paths = []
test_labels = []
class_to_idx = {class_name: idx for idx, class_name in enumerate(os.listdir(test_dir))}
for class_name in os.listdir(test_dir):
    class_folder = os.path.join(test_dir, class_name)
    if os.path.isdir(class_folder):
        for img_name in os.listdir(class_folder):
            if img_name.endswith(('.jpg', '.jpeg', '.png')):
                test_paths.append(os.path.join(class_folder, img_name))
                test_labels.append(class_to_idx[class_name])
print(f"Test images: {len(test_paths)}")

Training images: 3576
Validation images: 1038
Test images: 536


# ViT Base 16

In [40]:
class ViTWithHead(nn.Module):
    def __init__(self, pretrained_model, num_classes, dropout_rate=0.4):
        super(ViTWithHead, self).__init__()
        self.vit = pretrained_model  # Base ViT model
        self.classifier = nn.Sequential(
            nn.BatchNorm1d(self.vit.config.hidden_size),
            nn.Dropout(dropout_rate), 
            nn.Linear(self.vit.config.hidden_size, 256), 
            nn.ReLU(), 
            nn.Dropout(0.3), 
            nn.Linear(256, num_classes)  
        )

    def forward(self, x):
        outputs = self.vit(x) 
        cls_token = outputs.last_hidden_state[:, 0]  # CLS token (first token)
        logits = self.classifier(cls_token) 
        return logits

In [None]:
finetunedModelPath = "Models/ViT/Finetuned/ViT-Base16-in21k.pth"
modelPath = "Models/ViT/Not-Finetuned/ViT-Base16-in21k.pth"

In [55]:
model = torch.load(finetunedModelPath, weights_only=False)
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")

In [56]:
device = torch.device("mps" if torch.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ViTWithHead(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, ou

In [57]:
train_dataset = CustomDataset(train_paths, train_labels, processor, model, device)
val_dataset = CustomDataset(val_paths, val_labels, processor, model, device)
test_dataset = CustomDataset(test_paths, test_labels, processor, model, device)

trainLoader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validLoader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
testLoader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [58]:
labels, preds = evaluate_model_pytorch(model, testLoader, categories)

Classification Report:
               precision    recall  f1-score   support

      Normal       0.74      0.73      0.73       179
  Osteopenia       0.82      0.86      0.84       180
Osteoporosis       0.85      0.82      0.84       177

    accuracy                           0.80       536
   macro avg       0.80      0.80      0.80       536
weighted avg       0.80      0.80      0.80       536



In [59]:
labels, preds = evaluate_with_best_thresholds_pytorch(model, testLoader, categories, device)

Best threshold for class 0: 0.3983
Best threshold for class 1: 0.6438
Best threshold for class 2: 0.1324
Classification Report with Optimized Thresholds:
               precision    recall  f1-score   support

      Normal       0.72      0.73      0.72       179
  Osteopenia       0.83      0.82      0.82       180
Osteoporosis       0.83      0.83      0.83       177

    accuracy                           0.79       536
   macro avg       0.79      0.79      0.79       536
weighted avg       0.79      0.79      0.79       536

