In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
# ----- Set paths -----
image_dir_train = 'train'  # Directory containing the training images
csv_file = 'trainLabels/trainLabels.csv'  # CSV with image labels
image_dir_val = 'test'  # (not used in this snippet, but for future validation/test)

# ----- Load the CSV -----
df = pd.read_csv(csv_file)
df["image"] = df["image"].astype(str)  # ensure string format

# ----- Split the data -----
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

In [None]:
from dataset_utils import DRDatasetGPU #Import the DRDatasetGPU class from dataset_utils.py file
from torch.utils.data import DataLoader

In [None]:
train_dataset = DRDatasetGPU(image_dir=image_dir_train, dataframe=train_df,test_mode=False)
val_dataset   = DRDatasetGPU(image_dir=image_dir_train, dataframe=val_df,test_mode=False)
train_loader = DataLoader(train_dataset, batch_size=6, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_dataset, batch_size=6, shuffle=False, num_workers=4)
# batch_size = 6 for efficiencynetb4
# ----- Sample batch for sanity check -----
images, labels = next(iter(train_loader))
print("Batch image shape :", images.shape)  # [B, 3, 224, 224]
print("Batch label shape :", labels.shape)  # [B]
print("Label dtype       :", labels.dtype)  # torch.int64

[PyTorch Dataset] Valid images found: 28100
[PyTorch Dataset] Valid images found: 7026
Batch image shape : torch.Size([6, 3, 380, 380])
Batch label shape : torch.Size([6])
Label dtype       : torch.int64


In [None]:
import torch.nn as nn,numpy as np
import timm,torch
# from torch.optim import Adam
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import torch.nn as nn
from torchvision import models
from sklearn.utils.class_weight import compute_class_weight

  from .autonotebook import tqdm as notebook_tqdm


### Model

In [3]:

class EfficientNetB4_DR(nn.Module):
    def __init__(self, num_classes=5):
        super().__init__()
        self.backbone = timm.create_model('efficientnet_b4', pretrained=True)
        self.backbone.classifier = nn.Identity()  # Remove original classifier
        self.head = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(self.backbone.num_features, num_classes)
        )

    def forward(self, x):
        x = self.backbone(x)
        return self.head(x)


In [None]:

# ---- Model, Loss, Optimizer Setup ----
device = "cuda" if torch.cuda.is_available() else "cpu"
model_efficientnet = EfficientNetB4_DR(num_classes=5).to(device)

# Compute weights using full train dataframe
classes = np.array([0, 1, 2, 3, 4])
weights = compute_class_weight(class_weight="balanced", classes=classes, y=train_df["level"])
weights = np.clip(weights, 0.5, 3.0)  # avoid extremely high weights
class_weights = torch.tensor(weights, dtype=torch.float).to(device)

criterion = nn.CrossEntropyLoss(class_weights,label_smoothing=0.05)

optimizer = torch.optim.Adam(model_efficientnet.parameters(), lr=5e-4,weight_decay=3e-4)#EfficientNetB4_DR


# ---- Training Loop ----
num_epochs = 10
best_val_loss = float('inf')
patience = 3

for epoch in range(num_epochs):
    model_efficientnet.train()
    running_loss = 0.0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device,non_blocking=True), labels.to(device,non_blocking=True)

        optimizer.zero_grad()
        logits = model_efficientnet(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        

        running_loss += loss.item()

    avg_train_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}")
    
    # ---- Validation ----
    model_efficientnet.eval()
    val_loss, correct, total = 0.0, 0, 0
    all_preds, all_targets = [], []

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            logits = model_efficientnet(images)
            loss = criterion(logits, labels)
            val_loss += loss.item()

            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu())
            all_targets.append(labels.cpu())

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    avg_val_loss = val_loss / len(val_loader)
    val_acc = correct / total
    print(f"→ Val Loss: {avg_val_loss:.4f} | Val Accuracy: {val_acc:.2%}")
    
     # Early stopping logic
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        stop_counter = 0
        torch.save(model_efficientnet.state_dict(), "EfficientNetB4_DR_model2.pt")
        best_preds = torch.cat(all_preds)
        best_targets = torch.cat(all_targets)
    torch.cuda.empty_cache()

# ---- Evaluation ----
print("\n=== Final Evaluation Report ===")
y_true = torch.cat(all_targets).numpy()
y_pred = torch.cat(all_preds).numpy()

print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred))

print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=[
    "No DR", "Mild", "Moderate", "Severe", "Proliferative DR"]))


Epoch 1/10: 100%|██████████| 4684/4684 [12:13<00:00,  6.39it/s]


Epoch 1: Train Loss = 1.3701
→ Val Loss: 1.2773 | Val Accuracy: 76.40%


Epoch 2/10: 100%|██████████| 4684/4684 [12:13<00:00,  6.38it/s]


Epoch 2: Train Loss = 1.2784
→ Val Loss: 1.2559 | Val Accuracy: 81.18%


Epoch 3/10: 100%|██████████| 4684/4684 [12:17<00:00,  6.35it/s]


Epoch 3: Train Loss = 1.2411
→ Val Loss: 1.2381 | Val Accuracy: 80.33%


Epoch 4/10: 100%|██████████| 4684/4684 [12:18<00:00,  6.34it/s]


Epoch 4: Train Loss = 1.2080
→ Val Loss: 1.2621 | Val Accuracy: 82.17%


Epoch 5/10: 100%|██████████| 4684/4684 [12:19<00:00,  6.34it/s]


Epoch 5: Train Loss = 1.1841
→ Val Loss: 1.3026 | Val Accuracy: 78.76%


Epoch 6/10: 100%|██████████| 4684/4684 [12:20<00:00,  6.33it/s]


Epoch 6: Train Loss = 1.1696
→ Val Loss: 1.2225 | Val Accuracy: 80.39%


Epoch 7/10: 100%|██████████| 4684/4684 [12:19<00:00,  6.34it/s]


Epoch 7: Train Loss = 1.1526
→ Val Loss: 1.2353 | Val Accuracy: 78.95%


Epoch 8/10: 100%|██████████| 4684/4684 [12:19<00:00,  6.33it/s]


Epoch 8: Train Loss = 1.1408
→ Val Loss: 1.2265 | Val Accuracy: 78.95%


Epoch 9/10: 100%|██████████| 4684/4684 [12:19<00:00,  6.34it/s]


Epoch 9: Train Loss = 1.1289
→ Val Loss: 1.2052 | Val Accuracy: 80.71%


Epoch 10/10: 100%|██████████| 4684/4684 [12:19<00:00,  6.34it/s]


Epoch 10: Train Loss = 1.1154
→ Val Loss: 1.2314 | Val Accuracy: 80.49%

=== Final Evaluation Report ===
Confusion Matrix:
[[4876  230   46    0   23]
 [ 290  166   25    0    2]
 [ 309  198  509   22   20]
 [  11    7   98   17   32]
 [  14    3   34    7   87]]

Classification Report:
                  precision    recall  f1-score   support

           No DR       0.89      0.94      0.91      5175
            Mild       0.27      0.34      0.31       483
        Moderate       0.71      0.48      0.58      1058
          Severe       0.37      0.10      0.16       165
Proliferative DR       0.53      0.60      0.56       145

        accuracy                           0.80      7026
       macro avg       0.56      0.49      0.50      7026
    weighted avg       0.80      0.80      0.80      7026

