In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Install necessary libraries (if not already installed)
# !pip install --upgrade albumentations timm

import os
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import timm
from sklearn.metrics import confusion_matrix
import numpy as np
from tqdm import tqdm
import multiprocessing

import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

# Define data directories
train_dir = '/content/drive/MyDrive/train/'
val_dir = '/content/drive/MyDrive/val/val/'

# Robust data augmentation
train_transforms = A.Compose([
    A.Resize(224, 224),
    A.OneOf([
        A.RandomBrightnessContrast(),
        A.ColorJitter(),
        A.InvertImg(),
    ], p=0.8),
    A.OneOf([
        A.HorizontalFlip(),
        A.VerticalFlip(),
        A.RandomRotate90(),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45),
        A.Affine(shear=(-30, 30)),
        A.Rotate(limit=45),
        A.CoarseDropout(max_holes=8, max_height=8, max_width=8, fill_value=0, p=0.5),
        A.Transpose(),
    ], p=0.8),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transforms = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# Custom dataset to use Albumentations
class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.dataset = ImageFolder(root=root_dir)
        self.transform = transform

    def __getitem__(self, idx):
        img_path, label = self.dataset.samples[idx]
        image = np.array(Image.open(img_path).convert('RGB'))
        if self.transform:
            transformed = self.transform(image=image)
            image = transformed['image']
        return image, label

    def __len__(self):
        return len(self.dataset)

# Load datasets
train_dataset = CustomImageDataset(root_dir=train_dir, transform=train_transforms)
val_dataset = CustomImageDataset(root_dir=val_dir, transform=val_transforms)

# Verify class-to-label mapping
print("Class-to-Label Mapping for Training Dataset:")
print(train_dataset.dataset.class_to_idx)

# DataLoaders
num_workers = multiprocessing.cpu_count() - 1
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=num_workers, pin_memory=True)

# Define models with dropout
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Swin Transformer Model with Dropout and Weight Decay
model_swin = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=2, drop_rate=0.2)
model_swin = model_swin.to(device)

# RegNetY Model with Dropout and Weight Decay
model_regnet = timm.create_model('regnety_032', pretrained=True, num_classes=2, drop_rate=0.2)
model_regnet = model_regnet.to(device)

# Define loss and optimizer with weight decay
criterion = nn.CrossEntropyLoss()
optimizer_swin = torch.optim.AdamW(model_swin.parameters(), lr=1e-4, weight_decay=1e-5)
optimizer_regnet = torch.optim.AdamW(model_regnet.parameters(), lr=1e-4, weight_decay=1e-5)

# Learning rate scheduler - Cosine Annealing
scheduler_swin = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_swin, T_max=50)
scheduler_regnet = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_regnet, T_max=50)

# Training parameters
num_epochs = 50

# For weighted ensemble
best_val_accuracy_swin = 0
best_val_accuracy_regnet = 0
best_val_accuracy_ensemble = 0

for epoch in range(num_epochs):
    model_swin.train()
    model_regnet.train()
    running_loss_swin = 0.0
    running_loss_regnet = 0.0

    # Training loop
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        # Zero the parameter gradients
        optimizer_swin.zero_grad()
        optimizer_regnet.zero_grad()

        # Forward pass
        outputs_swin = model_swin(images)
        outputs_regnet = model_regnet(images)

        # Compute loss
        loss_swin = criterion(outputs_swin, labels)
        loss_regnet = criterion(outputs_regnet, labels)
        total_loss = loss_swin + loss_regnet

        # Backward pass and optimization
        total_loss.backward()
        optimizer_swin.step()
        optimizer_regnet.step()

        running_loss_swin += loss_swin.item()
        running_loss_regnet += loss_regnet.item()

    # Calculate average training losses
    avg_loss_swin = running_loss_swin / len(train_loader)
    avg_loss_regnet = running_loss_regnet / len(train_loader)
    avg_total_loss = avg_loss_swin + avg_loss_regnet
    print(f'Epoch [{epoch+1}/{num_epochs}] Average Training Loss - Swin: {avg_loss_swin:.4f}, RegNetY: {avg_loss_regnet:.4f}')

    # Validation loop
    model_swin.eval()
    model_regnet.eval()
    val_loss_swin = 0.0
    val_loss_regnet = 0.0
    correct_swin = 0
    correct_regnet = 0
    correct_ensemble = 0
    total = 0
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs_swin = model_swin(images)
            outputs_regnet = model_regnet(images)

            # Compute validation loss
            loss_swin = criterion(outputs_swin, labels)
            loss_regnet = criterion(outputs_regnet, labels)
            val_loss_swin += loss_swin.item()
            val_loss_regnet += loss_regnet.item()

            # Individual model predictions
            _, predicted_swin = torch.max(outputs_swin.data, 1)
            _, predicted_regnet = torch.max(outputs_regnet.data, 1)

            # Ensemble by averaging the outputs
            outputs = (outputs_swin + outputs_regnet) / 2
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct_swin += (predicted_swin == labels).sum().item()
            correct_regnet += (predicted_regnet == labels).sum().item()
            correct_ensemble += (predicted == labels).sum().item()

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    # Calculate average validation losses and accuracy
    avg_val_loss_swin = val_loss_swin / len(val_loader)
    avg_val_loss_regnet = val_loss_regnet / len(val_loader)
    avg_val_loss = avg_val_loss_swin + avg_val_loss_regnet
    val_accuracy_swin = 100 * correct_swin / total
    val_accuracy_regnet = 100 * correct_regnet / total
    val_accuracy_ensemble = 100 * correct_ensemble / total
    print(f'Validation Loss - Swin: {avg_val_loss_swin:.4f}, RegNetY: {avg_val_loss_regnet:.4f}, Total: {avg_val_loss:.4f}')
    print(f'Validation Accuracy - Swin: {val_accuracy_swin:.2f}%, RegNetY: {val_accuracy_regnet:.2f}%, Ensemble: {val_accuracy_ensemble:.2f}%')

    # Update learning rate schedulers
    scheduler_swin.step()
    scheduler_regnet.step()

    # Save the model based on best ensemble validation accuracy
    if val_accuracy_ensemble > best_val_accuracy_ensemble:
        best_val_accuracy_ensemble = val_accuracy_ensemble
        # Save the best model
        torch.save({
            'epoch': epoch + 1,
            'model_swin_state_dict': model_swin.state_dict(),
            'model_regnet_state_dict': model_regnet.state_dict(),
            'optimizer_swin_state_dict': optimizer_swin.state_dict(),
            'optimizer_regnet_state_dict': optimizer_regnet.state_dict(),
            'val_accuracy_ensemble': best_val_accuracy_ensemble,
            'best_val_accuracy_swin': val_accuracy_swin,
            'best_val_accuracy_regnet': val_accuracy_regnet,
        }, 'best_model.pth')
        print(f'Best model saved at epoch {epoch+1} with ensemble validation accuracy {best_val_accuracy_ensemble:.2f}%')

        # Update best validation accuracies
        best_val_accuracy_swin = val_accuracy_swin
        best_val_accuracy_regnet = val_accuracy_regnet
    else:
        print('No improvement in ensemble validation accuracy.')

print('Training completed.')

# Load the best model
checkpoint = torch.load('best_model.pth')
model_swin.load_state_dict(checkpoint['model_swin_state_dict'])
model_regnet.load_state_dict(checkpoint['model_regnet_state_dict'])
best_val_accuracy_swin = checkpoint['best_val_accuracy_swin']
best_val_accuracy_regnet = checkpoint['best_val_accuracy_regnet']

# Evaluate on the validation set
model_swin.eval()
model_regnet.eval()
all_labels = []
all_preds = []
correct_swin = 0
correct_regnet = 0
correct_ensemble = 0
total = 0
with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        outputs_swin = model_swin(images)
        outputs_regnet = model_regnet(images)

        # Individual model predictions
        _, predicted_swin = torch.max(outputs_swin.data, 1)
        _, predicted_regnet = torch.max(outputs_regnet.data, 1)

        # Ensemble by weighted averaging based on best validation accuracies
        weight_swin = best_val_accuracy_swin / (best_val_accuracy_swin + best_val_accuracy_regnet)
        weight_regnet = best_val_accuracy_regnet / (best_val_accuracy_swin + best_val_accuracy_regnet)
        outputs = (weight_swin * outputs_swin + weight_regnet * outputs_regnet)

        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct_swin += (predicted_swin == labels).sum().item()
        correct_regnet += (predicted_regnet == labels).sum().item()
        correct_ensemble += (predicted == labels).sum().item()

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

# Compute final accuracies
final_accuracy_swin = 100 * correct_swin / total
final_accuracy_regnet = 100 * correct_regnet / total
final_accuracy_ensemble = 100 * correct_ensemble / total
print(f'Final Validation Accuracy - Swin: {final_accuracy_swin:.2f}%, RegNetY: {final_accuracy_regnet:.2f}%, Ensemble: {final_accuracy_ensemble:.2f}%')

# Compute confusion matrix
cm = confusion_matrix(all_labels, all_preds)
print('Confusion Matrix:')
print(cm)


Class-to-Label Mapping for Training Dataset:
{'nevus': 0, 'others': 1}


Epoch 1/50: 100%|██████████| 475/475 [03:51<00:00,  2.06it/s]

Epoch [1/50] Average Training Loss - Swin: 0.4209, RegNetY: 0.4320





Validation Loss - Swin: 0.3404, RegNetY: 0.3754, Total: 0.7158
Validation Accuracy - Swin: 85.01%, RegNetY: 83.59%, Ensemble: 84.69%
Best model saved at epoch 1 with ensemble validation accuracy 84.69%


Epoch 2/50: 100%|██████████| 475/475 [02:38<00:00,  3.00it/s]

Epoch [2/50] Average Training Loss - Swin: 0.3451, RegNetY: 0.3511





Validation Loss - Swin: 0.3211, RegNetY: 0.3104, Total: 0.6315
Validation Accuracy - Swin: 86.56%, RegNetY: 86.85%, Ensemble: 87.43%
Best model saved at epoch 2 with ensemble validation accuracy 87.43%


Epoch 3/50: 100%|██████████| 475/475 [02:37<00:00,  3.01it/s]

Epoch [3/50] Average Training Loss - Swin: 0.3179, RegNetY: 0.3221





Validation Loss - Swin: 0.2885, RegNetY: 0.3179, Total: 0.6064
Validation Accuracy - Swin: 88.22%, RegNetY: 86.93%, Ensemble: 88.36%
Best model saved at epoch 3 with ensemble validation accuracy 88.36%


Epoch 4/50: 100%|██████████| 475/475 [02:37<00:00,  3.02it/s]

Epoch [4/50] Average Training Loss - Swin: 0.2927, RegNetY: 0.2921





Validation Loss - Swin: 0.2821, RegNetY: 0.2976, Total: 0.5797
Validation Accuracy - Swin: 88.86%, RegNetY: 88.12%, Ensemble: 89.07%
Best model saved at epoch 4 with ensemble validation accuracy 89.07%


Epoch 5/50: 100%|██████████| 475/475 [02:38<00:00,  3.01it/s]

Epoch [5/50] Average Training Loss - Swin: 0.2651, RegNetY: 0.2637





Validation Loss - Swin: 0.3043, RegNetY: 0.2804, Total: 0.5847
Validation Accuracy - Swin: 86.64%, RegNetY: 88.67%, Ensemble: 89.09%
Best model saved at epoch 5 with ensemble validation accuracy 89.09%


Epoch 6/50: 100%|██████████| 475/475 [02:38<00:00,  3.00it/s]

Epoch [6/50] Average Training Loss - Swin: 0.2399, RegNetY: 0.2453





Validation Loss - Swin: 0.2667, RegNetY: 0.2863, Total: 0.5530
Validation Accuracy - Swin: 89.09%, RegNetY: 88.17%, Ensemble: 89.52%
Best model saved at epoch 6 with ensemble validation accuracy 89.52%


Epoch 7/50: 100%|██████████| 475/475 [02:37<00:00,  3.01it/s]

Epoch [7/50] Average Training Loss - Swin: 0.2293, RegNetY: 0.2212





Validation Loss - Swin: 0.2922, RegNetY: 0.3019, Total: 0.5942
Validation Accuracy - Swin: 88.49%, RegNetY: 88.04%, Ensemble: 89.57%
Best model saved at epoch 7 with ensemble validation accuracy 89.57%


Epoch 8/50: 100%|██████████| 475/475 [02:37<00:00,  3.02it/s]

Epoch [8/50] Average Training Loss - Swin: 0.2088, RegNetY: 0.2112





Validation Loss - Swin: 0.2723, RegNetY: 0.2804, Total: 0.5527
Validation Accuracy - Swin: 89.01%, RegNetY: 89.33%, Ensemble: 90.46%
Best model saved at epoch 8 with ensemble validation accuracy 90.46%


Epoch 9/50: 100%|██████████| 475/475 [02:37<00:00,  3.01it/s]

Epoch [9/50] Average Training Loss - Swin: 0.1956, RegNetY: 0.1993





Validation Loss - Swin: 0.2926, RegNetY: 0.3021, Total: 0.5947
Validation Accuracy - Swin: 89.30%, RegNetY: 88.04%, Ensemble: 90.09%
No improvement in ensemble validation accuracy.


Epoch 10/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [10/50] Average Training Loss - Swin: 0.1808, RegNetY: 0.1752





Validation Loss - Swin: 0.2568, RegNetY: 0.2857, Total: 0.5425
Validation Accuracy - Swin: 89.83%, RegNetY: 89.67%, Ensemble: 91.39%
Best model saved at epoch 10 with ensemble validation accuracy 91.39%


Epoch 11/50: 100%|██████████| 475/475 [02:38<00:00,  3.00it/s]

Epoch [11/50] Average Training Loss - Swin: 0.1640, RegNetY: 0.1591





Validation Loss - Swin: 0.3130, RegNetY: 0.2965, Total: 0.6095
Validation Accuracy - Swin: 89.12%, RegNetY: 89.09%, Ensemble: 90.25%
No improvement in ensemble validation accuracy.


Epoch 12/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [12/50] Average Training Loss - Swin: 0.1465, RegNetY: 0.1490





Validation Loss - Swin: 0.2957, RegNetY: 0.2821, Total: 0.5778
Validation Accuracy - Swin: 89.46%, RegNetY: 89.67%, Ensemble: 90.91%
No improvement in ensemble validation accuracy.


Epoch 13/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [13/50] Average Training Loss - Swin: 0.1383, RegNetY: 0.1374





Validation Loss - Swin: 0.2984, RegNetY: 0.3143, Total: 0.6126
Validation Accuracy - Swin: 89.81%, RegNetY: 89.12%, Ensemble: 91.23%
No improvement in ensemble validation accuracy.


Epoch 14/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [14/50] Average Training Loss - Swin: 0.1298, RegNetY: 0.1239





Validation Loss - Swin: 0.2758, RegNetY: 0.3116, Total: 0.5874
Validation Accuracy - Swin: 89.94%, RegNetY: 89.33%, Ensemble: 91.28%
No improvement in ensemble validation accuracy.


Epoch 15/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [15/50] Average Training Loss - Swin: 0.1202, RegNetY: 0.1243





Validation Loss - Swin: 0.2837, RegNetY: 0.3061, Total: 0.5898
Validation Accuracy - Swin: 89.83%, RegNetY: 89.62%, Ensemble: 91.10%
No improvement in ensemble validation accuracy.


Epoch 16/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [16/50] Average Training Loss - Swin: 0.1031, RegNetY: 0.1129





Validation Loss - Swin: 0.3372, RegNetY: 0.3138, Total: 0.6511
Validation Accuracy - Swin: 89.91%, RegNetY: 90.20%, Ensemble: 91.54%
Best model saved at epoch 16 with ensemble validation accuracy 91.54%


Epoch 17/50: 100%|██████████| 475/475 [02:37<00:00,  3.01it/s]

Epoch [17/50] Average Training Loss - Swin: 0.1004, RegNetY: 0.1046





Validation Loss - Swin: 0.2817, RegNetY: 0.3061, Total: 0.5878
Validation Accuracy - Swin: 91.33%, RegNetY: 90.41%, Ensemble: 92.10%
Best model saved at epoch 17 with ensemble validation accuracy 92.10%


Epoch 18/50: 100%|██████████| 475/475 [02:37<00:00,  3.02it/s]

Epoch [18/50] Average Training Loss - Swin: 0.0803, RegNetY: 0.0946





Validation Loss - Swin: 0.3218, RegNetY: 0.3118, Total: 0.6336
Validation Accuracy - Swin: 91.02%, RegNetY: 90.67%, Ensemble: 92.44%
Best model saved at epoch 18 with ensemble validation accuracy 92.44%


Epoch 19/50: 100%|██████████| 475/475 [02:38<00:00,  3.00it/s]

Epoch [19/50] Average Training Loss - Swin: 0.0811, RegNetY: 0.0870





Validation Loss - Swin: 0.2783, RegNetY: 0.3430, Total: 0.6213
Validation Accuracy - Swin: 91.62%, RegNetY: 90.38%, Ensemble: 91.97%
No improvement in ensemble validation accuracy.


Epoch 20/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [20/50] Average Training Loss - Swin: 0.0750, RegNetY: 0.0819





Validation Loss - Swin: 0.2881, RegNetY: 0.3374, Total: 0.6255
Validation Accuracy - Swin: 91.23%, RegNetY: 90.62%, Ensemble: 92.02%
No improvement in ensemble validation accuracy.


Epoch 21/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [21/50] Average Training Loss - Swin: 0.0688, RegNetY: 0.0747





Validation Loss - Swin: 0.3087, RegNetY: 0.3343, Total: 0.6430
Validation Accuracy - Swin: 91.07%, RegNetY: 90.49%, Ensemble: 91.83%
No improvement in ensemble validation accuracy.


Epoch 22/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [22/50] Average Training Loss - Swin: 0.0616, RegNetY: 0.0718





Validation Loss - Swin: 0.3663, RegNetY: 0.3236, Total: 0.6899
Validation Accuracy - Swin: 91.02%, RegNetY: 90.78%, Ensemble: 92.18%
No improvement in ensemble validation accuracy.


Epoch 23/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [23/50] Average Training Loss - Swin: 0.0590, RegNetY: 0.0653





Validation Loss - Swin: 0.3787, RegNetY: 0.3440, Total: 0.7227
Validation Accuracy - Swin: 90.73%, RegNetY: 90.83%, Ensemble: 92.02%
No improvement in ensemble validation accuracy.


Epoch 24/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [24/50] Average Training Loss - Swin: 0.0552, RegNetY: 0.0643





Validation Loss - Swin: 0.4103, RegNetY: 0.3519, Total: 0.7621
Validation Accuracy - Swin: 90.57%, RegNetY: 91.04%, Ensemble: 92.02%
No improvement in ensemble validation accuracy.


Epoch 25/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [25/50] Average Training Loss - Swin: 0.0508, RegNetY: 0.0578





Validation Loss - Swin: 0.3643, RegNetY: 0.3544, Total: 0.7187
Validation Accuracy - Swin: 91.41%, RegNetY: 90.44%, Ensemble: 92.54%
Best model saved at epoch 25 with ensemble validation accuracy 92.54%


Epoch 26/50: 100%|██████████| 475/475 [02:38<00:00,  3.00it/s]

Epoch [26/50] Average Training Loss - Swin: 0.0409, RegNetY: 0.0521





Validation Loss - Swin: 0.4156, RegNetY: 0.3551, Total: 0.7707
Validation Accuracy - Swin: 91.07%, RegNetY: 90.99%, Ensemble: 92.44%
No improvement in ensemble validation accuracy.


Epoch 27/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [27/50] Average Training Loss - Swin: 0.0433, RegNetY: 0.0498





Validation Loss - Swin: 0.3533, RegNetY: 0.3651, Total: 0.7184
Validation Accuracy - Swin: 91.44%, RegNetY: 90.94%, Ensemble: 92.57%
Best model saved at epoch 27 with ensemble validation accuracy 92.57%


Epoch 28/50: 100%|██████████| 475/475 [02:37<00:00,  3.01it/s]

Epoch [28/50] Average Training Loss - Swin: 0.0380, RegNetY: 0.0482





Validation Loss - Swin: 0.3509, RegNetY: 0.3600, Total: 0.7109
Validation Accuracy - Swin: 92.04%, RegNetY: 91.02%, Ensemble: 92.65%
Best model saved at epoch 28 with ensemble validation accuracy 92.65%


Epoch 29/50: 100%|██████████| 475/475 [02:37<00:00,  3.02it/s]

Epoch [29/50] Average Training Loss - Swin: 0.0321, RegNetY: 0.0444





Validation Loss - Swin: 0.3689, RegNetY: 0.3677, Total: 0.7366
Validation Accuracy - Swin: 91.44%, RegNetY: 91.15%, Ensemble: 92.52%
No improvement in ensemble validation accuracy.


Epoch 30/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [30/50] Average Training Loss - Swin: 0.0331, RegNetY: 0.0439





Validation Loss - Swin: 0.3505, RegNetY: 0.3653, Total: 0.7158
Validation Accuracy - Swin: 91.65%, RegNetY: 91.28%, Ensemble: 92.86%
Best model saved at epoch 30 with ensemble validation accuracy 92.86%


Epoch 31/50: 100%|██████████| 475/475 [02:38<00:00,  3.01it/s]

Epoch [31/50] Average Training Loss - Swin: 0.0266, RegNetY: 0.0353





Validation Loss - Swin: 0.3674, RegNetY: 0.3749, Total: 0.7423
Validation Accuracy - Swin: 92.10%, RegNetY: 91.33%, Ensemble: 93.02%
Best model saved at epoch 31 with ensemble validation accuracy 93.02%


Epoch 32/50: 100%|██████████| 475/475 [02:38<00:00,  3.00it/s]

Epoch [32/50] Average Training Loss - Swin: 0.0276, RegNetY: 0.0348





Validation Loss - Swin: 0.4472, RegNetY: 0.4077, Total: 0.8549
Validation Accuracy - Swin: 91.20%, RegNetY: 91.39%, Ensemble: 92.60%
No improvement in ensemble validation accuracy.


Epoch 33/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [33/50] Average Training Loss - Swin: 0.0239, RegNetY: 0.0358





Validation Loss - Swin: 0.4070, RegNetY: 0.3852, Total: 0.7923
Validation Accuracy - Swin: 91.89%, RegNetY: 91.02%, Ensemble: 92.57%
No improvement in ensemble validation accuracy.


Epoch 34/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [34/50] Average Training Loss - Swin: 0.0186, RegNetY: 0.0317





Validation Loss - Swin: 0.4066, RegNetY: 0.3762, Total: 0.7827
Validation Accuracy - Swin: 92.20%, RegNetY: 91.39%, Ensemble: 92.73%
No improvement in ensemble validation accuracy.


Epoch 35/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [35/50] Average Training Loss - Swin: 0.0172, RegNetY: 0.0263





Validation Loss - Swin: 0.4230, RegNetY: 0.3932, Total: 0.8162
Validation Accuracy - Swin: 91.91%, RegNetY: 91.31%, Ensemble: 93.02%
No improvement in ensemble validation accuracy.


Epoch 36/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [36/50] Average Training Loss - Swin: 0.0137, RegNetY: 0.0260





Validation Loss - Swin: 0.4983, RegNetY: 0.4039, Total: 0.9022
Validation Accuracy - Swin: 92.20%, RegNetY: 91.41%, Ensemble: 92.91%
No improvement in ensemble validation accuracy.


Epoch 37/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [37/50] Average Training Loss - Swin: 0.0150, RegNetY: 0.0260





Validation Loss - Swin: 0.4337, RegNetY: 0.4030, Total: 0.8368
Validation Accuracy - Swin: 91.86%, RegNetY: 90.83%, Ensemble: 92.62%
No improvement in ensemble validation accuracy.


Epoch 38/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [38/50] Average Training Loss - Swin: 0.0137, RegNetY: 0.0226





Validation Loss - Swin: 0.4576, RegNetY: 0.4062, Total: 0.8638
Validation Accuracy - Swin: 92.20%, RegNetY: 91.46%, Ensemble: 93.05%
Best model saved at epoch 38 with ensemble validation accuracy 93.05%


Epoch 39/50: 100%|██████████| 475/475 [02:38<00:00,  3.00it/s]

Epoch [39/50] Average Training Loss - Swin: 0.0133, RegNetY: 0.0245





Validation Loss - Swin: 0.4201, RegNetY: 0.4057, Total: 0.8258
Validation Accuracy - Swin: 92.07%, RegNetY: 91.57%, Ensemble: 93.05%
No improvement in ensemble validation accuracy.


Epoch 40/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [40/50] Average Training Loss - Swin: 0.0093, RegNetY: 0.0240





Validation Loss - Swin: 0.4772, RegNetY: 0.4052, Total: 0.8825
Validation Accuracy - Swin: 92.02%, RegNetY: 91.46%, Ensemble: 92.97%
No improvement in ensemble validation accuracy.


Epoch 41/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [41/50] Average Training Loss - Swin: 0.0077, RegNetY: 0.0220





Validation Loss - Swin: 0.4986, RegNetY: 0.4043, Total: 0.9028
Validation Accuracy - Swin: 92.07%, RegNetY: 91.10%, Ensemble: 93.15%
Best model saved at epoch 41 with ensemble validation accuracy 93.15%


Epoch 42/50: 100%|██████████| 475/475 [02:38<00:00,  3.00it/s]

Epoch [42/50] Average Training Loss - Swin: 0.0074, RegNetY: 0.0202





Validation Loss - Swin: 0.5077, RegNetY: 0.3952, Total: 0.9029
Validation Accuracy - Swin: 92.49%, RegNetY: 91.49%, Ensemble: 93.20%
Best model saved at epoch 42 with ensemble validation accuracy 93.20%


Epoch 43/50: 100%|██████████| 475/475 [02:38<00:00,  3.00it/s]

Epoch [43/50] Average Training Loss - Swin: 0.0083, RegNetY: 0.0177





Validation Loss - Swin: 0.5159, RegNetY: 0.4003, Total: 0.9163
Validation Accuracy - Swin: 92.33%, RegNetY: 91.41%, Ensemble: 92.86%
No improvement in ensemble validation accuracy.


Epoch 44/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [44/50] Average Training Loss - Swin: 0.0076, RegNetY: 0.0183





Validation Loss - Swin: 0.4915, RegNetY: 0.4049, Total: 0.8964
Validation Accuracy - Swin: 92.36%, RegNetY: 91.36%, Ensemble: 92.94%
No improvement in ensemble validation accuracy.


Epoch 45/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [45/50] Average Training Loss - Swin: 0.0094, RegNetY: 0.0192





Validation Loss - Swin: 0.4729, RegNetY: 0.4187, Total: 0.8916
Validation Accuracy - Swin: 92.04%, RegNetY: 91.49%, Ensemble: 92.97%
No improvement in ensemble validation accuracy.


Epoch 46/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [46/50] Average Training Loss - Swin: 0.0063, RegNetY: 0.0160





Validation Loss - Swin: 0.4794, RegNetY: 0.4135, Total: 0.8929
Validation Accuracy - Swin: 92.23%, RegNetY: 91.39%, Ensemble: 93.02%
No improvement in ensemble validation accuracy.


Epoch 47/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [47/50] Average Training Loss - Swin: 0.0075, RegNetY: 0.0133





Validation Loss - Swin: 0.4765, RegNetY: 0.4149, Total: 0.8914
Validation Accuracy - Swin: 92.31%, RegNetY: 91.54%, Ensemble: 93.07%
No improvement in ensemble validation accuracy.


Epoch 48/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [48/50] Average Training Loss - Swin: 0.0068, RegNetY: 0.0162





Validation Loss - Swin: 0.4781, RegNetY: 0.4194, Total: 0.8974
Validation Accuracy - Swin: 92.33%, RegNetY: 91.49%, Ensemble: 92.99%
No improvement in ensemble validation accuracy.


Epoch 49/50: 100%|██████████| 475/475 [02:35<00:00,  3.06it/s]

Epoch [49/50] Average Training Loss - Swin: 0.0063, RegNetY: 0.0177





Validation Loss - Swin: 0.4773, RegNetY: 0.4120, Total: 0.8893
Validation Accuracy - Swin: 92.33%, RegNetY: 91.57%, Ensemble: 93.18%
No improvement in ensemble validation accuracy.


Epoch 50/50: 100%|██████████| 475/475 [02:35<00:00,  3.05it/s]

Epoch [50/50] Average Training Loss - Swin: 0.0076, RegNetY: 0.0186





Validation Loss - Swin: 0.4769, RegNetY: 0.4088, Total: 0.8856
Validation Accuracy - Swin: 92.31%, RegNetY: 91.44%, Ensemble: 93.20%
No improvement in ensemble validation accuracy.
Training completed.


  checkpoint = torch.load('best_model.pth')


Final Validation Accuracy - Swin: 92.49%, RegNetY: 91.49%, Ensemble: 93.20%
Confusion Matrix:
[[1822  109]
 [ 149 1716]]


In [None]:
# Import necessary libraries
import os
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import timm
from sklearn.metrics import confusion_matrix, cohen_kappa_score
import numpy as np
from tqdm import tqdm
import multiprocessing

import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

# Define data directories
train_dir = '/content/drive/MyDrive/train/'
val_dir = '/content/drive/MyDrive/val/val/'

# Data augmentation for validation (usually minimal)
val_transforms = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# Custom dataset to use Albumentations
class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.dataset = ImageFolder(root=root_dir)
        self.transform = transform

    def __getitem__(self, idx):
        img_path, label = self.dataset.samples[idx]
        image = np.array(Image.open(img_path).convert('RGB'))
        if self.transform:
            transformed = self.transform(image=image)
            image = transformed['image']
        return image, label

    def __len__(self):
        return len(self.dataset)

# Load datasets
val_dataset = CustomImageDataset(root_dir=val_dir, transform=val_transforms)

# Verify class-to-label mapping
print("Class-to-Label Mapping for Validation Dataset:")
print(val_dataset.dataset.class_to_idx)

# DataLoader
num_workers = multiprocessing.cpu_count() - 1
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=num_workers, pin_memory=True)

# Define models with dropout
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Swin Transformer Model
model_swin = timm.create_model('swin_base_patch4_window7_224', pretrained=False, num_classes=2, drop_rate=0.2)
model_swin = model_swin.to(device)

# RegNetY Model
model_regnet = timm.create_model('regnety_032', pretrained=False, num_classes=2, drop_rate=0.2)
model_regnet = model_regnet.to(device)

# Load the best model checkpoint
checkpoint_path = '/content/drive/MyDrive/Model/best_model.pth'  # Update this path if necessary
checkpoint = torch.load(checkpoint_path, map_location=device)
model_swin.load_state_dict(checkpoint['model_swin_state_dict'])
model_regnet.load_state_dict(checkpoint['model_regnet_state_dict'])

# Move models to evaluation mode
model_swin.eval()
model_regnet.eval()

# Initialize variables to track metrics
all_labels = []
all_preds_ensemble = []
all_preds_swin = []
all_preds_regnet = []
correct_swin = 0
correct_regnet = 0
correct_ensemble = 0
total = 0

# Retrieve best validation accuracies for weighted ensemble
best_val_accuracy_swin = checkpoint.get('best_val_accuracy_swin', 0)
best_val_accuracy_regnet = checkpoint.get('best_val_accuracy_regnet', 0)

# Ensure that the sum of weights is not zero
weight_sum = best_val_accuracy_swin + best_val_accuracy_regnet
if weight_sum == 0:
    weight_swin = weight_regnet = 0.5
else:
    weight_swin = best_val_accuracy_swin / weight_sum
    weight_regnet = best_val_accuracy_regnet / weight_sum

print(f'Using weights for ensemble - Swin: {weight_swin:.4f}, RegNetY: {weight_regnet:.4f}')

# Evaluation loop
with torch.no_grad():
    for images, labels in tqdm(val_loader, desc='Evaluating'):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        # Forward pass
        outputs_swin = model_swin(images)
        outputs_regnet = model_regnet(images)

        # Individual model predictions
        _, predicted_swin = torch.max(outputs_swin.data, 1)
        _, predicted_regnet = torch.max(outputs_regnet.data, 1)

        # Ensemble by weighted averaging based on best validation accuracies
        outputs_ensemble = (weight_swin * outputs_swin + weight_regnet * outputs_regnet)
        _, predicted_ensemble = torch.max(outputs_ensemble.data, 1)

        # Update total and correct counts
        total += labels.size(0)
        correct_swin += (predicted_swin == labels).sum().item()
        correct_regnet += (predicted_regnet == labels).sum().item()
        correct_ensemble += (predicted_ensemble == labels).sum().item()

        # Collect all labels and predictions for metrics
        all_labels.extend(labels.cpu().numpy())
        all_preds_swin.extend(predicted_swin.cpu().numpy())
        all_preds_regnet.extend(predicted_regnet.cpu().numpy())
        all_preds_ensemble.extend(predicted_ensemble.cpu().numpy())

# Compute final accuracies
final_accuracy_swin = 100 * correct_swin / total
final_accuracy_regnet = 100 * correct_regnet / total
final_accuracy_ensemble = 100 * correct_ensemble / total

print(f'\nFinal Validation Accuracy:')
print(f' - Swin Transformer: {final_accuracy_swin:.2f}%')
print(f' - RegNetY: {final_accuracy_regnet:.2f}%')
print(f' - Ensemble: {final_accuracy_ensemble:.2f}%')

# Compute confusion matrix for ensemble
cm = confusion_matrix(all_labels, all_preds_ensemble)
print('\nConfusion Matrix for Ensemble:')
print(cm)

# Compute Cohen's Kappa score for ensemble
kappa = cohen_kappa_score(all_labels, all_preds_ensemble)
print(f'\nCohen\'s Kappa Score for Ensemble: {kappa:.4f}')


  check_for_updates()


Class-to-Label Mapping for Validation Dataset:
{'nevus': 0, 'others': 1}


  checkpoint = torch.load(checkpoint_path, map_location=device)


Using weights for ensemble - Swin: 0.5027, RegNetY: 0.4973


Evaluating: 100%|██████████| 119/119 [03:41<00:00,  1.86s/it]


Final Validation Accuracy:
 - Swin Transformer: 92.49%
 - RegNetY: 91.49%
 - Ensemble: 93.20%

Confusion Matrix for Ensemble:
[[1822  109]
 [ 149 1716]]

Cohen's Kappa Score for Ensemble: 0.8640





# Hybrid Model

In [None]:
# Install necessary libraries (if not already installed)
# !pip install --upgrade albumentations timm

import os
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import timm
from sklearn.metrics import confusion_matrix
import numpy as np
from tqdm import tqdm
import multiprocessing

import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

# Define data directories
train_dir = '/content/drive/MyDrive/train/'
val_dir = '/content/drive/MyDrive/val/val/'

# Robust data augmentation
train_transforms = A.Compose([
    A.Resize(224, 224),
    A.OneOf([
        A.RandomBrightnessContrast(),
        A.ColorJitter(),
        A.InvertImg(),
    ], p=0.8),
    A.OneOf([
        A.HorizontalFlip(),
        A.VerticalFlip(),
        A.RandomRotate90(),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45),
        A.Affine(shear=(-30, 30)),
        A.Rotate(limit=45),
        A.CoarseDropout(max_holes=8, max_height=8, max_width=8, fill_value=0, p=0.5),
        A.Transpose(),
    ], p=0.8),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transforms = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# Custom dataset to use Albumentations
class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.dataset = ImageFolder(root=root_dir)
        self.transform = transform

    def __getitem__(self, idx):
        img_path, label = self.dataset.samples[idx]
        image = np.array(Image.open(img_path).convert('RGB'))
        if self.transform:
            transformed = self.transform(image=image)
            image = transformed['image']
        return image, label

    def __len__(self):
        return len(self.dataset)

# Load datasets
train_dataset = CustomImageDataset(root_dir=train_dir, transform=train_transforms)
val_dataset = CustomImageDataset(root_dir=val_dir, transform=val_transforms)

# Verify class-to-label mapping
print("Class-to-Label Mapping for Training Dataset:")
print(train_dataset.dataset.class_to_idx)

# DataLoaders
num_workers = multiprocessing.cpu_count() - 1
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=num_workers, pin_memory=True)

# Define models with dropout
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Swin Transformer Model with Dropout
model_swin = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=2, drop_rate=0.2)
model_swin = model_swin.to(device)

# EfficientNet Model with Dropout
model_efficientnet = timm.create_model('tf_efficientnet_b7_ns', pretrained=True, num_classes=2, drop_rate=0.2)
model_efficientnet = model_efficientnet.to(device)

# Define loss and optimizer with weight decay
criterion = nn.CrossEntropyLoss()
optimizer_swin = torch.optim.AdamW(model_swin.parameters(), lr=1e-4, weight_decay=1e-5)
optimizer_efficientnet = torch.optim.AdamW(model_efficientnet.parameters(), lr=1e-4, weight_decay=1e-5)

# Modified LR schedulers - Using CosineAnnealingWarmRestarts
scheduler_swin = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_swin, T_0=10, T_mult=2)
scheduler_efficientnet = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_efficientnet, T_0=10, T_mult=2)

# Training parameters
num_epochs = 50

# For weighted ensemble
best_val_accuracy_swin = 0
best_val_accuracy_efficientnet = 0
best_val_accuracy_ensemble = 0

for epoch in range(num_epochs):
    model_swin.train()
    model_efficientnet.train()
    running_loss_swin = 0.0
    running_loss_efficientnet = 0.0

    # Training loop
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        # Zero the parameter gradients
        optimizer_swin.zero_grad()
        optimizer_efficientnet.zero_grad()

        # Forward pass
        outputs_swin = model_swin(images)
        outputs_efficientnet = model_efficientnet(images)

        # Compute loss
        loss_swin = criterion(outputs_swin, labels)
        loss_efficientnet = criterion(outputs_efficientnet, labels)
        total_loss = loss_swin + loss_efficientnet

        # Backward pass and optimization
        total_loss.backward()
        optimizer_swin.step()
        optimizer_efficientnet.step()

        running_loss_swin += loss_swin.item()
        running_loss_efficientnet += loss_efficientnet.item()

    # Calculate average training losses
    avg_loss_swin = running_loss_swin / len(train_loader)
    avg_loss_efficientnet = running_loss_efficientnet / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}] Average Training Loss - Swin: {avg_loss_swin:.4f}, EfficientNet: {avg_loss_efficientnet:.4f}')

    # Validation loop
    model_swin.eval()
    model_efficientnet.eval()
    val_loss_swin = 0.0
    val_loss_efficientnet = 0.0
    correct_swin = 0
    correct_efficientnet = 0
    correct_ensemble = 0
    total = 0
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs_swin = model_swin(images)
            outputs_efficientnet = model_efficientnet(images)

            # Compute validation loss
            loss_swin = criterion(outputs_swin, labels)
            loss_efficientnet = criterion(outputs_efficientnet, labels)
            val_loss_swin += loss_swin.item()
            val_loss_efficientnet += loss_efficientnet.item()

            # Individual model predictions
            _, predicted_swin = torch.max(outputs_swin.data, 1)
            _, predicted_efficientnet = torch.max(outputs_efficientnet.data, 1)

            # Ensemble by averaging the outputs
            outputs = (outputs_swin + outputs_efficientnet) / 2
            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)
            correct_swin += (predicted_swin == labels).sum().item()
            correct_efficientnet += (predicted_efficientnet == labels).sum().item()
            correct_ensemble += (predicted == labels).sum().item()

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())

    # Calculate average validation losses and accuracy
    avg_val_loss_swin = val_loss_swin / len(val_loader)
    avg_val_loss_efficientnet = val_loss_efficientnet / len(val_loader)
    val_accuracy_swin = 100 * correct_swin / total
    val_accuracy_efficientnet = 100 * correct_efficientnet / total
    val_accuracy_ensemble = 100 * correct_ensemble / total
    print(f'Validation Loss - Swin: {avg_val_loss_swin:.4f}, EfficientNet: {avg_val_loss_efficientnet:.4f}')
    print(f'Validation Accuracy - Swin: {val_accuracy_swin:.2f}%, EfficientNet: {val_accuracy_efficientnet:.2f}%, Ensemble: {val_accuracy_ensemble:.2f}%')

    # Update learning rate schedulers with warm restarts
    scheduler_swin.step(epoch + 1)           # step with the epoch count
    scheduler_efficientnet.step(epoch + 1)   # step with the epoch count

    # Save the model based on best ensemble validation accuracy
    if val_accuracy_ensemble > best_val_accuracy_ensemble:
        best_val_accuracy_ensemble = val_accuracy_ensemble
        # Save the best model
        torch.save({
            'epoch': epoch + 1,
            'model_swin_state_dict': model_swin.state_dict(),
            'model_efficientnet_state_dict': model_efficientnet.state_dict(),
            'optimizer_swin_state_dict': optimizer_swin.state_dict(),
            'optimizer_efficientnet_state_dict': optimizer_efficientnet.state_dict(),
            'val_accuracy_ensemble': best_val_accuracy_ensemble,
            'best_val_accuracy_swin': val_accuracy_swin,
            'best_val_accuracy_efficientnet': val_accuracy_efficientnet,
        }, 'best_model.pth')
        print(f'Best model saved at epoch {epoch+1} with ensemble validation accuracy {best_val_accuracy_ensemble:.2f}%')

        # Update best validation accuracies
        best_val_accuracy_swin = val_accuracy_swin
        best_val_accuracy_efficientnet = val_accuracy_efficientnet
    else:
        print('No improvement in ensemble validation accuracy.')

print('Training completed.')

# Load the best model
checkpoint = torch.load('best_model.pth')
model_swin.load_state_dict(checkpoint['model_swin_state_dict'])
model_efficientnet.load_state_dict(checkpoint['model_efficientnet_state_dict'])
best_val_accuracy_swin = checkpoint['best_val_accuracy_swin']
best_val_accuracy_efficientnet = checkpoint['best_val_accuracy_efficientnet']

# Evaluate on the validation set with weighted ensemble
model_swin.eval()
model_efficientnet.eval()
all_labels = []
all_preds = []
correct_swin = 0
correct_efficientnet = 0
correct_ensemble = 0
total = 0
with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        outputs_swin = model_swin(images)
        outputs_efficientnet = model_efficientnet(images)

        # Individual model predictions
        _, predicted_swin = torch.max(outputs_swin.data, 1)
        _, predicted_efficientnet = torch.max(outputs_efficientnet.data, 1)

        # Ensemble by weighted averaging based on best validation accuracies
        weight_swin = best_val_accuracy_swin / (best_val_accuracy_swin + best_val_accuracy_efficientnet)
        weight_efficientnet = best_val_accuracy_efficientnet / (best_val_accuracy_swin + best_val_accuracy_efficientnet)
        outputs = (weight_swin * outputs_swin + weight_efficientnet * outputs_efficientnet)

        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct_swin += (predicted_swin == labels).sum().item()
        correct_efficientnet += (predicted_efficientnet == labels).sum().item()
        correct_ensemble += (predicted == labels).sum().item()

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

# Compute final accuracies
final_accuracy_swin = 100 * correct_swin / total
final_accuracy_efficientnet = 100 * correct_efficientnet / total
final_accuracy_ensemble = 100 * correct_ensemble / total
print(f'Final Validation Accuracy - Swin: {final_accuracy_swin:.2f}%, EfficientNet: {final_accuracy_efficientnet:.2f}%, Ensemble: {final_accuracy_ensemble:.2f}%')

# Compute confusion matrix
cm = confusion_matrix(all_labels, all_preds)
print('Confusion Matrix:')
print(cm)


Class-to-Label Mapping for Training Dataset:
{'nevus': 0, 'others': 1}


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

  model = create_fn(


model.safetensors:   0%|          | 0.00/267M [00:00<?, ?B/s]

Epoch 1/50: 100%|██████████| 475/475 [15:50<00:00,  2.00s/it]

Epoch [1/50] Average Training Loss - Swin: 0.4224, EfficientNet: 0.7795





Validation Loss - Swin: 0.4102, EfficientNet: 0.4121
Validation Accuracy - Swin: 81.22%, EfficientNet: 80.82%, Ensemble: 83.54%
Best model saved at epoch 1 with ensemble validation accuracy 83.54%


Epoch 2/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [2/50] Average Training Loss - Swin: 0.3485, EfficientNet: 0.4790





Validation Loss - Swin: 0.3026, EfficientNet: 0.3844
Validation Accuracy - Swin: 87.07%, EfficientNet: 83.32%, Ensemble: 86.38%
Best model saved at epoch 2 with ensemble validation accuracy 86.38%


Epoch 3/50: 100%|██████████| 475/475 [03:39<00:00,  2.16it/s]

Epoch [3/50] Average Training Loss - Swin: 0.3141, EfficientNet: 0.4211





Validation Loss - Swin: 0.2926, EfficientNet: 0.3562
Validation Accuracy - Swin: 87.46%, EfficientNet: 84.54%, Ensemble: 87.49%
Best model saved at epoch 3 with ensemble validation accuracy 87.49%


Epoch 4/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [4/50] Average Training Loss - Swin: 0.2769, EfficientNet: 0.3793





Validation Loss - Swin: 0.2971, EfficientNet: 0.3481
Validation Accuracy - Swin: 87.88%, EfficientNet: 84.30%, Ensemble: 88.04%
Best model saved at epoch 4 with ensemble validation accuracy 88.04%


Epoch 5/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [5/50] Average Training Loss - Swin: 0.2451, EfficientNet: 0.3479





Validation Loss - Swin: 0.2769, EfficientNet: 0.3289
Validation Accuracy - Swin: 88.20%, EfficientNet: 86.51%, Ensemble: 88.78%
Best model saved at epoch 5 with ensemble validation accuracy 88.78%


Epoch 6/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [6/50] Average Training Loss - Swin: 0.2084, EfficientNet: 0.3119





Validation Loss - Swin: 0.2655, EfficientNet: 0.3368
Validation Accuracy - Swin: 89.91%, EfficientNet: 86.88%, Ensemble: 89.81%
Best model saved at epoch 6 with ensemble validation accuracy 89.81%


Epoch 7/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [7/50] Average Training Loss - Swin: 0.1716, EfficientNet: 0.2838





Validation Loss - Swin: 0.2475, EfficientNet: 0.3173
Validation Accuracy - Swin: 90.52%, EfficientNet: 86.96%, Ensemble: 90.41%
Best model saved at epoch 7 with ensemble validation accuracy 90.41%


Epoch 8/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [8/50] Average Training Loss - Swin: 0.1335, EfficientNet: 0.2567





Validation Loss - Swin: 0.2661, EfficientNet: 0.3084
Validation Accuracy - Swin: 90.65%, EfficientNet: 87.41%, Ensemble: 91.02%
Best model saved at epoch 8 with ensemble validation accuracy 91.02%


Epoch 9/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [9/50] Average Training Loss - Swin: 0.1134, EfficientNet: 0.2338





Validation Loss - Swin: 0.2631, EfficientNet: 0.3085
Validation Accuracy - Swin: 91.07%, EfficientNet: 87.51%, Ensemble: 91.33%
Best model saved at epoch 9 with ensemble validation accuracy 91.33%


Epoch 10/50: 100%|██████████| 475/475 [03:41<00:00,  2.14it/s]

Epoch [10/50] Average Training Loss - Swin: 0.0963, EfficientNet: 0.2135





Validation Loss - Swin: 0.2718, EfficientNet: 0.3102
Validation Accuracy - Swin: 91.10%, EfficientNet: 87.70%, Ensemble: 91.60%
Best model saved at epoch 10 with ensemble validation accuracy 91.60%


Epoch 11/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [11/50] Average Training Loss - Swin: 0.2114, EfficientNet: 0.3131





Validation Loss - Swin: 0.2596, EfficientNet: 0.3273
Validation Accuracy - Swin: 89.36%, EfficientNet: 85.64%, Ensemble: 89.20%
No improvement in ensemble validation accuracy.


Epoch 12/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [12/50] Average Training Loss - Swin: 0.1979, EfficientNet: 0.2731





Validation Loss - Swin: 0.3149, EfficientNet: 0.3379
Validation Accuracy - Swin: 89.23%, EfficientNet: 87.28%, Ensemble: 90.23%
No improvement in ensemble validation accuracy.


Epoch 13/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [13/50] Average Training Loss - Swin: 0.1871, EfficientNet: 0.2395





Validation Loss - Swin: 0.2798, EfficientNet: 0.3074
Validation Accuracy - Swin: 89.57%, EfficientNet: 88.22%, Ensemble: 90.25%
No improvement in ensemble validation accuracy.


Epoch 14/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [14/50] Average Training Loss - Swin: 0.1679, EfficientNet: 0.2053





Validation Loss - Swin: 0.2724, EfficientNet: 0.3043
Validation Accuracy - Swin: 89.30%, EfficientNet: 88.12%, Ensemble: 90.25%
No improvement in ensemble validation accuracy.


Epoch 15/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [15/50] Average Training Loss - Swin: 0.1524, EfficientNet: 0.1868





Validation Loss - Swin: 0.2705, EfficientNet: 0.3405
Validation Accuracy - Swin: 90.65%, EfficientNet: 88.62%, Ensemble: 91.31%
No improvement in ensemble validation accuracy.


Epoch 16/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [16/50] Average Training Loss - Swin: 0.1421, EfficientNet: 0.1607





Validation Loss - Swin: 0.2525, EfficientNet: 0.3324
Validation Accuracy - Swin: 90.28%, EfficientNet: 89.09%, Ensemble: 90.86%
No improvement in ensemble validation accuracy.


Epoch 17/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [17/50] Average Training Loss - Swin: 0.1288, EfficientNet: 0.1434





Validation Loss - Swin: 0.2670, EfficientNet: 0.3275
Validation Accuracy - Swin: 90.46%, EfficientNet: 88.99%, Ensemble: 92.02%
Best model saved at epoch 17 with ensemble validation accuracy 92.02%


Epoch 18/50: 100%|██████████| 475/475 [03:40<00:00,  2.15it/s]

Epoch [18/50] Average Training Loss - Swin: 0.1088, EfficientNet: 0.1170





Validation Loss - Swin: 0.3213, EfficientNet: 0.3435
Validation Accuracy - Swin: 90.23%, EfficientNet: 88.94%, Ensemble: 91.86%
No improvement in ensemble validation accuracy.


Epoch 19/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [19/50] Average Training Loss - Swin: 0.0949, EfficientNet: 0.1047





Validation Loss - Swin: 0.3093, EfficientNet: 0.3530
Validation Accuracy - Swin: 90.67%, EfficientNet: 89.73%, Ensemble: 91.54%
No improvement in ensemble validation accuracy.


Epoch 20/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [20/50] Average Training Loss - Swin: 0.0806, EfficientNet: 0.0864





Validation Loss - Swin: 0.3550, EfficientNet: 0.4225
Validation Accuracy - Swin: 90.91%, EfficientNet: 90.02%, Ensemble: 92.12%
Best model saved at epoch 20 with ensemble validation accuracy 92.12%


Epoch 21/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [21/50] Average Training Loss - Swin: 0.0703, EfficientNet: 0.0757





Validation Loss - Swin: 0.3185, EfficientNet: 0.3770
Validation Accuracy - Swin: 91.41%, EfficientNet: 90.09%, Ensemble: 92.02%
No improvement in ensemble validation accuracy.


Epoch 22/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [22/50] Average Training Loss - Swin: 0.0600, EfficientNet: 0.0622





Validation Loss - Swin: 0.3161, EfficientNet: 0.3946
Validation Accuracy - Swin: 91.94%, EfficientNet: 90.83%, Ensemble: 92.18%
Best model saved at epoch 22 with ensemble validation accuracy 92.18%


Epoch 23/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [23/50] Average Training Loss - Swin: 0.0467, EfficientNet: 0.0556





Validation Loss - Swin: 0.3703, EfficientNet: 0.3858
Validation Accuracy - Swin: 91.31%, EfficientNet: 90.36%, Ensemble: 92.49%
Best model saved at epoch 23 with ensemble validation accuracy 92.49%


Epoch 24/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [24/50] Average Training Loss - Swin: 0.0397, EfficientNet: 0.0496





Validation Loss - Swin: 0.3773, EfficientNet: 0.4156
Validation Accuracy - Swin: 91.81%, EfficientNet: 90.52%, Ensemble: 92.60%
Best model saved at epoch 24 with ensemble validation accuracy 92.60%


Epoch 25/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [25/50] Average Training Loss - Swin: 0.0363, EfficientNet: 0.0437





Validation Loss - Swin: 0.3291, EfficientNet: 0.4234
Validation Accuracy - Swin: 91.68%, EfficientNet: 90.60%, Ensemble: 92.39%
No improvement in ensemble validation accuracy.


Epoch 26/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [26/50] Average Training Loss - Swin: 0.0231, EfficientNet: 0.0357





Validation Loss - Swin: 0.3728, EfficientNet: 0.4455
Validation Accuracy - Swin: 91.73%, EfficientNet: 90.78%, Ensemble: 92.62%
Best model saved at epoch 26 with ensemble validation accuracy 92.62%


Epoch 27/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [27/50] Average Training Loss - Swin: 0.0210, EfficientNet: 0.0344





Validation Loss - Swin: 0.3995, EfficientNet: 0.4560
Validation Accuracy - Swin: 92.18%, EfficientNet: 90.94%, Ensemble: 92.91%
Best model saved at epoch 27 with ensemble validation accuracy 92.91%


Epoch 28/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [28/50] Average Training Loss - Swin: 0.0213, EfficientNet: 0.0295





Validation Loss - Swin: 0.3855, EfficientNet: 0.4420
Validation Accuracy - Swin: 92.47%, EfficientNet: 90.83%, Ensemble: 92.70%
No improvement in ensemble validation accuracy.


Epoch 29/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [29/50] Average Training Loss - Swin: 0.0161, EfficientNet: 0.0256





Validation Loss - Swin: 0.3957, EfficientNet: 0.4360
Validation Accuracy - Swin: 92.39%, EfficientNet: 91.02%, Ensemble: 93.02%
Best model saved at epoch 29 with ensemble validation accuracy 93.02%


Epoch 30/50: 100%|██████████| 475/475 [03:42<00:00,  2.14it/s]

Epoch [30/50] Average Training Loss - Swin: 0.0156, EfficientNet: 0.0241





Validation Loss - Swin: 0.3951, EfficientNet: 0.4415
Validation Accuracy - Swin: 92.41%, EfficientNet: 91.12%, Ensemble: 92.94%
No improvement in ensemble validation accuracy.


Epoch 31/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [31/50] Average Training Loss - Swin: 0.1349, EfficientNet: 0.1232





Validation Loss - Swin: 0.3362, EfficientNet: 0.4045
Validation Accuracy - Swin: 89.17%, EfficientNet: 88.59%, Ensemble: 91.39%
No improvement in ensemble validation accuracy.


Epoch 32/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [32/50] Average Training Loss - Swin: 0.1172, EfficientNet: 0.1132





Validation Loss - Swin: 0.3140, EfficientNet: 0.4275
Validation Accuracy - Swin: 89.91%, EfficientNet: 88.94%, Ensemble: 91.60%
No improvement in ensemble validation accuracy.


Epoch 33/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [33/50] Average Training Loss - Swin: 0.1090, EfficientNet: 0.0949





Validation Loss - Swin: 0.3269, EfficientNet: 0.4686
Validation Accuracy - Swin: 89.96%, EfficientNet: 89.20%, Ensemble: 91.52%
No improvement in ensemble validation accuracy.


Epoch 34/50: 100%|██████████| 475/475 [03:39<00:00,  2.17it/s]

Epoch [34/50] Average Training Loss - Swin: 0.0995, EfficientNet: 0.0843





Validation Loss - Swin: 0.3273, EfficientNet: 0.4781
Validation Accuracy - Swin: 89.83%, EfficientNet: 89.78%, Ensemble: 91.02%
No improvement in ensemble validation accuracy.


Epoch 35/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [35/50] Average Training Loss - Swin: 0.0993, EfficientNet: 0.0768





Validation Loss - Swin: 0.3089, EfficientNet: 0.4113
Validation Accuracy - Swin: 91.33%, EfficientNet: 89.57%, Ensemble: 92.12%
No improvement in ensemble validation accuracy.


Epoch 36/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [36/50] Average Training Loss - Swin: 0.0899, EfficientNet: 0.0738





Validation Loss - Swin: 0.3487, EfficientNet: 0.4119
Validation Accuracy - Swin: 90.86%, EfficientNet: 89.78%, Ensemble: 91.91%
No improvement in ensemble validation accuracy.


Epoch 37/50: 100%|██████████| 475/475 [03:39<00:00,  2.17it/s]

Epoch [37/50] Average Training Loss - Swin: 0.0907, EfficientNet: 0.0702





Validation Loss - Swin: 0.2826, EfficientNet: 0.4553
Validation Accuracy - Swin: 90.70%, EfficientNet: 89.41%, Ensemble: 91.57%
No improvement in ensemble validation accuracy.


Epoch 38/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [38/50] Average Training Loss - Swin: 0.0740, EfficientNet: 0.0678





Validation Loss - Swin: 0.2920, EfficientNet: 0.4688
Validation Accuracy - Swin: 91.10%, EfficientNet: 89.33%, Ensemble: 92.04%
No improvement in ensemble validation accuracy.


Epoch 39/50: 100%|██████████| 475/475 [03:39<00:00,  2.17it/s]

Epoch [39/50] Average Training Loss - Swin: 0.0709, EfficientNet: 0.0649





Validation Loss - Swin: 0.3570, EfficientNet: 0.4382
Validation Accuracy - Swin: 90.91%, EfficientNet: 90.38%, Ensemble: 92.23%
No improvement in ensemble validation accuracy.


Epoch 40/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [40/50] Average Training Loss - Swin: 0.0749, EfficientNet: 0.0575





Validation Loss - Swin: 0.3365, EfficientNet: 0.4180
Validation Accuracy - Swin: 90.36%, EfficientNet: 89.49%, Ensemble: 91.83%
No improvement in ensemble validation accuracy.


Epoch 41/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [41/50] Average Training Loss - Swin: 0.0610, EfficientNet: 0.0498





Validation Loss - Swin: 0.3432, EfficientNet: 0.4522
Validation Accuracy - Swin: 89.96%, EfficientNet: 90.52%, Ensemble: 91.36%
No improvement in ensemble validation accuracy.


Epoch 42/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [42/50] Average Training Loss - Swin: 0.0614, EfficientNet: 0.0476





Validation Loss - Swin: 0.4420, EfficientNet: 0.5005
Validation Accuracy - Swin: 89.99%, EfficientNet: 89.88%, Ensemble: 92.02%
No improvement in ensemble validation accuracy.


Epoch 43/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [43/50] Average Training Loss - Swin: 0.0568, EfficientNet: 0.0467





Validation Loss - Swin: 0.3888, EfficientNet: 0.4712
Validation Accuracy - Swin: 91.15%, EfficientNet: 89.86%, Ensemble: 92.28%
No improvement in ensemble validation accuracy.


Epoch 44/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [44/50] Average Training Loss - Swin: 0.0490, EfficientNet: 0.0416





Validation Loss - Swin: 0.4244, EfficientNet: 0.4563
Validation Accuracy - Swin: 90.57%, EfficientNet: 89.88%, Ensemble: 92.26%
No improvement in ensemble validation accuracy.


Epoch 45/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [45/50] Average Training Loss - Swin: 0.0489, EfficientNet: 0.0389





Validation Loss - Swin: 0.3312, EfficientNet: 0.4588
Validation Accuracy - Swin: 90.75%, EfficientNet: 90.33%, Ensemble: 92.07%
No improvement in ensemble validation accuracy.


Epoch 46/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [46/50] Average Training Loss - Swin: 0.0476, EfficientNet: 0.0365





Validation Loss - Swin: 0.3608, EfficientNet: 0.4704
Validation Accuracy - Swin: 91.25%, EfficientNet: 90.49%, Ensemble: 92.15%
No improvement in ensemble validation accuracy.


Epoch 47/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [47/50] Average Training Loss - Swin: 0.0380, EfficientNet: 0.0337





Validation Loss - Swin: 0.3991, EfficientNet: 0.5099
Validation Accuracy - Swin: 90.62%, EfficientNet: 90.60%, Ensemble: 91.99%
No improvement in ensemble validation accuracy.


Epoch 48/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [48/50] Average Training Loss - Swin: 0.0390, EfficientNet: 0.0274





Validation Loss - Swin: 0.4174, EfficientNet: 0.5264
Validation Accuracy - Swin: 91.07%, EfficientNet: 90.25%, Ensemble: 92.10%
No improvement in ensemble validation accuracy.


Epoch 49/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [49/50] Average Training Loss - Swin: 0.0344, EfficientNet: 0.0294





Validation Loss - Swin: 0.3433, EfficientNet: 0.5154
Validation Accuracy - Swin: 91.49%, EfficientNet: 90.70%, Ensemble: 92.15%
No improvement in ensemble validation accuracy.


Epoch 50/50: 100%|██████████| 475/475 [03:38<00:00,  2.17it/s]

Epoch [50/50] Average Training Loss - Swin: 0.0329, EfficientNet: 0.0238





Validation Loss - Swin: 0.3781, EfficientNet: 0.5065
Validation Accuracy - Swin: 91.57%, EfficientNet: 90.73%, Ensemble: 92.12%
No improvement in ensemble validation accuracy.
Training completed.


  checkpoint = torch.load('best_model.pth')


Final Validation Accuracy - Swin: 92.39%, EfficientNet: 91.02%, Ensemble: 93.02%
Confusion Matrix:
[[1806  125]
 [ 140 1725]]


# Three Class Problem

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Install necessary libraries (if not already installed)
# !pip install --upgrade albumentations timm

import os
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, WeightedRandomSampler
import timm
from sklearn.metrics import confusion_matrix, cohen_kappa_score
import numpy as np
from tqdm import tqdm
import multiprocessing

import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

# ADDED for copying file to Drive
import shutil

#######################################
# Data Directories and Class Mapping
#######################################
# Data directories:
train_dir = '/content/drive/MyDrive/Three_Class/train_3/train_3/'
val_dir = '/content/drive/MyDrive/Three_Class/val_3/val_3/'

# Desired mapping: mel=0, bcc=1, scc=2
# Alphabetical order: bcc, mel, scc → original indexing: bcc=0, mel=1, scc=2
# We must remap:
# mel (original 1) → 0
# bcc (original 0) → 1
# scc (original 2) → 2
original_to_desired = {0: 1, 1: 0, 2: 2}

#######################################
# Data Transforms
#######################################
train_transforms = A.Compose([
    A.Resize(224, 224),
    A.OneOf([
        A.RandomBrightnessContrast(),
        A.ColorJitter(),
        A.InvertImg(),
    ], p=0.8),
    A.OneOf([
        A.HorizontalFlip(),
        A.VerticalFlip(),
        A.RandomRotate90(),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=45),
        A.Affine(shear=(-30, 30)),
        A.Rotate(limit=45),
        A.CoarseDropout(max_holes=8, max_height=8, max_width=8, fill_value=0, p=0.5),
        A.Transpose(),
    ], p=0.8),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transforms = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

#######################################
# Custom Dataset
#######################################
class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None, label_map=None):
        self.dataset = ImageFolder(root=root_dir)
        self.transform = transform
        self.label_map = label_map  # Dictionary for remapping labels

    def __getitem__(self, idx):
        img_path, label = self.dataset.samples[idx]
        image = np.array(Image.open(img_path).convert('RGB'))
        if self.transform:
            transformed = self.transform(image=image)
            image = transformed['image']

        # Remap label using original_to_desired mapping if provided
        if self.label_map is not None:
            label = self.label_map[label]

        return image, label

    def __len__(self):
        return len(self.dataset)

#######################################
# Load Datasets
#######################################
train_dataset = CustomImageDataset(root_dir=train_dir, transform=train_transforms, label_map=original_to_desired)
val_dataset = CustomImageDataset(root_dir=val_dir, transform=val_transforms, label_map=original_to_desired)

print("Class-to-Label Mapping for Training Dataset (original):")
print(train_dataset.dataset.class_to_idx)
print("Our desired mapping is mel=0, bcc=1, scc=2.")

#######################################
# Handle Class Imbalance with WeightedRandomSampler
#######################################
# We now compute class frequencies from the train_dataset
all_train_labels = [train_dataset[i][1] for i in range(len(train_dataset))]
class_counts = np.bincount(all_train_labels)
# class_counts[i] = number of samples for class i
# Example: mel=0, bcc=1, scc=2

class_weights = 1.0 / class_counts
# Create a weight for each sample based on its class
sample_weights = [class_weights[label] for label in all_train_labels]

sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

#######################################
# DataLoaders
#######################################
num_workers = multiprocessing.cpu_count() - 1
train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=num_workers, pin_memory=True)

#######################################
# Model Setup
#######################################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Modify num_classes=3 for mel, bcc, scc classification
model_swin = timm.create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=3, drop_rate=0.2)
model_swin = model_swin.to(device)

model_regnet = timm.create_model('regnety_032', pretrained=True, num_classes=3, drop_rate=0.2)
model_regnet = model_regnet.to(device)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_swin = torch.optim.AdamW(model_swin.parameters(), lr=1e-4, weight_decay=1e-5)
optimizer_regnet = torch.optim.AdamW(model_regnet.parameters(), lr=1e-4, weight_decay=1e-5)

# Learning rate schedulers
scheduler_swin = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_swin, T_max=50)
scheduler_regnet = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_regnet, T_max=50)

num_epochs = 100

best_val_accuracy_swin = 0
best_val_accuracy_regnet = 0
best_val_accuracy_ensemble = 0

#######################################
# Drive folder to copy model
#######################################
# Change this path to your desired Google Drive folder:
drive_folder = '/content/drive/MyDrive/Three_Class/trained_models'
os.makedirs(drive_folder, exist_ok=True)  # Create folder if it doesn't exist

#######################################
# Training and Validation Loop
#######################################
for epoch in range(num_epochs):
    model_swin.train()
    model_regnet.train()
    running_loss_swin = 0.0
    running_loss_regnet = 0.0

    # Training loop
    for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        # Zero gradients
        optimizer_swin.zero_grad()
        optimizer_regnet.zero_grad()

        # Forward pass
        outputs_swin = model_swin(images)
        outputs_regnet = model_regnet(images)

        # Compute loss
        loss_swin = criterion(outputs_swin, labels)
        loss_regnet = criterion(outputs_regnet, labels)
        total_loss = loss_swin + loss_regnet

        # Backward + Optimize
        total_loss.backward()
        optimizer_swin.step()
        optimizer_regnet.step()

        running_loss_swin += loss_swin.item()
        running_loss_regnet += loss_regnet.item()

    # Average training losses
    avg_loss_swin = running_loss_swin / len(train_loader)
    avg_loss_regnet = running_loss_regnet / len(train_loader)
    avg_total_loss = avg_loss_swin + avg_loss_regnet
    print(f'Epoch [{epoch+1}/{num_epochs}] Training Loss - Swin: {avg_loss_swin:.4f}, RegNetY: {avg_loss_regnet:.4f}')

    # Validation
    model_swin.eval()
    model_regnet.eval()
    val_loss_swin = 0.0
    val_loss_regnet = 0.0
    correct_swin = 0
    correct_regnet = 0
    correct_ensemble = 0
    total = 0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs_swin = model_swin(images)
            outputs_regnet = model_regnet(images)

            loss_swin = criterion(outputs_swin, labels)
            loss_regnet = criterion(outputs_regnet, labels)
            val_loss_swin += loss_swin.item()
            val_loss_regnet += loss_regnet.item()

            # Predictions
            _, predicted_swin = torch.max(outputs_swin, 1)
            _, predicted_regnet = torch.max(outputs_regnet, 1)

            # Simple averaging for ensemble
            combined_outputs = (outputs_swin + outputs_regnet) / 2
            _, predicted_ensemble = torch.max(combined_outputs, 1)

            total += labels.size(0)
            correct_swin += (predicted_swin == labels).sum().item()
            correct_regnet += (predicted_regnet == labels).sum().item()
            correct_ensemble += (predicted_ensemble == labels).sum().item()

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted_ensemble.cpu().numpy())

    # Validation metrics
    avg_val_loss_swin = val_loss_swin / len(val_loader)
    avg_val_loss_regnet = val_loss_regnet / len(val_loader)
    avg_val_loss = avg_val_loss_swin + avg_val_loss_regnet
    val_accuracy_swin = 100 * correct_swin / total
    val_accuracy_regnet = 100 * correct_regnet / total
    val_accuracy_ensemble = 100 * correct_ensemble / total

    # Calculate validation Kappa (using the ensemble predictions)
    val_kappa_ensemble = cohen_kappa_score(all_labels, all_preds)

    print(f'Validation Loss - Swin: {avg_val_loss_swin:.4f}, '
          f'RegNetY: {avg_val_loss_regnet:.4f}, Total: {avg_val_loss:.4f}')
    print(f'Validation Accuracy - Swin: {val_accuracy_swin:.2f}%, '
          f'RegNetY: {val_accuracy_regnet:.2f}%, '
          f'Ensemble: {val_accuracy_ensemble:.2f}%, '
          f'Kappa (Ensemble): {val_kappa_ensemble:.4f}')

    # Update schedulers
    scheduler_swin.step()
    scheduler_regnet.step()

    # Save best model based on the best ensemble accuracy
    if val_accuracy_ensemble > best_val_accuracy_ensemble:
        best_val_accuracy_ensemble = val_accuracy_ensemble
        torch.save({
            'epoch': epoch + 1,
            'model_swin_state_dict': model_swin.state_dict(),
            'model_regnet_state_dict': model_regnet.state_dict(),
            'optimizer_swin_state_dict': optimizer_swin.state_dict(),
            'optimizer_regnet_state_dict': optimizer_regnet.state_dict(),
            'val_accuracy_ensemble': best_val_accuracy_ensemble,
            'best_val_accuracy_swin': val_accuracy_swin,
            'best_val_accuracy_regnet': val_accuracy_regnet,
        }, 'best_model.pth')

        # Copy the saved model to Drive
        drive_model_path = os.path.join(drive_folder, 'best_model.pth')
        shutil.copy('best_model.pth', drive_model_path)

        print(f'Best model saved at epoch {epoch+1} with ensemble validation accuracy {best_val_accuracy_ensemble:.2f}%')
        print(f'Model also copied to: {drive_model_path}')

        best_val_accuracy_swin = val_accuracy_swin
        best_val_accuracy_regnet = val_accuracy_regnet
    else:
        print('No improvement in ensemble validation accuracy.')

print('Training completed.')

#######################################
# Final Evaluation with Weighted Ensemble
#######################################
checkpoint = torch.load('best_model.pth', map_location=device)
model_swin.load_state_dict(checkpoint['model_swin_state_dict'])
model_regnet.load_state_dict(checkpoint['model_regnet_state_dict'])
best_val_accuracy_swin = checkpoint['best_val_accuracy_swin']
best_val_accuracy_regnet = checkpoint['best_val_accuracy_regnet']

model_swin.eval()
model_regnet.eval()

all_labels = []
all_preds = []
correct_swin = 0
correct_regnet = 0
correct_ensemble = 0
total = 0

# Weighted ensemble based on best validation accuracies
weight_sum = best_val_accuracy_swin + best_val_accuracy_regnet
if weight_sum == 0:
    weight_swin = 0.5
    weight_regnet = 0.5
else:
    weight_swin = best_val_accuracy_swin / weight_sum
    weight_regnet = best_val_accuracy_regnet / weight_sum

print(f'Using weighted ensemble: Swin Weight={weight_swin:.2f}, RegNetY Weight={weight_regnet:.2f}')

with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        outputs_swin = model_swin(images)
        outputs_regnet = model_regnet(images)

        # Individual predictions
        _, predicted_swin = torch.max(outputs_swin, 1)
        _, predicted_regnet = torch.max(outputs_regnet, 1)

        # Weighted ensemble
        weighted_outputs = weight_swin * outputs_swin + weight_regnet * outputs_regnet
        _, predicted_ensemble = torch.max(weighted_outputs, 1)

        total += labels.size(0)
        correct_swin += (predicted_swin == labels).sum().item()
        correct_regnet += (predicted_regnet == labels).sum().item()
        correct_ensemble += (predicted_ensemble == labels).sum().item()

        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted_ensemble.cpu().numpy())

final_accuracy_swin = 100 * correct_swin / total
final_accuracy_regnet = 100 * correct_regnet / total
final_accuracy_ensemble = 100 * correct_ensemble / total
print(f'Final Validation Accuracy - Swin: {final_accuracy_swin:.2f}%, '
      f'RegNetY: {final_accuracy_regnet:.2f}%, '
      f'Ensemble: {final_accuracy_ensemble:.2f}%')

cm = confusion_matrix(all_labels, all_preds)
final_kappa_ensemble = cohen_kappa_score(all_labels, all_preds)
print('Confusion Matrix:')
print(cm)
print(f'Final Kappa (Ensemble): {final_kappa_ensemble:.4f}')


Class-to-Label Mapping for Training Dataset (original):
{'bcc': 0, 'mel': 1, 'scc': 2}
Our desired mapping is mel=0, bcc=1, scc=2.


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/78.1M [00:00<?, ?B/s]

Epoch 1/100: 100%|██████████| 159/159 [00:56<00:00,  2.80it/s]

Epoch [1/100] Training Loss - Swin: 0.7073, RegNetY: 0.8471





Validation Loss - Swin: 0.3999, RegNetY: 0.5178, Total: 0.9178
Validation Accuracy - Swin: 85.98%, RegNetY: 79.13%, Ensemble: 85.28%, Kappa (Ensemble): 0.7393
Best model saved at epoch 1 with ensemble validation accuracy 85.28%
Model also copied to: /content/drive/MyDrive/Three_Class/trained_models/best_model.pth


Epoch 2/100: 100%|██████████| 159/159 [00:56<00:00,  2.82it/s]

Epoch [2/100] Training Loss - Swin: 0.4402, RegNetY: 0.5319





Validation Loss - Swin: 0.4331, RegNetY: 0.4122, Total: 0.8453
Validation Accuracy - Swin: 84.65%, RegNetY: 83.62%, Ensemble: 87.32%, Kappa (Ensemble): 0.7801
Best model saved at epoch 2 with ensemble validation accuracy 87.32%
Model also copied to: /content/drive/MyDrive/Three_Class/trained_models/best_model.pth


Epoch 3/100: 100%|██████████| 159/159 [00:55<00:00,  2.88it/s]

Epoch [3/100] Training Loss - Swin: 0.3261, RegNetY: 0.4137





Validation Loss - Swin: 0.3815, RegNetY: 0.4428, Total: 0.8243
Validation Accuracy - Swin: 85.83%, RegNetY: 82.91%, Ensemble: 88.03%, Kappa (Ensemble): 0.7903
Best model saved at epoch 3 with ensemble validation accuracy 88.03%
Model also copied to: /content/drive/MyDrive/Three_Class/trained_models/best_model.pth


Epoch 4/100: 100%|██████████| 159/159 [00:55<00:00,  2.85it/s]

Epoch [4/100] Training Loss - Swin: 0.2663, RegNetY: 0.3163





Validation Loss - Swin: 0.3262, RegNetY: 0.3015, Total: 0.6277
Validation Accuracy - Swin: 88.43%, RegNetY: 88.35%, Ensemble: 91.26%, Kappa (Ensemble): 0.8447
Best model saved at epoch 4 with ensemble validation accuracy 91.26%
Model also copied to: /content/drive/MyDrive/Three_Class/trained_models/best_model.pth


Epoch 5/100: 100%|██████████| 159/159 [00:55<00:00,  2.87it/s]

Epoch [5/100] Training Loss - Swin: 0.2107, RegNetY: 0.2919





Validation Loss - Swin: 0.2362, RegNetY: 0.3512, Total: 0.5874
Validation Accuracy - Swin: 91.65%, RegNetY: 86.85%, Ensemble: 92.91%, Kappa (Ensemble): 0.8751
Best model saved at epoch 5 with ensemble validation accuracy 92.91%
Model also copied to: /content/drive/MyDrive/Three_Class/trained_models/best_model.pth


Epoch 6/100: 100%|██████████| 159/159 [00:55<00:00,  2.88it/s]

Epoch [6/100] Training Loss - Swin: 0.1736, RegNetY: 0.2270





Validation Loss - Swin: 0.2448, RegNetY: 0.2983, Total: 0.5431
Validation Accuracy - Swin: 91.97%, RegNetY: 88.90%, Ensemble: 92.91%, Kappa (Ensemble): 0.8723
No improvement in ensemble validation accuracy.


Epoch 7/100: 100%|██████████| 159/159 [00:52<00:00,  3.01it/s]

Epoch [7/100] Training Loss - Swin: 0.1378, RegNetY: 0.1994





Validation Loss - Swin: 0.2453, RegNetY: 0.2672, Total: 0.5125
Validation Accuracy - Swin: 92.13%, RegNetY: 90.39%, Ensemble: 93.39%, Kappa (Ensemble): 0.8825
Best model saved at epoch 7 with ensemble validation accuracy 93.39%
Model also copied to: /content/drive/MyDrive/Three_Class/trained_models/best_model.pth


Epoch 8/100: 100%|██████████| 159/159 [00:55<00:00,  2.86it/s]

Epoch [8/100] Training Loss - Swin: 0.1343, RegNetY: 0.1778





Validation Loss - Swin: 0.2009, RegNetY: 0.2307, Total: 0.4316
Validation Accuracy - Swin: 93.31%, RegNetY: 91.50%, Ensemble: 94.09%, Kappa (Ensemble): 0.8941
Best model saved at epoch 8 with ensemble validation accuracy 94.09%
Model also copied to: /content/drive/MyDrive/Three_Class/trained_models/best_model.pth


Epoch 9/100: 100%|██████████| 159/159 [00:55<00:00,  2.87it/s]

Epoch [9/100] Training Loss - Swin: 0.1372, RegNetY: 0.1737





Validation Loss - Swin: 0.1935, RegNetY: 0.2149, Total: 0.4085
Validation Accuracy - Swin: 93.07%, RegNetY: 92.60%, Ensemble: 95.35%, Kappa (Ensemble): 0.9168
Best model saved at epoch 9 with ensemble validation accuracy 95.35%
Model also copied to: /content/drive/MyDrive/Three_Class/trained_models/best_model.pth


Epoch 10/100: 100%|██████████| 159/159 [00:55<00:00,  2.85it/s]

Epoch [10/100] Training Loss - Swin: 0.0990, RegNetY: 0.1377





Validation Loss - Swin: 0.2238, RegNetY: 0.2240, Total: 0.4477
Validation Accuracy - Swin: 92.76%, RegNetY: 92.60%, Ensemble: 94.02%, Kappa (Ensemble): 0.8919
No improvement in ensemble validation accuracy.


Epoch 11/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [11/100] Training Loss - Swin: 0.0940, RegNetY: 0.1336





Validation Loss - Swin: 0.2309, RegNetY: 0.1972, Total: 0.4281
Validation Accuracy - Swin: 93.62%, RegNetY: 93.78%, Ensemble: 95.35%, Kappa (Ensemble): 0.9166
No improvement in ensemble validation accuracy.


Epoch 12/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [12/100] Training Loss - Swin: 0.0962, RegNetY: 0.1248





Validation Loss - Swin: 0.2144, RegNetY: 0.2195, Total: 0.4339
Validation Accuracy - Swin: 93.15%, RegNetY: 93.23%, Ensemble: 94.80%, Kappa (Ensemble): 0.9063
No improvement in ensemble validation accuracy.


Epoch 13/100: 100%|██████████| 159/159 [00:52<00:00,  3.01it/s]

Epoch [13/100] Training Loss - Swin: 0.0710, RegNetY: 0.1107





Validation Loss - Swin: 0.2165, RegNetY: 0.2395, Total: 0.4559
Validation Accuracy - Swin: 93.70%, RegNetY: 92.68%, Ensemble: 95.12%, Kappa (Ensemble): 0.9120
No improvement in ensemble validation accuracy.


Epoch 14/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [14/100] Training Loss - Swin: 0.0609, RegNetY: 0.0946





Validation Loss - Swin: 0.1804, RegNetY: 0.2028, Total: 0.3832
Validation Accuracy - Swin: 94.41%, RegNetY: 94.25%, Ensemble: 96.14%, Kappa (Ensemble): 0.9309
Best model saved at epoch 14 with ensemble validation accuracy 96.14%
Model also copied to: /content/drive/MyDrive/Three_Class/trained_models/best_model.pth


Epoch 15/100: 100%|██████████| 159/159 [00:55<00:00,  2.89it/s]

Epoch [15/100] Training Loss - Swin: 0.0612, RegNetY: 0.0880





Validation Loss - Swin: 0.2029, RegNetY: 0.2121, Total: 0.4150
Validation Accuracy - Swin: 94.57%, RegNetY: 93.62%, Ensemble: 95.43%, Kappa (Ensemble): 0.9177
No improvement in ensemble validation accuracy.


Epoch 16/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [16/100] Training Loss - Swin: 0.0694, RegNetY: 0.0862





Validation Loss - Swin: 0.2167, RegNetY: 0.2257, Total: 0.4424
Validation Accuracy - Swin: 94.96%, RegNetY: 93.78%, Ensemble: 95.43%, Kappa (Ensemble): 0.9180
No improvement in ensemble validation accuracy.


Epoch 17/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [17/100] Training Loss - Swin: 0.0592, RegNetY: 0.0744





Validation Loss - Swin: 0.1860, RegNetY: 0.2396, Total: 0.4255
Validation Accuracy - Swin: 94.49%, RegNetY: 93.78%, Ensemble: 95.12%, Kappa (Ensemble): 0.9123
No improvement in ensemble validation accuracy.


Epoch 18/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [18/100] Training Loss - Swin: 0.0450, RegNetY: 0.0755





Validation Loss - Swin: 0.2532, RegNetY: 0.2124, Total: 0.4656
Validation Accuracy - Swin: 94.57%, RegNetY: 94.33%, Ensemble: 95.83%, Kappa (Ensemble): 0.9244
No improvement in ensemble validation accuracy.


Epoch 19/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [19/100] Training Loss - Swin: 0.0606, RegNetY: 0.0724





Validation Loss - Swin: 0.1679, RegNetY: 0.2039, Total: 0.3718
Validation Accuracy - Swin: 95.59%, RegNetY: 93.62%, Ensemble: 95.98%, Kappa (Ensemble): 0.9277
No improvement in ensemble validation accuracy.


Epoch 20/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [20/100] Training Loss - Swin: 0.0447, RegNetY: 0.0738





Validation Loss - Swin: 0.2218, RegNetY: 0.1980, Total: 0.4199
Validation Accuracy - Swin: 93.70%, RegNetY: 93.94%, Ensemble: 96.14%, Kappa (Ensemble): 0.9305
No improvement in ensemble validation accuracy.


Epoch 21/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [21/100] Training Loss - Swin: 0.0431, RegNetY: 0.0516





Validation Loss - Swin: 0.1790, RegNetY: 0.2141, Total: 0.3931
Validation Accuracy - Swin: 95.51%, RegNetY: 94.25%, Ensemble: 95.43%, Kappa (Ensemble): 0.9175
No improvement in ensemble validation accuracy.


Epoch 22/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [22/100] Training Loss - Swin: 0.0304, RegNetY: 0.0585





Validation Loss - Swin: 0.2278, RegNetY: 0.2158, Total: 0.4436
Validation Accuracy - Swin: 94.72%, RegNetY: 94.25%, Ensemble: 96.06%, Kappa (Ensemble): 0.9290
No improvement in ensemble validation accuracy.


Epoch 23/100: 100%|██████████| 159/159 [00:52<00:00,  3.01it/s]

Epoch [23/100] Training Loss - Swin: 0.0248, RegNetY: 0.0499





Validation Loss - Swin: 0.2481, RegNetY: 0.2242, Total: 0.4723
Validation Accuracy - Swin: 94.49%, RegNetY: 94.09%, Ensemble: 95.83%, Kappa (Ensemble): 0.9248
No improvement in ensemble validation accuracy.


Epoch 24/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [24/100] Training Loss - Swin: 0.0412, RegNetY: 0.0544





Validation Loss - Swin: 0.1974, RegNetY: 0.2104, Total: 0.4078
Validation Accuracy - Swin: 94.80%, RegNetY: 94.72%, Ensemble: 96.14%, Kappa (Ensemble): 0.9303
No improvement in ensemble validation accuracy.


Epoch 25/100: 100%|██████████| 159/159 [00:53<00:00,  2.99it/s]

Epoch [25/100] Training Loss - Swin: 0.0322, RegNetY: 0.0456





Validation Loss - Swin: 0.1812, RegNetY: 0.1945, Total: 0.3758
Validation Accuracy - Swin: 94.88%, RegNetY: 94.41%, Ensemble: 96.85%, Kappa (Ensemble): 0.9433
Best model saved at epoch 25 with ensemble validation accuracy 96.85%
Model also copied to: /content/drive/MyDrive/Three_Class/trained_models/best_model.pth


Epoch 26/100: 100%|██████████| 159/159 [00:55<00:00,  2.86it/s]

Epoch [26/100] Training Loss - Swin: 0.0204, RegNetY: 0.0391





Validation Loss - Swin: 0.1725, RegNetY: 0.2133, Total: 0.3858
Validation Accuracy - Swin: 96.06%, RegNetY: 94.41%, Ensemble: 96.69%, Kappa (Ensemble): 0.9405
No improvement in ensemble validation accuracy.


Epoch 27/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [27/100] Training Loss - Swin: 0.0155, RegNetY: 0.0359





Validation Loss - Swin: 0.1846, RegNetY: 0.2169, Total: 0.4015
Validation Accuracy - Swin: 96.06%, RegNetY: 94.72%, Ensemble: 96.69%, Kappa (Ensemble): 0.9402
No improvement in ensemble validation accuracy.


Epoch 28/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [28/100] Training Loss - Swin: 0.0262, RegNetY: 0.0379





Validation Loss - Swin: 0.1725, RegNetY: 0.2025, Total: 0.3750
Validation Accuracy - Swin: 96.22%, RegNetY: 94.57%, Ensemble: 96.69%, Kappa (Ensemble): 0.9406
No improvement in ensemble validation accuracy.


Epoch 29/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [29/100] Training Loss - Swin: 0.0135, RegNetY: 0.0361





Validation Loss - Swin: 0.2000, RegNetY: 0.2060, Total: 0.4060
Validation Accuracy - Swin: 95.12%, RegNetY: 94.65%, Ensemble: 96.22%, Kappa (Ensemble): 0.9318
No improvement in ensemble validation accuracy.


Epoch 30/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [30/100] Training Loss - Swin: 0.0182, RegNetY: 0.0428





Validation Loss - Swin: 0.1807, RegNetY: 0.1985, Total: 0.3792
Validation Accuracy - Swin: 96.14%, RegNetY: 95.59%, Ensemble: 96.61%, Kappa (Ensemble): 0.9390
No improvement in ensemble validation accuracy.


Epoch 31/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [31/100] Training Loss - Swin: 0.0134, RegNetY: 0.0325





Validation Loss - Swin: 0.1613, RegNetY: 0.1936, Total: 0.3550
Validation Accuracy - Swin: 95.91%, RegNetY: 95.12%, Ensemble: 96.77%, Kappa (Ensemble): 0.9420
No improvement in ensemble validation accuracy.


Epoch 32/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [32/100] Training Loss - Swin: 0.0136, RegNetY: 0.0264





Validation Loss - Swin: 0.1840, RegNetY: 0.2003, Total: 0.3843
Validation Accuracy - Swin: 95.59%, RegNetY: 95.20%, Ensemble: 96.38%, Kappa (Ensemble): 0.9346
No improvement in ensemble validation accuracy.


Epoch 33/100: 100%|██████████| 159/159 [00:53<00:00,  2.99it/s]

Epoch [33/100] Training Loss - Swin: 0.0130, RegNetY: 0.0274





Validation Loss - Swin: 0.2103, RegNetY: 0.1942, Total: 0.4045
Validation Accuracy - Swin: 95.59%, RegNetY: 94.72%, Ensemble: 96.30%, Kappa (Ensemble): 0.9331
No improvement in ensemble validation accuracy.


Epoch 34/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [34/100] Training Loss - Swin: 0.0073, RegNetY: 0.0225





Validation Loss - Swin: 0.1883, RegNetY: 0.1891, Total: 0.3774
Validation Accuracy - Swin: 95.83%, RegNetY: 95.35%, Ensemble: 96.54%, Kappa (Ensemble): 0.9375
No improvement in ensemble validation accuracy.


Epoch 35/100: 100%|██████████| 159/159 [00:53<00:00,  2.99it/s]

Epoch [35/100] Training Loss - Swin: 0.0094, RegNetY: 0.0306





Validation Loss - Swin: 0.1676, RegNetY: 0.1978, Total: 0.3655
Validation Accuracy - Swin: 95.67%, RegNetY: 95.35%, Ensemble: 96.61%, Kappa (Ensemble): 0.9391
No improvement in ensemble validation accuracy.


Epoch 36/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [36/100] Training Loss - Swin: 0.0129, RegNetY: 0.0199





Validation Loss - Swin: 0.1568, RegNetY: 0.2025, Total: 0.3593
Validation Accuracy - Swin: 95.83%, RegNetY: 94.96%, Ensemble: 96.61%, Kappa (Ensemble): 0.9391
No improvement in ensemble validation accuracy.


Epoch 37/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [37/100] Training Loss - Swin: 0.0102, RegNetY: 0.0318





Validation Loss - Swin: 0.1812, RegNetY: 0.2014, Total: 0.3826
Validation Accuracy - Swin: 95.83%, RegNetY: 95.04%, Ensemble: 96.77%, Kappa (Ensemble): 0.9417
No improvement in ensemble validation accuracy.


Epoch 38/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [38/100] Training Loss - Swin: 0.0069, RegNetY: 0.0239





Validation Loss - Swin: 0.1629, RegNetY: 0.1964, Total: 0.3593
Validation Accuracy - Swin: 96.14%, RegNetY: 95.04%, Ensemble: 96.69%, Kappa (Ensemble): 0.9405
No improvement in ensemble validation accuracy.


Epoch 39/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [39/100] Training Loss - Swin: 0.0055, RegNetY: 0.0229





Validation Loss - Swin: 0.1805, RegNetY: 0.2085, Total: 0.3890
Validation Accuracy - Swin: 95.91%, RegNetY: 94.72%, Ensemble: 96.46%, Kappa (Ensemble): 0.9360
No improvement in ensemble validation accuracy.


Epoch 40/100: 100%|██████████| 159/159 [00:52<00:00,  3.01it/s]

Epoch [40/100] Training Loss - Swin: 0.0042, RegNetY: 0.0199





Validation Loss - Swin: 0.1828, RegNetY: 0.2130, Total: 0.3958
Validation Accuracy - Swin: 95.83%, RegNetY: 95.20%, Ensemble: 96.54%, Kappa (Ensemble): 0.9374
No improvement in ensemble validation accuracy.


Epoch 41/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [41/100] Training Loss - Swin: 0.0041, RegNetY: 0.0264





Validation Loss - Swin: 0.2203, RegNetY: 0.2041, Total: 0.4244
Validation Accuracy - Swin: 95.51%, RegNetY: 95.28%, Ensemble: 96.30%, Kappa (Ensemble): 0.9330
No improvement in ensemble validation accuracy.


Epoch 42/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [42/100] Training Loss - Swin: 0.0018, RegNetY: 0.0186





Validation Loss - Swin: 0.1869, RegNetY: 0.2080, Total: 0.3949
Validation Accuracy - Swin: 96.22%, RegNetY: 94.96%, Ensemble: 96.54%, Kappa (Ensemble): 0.9374
No improvement in ensemble validation accuracy.


Epoch 43/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [43/100] Training Loss - Swin: 0.0050, RegNetY: 0.0188





Validation Loss - Swin: 0.1957, RegNetY: 0.2086, Total: 0.4043
Validation Accuracy - Swin: 96.14%, RegNetY: 95.20%, Ensemble: 96.38%, Kappa (Ensemble): 0.9342
No improvement in ensemble validation accuracy.


Epoch 44/100: 100%|██████████| 159/159 [00:52<00:00,  3.01it/s]

Epoch [44/100] Training Loss - Swin: 0.0067, RegNetY: 0.0283





Validation Loss - Swin: 0.1923, RegNetY: 0.2079, Total: 0.4002
Validation Accuracy - Swin: 95.91%, RegNetY: 95.12%, Ensemble: 96.38%, Kappa (Ensemble): 0.9343
No improvement in ensemble validation accuracy.


Epoch 45/100: 100%|██████████| 159/159 [00:53<00:00,  2.98it/s]

Epoch [45/100] Training Loss - Swin: 0.0060, RegNetY: 0.0207





Validation Loss - Swin: 0.2034, RegNetY: 0.2050, Total: 0.4084
Validation Accuracy - Swin: 95.67%, RegNetY: 94.80%, Ensemble: 96.38%, Kappa (Ensemble): 0.9344
No improvement in ensemble validation accuracy.


Epoch 46/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [46/100] Training Loss - Swin: 0.0019, RegNetY: 0.0155





Validation Loss - Swin: 0.1868, RegNetY: 0.2021, Total: 0.3889
Validation Accuracy - Swin: 95.59%, RegNetY: 95.20%, Ensemble: 96.69%, Kappa (Ensemble): 0.9402
No improvement in ensemble validation accuracy.


Epoch 47/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [47/100] Training Loss - Swin: 0.0029, RegNetY: 0.0216





Validation Loss - Swin: 0.1840, RegNetY: 0.2046, Total: 0.3886
Validation Accuracy - Swin: 95.67%, RegNetY: 95.28%, Ensemble: 96.85%, Kappa (Ensemble): 0.9431
No improvement in ensemble validation accuracy.


Epoch 48/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [48/100] Training Loss - Swin: 0.0013, RegNetY: 0.0217





Validation Loss - Swin: 0.1802, RegNetY: 0.1970, Total: 0.3772
Validation Accuracy - Swin: 95.83%, RegNetY: 95.43%, Ensemble: 96.69%, Kappa (Ensemble): 0.9402
No improvement in ensemble validation accuracy.


Epoch 49/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [49/100] Training Loss - Swin: 0.0034, RegNetY: 0.0189





Validation Loss - Swin: 0.1793, RegNetY: 0.1984, Total: 0.3777
Validation Accuracy - Swin: 95.75%, RegNetY: 95.12%, Ensemble: 96.69%, Kappa (Ensemble): 0.9402
No improvement in ensemble validation accuracy.


Epoch 50/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [50/100] Training Loss - Swin: 0.0044, RegNetY: 0.0250





Validation Loss - Swin: 0.1797, RegNetY: 0.1995, Total: 0.3792
Validation Accuracy - Swin: 95.83%, RegNetY: 95.35%, Ensemble: 96.61%, Kappa (Ensemble): 0.9387
No improvement in ensemble validation accuracy.


Epoch 51/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [51/100] Training Loss - Swin: 0.0033, RegNetY: 0.0191





Validation Loss - Swin: 0.1797, RegNetY: 0.1983, Total: 0.3780
Validation Accuracy - Swin: 95.83%, RegNetY: 95.35%, Ensemble: 96.69%, Kappa (Ensemble): 0.9402
No improvement in ensemble validation accuracy.


Epoch 52/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [52/100] Training Loss - Swin: 0.0025, RegNetY: 0.0180





Validation Loss - Swin: 0.1798, RegNetY: 0.2015, Total: 0.3813
Validation Accuracy - Swin: 95.83%, RegNetY: 95.12%, Ensemble: 96.77%, Kappa (Ensemble): 0.9416
No improvement in ensemble validation accuracy.


Epoch 53/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [53/100] Training Loss - Swin: 0.0060, RegNetY: 0.0165





Validation Loss - Swin: 0.1773, RegNetY: 0.1998, Total: 0.3772
Validation Accuracy - Swin: 95.75%, RegNetY: 95.67%, Ensemble: 96.69%, Kappa (Ensemble): 0.9402
No improvement in ensemble validation accuracy.


Epoch 54/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [54/100] Training Loss - Swin: 0.0013, RegNetY: 0.0192





Validation Loss - Swin: 0.1789, RegNetY: 0.2029, Total: 0.3818
Validation Accuracy - Swin: 95.67%, RegNetY: 95.20%, Ensemble: 96.46%, Kappa (Ensemble): 0.9358
No improvement in ensemble validation accuracy.


Epoch 55/100: 100%|██████████| 159/159 [00:53<00:00,  2.98it/s]

Epoch [55/100] Training Loss - Swin: 0.0019, RegNetY: 0.0163





Validation Loss - Swin: 0.1728, RegNetY: 0.1950, Total: 0.3678
Validation Accuracy - Swin: 95.98%, RegNetY: 95.28%, Ensemble: 96.69%, Kappa (Ensemble): 0.9402
No improvement in ensemble validation accuracy.


Epoch 56/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [56/100] Training Loss - Swin: 0.0028, RegNetY: 0.0173





Validation Loss - Swin: 0.1812, RegNetY: 0.2040, Total: 0.3852
Validation Accuracy - Swin: 95.91%, RegNetY: 95.20%, Ensemble: 96.69%, Kappa (Ensemble): 0.9402
No improvement in ensemble validation accuracy.


Epoch 57/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [57/100] Training Loss - Swin: 0.0015, RegNetY: 0.0196





Validation Loss - Swin: 0.1867, RegNetY: 0.2030, Total: 0.3897
Validation Accuracy - Swin: 96.22%, RegNetY: 95.59%, Ensemble: 96.46%, Kappa (Ensemble): 0.9358
No improvement in ensemble validation accuracy.


Epoch 58/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [58/100] Training Loss - Swin: 0.0030, RegNetY: 0.0205





Validation Loss - Swin: 0.2016, RegNetY: 0.2028, Total: 0.4044
Validation Accuracy - Swin: 95.91%, RegNetY: 95.43%, Ensemble: 96.22%, Kappa (Ensemble): 0.9314
No improvement in ensemble validation accuracy.


Epoch 59/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [59/100] Training Loss - Swin: 0.0031, RegNetY: 0.0188





Validation Loss - Swin: 0.2070, RegNetY: 0.2067, Total: 0.4137
Validation Accuracy - Swin: 95.91%, RegNetY: 95.04%, Ensemble: 96.38%, Kappa (Ensemble): 0.9343
No improvement in ensemble validation accuracy.


Epoch 60/100: 100%|██████████| 159/159 [00:53<00:00,  2.99it/s]

Epoch [60/100] Training Loss - Swin: 0.0038, RegNetY: 0.0186





Validation Loss - Swin: 0.1915, RegNetY: 0.2045, Total: 0.3960
Validation Accuracy - Swin: 96.14%, RegNetY: 95.12%, Ensemble: 96.46%, Kappa (Ensemble): 0.9357
No improvement in ensemble validation accuracy.


Epoch 61/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [61/100] Training Loss - Swin: 0.0013, RegNetY: 0.0150





Validation Loss - Swin: 0.1904, RegNetY: 0.2085, Total: 0.3989
Validation Accuracy - Swin: 95.98%, RegNetY: 95.04%, Ensemble: 96.54%, Kappa (Ensemble): 0.9373
No improvement in ensemble validation accuracy.


Epoch 62/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [62/100] Training Loss - Swin: 0.0049, RegNetY: 0.0211





Validation Loss - Swin: 0.1965, RegNetY: 0.2048, Total: 0.4013
Validation Accuracy - Swin: 96.06%, RegNetY: 95.35%, Ensemble: 96.46%, Kappa (Ensemble): 0.9361
No improvement in ensemble validation accuracy.


Epoch 63/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [63/100] Training Loss - Swin: 0.0069, RegNetY: 0.0160





Validation Loss - Swin: 0.1746, RegNetY: 0.2099, Total: 0.3846
Validation Accuracy - Swin: 96.30%, RegNetY: 94.88%, Ensemble: 96.85%, Kappa (Ensemble): 0.9432
No improvement in ensemble validation accuracy.


Epoch 64/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [64/100] Training Loss - Swin: 0.0080, RegNetY: 0.0145





Validation Loss - Swin: 0.1996, RegNetY: 0.1993, Total: 0.3989
Validation Accuracy - Swin: 96.14%, RegNetY: 95.35%, Ensemble: 96.54%, Kappa (Ensemble): 0.9374
No improvement in ensemble validation accuracy.


Epoch 65/100: 100%|██████████| 159/159 [00:53<00:00,  2.99it/s]

Epoch [65/100] Training Loss - Swin: 0.0058, RegNetY: 0.0193





Validation Loss - Swin: 0.2121, RegNetY: 0.2068, Total: 0.4189
Validation Accuracy - Swin: 95.75%, RegNetY: 94.72%, Ensemble: 96.77%, Kappa (Ensemble): 0.9415
No improvement in ensemble validation accuracy.


Epoch 66/100: 100%|██████████| 159/159 [00:53<00:00,  2.99it/s]

Epoch [66/100] Training Loss - Swin: 0.0090, RegNetY: 0.0209





Validation Loss - Swin: 0.1982, RegNetY: 0.1930, Total: 0.3912
Validation Accuracy - Swin: 95.98%, RegNetY: 95.43%, Ensemble: 96.85%, Kappa (Ensemble): 0.9434
No improvement in ensemble validation accuracy.


Epoch 67/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [67/100] Training Loss - Swin: 0.0126, RegNetY: 0.0182





Validation Loss - Swin: 0.2308, RegNetY: 0.1891, Total: 0.4199
Validation Accuracy - Swin: 95.67%, RegNetY: 95.43%, Ensemble: 96.06%, Kappa (Ensemble): 0.9287
No improvement in ensemble validation accuracy.


Epoch 68/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [68/100] Training Loss - Swin: 0.0136, RegNetY: 0.0236





Validation Loss - Swin: 0.1629, RegNetY: 0.1960, Total: 0.3589
Validation Accuracy - Swin: 95.83%, RegNetY: 95.12%, Ensemble: 96.61%, Kappa (Ensemble): 0.9388
No improvement in ensemble validation accuracy.


Epoch 69/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [69/100] Training Loss - Swin: 0.0125, RegNetY: 0.0213





Validation Loss - Swin: 0.1696, RegNetY: 0.2106, Total: 0.3802
Validation Accuracy - Swin: 96.14%, RegNetY: 94.80%, Ensemble: 96.38%, Kappa (Ensemble): 0.9348
No improvement in ensemble validation accuracy.


Epoch 70/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [70/100] Training Loss - Swin: 0.0171, RegNetY: 0.0179





Validation Loss - Swin: 0.1701, RegNetY: 0.2217, Total: 0.3917
Validation Accuracy - Swin: 95.75%, RegNetY: 94.88%, Ensemble: 96.54%, Kappa (Ensemble): 0.9374
No improvement in ensemble validation accuracy.


Epoch 71/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [71/100] Training Loss - Swin: 0.0177, RegNetY: 0.0196





Validation Loss - Swin: 0.1960, RegNetY: 0.1785, Total: 0.3745
Validation Accuracy - Swin: 95.35%, RegNetY: 95.91%, Ensemble: 96.46%, Kappa (Ensemble): 0.9359
No improvement in ensemble validation accuracy.


Epoch 72/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [72/100] Training Loss - Swin: 0.0242, RegNetY: 0.0236





Validation Loss - Swin: 0.1996, RegNetY: 0.2178, Total: 0.4174
Validation Accuracy - Swin: 94.96%, RegNetY: 94.80%, Ensemble: 96.06%, Kappa (Ensemble): 0.9284
No improvement in ensemble validation accuracy.


Epoch 73/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [73/100] Training Loss - Swin: 0.0191, RegNetY: 0.0250





Validation Loss - Swin: 0.2078, RegNetY: 0.2018, Total: 0.4096
Validation Accuracy - Swin: 94.96%, RegNetY: 94.96%, Ensemble: 95.51%, Kappa (Ensemble): 0.9185
No improvement in ensemble validation accuracy.


Epoch 74/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [74/100] Training Loss - Swin: 0.0193, RegNetY: 0.0257





Validation Loss - Swin: 0.2294, RegNetY: 0.2084, Total: 0.4377
Validation Accuracy - Swin: 94.65%, RegNetY: 95.51%, Ensemble: 95.83%, Kappa (Ensemble): 0.9242
No improvement in ensemble validation accuracy.


Epoch 75/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [75/100] Training Loss - Swin: 0.0256, RegNetY: 0.0221





Validation Loss - Swin: 0.2237, RegNetY: 0.2127, Total: 0.4364
Validation Accuracy - Swin: 95.12%, RegNetY: 94.25%, Ensemble: 96.61%, Kappa (Ensemble): 0.9393
No improvement in ensemble validation accuracy.


Epoch 76/100: 100%|██████████| 159/159 [00:53<00:00,  2.99it/s]

Epoch [76/100] Training Loss - Swin: 0.0205, RegNetY: 0.0286





Validation Loss - Swin: 0.2301, RegNetY: 0.2274, Total: 0.4576
Validation Accuracy - Swin: 93.94%, RegNetY: 94.25%, Ensemble: 96.06%, Kappa (Ensemble): 0.9290
No improvement in ensemble validation accuracy.


Epoch 77/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [77/100] Training Loss - Swin: 0.0308, RegNetY: 0.0362





Validation Loss - Swin: 0.1868, RegNetY: 0.2313, Total: 0.4182
Validation Accuracy - Swin: 95.12%, RegNetY: 94.65%, Ensemble: 96.06%, Kappa (Ensemble): 0.9291
No improvement in ensemble validation accuracy.


Epoch 78/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [78/100] Training Loss - Swin: 0.0308, RegNetY: 0.0294





Validation Loss - Swin: 0.2533, RegNetY: 0.2229, Total: 0.4762
Validation Accuracy - Swin: 93.31%, RegNetY: 94.80%, Ensemble: 96.22%, Kappa (Ensemble): 0.9323
No improvement in ensemble validation accuracy.


Epoch 79/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [79/100] Training Loss - Swin: 0.0261, RegNetY: 0.0309





Validation Loss - Swin: 0.1767, RegNetY: 0.2103, Total: 0.3870
Validation Accuracy - Swin: 95.28%, RegNetY: 94.41%, Ensemble: 96.38%, Kappa (Ensemble): 0.9348
No improvement in ensemble validation accuracy.


Epoch 80/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [80/100] Training Loss - Swin: 0.0318, RegNetY: 0.0406





Validation Loss - Swin: 0.2275, RegNetY: 0.2193, Total: 0.4468
Validation Accuracy - Swin: 93.31%, RegNetY: 93.39%, Ensemble: 95.75%, Kappa (Ensemble): 0.9241
No improvement in ensemble validation accuracy.


Epoch 81/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [81/100] Training Loss - Swin: 0.0259, RegNetY: 0.0261





Validation Loss - Swin: 0.1918, RegNetY: 0.2208, Total: 0.4126
Validation Accuracy - Swin: 95.67%, RegNetY: 94.57%, Ensemble: 96.46%, Kappa (Ensemble): 0.9360
No improvement in ensemble validation accuracy.


Epoch 82/100: 100%|██████████| 159/159 [00:53<00:00,  2.99it/s]

Epoch [82/100] Training Loss - Swin: 0.0468, RegNetY: 0.0379





Validation Loss - Swin: 0.2074, RegNetY: 0.2373, Total: 0.4447
Validation Accuracy - Swin: 94.09%, RegNetY: 93.94%, Ensemble: 95.59%, Kappa (Ensemble): 0.9204
No improvement in ensemble validation accuracy.


Epoch 83/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [83/100] Training Loss - Swin: 0.0438, RegNetY: 0.0315





Validation Loss - Swin: 0.1764, RegNetY: 0.2279, Total: 0.4043
Validation Accuracy - Swin: 95.83%, RegNetY: 94.02%, Ensemble: 96.14%, Kappa (Ensemble): 0.9303
No improvement in ensemble validation accuracy.


Epoch 84/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [84/100] Training Loss - Swin: 0.0460, RegNetY: 0.0364





Validation Loss - Swin: 0.3353, RegNetY: 0.2323, Total: 0.5676
Validation Accuracy - Swin: 93.54%, RegNetY: 94.02%, Ensemble: 95.28%, Kappa (Ensemble): 0.9140
No improvement in ensemble validation accuracy.


Epoch 85/100: 100%|██████████| 159/159 [00:52<00:00,  3.01it/s]

Epoch [85/100] Training Loss - Swin: 0.0459, RegNetY: 0.0317





Validation Loss - Swin: 0.1716, RegNetY: 0.2506, Total: 0.4222
Validation Accuracy - Swin: 95.28%, RegNetY: 93.78%, Ensemble: 95.67%, Kappa (Ensemble): 0.9218
No improvement in ensemble validation accuracy.


Epoch 86/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [86/100] Training Loss - Swin: 0.0251, RegNetY: 0.0256





Validation Loss - Swin: 0.1902, RegNetY: 0.2210, Total: 0.4112
Validation Accuracy - Swin: 95.91%, RegNetY: 94.65%, Ensemble: 96.22%, Kappa (Ensemble): 0.9315
No improvement in ensemble validation accuracy.


Epoch 87/100: 100%|██████████| 159/159 [00:53<00:00,  2.99it/s]

Epoch [87/100] Training Loss - Swin: 0.0507, RegNetY: 0.0258





Validation Loss - Swin: 0.1902, RegNetY: 0.2586, Total: 0.4487
Validation Accuracy - Swin: 94.65%, RegNetY: 94.17%, Ensemble: 95.98%, Kappa (Ensemble): 0.9273
No improvement in ensemble validation accuracy.


Epoch 88/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [88/100] Training Loss - Swin: 0.0408, RegNetY: 0.0414





Validation Loss - Swin: 0.1736, RegNetY: 0.2289, Total: 0.4025
Validation Accuracy - Swin: 95.12%, RegNetY: 93.86%, Ensemble: 95.59%, Kappa (Ensemble): 0.9205
No improvement in ensemble validation accuracy.


Epoch 89/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [89/100] Training Loss - Swin: 0.0551, RegNetY: 0.0354





Validation Loss - Swin: 0.1994, RegNetY: 0.2048, Total: 0.4042
Validation Accuracy - Swin: 94.49%, RegNetY: 94.57%, Ensemble: 96.30%, Kappa (Ensemble): 0.9328
No improvement in ensemble validation accuracy.


Epoch 90/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [90/100] Training Loss - Swin: 0.0384, RegNetY: 0.0297





Validation Loss - Swin: 0.2708, RegNetY: 0.2353, Total: 0.5061
Validation Accuracy - Swin: 93.78%, RegNetY: 94.57%, Ensemble: 95.43%, Kappa (Ensemble): 0.9171
No improvement in ensemble validation accuracy.


Epoch 91/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [91/100] Training Loss - Swin: 0.0486, RegNetY: 0.0398





Validation Loss - Swin: 0.3068, RegNetY: 0.2774, Total: 0.5842
Validation Accuracy - Swin: 93.07%, RegNetY: 92.83%, Ensemble: 95.20%, Kappa (Ensemble): 0.9126
No improvement in ensemble validation accuracy.


Epoch 92/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [92/100] Training Loss - Swin: 0.0455, RegNetY: 0.0341





Validation Loss - Swin: 0.2284, RegNetY: 0.2609, Total: 0.4893
Validation Accuracy - Swin: 94.41%, RegNetY: 93.70%, Ensemble: 95.20%, Kappa (Ensemble): 0.9129
No improvement in ensemble validation accuracy.


Epoch 93/100: 100%|██████████| 159/159 [00:53<00:00,  2.99it/s]

Epoch [93/100] Training Loss - Swin: 0.0479, RegNetY: 0.0396





Validation Loss - Swin: 0.2135, RegNetY: 0.2528, Total: 0.4664
Validation Accuracy - Swin: 94.41%, RegNetY: 94.02%, Ensemble: 95.67%, Kappa (Ensemble): 0.9225
No improvement in ensemble validation accuracy.


Epoch 94/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [94/100] Training Loss - Swin: 0.0346, RegNetY: 0.0325





Validation Loss - Swin: 0.2201, RegNetY: 0.2368, Total: 0.4568
Validation Accuracy - Swin: 94.17%, RegNetY: 94.09%, Ensemble: 95.67%, Kappa (Ensemble): 0.9218
No improvement in ensemble validation accuracy.


Epoch 95/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [95/100] Training Loss - Swin: 0.0480, RegNetY: 0.0268





Validation Loss - Swin: 0.1758, RegNetY: 0.2352, Total: 0.4109
Validation Accuracy - Swin: 95.04%, RegNetY: 94.65%, Ensemble: 96.38%, Kappa (Ensemble): 0.9349
No improvement in ensemble validation accuracy.


Epoch 96/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [96/100] Training Loss - Swin: 0.0450, RegNetY: 0.0329





Validation Loss - Swin: 0.1711, RegNetY: 0.2501, Total: 0.4212
Validation Accuracy - Swin: 95.20%, RegNetY: 93.94%, Ensemble: 96.06%, Kappa (Ensemble): 0.9295
No improvement in ensemble validation accuracy.


Epoch 97/100: 100%|██████████| 159/159 [00:53<00:00,  3.00it/s]

Epoch [97/100] Training Loss - Swin: 0.0467, RegNetY: 0.0356





Validation Loss - Swin: 0.1890, RegNetY: 0.2493, Total: 0.4383
Validation Accuracy - Swin: 95.35%, RegNetY: 94.17%, Ensemble: 95.83%, Kappa (Ensemble): 0.9250
No improvement in ensemble validation accuracy.


Epoch 98/100: 100%|██████████| 159/159 [00:53<00:00,  2.99it/s]

Epoch [98/100] Training Loss - Swin: 0.0400, RegNetY: 0.0508





Validation Loss - Swin: 0.2945, RegNetY: 0.2258, Total: 0.5203
Validation Accuracy - Swin: 92.99%, RegNetY: 94.72%, Ensemble: 94.49%, Kappa (Ensemble): 0.9004
No improvement in ensemble validation accuracy.


Epoch 99/100: 100%|██████████| 159/159 [00:52<00:00,  3.00it/s]

Epoch [99/100] Training Loss - Swin: 0.0475, RegNetY: 0.0345





Validation Loss - Swin: 0.2078, RegNetY: 0.2139, Total: 0.4217
Validation Accuracy - Swin: 93.86%, RegNetY: 94.49%, Ensemble: 96.30%, Kappa (Ensemble): 0.9330
No improvement in ensemble validation accuracy.


Epoch 100/100: 100%|██████████| 159/159 [00:53<00:00,  2.99it/s]

Epoch [100/100] Training Loss - Swin: 0.0309, RegNetY: 0.0354





Validation Loss - Swin: 0.2674, RegNetY: 0.2491, Total: 0.5165
Validation Accuracy - Swin: 93.62%, RegNetY: 94.57%, Ensemble: 95.20%, Kappa (Ensemble): 0.9125
No improvement in ensemble validation accuracy.
Training completed.


  checkpoint = torch.load('best_model.pth', map_location=device)


Using weighted ensemble: Swin Weight=0.50, RegNetY Weight=0.50
Final Validation Accuracy - Swin: 94.88%, RegNetY: 94.41%, Ensemble: 96.93%
Confusion Matrix:
[[659  16   3]
 [  8 488   2]
 [  3   7  84]]
Final Kappa (Ensemble): 0.9447
