In [1]:
import os
import random

from collections import defaultdict

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import timm

from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, classification_report, roc_auc_score, recall_score
from sklearn.model_selection import StratifiedKFold

# Check for CUDA device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

SEED = 16

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

cuda


In [2]:
# Parameters
dataset_path = '/kaggle/input/two-stage-dataset/Train_2Stage/BvsC'
model_name = 'tiny_vit_21m_512.dist_in22k_ft_in1k'
num_epochs = 10
num_classes = 2
batch_size = 14
k_folds = 5

In [3]:
def get_sample_count(folder_path):
    output = {}
    classes = os.listdir(folder_path)
    for i in classes:
        output[i] = len(os.listdir(f'{folder_path}/{i}'))
    return output

print(get_sample_count(dataset_path))

{'B': 1552, 'C': 1705}


In [3]:
# GET MODEL INFO
m = timm.create_model(model_name, pretrained=True, num_classes=num_classes, drop_rate=0.3)
model_info = m.default_cfg
del m

model.safetensors:   0%|          | 0.00/85.2M [00:00<?, ?B/s]

In [4]:
input_shape = model_info['input_size'][1:]
transform_mean = model_info['mean']
transform_std = model_info['std']

print(f"USING MODEL ARCHITECTURE {model_info['architecture']} ")
print(f"INPUT SHAPE = {input_shape}")
print(f"       MEAN = {transform_mean}")
print(f"        STD = {transform_std}")

USING MODEL ARCHITECTURE tiny_vit_21m_512 
INPUT SHAPE = (512, 512)
       MEAN = (0.485, 0.456, 0.406)
        STD = (0.229, 0.224, 0.225)


In [5]:
transform = transforms.Compose([
        transforms.Resize(input_shape),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=transform_mean, std=transform_std)
    ])

dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

print("CLASS MAPPING")
print(dataset.class_to_idx)

CLASS MAPPING
{'B': 0, 'C': 1}


In [6]:
kfold = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=SEED)

history = []

# K-Fold Loop
for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset, dataset.targets)):
    print('\n--------------------------------')
    print(f'FOLD {fold+1}/{k_folds}')
    print('--------------------------------')

    # Create data loaders for this fold
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)

    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_subsampler, num_workers=4)
    val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_subsampler, num_workers=4)

    # Initialize the model for this fold
    model = timm.create_model(model_name, pretrained=True, num_classes=num_classes, drop_rate=0.3)
    # Add sigmoid activation for BCE Loss
    model = nn.Sequential(
        model,
        nn.Sigmoid()
    )
    model = model.to(DEVICE)

    # Change to BCE Loss
    criterion = nn.BCELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    # Variables to keep track of the best model for this fold
    best_val_acc = -999
    best_epoch = 0
    best_model_wts = None
    best_preds = {
        'labels':None,
        'preds':None
    }

    # Model training loop START
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        model.train()
        running_loss = 0.0

        # Training
        for inputs, labels in train_loader:
            inputs = inputs.to(DEVICE)
            # Convert labels to one-hot encoding and float type for BCE
            labels_one_hot = torch.zeros(labels.size(0), num_classes)
            labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)
            labels_one_hot = labels_one_hot.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels_one_hot)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        all_labels = []
        all_preds = []
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(DEVICE)
                # Convert labels to one-hot encoding and float type for BCE
                labels_one_hot = torch.zeros(labels.size(0), num_classes)
                labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)
                labels_one_hot = labels_one_hot.to(DEVICE)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels_one_hot)
                val_loss += loss.item()

                # Convert probabilities to predictions
                predicted = torch.max(outputs, 1)[1]
                total += labels.size(0)
                correct += (predicted.cpu() == labels).sum().item()
                
                all_labels.extend(labels.numpy())
                all_preds.extend(predicted.cpu().numpy())

        val_acc = correct / total

        print(f'Loss: {running_loss/len(train_loader)} , Val Loss: {val_loss/len(val_loader)}, Val Accuracy: {val_acc}')

        # Always keep track of the best model and best output
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch + 1
            best_model_wts = model.state_dict()
            best_preds = {
                'labels': all_labels,
                'preds': all_preds
            }

    # TRAIN LOOP ENDS HERE
    all_labels = best_preds['labels']
    all_preds = best_preds['preds']
    
    # Calculate confusion matrix
    conf_matrix = confusion_matrix(all_labels, all_preds)
    
    # Sensitivity, Specificity calculation
    sensitivity = []
    specificity = []
    
    for i in range(num_classes):
        tp = conf_matrix[i, i]
        fn = sum(conf_matrix[i, :]) - tp
        fp = sum(conf_matrix[:, i]) - tp
        tn = conf_matrix.sum() - (tp + fp + fn)

        sensitivity.append(tp / (tp + fn) if (tp + fn) > 0 else 0)
        specificity.append(tn / (tn + fp) if (tn + fp) > 0 else 0)

    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')

    best_hist = { 'Accuracy': accuracy, 'Precision': precision, 'F1 Score': f1, 'recall': recall }

    for class_name in dataset.class_to_idx:
        best_hist[f'{class_name} Sensitivity'] = sensitivity[dataset.class_to_idx[class_name]]
        best_hist[f'{class_name} Specificity'] = specificity[dataset.class_to_idx[class_name]]

    print()
    for key, value in best_hist.items():
        print(f'{key}: {value}')

    print()
    # get classification report
    print(classification_report(all_labels, all_preds, target_names=[i for i in dataset.class_to_idx]))
    print()

    # Save model to disk
    print(f'Best Epoch for Fold {fold+1}: {best_epoch} with Validation Accuracy: {best_val_acc}')
    torch.save(best_model_wts, f'fold{fold+1}_model.pth')
    
    history.append(best_hist)


--------------------------------
FOLD 1/5
--------------------------------
Epoch 1/10
Loss: 0.5604024905571963 , Val Loss: 0.41286928571285086, Val Accuracy: 0.8343558282208589
Epoch 2/10
Loss: 0.3754924373990074 , Val Loss: 0.3697203879660748, Val Accuracy: 0.8266871165644172
Epoch 3/10
Loss: 0.304266178671689 , Val Loss: 0.2768536056292818, Val Accuracy: 0.8803680981595092
Epoch 4/10
Loss: 0.26691465531998776 , Val Loss: 0.2465699036070641, Val Accuracy: 0.9003067484662577
Epoch 5/10
Loss: 0.24850142717281765 , Val Loss: 0.2540544186659316, Val Accuracy: 0.8849693251533742
Epoch 6/10
Loss: 0.2074119823780608 , Val Loss: 0.25150415903710305, Val Accuracy: 0.8880368098159509
Epoch 7/10
Loss: 0.19895353383399586 , Val Loss: 0.23467337205371958, Val Accuracy: 0.8941717791411042
Epoch 8/10
Loss: 0.2046414692014615 , Val Loss: 0.20743345683242412, Val Accuracy: 0.9141104294478528
Epoch 9/10
Loss: 0.15713811286769927 , Val Loss: 0.2003417765682048, Val Accuracy: 0.9202453987730062
Epoch 10

In [7]:
history

[{'Accuracy': 0.9202453987730062,
  'Precision': 0.9245980445670181,
  'F1 Score': 0.919609237919097,
  'recall': 0.9180960104100857,
  'B Sensitivity': 0.8713826366559485,
  'B Specificity': 0.9648093841642229,
  'C Sensitivity': 0.9648093841642229,
  'C Specificity': 0.8713826366559485},
 {'Accuracy': 0.9493865030674846,
  'Precision': 0.9539574402939375,
  'F1 Score': 0.9489966931765654,
  'recall': 0.9473696617665086,
  'B Sensitivity': 0.9035369774919614,
  'B Specificity': 0.9912023460410557,
  'C Sensitivity': 0.9912023460410557,
  'C Specificity': 0.9035369774919614},
 {'Accuracy': 0.9339477726574501,
  'Precision': 0.9337015009912206,
  'F1 Score': 0.9338853710784811,
  'recall': 0.9346041055718475,
  'B Sensitivity': 0.9483870967741935,
  'B Specificity': 0.9208211143695014,
  'C Sensitivity': 0.9208211143695014,
  'C Specificity': 0.9483870967741935},
 {'Accuracy': 0.9308755760368663,
  'Precision': 0.9310859471240406,
  'F1 Score': 0.9306635424432848,
  'recall': 0.93035190

In [8]:
!zip BvsC.zip ./*.pth

  adding: fold1_model.pth (deflated 7%)
  adding: fold2_model.pth (deflated 7%)
  adding: fold3_model.pth (deflated 7%)
  adding: fold4_model.pth (deflated 7%)
  adding: fold5_model.pth (deflated 7%)


In [9]:
hist_df = pd.DataFrame(history)
hist_df

Unnamed: 0,Accuracy,Precision,F1 Score,recall,B Sensitivity,B Specificity,C Sensitivity,C Specificity
0,0.920245,0.924598,0.919609,0.918096,0.871383,0.964809,0.964809,0.871383
1,0.949387,0.953957,0.948997,0.94737,0.903537,0.991202,0.991202,0.903537
2,0.933948,0.933702,0.933885,0.934604,0.948387,0.920821,0.920821,0.948387
3,0.930876,0.931086,0.930664,0.930352,0.919355,0.941349,0.941349,0.919355
4,0.924731,0.92475,0.924525,0.92434,0.916129,0.932551,0.932551,0.916129
