In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torchvision import transforms
from medmnist import DermaMNIST

# Vision Transformer Model

In [None]:
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()

    def forward(self, Query, Key, Value, D):
        scaling_factor = torch.sqrt(torch.tensor(D, dtype=torch.float32))

        QueryKey = torch.matmul(Query, torch.transpose(Key, -1, -2))
        attention_scores = torch.div(QueryKey, scaling_factor)
        softmax = F.softmax(attention_scores, dim = -1)
        return torch.matmul(softmax, Value)

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, D):
        super(MultiHeadAttention, self).__init__()
        dimensions = D / num_heads
        assert dimensions % 1 == 0, f"D must be divisible by heads, got heads: {num_heads}, D: {D}"
        self.num_heads = num_heads
        self.dimensions = int(dimensions)
        self.D = D

        self.Q_Linear = nn.Linear(D, D)
        self.K_Linear = nn.Linear(D, D)
        self.V_Linear = nn.Linear(D, D)

        self.attention = Attention()

        self.final_layer = nn.Linear(D, D)

    def forward(self, X):
        batch_size = X.shape[:-2]


        Query = self.Q_Linear(X)
        Key = self.K_Linear(X)
        Value = self.V_Linear(X)

        Query = Query.view((*batch_size, -1, self.num_heads , self.dimensions)).transpose(-3,-2)
        Key = Key.view((*batch_size, -1, self.num_heads , self.dimensions)).transpose(-3,-2)
        Value = Value.view((*batch_size, -1, self.num_heads , self.dimensions)).transpose(-3,-2)

        heads = attention(Query, Key, Value, self.dimensions)
        concatenated_heads = heads.transpose(-3, -2).contiguous().view(*batch_size, -1, D)

        return self.final_layer(concatenated_heads)

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, num_heads, D, mlp_width = 248):
        super(TransformerEncoder, self).__init__()

        # The operation follow in this order
        self.norm_1 = nn.LayerNorm(D)
        self.multi_head_attention = MultiHeadAttention(num_heads, D)
        self.norm_2 = nn.LayerNorm(D)
        self.MLP = nn.Sequential(
            nn.Linear(D, mlp_width),
            nn.GELU(),
            nn.Linear(mlp_width, D)
        )

    def forward(self, X):
        first_output = self.norm_1(X)
        first_output = self.multi_head_attention(first_output)
        first_output = X + first_output # Residual connection
        second_output = self.norm_2(first_output)
        second_output = self.MLP(second_output)
        final_output = first_output + second_output # Second residual connection
        return final_output

In [None]:
class LinearProjectionOfFlattenedPatches(nn.Module):
    def __init__(self, in_channels, heigth, width, patch_size,  D):
        super(LinearProjectionOfFlattenedPatches, self).__init__()
        self.img_size = (heigth, width)
        self.patch_size = patch_size  # P
        self.in_channels = in_channels  # C

        self.N = (self.img_size[0] * self.img_size[1]) / (patch_size**2)
        assert self.N % 1 == 0, f"num_patches must be divisible by patch size: {patch_size}, size: {self.img_size}"
        self.N = int(self.N)
        patch_dimensions = patch_size * patch_size * in_channels

        # Linear projection of patches into latent vector of dimension D
        self.projection = nn.Linear(patch_dimensions, D)

    def forward(self, x):
        batch_size = x.shape[:-3]
        permutate_dim = (0, 2, 1, 3) if batch_size else (1, 0, 2)

        # Divide images into patches and flatten
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.contiguous().view(*batch_size, self.in_channels, self.N, -1)
        x = torch.permute(x, permutate_dim).contiguous().view(*batch_size, self.N, -1)

        # Apply linear projection to each patch
        x = self.projection(x)

        return x

In [None]:
class VisionTransformer(nn.Module):
    def __init__(self,
                 in_channels,
                 height,
                 width,
                 patch_size,
                 num_heads,
                 D,
                 L,
                 num_classes,
                 mlp_width = 248,
                 head_width = 248):
        super(VisionTransformer, self).__init__()
        self.img_size = (height, width)

        # Calculating the correct dimensions:
        N = (self.img_size[0] * self.img_size[1]) / (patch_size**2)
        assert N % 1 == 0, f"num_patches must be divisible by patch size: {patch_size}, size: {self.img_size}"
        N = int(N)

        # Initialising Extra Learnable [class] Embedding
        self.class_embedding = nn.Parameter(torch.randn(1, D))

        # Initialising Learnable Position Embedding
        self.position_embedding = nn.Parameter(torch.randn(N+1, D))

        # ----- Building the Model ----------
        # The operations follow in this order
        self.linear_projection_of_flattened_patches = LinearProjectionOfFlattenedPatches(
            in_channels, height, width, patch_size, D)

        transformer_encoder_list = nn.ModuleList()
        for _ in range(L): # Adding the TransformerEncoder block L times
            transformer_encoder_list.append(TransformerEncoder(num_heads, D, mlp_width))

        self.transformer_encoder = nn.Sequential(*transformer_encoder_list) # Unpacking L blocks into one sequential block

        # Only the class token will be passed to the head, which will be concatenated at position 0 of our image_patches
        self.head = nn.Sequential(
            nn.Linear(D, head_width),
            nn.GELU(),
            nn.Linear(head_width, num_classes)
        )

    def forward(self, X):
        batch_size = X.shape[:-3]
        image_patches = self.linear_projection_of_flattened_patches(X)
        # Concatenating with the [class] token
        class_embedding = self.class_embedding.unsqueeze(0).repeat(*batch_size, 1, 1) if batch_size else self.class_embedding
        image_patches = torch.cat((class_embedding, image_patches), -2)
        image_patches = image_patches + self.position_embedding # Adding positional embedding
        image_patches = self.transformer_encoder(image_patches)
        # Extract the class token as we're only using a class token for the classification
        class_token = image_patches[:,0,:] if batch_size else image_patches[0, :]
        output = self.head(class_token)
        return output

# CNN Model

In [None]:
class CNN(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(CNN, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=1), # 28x28
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0), #14x14
            nn.Conv2d(64, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0), #7x7
            nn.Flatten(),
            nn.Linear(32*7*7, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, X):
        return self.model(X)

# Loading the Data

In [None]:
from torchvision import transforms
from medmnist import DermaMNIST
from torch.utils.data import DataLoader

train_transformations = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
])


val_transformations = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(.5, .5),
])


train_download = DermaMNIST(split='train', transform=train_transformations, download=True)
val_download = DermaMNIST(split='val', transform=val_transformations, download=True)
test_download = DermaMNIST(split='test', transform=val_transformations, download=True)

# Model Init

In [None]:
channels, height, width = train_download[0][0].shape
patch_size = 4
num_heads = 8
latent_vector = 64
transformer_blocks = 2
num_classes = 7
mlp_width = 248
head_width = 248

vit = VisionTransformer(
    in_channels=channels,
    height=height,
    width=width,
    patch_size=patch_size,
    num_heads=num_heads,
    D=latent_vector,
    L=transformer_blocks,
    num_classes=num_classes,
    mlp_width=mlp_width,
    head_width=head_width
)

cnn = CNN(in_channels=channels, num_classes=num_classes)

# Training Parameters for ViT

In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
lr = 1e-3
num_epochs = 30
batch_size = 64
weight_decay = 0
checkpoint_every_th_epoch = None
vit_optimizer = optim.Adam(vit.parameters(), lr=lr, weight_decay=weight_decay)

train_data = DataLoader(train_download, batch_size=batch_size, shuffle=True)
val_data = DataLoader(val_download, batch_size=batch_size, shuffle=True)
test_data = DataLoader(test_download, batch_size=len(test_download), shuffle=False)

# ViT Training Loop

In [None]:
for epoch in range(num_epochs):
    # Training Phase
    vit.train()
    train_loss = 0.0
    for images, labels in train_data:
        images, labels = images, labels.squeeze(1)

        vit_optimizer.zero_grad()

        outputs = vit(images)
        loss = criterion(outputs, labels)
        loss.backward()
        vit_optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(train_data.dataset)

    # Validation Phase
    vit.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_data:
            images, labels = images, labels.squeeze(1)

            outputs = vit(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_data.dataset)
    val_accuracy = correct / total

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')

    # Save checkpoint
    if checkpoint_every_th_epoch:
        if (epoch + 1) % checkpoint_every_th_epoch == 0:
            torch.save(vit.state_dict(), f'checkpoint_epoch_{epoch+1}.pth')

# Training Parameters for CNN

In [None]:
criterion = nn.CrossEntropyLoss()
lr = 1e-3
num_epochs = 30
batch_size = 64
weight_decay = 0
checkpoint_every_th_epoch = None
cnn_optimizer = optim.Adam(cnn.parameters(), lr=lr, weight_decay=weight_decay)

train_data = DataLoader(train_download, batch_size=batch_size, shuffle=True)
val_data = DataLoader(val_download, batch_size=batch_size, shuffle=True)
test_data = DataLoader(test_download, batch_size=len(test_download), shuffle=False)

# CNN Training Loop

In [None]:
for epoch in range(num_epochs):
    # Training Phase
    cnn.train()
    train_loss = 0.0
    for images, labels in train_data:
        images, labels = images, labels.squeeze(1)

        cnn_optimizer.zero_grad()

        outputs = cnn(images)
        loss = criterion(outputs, labels)
        loss.backward()
        cnn_optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(train_data.dataset)

    # Validation Phase
    cnn.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_data:
            images, labels = images, labels.squeeze(1)

            outputs = cnn(images)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    val_loss /= len(val_data.dataset)
    val_accuracy = correct / total

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}')

    # Save checkpoint
    if checkpoint_every_th_epoch:
        if (epoch + 1) % checkpoint_every_th_epoch == 0:
            torch.save(cnn.state_dict(), f'checkpoint_epoch_{epoch+1}.pth')

# Evaluating our Models

In [None]:
import numpy as np

def evaluate_model(model, test_loader, criterion, num_classes):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    # Initialize a confusion matrix
    confusion_matrix = np.zeros((num_classes, num_classes), dtype=np.int64)

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images, labels.squeeze(1)

            outputs = model(images)
            loss = criterion(outputs, labels)

            test_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Update confusion matrix
            for t, p in zip(labels.view(-1), predicted.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1

    test_loss /= len(test_loader.dataset)
    test_accuracy = correct / total

    # Calculate accuracy for each class
    class_accuracy = 100 * confusion_matrix.diagonal() / confusion_matrix.sum(1)
    class_accuracy_dict = {idx: acc for idx, acc in enumerate(class_accuracy)}

    return test_loss, test_accuracy, class_accuracy_dict



# Evaluate Vision Transformer

vit_test_loss, vit_test_accuracy, vit_class_accuracy = evaluate_model(vit, test_data, criterion, 7)
print(f'--- Vision Transformer Performance ---')
print(f'Test Loss: {vit_test_loss:.4f}, Test Accuracy: {vit_test_accuracy:.4f}\n')
print(f'Class-wise Accuracy:\n{vit_class_accuracy}\n')

# Evaluate CNN
print(f'--- CNN Performance ---')
cnn_test_loss, cnn_test_accuracy, cnn_class_accuracy = evaluate_model(cnn, test_data, criterion, 7)
print(f'Test Loss: {cnn_test_loss:.4f}, Test Accuracy: {cnn_test_accuracy:.4f}\n')
print(f'Class-wise Accuracy:\n{cnn_class_accuracy}')

--- Vision Transformer Performance ---
Test Loss: 0.6937, Test Accuracy: 0.7406

Class-wise Accuracy:
{0: 33.333333333333336, 1: 37.86407766990291, 2: 41.36363636363637, 3: 0.0, 4: 10.762331838565023, 5: 96.42058165548099, 6: 55.172413793103445}

--- CNN Performance ---
Test Loss: 0.6621, Test Accuracy: 0.7566

Class-wise Accuracy:
{0: 27.272727272727273, 1: 62.13592233009709, 2: 43.18181818181818, 3: 13.043478260869565, 4: 24.2152466367713, 5: 94.33258762117822, 6: 62.06896551724138}
