# RetinAI_ViT

Diabetic Retinopathy Classifier using BEiT-2, Attention, and a custom head.

## Configuration

In [6]:
import numpy as np
import torch

In [7]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {DEVICE}")

✅ Using device: cuda


In [8]:
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

### Model & Training Hyperparameters

In [9]:
MODEL_NAME = 'microsoft/beit-base-patch16-224'
NUM_CLASSES = 5
NUM_EPOCHS = 15
BATCH_SIZE = 32 # Adjust based on your GPU memory
LEARNING_RATE = 1e-5 # Lower learning rate for fine-tuning

## Data Preparation

In [10]:
print(f"Dataset is from https://www.kaggle.com/datasets/amanneo/diabetic-retinopathy-resized-arranged, Download and Extract")
DATA_DIR = '/home/spidey03/Downloads/diabetic-retinopathy-resized-arranged'

Dataset is from https://www.kaggle.com/datasets/amanneo/diabetic-retinopathy-resized-arranged, Download and Extract


### Load Image Processor for BeiT model

In [11]:
from transformers import AutoImageProcessor

processor = AutoImageProcessor.from_pretrained(MODEL_NAME, use_fast=True)
image_mean = processor.image_mean
image_std = processor.image_std
image_size = processor.size['height']

### Data Augmentations for Training Set

In [12]:
import torchvision.transforms as T

train_transforms = T.Compose([
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(degrees=15),
    T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    T.Resize((image_size, image_size)),
    T.ToTensor(),
    T.Normalize(mean=image_mean, std=image_std),
])

### Data Transformations for Validation & Test sets

In [13]:
eval_transforms = T.Compose([
    T.Resize((image_size, image_size)),
    T.ToTensor(),
    T.Normalize(mean=image_mean, std=image_std),
])

### Load Dataset

In [14]:
from torchvision.datasets import ImageFolder

print(f"Loading data from {DATA_DIR}")

full_dataset = ImageFolder(DATA_DIR)
class_names = full_dataset.classes
print(f"Found {len(full_dataset)} images belonging to {len(class_names)} classes.")

Loading data from /home/spidey03/Downloads/diabetic-retinopathy-resized-arranged
Found 35126 images belonging to 5 classes.


### Data Split

In [15]:
from torch.utils.data import random_split

TRAIN_SPLIT = 0.7
VALID_SPLIT = 0.15

train_size = int(TRAIN_SPLIT * len(full_dataset))
valid_size = int(VALID_SPLIT * len(full_dataset))
test_size = len(full_dataset) - train_size - valid_size

train_subset, valid_subset, test_subset = random_split(
    full_dataset, [train_size, valid_size, test_size],
    generator=torch.Generator().manual_seed(SEED)
)

print(f"Training set size: {len(train_subset)}")
print(f"Validation set size: {len(valid_subset)}")
print(f"Test set size: {len(test_subset)}")

Training set size: 24588
Validation set size: 5268
Test set size: 5270


#### Custom Dataset class to Apply correct transformation

In [16]:
from torch.utils.data import Dataset

class DRDataset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

train_dataset = DRDataset(train_subset, transform=train_transforms)
valid_dataset = DRDataset(valid_subset, transform=eval_transforms)
test_dataset = DRDataset(test_subset, transform=eval_transforms)

#### Handle Class Imbalance with WeightedRandomSampler

In [None]:
from torch.utils.data import WeightedRandomSampler

print("\n⚖️ Addressing class imbalance...")
train_labels = [label for _, label in train_subset]
class_counts = np.bincount(train_labels)
class_weights = 1. / class_counts
sample_weights = np.array([class_weights[t] for t in train_labels])
sampler = WeightedRandomSampler(
    weights=torch.from_numpy(sample_weights).double(),
    num_samples=len(train_subset),
    replacement=True
)


⚖️ Addressing class imbalance...
