# Let's built a Vision Transformer (ViT) using RESNet-152

here we will extract features from images using Resnet 152 and create patches out of them, position embed them, and finally pass it to encoder

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, confusion_matrix
import seaborn as sns

from zipfile import ZipFile
from matplotlib import pyplot as plt
from tqdm import tqdm
from collections import Counter
from pathlib import Path
import time

from PIL import Image
import numpy as np
import pandas as pd

!pip install funcyou -q
from funcyou.utils import DotDict, dir_walk
from funcyou.dataset import download_kaggle_resource
from funcyou.preprocessing.image import Patcher
from funcyou.pytorch.utils import calculate_class_weights_from_directory
from funcyou.sklearn.metrics import calculate_results

In [None]:
device= 'cuda' if torch.cuda.is_available() else 'cpu'

model_name = 'vit'

if model_name=='vitres':
    model_path = Path('models/vitres.pth')

elif model_name=='vit':
    model_path = Path('models/vit1.pth')
elif model_name=='res':
    model_path = Path('models/res.pth')

In [None]:
# !ls models

In [None]:
model_path

In [None]:
# list(model_path.parent.iterdir())

In [None]:
model_path.parent.mkdir(exist_ok=True)

# Config

In [None]:
# Create a DotDict instance and initialize it with your configuration
config = DotDict()
config.num_layers = 4
config.resnet_layers = 2

config.hidden_dim = 120  # should be mutiple of num_heads
config.mlp_dim = 2048
config.num_heads = 12
config.dropout_rate = 0.1
config.image_size = 512  # should be mutiple of patch_size
config.patch_size = 32   # should be mutiple of 8
config.num_patches = int(config.image_size**2 / config.patch_size**2)
config.num_channels = 3
config.patching_elements = (config.num_channels*config.image_size**2 )//config.num_patches
config.final_resnet_output_dim = 2048
config.num_classes = 2
config.batch_size = 8
config.device = device

In [None]:
config.num_heads*10

# Download Dataset

In [None]:
data_dir = Path('../input/pneumonia-chest-x-ray-dataset')

In [None]:
list(data_dir.iterdir())

In [None]:
train_dir = data_dir/'train'
test_dir = data_dir/'test'
val_dir = data_dir/'val'

In [None]:
dir_walk(train_dir)

> There is too much class imbalance pn(can't spell it) : 3875 , normal : 1341, I will take care of it while making dataloader .

In [None]:
dir_walk(test_dir)

In [None]:
dir_walk(val_dir)

In [None]:
# Define data augmentation transformations for X-ray images
train_transform = transforms.Compose([
    transforms.RandomRotation(degrees=(-20, 20)),  # Random rotation between -10 and 10 degrees
    transforms.RandomHorizontalFlip(),            # Random horizontal flip
    transforms.RandomVerticalFlip(),            # Random vertical flip
    transforms.RandomResizedCrop(config.image_size, scale=(0.7, 1.3)),  # Randomly resize and crop to 224x224 pixels
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),  # Adjust brightness, contrast, saturation, and hue
    transforms.ToTensor(),  # Convert to tensor
   ])


In [None]:
# Create an ImageFolder dataset
train_dataset = ImageFolder(root=train_dir, transform=train_transform)

# since we have imbalanced classes we will create a custom sampler which will sample with given weights
class_dist, class_weights, bincount = calculate_class_weights_from_directory(train_dir)
train_sampler = torch.utils.data.WeightedRandomSampler(class_weights, len(class_weights), replacement=True)
print('Train')
print('given weights:' ,np.unique(class_weights))
print('counts: ', Counter(class_weights))
print('Train bincount: ', bincount)
# Create a DataLoader to load the data in batches

train_dataloader = DataLoader(train_dataset, batch_size=config.batch_size, sampler = train_sampler, num_workers=2)


as you can see that higher weights are given to elements with low bin count.

In [None]:
test_transform = transforms.Compose([
    transforms.Resize((config.image_size,config.image_size)),
    transforms.ToTensor()
   ])


In [None]:
# test data , we don't have worry about imbalance in the test data.
test_dataset = ImageFolder(root=test_dir, transform=test_transform)
test_dataloader  = DataLoader(test_dataset,  batch_size=config.batch_size, shuffle=False, num_workers=2)

# val
val_dataset = ImageFolder(root=val_dir, transform=test_transform)
val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)


In [None]:

# since we have imbalanced classes we will create a custom sampler which will sample with given weights
class_dist, class_weights, bincount = calculate_class_weights_from_directory(test_dir)
test_sampler = torch.utils.data.WeightedRandomSampler(class_weights, len(class_weights), replacement=True)
print('Test')
print('given weights:' ,np.unique(class_weights))
print('counts: ', Counter(class_weights))
print('Test bincount: ', bincount)
# Create a DataLoader to load the data in batches

balanced_test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, sampler = test_sampler, num_workers=2)


In [None]:
# Iterate through the data loader and print the shape of each batch
a = 1
for i, j in val_dataloader:
    print(i.shape)
    print(j.shape)
    a += 1
    if a == 3:
        break


## Class and idx transform

In [None]:
train_dataset.classes

In [None]:
idx2class = {i:name for i, name in enumerate(train_dataset.classes)}
class2idx = {name:i for i, name in enumerate(train_dataset.classes)}

# Vision Transformer(VIT)

![image](https://viso.ai/wp-content/uploads/2021/09/vision-transformer-vit.png)

In [None]:

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_dim, config.mlp_dim)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.out = nn.Linear(config.mlp_dim, config.hidden_dim)

    def forward(self, x):
        x = F.gelu(self.dense(x))
        return self.out(self.dropout(x))

class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.attention = nn.MultiheadAttention(config.hidden_dim, config.num_heads)
        self.mlp = MLP(config)
        self.norm1 = nn.LayerNorm(config.hidden_dim)
        self.norm2 = nn.LayerNorm(config.hidden_dim)
        self.attention_weights = None

    def forward(self, x):
        n_x = self.norm1(x)
        attn_output, self.attention_weights = self.attention(n_x, n_x, n_x)
        x = x + attn_output
        return x + self.mlp(self.norm2(x))

class EncoderStack(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layers = nn.ModuleList([Encoder(config) for _ in range(config.num_layers)])

    def forward(self, x):
        attention_weights = []
        for layer in self.layers:
            x = layer(x)
            attention_weights.append(layer.attention_weights)
        return x, attention_weights


In [None]:

class VisionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.positional_embedding_layer = nn.Embedding(config.num_patches, config.hidden_dim).to(config.device)
        self.positional_embedding_layer.weight.data.uniform_(0, 1)

        self.encoder = EncoderStack(config)  # Create multiple Encoder layers
        self.dense = nn.Linear(config.hidden_dim, config.hidden_dim)
        self.final_layer = nn.Linear(config.hidden_dim, 1)
        self.device = config.device
        self.patch_linear_proj = nn.Linear(config.patching_elements, config.hidden_dim, bias=False).to(config.device)
        self.patcher = Patcher((config.patch_size,config.patch_size))
        # Initialize the class token as a learnable parameter
#         self.class_token = nn.Parameter(torch.zeros(1, 1, config.hidden_dim))  # Learnable class token
        self.positional_embeddings = self.build_positional_embedding(config.num_patches, config.hidden_dim)
        self.sigmoid = nn.Sigmoid()

    def create_patches(self, x):
        batch = x.shape[0]
        patches = self.patcher(x).to(self.device)
#         print('patcher: ', patches.shape,patches.device)
        patches = self.patch_linear_proj(patches)
#         print('proj: ', patches.shape,patches.device)
        # Add a learnable class token to the patch embeddings
#         class_token = self.class_token.expand(batch, -1, -1)
#         patches = torch.cat((class_token.to(self.device), patches), dim=1)

        return patches

    def build_positional_embedding(self, num_patches, embed_dim):
        positions = torch.arange(0, num_patches).view(1, -1).to(self.device)  # Move to the correct device
#         print('posss: ',positions.device)
        positional_embeddings = self.positional_embedding_layer(positions)
        return positional_embeddings


    def test(self, x):
        batch_size = x.shape[0]
        with torch.no_grad():  # Use torch.no_grad() for inference
            print('input: ', x.shape, x.device)
            a = self.create_patches(x)
            print('patches: ', a.shape, a.device)
            a = a + self.positional_embeddings.repeat(batch_size, 1, 1)
            print('pos emb: ', a.shape, a.device)
            a, b = self.encoder(a)
            print('encoder: ', a.shape, a.device)
            a = self.dense(a)
            print('dense: ', a.shape, a.device)
            a = self.final_layer(a)
            print('final: ', a.shape, a.device)
            a,_ = a.max(dim=1)
            print('max: ', a.shape, a.device)
            print('attention weights: ', len(b), '*', b[0].shape)
            return a, b

    def forward(self, x):
        batch_size = x.shape[0]
        patch_embeddings = self.create_patches(x)
        positional_embeddings = self.positional_embeddings.repeat(batch_size, 1, 1)  # Repeat for each batch
        patch_embeddings = patch_embeddings + positional_embeddings
        encoded_output, attention_weights_list = self.encoder(patch_embeddings)
        # Calculate the mean over the 'num_patches' dimension
        encoded_output = encoded_output.mean(dim=1)
        # Apply the final linear layer and sigmoid
        encoded_output = self.final_layer(self.dense(encoded_output))
        encoded_output = self.sigmoid(encoded_output)

        # Reshape to (batch, 1) shape
        encoded_output = encoded_output.view(batch_size, 1)

        return encoded_output, attention_weights_list


class Patcher(nn.Module):
    def __init__(self, patch_size):
        super(Patcher, self).__init__()
        self.patch_size = patch_size

    def forward(self, images):
        if images.dim() == 3:
            images = images.unsqueeze(0)  # Convert a single image to a batch

        batch_size, channels, height, width = images.size()
        patch_height, patch_width = self.patch_size

        # Calculate the number of patches in the height and width dimensions
        num_patches_height = height // patch_height
        num_patches_width = width // patch_width
        num_patches = num_patches_height * num_patches_width

        patches = images.unfold(2, patch_height, patch_height).unfold(3, patch_width, patch_width)
        patches = patches.contiguous().view(batch_size, channels, -1, patch_height, patch_width)
        patches = patches.permute(0, 2, 3, 4, 1).contiguous().view(batch_size, num_patches, -1)

        return patches

In [None]:
vit = VisionTransformer(config)
vit = vit.to(device)
a = torch.rand((1,3,512,512))
output, attention_weights = vit.test(a.to(device))

In [None]:
output, attention_weights = vit(a)
print(output.shape, len(attention_weights))
print('outputs: ',output)

In [None]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(vit.parameters(), lr = 5e-4)

config.device = device
device

In [None]:
try:
    vit.load_state_dict(torch.load(model_path))
    print('weights loaded!')
except:
    print('No saved model')
    pass

In [None]:
def train_model(
    model,
    dataloader,
    optimizer,
    loss_function,
    num_epochs=10,
    device="cpu",
    data_percent=1.0,
    steps_per_epoch=None,
    save_on_every_n_epochs=5,
    model_path=None,
):
    model.to(device)
    print(f"{model.__class__.__name__} Running on: {device}")

    data_size = int(data_percent * len(dataloader)) if steps_per_epoch is None else steps_per_epoch

    # if steps_per_epoch is None:
    #     steps_per_epoch = len(dataloader) // num_epochs

    for epoch in range(num_epochs):
        total_loss = 0.0
        total_correct_predictions = 0
        total_samples = 0

        epoch_progress = tqdm(
            dataloader, desc=f"Epoch [{epoch + 1:2}/{num_epochs:2}]"
        )

        last_update_time = time.time() - 1.0  # Initialize to ensure the first update

        for j, batch in enumerate(epoch_progress):
            image, label = batch
            image = image.to(device)
            label = label.to(device)

            optimizer.zero_grad()

            outputs, _ = model(image)
            outputs = outputs.squeeze()
            predictions = torch.round(outputs)
#             print('pred: ', predictions.shape)
#             print('label: ',label.shape, label.dtype)
# #             print('outsq: ', outputs.shape, outputs.dtype)
#             print('label: ',label)
#             print('outsq: ', predictions, predictions.dtype)

            loss = loss_function(outputs, label.to(outputs.dtype))
            loss.backward(retain_graph=True)
            optimizer.step()

            total_loss += loss.item()
            total_correct_predictions += (predictions.to(torch.int32) == label.to(torch.int32)).sum().item()
            total_samples += label.size()[0]

            formatted_loss = f"{loss.item():.8f}"
            accuracy = (total_correct_predictions / total_samples) * 100
            formatted_accuracy = f"{accuracy:.2f}%"
            
            
            current_time = time.time()
            if current_time - last_update_time > epoch_progress.mininterval:
                epoch_progress.set_postfix(
                    {"Loss": formatted_loss, "Accuracy": formatted_accuracy}
                )
                last_update_time = current_time

            if steps_per_epoch is not None and j + 1 >= steps_per_epoch:
                break
        
        average_loss = total_loss / data_size
        average_accuracy = (total_correct_predictions / (total_samples + 1e-7)) * 100

        print(
            f"\nEpoch [{epoch + 1:2}/{num_epochs:2}] - Average Loss: {average_loss:.8f} - Average Accuracy: {average_accuracy:.2f}%"
        )
        print()

        if (epoch+1) % save_on_every_n_epochs == 0 and model_path is not None:
            torch.save(model.state_dict(), model_path)


In [None]:
train_model(vit, train_dataloader, optimizer, criterion, num_epochs=40, device = device, save_on_every_n_epochs=4, model_path=model_path)

In [None]:
torch.save(vit.state_dict(), model_path)

In [None]:
!ls -la models -h


In [None]:
def validate_model(
    model,
    dataloader,
    loss_function,
    device="cpu",
):
    model.to(device)
    print(f"Validating {model.__class__.__name__} on: {device}")

    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.inference_mode():
        validation_progress = tqdm(
            dataloader, desc="Validation"
        )

        for batch in validation_progress:
            image, label = batch
            image = image.to(device)
            label = label.to(device)

            outputs, _ = model(image)
            outputs = outputs.squeeze()
            predictions = torch.round(outputs)

            loss = loss_function(outputs, label.to(outputs.dtype))
            
            total_loss += loss.item()
            correct_predictions += (predictions.to(torch.int32) == label.to(torch.int32)).sum().item()
            total_samples += label.size()[0]

            formatted_loss = f"{loss.item():.8f}"
            accuracy = (correct_predictions / total_samples) * 100
            formatted_accuracy = f"{accuracy:.2f}%"

            validation_progress.set_postfix(
                {"Loss": formatted_loss, "Accuracy": formatted_accuracy}
            )

    average_loss = total_loss / len(dataloader)
    accuracy = (correct_predictions / (total_samples + 1e-7)) * 100

    print(f"Validation - Average Loss: {average_loss:.8f} - Accuracy: {accuracy:.2f}%")
    print()



In [None]:
validate_model(vit, val_dataloader, criterion, device = config.device)

In [None]:
def evaluate_model(model, dataloader, device="cpu"):
    model.to(device)

    y_true = []
    y_pred = []

    with torch.inference_mode():  # Disable gradient computation during evaluation
        for batch in tqdm(dataloader):
            images, labels = batch
            images = images.to(device)
            labels = labels.to(device)

            outputs, _ = model(images)
            outputs = outputs.squeeze()
            predictions = torch.round(outputs)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predictions.cpu().numpy())

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='binary')
    recall = recall_score(y_true, y_pred, average='binary')
    f1 = f1_score(y_true, y_pred, average='binary')

    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    return y_true, y_pred


In [None]:
def make_cm(ytrue, ypred, title=None):
    # Create a confusion matrix
    cm = confusion_matrix(ytrue, ypred)

    # Create a figure with two subplots
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    # Plot the confusion matrix in the first subplot
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False, ax=axes[0])
    axes[0].set_xlabel("Predicted Labels")
    axes[0].set_ylabel("True Labels")
    axes[0].set_title("Confusion Matrix" if title is None else title)

    # Calculate and display performance metrics in the second subplot
    accuracy = accuracy_score(ytrue, ypred)
    precision = precision_score(ytrue, ypred, average='binary')
    recall = recall_score(ytrue, ypred, average='binary')
    f1 = f1_score(ytrue, ypred, average='binary')

    # Set the x and y coordinates for each text label
    x_pos = 0.2  # Adjust these coordinates as needed
    y_pos = 0.6

    # Print the text on the plot without x-label and y-label
    axes[1].text(x_pos, y_pos, f"Accuracy: {accuracy:.4f}", fontsize=12)
    axes[1].text(x_pos, y_pos - 0.1, f"Precision: {precision:.4f}", fontsize=12)
    axes[1].text(x_pos, y_pos - 0.2, f"Recall: {recall:.4f}", fontsize=12)
    axes[1].text(x_pos, y_pos - 0.3, f"F1 Score: {f1:.4f}", fontsize=12)

    # Remove x-label and y-label from the second subplot
    axes[1].axes.get_xaxis().set_visible(False)
    axes[1].axes.get_yaxis().set_visible(False)

    # Adjust the layout of subplots
    plt.tight_layout()

    # Show the plot
    plt.show()



##  Testing on Given(unbalanced) test_data


In [None]:
#testing on test_data
ytrue, ypred = evaluate_model(vit, test_dataloader, device = config.device)

In [None]:
np.bincount(ytrue), np.bincount(ypred)

In [None]:
make_cm(ytrue, ypred,title='test data unbalanced')

In [None]:
# Explaination

*  **Accuracy**  
*  **Precision**    : (higher the number less likely to predict false positive.) which means model is likely to not classify a healthy person as pneumonia patient 
*  **Recall**       : (higher the number less likely to predict false negetive.) which means model is likely to not classify a pneumonia patient as healthy person.
* **F1**          : (higher the better) this is sweet spot between precision and recall.

We should aim for higher Recall. because it is better to classify a healthy person as sick compared to classifing a sick person as healthy. 


## Testing on balanced test_data

In [None]:
#testing on test_data
ytrue, ypred = evaluate_model(vit, balanced_test_dataloader, device = config.device)

In [None]:
np.bincount(ytrue), np.bincount(ypred)

In [None]:
make_cm(ytrue, ypred,title='test data balanced')

# Model 's performance can be improved with more unbiased data. 

In conclusion, the model's performance and capabilities are influenced by various factors, including its architecture, training data, and the problem it was designed to solve. 