# RetinAI_ViT

Diabetic Retinopathy Classifier using BEiT-2, Attention, and a custom head.

## Configuration

In [6]:
import numpy as np
import torch

In [7]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {DEVICE}")

✅ Using device: cuda


In [8]:
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

### Model & Training Hyperparameters

In [9]:
MODEL_NAME = 'microsoft/beit-base-patch16-224'
NUM_CLASSES = 5
NUM_EPOCHS = 15
BATCH_SIZE = 32 # Adjust based on your GPU memory
LEARNING_RATE = 1e-5 # Lower learning rate for fine-tuning

## Data Preparation

In [10]:
print(f"Dataset is from https://www.kaggle.com/datasets/amanneo/diabetic-retinopathy-resized-arranged, Download and Extract")
DATA_DIR = '/home/spidey03/Downloads/diabetic-retinopathy-resized-arranged'

Dataset is from https://www.kaggle.com/datasets/amanneo/diabetic-retinopathy-resized-arranged, Download and Extract


### Load Image Processor for BeiT model

In [11]:
from transformers import AutoImageProcessor

processor = AutoImageProcessor.from_pretrained(MODEL_NAME, use_fast=True)
image_mean = processor.image_mean
image_std = processor.image_std
image_size = processor.size['height']

### Data Augmentations for Training Set

In [12]:
import torchvision.transforms as T

train_transforms = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(degrees=15),
    T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    T.Resize((image_size, image_size)),
    T.ToTensor(),
    T.Normalize(mean=image_mean, std=image_std),
])

### Data Transformations for Validation & Test sets

In [13]:
eval_transforms = T.Compose([
    T.Resize((image_size, image_size)),
    T.ToTensor(),
    T.Normalize(mean=image_mean, std=image_std),
])

### Load Dataset

In [14]:
from torchvision.datasets import ImageFolder

print(f"Loading data from {DATA_DIR}")

full_dataset = ImageFolder(DATA_DIR)
class_names = full_dataset.classes
print(f"Found {len(full_dataset)} images belonging to {len(class_names)} classes.")

Loading data from /home/spidey03/Downloads/diabetic-retinopathy-resized-arranged
Found 35126 images belonging to 5 classes.


### Data Split

In [15]:
from torch.utils.data import random_split

TRAIN_SPLIT = 0.7
VALID_SPLIT = 0.15

train_size = int(TRAIN_SPLIT * len(full_dataset))
valid_size = int(VALID_SPLIT * len(full_dataset))
test_size = len(full_dataset) - train_size - valid_size

train_subset, valid_subset, test_subset = random_split(
    full_dataset, [train_size, valid_size, test_size],
    generator=torch.Generator().manual_seed(SEED)
)

print(f"Training set size: {len(train_subset)}")
print(f"Validation set size: {len(valid_subset)}")
print(f"Test set size: {len(test_subset)}")

Training set size: 24588
Validation set size: 5268
Test set size: 5270


#### Custom Dataset class to Apply correct transformation

In [16]:
from torch.utils.data import Dataset

class DRDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

train_dataset = DRDataset(train_subset, transform=train_transforms)
valid_dataset = DRDataset(valid_subset, transform=eval_transforms)
test_dataset = DRDataset(test_subset, transform=eval_transforms)

#### Handle Class Imbalance with WeightedRandomSampler

In [17]:
from torch.utils.data import WeightedRandomSampler

print("\n⚖️ Addressing class imbalance...")
train_labels = [label for _, label in train_subset]
class_counts = np.bincount(train_labels)
class_weights = 1. / class_counts
sample_weights = np.array([class_weights[t] for t in train_labels])
sampler = WeightedRandomSampler(
    weights=torch.from_numpy(sample_weights).double(),
    num_samples=len(train_subset),
    replacement=True
)


⚖️ Addressing class imbalance...


#### Create Dataloaders

In [18]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
print("✅ DataLoaders created.")

✅ DataLoaders created.


### Data Architecture

In [19]:
import torch.nn as nn
from transformers import BeitModel

class BEiTForDRClassification(nn.Module):
    def __init__(self, num_classes=5):
        super(BEiTForDRClassification, self).__init__()
        # Load pre-trained BEiT model
        self.beit = BeitModel.from_pretrained(MODEL_NAME)
        
        # --- Freeze most layers ---
        # Unfreeze only the last 2 layers of the encoder for fine-tuning
        for param in self.beit.parameters():
            param.requires_grad = False
        for layer in self.beit.encoder.layer[-2:]:
            for param in layer.parameters():
                param.requires_grad = True

        hidden_size = self.beit.config.hidden_size

        # --- Attention Mechanism ---
        # This will learn to weigh the importance of different patch embeddings
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_size,
            num_heads=8, # A common choice
            dropout=0.1,
            batch_first=True
        )

        # --- Classifier Head ---
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_size),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.GELU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_size // 2, num_classes)
        )

    def forward(self, pixel_values):
        # BEiT forward pass -> outputs patch embeddings
        # Output shape: (batch_size, num_patches + 1, hidden_size)
        outputs = self.beit(pixel_values=pixel_values)
        last_hidden_state = outputs.last_hidden_state

        # Attention mechanism
        # MultiheadAttention expects (query, key, value)
        # We use the same hidden state for all three
        attn_output, _ = self.attention(
            last_hidden_state, last_hidden_state, last_hidden_state
        )

        # We take the output corresponding to the [CLS] token for classification
        # The [CLS] token is the first token in the sequence
        cls_token_output = attn_output[:, 0]

        # Pass through the classifier head
        logits = self.classifier(cls_token_output)
        return logits

print("\n🧠 Initializing model architecture...")
model = BEiTForDRClassification(num_classes=NUM_CLASSES).to(DEVICE)
print("✅ Model initialized and moved to device.")


🧠 Initializing model architecture...
✅ Model initialized and moved to device.


### Training & Fine Tuning

In [20]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.2, patience=2)

In [21]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from tqdm.auto import tqdm

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    progress_bar = tqdm(dataloader, desc="Training")
    for inputs, labels in progress_bar:
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        progress_bar.set_postfix(loss=loss.item())

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = accuracy_score(all_labels, all_preds)
    return epoch_loss, epoch_acc

In [22]:
def evaluate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Evaluating")
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * inputs.size(0)
            
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
            progress_bar.set_postfix(loss=loss.item())
            
    epoch_loss = running_loss / len(dataloader.dataset)
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    
    return epoch_loss, {"accuracy": accuracy, "f1": f1, "precision": precision, "recall": recall}, all_labels, all_preds