In [1]:
import os
import sys
from vit_pytorch import ViT
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tdata
import numpy as np
from tqdm import tqdm
import cv2
import wandb

In [2]:
wandb.login()

run = wandb.init(
    project='cancer-classifier-ViT-model-DL-course',
    config={}
    )


sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

from cancer_classifier.config import MODELS_DIR, RAW_DATA_DIR, PROCESSED_DATA_DIR, INTERIM_DATA_DIR, CLASSES
from cancer_classifier.processing.image_utils import adjust_image_contrast, resize_image_tensor, normalize_image_tensor, augment_image_tensor, process_dataset, save_processed_images, crop_image

%load_ext autoreload
%autoreload 2

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mayoub-dev2000[0m ([33mayoub-dev2000-university-of-liege[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[32m2025-05-16 18:47:18.139[0m | [1mINFO    [0m | [36mcancer_classifier.config[0m:[36m<module>[0m:[36m11[0m - [1mPROJ_ROOT path is: /home/ayoubvip/deep-learning-cancer-classifier[0m


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


## Data Preprocessing

In [4]:
img_size = (256, 256)
clip_limit = 2.0
tile_size = (1,1)

### data crop and contract equilazing

In [5]:
# for i, cls in enumerate(CLASSES):
#     cls_dir = os.path.join(RAW_DATA_DIR, cls)
#     for img_name in os.listdir(cls_dir):
#         img_path = os.path.join(cls_dir, img_name)
#         img = crop_image(img_path, img_size, clip_limit, tile_size)
#         # img = adjust_image_contrast(img)

#         # save in processed directory
#         cv2.imwrite(os.path.join(PROCESSED_DATA_DIR, cls, img_name), img)
#         # save_processed_images(img_tensor, os.path.join(PROCESSED_DATA_DIR, cls), img_name)

### data augmentation

In [6]:
transformers = torchvision.transforms.Compose([
    torchvision.transforms.Grayscale(num_output_channels=1),
    torchvision.transforms.Resize(size=img_size),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomVerticalFlip(),
    torchvision.transforms.RandomRotation(10),
    torchvision.transforms.RandomResizedCrop(size=img_size, scale=(0.8, 1.0)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = torchvision.datasets.ImageFolder(root=PROCESSED_DATA_DIR, transform=transformers, target_transform=None)

### data spliting 

In [7]:
batch_nbr = 127
train_ratio = 0.80
test_ratio = 0.10
val_ratio = 0.10

In [8]:
rand_gen = torch.Generator().manual_seed(142)
train_dataset, val_dataset, test_dataset = tdata.random_split(
    dataset = dataset,
    lengths=[train_ratio, val_ratio, test_ratio],
    generator=rand_gen
)

train_loader = tdata.DataLoader(
    dataset=train_dataset,
    batch_size=batch_nbr,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)
val_loader = tdata.DataLoader(
    dataset=val_dataset,
    batch_size=batch_nbr,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)
test_loader = tdata.DataLoader(
    dataset=test_dataset,
    batch_size=batch_nbr,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)
def get_class_distribution(dataset):
    class_distribution = {}
    for _, label in dataset:
        if label.__str__() not in class_distribution:
            class_distribution[label.__str__()] = 0
        class_distribution[label.__str__()] += 1
    return class_distribution

print("Train dataset class distribution:", get_class_distribution(train_dataset))
print("Validation dataset class distribution:", get_class_distribution(val_dataset))
print("Test dataset class distribution:", get_class_distribution(test_dataset))

Train dataset class distribution: {'2': 1640, '1': 1601, '0': 1604}
Validation dataset class distribution: {'1': 211, '2': 209, '0': 186}
Test dataset class distribution: {'0': 214, '2': 199, '1': 192}


# Model Architecture: Vision Transformer (Dosovitskiy et al., 2021)

In [9]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=256, patch_size=32, num_hiddens=512):
        super().__init__()
        def _make_tuple(x):
            if not isinstance(x, (list, tuple)):
                return (x, x)
            return x
        img_size, patch_size = _make_tuple(img_size), _make_tuple(patch_size)
        self.num_patches = (img_size[0] // patch_size[0]) * (
            img_size[1] // patch_size[1])
        self.conv = nn.LazyConv2d(num_hiddens, kernel_size=patch_size,
                                  stride=patch_size)

    def forward(self, X):
        # Output shape: (batch size, no. of patches, no. of channels)
        return self.conv(X).flatten(2).transpose(1, 2)

In [10]:
class ViTMLP(nn.Module):
    def __init__(self, mlp_num_hiddens, mlp_num_outputs, dropout=0.5):
        super().__init__()
        self.dense1 = nn.LazyLinear(mlp_num_hiddens)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.dense2 = nn.LazyLinear(mlp_num_outputs)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout2(self.dense2(self.dropout1(self.gelu(
            self.dense1(x)))))
        
def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)
        mask = torch.arange((maxlen), dtype=torch.float32,
                            device=X.device)[None, :] < valid_len[:, None]
        X[~mask] = value
        return X

    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)
    
class DotProductAttention(nn.Module):
    """Scaled dot product attention."""
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    # Shape of queries: (batch_size, no. of queries, d)
    # Shape of keys: (batch_size, no. of key-value pairs, d)
    # Shape of values: (batch_size, no. of key-value pairs, value dimension)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Swap the last two dimensions of keys with keys.transpose(1, 2)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / np.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)
        
class MultiHeadAttention(nn.Module):
    """Multi-head attention."""
    def __init__(self, num_hiddens, num_heads, dropout, bias=False, **kwargs):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_k = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_v = nn.LazyLinear(num_hiddens, bias=bias)
        self.W_o = nn.LazyLinear(num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # Shape of queries, keys, or values:
        # (batch_size, no. of queries or key-value pairs, num_hiddens)
        # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
        # After transposing, shape of output queries, keys, or values:
        # (batch_size * num_heads, no. of queries or key-value pairs,
        # num_hiddens / num_heads)
        queries = self.W_q(queries).T
        keys = self.W_k(keys).T
        values = self.W_v(values).T

        if valid_lens is not None:
            # On axis 0, copy the first item (scalar or vector) for num_heads
            # times, then copy the next item, and so on
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # Shape of output: (batch_size * num_heads, no. of queries,
        # num_hiddens / num_heads)
        output = self.attention(queries, keys, values, valid_lens)
        # Shape of output_concat: (batch_size, no. of queries, num_hiddens)
        output_concat = output.T
        return self.W_o(output_concat)
        
class ViTBlock(nn.Module):
    def __init__(self, num_hiddens, norm_shape, mlp_num_hiddens,
                 num_heads, dropout, use_bias=False):
        super().__init__()
        self.ln1 = nn.LayerNorm(norm_shape)
        self.attention = MultiHeadAttention(num_hiddens, num_heads,
                                                dropout, use_bias)
        self.ln2 = nn.LayerNorm(norm_shape)
        self.mlp = ViTMLP(mlp_num_hiddens, num_hiddens, dropout)

    def forward(self, X, valid_lens=None):
        X = X + self.attention(*([self.ln1(X)] * 3), valid_lens)
        return X + self.mlp(self.ln2(X))

In [11]:
class ViTModel(nn.Module):
    """Vision Transformer."""
    def __init__(self, img_size, patch_size, num_hiddens, mlp_num_hiddens,
                 num_heads, num_blks, emb_dropout, blk_dropout, lr=0.1,
                 use_bias=False, num_classes=10):
        super().__init__()
        # self.save_hyperparameters()
        self.patch_embedding = PatchEmbedding(
            img_size, patch_size, num_hiddens)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, num_hiddens))
        num_steps = self.patch_embedding.num_patches + 1  # Add the cls token
        # Positional embeddings are learnable
        self.pos_embedding = nn.Parameter(
            torch.randn(1, num_steps, num_hiddens))
        self.dropout = nn.Dropout(emb_dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module(f"{i}", ViTBlock(
                num_hiddens, num_hiddens, mlp_num_hiddens,
                num_heads, blk_dropout, use_bias))
        self.head = nn.Sequential(nn.LayerNorm(num_hiddens),
                                  nn.Linear(num_hiddens, num_classes))

    def forward(self, X):
        X = self.patch_embedding(X)
        X = torch.cat((self.cls_token.expand(X.shape[0], -1, -1), X), 1)
        X = self.dropout(X + self.pos_embedding)
        for blk in self.blks:
            X = blk(X)
        return self.head(X[:, 0])

In [12]:
patch_size = 16
num_hiddens = 516
mlp_num_hiddens = 156
num_heads = 8
num_blks = 2
emb_dropout = 0.0001
blk_dropout = 0.0001

In [13]:
Vit_model = ViTModel(img_size, patch_size, num_hiddens, mlp_num_hiddens,
                 num_heads, num_blks, emb_dropout, blk_dropout, lr=0.1,
                 use_bias=False, num_classes=len(CLASSES)).to(device)

## Training the model

### defining Loss function, Optimization method, and training parameters

In [14]:
loss_fn = nn.CrossEntropyLoss()
weights_decay = 0.0001
learning_rate = 0.0001
epochs = 5

# optimizer = torch.optim.Adam(
#     Vit_model.parameters(),
#     lr=learning_rate,
#     weight_decay=weights_decay
# ) #->70%

optimizer = torch.optim.AdamW(
    Vit_model.parameters(),
    lr=learning_rate,
    weight_decay=weights_decay
) #->77.4

# optimizer = torch.optim.SGD(
#     Vit_model.parameters(),
#     lr=learning_rate,
#     weight_decay=weights_decay
# ) #-> 77.5

In [15]:
def test_loop(dataloader, model, loss_fn, epoch=0):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    true_labels = []
    pred_labels = []
    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            true_labels.extend(y.cpu().numpy())
            pred_labels.extend(pred.argmax(1).cpu().numpy())

    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    wandb.log({"epoch": epoch,"accuracy": f"{(100*correct):>0.1f}", "test_loss": test_loss})


    
    
    return true_labels, pred_labels, correct, test_loss


In [17]:
from cancer_classifier.modeling.train import train
# import evaluate from cancer_classifier.modeling.evaluate as evaluate
def train_loop(dataloader, model, loss_fn, optimizer, epoch):
    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        X = X.to(device)
        y = y.to(device) 
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * batch_nbr + len(X)
            wandb.log({"step":  f"{current:>5d}/{size:>5d}", "train_loss": loss})
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

            

for epoch in tqdm(range(epochs)):
    print("----------------- EPOCH " + str(epoch) + "------------------")
    train_loop(
        dataloader=train_loader,
        model=Vit_model,
        loss_fn=loss_fn,
        optimizer=optimizer,
        epoch=epoch
        )
    true_labels, pred_labels, correct, test_loss =  test_loop(
        dataloader=test_loader,
        model=Vit_model,
        loss_fn=loss_fn,
        epoch=epoch
    )


  0%|          | 0/5 [00:00<?, ?it/s]

----------------- EPOCH 0------------------


  queries = self.W_q(queries).T


loss: 1.103517  [  127/ 4845]



 20%|██        | 1/5 [00:10<00:40, 10.13s/it]

Test Error: 
 Accuracy: 54.9%, Avg loss: 0.881669 

----------------- EPOCH 1------------------
loss: 0.875471  [  127/ 4845]



 40%|████      | 2/5 [00:17<00:24,  8.21s/it]

Test Error: 
 Accuracy: 62.8%, Avg loss: 0.803235 

----------------- EPOCH 2------------------
loss: 0.848864  [  127/ 4845]



 60%|██████    | 3/5 [00:23<00:15,  7.53s/it]

Test Error: 
 Accuracy: 65.6%, Avg loss: 0.760624 

----------------- EPOCH 3------------------
loss: 0.782124  [  127/ 4845]



 80%|████████  | 4/5 [00:30<00:07,  7.27s/it]

Test Error: 
 Accuracy: 67.1%, Avg loss: 0.738104 

----------------- EPOCH 4------------------
loss: 0.828504  [  127/ 4845]


100%|██████████| 5/5 [00:37<00:00,  7.49s/it]

Test Error: 
 Accuracy: 66.6%, Avg loss: 0.747415 






In [18]:
true_labels, pred_labels, correct, test_loss =  test_loop(
    dataloader=test_loader,
    model=Vit_model,
    loss_fn=loss_fn
)
# train(
#     model=Vit_model,
#     train_loader=train_loader,
#     val_loader=val_loader,
#     test_loader=test_loader,
#     loss_fn=loss_fn,
#     optimizer=optimizer,
#     epochs=epochs,
#     device=device
# )

Test Error: 
 Accuracy: 68.3%, Avg loss: 0.724880 



## Evaluating the model

### Covariance Matrix

In [None]:
from cancer_classifier.plots import visualize_sample_images,  plot_confusion_matrix
import time

visualize_sample_images(
    dataset=dataset,
)
plot_confusion_matrix(
    true_classes=true_labels,
    predicted_classes=pred_labels,
    model_name =  time.strftime("%Y-%m-%d_%H-%M-%S") + "_Vit_model_" + f"{(100*correct):>0.1f}%"
)

### serializing the model

In [None]:
import time
model_name = f"{(100*correct):>0.1f}%--" + time.strftime("%Y-%m-%d_%H-%M-%S") + "_Vit_model.pth"
torch.save(Vit_model.state_dict(), MODELS_DIR / model_name)
print("Saved PyTorch Model State to " +  model_name)

## References

- https://d2l.ai/chapter_attention-mechanisms-and-transformers/vision-transformer.html
- https://github.com/lucidrains/vit-pytorch
- https://docs.pytorch.org/tutorials/beginner/basics/transforms_tutorial.html
