In [None]:
# Multilabel OCT training notebook (TensorFlow 2.x)
# Requirements: tensorflow >= 2.6, pandas, numpy, scikit-learn, matplotlib
# Example: pip install tensorflow pandas scikit-learn matplotlib

import os
import math
import random
from pathlib import Path

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, optimizers, losses, metrics
from sklearn.metrics import classification_report, roc_auc_score
import matplotlib.pyplot as plt

# -------------------------
# USER CONFIG / PATHS
# -------------------------
TRAIN_CSV = "/home/sutirtha/anaconda3/sutirtha_research_operations/OCT_Data/OCT_layerwise_classification_dataset_15k/oct_data_15k/train_test_val_split_csv/train.csv"   # path to training csv (must have 'filename' + label columns)
VAL_CSV   = "/home/sutirtha/anaconda3/sutirtha_research_operations/OCT_Data/OCT_layerwise_classification_dataset_15k/oct_data_15k/train_test_val_split_csv/val.csv"     # validation csv
TEST_CSV  = "/home/sutirtha/anaconda3/sutirtha_research_operations/OCT_Data/OCT_layerwise_classification_dataset_15k/oct_data_15k/train_test_val_split_csv/test.csv"    # test csv
IMAGE_DIR = "/home/sutirtha/anaconda3/sutirtha_research_operations/OCT_Data/OCT_layerwise_classification_dataset_15k/oct_data_15k/data"     # base folder that contains the image files

IMG_SIZE = (224, 224)     # image size
BATCH_SIZE = 32
EPOCHS = 50
AUTOTUNE = tf.data.AUTOTUNE
BACKBONE = "EfficientNetB4"  # change to other tf.keras.applications if desired
LEARNING_RATE = 1e-4
DROPOUT_RATE = 0.4
SEED = 42

# Label columns provided by the user
LABEL_COLUMNS = [
    'Vitreomacular Traction(#D95030)',
    'Epiretinal Membrane(ERM)(#EA899A)',
    'Full Thickness Macular Hole(FTMH)(#F54021)',
    'Lamellar Macular Hole(LMH)(#F3A505)',
    'Pseudo Macular Hole(#79553D)',
    'Intraretinal Fluid/Spongiform Edema(#EA899A)',
    'Subretinal Fluid(IRL)(#B44C43)',
    'Cystoid Macular Edema(CME)(#00BB2D)',
    'Hyperreflective Intraretinal Foci(#EFA94A)',
    'Subretinal Fluid(SRL)(#8673A1)',
    'Subretinal Hyperreflective Material(SHRM)(#6A5D4D)',
    'Drusen(#FAD201)',
    'CNVM(#316650)',
    'PED(#0E294B)',
    'Normal'
]

NUM_CLASSES = len(LABEL_COLUMNS)
tf.random.set_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# -------------------------
# UTIL: load CSV and sanity checks
# -------------------------
def load_df(csv_path):
    df = pd.read_csv(csv_path)
    if 'filename' not in df.columns:
        raise ValueError("CSV must contain a 'filename' column with image file names/paths.")
    # Ensure label columns present. If not present, create columns with zeros (safer).
    for c in LABEL_COLUMNS:
        if c not in df.columns:
            print(f"Warning: label column {c} not found in {csv_path}. Filling with zeros.")
            df[c] = 0
    # create absolute path column
    df['filepath'] = df['filename'].apply(lambda x: os.path.join(IMAGE_DIR, str(x)))
    # create multi-hot numpy arrays
    df['labels_array'] = df[LABEL_COLUMNS].values.tolist()
    return df

train_df = load_df(TRAIN_CSV)
val_df = load_df(VAL_CSV)
test_df = load_df(TEST_CSV)

print("Train samples:", len(train_df))
print("Val samples:", len(val_df))
print("Test samples:", len(test_df))

# -------------------------
# Compute class frequencies & positive weights (pos_weight)
# pos_weight = (#negatives / #positives) for each class — used with
# tf.nn.weighted_cross_entropy_with_logits inside a custom loss.
# -------------------------
class_counts = train_df[LABEL_COLUMNS].sum(axis=0).astype(int)
total_samples = len(train_df)
neg_counts = total_samples - class_counts
# avoid division by zero by clipping counts at 1
pos_counts = class_counts.clip(lower=1)
pos_weight = (neg_counts / pos_counts).astype(float).values  # shape (NUM_CLASSES,)

print("Class counts (train):")
print(class_counts)
print("pos_weight per class (neg/pos):")
for c, pw in zip(LABEL_COLUMNS, pos_weight):
    print(f"{c}: {pw:.3f}")

# -------------------------
# Optionally do light upsampling of minority samples (probabilistic resampling)
# This creates a new DataFrame 'train_df_balanced' by sampling with probability proportional to sample weight.
# If you prefer NOT to upsample, simply set train_df_balanced = train_df
# -------------------------
def create_sample_weights(df, pos_weight_per_class):
    # For each sample compute a sample weight = sum(pos_weight_of_each_positive_class) / (num_positive_labels+1e-6)
    labels = np.array(df[LABEL_COLUMNS].values.tolist(), dtype=float)
    # multiply each label by pos_weight and sum across classes
    per_sample_score = (labels * pos_weight_per_class).sum(axis=1)
    # normalize to [0,1]
    if per_sample_score.max() - per_sample_score.min() > 0:
        per_sample_score = (per_sample_score - per_sample_score.min()) / (per_sample_score.max() - per_sample_score.min())
    else:
        per_sample_score = np.ones_like(per_sample_score)
    # add small constant to keep positives boosted
    return per_sample_score + 0.1

upsample = True   # set to False to skip resampling
if upsample:
    sample_scores = create_sample_weights(train_df, pos_weight)
    # create a sampled DataFrame larger than original to partially balance classes
    multiply_factor = 1.5  # how many times bigger the new train set will be (1.0 = no change)
    n_new = int(len(train_df) * multiply_factor)
    sampled_idx = np.random.choice(len(train_df), size=n_new, replace=True, p=sample_scores / sample_scores.sum())
    train_df_balanced = train_df.iloc[sampled_idx].reset_index(drop=True)
    print(f"Upsampled train: from {len(train_df)} to {len(train_df_balanced)}")
else:
    train_df_balanced = train_df

# -------------------------
# TF Dataset pipeline
# -------------------------
def read_image(path):
    img = tf.io.read_file(path)
    img = tf.io.decode_image(img, channels=3, expand_animations=False)
    img = tf.image.convert_image_dtype(img, tf.float32)  # 0..1
    img = tf.image.resize(img, IMG_SIZE)
    return img

# Augmentations for training
def augment_image(image):
    # Random flip
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    # Random rotation by small angle
    angle = tf.random.uniform([], -0.08, 0.08)  # radians (~ +/- 4.5 deg)
    image = tfa_image_rotate(image, angle)
    # Random brightness/contrast
    image = tf.image.random_brightness(image, 0.08)
    image = tf.image.random_contrast(image, 0.9, 1.1)
    # random zoom / crop and resize back
    if tf.random.uniform([]) < 0.3:
        crop = tf.image.random_crop(image, size=[int(IMG_SIZE[0]*0.9), int(IMG_SIZE[1]*0.9), 3])
        image = tf.image.resize(crop, IMG_SIZE)
    return image

# small helper to rotate images (tf.image doesn't have rotate in base TF).
def tfa_image_rotate(image, angle):
    # We implement a small rotation using tf.keras.preprocessing if tfa isn't installed.
    # Prefer using tensorflow-addons' image.rotate if available. We'll try to use tfa if installed.
    try:
        import tensorflow_addons as tfa
        return tfa.image.rotate(image, angle)
    except Exception:
        # Fallback: approximate rotation by using affine transform via tf.keras.layers.experimental.preprocessing
        # We'll just return image (no rotation) if tfa is not available — it's optional.
        return image

def make_dataset_from_df(df, shuffle=False, augment=False, batch_size=32):
    filepaths = df['filepath'].values
    labels = np.array(df[LABEL_COLUMNS].values.tolist(), dtype=np.float32)
    ds = tf.data.Dataset.from_tensor_slices((filepaths, labels))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(df), seed=SEED)
    def _load(path, label):
        img = read_image(path)
        if augment:
            img = augment_image(img)
        return img, label
    ds = ds.map(_load, num_parallel_calls=AUTOTUNE)
    ds = ds.batch(batch_size)
    ds = ds.prefetch(AUTOTUNE)
    return ds

train_ds = make_dataset_from_df(train_df_balanced, shuffle=True, augment=True, batch_size=BATCH_SIZE)
val_ds = make_dataset_from_df(val_df, shuffle=False, augment=False, batch_size=BATCH_SIZE)
test_ds = make_dataset_from_df(test_df, shuffle=False, augment=False, batch_size=BATCH_SIZE)

# -------------------------
# Model: EfficientNet backbone + head
# -------------------------
def build_model(img_size=IMG_SIZE, num_classes=NUM_CLASSES, dropout_rate=DROPOUT_RATE, backbone_name=BACKBONE, lr=LEARNING_RATE):
    inputs = layers.Input(shape=(img_size[0], img_size[1], 3), name="image")
    # Use tf.keras.applications dynamically
    backbone_constructor = getattr(tf.keras.applications, backbone_name)
    backbone = backbone_constructor(include_top=False, weights='imagenet', input_tensor=inputs, pooling='avg')
    x = backbone.output
    x = layers.Dropout(dropout_rate)(x)
    outputs = layers.Dense(num_classes, activation='sigmoid', name='predictions')(x)
    model = models.Model(inputs=inputs, outputs=outputs)
    optimizer = optimizers.Adam(learning_rate=lr)
    return model, optimizer

model, optimizer = build_model()
model.summary()

# -------------------------
# Custom weighted loss for multilabel: uses pos_weight per class
# Implementation uses logits; since our model outputs probabilities (sigmoid),
# we will convert predictions to logits inside the loss or better: build model to give logits.
# Simpler approach: create model with sigmoid as above, then use:
# loss = sum( (1-label)*-log(1-p) + pos_weight*label*-log(p) ) / num_classes
# We'll implement numerically stable version.
# -------------------------
pos_weight_tf = tf.constant(pos_weight, dtype=tf.float32)

def weighted_bce_loss(y_true, y_pred):
    # y_true, y_pred both are [batch, num_classes], y_pred in [0,1]
    # Avoid log(0) numerically stable variant:
    eps = 1e-7
    y_pred = tf.clip_by_value(y_pred, eps, 1 - eps)
    # compute component-wise: loss = - ( pos_weight * y_true * log(y_pred) + (1 - y_true) * log(1 - y_pred) )
    pos_term = - (pos_weight_tf * y_true * tf.math.log(y_pred))
    neg_term = - ((1.0 - y_true) * tf.math.log(1.0 - y_pred))
    loss = pos_term + neg_term
    # average across classes then across batch
    loss = tf.reduce_mean(tf.reduce_mean(loss, axis=1))
    return loss

# Metrics: per-class AUC + micro/macro AUC would be useful
# We'll add an overall micro AUC and per-class AUC metrics
metric_list = [metrics.AUC(name="auc_micro", multi_label=True)]
# Add per-class AUC metrics (useful for monitoring)
for i, cname in enumerate(LABEL_COLUMNS):
    metric_list.append(metrics.AUC(name=f"auc_{i}"))

# Compile
model.compile(optimizer=optimizer, loss=weighted_bce_loss, metrics=metric_list)

# -------------------------
# Callbacks
# -------------------------
checkpoint_path = "/home/sutirtha/anaconda3/sutirtha_research_operations/OCT_Data/OCT_layerwise_classification_dataset_15k/codes/models/efficientnetB4_with_normal_best_model.h5"
cb = [
    callbacks.ModelCheckpoint(checkpoint_path, monitor='val_auc_micro', mode='max', save_best_only=True, verbose=1),
    callbacks.EarlyStopping(monitor='val_auc_micro', mode='max', patience=6, restore_best_weights=True, verbose=1),
    callbacks.ReduceLROnPlateau(monitor='val_auc_micro', factor=0.5, patience=3, verbose=1, mode='max', min_lr=1e-7)
]

# -------------------------
# Fit
# -------------------------
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=cb,
    verbose=2
)



# -------------------------
# Utility: function for inference on single images
# -------------------------
# def infer_image(model, image_path, threshold=0.5):
#     img = read_image(image_path)
#     img = tf.expand_dims(img, axis=0)
#     prob = model.predict(img)[0]
#     labels = [LABEL_COLUMNS[i] for i in range(NUM_CLASSES) if prob[i] >= threshold]
#     return prob, labels

# Example usage:
# probs, predicted_labels = infer_image(model, "/path/to/some/image.png", threshold=0.4)
# print(predicted_labels)



In [None]:
# -------------------------
# Plot training curves
# -------------------------
def plot_history(hist):
    dh = hist.history
    plt.figure(figsize=(12,4))
    plt.subplot(1,2,1)
    plt.plot(dh['loss'], label='train_loss')
    plt.plot(dh['val_loss'], label='val_loss')
    plt.legend(); plt.title('Loss')
    plt.subplot(1,2,2)
    if 'auc_micro' in dh:
        plt.plot(dh['auc_micro'], label='train_auc_micro')
        plt.plot(dh['val_auc_micro'], label='val_auc_micro')
        plt.legend(); plt.title('AUC micro')
    plt.show()

plot_history(history)

# -------------------------
# Evaluate on test set & prediction generation
# -------------------------
print("Loading best weights from:", checkpoint_path)
try:
    model.load_weights(checkpoint_path)
except Exception as e:
    print("Could not load checkpoint:", e)

# Predict on test
y_pred_prob = model.predict(test_ds, verbose=1)
# Convert dataset labels to ground truth array
y_true = np.vstack(test_df[LABEL_COLUMNS].values)

# Compute per-class AUCs
per_class_auc = []
for i in range(NUM_CLASSES):
    try:
        auc_i = roc_auc_score(y_true[:, i], y_pred_prob[:, i])
    except ValueError:
        auc_i = float('nan')  # if only one class present in y_true
    per_class_auc.append(auc_i)
    print(f"Class {LABEL_COLUMNS[i]} AUC: {auc_i:.4f}")

# Overall micro/macro AUC
try:
    micro_auc = roc_auc_score(y_true.ravel(), y_pred_prob.ravel(), average='micro')
    macro_auc = np.nanmean(per_class_auc)
    print(f"Micro AUC (flattened): {micro_auc:.4f}, Macro AUC (mean of per-class): {macro_auc:.4f}")
except Exception as e:
    print("Error computing overall AUC:", e)

# Binarize predictions at threshold (default 0.5) for classification report
THRESH = 0.4
y_pred_bin = (y_pred_prob >= THRESH).astype(int)

print("Multilabel classification report (threshold=0.5):")
report = classification_report(y_true, y_pred_bin, target_names=LABEL_COLUMNS, zero_division=0)
print(report)

# Save predictions into test_df
test_df_preds = test_df.copy()
for i, c in enumerate(LABEL_COLUMNS):
    test_df_preds[f"pred_prob_{i}"] = y_pred_prob[:, i]
    test_df_preds[f"pred_{i}"] = y_pred_bin[:, i]
# Optionally save
test_df_preds.to_csv("test_predictions_with_probs.csv", index=False)
print("Saved predictions to test_predictions_with_probs.csv")

In [None]:
# ============================================
# Detailed Evaluation: Confusion Matrices per Label
# ============================================
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support

# Compute confusion matrix for each label
conf_matrices = {}
for i, label in enumerate(LABEL_COLUMNS):
    cm = confusion_matrix(y_true[:, i], y_pred_bin[:, i])
    conf_matrices[label] = cm
    tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
    print(f"\nLabel: {label}")
    print(f"Confusion Matrix:\n{cm}")
    print(f"TP={tp}, FP={fp}, TN={tn}, FN={fn}")
    sensitivity = tp / (tp + fn + 1e-8)
    specificity = tn / (tn + fp + 1e-8)
    print(f"Sensitivity (Recall): {sensitivity:.4f}, Specificity: {specificity:.4f}")

# ============================================
# Classification Metrics Summary
# ============================================
precision, recall, f1, support = precision_recall_fscore_support(
    y_true, y_pred_bin, average=None, zero_division=0
)

summary_df = pd.DataFrame({
    'Label': LABEL_COLUMNS,
    'Precision': precision,
    'Recall': recall,
    'F1-score': f1,
    'Support': support
})

overall_precision, overall_recall, overall_f1, _ = precision_recall_fscore_support(
    y_true, y_pred_bin, average='micro', zero_division=0
)
print("\n================ Overall Micro-Average Metrics ================")
print(f"Precision: {overall_precision:.4f}, Recall: {overall_recall:.4f}, F1: {overall_f1:.4f}")

overall_precision, overall_recall, overall_f1, _ = precision_recall_fscore_support(
    y_true, y_pred_bin, average='macro', zero_division=0
)
print("\n================ Overall Macro-Average Metrics ================")
print(f"Precision: {overall_precision:.4f}, Recall: {overall_recall:.4f}, F1: {overall_f1:.4f}")

print("\nDetailed per-label metrics:")
display(summary_df)

# ============================================
# Plot Confusion Matrices as Heatmap Grid
# ============================================
import math

num_labels = len(LABEL_COLUMNS)
cols = 4
rows = math.ceil(num_labels / cols)
fig, axes = plt.subplots(rows, cols, figsize=(20, 4 * rows))

for idx, label in enumerate(LABEL_COLUMNS):
    r, c = divmod(idx, cols)
    ax = axes[r, c] if rows > 1 else axes[c]
    cm = conf_matrices[label]
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False, ax=ax)
    ax.set_title(label)
    ax.set_xlabel("Predicted")
    ax.set_ylabel("Actual")

# Hide empty subplots if any
for j in range(num_labels, rows * cols):
    r, c = divmod(j, cols)
    if rows > 1:
        fig.delaxes(axes[r, c])
    else:
        fig.delaxes(axes[c])

plt.tight_layout()
plt.show()
