## A Vision Transformer for diagnosing AD - including a

In [18]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### Check GPU availability

In [20]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


### Import libraries

In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import timm  # For Vision Transformer models

### Prepare the data

In [None]:
# Set directories and transform our images
# Data directories
train_dir = '/content/drive/MyDrive/AI_Projects/MRI/train'
test_dir = '/content/drive/MyDrive/AI_Projects/MRI/test'

# Define transforms (resize images, convert to tensor, and normalize using ImageNet stats)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Create datasets using ImageFolder (assumes subfolders correspond to class labels)
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)

# Data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


### Create our vision transformer model

In [None]:
# Create a Vision Transformer model; for example, 'vit_base_patch16_224'
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=2)
model = model.to(device)

### Define Loss Function and Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

### Train our model

In [None]:
num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_dataset)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f}")


KeyboardInterrupt: 

### Evaluate the model

In [None]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print("Test Accuracy: {:.2f}%".format(accuracy * 100))


### Heatmaps

In [None]:
!pip install captum
import matplotlib.pyplot as plt
import numpy as np
from captum.attr import Occlusion

# Put the model in evaluation mode
model.eval()

# Get one sample from the test set (e.g., the first image and its label)
images, labels = next(iter(test_loader))
# Select the first image and add a batch dimension: shape becomes (1, 3, 224, 224)
img = images[0].unsqueeze(0).to(device)
target_label = labels[0].item()

# Create an Occlusion object for our model
occlusion = Occlusion(model)

# Compute occlusion attributions. Here we occlude with a window of size 50x50 pixels.
# The input tensor shape is (1, 3, 224, 224), so we need sliding_window_shapes and strides as 4-tuples.
attributions = occlusion.attribute(
    img,
    strides=(1, 1, 8, 8),                   # don't stride on batch/channels, slide spatially
    sliding_window_shapes=(1, 3, 50, 50),     # occlude all channels over a 50x50 region
    target=target_label,
    baselines=0
)

# Remove the batch dimension and convert to a NumPy array: shape (3, 224, 224)
attr_np = attributions.squeeze(0).detach().cpu().numpy()

# To visualize as a heatmap, sum across channels to collapse to (224, 224)
heatmap = np.sum(attr_np, axis=0)

# Plot the heatmap
plt.figure(figsize=(6, 6))
plt.imshow(heatmap, cmap='hot', interpolation='nearest')
plt.title('Occlusion Sensitivity Heatmap')
plt.colorbar()
plt.show()