In [None]:
"""
CNN-SVM Hybrid for Diabetic Retinopathy Classification
- Supports Kaggle and Messidor (Messidor has 3 CSVs)
- Uses ImageDataGenerator to avoid memory overflow
- CNN feature extraction + SVM on penultimate features
"""

import os
import math
import numpy as np
import pandas as pd
import joblib
import time

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
from sklearn.metrics import (
    accuracy_score, classification_report, confusion_matrix, cohen_kappa_score
)
from sklearn.svm import SVC

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# ----------------------------
# USER CONFIG
# ----------------------------
USE_DATASET = "MESSIDOR"  # <== Options: "KAGGLE" or "MESSIDOR"

IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 20
LEARNING_RATE = 1e-4
OUTPUT_DIR = "./outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Paths
KAGGLE_IMAGES_DIR = r"D:/Education/MSc/Active Assignments/Project/Model/KDR_Pre-processed/subset"
KAGGLE_CSV = r"D:/Education/MSc/Active Assignments/Project/Model/KDR_Pre-processed/subset/subset_labels.csv"

MESSIDOR_IMAGES_DIR = r"D:/Education/MSc/Active Assignments/Project/Model/MS_Pre-processed"
MESSIDOR_CSVS = [
    r"D:/Education/MSc/Active Assignments/Project/Model/MS_Pre-processed/Annotation_Base11.csv",
    r"D:/Education/MSc/Active Assignments/Project/Model/MS_Pre-processed/Annotation_Base21.csv",
    r"D:/Education/MSc/Active Assignments/Project/Model/MS_Pre-processed/Annotation_Base31.csv"
]

SVM_MODEL_PATH = os.path.join(OUTPUT_DIR, "svm_classifier.pkl")
PCA_PATH = os.path.join(OUTPUT_DIR, "pca_transform.pkl")
SCALER_PATH = os.path.join(OUTPUT_DIR, "scaler_transform.pkl")
CNN_FEATURE_EXTRACTOR_PATH = os.path.join(OUTPUT_DIR, "cnn_feature_extractor.h5")
BEST_CNN_CHECKPOINT = os.path.join(OUTPUT_DIR, "best_cnn.h5")
FINAL_CNN_MODEL_PATH = os.path.join(OUTPUT_DIR, "final_cnn_model.h5")

# ----------------------------
# DATA LOADING
# ----------------------------
def load_kaggle(csv_path, img_dir):
    df = pd.read_csv(csv_path)
    df.columns = [c.strip().lower() for c in df.columns]
    df = df.rename(columns={"image": "file_path", "level": "raw_label"})
    df['file_path'] = df['file_path'].apply(lambda x: os.path.join(img_dir, str(x)))
    df = df[df['file_path'].apply(os.path.exists)].reset_index(drop=True)
    return df

def load_messidor(csv_paths, img_dir):
    dfs = []
    for path in csv_paths:
        df_temp = pd.read_csv(path)
        df_temp.columns = [c.strip().lower() for c in df_temp.columns]
        df_temp = df_temp.rename(columns={"image name": "file_path", "retinopathy grade": "raw_label"})
        df_temp['file_path'] = df_temp['file_path'].apply(lambda x: os.path.join(img_dir, str(x)))
        dfs.append(df_temp)
    df = pd.concat(dfs, ignore_index=True)
    df = df[df['file_path'].apply(os.path.exists)].reset_index(drop=True)
    return df

if USE_DATASET == "KAGGLE":
    df = load_kaggle(KAGGLE_CSV, KAGGLE_IMAGES_DIR)
elif USE_DATASET == "MESSIDOR":
    df = load_messidor(MESSIDOR_CSVS, MESSIDOR_IMAGES_DIR)
else:
    raise ValueError("Invalid dataset selection!")

# Label encoding
le = LabelEncoder()
df['label_encoded'] = le.fit_transform(df['raw_label'].astype(int))
df['label_str'] = df['label_encoded'].astype(str)
num_classes = len(le.classes_)

# ----------------------------
# IMAGE GENERATOR
# ----------------------------
datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,
    horizontal_flip=True,
    rotation_range=10,
    zoom_range=0.1,
    brightness_range=[0.9, 1.1]
)

train_gen = datagen.flow_from_dataframe(
    df,
    x_col='file_path',
    y_col='label_str',      # use string labels
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='training',
    shuffle=True
)

val_gen = datagen.flow_from_dataframe(
    df,
    x_col='file_path',
    y_col='label_str',      # use string labels
    target_size=(IMG_SIZE, IMG_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)


# ----------------------------
# CLASS WEIGHTS
# ----------------------------
cw = class_weight.compute_class_weight('balanced', classes=np.unique(train_gen.classes), y=train_gen.classes)
class_weights_dict = dict(zip(np.unique(train_gen.classes), cw))

# ----------------------------
# CNN MODEL
# ----------------------------
def build_resnet50_head(input_shape=(IMG_SIZE, IMG_SIZE, 3), num_classes=5, penultimate_units=512):
    base = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape=input_shape)
    x = GlobalAveragePooling2D(name="gap")(base.output)
    x = Dense(penultimate_units, activation="relu", name="penultimate_dense")(x)
    x = Dropout(0.5)(x)
    outputs = Dense(num_classes, activation='softmax', name='softmax_head')(x)
    model = Model(inputs=base.input, outputs=outputs, name="ResNet50_with_head")
    return model

def feature_extractor(model):
    penultimate = model.get_layer("penultimate_dense").output
    return Model(inputs=model.input, outputs=penultimate)

# ----------------------------
# TRAIN CNN
# ----------------------------
model = build_resnet50_head(input_shape=(IMG_SIZE, IMG_SIZE, 3), num_classes=num_classes)
model.compile(optimizer=Adam(LEARNING_RATE), loss='categorical_crossentropy', metrics=['accuracy'])

callbacks = [
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True, verbose=1),
    ModelCheckpoint(BEST_CNN_CHECKPOINT, monitor='val_loss', save_best_only=True, verbose=1)
]

steps_per_epoch = math.ceil(train_gen.n / BATCH_SIZE)
validation_steps = math.ceil(val_gen.n / BATCH_SIZE)

print(f"Training CNN on {USE_DATASET} dataset...")
model.fit(
    train_gen,
    validation_data=val_gen,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
    epochs=EPOCHS,
    callbacks=callbacks,
    class_weight=class_weights_dict,
    verbose=1
)
model.save(FINAL_CNN_MODEL_PATH)
print("Final CNN saved at:", FINAL_CNN_MODEL_PATH)

# ----------------------------
# FEATURE EXTRACTION
# ----------------------------
feat_model = feature_extractor(model)
feat_model.save(CNN_FEATURE_EXTRACTOR_PATH)
print("CNN feature extractor saved at:", CNN_FEATURE_EXTRACTOR_PATH)

def extract_features_in_batches(feat_model, data_gen):
    X_list, y_list = [], []
    steps = math.ceil(data_gen.n / data_gen.batch_size)
    for i in range(steps):
        X_batch, y_batch = next(data_gen)
        feats = feat_model.predict(X_batch, verbose=0)
        X_list.append(feats)
        y_list.append(np.argmax(y_batch, axis=1))
    X_feat = np.vstack(X_list)
    y_feat = np.concatenate(y_list)
    return X_feat, y_feat

# Reset generator for feature extraction
train_gen.reset()
X_train_feat, y_train_feat = extract_features_in_batches(feat_model, train_gen)
val_gen.reset()
X_val_feat, y_val_feat = extract_features_in_batches(feat_model, val_gen)


# ----------------------------
# TRAIN SVM
# ----------------------------
X_svm = np.vstack([X_train_feat, X_val_feat])
y_svm = np.concatenate([y_train_feat, y_val_feat])
print("SVM features shape:", X_svm.shape)

print("\n--- Training SVM on Updated Features (512-D) ---")

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_svm)

pca = PCA(n_components=0.95)
X_pca = pca.fit_transform(X_scaled)

svm_model = SVC(kernel='rbf', probability=True)
svm_model.fit(X_pca, y_svm)

print("SVM training completed. PCA output dim:", X_pca.shape[1])

# Save only these
joblib.dump(svm_model, SVM_MODEL_PATH)
joblib.dump(pca, PCA_PATH)
joblib.dump(scaler, SCALER_PATH)


# ----------------------------
# FUSION PREDICTION FUNCTION
# ----------------------------
def fusion_predict_from_features(cnn_model, svm_model, pca, scaler, X_val_feat, val_labels, val_gen, batch_size=16):
    # CNN predictions on validation images
    val_gen.reset()
    steps = math.ceil(val_gen.n / batch_size)
    cnn_probs_list = []
    for i in range(steps):
        X_batch, _ = next(val_gen)
        cnn_probs_batch = cnn_model.predict(X_batch, verbose=0)
        cnn_probs_list.append(cnn_probs_batch)
    cnn_probs = np.vstack(cnn_probs_list)

    # SVM predictions (correct order: scale → PCA)
    X_scaled = scaler.transform(X_val_feat)   # scale raw CNN features first
    X_pca = pca.transform(X_scaled)           # then PCA
    svm_probs = svm_model.predict_proba(X_pca)

    # Fusion: average probabilities
    fusion_probs = (cnn_probs + svm_probs) / 2
    y_pred_fusion = np.argmax(fusion_probs, axis=1)
    y_true = val_labels

    return y_pred_fusion, y_true

# ----------------------------
# EVALUATE FUSION
# ----------------------------
y_pred_fusion, y_true_fusion = fusion_predict_from_features(
    cnn_model=model,
    svm_model=svm_model,
    pca=pca,
    scaler=scaler,
    X_val_feat=X_val_feat,
    val_labels=y_val_feat,
    val_gen=val_gen,      
    batch_size=BATCH_SIZE
)


print("\nFusion Metrics on Validation Set:")
print(classification_report(y_true_fusion, y_pred_fusion))
print("Cohen's kappa:", cohen_kappa_score(y_true_fusion, y_pred_fusion))



# ----------------------------
# MAIN PIPELINE
# ----------------------------
if __name__ == "__main__":
    print(f"Using dataset: {USE_DATASET}")

    # CNN TRAINING
    print("\n--- Training CNN ---")
    steps_per_epoch = math.ceil(train_gen.n / BATCH_SIZE)
    validation_steps = math.ceil(val_gen.n / BATCH_SIZE)

    model.fit(
        train_gen,
        validation_data=val_gen,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        epochs=EPOCHS,
        callbacks=callbacks,
        class_weight=class_weights_dict,
        verbose=1
    )
    model.save(FINAL_CNN_MODEL_PATH)
    print("Final CNN model saved:", FINAL_CNN_MODEL_PATH)

    # FEATURE EXTRACTION
    print("\n--- Extracting CNN features ---")
    feat_model = feature_extractor(model)
    feat_model.save(CNN_FEATURE_EXTRACTOR_PATH)
    print("CNN feature extractor saved:", CNN_FEATURE_EXTRACTOR_PATH)

    # Reset generators for feature extraction
    train_gen.reset()
    val_gen.reset()

    X_train_feat, y_train_feat = extract_features_in_batches(feat_model, train_gen)
    X_val_feat, y_val_feat = extract_features_in_batches(feat_model, val_gen)

    # Combine train + val features for SVM
    X_svm = np.vstack([X_train_feat, X_val_feat])
    y_svm = np.concatenate([y_train_feat, y_val_feat])
    print(f"SVM feature matrix shape: {X_svm.shape}, labels shape: {y_svm.shape}")

    # SVM TRAINING
    print("\n--- Training SVM ---")
    pca = PCA(n_components=min(128, X_svm.shape[1]), random_state=42)
    X_svm_pca = pca.fit_transform(X_svm)

    scaler = StandardScaler()
    X_svm_scaled = scaler.fit_transform(X_svm_pca)

    svm_model = SVC(kernel='rbf', C=1.0, gamma='scale', class_weight='balanced', probability=True)
    start_time = time.time()
    svm_model.fit(X_svm_scaled, y_svm)
    print(f"SVM training finished in {time.time() - start_time:.2f} seconds")

    y_pred = svm_model.predict(X_svm_scaled)
    print("\n--- SVM Training Metrics ---")
    print(classification_report(y_svm, y_pred))
    print("Cohen's kappa:", cohen_kappa_score(y_svm, y_pred))

    # Save SVM artifacts
    joblib.dump(svm_model, SVM_MODEL_PATH)
    joblib.dump(pca, PCA_PATH)
    joblib.dump(scaler, SCALER_PATH)
    print("Saved SVM artifacts.")

    # OPTIONAL: FUSION PREDICTION
    print("\n--- Fusion CNN + SVM ---")
    y_pred_fusion, y_true_fusion = fusion_predict(
        model,
        svm_model,
        pca,
        scaler,
        val_gen,          
        batch_size=BATCH_SIZE
    )

    print("\nFusion Metrics on Validation Set:")
    print(classification_report(y_true_fusion, y_pred_fusion))
    print("Cohen's kappa:", cohen_kappa_score(y_val_feat, y_pred_fusion))

Found 240 validated image filenames belonging to 4 classes.
Found 60 validated image filenames belonging to 4 classes.
Training CNN on MESSIDOR dataset...
Epoch 1/20
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - accuracy: 0.3679 - loss: 1.5101
Epoch 1: val_loss improved from None to 1.27156, saving model to ./outputs\best_cnn.h5




[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 3s/step - accuracy: 0.3292 - loss: 1.6004 - val_accuracy: 0.4667 - val_loss: 1.2716
Epoch 2/20
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - accuracy: 0.3604 - loss: 1.2967
Epoch 2: val_loss did not improve from 1.27156
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 3s/step - accuracy: 0.4167 - loss: 1.2325 - val_accuracy: 0.4667 - val_loss: 1.2989
Epoch 3/20
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - accuracy: 0.6313 - loss: 0.9136
Epoch 3: val_loss did not improve from 1.27156
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 3s/step - accuracy: 0.5875 - loss: 0.9401 - val_accuracy: 0.4667 - val_loss: 1.3178
Epoch 4/20
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - accuracy: 0.6313 - loss: 0.7313
Epoch 4: val_loss did not improve from 1.27156
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s



Final CNN saved at: ./outputs\final_cnn_model.h5
CNN feature extractor saved at: ./outputs\cnn_feature_extractor.h5
SVM features shape: (300, 512)

--- Training SVM on Updated Features (512-D) ---
SVM training completed. PCA output dim: 30

Fusion Metrics on Validation Set:
              precision    recall  f1-score   support

           0       0.56      1.00      0.72        18
           1       0.00      0.00      0.00         7
           2       0.00      0.00      0.00         7
           3       0.82      0.82      0.82        28

    accuracy                           0.68        60
   macro avg       0.35      0.46      0.39        60
weighted avg       0.55      0.68      0.60        60

Cohen's kappa: 0.4910714285714286
Using dataset: MESSIDOR

--- Training CNN ---


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


Epoch 1/20
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - accuracy: 0.4306 - loss: 1.2537
Epoch 1: val_loss improved from 1.27156 to 1.25882, saving model to ./outputs\best_cnn.h5




[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 3s/step - accuracy: 0.4792 - loss: 1.2084 - val_accuracy: 0.4667 - val_loss: 1.2588
Epoch 2/20
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - accuracy: 0.6126 - loss: 0.9909
Epoch 2: val_loss improved from 1.25882 to 1.25005, saving model to ./outputs\best_cnn.h5




[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 3s/step - accuracy: 0.5750 - loss: 0.9548 - val_accuracy: 0.4667 - val_loss: 1.2500
Epoch 3/20
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - accuracy: 0.6932 - loss: 0.7165
Epoch 3: val_loss improved from 1.25005 to 1.24618, saving model to ./outputs\best_cnn.h5




[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 3s/step - accuracy: 0.6958 - loss: 0.7703 - val_accuracy: 0.4667 - val_loss: 1.2462
Epoch 4/20
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - accuracy: 0.7844 - loss: 0.4923
Epoch 4: val_loss improved from 1.24618 to 1.22949, saving model to ./outputs\best_cnn.h5




[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 3s/step - accuracy: 0.7917 - loss: 0.4979 - val_accuracy: 0.4667 - val_loss: 1.2295
Epoch 5/20
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - accuracy: 0.7888 - loss: 0.4253
Epoch 5: val_loss did not improve from 1.22949
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 3s/step - accuracy: 0.8417 - loss: 0.3693 - val_accuracy: 0.3000 - val_loss: 1.2810
Epoch 6/20
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - accuracy: 0.8435 - loss: 0.2807
Epoch 6: val_loss did not improve from 1.22949
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 3s/step - accuracy: 0.8375 - loss: 0.3247 - val_accuracy: 0.3000 - val_loss: 1.4997
Epoch 7/20
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3s/step - accuracy: 0.8589 - loss: 0.2845
Epoch 7: val_loss did not improve from 1.22949
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s



Final CNN model saved: ./outputs\final_cnn_model.h5

--- Extracting CNN features ---
CNN feature extractor saved: ./outputs\cnn_feature_extractor.h5
SVM feature matrix shape: (300, 512), labels shape: (300,)

--- Training SVM ---
SVM training finished in 0.05 seconds

--- SVM Training Metrics ---
              precision    recall  f1-score   support

           0       1.00      0.99      1.00       154
           1       1.00      1.00      1.00        36
           2       1.00      1.00      1.00        45
           3       0.98      1.00      0.99        65

    accuracy                           1.00       300
   macro avg       1.00      1.00      1.00       300
weighted avg       1.00      1.00      1.00       300

Cohen's kappa: 0.9949003008822479
Saved SVM artifacts.

--- Fusion CNN + SVM ---

Fusion Metrics on Validation Set:
              precision    recall  f1-score   support

           0       0.30      1.00      0.46        18
           1       0.00      0.00      0.0

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
