In [5]:
# ============================================================
# FUNDUS vs NON-FUNDUS CLASSIFIER TRAINER (UPDATED VERSION)
# ============================================================

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split, ConcatDataset
from tqdm import tqdm

# -----------------------------
# CONFIGURATION
# -----------------------------
FUNDUS_DIR = r"C:\Users\hitha\OneDrive\Desktop\eye app\eye disease app\dataset"
BRIGHT_DIR = r"C:\Users\hitha\OneDrive\Desktop\eye app\eye disease app\backend\bright_nonfundus"

SAVE_PATH = r"C:\Users\hitha\OneDrive\Desktop\eye app\eye disease app\backend\models\fundus_vs_nonfundus.pt"

BATCH_SIZE = 16
EPOCHS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# DATA TRANSFORMS
# -----------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# -----------------------------
# LOAD FUNDUS IMAGES
# -----------------------------
fundus_dataset = datasets.ImageFolder(FUNDUS_DIR, transform=transform)
fundus_dataset.samples = [(path, 1) for path, _ in fundus_dataset.samples]
print(f"âœ… Fundus images loaded: {len(fundus_dataset)}")

# -----------------------------
# LOAD BRIGHT NON-FUNDUS IMAGES
# -----------------------------
bright_dataset = datasets.ImageFolder(BRIGHT_DIR, transform=transform)

# label = 0 for NON-FUNDUS
bright_dataset.samples = [(path, 0) for path, _ in bright_dataset.samples]

print(f"ðŸ’¡ Bright non-fundus images loaded: {len(bright_dataset)}")

# -----------------------------
# LOAD CIFAR10 AS GENERAL NON-FUNDUS
# -----------------------------
cifar_data = datasets.CIFAR10(
    root="./cifar_data",
    train=True,
    download=True,
    transform=transform
)

print(f"ðŸ“¦ CIFAR10 non-fundus images loaded: {len(cifar_data)}")

# Make CIFAR subset same size as fundus dataset
cifar_subset, _ = random_split(
    cifar_data,
    [len(fundus_dataset), len(cifar_data) - len(fundus_dataset)]
)

# Wrap CIFAR labels -> 0
class CifarNonFundus(torch.utils.data.Dataset):
    def __init__(self, subset):
        self.subset = subset

    def __getitem__(self, idx):
        x, _ = self.subset[idx]
        return x, 0

    def __len__(self):
        return len(self.subset)

cifar_wrapped = CifarNonFundus(cifar_subset)

# -----------------------------
# COMBINE DATASETS
# -----------------------------
combined_dataset = ConcatDataset([
    fundus_dataset,
    bright_dataset,
    cifar_wrapped
])

print(f"ðŸ“Š Total dataset size: {len(combined_dataset)}")

train_len = int(0.8 * len(combined_dataset))
val_len = len(combined_dataset) - train_len

train_set, val_set = random_split(combined_dataset, [train_len, val_len])

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE)

print(f"ðŸŸ¦ Train: {len(train_set)} | ðŸŸ© Val: {len(val_set)}")

# -----------------------------
# MODEL â€” MobileNet V3 Small
# -----------------------------
model = models.mobilenet_v3_small(
    weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1
)
model.classifier[3] = nn.Linear(model.classifier[3].in_features, 2)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# -----------------------------
# TRAINING LOOP
# -----------------------------
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"ðŸ§  Epoch {epoch+1} | Train Loss: {total_loss/len(train_loader):.4f}")

    # Validation
    model.eval()
    correct, total = 0, 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = 100 * correct / total
    print(f"âœ… Validation Accuracy: {acc:.2f}%")

# -----------------------------
# SAVE MODEL
# -----------------------------
os.makedirs(os.path.dirname(SAVE_PATH), exist_ok=True)
torch.save(model.state_dict(), SAVE_PATH)

print(f"\nðŸŽ‰ Model saved â†’ {SAVE_PATH}")


âœ… Fundus images loaded: 4152
ðŸ’¡ Bright non-fundus images loaded: 37
Files already downloaded and verified
ðŸ“¦ CIFAR10 non-fundus images loaded: 50000
ðŸ“Š Total dataset size: 8341
ðŸŸ¦ Train: 6672 | ðŸŸ© Val: 1669


Epoch 1/5: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 417/417 [03:25<00:00,  2.03it/s]


ðŸ§  Epoch 1 | Train Loss: 0.0287
âœ… Validation Accuracy: 100.00%


Epoch 2/5: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 417/417 [20:26<00:00,  2.94s/it]   


ðŸ§  Epoch 2 | Train Loss: 0.0009
âœ… Validation Accuracy: 100.00%


Epoch 3/5: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 417/417 [05:25<00:00,  1.28it/s]


ðŸ§  Epoch 3 | Train Loss: 0.0014
âœ… Validation Accuracy: 100.00%


Epoch 4/5: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 417/417 [05:23<00:00,  1.29it/s]


ðŸ§  Epoch 4 | Train Loss: 0.0002
âœ… Validation Accuracy: 100.00%


Epoch 5/5: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 417/417 [05:18<00:00,  1.31it/s]


ðŸ§  Epoch 5 | Train Loss: 0.0001
âœ… Validation Accuracy: 100.00%

ðŸŽ‰ Model saved â†’ C:\Users\hitha\OneDrive\Desktop\eye app\eye disease app\backend\models\fundus_vs_nonfundus.pt
