In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from timm import create_model  # pip install timm

In [3]:
import zipfile

zip_path = "/content/preprocessed.zip"
extract_path = "datasets/preprocessed"

os.makedirs(extract_path, exist_ok=True)

with zipfile.ZipFile(zip_path, "r") as zip_ref:
    zip_ref.extractall(extract_path)

print("Dataset unzipped into datasets/preprocessed/")

Dataset unzipped into datasets/preprocessed/


## Swin Transformer on Pre-processed Fundus Images

- Using `timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)`  
- Output layer modified for **2 classes**: `No_DR` and `DR`.
- Model is moved to **GPU if available**, else CPU.


In [4]:
# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

Using device: cuda


In [5]:
# Dataset Definition
class DRDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        label_map = {"No_DR": 0, "DR": 1}

        for label in os.listdir(root_dir):
            if label not in label_map:
                continue
            label_dir = os.path.join(root_dir, label)
            for img_file in os.listdir(label_dir):
                self.images.append(os.path.join(label_dir, img_file))
                self.labels.append(label_map[label])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)
        return image, label


In [6]:
# Transforms
IMG_SIZE = 224  # Swin Transformer expects 224x224
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

val_test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

In [7]:
# Datasets & DataLoaders

train_dir = "/content/datasets/preprocessed/preprocessed/train"
val_dir   = "/content/datasets/preprocessed/preprocessed/val"
test_dir  = "/content/datasets/preprocessed/preprocessed/test"

train_dataset = DRDataset(train_dir, transform=train_transform)
val_dataset   = DRDataset(val_dir, transform=val_test_transform)
test_dataset  = DRDataset(test_dir, transform=val_test_transform)

BATCH_SIZE = 16
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [11]:
# Model Definition
# Create Swin Transformer (tiny) pretrained on ImageNet
model = create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=2)
model = model.to(DEVICE)

In [9]:
# Loss & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [12]:
# Training Loop
NUM_EPOCHS = 10
best_acc = 0

for epoch in range(NUM_EPOCHS):
    # Training
    model.train()
    running_loss = 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Training Loss: {avg_loss:.4f}")

    #  Validation 
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    val_acc = correct / total
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Validation Accuracy: {val_acc:.4f}")

    # Save Best Model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), "best_swin_transformer.pth")
        print(" Saved new best model")

Epoch 1 Training: 100%|██████████| 633/633 [00:38<00:00, 16.40it/s]

Epoch [1/10] Training Loss: 0.8276





Epoch [1/10] Validation Accuracy: 0.5367
 Saved new best model


Epoch 2 Training: 100%|██████████| 633/633 [00:38<00:00, 16.64it/s]

Epoch [2/10] Training Loss: 0.8266





Epoch [2/10] Validation Accuracy: 0.5367


Epoch 3 Training: 100%|██████████| 633/633 [00:38<00:00, 16.34it/s]

Epoch [3/10] Training Loss: 0.8272





Epoch [3/10] Validation Accuracy: 0.5367


Epoch 4 Training: 100%|██████████| 633/633 [00:38<00:00, 16.50it/s]

Epoch [4/10] Training Loss: 0.8257





Epoch [4/10] Validation Accuracy: 0.5367


Epoch 5 Training: 100%|██████████| 633/633 [00:38<00:00, 16.32it/s]

Epoch [5/10] Training Loss: 0.8248





Epoch [5/10] Validation Accuracy: 0.5367


Epoch 6 Training: 100%|██████████| 633/633 [00:38<00:00, 16.32it/s]

Epoch [6/10] Training Loss: 0.8254





Epoch [6/10] Validation Accuracy: 0.5367


Epoch 7 Training: 100%|██████████| 633/633 [00:38<00:00, 16.37it/s]

Epoch [7/10] Training Loss: 0.8266





Epoch [7/10] Validation Accuracy: 0.5367


Epoch 8 Training: 100%|██████████| 633/633 [00:38<00:00, 16.48it/s]

Epoch [8/10] Training Loss: 0.8264





Epoch [8/10] Validation Accuracy: 0.5367


Epoch 9 Training: 100%|██████████| 633/633 [00:39<00:00, 16.13it/s]

Epoch [9/10] Training Loss: 0.8276





Epoch [9/10] Validation Accuracy: 0.5367


Epoch 10 Training: 100%|██████████| 633/633 [00:38<00:00, 16.61it/s]

Epoch [10/10] Training Loss: 0.8279





Epoch [10/10] Validation Accuracy: 0.5367


###  Final Evaluation
We evaluate the enhanced model on the **test dataset**, giving the final:

In [13]:
model.load_state_dict(torch.load("/content/best_swin_transformer.pth"))
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

test_acc = correct / total
print(" TEST ACCURACY:", test_acc)

 TEST ACCURACY: 0.5365168539325843
