In [2]:
import os
import random

from collections import defaultdict

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import timm

from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, classification_report, roc_auc_score, recall_score
from sklearn.model_selection import StratifiedKFold

# Check for CUDA device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

SEED = 16

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

cuda


In [3]:
# Parameters
dataset_path = '/kaggle/input/two-stage-dataset/Train_2Stage/NvsP'
model_name = 'tiny_vit_21m_512.dist_in22k_ft_in1k'
num_epochs = 10
num_classes = 2
batch_size = 14
k_folds = 5

In [4]:
def get_sample_count(folder_path):
    output = {}
    classes = os.listdir(folder_path)
    for i in classes:
        output[i] = len(os.listdir(f'{folder_path}/{i}'))
    return output

print(get_sample_count(dataset_path))

{'N': 1006, 'P': 2251}


In [3]:
# GET MODEL INFO
m = timm.create_model(model_name, pretrained=True, num_classes=num_classes, drop_rate=0.3)
model_info = m.default_cfg
del m

model.safetensors:   0%|          | 0.00/85.2M [00:00<?, ?B/s]

In [4]:
input_shape = model_info['input_size'][1:]
transform_mean = model_info['mean']
transform_std = model_info['std']

print(f"USING MODEL ARCHITECTURE {model_info['architecture']} ")
print(f"INPUT SHAPE = {input_shape}")
print(f"       MEAN = {transform_mean}")
print(f"        STD = {transform_std}")

USING MODEL ARCHITECTURE tiny_vit_21m_512 
INPUT SHAPE = (512, 512)
       MEAN = (0.485, 0.456, 0.406)
        STD = (0.229, 0.224, 0.225)


In [5]:
transform = transforms.Compose([
        transforms.Resize(input_shape),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=transform_mean, std=transform_std)
    ])

dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

print("CLASS MAPPING")
print(dataset.class_to_idx)

CLASS MAPPING
{'N': 0, 'P': 1}


In [6]:
kfold = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=SEED)

history = []

# K-Fold Loop
for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset, dataset.targets)):
    print('\n--------------------------------')
    print(f'FOLD {fold+1}/{k_folds}')
    print('--------------------------------')

    # Create data loaders for this fold
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)

    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_subsampler, num_workers=4)
    val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_subsampler, num_workers=4)

    # Initialize the model for this fold
    model = timm.create_model(model_name, pretrained=True, num_classes=num_classes, drop_rate=0.3)
    # Add sigmoid activation for BCE Loss
    model = nn.Sequential(
        model,
        nn.Sigmoid()
    )
    model = model.to(DEVICE)

    # Change to BCE Loss
    criterion = nn.BCELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    # Variables to keep track of the best model for this fold
    best_val_acc = -999
    best_epoch = 0
    best_model_wts = None
    best_preds = {
        'labels':None,
        'preds':None
    }

    # Model training loop START
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        model.train()
        running_loss = 0.0

        # Training
        for inputs, labels in train_loader:
            inputs = inputs.to(DEVICE)
            # Convert labels to one-hot encoding and float type for BCE
            labels_one_hot = torch.zeros(labels.size(0), num_classes)
            labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)
            labels_one_hot = labels_one_hot.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels_one_hot)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        all_labels = []
        all_preds = []
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(DEVICE)
                # Convert labels to one-hot encoding and float type for BCE
                labels_one_hot = torch.zeros(labels.size(0), num_classes)
                labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)
                labels_one_hot = labels_one_hot.to(DEVICE)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels_one_hot)
                val_loss += loss.item()

                # Convert probabilities to predictions
                predicted = torch.max(outputs, 1)[1]
                total += labels.size(0)
                correct += (predicted.cpu() == labels).sum().item()
                
                all_labels.extend(labels.numpy())
                all_preds.extend(predicted.cpu().numpy())

        val_acc = correct / total

        print(f'Loss: {running_loss/len(train_loader)} , Val Loss: {val_loss/len(val_loader)}, Val Accuracy: {val_acc}')

        # Always keep track of the best model and best output
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch + 1
            best_model_wts = model.state_dict()
            best_preds = {
                'labels': all_labels,
                'preds': all_preds
            }

    # TRAIN LOOP ENDS HERE
    all_labels = best_preds['labels']
    all_preds = best_preds['preds']
    
    # Calculate confusion matrix
    conf_matrix = confusion_matrix(all_labels, all_preds)
    
    # Sensitivity, Specificity calculation
    sensitivity = []
    specificity = []
    
    for i in range(num_classes):
        tp = conf_matrix[i, i]
        fn = sum(conf_matrix[i, :]) - tp
        fp = sum(conf_matrix[:, i]) - tp
        tn = conf_matrix.sum() - (tp + fp + fn)

        sensitivity.append(tp / (tp + fn) if (tp + fn) > 0 else 0)
        specificity.append(tn / (tn + fp) if (tn + fp) > 0 else 0)

    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')

    best_hist = { 'Accuracy': accuracy, 'Precision': precision, 'F1 Score': f1, 'recall': recall }

    for class_name in dataset.class_to_idx:
        best_hist[f'{class_name} Sensitivity'] = sensitivity[dataset.class_to_idx[class_name]]
        best_hist[f'{class_name} Specificity'] = specificity[dataset.class_to_idx[class_name]]

    print()
    for key, value in best_hist.items():
        print(f'{key}: {value}')

    print()
    # get classification report
    print(classification_report(all_labels, all_preds, target_names=[i for i in dataset.class_to_idx]))
    print()

    # Save model to disk
    print(f'Best Epoch for Fold {fold+1}: {best_epoch} with Validation Accuracy: {best_val_acc}')
    torch.save(best_model_wts, f'fold{fold+1}_model.pth')
    
    history.append(best_hist)


--------------------------------
FOLD 1/5
--------------------------------
Epoch 1/10
Loss: 0.48672187702859787 , Val Loss: 0.31891365349292755, Val Accuracy: 0.8788343558282209
Epoch 2/10
Loss: 0.2921659669933472 , Val Loss: 0.25931311049994005, Val Accuracy: 0.906441717791411
Epoch 3/10
Loss: 0.23973783049832054 , Val Loss: 0.23746343337474984, Val Accuracy: 0.9095092024539877
Epoch 4/10
Loss: 0.20266131829609846 , Val Loss: 0.2123019612850027, Val Accuracy: 0.9263803680981595
Epoch 5/10
Loss: 0.18180278467542346 , Val Loss: 0.17132031410298448, Val Accuracy: 0.9309815950920245
Epoch 6/10
Loss: 0.1558632727134674 , Val Loss: 0.18904223465459777, Val Accuracy: 0.9386503067484663
Epoch 7/10
Loss: 0.14521032400469092 , Val Loss: 0.1837234071039773, Val Accuracy: 0.9294478527607362
Epoch 8/10
Loss: 0.12773524507143122 , Val Loss: 0.18615872103800166, Val Accuracy: 0.9371165644171779
Epoch 9/10
Loss: 0.11546220748441265 , Val Loss: 0.16681487941520012, Val Accuracy: 0.9325153374233128
Ep

In [7]:
history

[{'Accuracy': 0.9447852760736196,
  'Precision': 0.9309819897084048,
  'F1 Score': 0.9362970200293111,
  'recall': 0.9422662266226622,
  'N Sensitivity': 0.9356435643564357,
  'N Specificity': 0.9488888888888889,
  'P Sensitivity': 0.9488888888888889,
  'P Specificity': 0.9356435643564357},
 {'Accuracy': 0.9647239263803681,
  'Precision': 0.958074807480748,
  'F1 Score': 0.9587004238466771,
  'recall': 0.9593330465190677,
  'N Sensitivity': 0.945273631840796,
  'N Specificity': 0.9733924611973392,
  'P Sensitivity': 0.9733924611973392,
  'P Specificity': 0.945273631840796},
 {'Accuracy': 0.946236559139785,
  'Precision': 0.9448267542851891,
  'F1 Score': 0.9358452740618938,
  'recall': 0.9280762852404643,
  'N Sensitivity': 0.8805970149253731,
  'N Specificity': 0.9755555555555555,
  'P Sensitivity': 0.9755555555555555,
  'P Specificity': 0.8805970149253731},
 {'Accuracy': 0.9523809523809523,
  'Precision': 0.9523057304351823,
  'F1 Score': 0.9431772427405345,
  'recall': 0.93527363184

In [8]:
hist_df = pd.DataFrame(history)
hist_df

Unnamed: 0,Accuracy,Precision,F1 Score,recall,N Sensitivity,N Specificity,P Sensitivity,P Specificity
0,0.944785,0.930982,0.936297,0.942266,0.935644,0.948889,0.948889,0.935644
1,0.964724,0.958075,0.9587,0.959333,0.945274,0.973392,0.973392,0.945274
2,0.946237,0.944827,0.935845,0.928076,0.880597,0.975556,0.975556,0.880597
3,0.952381,0.952306,0.943177,0.935274,0.890547,0.98,0.98,0.890547
4,0.943164,0.939759,0.932378,0.925854,0.880597,0.971111,0.971111,0.880597


In [9]:
!zip PvN.zip ./*.pth

  adding: fold1_model.pth (deflated 7%)
  adding: fold2_model.pth (deflated 7%)
  adding: fold3_model.pth (deflated 7%)
  adding: fold4_model.pth (deflated 7%)
  adding: fold5_model.pth (deflated 7%)
