# VISUAL TRANSFORMER TRAINING

In [1]:
cd /content/drive/MyDrive/pneumonia_data/vit

/content/drive/MyDrive/pneumonia_data/vit


In [2]:
!ls

In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, RandomSampler
import torch.optim as optim

import os
os.chdir("..")

from utils.vit_2 import ViTClassifier, PneumoniaDataset
from utils.vit_utils import VitUtilities

### Define The Model

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViTClassifier().to(device)

print("ViT Classifier Initialized!")

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


ViT Classifier Initialized!


### Define Dataset

In [5]:
CLASS_NAMES = {"pneumonia": 0, "normal": 1}

# train_dir = r"D:\pulpit\wbudowane\pneumonia_training\data\train"
# val_dir = r"D:\pulpit\wbudowane\pneumonia_training\data\val"
train_dir = "/content/drive/MyDrive/pneumonia_data/data/clsif/train"
val_dir = "/content/drive/MyDrive/pneumonia_data/data/clsif/val"

train_dataset = PneumoniaDataset(train_dir, CLASS_NAMES)
val_dataset = PneumoniaDataset(val_dir, CLASS_NAMES)

print(f"Training Samples: {len(train_dataset)}, Validation Samples: {len(val_dataset)}")

Training Samples: 11888, Validation Samples: 2972


### Define DataLoaders

In [6]:
train_loader = DataLoader(train_dataset, batch_size=8, sampler=RandomSampler(train_dataset), num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4, pin_memory=True)

print(f"DataLoaders Created Successfully")

DataLoaders Created Successfully




### Load The Model (optional)

In [7]:
# model.load_state_dict(torch.load("vit_pneumonia_classifier_best.pth", map_location=torch.device('cpu')))
# model.to(device)  
# model.train()

### Define Criterion and Optimizer

In [10]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW([
        {'params': model.vit.patch_embed.parameters(), 'lr': 1e-6},  # Frozen in practice
        {'params': model.vit.blocks[:-6].parameters(), 'lr': 3e-5},  # Early layers
        {'params': model.vit.blocks[-6:-3].parameters(), 'lr': 1e-4}, # Mid layers
        {'params': model.vit.blocks[-3:].parameters(), 'lr': 3e-4},   # Late layers
        {'params': model.vit.head.parameters(), 'lr': 1e-3}           # Classification head
    ])

### Train The Model

In [11]:
epochs = 10

VitUtilities.train_vit(model, train_loader, val_loader, optimizer, criterion, device, epochs)
torch.save(model.state_dict(), "/content/drive/MyDrive/pneumonia_data/data/vit_pneumonia_classifier_test.pth")

print("Model Saved Successfully!")

Epoch 1/10 [Train]: 100%|██████████| 1486/1486 [10:43<00:00,  2.31it/s, loss=0.0380]
Epoch 2/10 [Train]: 100%|██████████| 1486/1486 [10:38<00:00,  2.33it/s, loss=0.0045]
Epoch 3/10 [Train]: 100%|██████████| 1486/1486 [10:31<00:00,  2.35it/s, loss=0.3985]
Epoch 4/10 [Train]: 100%|██████████| 1486/1486 [10:41<00:00,  2.32it/s, loss=0.4600]
Epoch 5/10 [Train]: 100%|██████████| 1486/1486 [10:47<00:00,  2.30it/s, loss=0.0028]
Epoch 6/10 [Train]: 100%|██████████| 1486/1486 [10:22<00:00,  2.39it/s, loss=0.0129]
Epoch 7/10 [Train]: 100%|██████████| 1486/1486 [10:18<00:00,  2.40it/s, loss=0.0017]
Epoch 8/10 [Train]: 100%|██████████| 1486/1486 [10:32<00:00,  2.35it/s, loss=0.4860]
Epoch 9/10 [Train]: 100%|██████████| 1486/1486 [10:30<00:00,  2.36it/s, loss=0.0024]
Epoch 10/10 [Train]: 100%|██████████| 1486/1486 [10:19<00:00,  2.40it/s, loss=0.0859]


Training Complete!
Final Confusion Matrix:
[[1119   83]
 [ 123 1647]]
Model Saved Successfully!


In [None]:
import torch
from torchvision import transforms
from PIL import Image

def predict_vit(model, image_path, device):
    """
    Function to predict the class of an image using a trained ViT model.

    Args:
        model (torch.nn.Module): The trained ViT model.
        image_path (str): Path to the image file.
        device (torch.device): Device to run inference on (CPU or CUDA).

    Returns:
        str: Predicted class label ("Pneumonia" or "Normal").
    """

    # Define the same transformations as training
    transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1), 
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # Repeat to 3 channels
            transforms.Normalize([0.5]*3, [0.5]*3),
            transforms.Resize(224),
            transforms.CenterCrop(224)
        ])

    # Load and preprocess the image
    image = Image.open(image_path).convert("L")
    image = VitUtilities.apply_clahe(image)
    image_tensor = transform(image).unsqueeze(0).to(device)  # Add batch dimensi

    # Set model to evaluation mode
    model.eval()

    with torch.no_grad():
        output = model(image_tensor).squeeze()
        probability = torch.sigmoid(output).item()  # Convert logits to probability

    # Class mapping (assuming binary classification)
    predicted_class = "Pneumonia" if probability < 0.5 else "Normal"

    return predicted_class, probability


In [28]:
image_path = r"/content/drive/MyDrive/pneumonia_data/data/clsif/rr3.png"

predict_vit(model, image_path, device)

('Normal', 0.9977536797523499)

In [58]:
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from PIL import Image

def vit_gradcam(model, image_path, device, block_idx=-1):
    """
    Working Grad-CAM for modern timm ViTs
    Args:
        block_idx: Which transformer block to visualize (-1 for last)
    """
    # Preprocessing
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
        transforms.Normalize([0.5]*3, [0.5]*3),
        transforms.Resize(224),
        transforms.CenterCrop(224)
    ])
    
    # Load and process image
    orig_img = Image.open(image_path).convert("L")
    img_clahe = VitUtilities.apply_clahe(orig_img)
    img_tensor = transform(img_clahe).unsqueeze(0).to(device)
    
    # Hook setup
    attention_weights = []
    
    def hook_fn(module, input, output):
        # For timm's Attention module, we need to manually compute attention
        B, N, C = input[0].shape
        qkv = module.qkv(input[0]).reshape(B, N, 3, module.num_heads, C // module.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # [B, heads, N, C//heads]
        
        # Compute attention matrix
        attn = (q @ k.transpose(-2, -1)) * module.scale
        attn = attn.softmax(dim=-1)
        attention_weights.append(attn.detach().mean(dim=1))  # Average heads
    
    # Register hook
    target_block = model.vit.blocks[block_idx]
    handle = target_block.attn.register_forward_hook(hook_fn)
    
    # Forward pass
    model.eval()
    with torch.no_grad():
        output = model(img_tensor)
        prob = torch.sigmoid(output).item()
        pred_class = "Pneumonia" if prob > 0.5 else "Normal"
        
        if not attention_weights:
            raise ValueError("No attention weights captured. Check model architecture.")
        
        # Process attention weights
        attn = attention_weights[0][0]  # [N+1, N+1]
        attn = attn[1:, 1:]  # Remove CLS token
        cam = attn.mean(dim=0)  # Average over keys
        
        # Reshape to patch grid
        grid_size = int(np.sqrt(cam.shape[0]))
        cam = cam.reshape(grid_size, grid_size)
        
        # Interpolate to image size
        cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), 
                          size=(224, 224), 
                          mode='bicubic').squeeze().cpu().numpy()
        
        # Normalize
        cam = np.maximum(cam, 0)
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        
        # Create overlay
        img_np = np.array(orig_img.convert("RGB"))
        img_np = np.array(transforms.functional.resize(Image.fromarray(img_np), (224, 224)))
        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        overlayed = cv2.addWeighted(img_np, 0.6, heatmap, 0.4, 0)
    
    # Remove hook
    handle.remove()
    
    return overlayed, pred_class, prob

def plot_gradcam(gradcam_result):
    """Visualize Grad-CAM results"""
    overlayed, pred_class, prob = gradcam_result
    
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(overlayed)
    plt.title(f"Pred: {pred_class} ({prob:.2f})")
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(overlayed)
    plt.axis('off')
    plt.show()

In [None]:
image_path = r"/content/drive/MyDrive/pneumonia_data/data/clsif/rr3.png"
result = vit_gradcam(model, image_path, device, block_idx=3)
plot_gradcam(result)