<a href="https://colab.research.google.com/github/PankajaLaks/CWTernausResNet/blob/main/CWTernausResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
ResNet50-TernausNet classifier (fixed):
- Deterministic stratified 80/10/10 split
- Consistent preprocessing for train/val/test and single-image inference
- Class weights to mitigate imbalance
- ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
- Diagnostics: class counts, sample counts, confusion matrix + report
Requirements: tensorflow>=2.10, numpy, scikit-learn, matplotlib
"""
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support, roc_auc_score
import matplotlib.pyplot as plt
from pathlib import Path
import itertools

# ==============================
# Config
# ==============================
DATA_DIR = "/content/drive/MyDrive/DS"   # <-- folder containing subfolders for each class
INPUT_SHAPE = (256, 256, 3)
BATCH_SIZE  = 16
EPOCHS      = 30
LR          = 1e-3
SEED        = 42
RESNET_WEIGHTS = "imagenet"   # or None
DROPOUT_RATE    = 0.25
FEATURE_STAGE   = "decoder"   # 'decoder' or 'bottleneck'
FEATURE_CHANNELS = 64

tf.keras.utils.set_random_seed(SEED)
AUTOTUNE = tf.data.AUTOTUNE

# ==============================
# Utilities: dataset creation (stratified splits)
# ==============================
def gather_image_paths_and_labels(data_dir, exts=(".jpg",".jpeg",".png",".bmp")):
    data_dir = Path(data_dir)
    class_dirs = [d for d in sorted(data_dir.iterdir()) if d.is_dir()]
    class_names = [d.name for d in class_dirs]
    paths = []
    labels = []
    for idx, d in enumerate(class_dirs):
        for ext in exts:
            for p in d.rglob(f"*{ext}"):
                paths.append(str(p))
                labels.append(idx)
    paths = np.array(paths)
    labels = np.array(labels)
    # sort by path for determinism (optional), then shuffle via train_test_split with seed
    order = np.argsort(paths)
    return paths[order], labels[order], class_names

def decode_and_resize(path, label, img_size=INPUT_SHAPE[:2], augment=False):
    img = tf.io.read_file(path)
    img = tf.image.decode_image(img, channels=3, expand_animations=False)
    img = tf.image.convert_image_dtype(img, tf.float32)  # [0,1]
    img = tf.image.resize(img, img_size)
    if augment:
        img = tf.image.random_flip_left_right(img)
        img = tf.image.random_brightness(img, max_delta=0.05)
        # small random rotation
        angle = tf.random.uniform([], -0.03, 0.03)
        img = tfa.image.rotate(img, angle) if "tfa" in globals() else img  # optional if tensorflow_addons installed
    return img, label

def make_dataset(paths, labels, batch_size=BATCH_SIZE, shuffle=False, augment=False):
    ds = tf.data.Dataset.from_tensor_slices((paths, labels))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(paths), seed=SEED, reshuffle_each_iteration=True)
    ds = ds.map(lambda p, l: decode_and_resize(p, l, augment=augment), num_parallel_calls=AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(AUTOTUNE)
    return ds

# ==============================
# Build model (ResNet50 + Ternaus-style decoder -> GAP -> classifier)
# ==============================
def conv_block(x, filters, name=None):
    x = layers.Conv2D(filters, 3, padding="same", use_bias=False, name=None if not name else name+"_conv1")(x)
    x = layers.BatchNormalization(name=None if not name else name+"_bn1")(x)
    x = layers.Activation("relu", name=None if not name else name+"_relu1")(x)
    x = layers.Conv2D(filters, 3, padding="same", use_bias=False, name=None if not name else name+"_conv2")(x)
    x = layers.BatchNormalization(name=None if not name else name+"_bn2")(x)
    x = layers.Activation("relu", name=None if not name else name+"_relu2")(x)
    return x

def up_block(x, skip, filters, name=None):
    x = layers.UpSampling2D((2, 2), interpolation="bilinear", name=None if not name else name+"_up")(x)
    x = layers.Concatenate(name=None if not name else name+"_concat")([x, skip])
    x = conv_block(x, filters, name=None if not name else name+"_convblock")
    return x

def build_resnet50_feature_extractor(input_shape=INPUT_SHAPE, encoder_weights="imagenet",
                                     feature_stage="decoder", feature_channels=FEATURE_CHANNELS):
    assert feature_stage in ("decoder", "bottleneck")
    base = ResNet50(include_top=False, weights=encoder_weights, input_shape=input_shape)
    inputs = base.input
    c1 = base.get_layer("conv1_relu").output
    c2 = base.get_layer("conv2_block3_out").output
    c3 = base.get_layer("conv3_block4_out").output
    c4 = base.get_layer("conv4_block6_out").output
    c5 = base.get_layer("conv5_block3_out").output

    # bottleneck convs
    x = layers.Conv2D(1024, 3, padding="same", use_bias=False, name="bottleneck_conv1")(c5)
    x = layers.BatchNormalization(name="bottleneck_bn1")(x)
    x = layers.Activation("relu", name="bottleneck_relu1")(x)
    x = layers.Conv2D(1024, 3, padding="same", use_bias=False, name="bottleneck_conv2")(x)
    x = layers.BatchNormalization(name="bottleneck_bn2")(x)
    x = layers.Activation("relu", name="bottleneck_relu2")(x)

    if feature_stage == "bottleneck":
        return models.Model(inputs, x, name="ResNet50_Ternaus_Bottleneck")

    # decoder
    d4 = up_block(x, c4, 512, name="dec4")
    d3 = up_block(d4, c3, 256, name="dec3")
    d2 = up_block(d3, c2, 128, name="dec2")
    d1 = up_block(d2, c1, 64, name="dec1")
    d0 = layers.UpSampling2D((2, 2), interpolation="bilinear", name="dec0_up")(d1)
    d0 = conv_block(d0, 64, name="dec0_convblock")
    feats = layers.Conv2D(feature_channels, 3, padding="same", activation="relu", name="features_conv")(d0)
    return models.Model(inputs, feats, name="ResNet50_Ternaus_Features")

def build_resnet50_ternaus_classifier(input_shape=INPUT_SHAPE, num_classes=3, encoder_weights="imagenet",
                                      feature_stage="decoder", feature_channels=64, dropout_rate=0.25):
    ternaus = build_resnet50_feature_extractor(input_shape, encoder_weights, feature_stage, feature_channels)
    x_in = ternaus.input
    feats = ternaus.output
    x = layers.GlobalAveragePooling2D(name="gap")(feats)
    if dropout_rate and dropout_rate > 0:
        x = layers.Dropout(dropout_rate, name="cls_dropout")(x)
    outputs = layers.Dense(num_classes, activation="softmax", name="classifier")(x)
    return models.Model(x_in, outputs, name="ResNet50_Ternaus_Classifier")

# ==============================
# Evaluation helper
# ==============================
def evaluate_and_print(model, dataset, class_names):
    y_true = []
    for _, y in dataset.unbatch().batch(1024):  # gather in larger chunks
        y_true.append(y.numpy())
    y_true = np.concatenate(y_true)
    y_proba = model.predict(dataset, verbose=0)
    y_pred = np.argmax(y_proba, axis=1)
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=class_names, digits=4))
    cm = confusion_matrix(y_true, y_pred)
    print("Confusion Matrix:\n", cm)
    acc = (y_true == y_pred).mean()
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)
    print(f"Accuracy: {acc:.4f} | Macro Precision: {prec:.4f} | Macro Recall: {rec:.4f} | Macro F1: {f1:.4f}")

    # Plot confusion matrix
    plt.figure(figsize=(6,5))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title("Confusion matrix")
    plt.colorbar()
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45)
    plt.yticks(tick_marks, class_names)
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], 'd'),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.show()

# ==============================
# Single image inference
# ==============================
def preprocess_image_path(img_path, img_size=INPUT_SHAPE[:2]):
    img = tf.io.read_file(img_path)
    img = tf.image.decode_image(img, channels=3, expand_animations=False)
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, img_size)
    x = tf.expand_dims(img, axis=0)
    return x.numpy()

def predict_single_image(model, img_path, class_names):
    x = preprocess_image_path(img_path)
    proba = model.predict(x, verbose=0)[0]
    idx = int(np.argmax(proba))
    print(f"\nImage: {img_path}")
    for i, cname in enumerate(class_names):
        print(f"  {cname}: {proba[i]:.4f}")
    print(f"Predicted: {class_names[idx]} (index={idx})")
    return idx, proba

# ==============================
# Main
# ==============================
def main():
    # Gather paths + labels
    paths, labels, class_names = gather_image_paths_and_labels(DATA_DIR)
    assert len(class_names) >= 2, "Need at least 2 classes in subfolders."
    print("Detected classes (alphabetical):", class_names)

    # Stratified split: train 80%, pool 20% -> then pool split 50/50 for val/test (10/10)
    p_train, p_pool, y_train, y_pool = train_test_split(paths, labels, test_size=0.2, stratify=labels, random_state=SEED)
    p_val, p_test, y_val, y_test = train_test_split(p_pool, y_pool, test_size=0.5, stratify=y_pool, random_state=SEED)

    print(f"Samples: total={len(paths)}, train={len(p_train)}, val={len(p_val)}, test={len(p_test)}")
    print("Class distribution in training set:", np.bincount(y_train))

    # Build tf.data datasets with identical preprocessing
    # (augmentation only on train)
    try:
        # If tensorflow_addons available, we used above in decode_and_resize - otherwise it's gracefully skipped
        import tensorflow_addons as tfa  # noqa: F401
    except Exception:
        pass

    train_ds = make_dataset(p_train, y_train, batch_size=BATCH_SIZE, shuffle=True, augment=True)
    val_ds   = make_dataset(p_val,   y_val,   batch_size=BATCH_SIZE, shuffle=False, augment=False)
    test_ds  = make_dataset(p_test,  y_test,  batch_size=BATCH_SIZE, shuffle=False, augment=False)

    # Class weights
    classes = np.unique(y_train)
    cw = compute_class_weight(class_weight='balanced', classes=classes, y=y_train)
    class_weight = {int(c): float(w) for c, w in zip(classes, cw)}
    print("Class weights:", class_weight)

    # Build & compile model
    num_classes = len(class_names)
    model = build_resnet50_ternaus_classifier(
        input_shape=INPUT_SHAPE,
        num_classes=num_classes,
        encoder_weights=RESNET_WEIGHTS,
        feature_stage=FEATURE_STAGE,
        feature_channels=FEATURE_CHANNELS,
        dropout_rate=DROPOUT_RATE
    )
    model.compile(optimizer=tf.keras.optimizers.Adam(LR),
                  loss="sparse_categorical_crossentropy",
                  metrics=["accuracy"])
    model.summary()

    # Callbacks
    callbacks = [
        tf.keras.callbacks.ModelCheckpoint("best_resnet_ternaus.keras", monitor="val_loss", save_best_only=True),
        tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=7, restore_best_weights=True),
        tf.keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3, verbose=1)
    ]

    # Train
    history = model.fit(train_ds, validation_data=val_ds, epochs=EPOCHS,
                        callbacks=callbacks, class_weight=class_weight)

    # Evaluate
    print("\nEvaluating on test set…")
    test_loss, test_acc = model.evaluate(test_ds, verbose=0)
    print(f"Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.4f}")
    evaluate_and_print(model, test_ds, class_names)

    # Example single image inference (change path as needed)
    example_image = "/content/drive/MyDrive/NORMAL2-IM-1442-0001.jpeg"
    predict_single_image(model, example_image, class_names)

    model.save("/content/drive/MyDrive/res_ternausnet.keras")
    model.save("/content/drive/MyDrive/res_ternausnet.h5")
    print("Saved model to /content/drive/MyDrive/res_ternausnet.keras")

if __name__ == "__main__":
    main()