In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q torchinfo

In [None]:
#%cd "/content/drive/MyDrive/Colab Notebooks/Projects/BUSI_bin_augmented"

In [None]:
#@title Imports
import os
import cv2
import random
import albumentations as A
from collections import Counter
from PIL import Image, ImageChops, ImageDraw, ImageFilter

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, random_split
from sklearn.model_selection import KFold
from torchinfo import summary
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from sklearn.metrics import roc_curve, auc
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from datetime import datetime
from tqdm.notebook import tqdm_notebook
from tqdm import tqdm

In [None]:
# Setting the seed
random.seed(42)
g = torch.Generator().manual_seed(2147483647) # for reproducibility

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

In [None]:
data_dir = "/content/drive/MyDrive/Datasets/BUSI_bin_augmented"

In [None]:
subfolders = ['Benign','Malignant']
ben_dir = os.path.join(data_dir, subfolders[0])
mal_dir = os.path.join(data_dir, subfolders[1])

_, _, ben_files = next(os.walk(ben_dir))
_, _, mal_files = next(os.walk(mal_dir))

print(f"Number of benign images: {len(ben_files)}")
print(f"Number of malignant images: {len(mal_files)}")

In [None]:
#@title Data Loading and Preprocessing
# Set the preprocess operations to be performed on train/val/test samples
# Define the transform operations
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the dataset
dataset = datasets.ImageFolder(root=data_dir, transform=preprocess)

# Split the dataset into train, validation, and test sets
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_set, val_set, test_set = torch.utils.data.random_split(dataset, [train_size, val_size, test_size], generator=g)

# Create data loaders
batch_size = 32
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, generator=g)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size, shuffle=False, generator=g)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False, generator=g)

In [None]:
# Visualize some examples
NUM_IMAGES = 4
examples = torch.stack([val_set[idx][0] for idx in range(NUM_IMAGES)], dim=0)
img_grid = torchvision.utils.make_grid(examples, nrow=2, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(8, 8))
plt.title("Sample images in dataset")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()

In [None]:
# Print the dimension of images to verify all loaders have the same dimensions
def print_dim(loader, text):
    print('---------'+text+'---------')
    print(len(loader.dataset))
    for image, label in loader:
        print(image.shape)
        print(label.shape)
        break

print_dim(train_loader,'training loader')
print_dim(val_loader,'validation loader')
print_dim(test_loader,'test loader')

In [None]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Convert images to patches.

    Args:
        x (torch.Tensor): Tensor representing the image of shape [B, C, H, W]
        patch_size (int): Number of pixels per dimension of the patches
        flatten_channels (bool): If True, flatten the patches into feature vectors

    Returns:
        torch.Tensor: Patched images
    """
    B, C, H, W = x.shape
    assert H % patch_size == 0 and W % patch_size == 0, "Image dimensions must be divisible by patch size"

    x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)
    x = x.flatten(1, 2)

    if flatten_channels:
        x = x.flatten(2, 4)

    return x

In [None]:
def visualize_patches(images, patch_size, image_size, titles=None):
    img_patches = img_to_patch(images, patch_size=patch_size, flatten_channels=False)

    fig, axes = plt.subplots(2, images.shape[0], figsize=(20, 10))

    if images.shape[0] == 1:
        axes = axes.reshape(2, -1)

    # fig.suptitle("Original Images and Their Patches", fontsize=16)

    for i in range(images.shape[0]):
        # Plot original image
        orig_img = images[i].permute(1, 2, 0)  # Change shape from (C, H, W) to (H, W, C)
        axes[0, i].imshow(orig_img)
        axes[0, i].axis("off")
        axes[0, i].set_title(f"Original Image: {titles[i] if titles else ''}", fontsize=12)

        # Plot image patches
        img_grid = torchvision.utils.make_grid(img_patches[i], nrow=image_size//patch_size, normalize=True, pad_value=0.9)
        img_grid = img_grid.permute(1, 2, 0)  # Change shape from (C, H, W) to (H, W, C)
        axes[1, i].imshow(img_grid)
        axes[1, i].axis("off")
        axes[1, i].set_title(f"Image Patches: {titles[i] if titles else ''}", fontsize=12)

    plt.tight_layout()
    plt.show()

In [None]:
def load_and_preprocess_image(image_path, target_size):
    """
    Load and preprocess an image file.

    Args:
        image_path (str): Path to the image file
        target_size (tuple): Target size for resizing (height, width)

    Returns:
        torch.Tensor: Preprocessed image tensor
    """
    img = Image.open(image_path).convert('RGB')
    img = img.resize(target_size)
    img_tensor = torchvision.transforms.ToTensor()(img)
    return img_tensor.unsqueeze(0)

In [None]:
def process_images(img_names, img_dir, target_size):
    img_tensors = []
    for img_name in img_names:
        img_path = os.path.join(img_dir, img_name)
        img_tensor = load_and_preprocess_image(img_path, target_size)
        img_tensors.append(img_tensor)
    return torch.stack(img_tensors)

In [None]:
image_size = 224
patch_size = 16
num_patches = (image_size // patch_size) ** 2
num_channels = 3

In [None]:
ben_imgs = [f for f in os.listdir(ben_dir) if f.endswith('.bmp')][:2]
mal_imgs = [f for f in os.listdir(mal_dir) if f.endswith('.bmp')][:2]

image_titles = ben_imgs + mal_imgs

# Load and preprocess images
images = []
for img_name in ben_imgs:
    img_path = os.path.join(ben_dir, img_name)
    img_tensor = load_and_preprocess_image(img_path, (image_size, image_size))
    images.append(img_tensor)

for img_name in mal_imgs:
    img_path = os.path.join(mal_dir, img_name)
    img_tensor = load_and_preprocess_image(img_path, (image_size, image_size))
    images.append(img_tensor)

# Combine images into a single tensor
samples = torch.cat(images, dim=0)

In [None]:
# Visualize patches
visualize_patches(samples, patch_size, image_size, titles=image_titles)

# Modelling

In [None]:
#@title Helper Functions
# Early stopping class
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None or val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [None]:
# Pretrained ViT setup
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1 #ViT_B_16_Weights, ViT_H_14_Weights, ViT_H_14_Weights
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device) # vit_b_16, vit_h_14, vit_l_16

# Freeze base parameters
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False

# Replace the classification head
num_classes = len(subfolders)  # Two classes: Malignant and Benign
pretrained_vit.heads = nn.Linear(in_features=768, out_features=2).to(device)

In [None]:
# Get automatic transforms from pretrained ViT weights
pretrained_vit_transforms = pretrained_vit_weights.transforms()
print(pretrained_vit_transforms)

In [None]:
# Print a summary using torchinfo (uncomment for actual output)
summary(model=pretrained_vit,
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

In [None]:
# Function to create dataloaders
def create_dataloaders(data_dir: str, batch_size: int, transform: transforms.Compose, num_workers: int = os.cpu_count()):
    # Load dataset
    dataset = datasets.ImageFolder(root=data_dir, transform=transform)

    # Split the dataset into train, validation, and test sets
    train_size = int(0.7 * len(dataset))
    print(f"Train size: {train_size}")
    val_size = int(0.2 * len(dataset))
    print(f"Val size: {val_size}")
    test_size = len(dataset) - train_size - val_size
    print(f"Test size: {test_size}")

    class_names = dataset.classes
    class_counts = [0] * len(class_names)
    for _, label in dataset:
        class_counts[label] += 1

    print("Class counts:")
    for i, class_name in enumerate(class_names):
        print(f"{class_name}: {class_counts[i]}")

    train_set, val_set, test_set = torch.utils.data.random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(23))

    # Turn datasets into DataLoaders
    train_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )
    test_loader = DataLoader(
        test_set,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    return train_loader, val_loader, test_loader, class_names

In [None]:
os.cpu_count()

# Pretrained Vision Transformer

In [None]:
def evaluate(model, test_dataloader, loss_fn, device):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        for images, labels in test_dataloader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = loss_fn(outputs, labels)

            test_loss += loss.item() * images.size(0)  # Accumulate loss
            _, preds = torch.max(outputs, 1)
            correct_predictions += torch.sum(preds == labels.data).item()
            total_samples += labels.size(0)

    average_loss = test_loss / total_samples
    accuracy = correct_predictions / total_samples

    # print(f"Test Loss: {average_loss:.4f}, Test Accuracy: {accuracy:.4f}")

    return average_loss, accuracy

In [None]:
def train(model, train_dataloader, val_dataloader, test_dataloader, loss_fn, optimizer, device, epochs, early_stopping):
    metrics = {
        'train_losses': [],
        'val_losses': [],
        'train_accuracies': [],
        'val_accuracies': [],
        'test_losses': [],
        'test_accuracies': []
    }

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")

        # Training phase
        model.train()
        training_loss = 0.0
        train_corrects = 0
        total_train_samples = 0
        for images, labels in tqdm(train_dataloader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            training_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            train_corrects += torch.sum(preds == labels.data).item()
            total_train_samples += labels.size(0)

        training_loss /= total_train_samples
        train_acc = train_corrects / total_train_samples

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        total_val_samples = 0
        with torch.no_grad():
            for images, labels in val_dataloader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = loss_fn(outputs, labels)

                val_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                val_corrects += torch.sum(preds == labels.data).item()
                total_val_samples += labels.size(0)

        val_loss /= total_val_samples
        val_acc = val_corrects / total_val_samples

        test_loss, test_accuracy = evaluate(model, test_dataloader, loss_fn, device)

        metrics['train_losses'].append(training_loss)
        metrics['val_losses'].append(val_loss)
        metrics['train_accuracies'].append(train_acc)
        metrics['val_accuracies'].append(val_acc)
        metrics['test_losses'].append(test_loss)
        metrics['test_accuracies'].append(test_accuracy)

        print(f"Epoch [{epoch + 1}/{epochs}], "
              f"Train Loss: {training_loss:.4f}, Train Accuracy: {train_acc:.4f}, "
              f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}, "
              f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

        if early_stopping:
            early_stopping(val_loss)
            if early_stopping.early_stop:
                print("Early stopping")
                break

    return model, metrics

In [None]:
#@title Hyperparameters
learning_rate = 1e-4
optimizer = torch.optim.Adam(params=pretrained_vit.parameters(), lr=learning_rate)
loss_fn = torch.nn.CrossEntropyLoss()
num_epochs = 80
early_stopping_patience = 10

In [None]:
#@title Training Loop
early_stopping = EarlyStopping(patience=early_stopping_patience)

# Train the classifier head of the pretrained ViT
pretrained_vit, pretrained_vit_metrics = train(
    model=pretrained_vit,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    test_dataloader=test_loader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    epochs=num_epochs,
    device=device,
    early_stopping=early_stopping
)

In [None]:
def plot_metrics(metrics, title):
    epochs = range(1, len(metrics['train_losses']) + 1)

    fig = make_subplots(rows=1, cols=2, subplot_titles=(f'Loss', f'Accuracy'))

    # Plot loss
    fig.add_trace(go.Scatter(x=list(epochs), y=metrics['train_losses'], mode='lines', name='Training Loss'), row=1, col=1)
    fig.add_trace(go.Scatter(x=list(epochs), y=metrics['val_losses'], mode='lines', name='Validation Loss'), row=1, col=1)
    fig.add_trace(go.Scatter(x=[epochs[0], epochs[-1]], y=[metrics['test_losses'][-1], metrics['test_losses'][-1]], mode='lines', name='Test Loss', line=dict(dash='dash', color='red')), row=1, col=1)

    # Plot accuracy
    fig.add_trace(go.Scatter(x=list(epochs), y=metrics['train_accuracies'], mode='lines', name='Training Accuracy'), row=1, col=2)
    fig.add_trace(go.Scatter(x=list(epochs), y=metrics['val_accuracies'], mode='lines', name='Validation Accuracy'), row=1, col=2)
    fig.add_trace(go.Scatter(x=[epochs[0], epochs[-1]], y=[metrics['test_accuracies'][-1], metrics['test_accuracies'][-1]], mode='lines', name='Test Accuracy', line=dict(dash='dash', color='red')), row=1, col=2)

    fig.update_xaxes(title_text="Epoch", row=1, col=1)
    fig.update_yaxes(title_text="Loss", row=1, col=1)
    fig.update_xaxes(title_text="Epoch", row=1, col=2)
    fig.update_yaxes(title_text="Accuracy", row=1, col=2)

    fig.update_layout(title_text=title, showlegend=True, width=1000, height=500)

    fig.show()

In [None]:
# Plot training and validation metrics with test metrics
plot_metrics(pretrained_vit_metrics, "Training, Validation, and Test")

In [None]:
def plot_roc_curve(model, dataloader, device, title="ROC Curve"):
    model.eval()
    y_true = []
    y_scores = []

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)
            y_true.extend(labels.cpu().numpy())
            y_scores.extend(probabilities.cpu().numpy()[:, 1])  # binary classification and we're interested in class 1

    fpr, tpr, _ = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)

    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(title)
    plt.legend(loc="lower right")
    plt.show()

In [None]:
# Plot ROC curve for the test set
plot_roc_curve(pretrained_vit, test_loader, device)

In [None]:
#@title Saving the model
def save_model(model, path):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

In [None]:
model_path = "/content/drive/MyDrive/Results and output/Breast Cancer Transformer (BCT)"
save_model(pretrained_vit, model_path)