## Setup and Settings

In [None]:
if 'google.colab' in str(get_ipython()):
    !pip install captum
    !pip install pytorch-lightning
    from google.colab import drive
    drive.mount('/content/drive')
    %cd /content/drive/MyDrive/Thesis/

In [None]:
import os

import numpy as np
import random
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedShuffleSplit
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import (
    DataLoader,
    Subset,
    WeightedRandomSampler,
    random_split
)
from torchvision import datasets, transforms
import pytorch_lightning as pl
from PIL import Image
from tqdm import tqdm

from models import ConvModel
from model_container import ModelContainer

In [None]:
import warnings

warnings.filterwarnings('ignore', message='.*DataLoader will create.*') # Suppressed the warning related to the creation of DataLoader using a high number of num_workers

In [None]:
# SETTINGS

num_workers_for_data_loaders = 8 if torch.cuda.is_available() else 6 # 8 should be optimal if GPU is available. 6 should be optimal for CPU.

## Auxiliary Functions

In [None]:
def show_image(image, title=""):
    plt.imshow(image)
    plt.title(title)
    plt.axis('off')
    plt.show()

In [None]:
def get_targets_and_classes(dataset):
    if isinstance(dataset, Subset):
        original_dataset = dataset.dataset
        subset_indices = dataset.indices
        targets = [original_dataset.targets[i] for i in subset_indices]
        classes = original_dataset.classes
        class_to_idx = dataset.dataset.class_to_idx
    else:
        targets = dataset.targets
        classes = dataset.classes
        class_to_idx = dataset.class_to_idx
    return targets, classes, class_to_idx

In [None]:
def print_dataset_stats(dataset, dataset_name=""):
    # Retrive targes, class information, and class-to-index mapping
    targets, classes, class_to_idx = get_targets_and_classes(dataset)

    # Count the occurrences of each class in the dataset
    dataset_counts = Counter(targets)

    # Get class indices for 'pos' and 'neg'
    pos_idx = class_to_idx['pos']
    neg_idx = class_to_idx['neg']

    # Get the number of positive and negative samples
    pos_count = dataset_counts[pos_idx]
    neg_count = dataset_counts[neg_idx]
    total_count = len(targets)

    # Calculate the class ratio and percentages
    class_ratio = pos_count / neg_count if neg_count > 0 else float('inf')
    pos_percentage = (pos_count / total_count) * 100
    neg_percentage = (neg_count / total_count) * 100

    # Print dataset statistics
    print(f"'{dataset_name}' dataset:")
    print(f"\tNumber of samples: {total_count} (neg: {neg_count}, pos: {pos_count})")
    print(f"\tNumber of classes: {len(classes)}")
    print(f"\tClass names: {classes}")
    print(f"\tClass distribution ratio (pos:neg): {class_ratio:.2f}")
    print(f"\tClass percentages: {pos_percentage:.2f}% pos, {neg_percentage:.2f}% neg")
    print()

## Load and Inspect Data

In [None]:
data_augmentation = transforms.Compose([
    transforms.RandomAffine(degrees=45, translate=(0.1, 0.1), fill=0),  # Random Translation and Rotation
    transforms.RandomHorizontalFlip(),  # Random Horizontal Flip
    transforms.RandomVerticalFlip(),  # Random Vertical Flip
    transforms.RandomResizedCrop(size=(512, 512), scale=(0.85, 1.15)),  # Random Zoom and Crop
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Rescaling / Normalizing
])

data_prep = transforms.Compose([
    transforms.CenterCrop(size=(512, 512)),  # Center Crop
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Rescaling / Normalizing
])

In [None]:
train_dir = 'data/train'
test_dir = 'data/test'

train_val_data = datasets.ImageFolder(train_dir, transform=data_augmentation)
test_data = datasets.ImageFolder(test_dir, transform=data_prep)

In [None]:
from collections import Counter

print_dataset_stats(train_val_data, "Train val")
print_dataset_stats(test_data, "Test")

## Process Data

In [None]:
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)

labels = train_val_data.targets

# Split the data into training and validation sets while preserving class proportions
for train_indices, val_indices in split.split(np.zeros(len(labels)), labels):
    print(f"Number of 'Train' indices: {len(train_indices)}")
    print(f"Number of 'Val' indices: {len(val_indices)}")

    train_data = Subset(train_val_data, train_indices)

    # For the validation data we reload the images so that we don't apply augmentation
    val_data = Subset(datasets.ImageFolder(train_dir, transform=data_prep), val_indices)

print_dataset_stats(train_data, "Train")
print_dataset_stats(val_data, "Val")


In [None]:
train_data.dataset.transform

In [None]:
val_data.dataset.transform

In [None]:
# Create data loader based on class balance (imbalance)

train_counts = Counter([train_val_data.targets[i] for i in train_indices])

class_weights_not_normalized = { cls: 1.0 / count for cls, count in train_counts.items() }
total_weights = sum(class_weights_not_normalized.values())
class_weights = { cls: weight / total_weights for cls, weight in class_weights_not_normalized.items() }
print("Class weights:")
print(class_weights)

sample_weights = [class_weights[train_val_data.targets[i]] for i in train_indices] # Assigns the corresponding weight to each sample in the train dataset

sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

# Create DataLoader for training data using the sampler
train_loader = DataLoader(train_data, batch_size=64, sampler=sampler, num_workers=num_workers_for_data_loaders)

# Create DataLoader for validation data without any sampler
val_loader = DataLoader(val_data, batch_size=64, shuffle=False, num_workers=num_workers_for_data_loaders)

## Build and Train Model

In [None]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
model = ConvModel();

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

lit_model = ModelContainer(model, criterion, optimizer)

In [None]:
early_stopping_callback = pl.callbacks.EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=5
)

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    dirpath='checkpoints/',
    filename='best-checkpoint'
)

In [None]:
trainer = pl.Trainer(
    max_epochs = 100,
    callbacks=[early_stopping_callback, checkpoint_callback],
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    check_val_every_n_epoch=1
)

In [None]:
trainer.fit(lit_model, train_loader, val_loader)

In [None]:
# Plot training and validation loss over epochs
plt.plot(lit_model.train_losses, label='Train Loss')
plt.plot(lit_model.val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss over Epochs')
plt.legend()
plt.show()

## Save and Load Model

In [None]:
# Load the best model
model = ModelContainer.load_from_checkpoint('checkpoints/best-checkpoint.ckpt', model=ConvModel(), criterion=criterion, optimizer=optimizer)

## Model Evaluation

In [None]:
pos_images = os.listdir(os.path.join(train_dir, 'pos')) # Directory of 'pos' images
neg_images = os.listdir(os.path.join(train_dir, 'neg')) # Directory of 'neg' images

# Pick random image
category = random.choice(['pos', 'neg'])
chosen_image = random.choice(os.listdir(os.path.join(train_dir, category)))
image_path = os.path.join(train_dir, category, chosen_image)

# Create and overlay heatmap
original_image = Image.open(image_path).convert('RGB')
transformed_image = data_prep(original_image).unsqueeze(0)
heatmap = model.generate_gradcam_heatmap(transformed_image)
overlayed_image = model.overlay_gradcam_heatmap(original_image, heatmap)


# Display image with and without heatmap

plt.figure(figsize=(12, 6))

# Original image (left)
plt.subplot(1, 2, 1)
plt.imshow(original_image)
plt.title('Original Image')
plt.axis('off')

# Image with heatmap (right)
plt.subplot(1, 2, 2)
plt.imshow(overlayed_image)
plt.title('Grad-CAM Overlay')
plt.axis('off')

plt.show()


# Print additional info

print(f'Image path: {image_path}') # Print image path

pred = torch.argmax(model(transformed_image.to(device)), dim=1).item()
print(f'Predicted: {"pos" if pred==1 else "neg"}. Actual: {category}')