# DL Tutorial 1: Classification and Interpretability

This notebook builds upon Michelle Lochner's [deep learning tutorial](https://github.com/MichelleLochner/ml-tutorials/blob/main/tutorial-deep-learning.ipynb)

### Open In Colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Road2SKA/DL_Basics_tutorial/blob/main/classification_and_interpretability.ipynb)


In [None]:
!pip install grad-cam --quiet

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision import models
from torchvision import transforms
from torchsummary import summary

from PIL import Image
import time
import os
import subprocess

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from google.colab import output
output.enable_custom_widget_manager()

%pylab inline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
preprocess = transforms.Compose([
        transforms.ToTensor(),
    ])

In [None]:
class ThumbnailsDataset(Dataset):
    def __init__(self, root_dir, transform=None, device=None, maxsize=None):
        """
        Dataset that loads all images once and preloads them to GPU memory.

        Parameters
        ----------
        root_dir : str
            Root directory. Each subfolder = one class.
        transform : torchvision.transforms, optional
            Transforms to apply (must output a Tensor).
        device : torch.device or str, optional
            Device where data will be stored. Default: 'cuda' if available.
        """
        self.transform = transform
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

        classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {c: i for i, c in enumerate(classes)}

        self.images = []
        self.targets = []
        self.names = []

        print(f"ðŸ”„ Preloading dataset to {self.device}...")

        for c in classes:
            class_dir = os.path.join(root_dir, c)
            files = [f for f in os.listdir(class_dir) if f.lower().endswith(".png")]
            if maxsize is not None:
              files = files[:maxsize]

            for f in files:
                path = os.path.join(class_dir, f)
                im_name = os.path.splitext(f)[0]

                image = Image.open(path).convert("RGB")

                if self.transform:
                    image = self.transform(image)  # Must produce a tensor

                if not torch.is_tensor(image):
                    raise TypeError("Transform must convert images to torch.Tensor")

                # Move to GPU NOW
                image = image.to(self.device, non_blocking=True)

                self.images.append(image)
                self.targets.append(self.class_to_idx[c])
                self.names.append(im_name)

        self.targets = torch.tensor(self.targets, device=self.device)

        print(f"âœ… Loaded {len(self.images)} images.")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return {
            "image": self.images[idx],     # Already on GPU
            "name": self.names[idx],
            "class": self.targets[idx]     # Already on GPU
        }

In [None]:
!wget https://raw.githubusercontent.com/MichelleLochner/ml-tutorials/main/data/galaxy_zoo.zip
!unzip -q galaxy_zoo.zip -d galaxy_zoo
!ls -orth

In [None]:
training_dataset = ThumbnailsDataset("galaxy_zoo/galaxy_zoo/training", transform=preprocess, maxsize=None)
test_dataset = ThumbnailsDataset("galaxy_zoo/galaxy_zoo/test", transform=preprocess)

In [None]:
def plot_galaxy(dataset, idx):
    """
    Convenience function to make a nice image of a particular galaxy
    """
    # Retrieve the image
    im = dataset[idx]['image'].cpu().detach()
    # For whatever reason, torch and matplotlib expect different orders of the channels so we need to permute them
    im = im.permute(1, 2, 0)
    # Show the image
    imshow(im)
    
    # Get the class and put it in a title
    target = dataset[idx]['class']
    if target == 0:
        img_class = 'elliptical'
    else:
        img_class = 'spiral'
    xticks([])
    yticks([])
    title(img_class)

In [None]:
# Pick some random examples
inds = np.random.choice(np.arange(len(training_dataset)), 9, replace=False)

figure(figsize=(8,8))
for i in range(9):
    subplot(3,3,i+1)
    idx = inds[i]
    plot_galaxy(training_dataset, idx)

In [None]:
# A custom CNN class
class ConvNeuralNet(nn.Module):
#  Determine what layers and their order in CNN object
    def __init__(self, num_classes):
        super(ConvNeuralNet, self).__init__()
        self.conv_layer1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3)
        self.conv_layer2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3)
        self.max_pool1 = nn.MaxPool2d(kernel_size = 2, stride = 2)

        self.conv_layer3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)
        self.conv_layer4 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3)
        self.max_pool2 = nn.MaxPool2d(kernel_size = 2, stride = 2)

        # Dynamically calculate the input size for the first fully connected layer
        # Create a dummy input tensor
        dummy_input = torch.zeros(1, 3, 224, 224)
        # Pass it through the convolutional and pooling layers
        out = self.conv_layer1(dummy_input)
        out = self.conv_layer2(out)
        out = self.max_pool1(out)
        out = self.conv_layer3(out)
        out = self.conv_layer4(out)
        out = self.max_pool2(out)
        # Calculate the flattened size
        flattened_size = out.flatten(1).shape[1]

        self.fc1 = nn.Linear(flattened_size, 64)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(64, num_classes)

    # Progresses data across layers
    def forward(self, x):
        out = self.conv_layer1(x)
        out = self.conv_layer2(out)
        out = self.max_pool1(out)

        out = self.conv_layer3(out)
        out = self.conv_layer4(out)
        out = self.max_pool2(out)

        out = out.reshape(out.size(0), -1)

        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        return out

In [None]:
# Use a custom CNN
classifier = ConvNeuralNet(2).to(device) 

# OR use a pre-defined ResNet model
#classifier = models.resnet18(num_classes=2).to(device) 
summary(classifier, (3, 224, 224))

In [None]:
learning_rate = 3e-3
batch_size = 64

loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(classifier.parameters(), lr=learning_rate)

training_dataloader = DataLoader(training_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

train_losses, test_losses, train_acc, test_acc = [], [], [], []

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    """
    Function to iterate through the training set and train the network.
    """
    losses = []
    size = len(dataloader.dataset)
    correct = 0

    for batch, dat in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(dat['image'])
        loss = loss_fn(pred, dat['class'])
        correct += (pred.argmax(1) == dat['class']).type(torch.float).sum().item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.cpu().detach().numpy()[None])

        if False: #batch % 100 == 0:
            loss, current = loss.item(), batch * len(dat['image'])
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    
    correct /= size
    return np.concatenate(losses), np.array([100*correct])


def test_loop(dataloader, model, loss_fn):
    """
    Function to iterate through the test data and evaluate the algorithm.
    """
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    losses = []
    with torch.no_grad():
        for dat in dataloader:
            pred = model(dat['image'])
            test_loss += loss_fn(pred, dat['class']).item()
            correct += (pred.argmax(1) == dat['class']).type(torch.float).sum().item()
    
    test_loss /= num_batches
    correct /= size
    print(f"Test accuracy: {(100*correct):>0.1f}%, avg loss: {test_loss:>8f} ")
    return np.array([test_loss]), np.array([100*correct])

In [None]:
# Now we actually iterate through each epoch, checking performance as we go.
t1 = time.perf_counter()
epochs = 50


for t in range(epochs):
    print(f"Epoch {t+1}")
    trainloss, trainacc = train_loop(training_dataloader, classifier, loss_fn, optimizer)
    testloss, testacc = test_loop(test_dataloader, classifier, loss_fn)
    train_losses.append(trainloss)
    test_losses.append(testloss)
    train_acc.append(trainacc)
    test_acc.append(testacc)
all_train_losses = np.concatenate(train_losses)
all_test_losses = np.concatenate(test_losses)
all_train_acc = np.concatenate(train_acc)
all_test_acc = np.concatenate(test_acc)

print("Done!")
print(f"Time taken {time.perf_counter()-t1:.2f}s")


In [None]:
batches = np.linspace(0, epochs-1, all_train_losses.size)
btch_per_epoch = int(all_train_losses.size/epochs)

figure(figsize=(9,4))

subplot(1,2,1)
plot(batches, all_train_losses)
plot(batches[::btch_per_epoch]+1, all_test_losses)
xlabel('Epoch')
ylabel('Loss')
grid()
print(all_train_losses.size)
subplot(1,2,2)
plot(batches[::btch_per_epoch]+1, all_train_acc)
plot(batches[::btch_per_epoch]+1, all_test_acc)
xlabel('Epoch')
ylabel('Accuracy')
grid()

In [None]:
# Collect a set of predictions for the test data
predictions = []
targets = []
test_imgs = []
with torch.no_grad():
    for dat in test_dataloader:
        pred = classifier(dat['image'])
        predictions += list(pred.argmax(1).cpu().detach().numpy())
        targets += list(dat['class'].cpu().detach().numpy())
        test_imgs += list(dat['image'].cpu().detach().numpy().transpose(0,2,3,1))
targets = np.array(targets)
predictions = np.array(predictions)
test_imgs = np.array(test_imgs)

In [None]:
correct_inds = np.random.choice(np.arange(len(targets))[targets==predictions], 9, replace=False)
figure(figsize=(9,9))
for i in range(9):
    subplot(3,3,i+1)
    idx = correct_inds[i]
    imshow(test_imgs[idx])
    axis('off')
    title("Spiral" if targets[idx]==1 else "Ellipsoidal")

In [None]:
wrong_inds = np.random.choice(np.arange(len(targets))[targets!=predictions], 9, replace=False)
figure(figsize=(9,9))
for i in range(9):
    subplot(3,3,i+1)
    idx = wrong_inds[i]
    imshow(test_imgs[idx])
    axis('off')
    title("Spiral" if targets[idx]==1 else "Ellipsoidal")

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
cm = confusion_matrix(targets, predictions)
disp = ConfusionMatrixDisplay(cm, display_labels=['elliptical', 'spiral'])
disp.plot()

In [None]:
target_layers = [classifier.conv_layer4] #for resnet use classifier.layer4[-1]
cam = GradCAM(model=classifier, target_layers=target_layers)

In [None]:
correct_inds = np.random.choice(np.arange(len(targets))[targets==predictions], 9, replace=False)
figure(figsize=(9,9))
for i in range(9):
    subplot(3,3,i+1)
    idx = correct_inds[i]
    #targets_cam = [ClassifierOutputTarget()]
    input_tensor = torch.from_numpy(test_imgs[idx][None].transpose(0,3,1,2))
    grayscale_cam = cam(input_tensor=input_tensor) #, targets=targets_cam)
    grayscale_cam = grayscale_cam[0, :]
    visualization = show_cam_on_image(test_imgs[idx], grayscale_cam, use_rgb=True, image_weight=0.8)
    imshow(visualization)
    title(f"True {targets[idx]}, Pred {predictions[idx]}")
    axis('off')

In [None]:
wrong_inds = np.random.choice(np.arange(len(targets))[targets!=predictions], 9, replace=False)
figure(figsize=(9,9))
for i in range(9):
    subplot(3,3,i+1)
    idx = wrong_inds[i]
    #targets_cam = [ClassifierOutputTarget()]
    input_tensor = torch.from_numpy(test_imgs[idx][None].transpose(0,3,1,2))
    grayscale_cam = cam(input_tensor=input_tensor) #, targets=targets_cam)
    grayscale_cam = grayscale_cam[0, :]
    visualization = show_cam_on_image(test_imgs[idx], grayscale_cam, use_rgb=True, image_weight=0.8)
    imshow(visualization)
    title(f"True {targets[idx]}, Pred {predictions[idx]}")
    axis('off')