## Baseline Model: Training with ViT-B/32 on Image Data only

In [1]:
# Load the Drive helper and mount
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [2]:
!kaggle datasets download -d kmader/skin-cancer-mnist-ham10000

Dataset URL: https://www.kaggle.com/datasets/kmader/skin-cancer-mnist-ham10000
License(s): CC-BY-NC-SA-4.0
Downloading skin-cancer-mnist-ham10000.zip to /content
100% 5.18G/5.20G [00:37<00:00, 134MB/s]
100% 5.20G/5.20G [00:37<00:00, 148MB/s]


In [3]:
! mkdir data
! unzip -q skin-cancer-mnist-ham10000.zip -d data

In [4]:
! mkdir models

In [8]:
# Import Libraries
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn.functional as Fun
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from transformers import ViTFeatureExtractor, ViTForImageClassification, AdamW, get_linear_schedule_with_warmup

# Set Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")


Using device: cuda


In [6]:
# Constants
# change if required
DATA_DIR = 'data/'

METADATA_FILE = os.path.join(DATA_DIR, 'HAM10000_metadata.csv')
IMAGE_DIRS = [os.path.join(DATA_DIR, 'HAM10000_images_part_1'),
              os.path.join(DATA_DIR, 'HAM10000_images_part_2')]


# Load Metadata
metadata = pd.read_csv(METADATA_FILE)
metadata.head()



Unnamed: 0,lesion_id,image_id,dx,dx_type,age,sex,localization
0,HAM_0000118,ISIC_0027419,bkl,histo,80.0,male,scalp
1,HAM_0000118,ISIC_0025030,bkl,histo,80.0,male,scalp
2,HAM_0002730,ISIC_0026769,bkl,histo,80.0,male,scalp
3,HAM_0002730,ISIC_0025661,bkl,histo,80.0,male,scalp
4,HAM_0001466,ISIC_0031633,bkl,histo,75.0,male,ear


In [18]:
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"cuDNN Version: {torch.backends.cudnn.version()}")
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    print(f"GPU Count: {torch.cuda.device_count()}")

PyTorch Version: 2.4.1+cu121
CUDA Available: True
CUDA Version: 12.1
cuDNN Version: 90100
GPU Name: NVIDIA A100-SXM4-40GB
GPU Count: 1


## Fusion Models: Training with Both Modalities


### Data Preparation


In [9]:
# Data Cleaning and Encoding for Metadata
metadata['age'].fillna(metadata['age'].median(), inplace=True)
metadata['sex'].fillna('unknown', inplace=True)
metadata['localization'].fillna('unknown', inplace=True)

# Encoding Categorical Features
le_dx = LabelEncoder()
metadata['dx_encoded'] = le_dx.fit_transform(metadata['dx'])

le_sex = LabelEncoder()
metadata['sex_encoded'] = le_sex.fit_transform(metadata['sex'])

le_loc = LabelEncoder()
metadata['localization_encoded'] = le_loc.fit_transform(metadata['localization'])

# Normalize Numerical Features
scaler = StandardScaler()
metadata['age'] = scaler.fit_transform(metadata[['age']])

# Define Feature Columns
metadata_features = ['age', 'sex_encoded', 'localization_encoded']

# Define Number of Classes
num_classes = len(le_dx.classes_)
print(f"Number of classes: {num_classes}")
print(le_dx.classes_)

Number of classes: 7
['akiec' 'bcc' 'bkl' 'df' 'mel' 'nv' 'vasc']


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  metadata['age'].fillna(metadata['age'].median(), inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  metadata['sex'].fillna('unknown', inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we a

In [10]:
metadata.head()

Unnamed: 0,lesion_id,image_id,dx,dx_type,age,sex,localization,dx_encoded,sex_encoded,localization_encoded
0,HAM_0000118,ISIC_0027419,bkl,histo,1.663522,male,scalp,2,1,11
1,HAM_0000118,ISIC_0025030,bkl,histo,1.663522,male,scalp,2,1,11
2,HAM_0002730,ISIC_0026769,bkl,histo,1.663522,male,scalp,2,1,11
3,HAM_0002730,ISIC_0025661,bkl,histo,1.663522,male,scalp,2,1,11
4,HAM_0001466,ISIC_0031633,bkl,histo,1.368014,male,ear,2,1,4


### Modified Dataset Class (output both metadata and image)


In [11]:
class SkinCancerMultimodalDataset(Dataset):
    def __init__(self, dataframe, image_dirs, feature_extractor, metadata_features, transform=None):
        self.dataframe = dataframe
        self.image_dirs = image_dirs
        self.feature_extractor = feature_extractor
        self.transform = transform
        self.metadata_features = metadata_features

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        image_id = row['image_id']
        # Search for the image in the provided directories
        for directory in self.image_dirs:
            image_path = os.path.join(directory, f"{image_id}.jpg")
            if os.path.exists(image_path):
                break
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        metadata = torch.tensor(row[self.metadata_features].values.astype(np.float32))
        label = torch.tensor(row['dx_encoded'])
        return image, metadata, label

### Transform Pipeline

In [12]:
# Initialize Feature Extractor
MODEL_NAME = 'google/vit-base-patch32-224-in21k'  # ViT-B_32
feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_NAME)

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])

val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



### Creating Multimodal Datasets and DataLoaders

In [13]:
# Split Data into Train, Validation, and Test Sets
train_df, test_df = train_test_split(metadata, test_size=0.2, stratify=metadata['dx'], random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.1, stratify=train_df['dx'], random_state=42)

print(f"Training samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"Testing samples: {len(test_df)}")


Training samples: 7210
Validation samples: 802
Testing samples: 2003


In [14]:
# Creating Multimodal Datasets and DataLoaders
train_dataset = SkinCancerMultimodalDataset(train_df, IMAGE_DIRS, feature_extractor, metadata_features, transform=train_transform)
val_dataset = SkinCancerMultimodalDataset(val_df, IMAGE_DIRS, feature_extractor, metadata_features, transform=val_test_transform)
test_dataset = SkinCancerMultimodalDataset(test_df, IMAGE_DIRS, feature_extractor, metadata_features, transform=val_test_transform)

BATCH_SIZE = 970

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")
print(f"Number of testing batches: {len(test_loader)}")


Number of training batches: 8
Number of validation batches: 1
Number of testing batches: 3


### Early Fusion Model (Feature-Level Fusion)

In [15]:
# Implementation

import torch.nn as nn
from transformers import ViTModel

class EarlyFusionModel(nn.Module):
    def __init__(self, num_metadata_features, num_classes, pretrained_model_name='google/vit-base-patch32-224-in21k'):
        super(EarlyFusionModel, self).__init__()
        self.vit = ViTModel.from_pretrained(pretrained_model_name)
        vit_hidden_size = self.vit.config.hidden_size
        self.metadata_fc = nn.Linear(num_metadata_features, 128)
        self.classifier = nn.Linear(vit_hidden_size + 128, num_classes)
        self.dropout = nn.Dropout(0.3)

    def forward(self, image, metadata):
        vit_outputs = self.vit(image)
        image_features = vit_outputs.last_hidden_state[:, 0, :]  # CLS token
        metadata_features = self.metadata_fc(metadata)
        metadata_features = torch.relu(metadata_features)
        combined_features = torch.cat((image_features, metadata_features), dim=1)
        combined_features = self.dropout(combined_features)
        logits = self.classifier(combined_features)
        return logits


# Integration

# Initialize Early Fusion Model
num_metadata_features = len(metadata_features)
model_early_fusion = EarlyFusionModel(num_metadata_features=num_metadata_features, num_classes=num_classes)
model_early_fusion.to(DEVICE)

# Define Optimizer and Scheduler
EPOCHS = 10
optimizer_early = AdamW(model_early_fusion.parameters(), lr=2e-5)
scheduler_early = get_linear_schedule_with_warmup(
    optimizer_early,
    num_warmup_steps=int(0.1 * len(train_loader) * EPOCHS),
    num_training_steps=len(train_loader) * EPOCHS
)

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/352M [00:00<?, ?B/s]



### Train/eval loop


In [16]:
def train_epoch(model, dataloader, optimizer, scheduler, device, loss_fn, input_types=['image']):
    """
    Trains a model for one epoch, handling various input types.

    Args:
        model (nn.Module): The neural network model to train.
        dataloader (DataLoader): DataLoader for the training data.
        optimizer (torch.optim.Optimizer): Optimizer for updating model parameters.
        scheduler (torch.optim.lr_scheduler): Learning rate scheduler.
        device (torch.device): Device to perform training on (CPU or GPU).
        loss_fn (nn.Module): Loss function, e.g., CrossEntropyLoss.
        input_types (list): List specifying the types of inputs the model expects (e.g., ['image'], ['image', 'metadata']).

    Returns:
        tuple: (epoch_loss, epoch_acc)
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, data in enumerate(tqdm(dataloader, desc="Training")):
        # Unpack and move data to device
        if 'metadata' in input_types:
            images, metadata, labels = data
            images = images.to(device)
            metadata = metadata.to(device)
        else:
            images, labels = data
            images = images.to(device)

        labels = labels.to(device)

        # Verify devices
        if batch_idx == 0:
            if 'metadata' in input_types:
                print(f"Images are on device: {images.device}")
                print(f"Metadata are on device: {metadata.device}")
            else:
                print(f"Images are on device: {images.device}")
            print(f"Labels are on device: {labels.device}")
            print(f"Model is on device: {next(model.parameters()).device}")

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        if 'metadata' in input_types:
            outputs = model(images, metadata)
        else:
            outputs = model(images)

        # Compute loss
        if hasattr(outputs, 'loss'):
            loss = outputs.loss
            logits = outputs.logits
        else:
            logits = outputs
            loss = loss_fn(logits, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Accumulate loss and accuracy
        running_loss += loss.item() * labels.size(0)
        _, preds = torch.max(logits, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


def eval_epoch(model, dataloader, device, loss_fn, input_types=['image']):
    """
    Evaluates a model on a validation or test set, handling various input types.

    Args:
        model (nn.Module): The neural network model to evaluate.
        dataloader (DataLoader): DataLoader for the evaluation data.
        device (torch.device): Device to perform evaluation on (CPU or GPU).
        loss_fn (nn.Module): Loss function, e.g., CrossEntropyLoss.
        input_types (list): List specifying the types of inputs the model expects (e.g., ['image'], ['image', 'metadata']).

    Returns:
        tuple: (epoch_loss, epoch_acc)
    """
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_idx, data in enumerate(tqdm(dataloader, desc="Evaluating")):

            # Unpack and move data to device
            if 'metadata' in input_types:
                images, metadata, labels = data
                images = images.to(device)
                metadata = metadata.to(device)
            else:
                images, labels = data
                images = images.to(device)

            labels = labels.to(device)

            # Verify devices
            if batch_idx == 0:
                if 'metadata' in input_types:
                    print(f"Images are on device: {images.device}")
                    print(f"Metadata are on device: {metadata.device}")
                else:
                    print(f"Images are on device: {images.device}")
                print(f"Labels are on device: {labels.device}")
                print(f"Model is on device: {next(model.parameters()).device}")

            # Forward pass

            if 'metadata' in input_types:
                outputs = model(*inputs)
            else:
                outputs = model(inputs, labels=labels)

            # Compute loss
            if hasattr(outputs, 'loss'):
                loss = outputs.loss  # For models like ViTForImageClassification
                logits = outputs.logits
            else:
                logits = outputs  # For custom models
                loss = loss_fn(logits, labels)

            running_loss += loss.item() * labels.size(0)
            _, preds = torch.max(logits, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


In [20]:
# Define the loss function
criterion = nn.CrossEntropyLoss()

# Training Loop
best_val_acc = 0.0

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    print("-" * 10)

    # Training Phase
    train_loss, train_acc = train_epoch(
        model=model_early_fusion,
        dataloader=train_loader,
        optimizer=optimizer_early,
        scheduler=scheduler_early,
        device=DEVICE,
        loss_fn=criterion,
        input_types=['image', 'metadata']
    )

    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")

    # Evaluation Phase
    val_loss, val_acc = eval_epoch(
        model=model_early_fusion,
        dataloader=val_loader,
        device=DEVICE,
        loss_fn=criterion,
        input_types=['image', 'metadata']
    )

    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc*100:.2f}%")

    # Checkpointing
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        os.makedirs('models/early_fusion', exist_ok=True)
        torch.save(model_early_fusion.state_dict(), 'models/early_fusion/best_model.pth')
        print("Model checkpoint saved.")


Epoch 1/10
----------


Training:   0%|          | 0/8 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Training: 100%|██████████| 8/8 [00:26<00:00,  3.37s/it]


Train Loss: 1.6345 | Train Acc: 52.22%


Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Evaluating: 100%|██████████| 1/1 [00:08<00:00,  8.41s/it]


Val Loss: 1.4739 | Val Acc: 67.08%
Model checkpoint saved.

Epoch 2/10
----------


Training:   0%|          | 0/8 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Training: 100%|██████████| 8/8 [00:26<00:00,  3.37s/it]


Train Loss: 1.3798 | Train Acc: 65.08%


Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Evaluating: 100%|██████████| 1/1 [00:08<00:00,  8.51s/it]


Val Loss: 1.3061 | Val Acc: 66.96%

Epoch 3/10
----------


Training:   0%|          | 0/8 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Training: 100%|██████████| 8/8 [00:27<00:00,  3.44s/it]


Train Loss: 1.2496 | Train Acc: 66.74%


Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Evaluating: 100%|██████████| 1/1 [00:08<00:00,  8.67s/it]


Val Loss: 1.2266 | Val Acc: 66.96%

Epoch 4/10
----------


Training:   0%|          | 0/8 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Training: 100%|██████████| 8/8 [00:27<00:00,  3.43s/it]


Train Loss: 1.1922 | Train Acc: 66.89%


Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Evaluating: 100%|██████████| 1/1 [00:08<00:00,  8.48s/it]


Val Loss: 1.1796 | Val Acc: 66.96%

Epoch 5/10
----------


Training:   0%|          | 0/8 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Training: 100%|██████████| 8/8 [00:27<00:00,  3.39s/it]


Train Loss: 1.1430 | Train Acc: 66.95%


Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Evaluating: 100%|██████████| 1/1 [00:08<00:00,  8.65s/it]


Val Loss: 1.1456 | Val Acc: 66.96%

Epoch 6/10
----------


Training:   0%|          | 0/8 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Training: 100%|██████████| 8/8 [00:27<00:00,  3.40s/it]


Train Loss: 1.1177 | Train Acc: 66.89%


Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Evaluating: 100%|██████████| 1/1 [00:08<00:00,  8.50s/it]


Val Loss: 1.1188 | Val Acc: 66.96%

Epoch 7/10
----------


Training:   0%|          | 0/8 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Training: 100%|██████████| 8/8 [00:27<00:00,  3.42s/it]


Train Loss: 1.0921 | Train Acc: 66.92%


Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Evaluating: 100%|██████████| 1/1 [00:08<00:00,  8.52s/it]


Val Loss: 1.1006 | Val Acc: 66.96%

Epoch 8/10
----------


Training:   0%|          | 0/8 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Training: 100%|██████████| 8/8 [00:27<00:00,  3.45s/it]


Train Loss: 1.0730 | Train Acc: 67.02%


Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Evaluating: 100%|██████████| 1/1 [00:08<00:00,  8.55s/it]


Val Loss: 1.0885 | Val Acc: 66.96%

Epoch 9/10
----------


Training:   0%|          | 0/8 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Training: 100%|██████████| 8/8 [00:27<00:00,  3.41s/it]


Train Loss: 1.0680 | Train Acc: 66.96%


Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Evaluating: 100%|██████████| 1/1 [00:08<00:00,  8.42s/it]


Val Loss: 1.0846 | Val Acc: 66.96%

Epoch 10/10
----------


Training:   0%|          | 0/8 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Training: 100%|██████████| 8/8 [00:27<00:00,  3.39s/it]


Train Loss: 1.0670 | Train Acc: 66.99%


Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

Inputs are on device: [device(type='cpu'), device(type='cpu')]
Labels are on device: cuda:0
Model is on device: cuda:0


Evaluating: 100%|██████████| 1/1 [00:08<00:00,  8.45s/it]

Val Loss: 1.0846 | Val Acc: 66.96%





In [None]:
def plot_confusion_matrix(model, dataloader, device, class_names):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, metadata, labels in dataloader:
            images = images.to(device)
            metadata = metadata.to(device)
            labels = labels.to(device)

            outputs = model(images, metadata)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap='Blues')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.show()

# Assuming you have a list of class names
class_names = ['bcc', 'akiec', 'mel', 'nv', 'vasc']

plot_confusion_matrix(model_early_fusion, val_loader, DEVICE, class_names)