In [7]:
import os
import time
import copy
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm

from sklearn.metrics import classification_report, confusion_matrix

# ============================================================
# 1️⃣ Config
# ============================================================

data_root = r"C://retino"  # train/val/test inside this
train_dir = os.path.join(data_root, "train")
val_dir   = os.path.join(data_root, "valid")
test_dir  = os.path.join(data_root, "test")


img_size   = 224
batch_size = 16
num_epochs = 25

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

torch.backends.cudnn.benchmark = True

# ============================================================
# 2️⃣ Data Transforms & Dataloaders
# ============================================================
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std  = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(img_size, scale=(0.9, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])

val_test_transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])

train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
val_dataset   = datasets.ImageFolder(val_dir,   transform=val_test_transform)
test_dataset  = datasets.ImageFolder(test_dir,  transform=val_test_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                          num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False,
                          num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False,
                          num_workers=2, pin_memory=True)

class_names = train_dataset.classes
num_classes = len(class_names)
print("Classes:", class_names)          # should print ['DR', 'No_DR'] (alphabetical)
print("num_classes:", num_classes)

# ============================================================
# 3️⃣ Class Weights (handle imbalance) – optional but useful
# ============================================================
labels = np.array(train_dataset.targets)
class_counts = np.bincount(labels)
print("Class counts (train):", class_counts)

# inverse frequency
class_weights = len(labels) / (len(class_counts) * class_counts + 1e-6)
class_weights = torch.tensor(class_weights, dtype=torch.float32, device=device)
print("Class weights:", class_weights)

# ============================================================
# 4️⃣ ViT + BiLSTM Model (2-class)
# ============================================================
class ViT_BiLSTM_DR(nn.Module):
    def __init__(self, num_classes=2, pretrained=True, lstm_hidden=256):
        super(ViT_BiLSTM_DR, self).__init__()
        self.vit = timm.create_model(
            "vit_base_patch16_224",
            pretrained=pretrained
        )
        # remove original classifier
        self.vit.reset_classifier(0)

        embed_dim = self.vit.num_features  # usually 768

        self.bilstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=lstm_hidden,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )

        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(lstm_hidden * 2, num_classes)

    def forward(self, x):
        # ViT -> patch tokens
        tokens = self.vit.forward_features(x)

        # timm > 0.9 can return dict
        if isinstance(tokens, dict):
            tokens = tokens["x"]  # (B, T, D)

        # discard CLS token
        patch_tokens = tokens[:, 1:, :]    # (B, T-1, D)

        lstm_out, (h_n, c_n) = self.bilstm(patch_tokens)
        forward_h  = h_n[-2, :, :]
        backward_h = h_n[-1, :, :]
        h_cat = torch.cat([forward_h, backward_h], dim=1)

        h_cat = self.dropout(h_cat)
        logits = self.fc(h_cat)           # (B, 2)

        return logits

model = ViT_BiLSTM_DR(num_classes=num_classes, pretrained=True).to(device)

# ============================================================
# 5️⃣ Fine-tuning Strategy (freeze early ViT blocks)
# ============================================================
for p in model.vit.parameters():
    p.requires_grad = True

N_FREEZE = 8
if hasattr(model.vit, "blocks"):
    for block in model.vit.blocks[:N_FREEZE]:
        for p in block.parameters():
            p.requires_grad = False

for p in model.vit.patch_embed.parameters():
    p.requires_grad = False

model.vit.pos_embed.requires_grad = False

for p in model.bilstm.parameters():
    p.requires_grad = True
for p in model.fc.parameters():
    p.requires_grad = True

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params     = sum(p.numel() for p in model.parameters())
print(f"Total params: {total_params:,} | Trainable: {trainable_params:,}")

# ============================================================
# 6️⃣ Optimizer & Loss  (CrossEntropy for 2 classes)
# ============================================================
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=3e-5,
    weight_decay=1e-4
)

criterion = nn.CrossEntropyLoss(weight=class_weights)

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

# ============================================================
# 7️⃣ Train / Eval Helpers
# ============================================================
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    n_samples = 0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)          # (B,)

        optimizer.zero_grad()
        logits = model(images)              # (B, 2)
        loss = criterion(logits, labels)

        loss.backward()
        optimizer.step()

        preds = torch.argmax(logits, dim=1)

        running_loss += loss.item() * images.size(0)
        running_corrects += (preds == labels).sum().item()
        n_samples += images.size(0)

    epoch_loss = running_loss / n_samples
    epoch_acc  = running_corrects / n_samples
    return epoch_loss, epoch_acc

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    n_samples = 0

    all_labels = []
    all_probs = []

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        logits = model(images)
        loss = criterion(logits, labels)

        probs = torch.softmax(logits, dim=1)       # (B, 2)
        preds = torch.argmax(probs, dim=1)

        running_loss += loss.item() * images.size(0)
        running_corrects += (preds == labels).sum().item()
        n_samples += images.size(0)

        all_labels.append(labels.cpu().numpy())
        all_probs.append(probs.cpu().numpy())

    epoch_loss = running_loss / n_samples
    epoch_acc  = running_corrects / n_samples
    all_labels = np.concatenate(all_labels)
    all_probs  = np.vstack(all_probs)

    return epoch_loss, epoch_acc, all_labels, all_probs

# ============================================================
# 8️⃣ Training Loop
# ============================================================
best_model_wts = copy.deepcopy(model.state_dict())
best_val_acc = 0.0

for epoch in range(num_epochs):
    since = time.time()
    print(f"\nEpoch {epoch+1}/{num_epochs}")

    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc, _, _ = evaluate(model, val_loader, criterion, device)

    print(f"Train: loss={train_loss:.4f}, acc={train_acc*100:.2f}%")
    print(f"Val:   loss={val_loss:.4f}, acc={val_acc*100:.2f}%")

    scheduler.step()

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_wts = copy.deepcopy(model.state_dict())
        torch.save(model.state_dict(), "vit_bilstm_retino_best.pth")
        print("✅ Best model updated and saved.")

    print(f"Epoch time: {time.time() - since:.1f}s")

print("\nTraining complete. Best val acc: {:.2f}%".format(best_val_acc * 100))
model.load_state_dict(best_model_wts)

# ============================================================
# 9️⃣ Test Evaluation
# ============================================================
test_loss, test_acc, y_true, y_prob = evaluate(model, test_loader, criterion, device)
print(f"\nTest Loss: {test_loss:.4f}, Test Acc: {test_acc*100:.2f}%")

y_pred = np.argmax(y_prob, axis=1)

print("\nConfusion Matrix:")
print(confusion_matrix(y_true, y_pred))

print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=class_names))


Using device: cuda
Classes: ['DR', 'No_DR']
num_classes: 2
Class counts (train): [1050 1026]
Class weights: tensor([0.9886, 1.0117], device='cuda:0')
Total params: 87,900,930 | Trainable: 30,456,066

Epoch 1/25
Train: loss=0.1918, acc=94.08%
Val:   loss=0.1149, acc=96.80%
✅ Best model updated and saved.
Epoch time: 101.6s

Epoch 2/25
Train: loss=0.0883, acc=97.69%
Val:   loss=0.0836, acc=98.12%
✅ Best model updated and saved.
Epoch time: 99.8s

Epoch 3/25
Train: loss=0.0898, acc=97.40%
Val:   loss=0.0815, acc=97.55%
Epoch time: 100.2s

Epoch 4/25
Train: loss=0.0931, acc=97.35%
Val:   loss=0.0845, acc=97.55%
Epoch time: 101.2s

Epoch 5/25
Train: loss=0.0647, acc=98.17%
Val:   loss=0.0800, acc=97.55%
Epoch time: 102.4s

Epoch 6/25
Train: loss=0.0616, acc=98.27%
Val:   loss=0.0925, acc=96.99%
Epoch time: 104.0s

Epoch 7/25
Train: loss=0.0581, acc=98.36%
Val:   loss=0.0721, acc=97.36%
Epoch time: 103.4s

Epoch 8/25
Train: loss=0.0414, acc=98.99%
Val:   loss=0.0844, acc=97.55%
Epoch time: 1

In [None]:
# DR_model_train_save.py  (DR.ipynb – first cells)

import os
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

print("TensorFlow version:", tf.__version__)

# ======================
# Paths & Hyperparams
# ======================
dr_root   = r"C://retino"
train_dir = os.path.join(dr_root, "train")
val_dir   = os.path.join(dr_root, "valid")
test_dir  = os.path.join(dr_root, "test")

IMG_SIZE   = (224, 224)
BATCH_SIZE = 16
EPOCHS     = 20

# ======================
# Data Generators
# ======================
train_gen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.05,
    height_shift_range=0.05,
    zoom_range=0.15,
    horizontal_flip=True
)
val_gen = ImageDataGenerator(rescale=1./255)

train_data = train_gen.flow_from_directory(
    train_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary'
)

val_data = val_gen.flow_from_directory(
    val_dir,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary'
)

print("DR class indices:", train_data.class_indices)  # {'DR':0, 'No_DR':1} or similar

# Class weights for imbalance
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

labels = train_data.classes
class_weights_vals = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(labels),
    y=labels
)
class_weights = {i: w for i, w in enumerate(class_weights_vals)}
print("DR class weights:", class_weights)

# ======================
# Model Definition
# ======================
base_model = ResNet50(
    weights="imagenet",
    include_top=False,
    input_shape=(224, 224, 3)
)
base_model.trainable = False

x = GlobalAveragePooling2D()(base_model.output)
x = Dense(256, activation="relu")(x)
x = Dropout(0.4)(x)
output = Dense(1, activation="sigmoid")(x)

dr_model = Model(inputs=base_model.input, outputs=output)

dr_model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-4),
    loss="binary_crossentropy",
    metrics=["accuracy", tf.keras.metrics.AUC(name="auc")]
)

dr_model.summary()

# ======================
# Callbacks
# ======================
checkpoint_path_dr = "ResNet50_dr_model_best.h5"

checkpoint_dr = ModelCheckpoint(
    checkpoint_path_dr,
    monitor="val_accuracy",
    save_best_only=True,
    mode="max",
    verbose=1
)

early_stop_dr = EarlyStopping(
    monitor="val_accuracy",
    patience=5,
    mode="max",
    restore_best_weights=True,
    verbose=1
)

reduce_lr_dr = ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.3,
    patience=3,
    min_lr=1e-6,
    verbose=1
)

# ======================
# Train (Stage 1 – frozen backbone)
# ======================
history_dr = dr_model.fit(
    train_data,
    validation_data=val_data,
    epochs=EPOCHS,
    class_weight=class_weights,
    callbacks=[checkpoint_dr, early_stop_dr, reduce_lr_dr]
)

# ======================
# (Optional) Stage 2 – fine-tune last blocks
# ======================
for layer in base_model.layers:
    if "conv5_block" in layer.name:
        layer.trainable = True

dr_model.compile(
    optimizer=tf.keras.optimizers.Adam(1e-5),
    loss="binary_crossentropy",
    metrics=["accuracy", tf.keras.metrics.AUC(name="auc")]
)

history_dr_ft = dr_model.fit(
    train_data,
    validation_data=val_data,
    epochs=10,
    class_weight=class_weights,
    callbacks=[checkpoint_dr, early_stop_dr, reduce_lr_dr]
)

print("✅ DR model training complete. Best weights saved to", checkpoint_path_dr)


TensorFlow version: 2.18.0
Found 2076 images belonging to 2 classes.
Found 531 images belonging to 2 classes.
DR class indices: {'DR': 0, 'No_DR': 1}
DR class weights: {0: 0.9885714285714285, 1: 1.0116959064327486}


  self._warn_if_super_not_called()


Epoch 1/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 899ms/step - accuracy: 0.5139 - auc: 0.5217 - loss: 0.7365
Epoch 1: val_accuracy improved from -inf to 0.70810, saving model to ResNet50_dr_model_best.h5




[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m169s[0m 1s/step - accuracy: 0.5138 - auc: 0.5216 - loss: 0.7364 - val_accuracy: 0.7081 - val_auc: 0.7547 - val_loss: 0.6842 - learning_rate: 1.0000e-04
Epoch 2/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 690ms/step - accuracy: 0.5147 - auc: 0.5168 - loss: 0.7183
Epoch 2: val_accuracy improved from 0.70810 to 0.72128, saving model to ResNet50_dr_model_best.h5




[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m114s[0m 877ms/step - accuracy: 0.5146 - auc: 0.5168 - loss: 0.7183 - val_accuracy: 0.7213 - val_auc: 0.8518 - val_loss: 0.6759 - learning_rate: 1.0000e-04
Epoch 3/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 677ms/step - accuracy: 0.5173 - auc: 0.5364 - loss: 0.7003
Epoch 3: val_accuracy did not improve from 0.72128
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 864ms/step - accuracy: 0.5174 - auc: 0.5364 - loss: 0.7003 - val_accuracy: 0.6045 - val_auc: 0.8651 - val_loss: 0.6732 - learning_rate: 1.0000e-04
Epoch 4/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 729ms/step - accuracy: 0.5150 - auc: 0.5360 - loss: 0.6961
Epoch 4: val_accuracy did not improve from 0.72128
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m120s[0m 919ms/step - accuracy: 0.5151 - auc: 0.5361 - loss: 0.6961 - val_accuracy: 0.5838 - val_auc: 0.8831 - val_loss: 0.6612 - learnin



[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m114s[0m 878ms/step - accuracy: 0.5891 - auc: 0.6057 - loss: 0.6788 - val_accuracy: 0.7721 - val_auc: 0.9039 - val_loss: 0.6446 - learning_rate: 1.0000e-04
Epoch 8/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 719ms/step - accuracy: 0.6089 - auc: 0.6370 - loss: 0.6691
Epoch 8: val_accuracy did not improve from 0.77213
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m116s[0m 894ms/step - accuracy: 0.6089 - auc: 0.6370 - loss: 0.6691 - val_accuracy: 0.6629 - val_auc: 0.9068 - val_loss: 0.6385 - learning_rate: 1.0000e-04
Epoch 9/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 689ms/step - accuracy: 0.6141 - auc: 0.6741 - loss: 0.6608
Epoch 9: val_accuracy did not improve from 0.77213
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m112s[0m 863ms/step - accuracy: 0.6141 - auc: 0.6741 - loss: 0.6608 - val_accuracy: 0.6234 - val_auc: 0.9102 - val_loss: 0.6334 - learnin



[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m95s[0m 728ms/step - accuracy: 0.6183 - auc: 0.6811 - loss: 0.6611 - val_accuracy: 0.8267 - val_auc: 0.9137 - val_loss: 0.6274 - learning_rate: 1.0000e-04
Epoch 11/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 807ms/step - accuracy: 0.6036 - auc: 0.6588 - loss: 0.6596
Epoch 11: val_accuracy did not improve from 0.82674
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m132s[0m 1s/step - accuracy: 0.6036 - auc: 0.6588 - loss: 0.6596 - val_accuracy: 0.6215 - val_auc: 0.9160 - val_loss: 0.6307 - learning_rate: 1.0000e-04
Epoch 12/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 753ms/step - accuracy: 0.6399 - auc: 0.7143 - loss: 0.6494
Epoch 12: val_accuracy did not improve from 0.82674
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m124s[0m 949ms/step - accuracy: 0.6399 - auc: 0.7143 - loss: 0.6494 - val_accuracy: 0.8004 - val_auc: 0.9159 - val_loss: 0.6146 - learnin



[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m135s[0m 1s/step - accuracy: 0.6418 - auc: 0.7051 - loss: 0.6454 - val_accuracy: 0.8475 - val_auc: 0.9200 - val_loss: 0.5995 - learning_rate: 1.0000e-04
Epoch 16/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 801ms/step - accuracy: 0.6616 - auc: 0.7322 - loss: 0.6401
Epoch 16: val_accuracy did not improve from 0.84746
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m131s[0m 1s/step - accuracy: 0.6617 - auc: 0.7323 - loss: 0.6401 - val_accuracy: 0.7006 - val_auc: 0.9224 - val_loss: 0.6005 - learning_rate: 1.0000e-04
Epoch 17/20
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 809ms/step - accuracy: 0.6562 - auc: 0.7382 - loss: 0.6347
Epoch 17: val_accuracy did not improve from 0.84746
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m132s[0m 1s/step - accuracy: 0.6562 - auc: 0.7382 - loss: 0.6347 - val_accuracy: 0.8324 - val_auc: 0.9223 - val_loss: 0.5889 - learning_rat



[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m133s[0m 1s/step - accuracy: 0.6591 - auc: 0.7328 - loss: 0.6245 - val_accuracy: 0.8493 - val_auc: 0.9248 - val_loss: 0.5660 - learning_rate: 1.0000e-04
Restoring model weights from the end of the best epoch: 20.
Epoch 1/10
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.7518 - auc: 0.8010 - loss: 0.9077
Epoch 1: val_accuracy did not improve from 0.84934
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m235s[0m 1s/step - accuracy: 0.7522 - auc: 0.8017 - loss: 0.9045 - val_accuracy: 0.5386 - val_auc: 0.8500 - val_loss: 0.6967 - learning_rate: 1.0000e-05
Epoch 2/10
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.8607 - auc: 0.9335 - loss: 0.3311
Epoch 2: val_accuracy did not improve from 0.84934
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m180s[0m 1s/step - accuracy: 0.8607 - auc: 0.9335 - loss: 0.3311 - val_accuracy: 0.5989 -



[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m182s[0m 1s/step - accuracy: 0.8760 - auc: 0.9456 - loss: 0.2975 - val_accuracy: 0.9115 - val_auc: 0.9751 - val_loss: 0.2284 - learning_rate: 1.0000e-05
Epoch 4/10
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.8851 - auc: 0.9507 - loss: 0.2895
Epoch 4: val_accuracy did not improve from 0.91149
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m180s[0m 1s/step - accuracy: 0.8850 - auc: 0.9507 - loss: 0.2896 - val_accuracy: 0.6591 - val_auc: 0.9571 - val_loss: 0.8638 - learning_rate: 1.0000e-05
Epoch 5/10
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.8971 - auc: 0.9588 - loss: 0.2578
Epoch 5: val_accuracy improved from 0.91149 to 0.92655, saving model to ResNet50_dr_model_best.h5




[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m187s[0m 1s/step - accuracy: 0.8970 - auc: 0.9587 - loss: 0.2579 - val_accuracy: 0.9266 - val_auc: 0.9801 - val_loss: 0.2154 - learning_rate: 1.0000e-05
Epoch 6/10
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 981ms/step - accuracy: 0.8797 - auc: 0.9510 - loss: 0.2844
Epoch 6: val_accuracy did not improve from 0.92655
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m143s[0m 1s/step - accuracy: 0.8798 - auc: 0.9511 - loss: 0.2842 - val_accuracy: 0.6441 - val_auc: 0.9488 - val_loss: 1.1347 - learning_rate: 1.0000e-05
Epoch 7/10
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 650ms/step - accuracy: 0.9121 - auc: 0.9628 - loss: 0.2473
Epoch 7: val_accuracy did not improve from 0.92655
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m99s[0m 763ms/step - accuracy: 0.9120 - auc: 0.9628 - loss: 0.2474 - val_accuracy: 0.7137 - val_auc: 0.9628 - val_loss: 0.7300 - learning_rate:



[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 770ms/step - accuracy: 0.9141 - auc: 0.9658 - loss: 0.2357 - val_accuracy: 0.9322 - val_auc: 0.9810 - val_loss: 0.1742 - learning_rate: 1.0000e-05
Epoch 9/10
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 644ms/step - accuracy: 0.8820 - auc: 0.9593 - loss: 0.2604
Epoch 9: val_accuracy did not improve from 0.93220
[1m130/130[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m98s[0m 754ms/step - accuracy: 0.8820 - auc: 0.9593 - loss: 0.2605 - val_accuracy: 0.9171 - val_auc: 0.9824 - val_loss: 0.1857 - learning_rate: 1.0000e-05
Epoch 10/10
[1m100/130[0m [32m━━━━━━━━━━━━━━━[0m[37m━━━━━[0m [1m19s[0m 650ms/step - accuracy: 0.9117 - auc: 0.9720 - loss: 0.2175

In [None]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model, Model

# ======================
# Load Best DR Model
# ======================
dr_gradcam_model_path = "ResNet50_dr_model_best.h5"
dr_model_best = load_model(dr_gradcam_model_path, compile=False)

dr_model_best.summary()
LAST_CONV_LAYER_NAME_DR = "conv5_block3_out"  # from summary

# ======================
# Grad-CAM function
# ======================
def make_gradcam_heatmap_dr(img_array, model, last_conv_layer_name=LAST_CONV_LAYER_NAME_DR):
    grad_model = Model(
        inputs=model.input,
        outputs=[model.get_layer(last_conv_layer_name).output, model.output]
    )

    with tf.GradientTape() as tape:
        conv_outputs, preds = grad_model(img_array)
        class_idx = tf.argmax(preds[0])
        loss = preds[:, class_idx]

    grads = tape.gradient(loss, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    conv_outputs = conv_outputs[0]
    heatmap = conv_outputs @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    heatmap = tf.maximum(heatmap, 0) / (tf.reduce_max(heatmap) + 1e-10)
    return heatmap.numpy()

# ======================
# Pick a DR test image
# ======================
sample_class = "DR"   # or "No_DR"
sample_path = os.path.join(test_dir, sample_class,
                           os.listdir(os.path.join(test_dir, sample_class))[0])
print("Using test image:", sample_path)

img = tf.keras.preprocessing.image.load_img(sample_path, target_size=IMG_SIZE)
img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0
input_tensor = np.expand_dims(img_array, axis=0)

# ======================
# Generate heatmap
# ======================
heatmap = make_gradcam_heatmap_dr(input_tensor, dr_model_best)

heatmap_resized = cv2.resize(heatmap, IMG_SIZE)
heatmap_uint8 = np.uint8(255 * heatmap_resized)
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)

superimposed = heatmap_color * 0.4 + (img_array * 255).astype("uint8")

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Fundus Image")
plt.imshow(img_array)
plt.axis("off")

plt.subplot(1, 2, 2)
plt.title("Grad-CAM (DR)")
plt.imshow(superimposed.astype("uint8"))
plt.axis("off")
plt.tight_layout()
plt.show()
