In [1]:
'''
MIXUP ON NEW VALIDATED DATASET

TRAINED ON 250 SAMPLES PER CLASS
VALIDATED ON 50 SAMPLES PER CLASS

              precision    recall  f1-score   support

         GNB       0.84      0.84      0.84        50
         GNC       0.89      0.80      0.84        50
         GPB       0.81      0.84      0.82        50
         GPC       0.79      0.84      0.82        50

    accuracy                           0.83       200
   macro avg       0.83      0.83      0.83       200
weighted avg       0.83      0.83      0.83       200

Class-wise Sensitivity (Recall): [0.84 0.8  0.84 0.84]
Class-wise Specificity: [0.94666667 0.96666667 0.93333333 0.92666667]



TRAINED ON ALL DATA
VALIDATED ON 50 SAMPLES PER CLASS

              precision    recall  f1-score   support

         GNB       0.85      0.94      0.90        50
         GNC       0.90      0.76      0.83        50
         GPB       0.91      0.82      0.86        50
         GPC       0.76      0.88      0.81        50

    accuracy                           0.85       200
   macro avg       0.86      0.85      0.85       200
weighted avg       0.86      0.85      0.85       200

Class-wise Sensitivity (Recall): [0.94 0.76 0.82 0.88]
Class-wise Specificity: [0.94666667 0.97333333 0.97333333 0.90666667]


'''
print()




In [1]:
import os
import random

from collections import defaultdict

import torch
from torch import nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import timm

from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, classification_report, roc_auc_score, recall_score
from sklearn.model_selection import KFold, StratifiedKFold

SEED = 16

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

In [3]:
# HELPER FUNCTIONS

def get_sample_count(folder_path):
    output = {}
    classes = os.listdir(folder_path)
    for i in classes:
        output[i] = len(os.listdir(f'{folder_path}/{i}'))
    return output


def mixup_images(input_folder_path, output_dir,image_rescale_size=(300, 300),  n=100, alpha=0.4):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    image_paths = [os.path.join(input_folder_path, f) for f in os.listdir(input_folder_path) if os.path.isfile(os.path.join(input_folder_path, f))]
    if len(image_paths) < 2:
        return "Error: There must be at least two images in the input folder."

    transform = transforms.Compose([
        transforms.Resize(image_rescale_size),  
        transforms.RandomHorizontalFlip(p=0.5), 
        transforms.RandomVerticalFlip(p=0.5),  
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1), 
        transforms.ToTensor()
    ])
    to_pil = transforms.ToPILImage()
    
    image_pairs = []
    for i in range(len(image_paths)):
        for j in range(i+1, len(image_paths)):
            image_pairs.append([image_paths[i], image_paths[j]])
   
    random.shuffle(image_pairs)

    for i in range(n):
        if len(image_pairs) == 0:
            break

        img_path1, img_path2 = image_pairs.pop()
        img1 = transform(Image.open(img_path1).convert("RGB"))
        img2 = transform(Image.open(img_path2).convert("RGB"))
        lam = np.random.beta(alpha, alpha)
        mixed_image = lam * img1 + (1 - lam) * img2
        mixed_image_pil = to_pil(mixed_image.clamp(0, 1)) 
        output_path = os.path.join(output_dir, f"mixed_image_{i + 1}.jpg")
        mixed_image_pil.save(output_path)


def balance_folder(path, sample_size):
    all_files = os.listdir(path)
    if sample_size > len(all_files):
        print("SAMPLE_SIZE TOO BIG")
        return

    to_delete = len(all_files) - sample_size
    print(f"THERE ARE {len(all_files)} FILES at {path}")
    print(f"{to_delete} FILES WILL BE REMOVED")

    for i in range(to_delete):
        index = random.randint(0, len(all_files) - 1)
        os.remove(path + '/' + all_files[index])
        all_files.remove(all_files[index])

    print('DONE')
    return

In [4]:
dataset_path = '/kaggle/input/dataset/Train'
validation_set_path = '/kaggle/input/dataset/Val'

print(get_sample_count(dataset_path))
print(get_sample_count(validation_set_path))

{'GNC': 200, 'GNB': 1457, 'GPC': 1505, 'GPB': 746}
{'GNC': 50, 'GNB': 50, 'GPC': 50, 'GPB': 50}


In [5]:
# Check for CUDA device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

# Parameters
model_name = 'tiny_vit_21m_512.dist_in22k_ft_in1k'
num_epochs = 20
num_classes = 4
batch_size = 14
k_folds = 5

cuda


In [6]:
m = timm.create_model(model_name, pretrained=True, num_classes=num_classes, drop_rate=0.3)
model_info = m.default_cfg
del m

model.safetensors:   0%|          | 0.00/85.2M [00:00<?, ?B/s]

In [7]:
input_shape = model_info['input_size'][1:]
transform_mean = model_info['mean']
transform_std = model_info['std']

print(f"USING MODEL ARCHITECTURE {model_info['architecture']} ")
print(f"INPUT SHAPE = {input_shape}")
print(f"       MEAN = {transform_mean}")
print(f"        STD = {transform_std}")

USING MODEL ARCHITECTURE tiny_vit_21m_512 
INPUT SHAPE = (512, 512)
       MEAN = (0.485, 0.456, 0.406)
        STD = (0.229, 0.224, 0.225)


In [8]:
transform = transforms.Compose([
        transforms.Resize(input_shape),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=transform_mean, std=transform_std)
    ])

dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

print("CLASS MAPPING")
print(dataset.class_to_idx)

CLASS MAPPING
{'GNB': 0, 'GNC': 1, 'GPB': 2, 'GPC': 3}


In [9]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        return F_loss.mean()

In [10]:
kfold = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=SEED)

history = []

# K-Fold Loop
for fold, (train_idx, val_idx) in enumerate(kfold.split(dataset, dataset.targets)):
    print('\n--------------------------------')
    print(f'FOLD {fold+1}/{k_folds}')
    print('--------------------------------')

    # Create data loaders for this fold
    train_subsampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_subsampler = torch.utils.data.SubsetRandomSampler(val_idx)

    train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_subsampler, num_workers=4)
    val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_subsampler, num_workers=4)

    # Initialize the model for this fold
    model = timm.create_model(model_name, pretrained=True, num_classes=num_classes, drop_rate=0.3)
    model = model.to(DEVICE)

    criterion = FocalLoss()
    # criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

    # Variables to keep track of the best model for this fold
    best_val_acc = -999
    best_epoch = 0
    best_model_wts = None
    best_preds = {
        'labels':None,
        'preds':None
    }

    # Model training loop START
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        model.train()
        running_loss = 0.0

        # Training
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        all_labels = []
        all_preds = []
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(predicted.cpu().numpy())

        val_acc = correct / total

        print(f'Loss: {running_loss/len(train_loader)} , Val Loss: {val_loss/len(val_loader)}, Val Accuracy: {val_acc}')

        # Always keep track of the best model and best output
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch + 1
            best_model_wts = model.state_dict()
            best_preds = {
                'labels': all_labels,
                'preds': all_preds
            }


    # TRAIN LOOP ENDS HERE
    all_labels = best_preds['labels']
    all_preds = best_preds['preds']
    
    # Calculate confusion matrix
    conf_matrix = confusion_matrix(all_labels, all_preds)
    
    # Sensitivity, Specificity calculation
    sensitivity = []
    specificity = []
    
    for i in range(num_classes):
        tp = conf_matrix[i, i]
        fn = sum(conf_matrix[i, :]) - tp
        fp = sum(conf_matrix[:, i]) - tp
        tn = conf_matrix.sum() - (tp + fp + fn)

        sensitivity.append(tp / (tp + fn) if (tp + fn) > 0 else 0)
        specificity.append(tn / (tn + fp) if (tn + fp) > 0 else 0)

    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')

    best_hist = { 'Accuracy': accuracy, 'Precision': precision, 'F1 Score': f1, 'recall': recall }

    for class_name in dataset.class_to_idx:
        best_hist[f'{class_name} Sensitivity'] = sensitivity[dataset.class_to_idx[class_name]]
        best_hist[f'{class_name} Specificity'] = specificity[dataset.class_to_idx[class_name]]

    print()
    for key, value in best_hist.items():
        print(f'{key}: {value}')

    print()
    # get classification report
    print(classification_report(all_labels, all_preds, target_names=[i for i in dataset.class_to_idx]))
    print()

    # Save model to disk
    print(f'Best Epoch for Fold {fold+1}: {best_epoch} with Validation Accuracy: {best_val_acc}')
    torch.save(best_model_wts, f'fold{fold+1}_model.pth')
    
    history.append(best_hist)


--------------------------------
FOLD 1/5
--------------------------------
Epoch 1/20
Loss: 0.11717063640077997 , Val Loss: 0.07638190350761372, Val Accuracy: 0.7749360613810742
Epoch 2/20
Loss: 0.0648724103895282 , Val Loss: 0.05630394573589521, Val Accuracy: 0.8248081841432225
Epoch 3/20
Loss: 0.04864463819415375 , Val Loss: 0.04166151732871575, Val Accuracy: 0.870843989769821
Epoch 4/20
Loss: 0.042635135417055735 , Val Loss: 0.04273778775573841, Val Accuracy: 0.860613810741688
Epoch 5/20
Loss: 0.03641462211505443 , Val Loss: 0.03734082887448104, Val Accuracy: 0.8772378516624041
Epoch 6/20
Loss: 0.030930963604727628 , Val Loss: 0.0323444791221326, Val Accuracy: 0.8797953964194374
Epoch 7/20
Loss: 0.030190895130674886 , Val Loss: 0.034476560598704964, Val Accuracy: 0.8938618925831202
Epoch 8/20
Loss: 0.026672805324778892 , Val Loss: 0.03437107445123339, Val Accuracy: 0.8849104859335039
Epoch 9/20
Loss: 0.02295732185302768 , Val Loss: 0.036068123806866685, Val Accuracy: 0.888746803069

In [11]:
history

[{'Accuracy': 0.9002557544757033,
  'Precision': 0.8717516402810689,
  'F1 Score': 0.8460631687258553,
  'recall': 0.8270411986706077,
  'GNB Sensitivity': 0.9623287671232876,
  'GNB Specificity': 0.936734693877551,
  'GNC Sensitivity': 0.6,
  'GNC Specificity': 0.9905660377358491,
  'GPB Sensitivity': 0.8322147651006712,
  'GPB Specificity': 0.9778830963665087,
  'GPC Sensitivity': 0.9136212624584718,
  'GPC Specificity': 0.9459459459459459},
 {'Accuracy': 0.90153452685422,
  'Precision': 0.8466063423543084,
  'F1 Score': 0.8264686446264207,
  'recall': 0.8126651279162926,
  'GNB Sensitivity': 0.9212328767123288,
  'GNB Specificity': 0.9673469387755103,
  'GNC Sensitivity': 0.5,
  'GNC Specificity': 0.9865229110512129,
  'GPB Sensitivity': 0.8859060402684564,
  'GPB Specificity': 0.9715639810426541,
  'GPC Sensitivity': 0.9435215946843853,
  'GPC Specificity': 0.9313929313929314},
 {'Accuracy': 0.9207161125319693,
  'Precision': 0.8840511943621795,
  'F1 Score': 0.8697587393039703,
  

In [12]:
hist_df = pd.DataFrame(history)

# save hist_df to disk
hist_df.to_csv('KFOLD_CV.csv', index=False)

hist_df

Unnamed: 0,Accuracy,Precision,F1 Score,recall,GNB Sensitivity,GNB Specificity,GNC Sensitivity,GNC Specificity,GPB Sensitivity,GPB Specificity,GPC Sensitivity,GPC Specificity
0,0.900256,0.871752,0.846063,0.827041,0.962329,0.936735,0.6,0.990566,0.832215,0.977883,0.913621,0.945946
1,0.901535,0.846606,0.826469,0.812665,0.921233,0.967347,0.5,0.986523,0.885906,0.971564,0.943522,0.931393
2,0.920716,0.884051,0.869759,0.858941,0.962199,0.977597,0.65,0.990566,0.893333,0.96519,0.930233,0.954262
3,0.910371,0.873151,0.868554,0.86626,0.95189,0.965306,0.7,0.989204,0.919463,0.957278,0.893688,0.9625
4,0.90525,0.890532,0.869855,0.853078,0.948454,0.957143,0.725,0.993252,0.798658,0.974684,0.940199,0.933333


In [13]:
# Test Models on validation data
model_dir = './'

val_transform = transforms.Compose([
        transforms.Resize(input_shape),
        transforms.ToTensor(),
        transforms.Normalize(mean=transform_mean, std=transform_std)
    ])

val_dataset = datasets.ImageFolder(root=validation_set_path, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

print("CLASS MAPPING")
print(val_dataset.class_to_idx)


majority_vote_preds = defaultdict(list)

# Iterate over each saved model for validation
for fold in range(1, k_folds + 1):
    print(f"Generating Predictions on model {fold}")
    # Load the saved model weights
    model = timm.create_model(model_name, pretrained=False, num_classes=num_classes)
    model.load_state_dict(torch.load(f'{model_dir}/fold{fold}_model.pth', map_location=torch.device(DEVICE)))
    model = model.to(DEVICE)
    model.eval()

    # Variables to keep track of performance metrics
    correct = 0
    total = 0
    all_labels = []
    all_preds = []
    
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(val_loader):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            
            _, predicted = torch.max(outputs, 1)

            # Store predictions for each batch item
            for idx, prediction in enumerate(predicted.cpu().numpy()):
                sample_index = i * batch_size + idx
                majority_vote_preds[sample_index].append(prediction)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    # Calculate accuracy
    accuracy = correct / total
    print(f'Validation Accuracy for Fold {fold}: {accuracy}')

    # Calculate and display other metrics
    conf_matrix = confusion_matrix(all_labels, all_preds)

    print(f'Fold {fold} Validation Metrics:')

    print(classification_report(all_labels, all_preds, target_names=[i for i in val_dataset.class_to_idx]))
    print()

CLASS MAPPING
{'GNB': 0, 'GNC': 1, 'GPB': 2, 'GPC': 3}
Generating Predictions on model 1


  model.load_state_dict(torch.load(f'{model_dir}/fold{fold}_model.pth', map_location=torch.device(DEVICE)))


Validation Accuracy for Fold 1: 0.87
Fold 1 Validation Metrics:
              precision    recall  f1-score   support

         GNB       0.89      0.94      0.91        50
         GNC       0.89      0.82      0.85        50
         GPB       0.95      0.84      0.89        50
         GPC       0.77      0.88      0.82        50

    accuracy                           0.87       200
   macro avg       0.88      0.87      0.87       200
weighted avg       0.88      0.87      0.87       200


Generating Predictions on model 2


  model.load_state_dict(torch.load(f'{model_dir}/fold{fold}_model.pth', map_location=torch.device(DEVICE)))


Validation Accuracy for Fold 2: 0.85
Fold 2 Validation Metrics:
              precision    recall  f1-score   support

         GNB       0.87      0.96      0.91        50
         GNC       0.92      0.66      0.77        50
         GPB       0.93      0.86      0.90        50
         GPC       0.73      0.92      0.81        50

    accuracy                           0.85       200
   macro avg       0.86      0.85      0.85       200
weighted avg       0.86      0.85      0.85       200


Generating Predictions on model 3


  model.load_state_dict(torch.load(f'{model_dir}/fold{fold}_model.pth', map_location=torch.device(DEVICE)))


Validation Accuracy for Fold 3: 0.84
Fold 3 Validation Metrics:
              precision    recall  f1-score   support

         GNB       0.83      0.96      0.89        50
         GNC       0.86      0.74      0.80        50
         GPB       0.95      0.80      0.87        50
         GPC       0.75      0.86      0.80        50

    accuracy                           0.84       200
   macro avg       0.85      0.84      0.84       200
weighted avg       0.85      0.84      0.84       200


Generating Predictions on model 4


  model.load_state_dict(torch.load(f'{model_dir}/fold{fold}_model.pth', map_location=torch.device(DEVICE)))


Validation Accuracy for Fold 4: 0.84
Fold 4 Validation Metrics:
              precision    recall  f1-score   support

         GNB       0.84      0.94      0.89        50
         GNC       0.92      0.66      0.77        50
         GPB       0.91      0.84      0.87        50
         GPC       0.74      0.92      0.82        50

    accuracy                           0.84       200
   macro avg       0.85      0.84      0.84       200
weighted avg       0.85      0.84      0.84       200


Generating Predictions on model 5


  model.load_state_dict(torch.load(f'{model_dir}/fold{fold}_model.pth', map_location=torch.device(DEVICE)))


Validation Accuracy for Fold 5: 0.87
Fold 5 Validation Metrics:
              precision    recall  f1-score   support

         GNB       0.91      0.96      0.93        50
         GNC       0.93      0.80      0.86        50
         GPB       0.91      0.84      0.87        50
         GPC       0.76      0.88      0.81        50

    accuracy                           0.87       200
   macro avg       0.88      0.87      0.87       200
weighted avg       0.88      0.87      0.87       200




In [14]:
# Perform majority voting
final_preds = []
for sample_index in sorted(majority_vote_preds.keys()):
    # Majority voting across folds
    votes = majority_vote_preds[sample_index]
    final_pred = np.bincount(votes).argmax()  # Get the class with maximum votes
    final_preds.append(final_pred)

# Collect true labels for evaluation
true_labels = []
for inputs, labels in val_loader:
    true_labels.extend(labels.cpu().numpy())

# Calculate metrics based on final predictions
print('Final Model Validation Metrics:')
conf_matrix = confusion_matrix(true_labels, final_preds)
print(conf_matrix)
print(classification_report(true_labels, final_preds, target_names=[i for i in val_dataset.class_to_idx]))

# Initialize arrays to store sensitivity and specificity for each class
num_classes = len(conf_matrix)
sensitivity = np.zeros(num_classes)
specificity = np.zeros(num_classes)

# Calculate sensitivity and specificity for each class
for i in range(num_classes):
    tp = conf_matrix[i, i]  # True Positives
    fn = np.sum(conf_matrix[i, :]) - tp  # False Negatives
    fp = np.sum(conf_matrix[:, i]) - tp  # False Positives
    tn = np.sum(conf_matrix) - (tp + fn + fp)  # True Negatives
    
    # Calculate sensitivity and specificity
    sensitivity[i] = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity[i] = tn / (tn + fp) if (tn + fp) > 0 else 0

# Display the results
print("Class-wise Sensitivity (Recall):", sensitivity)
print("Class-wise Specificity:", specificity)

Final Model Validation Metrics:
[[47  2  1  0]
 [ 6 38  0  6]
 [ 1  0 41  8]
 [ 1  2  3 44]]
              precision    recall  f1-score   support

         GNB       0.85      0.94      0.90        50
         GNC       0.90      0.76      0.83        50
         GPB       0.91      0.82      0.86        50
         GPC       0.76      0.88      0.81        50

    accuracy                           0.85       200
   macro avg       0.86      0.85      0.85       200
weighted avg       0.86      0.85      0.85       200

Class-wise Sensitivity (Recall): [0.94 0.76 0.82 0.88]
Class-wise Specificity: [0.94666667 0.97333333 0.97333333 0.90666667]
