In [1]:
import pandas as pd
import numpy as np

# Load the dataset
df = pd.read_csv('/kaggle/input/roi-images-hda/masks_csvs/glaucoma_masks_train.csv')

# Count rows where 'Final Label' is 1 and 0
count_label_1 = len(df[df['Final Label'] == 1])
count_label_0 = len(df[df['Final Label'] == 0])

print(f"Count of rows with Final Label == 1: {count_label_1}")
print(f"Count of rows with Final Label == 0: {count_label_0}")


Count of rows with Final Label == 1: 2623
Count of rows with Final Label == 0: 5110


In [2]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import autoaugment

def train_transform():
    return transforms.Compose([
        transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC),
        autoaugment.AutoAugment(autoaugment.AutoAugmentPolicy.IMAGENET),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

def test_transform():
    return transforms.Compose([
        transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

class GlaucomaDataset(Dataset):
    """
    Args:
        dataframe (DataFrame): DataFrame containing the dataset information.
        img_folder (string): Directory with all the images.
        transform (callable, optional): Optional transform to be applied on a sample.
        extra_features (list of str, optional): Column names for the extra features.
    """
    def __init__(self, dataframe, img_folder, transform=None, extra_features=None):
        self.dataframe = dataframe
        self.img_folder = img_folder
        self.transform = transform
        self.extra_features = extra_features
        
    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_id = self.dataframe.iloc[idx]['Eye ID']
        for ext in ['.JPG','.JPEG', '.PNG', '.png', '.jpg', '.jpeg']:
            img_path = os.path.join(self.img_folder, f"{img_id}{ext}")
            if os.path.exists(img_path):
                break
        else:
            raise FileNotFoundError(f"No image found for ID {img_id} with any supported extension.")

        image = Image.open(img_path)
        if self.transform:
            image = self.transform(image)

        # Handling extra features 
        if self.extra_features == None :
            img_class = self.dataframe.iloc[idx]['Final Label']
            labels = torch.tensor(img_class, dtype=torch.float32)
        else:           
            extra_labels = self.dataframe.iloc[idx][self.extra_features].values.astype(float)
            labels = torch.tensor(extra_labels, dtype=torch.float32)
        
        return image, labels

In [3]:
from PIL import Image
image = Image.open('/kaggle/input/roi-images-hda/ROI_images/TRAIN000054.JPG')
print(image.size)

(518, 518)


In [4]:
import os
import re
import pandas as pd
import torch
import torch.nn as nn
import torch.optim
from torch.utils.data import DataLoader, random_split
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score

################### Configuration ###################
model_name = "ViT_RG_ROI"
model_save_directory = f'/kaggle/working/model/{model_name}'
img_folder = '/kaggle/input/roi-images-hda/ROI_images'
train_df = pd.read_csv('/kaggle/input/roi-images-hda/masks_csvs/glaucoma_masks_train.csv')

best_model_directory = os.path.join('/kaggle/working/', 'best_model')

if not os.path.exists(best_model_directory):
    os.makedirs(best_model_directory)

# Hyperparameters for early stopping
eval_every = 5  # Evaluate on validation set every 5 epochs
patience = 5    # Early stopping patience
num_epochs = 100  # Total training epochs

# Check for or create save directory
if not os.path.exists(model_save_directory):
    os.makedirs(model_save_directory)

################### Dataset and Dataloaders ###################
# Initialize dataset with transformations
train_dataset = GlaucomaDataset(dataframe=train_df, img_folder=img_folder, transform=train_transform(), extra_features=None)

# Split train_dataset into training and validation sets
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_data, val_data = random_split(train_dataset, [train_size, val_size])

# Dataloaders for training and validation
train_loader = DataLoader(train_data, batch_size=20, shuffle=True, num_workers=8)
val_loader = DataLoader(val_data, batch_size=20, shuffle=False, num_workers=8)

################### Model Initialization ###################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
weights = ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1
model = vit_b_16(weights=weights)

# Update the model's final layer for binary classification
num_features = model.heads.head.in_features
model.heads.head = nn.Linear(num_features, 1)

# Freeze all layers except the classifier head
for name, param in model.named_parameters():
    if 'heads.head' not in name:
        param.requires_grad = False

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

model.to(device)

################### Loss and Optimizer ###################
# Compute class weights for weighted BCE loss
negative_class = len(train_df[train_df['Final Label'] == 0])
positive_class = len(train_df[train_df['Final Label'] == 1])
pos_weight_value = negative_class / positive_class
pos_weight_tensor = torch.tensor([pos_weight_value], dtype=torch.float, device=device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)

# Parameter groups for differential learning rates
if torch.cuda.device_count() > 1:
    base_params = [p for n, p in model.module.named_parameters() if 'heads.head' not in n]
    classifier_params = model.module.heads.head.parameters()
else:
    base_params = [p for n, p in model.named_parameters() if 'heads.head' not in n]
    classifier_params = model.heads.head.parameters()

optimizer = torch.optim.AdamW([
    {'params': base_params, 'lr': 1e-5, 'weight_decay': 1e-4},
    {'params': classifier_params, 'lr': 1e-4, 'weight_decay': 1e-4}
])

################### Load Model Checkpoint (if exists) ###################
latest_model_path = None
start_epoch = 0

for file in os.listdir(model_save_directory):
    if file.startswith(f"{model_name}_epoch_") and file.endswith(".pth"):
        epoch_num = int(re.findall(r"\d+", file)[0])
        if epoch_num > start_epoch:
            start_epoch = epoch_num
            latest_model_path = os.path.join(model_save_directory, file)
            print(latest_model_path)

if latest_model_path:
    model.load_state_dict(torch.load(latest_model_path))
    print(f"Loaded model from {latest_model_path}, continuing training from epoch {start_epoch+1}")
else:
    print("No saved model found, starting training from scratch")


################### Training Loop with Early Stopping ###################
scaler = GradScaler()
best_val_auc = 0  # Track best AUC score for early stopping
epochs_no_improve = 0

for epoch in range(start_epoch, num_epochs):
    model.train()
    train_loss = 0.0

    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch + 1}")
    for batch_idx, (images, labels) in progress_bar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()

        with autocast():
            outputs = model(images)
            loss = criterion(outputs.squeeze(), labels.float())

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item() * images.size(0)
        progress_bar.set_postfix({'train_loss': loss.item()})

    train_loss /= len(train_loader.dataset)

    # Save model after each epoch
    if (epoch + 1) % 10 == 0:
        epoch_save_path = os.path.join(model_save_directory, f"{model_name}_epoch_{epoch + 1}.pth")
        torch.save(model.state_dict(), epoch_save_path)
        print(f"Model saved to {epoch_save_path} after epoch {epoch + 1}")

    # Validation and Early Stopping
    if (epoch + 1) % eval_every == 0:
        model.eval()
        val_loss = 0.0
        all_labels = []
        all_preds = []

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                with autocast():
                    outputs = model(images)
                    loss = criterion(outputs.squeeze(), labels.float())
                
                val_loss += loss.item() * images.size(0)
                preds = torch.sigmoid(outputs).squeeze().cpu().numpy()
                
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds)

        val_loss /= len(val_loader.dataset)

        # Calculate metrics
        val_auc = roc_auc_score(all_labels, all_preds)
        val_f1 = f1_score(all_labels, (np.array(all_preds) > 0.5).astype(int))
        val_precision = precision_score(all_labels, (np.array(all_preds) > 0.5).astype(int))
        val_recall = recall_score(all_labels, (np.array(all_preds) > 0.5).astype(int))

        print(f"Validation AUC after epoch {epoch + 1}: {val_auc}")
        print(f"Validation F1 Score: {val_f1}, Precision: {val_precision}, Recall: {val_recall}")

        # Early stopping check
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            epochs_no_improve = 0
            print(f"New best AUC: {best_val_auc}")
            
            best_model_path = os.path.join(best_model_directory, f"{model_name}_best.pth")
            torch.save(model.state_dict(), best_model_path)
            print(f"Best model saved to {best_model_path}")
        else:
            epochs_no_improve += 1
            print(f"No improvement for {epochs_no_improve} validation checks.")

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered. No improvement for {patience} validation checks.")
            break


Downloading: "https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16_swag-9ac1b537.pth
100%|██████████| 331M/331M [00:01<00:00, 220MB/s]


No saved model found, starting training from scratch


  scaler = GradScaler()
  with autocast():
Epoch 1: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.722]
Epoch 2: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.724]
Epoch 3: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.391]
Epoch 4: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.774]
Epoch 5: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.489]
  with autocast():


Validation AUC after epoch 5: 0.8608821923709709
Validation F1 Score: 0.6967071057192376, Precision: 0.6432, Recall: 0.7599243856332704
New best AUC: 0.8608821923709709
Best model saved to /kaggle/working/best_model/ViT_RG_ROI_best.pth


  with autocast():
Epoch 6: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.381]
Epoch 7: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.355]
Epoch 8: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.36]
Epoch 9: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.358]
Epoch 10: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.638]


Model saved to /kaggle/working/model/ViT_RG_ROI/ViT_RG_ROI_epoch_10.pth after epoch 10


  with autocast():


Validation AUC after epoch 10: 0.8795258132443985
Validation F1 Score: 0.7057864710676447, Precision: 0.6203438395415473, Recall: 0.8185255198487713
New best AUC: 0.8795258132443985
Best model saved to /kaggle/working/best_model/ViT_RG_ROI_best.pth


  with autocast():
Epoch 11: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.64]
Epoch 12: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.188]
Epoch 13: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.727]
Epoch 14: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.212]
Epoch 15: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.711]
  with autocast():


Validation AUC after epoch 15: 0.8902393959763947
Validation F1 Score: 0.7493403693931399, Precision: 0.7006578947368421, Recall: 0.8052930056710775
New best AUC: 0.8902393959763947
Best model saved to /kaggle/working/best_model/ViT_RG_ROI_best.pth


  with autocast():
Epoch 16: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.448]
Epoch 17: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.341]
Epoch 18: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.355]
Epoch 19: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=1.14]
Epoch 20: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.422]


Model saved to /kaggle/working/model/ViT_RG_ROI/ViT_RG_ROI_epoch_20.pth after epoch 20


  with autocast():


Validation AUC after epoch 20: 0.8927768967655916
Validation F1 Score: 0.7419072615923009, Precision: 0.6905537459283387, Recall: 0.8015122873345936
New best AUC: 0.8927768967655916
Best model saved to /kaggle/working/best_model/ViT_RG_ROI_best.pth


  with autocast():
Epoch 21: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.224]
Epoch 22: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.42]
Epoch 23: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.545]
Epoch 24: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.606]
Epoch 25: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.611]
  with autocast():


Validation AUC after epoch 25: 0.8922606690163075
Validation F1 Score: 0.7484662576687116, Precision: 0.6977124183006536, Recall: 0.8071833648393195
No improvement for 1 validation checks.


  with autocast():
Epoch 26: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.412]
Epoch 27: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=1]
Epoch 28: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.562]
Epoch 29: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.624]
Epoch 30: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.738]


Model saved to /kaggle/working/model/ViT_RG_ROI/ViT_RG_ROI_epoch_30.pth after epoch 30


  with autocast():


Validation AUC after epoch 30: 0.8980208793698308
Validation F1 Score: 0.7472150814053127, Precision: 0.6833855799373041, Recall: 0.8241965973534972
New best AUC: 0.8980208793698308
Best model saved to /kaggle/working/best_model/ViT_RG_ROI_best.pth


  with autocast():
Epoch 31: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.66]
Epoch 32: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.702]
Epoch 33: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.591]
Epoch 34: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.336]
Epoch 35: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=1.17]
  with autocast():


Validation AUC after epoch 35: 0.8994330779429623
Validation F1 Score: 0.7491039426523298, Precision: 0.7120954003407155, Recall: 0.7901701323251418
New best AUC: 0.8994330779429623
Best model saved to /kaggle/working/best_model/ViT_RG_ROI_best.pth


  with autocast():
Epoch 36: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.384]
Epoch 37: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.765]
Epoch 38: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=1.28]
Epoch 39: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=1.22]
Epoch 40: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.871]


Model saved to /kaggle/working/model/ViT_RG_ROI/ViT_RG_ROI_epoch_40.pth after epoch 40


  with autocast():


Validation AUC after epoch 40: 0.9029881787559283
Validation F1 Score: 0.764102564102564, Precision: 0.6973478939157566, Recall: 0.8449905482041588
New best AUC: 0.9029881787559283
Best model saved to /kaggle/working/best_model/ViT_RG_ROI_best.pth


  with autocast():
Epoch 41: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.422]
Epoch 42: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.541]
Epoch 43: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.164]
Epoch 44: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.606]
Epoch 45: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.246]
  with autocast():


Validation AUC after epoch 45: 0.896676458900472
Validation F1 Score: 0.7482638888888888, Precision: 0.6918138041733547, Recall: 0.8147448015122873
No improvement for 1 validation checks.


  with autocast():
Epoch 46: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.207]
Epoch 47: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=1.34]
Epoch 48: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.284]
Epoch 49: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.467]
Epoch 50: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.448]


Model saved to /kaggle/working/model/ViT_RG_ROI/ViT_RG_ROI_epoch_50.pth after epoch 50


  with autocast():


Validation AUC after epoch 50: 0.9108550811294618
Validation F1 Score: 0.7734082397003746, Precision: 0.7662337662337663, Recall: 0.780718336483932
New best AUC: 0.9108550811294618
Best model saved to /kaggle/working/best_model/ViT_RG_ROI_best.pth


  with autocast():
Epoch 51: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.337]
Epoch 52: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.153]
Epoch 53: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=1.17]
Epoch 54: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.515]
Epoch 55: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.401]
  with autocast():


Validation AUC after epoch 55: 0.8939189113908068
Validation F1 Score: 0.7517985611510791, Precision: 0.7169811320754716, Recall: 0.7901701323251418
No improvement for 1 validation checks.


  with autocast():
Epoch 56: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.373]
Epoch 57: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.149]
Epoch 58: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.16]
Epoch 59: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.371]
Epoch 60: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.39]


Model saved to /kaggle/working/model/ViT_RG_ROI/ViT_RG_ROI_epoch_60.pth after epoch 60


  with autocast():


Validation AUC after epoch 60: 0.8974498720572233
Validation F1 Score: 0.7433016421780466, Precision: 0.6847133757961783, Recall: 0.8128544423440454
No improvement for 2 validation checks.


  with autocast():
Epoch 61: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.714]
Epoch 62: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.295]
Epoch 63: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.954]
Epoch 64: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.236]
Epoch 65: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.282]
  with autocast():


Validation AUC after epoch 65: 0.9041144094391687
Validation F1 Score: 0.7695004382120947, Precision: 0.7173202614379085, Recall: 0.8298676748582231
No improvement for 3 validation checks.


  with autocast():
Epoch 66: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.195]
Epoch 67: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=1.16]
Epoch 68: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.347]
Epoch 69: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.205]
Epoch 70: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=1.35]


Model saved to /kaggle/working/model/ViT_RG_ROI/ViT_RG_ROI_epoch_70.pth after epoch 70


  with autocast():


Validation AUC after epoch 70: 0.8997125465626288
Validation F1 Score: 0.7675675675675676, Precision: 0.7332185886402753, Recall: 0.8052930056710775
No improvement for 4 validation checks.


  with autocast():
Epoch 71: 100%|██████████| 310/310 [02:39<00:00,  1.95it/s, train_loss=0.431]
Epoch 72: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.621]
Epoch 73: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.394]
Epoch 74: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.304]
Epoch 75: 100%|██████████| 310/310 [02:38<00:00,  1.95it/s, train_loss=0.545]
  with autocast():


Validation AUC after epoch 75: 0.9032054400748716
Validation F1 Score: 0.7590697674418606, Precision: 0.7472527472527473, Recall: 0.7712665406427222
No improvement for 5 validation checks.
Early stopping triggered. No improvement for 5 validation checks.


# for Best model


In [7]:
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import numpy as np

# Assuming test_df and img_folder paths are defined, and you have a test dataset loader
test_df = pd.read_csv('/kaggle/input/roi-images-hda/masks_csvs/glaucoma_masks_test.csv')  # Path to test CSV
test_dataset = GlaucomaDataset(dataframe=test_df, img_folder=img_folder, transform=test_transform(), extra_features=None)
test_loader = DataLoader(test_dataset, batch_size=20, shuffle=False, num_workers=8)

# best_model_directory = '/kaggle/input/model-for-hda/best_model/'

# Load the best model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)

# Update the model's final layer for binary classification
num_features = model.heads.head.in_features
model.heads.head = nn.Linear(num_features, 1)

best_model_path = os.path.join(best_model_directory, f"{model_name}_best.pth")
model.load_state_dict(torch.load(best_model_path))
model.to(device)
model.eval()

# Initialize lists to collect true labels and predictions
all_labels = []
all_preds = []

# Testing loop
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        
        # Get predictions
        outputs = model(images)
        preds = torch.sigmoid(outputs).squeeze().cpu().numpy()  # Apply sigmoid to get probabilities
        
        # Store labels and predictions
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(preds)

# Convert predictions to binary classes with a threshold of 0.5
binary_preds = (np.array(all_preds) > 0.5).astype(int)

# Calculate metrics
test_accuracy = accuracy_score(all_labels, binary_preds)
test_precision = precision_score(all_labels, binary_preds)
test_recall = recall_score(all_labels, binary_preds)
test_f1 = f1_score(all_labels, binary_preds)
test_auc = roc_auc_score(all_labels, all_preds)

# Print the results
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test Precision: {test_precision:.4f}")
print(f"Test Recall: {test_recall:.4f}")
print(f"Test F1 Score: {test_f1:.4f}")
print(f"Test AUC-ROC: {test_auc:.4f}")


  model.load_state_dict(torch.load(best_model_path))


Test Accuracy: 0.8592
Test Precision: 0.9147
Test Recall: 0.8110
Test F1 Score: 0.8597
Test AUC-ROC: 0.9275


In [8]:
# import os
# import pandas as pd
# import torch
# import torch.nn as nn
# from torchvision.models import vit_l_16, ViT_L_16_Weights
# from torch.utils.data import DataLoader
# import numpy as np
# from sklearn.metrics import roc_auc_score, roc_curve
# from tqdm import tqdm

# # from data_utils import GlaucomaDataset, test_transform

# ####### Adjust this section as needed ############################
# # model_save_directory = './model/ViT_glaucoma_ROI'    
# # img_folder = './ROI_images' 
# # test_df = pd.read_csv('./Datasets/glaucoma_masks_test.csv')

# # Unmute to validate ViT without ROI
# model_save_directory = '/kaggle/working/model'  
# img_folder = '/kaggle/input/preprocessed-image-hda-without-roi/preprocessed_images' 
# test_df = pd.read_csv('/kaggle/input/images-hda-before-preprocess/glaucoma_no_mask_test.csv')
# #################################################################

# # Load test data
# test_dataset = GlaucomaDataset(dataframe=test_df, img_folder=img_folder, transform=test_transform)
# test_loader = DataLoader(test_dataset, batch_size=12, shuffle=True, num_workers=2)

# # Model setup
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1)
# weights = ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1                                                                                               
# model = vit_b_16(weights=weights)
# model.heads.head = nn.Linear(model.heads.head.in_features, 1)
# model.to(device)

# def load_model(model, model_path, device):
#     state_dict = torch.load(model_path, map_location=device)
#     if any(k.startswith('module.') for k in state_dict.keys()):
#         new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
#     else:
#         new_state_dict = state_dict
#     model.load_state_dict(new_state_dict)

# def compute_metrics(actuals, probabilities):
#     fpr, tpr, thresholds = roc_curve(actuals, probabilities)
#     target_specificity = 0.95
#     target_fpr = 1 - target_specificity

#     # Find the first threshold where FPR is <= target FPR
#     index = np.where(fpr <= target_fpr)[0][0]
#     optimal_threshold = thresholds[index]
#     predictions = (probabilities >= optimal_threshold).astype(int)

#     TP = np.sum((actuals == 1) & (predictions == 1))
#     TN = np.sum((actuals == 0) & (predictions == 0))
#     FP = np.sum((actuals == 0) & (predictions == 1))
#     FN = np.sum((actuals == 1) & (predictions == 0))

#     sensitivity = TP / (TP + FN) if TP + FN > 0 else 0
#     specificity = TN / (TN + FP) if TN + FP > 0 else 0
#     accuracy = (TP + TN) / (TP + TN + FP + FN) if TP + TN + FP + FN > 0 else 0
#     auc = roc_auc_score(actuals, probabilities) if len(np.unique(actuals)) > 1 else 0
#     return sensitivity, specificity, accuracy, auc, optimal_threshold

# # Evaluate all models in the directory
# for filename in os.listdir(model_save_directory):
#     if filename.endswith(".pth"):
#         model_path = os.path.join(model_save_directory, filename)
#         load_model(model, model_path, device)
#         model.eval()

#         all_labels = []
#         all_probabilities = []
#         with torch.no_grad():
#             for images, labels in tqdm(test_loader, desc=f"Evaluating {filename}", leave=True):
#                 images = images.to(device)
#                 outputs = model(images)
#                 probabilities = torch.sigmoid(outputs).squeeze()
#                 all_labels.extend(labels.numpy())
#                 all_probabilities.extend(probabilities.cpu().numpy())

#         sensitivity, specificity, accuracy, auc_score, optimal_threshold = compute_metrics(np.array(all_labels), np.array(all_probabilities))
#         print(f"Model: {filename}")
#         print(f"Sensitivity: {sensitivity:.4f}")
#         print(f"Specificity: {specificity:.4f}")
#         print(f"Accuracy: {accuracy:.4f}")
#         print(f"AUC Score: {auc_score:.4f}")
#         print(f"Optimal Threshold: {optimal_threshold:.4f}\n")