## Model Training with Grad-CAM:

The purpose of this code is to verify the Xception model's behavior during training. The function "train_model_with_gradcam" integrates a Grad-CAM (Gradient-weighted Class Activation Mapping) layer into the training process, enabling the generation of Grad-CAM images during validation. By saving these images, we can monitor the model's focus over time, ensuring that it learns to concentrate on skin lesions rather than irrelevant artifacts.

Grad-CAM captures activations and gradients from a designated layer of the Xception model to compute and visualize class activation maps. This provides insights into the model's decision-making process and helps verify that the model is focusing on relevant features.

The code also creates a GIF from the saved Grad-CAM images, providing a dynamic view of the model's focus across different training stages. This visual representation confirms that the model consistently pays attention to the relevant areas, enhancing the interpretability and reliability of the model's outputs.

In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, Subset
import pandas as pd
import numpy as np
from torchvision import transforms as T
from PIL import Image
import timm
import copy
import os
import cv2
import imageio
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import glob

# Custom dataset class
class CustomDDIDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.ddi_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.ddi_frame['malignant'] = self.ddi_frame['malignant'].astype(int)

    def __len__(self):
        return len(self.ddi_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.ddi_frame.iloc[idx]['DDI_file'])
        image = Image.open(img_name).convert('RGB')
        label = self.ddi_frame.iloc[idx]['malignant']
        if self.transform:
            image = self.transform(image)
        label = torch.tensor(label, dtype=torch.long)
        return image, label

# Define transformations without normalization
train_transform = T.Compose([
    T.Resize(299),
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.RandomResizedCrop(299),
    T.ToTensor()
])

val_transform = T.Compose([
    T.Resize(299),
    T.CenterCrop(299),
    T.ToTensor()
])
# Initialize datasets
full_dataset = CustomDDIDataset(csv_file='C:\\Users\\user\\DDI\\ddi_metadata.csv',
                                root_dir='C:\\Users\\user\\DDI\\images',
                                transform=train_transform)

train_dataset = CustomDDIDataset(csv_file='C:\\Users\\user\\DDI\\ddi_metadata.csv',
                                 root_dir='C:\\Users\\user\\DDI\\images',
                                 transform=train_transform)

val_dataset = CustomDDIDataset(csv_file='C:\\Users\\user\\DDI\\ddi_metadata.csv',
                               root_dir='C:\\Users\\user\\DDI\\images',
                               transform=val_transform)

# Handle class imbalance
ddi_df = pd.read_csv('C:\\Users\\user\\DDI\\ddi_metadata.csv')
labels = ddi_df['malignant'].values
class_sample_count = torch.tensor([
    (labels == 0).sum(),
    (labels == 1).sum()
])

class_weights = 1. / class_sample_count.float()
samples_weights = np.array([class_weights[int(label)] for label in labels])  # Ensure label is int
samples_weights = torch.tensor(samples_weights)

sampler = WeightedRandomSampler(weights=samples_weights, num_samples=len(samples_weights), replacement=True)

# Split dataset
train_idx, val_test_idx = train_test_split(range(len(full_dataset)), test_size=0.4, random_state=42)
val_idx, test_idx = train_test_split(val_test_idx, test_size=0.5, random_state=42)

train_subset = Subset(full_dataset, train_idx)
val_subset = Subset(full_dataset, val_idx)
test_subset = Subset(full_dataset, test_idx)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler, shuffle=False)
val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)

# Grad-CAM class with register_full_backward_hook and handling division by zero
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.model.eval()
        self.hook_layers()

    def hook_layers(self):
        def forward_hook(module, input, output):
            self.activations = output

        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]

        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_full_backward_hook(backward_hook)

    def generate_cam(self, input_image, target_class=None, epsilon=1e-8):
        output = self.model(input_image)
        self.model.zero_grad()
        if target_class is None:
            target_class = output.argmax().item()
        target = output[:, target_class]
        target.backward()
        
        gradients = self.gradients.data.numpy()[0]
        activations = self.activations.data.numpy()[0]
        
        weights = np.mean(gradients, axis=(1, 2))
        cam = np.zeros(activations.shape[1:], dtype=np.float32)
        
        for i, w in enumerate(weights):
            cam += w * activations[i]
        
        cam = np.maximum(cam, 0)
        cam = cv2.resize(cam, (input_image.shape[2], input_image.shape[3]))
        cam = cam - np.min(cam)
        cam = cam / (np.max(cam) + epsilon)  # Avoid division by zero
        return cam

def apply_colormap_on_image(org_im, activation, colormap_name='jet'):
    color_map = cv2.applyColorMap(np.uint8(255 * activation), cv2.COLORMAP_JET)
    color_map = np.float32(color_map) / 255
    cam = color_map + np.float32(org_im)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)

def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam)


# Training function with Grad-CAM saving
def train_model_with_gradcam(model, criterion, optimizer, scheduler, train_loader, val_loader, num_epochs=15, patience=3):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    target_layer = model.conv4  # Choose an appropriate layer from the model
    grad_cam = GradCAM(model, target_layer)
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                dataloader = train_loader
            else:
                model.eval()
                dataloader = val_loader

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_corrects.double() / len(dataloader.dataset)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Save Grad-CAM images
            if phase == 'val':
                for i, (inputs, labels) in enumerate(dataloader):
                    inputs = inputs.to(device)
                    img_tensor = inputs[0].unsqueeze(0)  # Take the first image
                    cam = grad_cam.generate_cam(img_tensor)
                    cam_image = show_cam_on_image(inputs[0].cpu().numpy().transpose(1, 2, 0), cam)
                    Image.fromarray(cam_image).save(f'gradcam_epoch_{epoch}_image_{i}.png')
                    if i == 5:  # Save only a few images per epoch for example
                        break

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    print(f'Best val Acc: {best_acc:4f}')
    model.load_state_dict(best_model_wts)
    return model

# Create GIF
def create_gif(output_path, pattern='gradcam_epoch_*.png'):
    images = []
    for filename in sorted(glob.glob(pattern), key=os.path.getmtime):
        images.append(imageio.imread(filename))
    imageio.mimsave(output_path, images, fps=2)

# Initialize model, criterion, optimizer, and scheduler
xception_model = timm.create_model('legacy_xception', pretrained=True)  # Use the correct model name
num_ftrs = xception_model.fc.in_features  # The final fully connected layer
xception_model.fc = nn.Linear(num_ftrs, 2)  # Replace with a new fully connected layer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
xception_model = xception_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(xception_model.parameters(), lr=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Train the model and generate Grad-CAM images
Xception_best_model = train_model_with_gradcam(xception_model, criterion, optimizer, scheduler, train_loader, val_loader, num_epochs=15, patience=3)

# Create GIF from saved images
create_gif('gradcam_model_training.gif')

Epoch 0/14
----------
train Loss: 0.6349 Acc: 0.6707
val Loss: 0.5501 Acc: 0.7557
Epoch 1/14
----------
train Loss: 0.5296 Acc: 0.7500
val Loss: 0.4796 Acc: 0.7786
Epoch 2/14
----------
train Loss: 0.4336 Acc: 0.8171
val Loss: 0.3850 Acc: 0.8321
Epoch 3/14
----------
train Loss: 0.3443 Acc: 0.8567
val Loss: 0.2918 Acc: 0.8779
Epoch 4/14
----------
train Loss: 0.3082 Acc: 0.8628
val Loss: 0.2659 Acc: 0.8855
Epoch 5/14
----------
train Loss: 0.2609 Acc: 0.8765
val Loss: 0.2244 Acc: 0.9160
Epoch 6/14
----------
train Loss: 0.2561 Acc: 0.8994
val Loss: 0.1798 Acc: 0.9389
Epoch 7/14
----------
train Loss: 0.1961 Acc: 0.9192
val Loss: 0.2199 Acc: 0.9160
Epoch 8/14
----------
train Loss: 0.2204 Acc: 0.9131
val Loss: 0.1776 Acc: 0.9466
Epoch 9/14
----------
train Loss: 0.1770 Acc: 0.9375
val Loss: 0.1665 Acc: 0.9313
Epoch 10/14
----------
train Loss: 0.1850 Acc: 0.9375
val Loss: 0.1575 Acc: 0.9389
Epoch 11/14
----------
train Loss: 0.1929 Acc: 0.9314
val Loss: 0.1892 Acc: 0.9466
Epoch 12/14
--

  images.append(imageio.imread(filename))
