# Initialization

## Import libraries

In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import seaborn as sns

from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedKFold

import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

from tqdm.auto import tqdm

import ssl # Quickfix to torchaudio ssl error
ssl._create_default_https_context = ssl._create_unverified_context

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Preprocessing

## Helper Function

In [2]:
def preprocessing(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert to RGB
    image = cv2.resize(image, (224, 224))  # Resize to 224x224
    image = image / 255.0  # Normalize to [0, 1]
    image = np.transpose(image, (2, 0, 1))  # Convert to (C, H, W)
    image = torch.tensor(image, dtype=torch.float32)
    return image

def show_image(dataloader, index):
    # Get a batch of data
    data_iter = iter(dataloader)
    images, labels = next(data_iter)

    # Ensure the index is within the batch size
    batch_size = images.size(0)
    if index >= batch_size:
        raise IndexError(f"Index {index} is out of bounds for batch size {batch_size}")

    # Get the image and label at the specified index within the batch
    image = images[index]
    label = labels[index]

    # If images were normalized, we might need to denormalize them
    # For example, if we used transforms.Normalize(mean, std), we need to unnormalize
    # Replace these mean and std values with those used in your transforms
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    image = image * std[:, None, None] + mean[:, None, None]

    # Convert tensor to numpy array
    image_np = image.numpy().transpose((1, 2, 0))

    # Clip values to [0,1] if necessary
    image_np = np.clip(image_np, 0, 1)

    plt.figure(figsize=(6, 6))
    plt.title(f"Label: {label.item()}")  # Use label name if available
    plt.imshow(image_np)
    plt.axis('off')  # Hide axis ticks
    plt.show()

## Custom Dataset

In [3]:
import os
import cv2
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from torchvision import transforms

class Resize:
    def __init__(self, size):
        self.size = size  # (h, w)

    def __call__(self, image):
        image = F.interpolate(image.unsqueeze(0), size=self.size, mode='bilinear', align_corners=False)
        return image.squeeze(0)

class SkinDataset(Dataset):
    def __init__(self, root_dir, transform=None,group_mapping=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.label_to_index = {}
        self.group_mapping = group_mapping if group_mapping else {}
        
        self._build_label_index()

    def _build_label_index(self):
        # List all folder names in the root directory
        label_names = sorted([
            d for d in os.listdir(self.root_dir)
            if os.path.isdir(os.path.join(self.root_dir, d))
        ])
    
        # Map original labels to groups (if specified)
        # Fallback to original label if not in group_mapping
        mapped_labels = [self.group_mapping.get(label, label) for label in label_names]
    
        # Create label-to-index mapping
        unique_labels = sorted(set(mapped_labels))
        self.label_to_index = {label: idx for idx, label in enumerate(unique_labels)}
    
        # Populate image paths and labels
        for label_name in label_names:
            label_dir = os.path.join(self.root_dir, label_name)
            # Map to new group label (or fallback to original)
            mapped_label = self.group_mapping.get(label_name, label_name)
            label_index = self.label_to_index[mapped_label]
    
            for img_name in os.listdir(label_dir):
                img_path = os.path.join(label_dir, img_name)
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                    self.image_paths.append(img_path)
                    self.labels.append(label_index)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        image = cv2.imread(img_path)
        if image is None:
            raise ValueError(f"Failed to load image at path: {img_path}")

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0  # Shape: (C, H, W)

        if self.transform:
            image = self.transform(image)
        else:
            image = F.interpolate(image.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)

        return image, label

In [4]:
from torchvision import transforms

group_mappings = {
    "Chickenpox": "non-cancer",
    "Cowpox": "non-cancer",
    "HFMD": "non-cancer",
    "Healthy": "non-cancer",
    "Measles": "non-cancer",
    "Monkeypox": "non-cancer",
    # Add more mappings here if needed
}

data_transforms = transforms.Compose([
    Resize((224, 224)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

train_dataset = SkinDataset(root_dir='/kaggle/input/lesion-image/dataset/Train', transform=data_transforms,group_mapping=group_mappings)
test_dataset = SkinDataset(root_dir='/kaggle/input/lesion-image/dataset/Test', transform=data_transforms,group_mapping=group_mappings)
valid_dataset = SkinDataset(root_dir='/kaggle/input/lesion-image/dataset/Valid', transform=data_transforms,group_mapping=group_mappings)

print(train_dataset.label_to_index)

from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True)

{'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3, 'mel': 4, 'non-cancer': 5, 'nv': 6, 'vasc': 7}


# Modeling

## Model construction

In [5]:
class MobileNetV3Model(nn.Module):
    def __init__(self, num_classes, extractor_trainable: bool = True):
        super(MobileNetV3Model, self).__init__()
        mobilenet = models.mobilenet_v3_large(pretrained=True)
        
        self.feature_extractor = mobilenet.features
        
        for param in self.feature_extractor.parameters():
            param.requires_grad = extractor_trainable
        
        self.out_features = mobilenet.classifier[0].in_features

        self.classifier = nn.Sequential(
            nn.Linear(self.out_features, num_classes)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        
        x = F.adaptive_avg_pool2d(x, 1).reshape(x.size(0), -1)
        
        x = self.classifier(x)
        
        return x

In [6]:
class ResNetModel(nn.Module):
    def __init__(self, num_classes, extractor_trainable=True):
        super(ResNetModel, self).__init__()
        resnet = models.resnet34(pretrained=True)
        
        if not extractor_trainable:
            for param in resnet.parameters():
                param.requires_grad = False
        
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
        num_features = resnet.fc.in_features
        self.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [7]:
import timm
import torch.nn as nn

class EfficientNetModel(nn.Module):
    def __init__(self, num_classes, model_name="efficientnet_b2", extractor_trainable=True):
        super(EfficientNetModel, self).__init__()
        # Load EfficientNet using timm
        self.model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)
        
        if not extractor_trainable:
            for param in self.model.parameters():
                param.requires_grad = False
            
            # Ensure classifier remains trainable
            for param in self.model.get_classifier().parameters():
                param.requires_grad = True

    def forward(self, x):
        return self.model(x)


# Training and Validation Loop

In [8]:
from tqdm import tqdm
import torch
from sklearn.metrics import confusion_matrix, fbeta_score

def training_loop(model, epochs, optimizer, loss_fn, data_loader, val_loader, device, fold=0):
    epoch_losses = []
    best_val_accuracy = 0
    best_model_weights = None

    for epoch in range(epochs):
        loop = tqdm(data_loader, total=len(data_loader), leave=False)
        model.train()
        mean_loss = 0

        for _, (X, y) in enumerate(loop):
            optimizer.zero_grad()

            X, y = X.to(device), y.to(device)

            pred = model(X)
            loss = loss_fn(pred, y)
            mean_loss += loss.item()

            loss.backward()
            optimizer.step()

            loop.set_description(f"Epoch [{epoch+1}/{epochs}]")
            loop.set_postfix(loss=loss.item())

        mean_loss /= len(data_loader)
        epoch_losses.append(mean_loss)

        # Perform validation and track the best accuracy
        _, val_accuracy, _, _ = validation_loop(model, loss_fn, val_loader, device)
        print(f"Epoch [{epoch+1}/{epochs}] completed. Avg loss: {mean_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%")

        # Save the model weights if validation accuracy improves
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            best_model_weights = model.state_dict()  # Save current model weights

    # Restore the best weights at the end of training
    if best_model_weights is not None:
        model.load_state_dict(best_model_weights)
        print(f"Restored model weights from epoch with best validation accuracy: {best_val_accuracy:.2f}%")

    print(f"Training fold {fold+1} completed.")

    return epoch_losses, best_val_accuracy

In [9]:
from sklearn.metrics import fbeta_score

In [10]:
def validation_loop(model, loss_fn, data_loader, device):
    model.eval()
    size = len(data_loader.dataset)
    num_batches = len(data_loader)
    test_loss, correct = 0.0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for X, y in data_loader:
            X, y = X.to(device), y.to(device)

            # Forward pass
            outputs = model(X)

            # Calculate loss
            loss = loss_fn(outputs, y)
            test_loss += loss.item()

            # Get predicted classes
            _, pred_labels = torch.max(outputs, 1)

            # Calculate number of correct predictions
            correct += (pred_labels == y).sum().item()

            # Store predictions and true labels for metrics
            all_preds.extend(pred_labels.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    # Average loss and accuracy
    test_loss /= num_batches
    accuracy = (correct / size) * 100

    print(f"Validation Error:\n Accuracy: {accuracy:.2f}%, Avg loss: {test_loss:.4f}\n")

    # Calculate confusion matrix
    conf_matrix = confusion_matrix(all_labels, all_preds)

    # Calculate F-beta score with beta=2
    fbeta = fbeta_score(all_labels, all_preds, beta=2, average='weighted')

    print(f"F-beta Score (beta=2): {fbeta:.4f}\n")

    return conf_matrix, accuracy, (all_labels, all_preds), fbeta

## Model training

In [11]:
from torch import optim

In [12]:
def calculate_class_weights(dataset):
    """
    Calculate class weights based on the frequency of each class in the dataset.
    
    Args:
        dataset: A PyTorch dataset (e.g., DiabeticDataset).
        
    Returns:
        class_weights: A tensor of class weights to be used in the loss function.
    """
    # Get the labels from the dataset
    labels = [label for _, label in dataset]

    # Count the frequency of each class
    class_counts = np.bincount(labels)
    
    # Calculate weights as the inverse of the frequency of each class
    class_weights = 1.0 / class_counts
    
    # Normalize the weights to ensure stability
    class_weights = class_weights / class_weights.sum()

    # Convert the weights to a PyTorch tensor
    class_weights = torch.tensor(class_weights, dtype=torch.float32)
    
    return class_weights

# Usage:
# Calculate class weights based on the training dataset
class_weights = calculate_class_weights(train_dataset)

# Move the class weights to the appropriate device
class_weights = class_weights.to(device)

# Define the loss function with class weights
criterion = nn.CrossEntropyLoss(weight=class_weights)

In [13]:
model = EfficientNetModel(8).to(device)

# criterion = nn.CrossEntropyLoss()  # For multi-class classification
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Number of epochs
num_epochs = 20

# Train the model
'''
train_losses = training_loop(
    model=model, 
    epochs=num_epochs, 
    optimizer=optimizer, 
    loss_fn=criterion, 
    data_loader=train_loader, 
    device=device
)
'''
train_losses, best_val_accuracy = training_loop(
    model=model,
    epochs=num_epochs,
    optimizer=optimizer,
    loss_fn=criterion,
    data_loader=train_loader, # Match this name to your training DataLoader
    val_loader=valid_loader,
    device=device
)

# After training, validate the model
conf_matrix, val_accuracy, (all_labels, all_preds), fbeta = validation_loop(
    model, criterion, valid_loader, device
)

model.safetensors:   0%|          | 0.00/36.8M [00:00<?, ?B/s]

                                                                           

Validation Error:
 Accuracy: 67.31%, Avg loss: 0.8417

F-beta Score (beta=2): 0.6794

Epoch [1/20] completed. Avg loss: 1.6589, Val Accuracy: 67.31%


                                                                           

Validation Error:
 Accuracy: 74.28%, Avg loss: 0.9055

F-beta Score (beta=2): 0.7477

Epoch [2/20] completed. Avg loss: 0.9193, Val Accuracy: 74.28%


                                                                           

Validation Error:
 Accuracy: 71.32%, Avg loss: 0.8290

F-beta Score (beta=2): 0.7230

Epoch [3/20] completed. Avg loss: 0.7341, Val Accuracy: 71.32%


                                                                           

Validation Error:
 Accuracy: 76.63%, Avg loss: 0.7244

F-beta Score (beta=2): 0.7702

Epoch [4/20] completed. Avg loss: 0.5853, Val Accuracy: 76.63%


                                                                            

Validation Error:
 Accuracy: 82.65%, Avg loss: 0.7629

F-beta Score (beta=2): 0.8228

Epoch [5/20] completed. Avg loss: 0.4956, Val Accuracy: 82.65%


                                                                            

Validation Error:
 Accuracy: 78.73%, Avg loss: 0.8124

F-beta Score (beta=2): 0.7914

Epoch [6/20] completed. Avg loss: 0.4060, Val Accuracy: 78.73%


                                                                            

Validation Error:
 Accuracy: 80.12%, Avg loss: 0.9315

F-beta Score (beta=2): 0.8000

Epoch [7/20] completed. Avg loss: 0.3911, Val Accuracy: 80.12%


                                                                            

Validation Error:
 Accuracy: 68.79%, Avg loss: 0.9741

F-beta Score (beta=2): 0.6926

Epoch [8/20] completed. Avg loss: 0.4126, Val Accuracy: 68.79%


                                                                            

Validation Error:
 Accuracy: 75.59%, Avg loss: 0.9868

F-beta Score (beta=2): 0.7532

Epoch [9/20] completed. Avg loss: 0.3896, Val Accuracy: 75.59%


                                                                             

Validation Error:
 Accuracy: 76.55%, Avg loss: 0.9125

F-beta Score (beta=2): 0.7691

Epoch [10/20] completed. Avg loss: 0.3258, Val Accuracy: 76.55%


                                                                             

Validation Error:
 Accuracy: 79.69%, Avg loss: 0.9552

F-beta Score (beta=2): 0.7974

Epoch [11/20] completed. Avg loss: 0.3264, Val Accuracy: 79.69%


                                                                             

Validation Error:
 Accuracy: 83.17%, Avg loss: 1.0030

F-beta Score (beta=2): 0.8249

Epoch [12/20] completed. Avg loss: 0.3777, Val Accuracy: 83.17%


                                                                             

Validation Error:
 Accuracy: 82.91%, Avg loss: 0.9730

F-beta Score (beta=2): 0.8273

Epoch [13/20] completed. Avg loss: 0.2362, Val Accuracy: 82.91%


                                                                             

Validation Error:
 Accuracy: 82.65%, Avg loss: 1.3258

F-beta Score (beta=2): 0.8238

Epoch [14/20] completed. Avg loss: 0.2055, Val Accuracy: 82.65%


                                                                             

Validation Error:
 Accuracy: 81.43%, Avg loss: 1.0084

F-beta Score (beta=2): 0.8143

Epoch [15/20] completed. Avg loss: 0.1479, Val Accuracy: 81.43%


                                                                             

Validation Error:
 Accuracy: 72.97%, Avg loss: 1.1536

F-beta Score (beta=2): 0.7338

Epoch [16/20] completed. Avg loss: 0.3065, Val Accuracy: 72.97%


                                                                             

Validation Error:
 Accuracy: 79.77%, Avg loss: 1.0817

F-beta Score (beta=2): 0.7964

Epoch [17/20] completed. Avg loss: 0.4142, Val Accuracy: 79.77%


                                                                             

Validation Error:
 Accuracy: 81.34%, Avg loss: 0.9149

F-beta Score (beta=2): 0.8131

Epoch [18/20] completed. Avg loss: 0.1865, Val Accuracy: 81.34%


                                                                             

Validation Error:
 Accuracy: 81.78%, Avg loss: 1.1535

F-beta Score (beta=2): 0.8166

Epoch [19/20] completed. Avg loss: 0.1519, Val Accuracy: 81.78%


                                                                             

Validation Error:
 Accuracy: 83.09%, Avg loss: 1.3611

F-beta Score (beta=2): 0.8275

Epoch [20/20] completed. Avg loss: 0.1412, Val Accuracy: 83.09%
Restored model weights from epoch with best validation accuracy: 83.17%
Training fold 1 completed.
Validation Error:
 Accuracy: 83.09%, Avg loss: 1.4744

F-beta Score (beta=2): 0.8275



In [14]:
print(conf_matrix)

[[ 20   1   6   0   5   0   1   0]
 [  1  45   1   0   0   0   7   0]
 [  4   3  82   1  11   0   9   0]
 [  2   0   3   3   1   0   4   0]
 [  3   3  21   0  44   0  52   1]
 [  0   0   0   0   0 143   1   0]
 [  2   1  18   3  28   0 602   1]
 [  0   0   0   0   0   0   0  14]]


In [15]:
from sklearn.metrics import classification_report
target_names = ['akiec','bcc','bkl','df','mel','non-cancer','nv','vasc']
report = classification_report(all_labels, all_preds,target_names=target_names)
print("\nClassification Report:\n", report)


Classification Report:
               precision    recall  f1-score   support

       akiec       0.62      0.61      0.62        33
         bcc       0.85      0.83      0.84        54
         bkl       0.63      0.75      0.68       110
          df       0.43      0.23      0.30        13
         mel       0.49      0.35      0.41       124
  non-cancer       1.00      0.99      1.00       144
          nv       0.89      0.92      0.90       655
        vasc       0.88      1.00      0.93        14

    accuracy                           0.83      1147
   macro avg       0.72      0.71      0.71      1147
weighted avg       0.82      0.83      0.82      1147



In [18]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

print("Accuracy:", accuracy_score(all_labels, all_preds))
print("Precision:", precision_score(all_labels, all_preds, average='weighted'))  # Use 'weighted' for multiclass
print("Recall:", recall_score(all_labels, all_preds, average='weighted'))        # Use 'weighted' for multiclass
print("F1 Score:", f1_score(all_labels, all_preds, average='weighted'))          # Use 'weighted' for multiclass


Accuracy: 0.8308631211857018
Precision: 0.8210572226345487
Recall: 0.8308631211857018
F1 Score: 0.8236964904806758


In [19]:
from sklearn.metrics import balanced_accuracy_score
print("Accuracy:", balanced_accuracy_score(all_labels, all_preds))

Accuracy: 0.7103244937895423


In [17]:
torch.save(model.state_dict(), 'effnet_skincancer2_weights.pth')