In [None]:
from torch.utils.data import Dataset, DataLoader
import torch
import random
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torchvision.transforms import RandomErasing
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix

from collections import Counter
import numpy as np
import cv2
import os
import pandas as pd
from torch.cuda.amp import autocast, GradScaler


####### model paramters 

task_name='FGS_ensemble'


train_dir = 'TRAIN'
test_pre = 'TEST' 

root_folder='RESULTS' 



save_folder = os.path.join(root_folder,task_name)

os.makedirs(save_folder, exist_ok=True)
txt_fname=os.path.join(save_folder,task_name+'.txt')
save_model_path=os.path.join(save_folder,task_name+'best_model.pth')

  check_for_updates()


In [2]:
def set_seed(no):
    torch.manual_seed(no)
    random.seed(no)
    np.random.seed(no)
    os.environ['PYTHONHASHSEED'] = str()
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_seed(100)

In [None]:
##### custom data set 

class CustomDataset(Dataset):
    def __init__(self, root_dir, csv_path, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.class_labels = []
        self.severity_scores = {}
        self._load_severity_scores(csv_path)
        self._load_dataset()

    def _load_severity_scores(self, csv_path):
        # Read the CSV file
        df = pd.read_csv(csv_path)
        # Assuming the CSV has columns 'filename' and 'severity_score'
        for _, row in df.iterrows():
            self.severity_scores[row['filename']] = row['severity_category']

    def _load_dataset(self):
        for class_label, class_dir in enumerate(['HEALTHY', 'ABNORMAL']):  # 0: HEALTHY, 1: ABNORMAL
            class_dir_path = os.path.join(self.root_dir, class_dir)
            if not os.path.exists(class_dir_path):
                print(f"Directory not found: {class_dir_path}")
                continue
            for filename in os.listdir(class_dir_path):
                if filename.endswith(('.jpg', '.jpeg', '.png')):
                    self.image_paths.append(os.path.join(class_dir_path, filename))
                    self.class_labels.append(class_label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        class_label = self.class_labels[idx]
        image_path = self.image_paths[idx]
        filename = os.path.basename(image_path)
        severity_score = self.severity_scores.get(filename, 0)  # Default to 0 if not found
        image = cv2.imread(image_path)
        if image is None:
            raise FileNotFoundError(f"Image not found at path: {image_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented["image"]
        else:
            image = image.astype(np.float32)
        return image, class_label, severity_score, filename

    # New method to get labels
    def get_labels(self):
        return self.class_labels, list(self.severity_scores.values())
    
    

##### original data mean and variance 
mean = [0.4907, 0.2818, 0.34892]
std = [0.1352, 0.1340, 0.1515]
############################### old
train_transform = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.HueSaturationValue(p=0.2),
    A.CoarseDropout(
        max_holes=1,
        max_height=50,
        max_width=50,
        min_holes=1,
        min_height=50,
        min_width=50,
        p=0.5
    ),
    A.Rotate(limit=30, p=0.5),
    A.ShiftScaleRotate(
        shift_limit=0.1,
        scale_limit=0.1,
        rotate_limit=30,
        p=0.5
    ),
    A.GridDistortion(p=0.5),
    A.GaussianBlur(blur_limit=(3, 7), p=0.5),
    A.CLAHE(p=0.5),
    A.RGBShift(p=0.5),
    A.Normalize(mean=mean, std=std),
    ToTensorV2(),
])

###############################


test_transform = A.Compose([
    A.Resize(256, 256),
    A.CenterCrop(224, 224),
    A.Normalize(mean=mean, std=std),
    ToTensorV2(),
])
score_path='all_score_category.csv'
trainset = CustomDataset(train_dir,score_path,  transform=train_transform)
testset_pre = CustomDataset(test_pre,score_path,  transform=test_transform)

In [4]:
####model structure here
import timm
from torchvision import models

class EnsembleModel(nn.Module):
   def __init__(self):
       super().__init__()
       
       # DinoV2
       self.dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
       # Freeze first few layers
       for param in list(self.dinov2.parameters())[:-4]:
           param.requires_grad = False
           
       self.dino_classifier = nn.Sequential(
           nn.Linear(768, 256),
           nn.ReLU(),
           nn.BatchNorm1d(256),
           nn.Dropout(0.3),
           nn.Linear(256, 2)
       )

       # EfficientNet
       self.efficient = timm.create_model('efficientnet_b0', pretrained=True)
       # Freeze initial layers
       for param in list(self.efficient.parameters())[:-4]: 
           param.requires_grad = False
           
       self.efficient.classifier = nn.Sequential(
           nn.Linear(1280, 256),
           nn.ReLU(),
           nn.BatchNorm1d(256),
           nn.Dropout(0.3),
           nn.Linear(256, 2)
       )

       # ResNet18 
       self.resnet = models.resnet18(pretrained=True)
       for param in list(self.resnet.parameters())[:-4]:
           param.requires_grad = False
           
       self.resnet.fc = nn.Sequential(
           nn.Linear(512, 256), 
           nn.ReLU(),
           nn.BatchNorm1d(256),
           nn.Dropout(0.3),
           nn.Linear(256, 2)
       )

   @torch.cuda.amp.autocast()
   def forward(self, x):
       dino_out = self.dino_classifier(self.dinov2(x))
       eff_out = self.efficient(x)
       res_out = self.resnet(x)
       return (dino_out + eff_out + res_out) / 3
    
model = EnsembleModel()


Using cache found in /home/tanmoy/.cache/torch/hub/facebookresearch_dinov2_main


In [5]:
# Create dataloaders
trainloader = DataLoader(trainset, batch_size=16, shuffle=True)
testloader = DataLoader(testset_pre, batch_size=16, shuffle=False)

#### model training and testing here
import torch.nn.functional as F
import numpy as np
from sklearn.metrics import roc_auc_score, f1_score, confusion_matrix, precision_recall_curve, auc
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

def evaluate_model(model, dataloader, device):
   model.eval()
   all_class_preds = []
   all_class_labels = []
   all_class_probs = []
   
   with torch.no_grad():
       for images, class_labels, _, _ in dataloader:
           images = images.to(device)
           class_labels = class_labels.to(device)
           class_outputs = model(images)
           
           class_probs = F.softmax(class_outputs, dim=1).cpu().numpy()
           class_preds = np.argmax(class_probs, axis=1)
           
           all_class_probs.extend(class_probs)
           all_class_preds.extend(class_preds)
           all_class_labels.extend(class_labels.cpu().numpy())
   
   all_class_preds = np.array(all_class_preds)
   all_class_labels = np.array(all_class_labels) 
   all_class_probs = np.array(all_class_probs)

   tn, fp, fn, tp = confusion_matrix(all_class_labels, all_class_preds).ravel()
   metrics = {
       'auc': roc_auc_score(all_class_labels, all_class_probs[:,1]),
       'f1': f1_score(all_class_labels, all_class_preds),
       'specificity': tn / (tn + fp),
       'sensitivity': tp / (tp + fn),
       'accuracy': (tp + tn) / (tp + tn + fp + fn)
   }

   print("\nConfusion Matrix:")
   print(confusion_matrix(all_class_labels, all_class_preds))
   print(f"\nPerformance Metrics:")
   print(f"AUC: {metrics['auc']:.4f}")
   print(f"F1: {metrics['f1']:.4f}")
   print(f"Specificity: {metrics['specificity']:.4f}") 
   print(f"Sensitivity: {metrics['sensitivity']:.4f}")
   print(f"Accuracy: {metrics['accuracy']:.4f}")

   return metrics

cuda:0


In [6]:
# Training loop
model = model.to(device)
num_epochs = 150
best_f1 = 0.0
best_sens=0.0
# Training modifications
# optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3)
optimizer = optim.SGD(model.parameters(), lr=0.00002,
                      momentum=0.9, weight_decay=2e-3)
criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.4, 0.6]).to(device))   #### summation should be 1

# Additional training loop changes
early_stopping_patience = 20
best_f1 = 0
no_improve = 0

scaler = GradScaler()

for epoch in range(num_epochs):
   model.train()
   running_loss = 0.0
   
   for images, labels, _, _ in trainloader:
       images, labels = images.to(device), labels.to(device)
       optimizer.zero_grad()
       
       with autocast():
           outputs = model(images)
           loss = criterion(outputs, labels)
           
       scaler.scale(loss).backward()
       scaler.step(optimizer)
       scaler.update()
       
       running_loss += loss.item()

   metrics = evaluate_model(model, testloader, device)
   performance_string = (f"Test Set Performance - AUC: {metrics['auc']:.2f}, F1 Score: {metrics['f1']:.2f}, Specificity: {metrics['specificity']:.2f}, Sensitivity: {metrics['sensitivity']:.2f}, Accuracy: {metrics['accuracy']:.2f}")
   with open(txt_fname, 'a') as file:
      file.write(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(trainloader):.4f}\n")
      file.write(performance_string+'\n')
   
   if metrics['f1'] > best_f1 and metrics['sensitivity'] > 0.6  and metrics['specificity']> 0.6 :
       best_f1 = metrics['f1']
       best_sens=metrics['sensitivity']
       torch.save(model.state_dict(), save_model_path)
       with open(txt_fname, 'a') as file:
          file.write("New model saved"+'\n')
       print(f'New model saved with F1: {best_f1} and sens: {best_sens} ')
       
   if no_improve >= early_stopping_patience:
       print("Early stopping triggered")
       break


Confusion Matrix:
[[150  27]
 [ 15   6]]

Performance Metrics:
AUC: 0.6879
F1: 0.2222
Specificity: 0.8475
Sensitivity: 0.2857
Accuracy: 0.7879

Confusion Matrix:
[[125  52]
 [  9  12]]

Performance Metrics:
AUC: 0.6651
F1: 0.2824
Specificity: 0.7062
Sensitivity: 0.5714
Accuracy: 0.6919

Confusion Matrix:
[[116  61]
 [  8  13]]

Performance Metrics:
AUC: 0.6488
F1: 0.2737
Specificity: 0.6554
Sensitivity: 0.6190
Accuracy: 0.6515
New model saved with F1: 0.2736842105263158 and sens: 0.6190476190476191 

Confusion Matrix:
[[99 78]
 [ 6 15]]

Performance Metrics:
AUC: 0.6550
F1: 0.2632
Specificity: 0.5593
Sensitivity: 0.7143
Accuracy: 0.5758

Confusion Matrix:
[[95 82]
 [ 5 16]]

Performance Metrics:
AUC: 0.6521
F1: 0.2689
Specificity: 0.5367
Sensitivity: 0.7619
Accuracy: 0.5606

Confusion Matrix:
[[99 78]
 [ 5 16]]

Performance Metrics:
AUC: 0.6770
F1: 0.2783
Specificity: 0.5593
Sensitivity: 0.7619
Accuracy: 0.5808

Confusion Matrix:
[[93 84]
 [ 4 17]]

Performance Metrics:
AUC: 0.6894
F1


Confusion Matrix:
[[94 83]
 [ 6 15]]

Performance Metrics:
AUC: 0.6715
F1: 0.2521
Specificity: 0.5311
Sensitivity: 0.7143
Accuracy: 0.5505

Confusion Matrix:
[[93 84]
 [ 6 15]]

Performance Metrics:
AUC: 0.6671
F1: 0.2500
Specificity: 0.5254
Sensitivity: 0.7143
Accuracy: 0.5455

Confusion Matrix:
[[104  73]
 [  7  14]]

Performance Metrics:
AUC: 0.6698
F1: 0.2593
Specificity: 0.5876
Sensitivity: 0.6667
Accuracy: 0.5960

Confusion Matrix:
[[107  70]
 [  7  14]]

Performance Metrics:
AUC: 0.6667
F1: 0.2667
Specificity: 0.6045
Sensitivity: 0.6667
Accuracy: 0.6111

Confusion Matrix:
[[102  75]
 [  6  15]]

Performance Metrics:
AUC: 0.6770
F1: 0.2703
Specificity: 0.5763
Sensitivity: 0.7143
Accuracy: 0.5909

Confusion Matrix:
[[100  77]
 [  6  15]]

Performance Metrics:
AUC: 0.6507
F1: 0.2655
Specificity: 0.5650
Sensitivity: 0.7143
Accuracy: 0.5808

Confusion Matrix:
[[104  73]
 [  7  14]]

Performance Metrics:
AUC: 0.6677
F1: 0.2593
Specificity: 0.5876
Sensitivity: 0.6667
Accuracy: 0.5960



Confusion Matrix:
[[101  76]
 [  6  15]]

Performance Metrics:
AUC: 0.6850
F1: 0.2679
Specificity: 0.5706
Sensitivity: 0.7143
Accuracy: 0.5859

Confusion Matrix:
[[113  64]
 [  6  15]]

Performance Metrics:
AUC: 0.6965
F1: 0.3000
Specificity: 0.6384
Sensitivity: 0.7143
Accuracy: 0.6465

Confusion Matrix:
[[120  57]
 [  8  13]]

Performance Metrics:
AUC: 0.6847
F1: 0.2857
Specificity: 0.6780
Sensitivity: 0.6190
Accuracy: 0.6717

Confusion Matrix:
[[108  69]
 [  7  14]]

Performance Metrics:
AUC: 0.6817
F1: 0.2692
Specificity: 0.6102
Sensitivity: 0.6667
Accuracy: 0.6162

Confusion Matrix:
[[107  70]
 [  7  14]]

Performance Metrics:
AUC: 0.6649
F1: 0.2667
Specificity: 0.6045
Sensitivity: 0.6667
Accuracy: 0.6111

Confusion Matrix:
[[115  62]
 [  7  14]]

Performance Metrics:
AUC: 0.6778
F1: 0.2887
Specificity: 0.6497
Sensitivity: 0.6667
Accuracy: 0.6515

Confusion Matrix:
[[122  55]
 [  7  14]]

Performance Metrics:
AUC: 0.6893
F1: 0.3111
Specificity: 0.6893
Sensitivity: 0.6667
Accuracy: