# 3 Convolutional Neural Network for Skin Lesion Classification using PyTorch

In this tutorial, we'll build a Convolutional Neural Network (CNN) using PyTorch to classify skin lesions from the HAM10000 dataset. This dataset contains 10,000 dermatoscopic images of common pigmented skin lesions across seven diagnostic categories.

## 3.1 Downloading the dataset

The zipped dataset is about 3 GB in size, so the download may take a few minutes...

In [None]:
# ! curl -L -o skin-cancer-mnist-ham10000.zip https://www.kaggle.com/api/v1/datasets/download/kmader/skin-cancer-mnist-ham10000

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed

  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0

  0 5324M    0 80462    0     0   141k      0 10:43:21 --:--:-- 10:43:21  141k
  0 5324M    0 23.8M    0     0  15.3M      0  0:05:47  0:00:01  0:05:46 23.7M
  1 5324M    1 68.3M    0     0  26.5M      0  0:03:20  0:00:02  0:03:18 33.9M
  2 5324M    2  111M    0     0  31.4M      0  0:02:49  0:00:03  0:02:46 37.2M
  2 5324M    2  157M    0     0  34.5M      0  0:02:34  0:00:04  0:02:30 39.2M
  3 5324M    3  201M    0     0  36.2M      0  0:02:27  0:00:05  0:02:22 40.2M
  4 5324M    4  246M    0     0  37.5M      0  0:02:21  0:00:06  0:02:15 44.4M
  5 5324M    5  290M    0     0  38.4M      0  0:02:18  0:00:07  0:02:11 44.5M
  6 5324M    6  335M    0     0  39.1M      0  0:0

In [None]:
# !unzip skin-cancer-mnist-ham10000.zip

## 3.2 Importing and initial setup

In [None]:
# Install necessary packages (run this if needed)
# !pip install torch torchvision pandas matplotlib seaborn scikit-learn pillow tqdm

# Import libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms
from PIL import Image

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 3.3 Working with image data

### 3.3.1 Exploring the dataset

The HAM10000 dataset consists of 10,015 dermatoscopic images across 7 different categories:
- Melanocytic nevi (nv)
- Melanoma (mel)
- Benign keratosis-like lesions (bkl)
- Basal cell carcinoma (bcc)
- Actinic keratoses (akiec)
- Vascular lesions (vasc)
- Dermatofibroma (df)

Let's first explore the metadata:

In [None]:
# Load the metadata
metadata = pd.read_csv('HAM10000_metadata.csv')

# Display first few rows
print(metadata.head())

# Check class distribution
plt.figure(figsize=(10, 6))
sns.countplot(x='dx', data=metadata)
plt.title('Distribution of Skin Lesion Classes')
plt.xlabel('Class')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
class_counts = metadata['dx'].value_counts()
print("Class distribution:")
for class_name, count in class_counts.items():
    print(f"{class_name}: {count} images ({count/len(metadata)*100:.2f}%)")

### 3.3.2 Loading the dataset

Now, let's create a custom PyTorch dataset for loading the HAM10000 images:

In [None]:
class SkinLesionDataset(Dataset):
    def __init__(self, df, image_dir, transform=None):
        """
        Args:
            df (pandas.DataFrame): Dataframe with image metadata
            image_dir (string): Directory with all the images
            transform (callable, optional): Optional transform to be applied on a sample
        """
        self.df = df
        self.image_dir = image_dir
        self.transform = transform

        # Create a mapping from diagnosis to integer label
        self.classes = sorted(df['dx'].unique())
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Get image ID and path
        img_id = self.df.iloc[idx]['image_id']
        # Check both folders since images are split between them
        img_path = os.path.join(self.image_dir, 'HAM10000_images_part_1', f"{img_id}.jpg")
        if not os.path.exists(img_path):
            img_path = os.path.join(self.image_dir, 'HAM10000_images_part_2', f"{img_id}.jpg")

        # Load image
        image = Image.open(img_path).convert('RGB')

        # Get label
        diagnosis = self.df.iloc[idx]['dx']
        label = self.class_to_idx[diagnosis]

        if self.transform:
            image = self.transform(image)

        return image, label

### 3.3.3 Adding transformations

To make our model more robust, and to augment the dataset, we can apply some transformations to the images. These transformations include resizing, random cropping, and normalization. We will also convert the images to PyTorch tensors.

In [None]:
# Define data transformations
# Data augmentation is done only for training dataset
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) # Normalizing pixels based on ImageNet's average RGB values
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

### 3.3.4 Splitting the dataset and creating data loaders

In [None]:
# Split the dataset into training and validation sets
train_df, val_df = train_test_split(metadata, test_size=0.2, random_state=42, stratify=metadata['dx'])

print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")

# Create datasets
train_dataset = SkinLesionDataset(
    df=train_df,
    image_dir='.',  # Adjust this path as needed
    transform=data_transforms['train']
)

val_dataset = SkinLesionDataset(
    df=val_df,
    image_dir='.',  # Adjust this path as needed
    transform=data_transforms['val']
)

In [None]:
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# Get class names for reference
class_names = train_dataset.classes
print(f"Class names: {class_names}")

## 3.4 Training a CNN Model

### 3.4.1 Defining the model architecture

Now let's build our CNN architecture. We'll use a simple architecture with a few convolutional layers followed by fully connected layers. The model will take an input image of size 224x224 and output the probabilities for each of the 7 classes. Note the use of batch normalization, pooling, and dropout layers to improve the model's performance and prevent overfitting.

In [None]:
class SkinLesionCNN(nn.Module):
    def __init__(self, num_classes=7):
        super(SkinLesionCNN, self).__init__()

        # Convolutional layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)

        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)

        # Batch normalization layers
        self.bn1 = nn.BatchNorm2d(32)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(256)

        # Dropout layer
        self.dropout = nn.Dropout(0.5)

        # Fully connected layers
        # After 4 max-pooling operations with 224x224 input: 224/(2^4) = 14
        self.fc1 = nn.Linear(256 * 14 * 14, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        # First convolutional block
        x = self.pool(F.relu(self.bn1(self.conv1(x))))

        # Second convolutional block
        x = self.pool(F.relu(self.bn2(self.conv2(x))))

        # Third convolutional block
        x = self.pool(F.relu(self.bn3(self.conv3(x))))

        # Fourth convolutional block
        x = self.pool(F.relu(self.bn4(self.conv4(x))))

        # Flatten the output
        x = x.view(-1, 256 * 14 * 14)

        # Fully connected layers with dropout
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)

        return x

In [None]:
# Initialize the model
model = SkinLesionCNN(num_classes=len(class_names))
model = model.to(device)

# Print model summary
print(model)

### 3.4.2 Training the model

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    """
    Train the model and evaluate on validation set after each epoch
    """
    # Track best model
    best_acc = 0.0
    best_model_wts = model.state_dict()

    # Track loss and accuracy
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0

        # Iterate over data
        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            # Backward pass + optimize
            loss.backward()
            optimizer.step()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        train_accs.append(epoch_acc.item())

        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        # Validation phase
        model.eval()
        running_loss = 0.0
        running_corrects = 0

        # No gradient during validation
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validation"):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Forward pass
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(val_loader.dataset)
        epoch_acc = running_corrects.double() / len(val_loader.dataset)
        val_losses.append(epoch_loss)
        val_accs.append(epoch_acc.item())

        print(f'Val Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        # Save the best model
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = model.state_dict()

        print()

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, train_losses, val_losses, train_accs, val_accs

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
num_epochs = 15
model, train_losses, val_losses, train_accs, val_accs = train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs
)

# Save the trained model
torch.save(model.state_dict(), 'skin_lesion_cnn.pth')

In [None]:
def evaluate_model(model, dataloader):
    """
    Evaluate model performance on the given dataloader
    """
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    return y_true, y_pred

In [None]:
y_true, y_pred = evaluate_model(model, val_loader)

In [None]:
print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=class_names))

In [None]:
plt.figure(figsize=(10, 8))
cm = confusion_matrix(y_true, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.tight_layout()
plt.show()

### 3.4.3 Visualizing the results

Let's visualize the training process and some predictions:

In [None]:
# Plot training and validation loss
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(range(1, num_epochs+1), train_losses, 'b-', label='Training Loss')
plt.plot(range(1, num_epochs+1), val_losses, 'r-', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

In [None]:
# Plot training and validation accuracy
plt.subplot(1, 2, 2)
plt.plot(range(1, num_epochs+1), train_accs, 'b-', label='Training Accuracy')
plt.plot(range(1, num_epochs+1), val_accs, 'r-', label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# Function to visualize predictions
def visualize_predictions(model, dataloader, class_names, num_images=12):
    """
    Visualize some predictions from the model
    """
    model.eval()
    images_so_far = 0
    _ = plt.figure(figsize=(15, 10))

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(3, 4, images_so_far)
                ax.axis('off')
                ax.set_title(f'True: {class_names[labels[j]]}\nPred: {class_names[preds[j]]}',
                            color=('green' if preds[j] == labels[j] else 'red'))

                # Denormalize image
                img = inputs[j].cpu().numpy().transpose((1, 2, 0))
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                img = std * img + mean
                img = np.clip(img, 0, 1)

                plt.imshow(img)

                if images_so_far == num_images:
                    plt.tight_layout()
                    plt.show()
                    return
        plt.tight_layout()
        plt.show()



In [None]:
# Visualize predictions
visualize_predictions(model, val_loader, class_names)

## 3.5 Making predictions

Let's create a function to make predictions on new images:

In [None]:
def predict_image(model, image_path, transform, class_names):
    """
    Make a prediction on a single image
    """
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)

    # Set model to evaluation mode
    model.eval()

    # Make prediction
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = F.softmax(outputs, dim=1)
        confidence, prediction = torch.max(probabilities, 1)

    # Get prediction and confidence
    predicted_class = class_names[prediction.item()]
    confidence_score = confidence.item()

    # Display image and prediction
    plt.figure(figsize=(8, 6))
    plt.imshow(image)
    plt.title(f'Prediction: {predicted_class}\nConfidence: {confidence_score:.4f}')
    plt.axis('off')
    plt.show()

    # Return all class probabilities
    probs = probabilities.cpu().numpy()[0]
    for i, (class_name, prob) in enumerate(zip(class_names, probs)):
        print(f"{class_name}: {prob:.4f}")

    return predicted_class, confidence_score

# Example usage (replace with your image path)
# predict_image(model, 'path_to_your_image.jpg', data_transforms['val'], class_names)

## 3.6 Simplifying with Pytorch lightning 

As we have seen in the previous chapters, PyTorch Lightning is a lightweight wrapper around PyTorch that helps to organize PyTorch code. It can as easily be used for CNNs as for other model types.

In [None]:
import pytorch_lightning as pl

# Define a PyTorch Lightning module for the skin lesion CNN
class LitSkinLesionCNN(pl.LightningModule):
    def __init__(self, num_classes=7, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()  # Saves hyperparameters for logging and checkpointing
        self.lr = lr

        # Define the CNN architecture (same as before)
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.2)
        self.fc1 = nn.Linear(256 * 14 * 14, 512)
        self.fc2 = nn.Linear(512, num_classes)
        self.criterion = nn.CrossEntropyLoss()  # Loss function

    def forward(self, x):
        # Forward pass through the CNN
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool(F.relu(self.bn3(self.conv3(x))))
        x = self.pool(F.relu(self.bn4(self.conv4(x))))
        x = x.view(x.size(0), -1)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

    def training_step(self, batch, batch_idx):
        # Training step: computes loss and logs accuracy
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log('train_loss', loss, on_step=False, on_epoch=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # Validation step: computes loss and logs accuracy
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        # Define optimizer for training
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

# Example usage:
# model = LitSkinLesionCNN(num_classes=len(class_names))
# trainer = pl.Trainer(max_epochs=15, accelerator="auto")
# trainer.fit(model, train_loader, val_loader)

## 3.7 Optional exercise: Transfer learning 

Load the pretrained [resnet50](https://pytorch.org/hub/nvidia_deeplearningexamples_resnet50/) model and finetune it on the skin lesion dataset

In [None]:
class ResNetTransferModel(nn.Module):
    def __init__(self, num_classes=7):
        super(ResNetTransferModel, self).__init__()
        # Load pre-trained ResNet50
        self.resnet = ...

        # Freeze the early layers

        # Replace the final fully connected layer

    def forward(self, x):
        return self.resnet(x)

In [None]:
# initiailize the transfer learning model
transfer_model = ResNetTransferModel(num_classes=len(class_names))
transfer_model = transfer_model.to(device)

# Define loss function and optimizer for the transfer learning model
transfer_criterion = nn.CrossEntropyLoss()
transfer_optimizer = optim.Adam(transfer_model.parameters(), lr=0.0001)

# Train the transfer learning model
num_epochs_transfer = 10
transfer_model, tl_train_losses, tl_val_losses, tl_train_accs, tl_val_accs = train_model(
    transfer_model, train_loader, val_loader, transfer_criterion, transfer_optimizer, num_epochs_transfer
)

In [None]:
# Save the trained transfer learning model
torch.save(transfer_model.state_dict(), 'skin_lesion_transfer_learning.pth')

In [None]:
# Evaluate the transfer learning model
y_true_tl, y_pred_tl = evaluate_model(transfer_model, val_loader)

# Print classification report
print("Classification Report (Transfer Learning):")
print(classification_report(y_true_tl, y_pred_tl, target_names=class_names))

In [None]:
# Plot confusion matrix
plt.figure(figsize=(10, 8))
cm_tl = confusion_matrix(y_true_tl, y_pred_tl)
sns.heatmap(cm_tl, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix (Transfer Learning)')
plt.tight_layout()
plt.show()

# Compare the two models
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, num_epochs+1), val_accs, 'b-', label='Custom CNN')
plt.plot(range(1, num_epochs_transfer+1), tl_val_accs, 'r-', label='Transfer Learning')
plt.title('Validation Accuracy Comparison')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, num_epochs+1), val_losses, 'b-', label='Custom CNN')
plt.plot(range(1, num_epochs_transfer+1), tl_val_losses, 'r-', label='Transfer Learning')
plt.title('Validation Loss Comparison')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.show()

# Visualize predictions from the transfer learning model
visualize_predictions(transfer_model, val_loader, class_names)

In [None]:
class ResNetTransferModel(nn.Module):
    def __init__(self, num_classes=7):
        super(ResNetTransferModel, self).__init__()
        # Load pre-trained ResNet50
        self.resnet = torchvision.models.resnet50(pretrained=True)

        # Freeze the early layers
        for param in list(self.resnet.parameters())[:-20]:
            param.requires_grad = False

        # Replace the final fully connected layer
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.resnet(x)

# Initialize the transfer learning model
transfer_model = ResNetTransferModel(num_classes=len(class_names))
transfer_model = transfer_model.to(device)

# Define loss function and optimizer for the transfer learning model
transfer_criterion = nn.CrossEntropyLoss()
transfer_optimizer = optim.Adam(transfer_model.parameters(), lr=0.0001)

# Train the transfer learning model
num_epochs_transfer = 10
transfer_model, tl_train_losses, tl_val_losses, tl_train_accs, tl_val_accs = train_model(
    transfer_model, train_loader, val_loader, transfer_criterion, transfer_optimizer, num_epochs_transfer
)

# Save the trained transfer learning model
torch.save(transfer_model.state_dict(), 'skin_lesion_transfer_learning.pth')

# Evaluate the transfer learning model
y_true_tl, y_pred_tl = evaluate_model(transfer_model, val_loader)

# Print classification report
print("Classification Report (Transfer Learning):")
print(classification_report(y_true_tl, y_pred_tl, target_names=class_names))

# Plot confusion matrix
plt.figure(figsize=(10, 8))
cm_tl = confusion_matrix(y_true_tl, y_pred_tl)
sns.heatmap(cm_tl, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix (Transfer Learning)')
plt.tight_layout()
plt.show()

# Compare the two models
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, num_epochs+1), val_accs, 'b-', label='Custom CNN')
plt.plot(range(1, num_epochs_transfer+1), tl_val_accs, 'r-', label='Transfer Learning')
plt.title('Validation Accuracy Comparison')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, num_epochs+1), val_losses, 'b-', label='Custom CNN')
plt.plot(range(1, num_epochs_transfer+1), tl_val_losses, 'r-', label='Transfer Learning')
plt.title('Validation Loss Comparison')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.show()

# Visualize predictions from the transfer learning model
visualize_predictions(transfer_model, val_loader, class_names)