# Camelyon Patch Classification

## Training a CNN
We will first train a CNN on a dataset of [histopathology patches](https://en.wikipedia.org/wiki/Histopathology). This data corresponds to digitized microscopic analysis of tumor tissue, which has been divided into patches. The objective is to classify the patches into the ones containing tumor tissue, and ones not containing any tumor tissue. We will use the [PCAM dataset](https://github.com/basveeling/pcam) which consists of 96x96 pixel patches. We will only use the validation set (which contains 32768 patches and which should take about 0.8 GB of storage) in order to make the training faster.

In [None]:
import h5py
import random
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchnet as tnt
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.utils import download_file_from_google_drive, _decompress

You can Download the dataset which is stored in a `.h5` file.
The images can be download from [here](https://drive.google.com/uc?export=download&id=1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3), and the labels from [here](https://drive.google.com/uc?export=download&id=1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO). Please then unzip the files and write the paths below.

*Uncomment the following cell and run it*

In [None]:
# You can run the following cell to download the files on colab
# base_folder = "./"
# archive_name = "camelyonpatch_level_2_split_valid_x.h5.gz"
# download_file_from_google_drive("1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3", base_folder, filename=archive_name, md5="d5b63470df7cfa627aeec8b9dc0c066e")
# _decompress(base_folder + archive_name)

# archive_name = "camelyonpatch_level_2_split_valid_y.h5.gz"
# download_file_from_google_drive("1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO", base_folder, filename=archive_name, md5="2b85f58b927af9964a4c15b8f7e8f179")
# _decompress(base_folder + archive_name)

In [None]:
IMAGES_PATH = "data/camelyonpatch_level_2_split_valid_x.h5"
LABELS_PATH = "data/camelyonpatch_level_2_split_valid_y.h5"

In [None]:
images = np.array(h5py.File(IMAGES_PATH)['x'])
labels = np.array([y.item() for y in h5py.File(LABELS_PATH)['y']])

Now that we have the data, we will want to split it into a training and a validation set. For this, we will write a function which takes in as input the size of the dataset, and which will return the indices of the training set and the indices of the validation set.

In [None]:
random.seed(0)

In [None]:
def get_split_indices(dataset_length, train_ratio=0.7):
    """
    Function which splits the data into tranining and validation sets.
    arguments:
        dataset_length [int]: number of elements in the dataset
        train_ratio [float]: ratio of the dataset in the training set
    returns:
        train_indices [list]: list of indices in the training set (of size dataset_length*train_ratio)
        val_indices [list]: list of indices in the validation set (of size dataset_length*(1-train_ratio))
    """
    indices = list(range(dataset_length))
    random.shuffle(indices)
    return indices[:round(dataset_length*train_ratio)], indices[round(dataset_length*train_ratio):]

In [None]:
train_indices, val_indices = get_split_indices(len(labels))
print(f"There are {len(train_indices)} train indices and {len(val_indices)} validation indices.")

Now let's write the dataset classes. We can add any type of data augmentation that you like. Please note that pytorch has an implemented PCAM dataset class, but for learning sake we will code these using from scratch.

In [None]:
class PCAMDataset(Dataset):
    def __init__(self, data, labels, train):
        """
        Dataset class for the PCAM dataset.
        arguments:
            data [numpy.array]: all RGB 96-96 images
            labels [numpy.array]: corresponding labels
            train [bool]: whether the dataset is training or validation
        """
        super(PCAMDataset, self).__init__()
        self.data = data
        self.labels = labels
        self.train = train

        if self.train:
            self.augmentation = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize((96, 96)),
                transforms.ColorJitter(brightness=0.2, contrast=0.2),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), #These values are arbitrary
                transforms.RandomRotation(20),
                transforms.RandomHorizontalFlip()
            ])
        else:
            self.augmentation = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        image, label = self.data[idx], self.labels[idx]
        return self.augmentation(image), label

In [None]:
BATCH_SIZE = 32

In [None]:
train_dataset = PCAMDataset(images[train_indices], labels[train_indices], train=True)
val_dataset = PCAMDataset(images[val_indices], labels[val_indices], train=False)
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

We will now display a random sample of images that have a label of 0 (not containing any tumor tissue) and 1 (containing tumor tissue).

In [None]:
tumor_validation_samples = [sample for sample in val_dataset if sample[1] == 1]
no_tumor_validation_samples = [sample for sample in val_dataset if sample[1] == 0]

In [None]:
tumor_train_samples = [sample for sample in train_dataset if sample[1] == 1]
no_tumor_train_samples = [sample for sample in train_dataset if sample[1] == 0]

In [None]:
fig, axes = plt.subplots(1,2, figsize = (10, 5))
random_tumor_sample = tumor_validation_samples[random.randint(0, len(tumor_validation_samples)-1)]
random_no_tumor_sample = no_tumor_validation_samples[random.randint(0, len(no_tumor_validation_samples)-1)]

axes[0].imshow(random_tumor_sample[0].transpose(2,0), label ="test")
axes[0].set_title(f"Random sample containing tumor tissue :{random_tumor_sample[1]}")
axes[1].imshow(random_no_tumor_sample[0].transpose(2,0), label = "test")
axes[1].set_title(f"Random sample not containing tumor tissue :{random_no_tumor_sample[1]}")
plt.show()

It looks like tissues that contains tumors have some colored spot or discrepencies (heterogeneous cells), whereas non tumored tissues seem more homogeneous without speficic details.

Now we will plot the distribution of class labels in the training and validation datasets, to see how well the classes are balanced.

In [None]:
fig, axes = plt.subplots(1, 2, figsize =(10,5))

axes[0].pie([len(tumor_train_samples), len(no_tumor_train_samples)], labels=["Tumor tissue", "No Tumor tissue"], autopct='%.2f%%', shadow=True, colors = ["gold", "peru"])
axes[0].set_title("Label distribution in Training Dataset")
axes[1].pie([len(tumor_validation_samples), len(no_tumor_validation_samples)], labels=["Tumor tissue", "No Tumor tissue"], autopct='%.2f%%', shadow=True, colors = ["gold", "peru"])
axes[1].set_title("Label distribution in Validation Dataset")
plt.show()

Let's write our first CNN model:

In [None]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3,stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3,stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3,stride=1, padding=1)

        self.fc1 = nn.Linear(256*24*24, 512)
        self.fc2 = nn.Linear(512,1)

        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)

        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool(x)

        x = torch.flatten(x, 1)
        
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return torch.sigmoid(x)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device used : {device}')
model = ConvNet().to(device)

Initialization of the training hyperparameters: We will code the whole training loop, where the model is validated after each epoch:

In [None]:
lr = 0.001
num_epochs = 20
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
metric = tnt.meter.ConfusionMeter(2)

In [None]:
total_train_losses = []
total_val_losses = []
total_train_accuracies = []
total_val_accuracies = []
for epoch in range(1, num_epochs+1):
    ##TRAINING##
    model.train()
    train_losses = []
    metric.reset()
    print(f'Epoch: {epoch}/{num_epochs}')

    for batch in tqdm(train_dataloader):
        img_batch, lbl_batch = batch

        optimizer.zero_grad()
        outputs = model(img_batch.to(device))
        pred = (outputs > 0.5).float().to(device)
        loss = criterion(outputs, lbl_batch.float().unsqueeze(1).to(device))
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        metric.add(pred.squeeze(1), lbl_batch.long())

    total_train_losses.append(np.mean(train_losses))
    train_acc=(np.trace(metric.conf)/float(np.ndarray.sum(metric.conf))) *100
    total_train_accuracies.append(train_acc)

    model.eval()
    val_losses = []
    metric.reset()
    
    for batch in tqdm(val_dataloader):
        img_batch, lbl_batch = batch

        outputs = model(img_batch.float().to(device))
        pred = (outputs > 0.5).float()
        loss = criterion(outputs, lbl_batch.float().unsqueeze(1).to(device))
        val_losses.append(loss.item())
        metric.add(pred.squeeze(1), lbl_batch.long())
        
    total_val_losses.append(np.mean(val_losses))
    val_acc=(np.trace(metric.conf)/float(np.ndarray.sum(metric.conf))) *100
    total_val_accuracies.append(val_acc)

    print('Confusion Matrix:')
    print(metric.conf)
    print(f"Train Loss : {np.mean(train_losses)}, Train Accuracy: {train_acc}")
    print(f"Validation Loss : {np.mean(val_losses)}, Validation Accuracy: {val_acc}")

Now we can validate our model, show that it is not overfitting.

In [None]:
_, axes = plt.subplots(2,1, figsize=(15, 10))
axes[0].plot(total_train_accuracies, label = "Train Accuracy")
axes[0].plot(total_val_accuracies, label = "Validation Accuracy")
axes[0].set_title("Accuracies vs epoch")
axes[0].set_xlabel("Epochs")
axes[0].set_ylabel("Accuracy")
axes[0].legend()
axes[1].plot(total_train_losses, label = "Train Loss")
axes[1].plot(total_val_losses, label = "Validation Loss")
axes[1].set_title("Losses vs epoch")
axes[1].set_xlabel("Epochs")
axes[1].set_ylabel("Loss")
axes[1].legend()

I used the confusion matrix as a metric, to be able to get all the metrics in once (accuracy, FPR, TPR then F1-Score). But here we can only look at the training/validation losses and accuracies and see when the model is not overfitting since we transformed validation data (augmentation) we added some noise which makes the model more efficient on non transformed validation data.

We can try to optimize three hyperparameters (the learning rate, the batch size and the number of layers in your CNN model), to see it improves the efficiency of the model.

To do so, we use bayesian optimization to find the best set of hyperparameters using `scikit-optimize` library.

In [None]:
from skopt import gp_minimize
from skopt.utils import use_named_args
from skopt.space import Real, Integer, Categorical

In [None]:
dimensions = [
    Real(1e-5, 1e-1, prior="log-uniform", name="learning_rate"),
    Categorical([16, 32, 64, 128, 256], name="batch_size")
]
parameters_default_values = [
    0.001, #lr
    32 #BATCHSIZE
]
 # default value for each parameter for initialization

In [None]:
@use_named_args(dimensions=dimensions)
def fit_opt(learning_rate, batch_size):
    model = ConvNet().to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = torch.nn.BCELoss()

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False
    )
    
    model.train()
    for batch in tqdm(train_loader):
        img_batch, lbl_batch = batch
        optimizer.zero_grad()
        outputs = model(img_batch.to(device))
        loss = criterion(outputs, lbl_batch.float().unsqueeze(1).to(device))
        loss.backward()
        optimizer.step()

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader):
            img_batch, lbl_batch = batch
            outputs = model(img_batch.to(device))
            val_loss += criterion(outputs, lbl_batch.float().unsqueeze(1).to(device)).item()

    score = val_loss / len(val_loader) #We will minimize the mean validation loss

    return score

In [None]:
gp_result = gp_minimize(
    func=fit_opt,            # Function to minimize
    dimensions=dimensions,   # Search space
    x0=parameters_default_values,
    n_calls=11,
    random_state=42,
    verbose=True
    )

print(f"Optimal set of parameters found at iteration {np.argmin(gp_result.func_vals)}")
print(gp_result.x)

We can see that regarding the parameters we use we get a better 1-epoch loss even if this method is a bit empiric and depends on the batch etc. Also if we had more computing ressource we could do many more grid search cross validation to fine tune the model, but here we are limited. 

We couldn't fine tune the number of layers used in the convolutional layer, but it can be a very impactful hyperparameter to finetune (same for the probabilty of dropout, number of maxpool, size of the Fully connected layers)

However let's plot the losses vs batch sizes and learning rates (with the other parameter fixed):

In [None]:
learning_rates = np.array(gp_result.x_iters)[:,0]
sorted_args_lr = np.argsort(learning_rates)
batch_sizes = np.array(gp_result.x_iters)[:,1]
sorted_args_bs = np.argsort(batch_sizes)

_, axes = plt.subplots(1,2, figsize = (10,5))
axes[0].plot(np.sort(learning_rates), gp_result.func_vals[sorted_args_lr], label="Learning Rates Losses")
axes[0].set_xlabel("Learning rate")
axes[0].set_ylabel("Loss")
axes[0].set_title("Loss vs learning rates")
axes[0].legend()
axes[1].plot(np.sort(batch_sizes), gp_result.func_vals[sorted_args_bs], label="Batch sizes Losses", color = 'red')
axes[1].set_xlabel("Batch size")
axes[1].set_ylabel("Loss")
axes[1].set_title("Loss vs learning rates")
axes[1].legend()
plt.show()

For the batch sizes, it's hard to discuss on the efficiency of the chosen parameters, because it's for fixed learning rate. Here the learning rate is way more impactful on the losses which makes sens because a big learning rate causes the loss function to not converge at all.

Also instead of using self made convnet we can use well known architectures such as vgg16 which (after some research) is very efficient on this kind of classification tasks

### Explainability

#### Saliency Map & GradCAM

With the exception of using Saliency maps, we can use one other interpretability method such as GradCAM
With respect to the code block below, saliency maps are useful in interpreting the decisions of CNNs. However, they have some limitations.

In [None]:
## Code block to use saliency maps
import cv2

image = images[40]
label = labels[40]

preprocess = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
              ]) ### Here put the transforms to be applied

input_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension

model.eval()

# Set the requires_grad attribute of the input tensor to True for gradients
input_tensor.requires_grad_(True)

# Forward pass to get the model prediction
output = model(input_tensor.to(device))

# Choose the class index for which you want to visualize the saliency map
class_index = torch.argmax(output)

model.zero_grad()

# Backward pass to get the gradients of the output w.r.t the input
output[0, class_index].backward()

# Get the gradients from the input tensor
saliency_map = input_tensor.grad.squeeze(0).abs().cpu().numpy()

# Normalize the saliency map for visualization (optional)
saliency_map = saliency_map / saliency_map.max()

normalized_saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())

# Convert the saliency map back to a uint8 image format (0-255)
saliency_map_image = np.uint8(255 * normalized_saliency_map)

# Aggregate across the channels
aggregate_saliency = saliency_map.sum(axis=0)

# Plot the input image and its corresponding saliency map side by side
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Plot the input image
axes[0].imshow(image)
axes[0].set_title(f'Input Image, with {label} tumor')
axes[0].axis('off')

# Plot the saliency map
axes[1].imshow(aggregate_saliency, cmap='jet', alpha=0.7)  # Overlay saliency map on the input image
axes[1].imshow(image, alpha=0.3)  # Overlay input image for comparison
axes[1].set_title('Saliency Map')
axes[1].axis('off')

## Grad-CAM
#The following function registers the gradient and activations of our last layer
def register_hooks(layer):
    gradients = []
    activations = []

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    def forward_hook(module, input, output):
        activations.append(output)

    layer.register_forward_hook(forward_hook)
    layer.register_backward_hook(backward_hook)

    return gradients, activations

#We compute them here
gradients, activations = register_hooks(model.conv4)

# Forward pass to get the model prediction
output = model(input_tensor.to(device))

# Choose the class index for which you want to visualize the Grad-CAM (this time)
target_class = torch.argmax(output)

# Backward pass to get the gradients of the output w.r.t the input
model.zero_grad()
output[0, target_class].backward()

# Extract the gradients and activations
gradients = gradients[0].detach() 
activations = activations[0].detach() 

# Compute the heatmap
weights = gradients.mean(dim=[2, 3], keepdim=True)
gradcam = F.relu((weights * activations).sum(dim=1)).squeeze(0) 
gradcam -= gradcam.min()
gradcam /= gradcam.max()

gradcam_resized = cv2.resize(gradcam.cpu().numpy(), (image.shape[0], image.shape[1]))

# Plot the Grad-CAM heatmap
axes[2].imshow(image)
axes[2].imshow(gradcam_resized, cmap='jet', alpha=0.5)
axes[2].set_title("Grad-CAM Heatmap")
axes[2].axis('off')

plt.tight_layout()
plt.show()

Grad-CAM generates a heatmap by computing the gradients of the target class with respect to the feature maps of a specific convolutional layer label, it uses these gradients to weight the activations, highlighting regions of the image that strongly influence the model's classification.

Here, it seems that the Grad-CAM heatmap highlights biologically relevant regions, such as small tissue structures, which align with human intuition for the task. The results appear to make sense for this classification since it's focusing on hetergeneous regions. It's hard to really interpret without the knowledge of a specialist, so it's only an intuition.

The limitations, we can see on this example that interpretations can be noisy or maybe focus irrelevant things (especially in the Saliency Map below)

Also Grad-CAM depends on the choice of the layer which may not align with meaningful features and CNN remains black boxes at higher levels so complete transparency can be challenging. Some versions of Grad-CAM improve it such as Grad CAM ++ which refine the heatmap by considering the importance of individual pixels for multilple instance of the same class.

#### SHAP

In [None]:
import shap
import numpy as np
import matplotlib.pyplot as plt
import torch
import cv2
from torchvision import transforms

# Preprocess the input image
image = images[40]
label = labels[40]

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

input_tensor = preprocess(image).unsqueeze(0)  # Add batch dimension

model.eval()

# Convert the input tensor to a numpy array for SHAP
input_numpy = input_tensor.squeeze(0).permute(1, 2, 0).numpy()

# Define a SHAP explainer for the model
explainer = shap.DeepExplainer(model, torch.zeros_like(input_tensor).to(device))

# Compute SHAP values
shap_values = explainer.shap_values(input_tensor.to(device))

# Aggregate SHAP values across channels
shap_values = np.array(shap_values[0])  # Extract SHAP values for the first output class
aggregate_shap = shap_values.sum(axis=1).squeeze(0)

# Normalize the SHAP values for visualization
aggregate_shap_normalized = (aggregate_shap - aggregate_shap.min()) / (aggregate_shap.max() - aggregate_shap.min())

# Convert the SHAP values to a uint8 image format (0-255)
shap_map_image = np.uint8(255 * aggregate_shap_normalized)

# Resize SHAP values to match the original image size
shap_map_resized = cv2.resize(aggregate_shap_normalized, (image.shape[1], image.shape[0]))

# Plot the input image and its corresponding SHAP map side by side
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

# Plot the input image
axes[0].imshow(image)
axes[0].set_title(f'Input Image, with {label} tumor')
axes[0].axis('off')

# Plot the SHAP map
axes[1].imshow(image, alpha=0.3)  # Overlay input image for comparison
axes[1].imshow(shap_map_resized, cmap='jet', alpha=0.7)  # SHAP heatmap
axes[1].set_title('SHAP Heatmap')
axes[1].axis('off')

plt.tight_layout()
plt.show()


SHAP (SHapley Additive exPlanations) is a popular method for interpreting machine learning models by attributing the contribution of each feature to a model's prediction. Based on game theory's Shapley values, SHAP calculates the marginal contribution of each feature by comparing the prediction when the feature is included versus excluded. It ensures fair distribution of contributions by considering all possible feature combinations. This method works with a variety of model types and can provide global insights into feature importance or local explanations for individual predictions. In computer vision, SHAP is adapted to handle image data, highlighting regions of an image that most influence the model's output.

## Traning a Visual Transformer

In [None]:
from torchvision import models
class ViT(nn.Module):
    def __init__(self):
        super(ViT, self).__init__()
        self.vit = models.vit_b_16(pretrained=True)
        self.vit.heads = nn.Linear(self.vit.heads.head.in_features, 1) #The head is custom here

    def forward(self, x):
        return torch.sigmoid(self.vit(x))


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device used : {device}')
model = ViT().to(device)

In [None]:
lr = 0.0001
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
num_epochs = 10

In [None]:
#we can add a metric as before

transform = transforms.Compose([
    transforms.Resize((224, 224))
])

for epoch in range(num_epochs):
    #Training
    model.train()
    running_loss = 0.0
    for batch in tqdm(train_dataloader):
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(transform(inputs))
        loss = criterion(outputs, labels.float().unsqueeze(0).T)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    scheduler.step()
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss / len(train_dataloader)}")

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(val_dataloader):
            batch = inputs, labels
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Validation Loss: {val_loss / len(val_dataloader)}, Accuracy: {100 * correct / total:.2f}%")

torch.save(model.state_dict(), 'tumor_vit_model.pth')