# Initialization

## Import libraries

In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
import seaborn as sns

from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedKFold

import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

from tqdm.auto import tqdm

import ssl # Quickfix to torchaudio ssl error
ssl._create_default_https_context = ssl._create_unverified_context

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

# Preprocessing

## Helper Function

In [5]:
def preprocessing(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert to RGB
    image = cv2.resize(image, (224, 224))  # Resize to 224x224
    image = image / 255.0  # Normalize to [0, 1]
    image = np.transpose(image, (2, 0, 1))  # Convert to (C, H, W)
    image = torch.tensor(image, dtype=torch.float32)
    return image

def show_image(dataloader, index):
    # Get a batch of data
    data_iter = iter(dataloader)
    images, labels = next(data_iter)

    # Ensure the index is within the batch size
    batch_size = images.size(0)
    if index >= batch_size:
        raise IndexError(f"Index {index} is out of bounds for batch size {batch_size}")

    # Get the image and label at the specified index within the batch
    image = images[index]
    label = labels[index]

    # If images were normalized, we might need to denormalize them
    # For example, if we used transforms.Normalize(mean, std), we need to unnormalize
    # Replace these mean and std values with those used in your transforms
    mean = torch.tensor([0.485, 0.456, 0.406])
    std = torch.tensor([0.229, 0.224, 0.225])
    image = image * std[:, None, None] + mean[:, None, None]

    # Convert tensor to numpy array
    image_np = image.numpy().transpose((1, 2, 0))

    # Clip values to [0,1] if necessary
    image_np = np.clip(image_np, 0, 1)

    plt.figure(figsize=(6, 6))
    plt.title(f"Label: {label.item()}")  # Use label name if available
    plt.imshow(image_np)
    plt.axis('off')  # Hide axis ticks
    plt.show()

## Custom Dataset

In [3]:
import os
import cv2
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from torchvision import transforms

class Resize:
    def __init__(self, size):
        self.size = size  # (h, w)

    def __call__(self, image):
        image = F.interpolate(image.unsqueeze(0), size=self.size, mode='bilinear', align_corners=False)
        return image.squeeze(0)

class SkinDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.label_to_index = {}
        
        self._build_label_index()

    def _build_label_index(self):
        label_names = sorted([
            d for d in os.listdir(self.root_dir)
            if os.path.isdir(os.path.join(self.root_dir, d))
        ])
        
        self.label_to_index = {label_name: idx for idx, label_name in enumerate(label_names)}
        
        for label_name in label_names:
            label_dir = os.path.join(self.root_dir, label_name)
            label_index = self.label_to_index[label_name]
            for img_name in os.listdir(label_dir):
                img_path = os.path.join(label_dir, img_name)
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                    self.image_paths.append(img_path)
                    self.labels.append(label_index)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]

        image = cv2.imread(img_path)
        if image is None:
            raise ValueError(f"Failed to load image at path: {img_path}")

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0  # Shape: (C, H, W)

        if self.transform:
            image = self.transform(image)
        else:
            image = F.interpolate(image.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)

        return image, label

In [14]:
from torchvision import transforms

data_transforms = transforms.Compose([
    Resize((224, 224)),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

train_dataset = SkinDataset(root_dir='./dataset/Train/', transform=data_transforms)
test_dataset = SkinDataset(root_dir='./dataset/Test/', transform=data_transforms)
valid_dataset = SkinDataset(root_dir='./dataset/Valid/', transform=data_transforms)

print(train_dataset.label_to_index)

from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=True)

{'Chickenpox': 0, 'Cowpox': 1, 'HFMD': 2, 'Healthy': 3, 'Measles': 4, 'Monkeypox': 5}


# Modeling

## Model construction

In [15]:
class MobileNetV3Model(nn.Module):
    def __init__(self, num_classes, extractor_trainable: bool = True):
        super(MobileNetV3Model, self).__init__()
        mobilenet = models.mobilenet_v3_large(pretrained=True)
        
        self.feature_extractor = mobilenet.features
        
        for param in self.feature_extractor.parameters():
            param.requires_grad = extractor_trainable
        
        self.out_features = mobilenet.classifier[0].in_features

        self.classifier = nn.Sequential(
            nn.Linear(self.out_features, num_classes)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        
        x = F.adaptive_avg_pool2d(x, 1).reshape(x.size(0), -1)
        
        x = self.classifier(x)
        
        return x

In [29]:
class ResNetModel(nn.Module):
    def __init__(self, num_classes, extractor_trainable=True):
        super(ResNetModel, self).__init__()
        
        # Load the pre-trained ResNet model
        resnet = models.resnet34(pretrained=True)  # You can also choose resnet18, resnet34, resnet101, etc.
        
        # Freeze the feature extractor part if extractor_trainable is False
        if not extractor_trainable:
            for param in resnet.parameters():
                param.requires_grad = False
        
        # Replace the final fully connected layer (ResNet's classifier) to match num_classes
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])  # Remove the last fc layer
        
        # Get the number of input features of the final layer
        num_features = resnet.fc.in_features
        
        # Define the new fully connected layer for classification
        self.fc = nn.Linear(num_features, num_classes)

    def forward(self, x):
        # Forward pass through the feature extractor (ResNet backbone)
        x = self.feature_extractor(x)
        
        # Flatten the output from the ResNet backbone
        x = torch.flatten(x, 1)
        
        # Pass the flattened features through the classifier
        x = self.fc(x)
        
        return x

# Training and Validation Loop

In [16]:
def training_loop(model, epochs, optimizer, loss_fn, data_loader, device, fold=0):
    epoch_losses = []
    
    for epoch in range(epochs):
        loop = tqdm(data_loader, total=len(data_loader), leave=False)
        model.train()
        mean_loss = 0

        for _, (X, y) in enumerate(loop):
            optimizer.zero_grad()

            X, y = X.to(device), y.to(device)
            
            pred = model(X)
            
            loss = loss_fn(pred, y)
            mean_loss += loss.item()
            
            loss.backward()
            optimizer.step()

            loop.set_description(f"Epoch [{epoch+1}/{epochs}]")
            loop.set_postfix(loss=loss.item())
        
        mean_loss /= len(data_loader)
        epoch_losses.append(mean_loss)
        
        print(f"Epoch [{epoch+1}/{epochs}] completed. Avg loss: {mean_loss:.4f}")
        
    print(f"Training fold {fold+1} completed.")
    
    return epoch_losses

In [11]:
from sklearn.metrics import fbeta_score

In [17]:
def validation_loop(model, loss_fn, data_loader, device):
    model.eval()
    size = len(data_loader.dataset)
    num_batches = len(data_loader)
    test_loss, correct = 0.0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for X, y in data_loader:
            X, y = X.to(device), y.to(device)

            # Forward pass
            outputs = model(X)

            # Calculate loss
            loss = loss_fn(outputs, y)
            test_loss += loss.item()

            # Get predicted classes
            _, pred_labels = torch.max(outputs, 1)

            # Calculate number of correct predictions
            correct += (pred_labels == y).sum().item()

            # Move tensors to CPU and convert to numpy arrays
            pred_labels = pred_labels.cpu().numpy()
            y = y.cpu().numpy()

            # Store predictions and true labels for metrics
            all_preds.extend(pred_labels)
            all_labels.extend(y)

    # Average loss and accuracy
    test_loss /= num_batches
    accuracy = (correct / size) * 100

    print(f"Validation Error:\n Accuracy: {accuracy:.2f}%, Avg loss: {test_loss:.4f}\n")

    # Calculate confusion matrix
    conf_matrix = confusion_matrix(all_labels, all_preds)

    # Calculate F-beta score with beta=2
    fbeta = fbeta_score(all_labels, all_preds, beta=2, average='weighted')

    print(f"F-beta Score (beta=2): {fbeta:.4f}\n")

    return conf_matrix, accuracy, (all_labels, all_preds), fbeta

## Model training

In [20]:
from torch import optim

In [31]:
def calculate_class_weights(dataset):
    """
    Calculate class weights based on the frequency of each class in the dataset.
    
    Args:
        dataset: A PyTorch dataset (e.g., DiabeticDataset).
        
    Returns:
        class_weights: A tensor of class weights to be used in the loss function.
    """
    # Get the labels from the dataset
    labels = [label for _, label in dataset]

    # Count the frequency of each class
    class_counts = np.bincount(labels)
    
    # Calculate weights as the inverse of the frequency of each class
    class_weights = 1.0 / class_counts
    
    # Normalize the weights to ensure stability
    class_weights = class_weights / class_weights.sum()

    # Convert the weights to a PyTorch tensor
    class_weights = torch.tensor(class_weights, dtype=torch.float32)
    
    return class_weights

# Usage:
# Calculate class weights based on the training dataset
class_weights = calculate_class_weights(train_dataset)

# Move the class weights to the appropriate device
class_weights = class_weights.to(device)

# Define the loss function with class weights
criterion = nn.CrossEntropyLoss(weight=class_weights)

In [33]:
model = ResNetModel(6).to(device)

# criterion = nn.CrossEntropyLoss()  # For multi-class classification
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Number of epochs
num_epochs = 20

# Train the model
train_losses = training_loop(
    model=model, 
    epochs=num_epochs, 
    optimizer=optimizer, 
    loss_fn=criterion, 
    data_loader=train_loader, 
    device=device
)

# After training, validate the model
conf_matrix, val_accuracy, (all_labels, all_preds), fbeta = validation_loop(
    model, criterion, valid_loader, device
)

                                                                        

Epoch [1/20] completed. Avg loss: 1.5062


                                                                         

Epoch [2/20] completed. Avg loss: 1.1567


                                                                         

Epoch [3/20] completed. Avg loss: 0.9042


                                                                         

Epoch [4/20] completed. Avg loss: 0.6431


                                                                         

Epoch [5/20] completed. Avg loss: 0.5565


                                                                         

Epoch [6/20] completed. Avg loss: 0.5440


                                                                         

Epoch [7/20] completed. Avg loss: 0.5094


                                                                         

Epoch [8/20] completed. Avg loss: 0.3289


                                                                          

Epoch [9/20] completed. Avg loss: 0.2431


                                                                           

Epoch [10/20] completed. Avg loss: 0.1824


                                                                          

Epoch [11/20] completed. Avg loss: 0.2265


                                                                           

Epoch [12/20] completed. Avg loss: 0.1658


                                                                           

Epoch [13/20] completed. Avg loss: 0.1839


                                                                           

Epoch [14/20] completed. Avg loss: 0.1521


                                                                           

Epoch [15/20] completed. Avg loss: 0.2057


                                                                           

Epoch [16/20] completed. Avg loss: 0.1998


                                                                           

Epoch [17/20] completed. Avg loss: 0.1273


                                                                           

Epoch [18/20] completed. Avg loss: 0.1230


                                                                           

Epoch [19/20] completed. Avg loss: 0.0965


                                                                           

Epoch [20/20] completed. Avg loss: 0.0768
Training fold 1 completed.
Validation Error:
 Accuracy: 80.56%, Avg loss: 1.0073

F-beta Score (beta=2): 0.8037



In [34]:
torch.save(model.state_dict(), 'resnet_weights.pth')