In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/apple-dataset/ATLDSD/Healthy leaf/label/IMG_20190726_195734.png
/kaggle/input/apple-dataset/ATLDSD/Healthy leaf/label/IMG_20190726_195801.png
/kaggle/input/apple-dataset/ATLDSD/Healthy leaf/label/IMG_20190726_192237.png
/kaggle/input/apple-dataset/ATLDSD/Healthy leaf/label/IMG_20190726_192312.png
/kaggle/input/apple-dataset/ATLDSD/Healthy leaf/label/IMG_20190726_191605.png
/kaggle/input/apple-dataset/ATLDSD/Healthy leaf/label/IMG_20190726_193039.png
/kaggle/input/apple-dataset/ATLDSD/Healthy leaf/label/IMG_20190726_194148.png
/kaggle/input/apple-dataset/ATLDSD/Healthy leaf/label/IMG_20190726_195829.png
/kaggle/input/apple-dataset/ATLDSD/Healthy leaf/label/IMG_20190726_194413.png
/kaggle/input/apple-dataset/ATLDSD/Healthy leaf/label/IMG_20190726_192014.png
/kaggle/input/apple-dataset/ATLDSD/Healthy leaf/label/IMG_20190726_190848.png
/kaggle/input/apple-dataset/ATLDSD/Healthy leaf/label/IMG_20190726_193631.png
/kaggle/input/apple-dataset/ATLDSD/Healthy leaf/label/IMG_201907

In [1]:
print ("hello ' world'")

hello ' world'


In [2]:
# =========================================================
# Apple Leaf Segmentation - MULTI-MODEL COMPARISON
# UNet, DeepLabV3+, FCN, SegNet, BiSeNetV2
# Dice + Weighted CE, IoU(no-bg), tf.data Aug, Curves, Comparison
# Adds visualization: Image | GT | Pred + per-class disease severity (% of leaf)
# =========================================================
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["TF_DISABLE_PROFILER"] = "1"

import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pandas as pd
import time
import random

# ============ Config ============
BASE_DIR    = "/kaggle/input/apple-dataset/ATLDSD"   # <--- change if needed
IMG_SIZE    = 256
BATCH_SIZE  = 8
EPOCHS      = 15
SEED        = 2025

# ==== Resume / Skip settings =====
OUTDIR = "outputs"
SKIP_TRAIN_IF_CKPT = True         # if checkpoint exists, skip training that model
RUN_ONLY = None                    # e.g. ['DeepLabV3Plus','BiSeNetV2'] or None
START_AT = None                    # e.g. 'FCN' to skip earlier models

random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

CLASS_NAMES = [
    "Background",           # 0
    "Healthy",              # 1
    "Brown spot",           # 2
    "Alternaria leaf spot", # 3
    "Gray spot",            # 4
    "Rust"                  # 5
]
NUM_CLASSES = len(CLASS_NAMES)

# exact RGB colors -> class index
COLOR_MAP = {
    (0,   0,   0): 0,   # Background
    (128, 0,   0): 1,   # Healthy
    (128, 0, 128): 2,   # Brown spot (purple)
    (128,128,  0): 3,   # Alternaria (olive)
    (0,   0, 128): 4,   # Gray (blue)
    (0, 128,   0): 5,   # Rust (green)
}
CLASS_WEIGHTS = tf.constant([0.25, 0.7, 1.1, 1.1, 1.1, 1.2], dtype=tf.float32)

# augmentation knobs
A_ROT90_PROB   = 0.75
A_FLIP_H_PROB  = 0.5
A_FLIP_V_PROB  = 0.5
A_JITTER_PROB  = 0.6
A_NOISE_PROB   = 0.3
A_CROP_PROB    = 0.6
CROP_MIN_FRAC  = 0.85

# ============ Utils ============
def set_gpu_growth():
    try:
        gpus = tf.config.list_physical_devices('GPU')
        if gpus:
            for g in gpus:
                tf.config.experimental.set_memory_growth(g, True)
            print(f"✅ GPU found: {len(gpus)}; memory growth enabled")
        else:
            print("ℹ️  No GPU detected; running on CPU")
    except Exception as e:
        print("⚠️  GPU mem-growth not set:", e)
set_gpu_growth()

def rgb_mask_to_classes(mask_rgb):
    out = np.zeros(mask_rgb.shape[:2], dtype=np.uint8)
    R, G, B = mask_rgb[...,0], mask_rgb[...,1], mask_rgb[...,2]
    for (r,g,b), cls in COLOR_MAP.items():
        m = (R == r) & (G == g) & (B == b)
        out[m] = cls
    return out

PALETTE = {
    0:(0,0,0), 1:(128,0,0), 2:(128,0,128), 3:(128,128,0), 4:(0,0,128), 5:(0,128,0)
}
def mask_to_color(mask):
    h,w = mask.shape
    out = np.zeros((h,w,3), dtype=np.uint8)
    for c, col in PALETTE.items():
        out[mask==c] = col
    return out

# ===================== Severity Utils & Visualization =====================

def compute_severity_percentages(mask_int):
    """
    Compute severity % per class with respect to LEAF area (non-background).
    mask_int: [H,W] uint8 class map (0=background, 1=healthy, 2..=diseases)
    Returns: (per_class_dict, healthy_pct, disease_total_pct)
    """
    m = np.asarray(mask_int, dtype=np.uint8)
    leaf = (m != 0)
    leaf_pixels = int(leaf.sum())
    if leaf_pixels == 0:
        # No leaf pixels; avoid div by zero
        per_class = {CLASS_NAMES[c]: 0.0 for c in range(2, NUM_CLASSES)}
        healthy_pct = 0.0
        disease_total = 0.0
        return per_class, healthy_pct, disease_total

    per_class = {}
    for c in range(2, NUM_CLASSES):
        per_class[CLASS_NAMES[c]] = 100.0 * float((m == c).sum()) / leaf_pixels

    healthy_pct = 100.0 * float((m == 1).sum()) / leaf_pixels
    disease_total = 100.0 - healthy_pct
    return per_class, healthy_pct, disease_total


def _box_text_from_severity(per_class, healthy_pct, disease_total):
    """
    Build a neat multi-line text block for overlay.
    """
    lines = [f"Healthy: {healthy_pct:5.1f}%",
             f"Total disease: {disease_total:5.1f}%"]
    # List each disease class on new line
    for name, pct in per_class.items():
        lines.append(f"{name}: {pct:5.1f}%")
    return "\n".join(lines)


def visualize_with_severity(model, Xv, Yv_int, n=4, outdir="outputs", seed=2025):
    """
    Show Image | GT | Pred with severity (%) boxes.
    - model: trained Keras model
    - Xv:   [N,H,W,3] float32 images (0..1)
    - Yv_int: [N,H,W] uint8 integer masks (class ids)
    - n: how many samples to draw
    """
    np.random.seed(seed)
    os.makedirs(outdir, exist_ok=True)
    idx = np.random.choice(len(Xv), size=min(n, len(Xv)), replace=False)

    fig, axs = plt.subplots(len(idx), 3, figsize=(11, 3.6*len(idx)))
    if len(idx) == 1:
        axs = np.expand_dims(axs, 0)

    for r, i in enumerate(idx):
        img = Xv[i]
        gt  = Yv_int[i].astype(np.uint8)

        # Predict
        pr  = model.predict(img[None], verbose=0)[0]       # [H,W,C]
        pm  = np.argmax(pr, axis=-1).astype(np.uint8)      # [H,W]

        # Compute severities (GT & Pred)
        gt_per, gt_healthy, gt_dis = compute_severity_percentages(gt)
        pr_per, pr_healthy, pr_dis = compute_severity_percentages(pm)

        # Left: Image
        axs[r,0].imshow(img)
        axs[r,0].set_title("Image")
        axs[r,0].axis('off')

        # Middle: GT + severity box
        axs[r,1].imshow(mask_to_color(gt))
        axs[r,1].set_title("Ground Truth")
        axs[r,1].axis('off')
        gt_txt = _box_text_from_severity(gt_per, gt_healthy, gt_dis)
        axs[r,1].text(
            0.02, 0.98, gt_txt,
            transform=axs[r,1].transAxes,
            va='top', ha='left',
            fontsize=9,
            bbox=dict(facecolor='white', alpha=0.75, edgecolor='black', boxstyle='round,pad=0.4')
        )

        # Right: Pred + severity box
        axs[r,2].imshow(mask_to_color(pm))
        axs[r,2].set_title(f"Predicted ({model.name})")
        axs[r,2].axis('off')
        pr_txt = _box_text_from_severity(pr_per, pr_healthy, pr_dis)
        axs[r,2].text(
            0.02, 0.98, pr_txt,
            transform=axs[r,2].transAxes,
            va='top', ha='left',
            fontsize=9,
            bbox=dict(facecolor='white', alpha=0.75, edgecolor='black', boxstyle='round,pad=0.4')
        )

    plt.tight_layout()
    save_path = os.path.join(outdir, f"viz_with_severity_{model.name}.png")
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"🖼️  Saved severity visualization: {save_path}")
    return save_path

# ============ Dataset ============
class AppleLeafDataset:
    IMG_EXTS = (".png",".jpg",".jpeg",".bmp",".tif",".tiff")
    def __init__(self, base_dir, image_size=256):
        self.base_dir = base_dir
        self.image_size = image_size
        self.image_paths, self.mask_paths = self._discover_pairs()
        print(f"✅ Paired samples: {len(self.image_paths)}")

    def _list_images(self, d):
        acc = []
        for r,_,fs in os.walk(d):
            for f in fs:
                if f.lower().endswith(self.IMG_EXTS):
                    acc.append(os.path.join(r,f))
        return acc

    def _discover_pairs(self):
        imgs, msks = [], []
        if not os.path.exists(self.base_dir):
            raise FileNotFoundError(f"Base dir not found: {self.base_dir}")
        for cls_folder in sorted(os.listdir(self.base_dir)):
            cpath = os.path.join(self.base_dir, cls_folder)
            if not os.path.isdir(cpath): continue
            img_dir = os.path.join(cpath, "image")
            msk_dir = os.path.join(cpath, "label")
            if not (os.path.exists(img_dir) and os.path.exists(msk_dir)): continue

            img_files = self._list_images(img_dir)
            msk_files = self._list_images(msk_dir)
            img_by = {os.path.splitext(os.path.basename(p))[0].lower(): p for p in img_files}
            msk_by = {os.path.splitext(os.path.basename(p))[0].lower(): p for p in msk_files}
            common = sorted(set(img_by) & set(msk_by))
            print(f"📂 {cls_folder:20} | imgs:{len(img_files):4d} | masks:{len(msk_files):4d} | paired:{len(common):4d}")
            for s in common: imgs.append(img_by[s]); msks.append(msk_by[s])
        return imgs, msks

    def load(self):
        X, Y = [], []
        for ip, mp in tqdm(list(zip(self.image_paths, self.mask_paths)), desc="Loading data"):
            img = cv2.imread(ip, cv2.IMREAD_COLOR)
            if img is None: continue
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_AREA)
            img = img.astype(np.float32)/255.0

            msk = cv2.imread(mp, cv2.IMREAD_COLOR)
            if msk is None: continue
            msk = cv2.cvtColor(msk, cv2.COLOR_BGR2RGB)
            msk = cv2.resize(msk, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)
            msk = rgb_mask_to_classes(msk)

            X.append(img); Y.append(msk)
        return np.asarray(X, np.float32), np.asarray(Y, np.uint8)

# ============ Augmentation (tf.data) ============
def augment_img_mask(img, mask):
    img = tf.cast(img, tf.float32)
    mask = tf.cast(mask, tf.int32)
    mask_3d = tf.expand_dims(mask, axis=-1)

    def apply_transform(transform_func, prob):
        return tf.cond(
            tf.random.uniform([]) < prob,
            lambda: transform_func(img, mask_3d),
            lambda: (img, mask_3d)
        )

    def rot90_transform(i, m):
        k = tf.random.uniform([], 0, 4, dtype=tf.int32)
        return tf.image.rot90(i, k), tf.image.rot90(m, k)

    img, mask_3d = apply_transform(rot90_transform, A_ROT90_PROB)

    def flip_h(i, m): return tf.image.flip_left_right(i), tf.image.flip_left_right(m)
    def flip_v(i, m): return tf.image.flip_up_down(i), tf.image.flip_up_down(m)
    img, mask_3d = apply_transform(flip_h, A_FLIP_H_PROB)
    img, mask_3d = apply_transform(flip_v, A_FLIP_V_PROB)

    def crop_transform(i, m):
        shape = tf.shape(i)
        h, w = shape[0], shape[1]
        frac = tf.random.uniform([], CROP_MIN_FRAC, 1.0)
        nh = tf.cast(tf.cast(h, tf.float32) * frac, tf.int32)
        nw = tf.cast(tf.cast(w, tf.float32) * frac, tf.int32)
        nh = tf.minimum(nh, h); nw = tf.minimum(nw, w)
        max_y = tf.maximum(1, h - nh); max_x = tf.maximum(1, w - nw)
        oy = tf.random.uniform([], 0, max_y, dtype=tf.int32)
        ox = tf.random.uniform([], 0, max_x, dtype=tf.int32)
        i_crop = tf.image.crop_to_bounding_box(i, oy, ox, nh, nw)
        m_crop = tf.image.crop_to_bounding_box(m, oy, ox, nh, nw)

        i_resized = tf.image.resize(i_crop, [h, w], method='bilinear')

        # ensure mask resize is dtype-safe
        m_crop_f = tf.cast(m_crop, tf.float32)
        m_resized_f = tf.image.resize(m_crop_f, [h, w], method='nearest')
        m_resized = tf.cast(tf.round(m_resized_f), tf.int32)

        return i_resized, m_resized

    img, mask_3d = apply_transform(crop_transform, A_CROP_PROB)

    def apply_photometric(i):
        if tf.random.uniform([]) < A_JITTER_PROB:
            i = tf.image.random_brightness(i, 0.15)
            i = tf.image.random_contrast(i, 0.8, 1.2)
            i = tf.image.random_saturation(i, 0.8, 1.2)
            i = tf.image.random_hue(i, 0.02)
            i = tf.clip_by_value(i, 0.0, 1.0)
        if tf.random.uniform([]) < A_NOISE_PROB:
            noise = tf.random.normal(tf.shape(i), 0.0, 0.02, dtype=tf.float32)
            i = tf.clip_by_value(i + noise, 0.0, 1.0)
        return i

    img = apply_photometric(img)
    mask = tf.squeeze(mask_3d, axis=-1)
    return img, mask

def one_hot(mask):
    return tf.one_hot(tf.cast(mask, tf.int32), depth=NUM_CLASSES)

def make_dataset(X, Y, batch_size=8, shuffle=False, augment=False):
    ds = tf.data.Dataset.from_tensor_slices((X, Y))
    if shuffle:
        ds = ds.shuffle(min(len(X), 1024), reshuffle_each_iteration=True)
    def process(img, mask):
        img = tf.cast(img, tf.float32)
        mask = tf.cast(mask, tf.int32)
        if augment:
            img, mask = augment_img_mask(img, mask)
        return img, one_hot(mask)
    return ds.map(process, num_parallel_calls=tf.data.AUTOTUNE)\
             .batch(batch_size)\
             .prefetch(tf.data.AUTOTUNE)

# ============ Models ============

# 1) UNet
def build_unet(input_shape, num_classes, base=48, drop=0.15):
    x_in = keras.Input(shape=input_shape)
    def blk(x,f):
        x = layers.Conv2D(f,3,padding='same',activation='relu')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Conv2D(f,3,padding='same',activation='relu')(x)
        x = layers.BatchNormalization()(x)
        x = layers.SpatialDropout2D(drop)(x)
        return x

    c1 = blk(x_in, base);        p1 = layers.MaxPooling2D(2)(c1)      # 256->128
    c2 = blk(p1, base*2);        p2 = layers.MaxPooling2D(2)(c2)      # 128->64
    c3 = blk(p2, base*4);        p3 = layers.MaxPooling2D(2)(c3)      # 64->32
    c4 = blk(p3, base*8);        p4 = layers.MaxPooling2D(2)(c4)      # 32->16
    bn = blk(p4, base*16)

    u6 = layers.Conv2DTranspose(base*8,2,2,padding='same')(bn)         # 16->32
    u6 = layers.Concatenate()([u6,c4]); c6 = blk(u6, base*8)

    u7 = layers.Conv2DTranspose(base*4,2,2,padding='same')(c6)         # 32->64
    u7 = layers.Concatenate()([u7,c3]); c7 = blk(u7, base*4)

    u8 = layers.Conv2DTranspose(base*2,2,2,padding='same')(c7)         # 64->128
    u8 = layers.Concatenate()([u8,c2]); c8 = blk(u8, base*2)

    u9 = layers.Conv2DTranspose(base,2,2,padding='same')(c8)           # 128->256
    u9 = layers.Concatenate()([u9,c1]); c9 = blk(u9, base)

    out = layers.Conv2D(num_classes,1,activation='softmax')(c9)
    return keras.Model(x_in, out, name="UNet")

# 2) DeepLabV3+ (fixed pooling resize)
def build_deeplabv3plus(input_shape, num_classes):
    def aspp(x):
        dims = x.shape[-1]
        h, w = x.shape[1], x.shape[2]  # ints for fixed input

        # Image-level pooling branch (1x1)
        pool = layers.GlobalAveragePooling2D()(x)
        pool = layers.Reshape((1, 1, dims))(pool)
        pool = layers.Conv2D(256, 1, activation='relu', padding='same')(pool)
        pool = layers.BatchNormalization()(pool)
        # from 1x1 -> hxw using factor upsampling
        pool = layers.UpSampling2D(size=(h, w), interpolation='bilinear')(pool)

        # 1x1 conv branch
        conv1 = layers.Conv2D(256, 1, activation='relu', padding='same')(x)
        conv1 = layers.BatchNormalization()(conv1)

        # Atrous branches
        c6  = layers.Conv2D(256, 3, activation='relu', padding='same', dilation_rate=6)(x);  c6  = layers.BatchNormalization()(c6)
        c12 = layers.Conv2D(256, 3, activation='relu', padding='same', dilation_rate=12)(x); c12 = layers.BatchNormalization()(c12)
        c18 = layers.Conv2D(256, 3, activation='relu', padding='same', dilation_rate=18)(x); c18 = layers.BatchNormalization()(c18)

        y = layers.Concatenate()([pool, conv1, c6, c12, c18])
        y = layers.Conv2D(256, 1, activation='relu', padding='same')(y)
        y = layers.BatchNormalization()(y)
        return y

    inputs = keras.Input(shape=input_shape)

    # lightweight backbone
    x = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(inputs)  # 256->128
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)                  # 128->128
    x = layers.BatchNormalization()(x)
    low = x  # 128x128

    x = layers.Conv2D(128, 3, strides=2, padding='same', activation='relu')(x)      # 128->64
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)                 # 64->64
    x = layers.BatchNormalization()(x)

    x = layers.Conv2D(512, 3, strides=2, padding='same', activation='relu')(x)      # 64->32
    x = layers.BatchNormalization()(x)                                              # 32x32

    # ASPP + decoder
    x = aspp(x)                                                                     # 32x32
    x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(x)               # 32->128
    low = layers.Conv2D(48, 1, activation='relu', padding='same')(low)
    low = layers.BatchNormalization()(low)
    x = layers.Concatenate()([x, low])                                              # 128
    x = layers.Conv2D(256, 3, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256, 3, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x)               # 128->256

    outputs = layers.Conv2D(num_classes, 1, activation='softmax')(x)
    return keras.Model(inputs, outputs, name="DeepLabV3Plus")

# 3) FCN
def build_fcn(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)
    # Encoder (VGG-like)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(inputs)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
    p1 = layers.MaxPooling2D(2)(x)   # 256->128

    x = layers.Conv2D(128,3,padding='same',activation='relu')(p1)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x)
    p2 = layers.MaxPooling2D(2)(x)   # 128->64

    x = layers.Conv2D(256,3,padding='same',activation='relu')(p2)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x)
    p3 = layers.MaxPooling2D(2)(x)   # 64->32

    x = layers.Conv2D(512,3,padding='same',activation='relu')(p3)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x)
    p4 = layers.MaxPooling2D(2)(x)   # 32->16

    x = layers.Conv2D(512,3,padding='same',activation='relu')(p4)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x)
    p5 = layers.MaxPooling2D(2)(x)   # 16->8

    # FCN head
    x = layers.Conv2D(4096,7,padding='same',activation='relu')(p5); x = layers.Dropout(0.5)(x)
    x = layers.Conv2D(4096,1,activation='relu')(x); x = layers.Dropout(0.5)(x)

    s5 = layers.Conv2D(num_classes,1)(x)
    s4 = layers.Conv2D(num_classes,1)(p4)
    s3 = layers.Conv2D(num_classes,1)(p3)

    up2 = layers.Conv2DTranspose(num_classes,4,strides=2,padding='same')(s5)  # 8->16
    f4  = layers.Add()([up2, s4])

    up4 = layers.Conv2DTranspose(num_classes,4,strides=2,padding='same')(f4)  # 16->32
    f3  = layers.Add()([up4, s3])

    outputs = layers.Conv2DTranspose(num_classes,16,strides=8,padding='same',activation='softmax')(f3)  # 32->256
    return keras.Model(inputs, outputs, name="FCN")

# 4) SegNet
def build_segnet(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    # Encoder
    x = layers.Conv2D(64,3,padding='same',activation='relu')(inputs); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x);     x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(2)(x)  # 256->128

    x = layers.Conv2D(128,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(2)(x)  # 128->64

    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(2)(x)  # 64->32

    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(2)(x)  # 32->16

    # Decoder
    x = layers.UpSampling2D(2)(x)  # 16->32
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)

    x = layers.UpSampling2D(2)(x)  # 32->64
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)

    x = layers.UpSampling2D(2)(x)  # 64->128
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)

    x = layers.UpSampling2D(2)(x)  # 128->256
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)

    outputs = layers.Conv2D(num_classes,1,activation='softmax')(x)
    return keras.Model(inputs, outputs, name="SegNet")

# 5) BiSeNetV2 (compact; main head only) — fixed ContextEmbedding (no Lambda)
def build_bisenetv2(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    def ConvBNReLU(x, f, k=3, s=1):
        x = layers.Conv2D(f, k, strides=s, padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        return layers.ReLU()(x)

    def DWConvBNReLU(x, k=3, s=1):
        x = layers.DepthwiseConv2D(k, strides=s, padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        return layers.ReLU()(x)

    # ---------- Detail Branch (kept at /8) ----------
    def DetailBranch(x):
        # Stage 1: /2
        x = ConvBNReLU(x, 64, 3, 2)
        x = ConvBNReLU(x, 64, 3, 1)
        x = ConvBNReLU(x, 64, 3, 1)
        # Stage 2: /4
        x = ConvBNReLU(x, 64, 3, 2)
        x = ConvBNReLU(x, 64, 3, 1)
        x = ConvBNReLU(x, 64, 3, 1)
        # Stage 3: /8
        x = ConvBNReLU(x, 128, 3, 2)
        x = ConvBNReLU(x, 128, 3, 1)
        x = ConvBNReLU(x, 128, 3, 1)
        return x  # /8

    # ---------- Semantic Branch (down to /16 with CE, then up to /8) ----------
    def StemBlock(x):
        x = ConvBNReLU(x, 16, 3, 2)          # /2
        x = DWConvBNReLU(x, 3, 1)
        x = ConvBNReLU(x, 16, 1, 1)
        x = ConvBNReLU(x, 32, 3, 2)          # /4
        x = DWConvBNReLU(x, 3, 1)
        x = ConvBNReLU(x, 32, 1, 1)
        return x  # /4

    def GEBlock(x, out_ch, stride):
        in_ch = x.shape[-1]
        y = DWConvBNReLU(x, 3, stride)     # spatial gather
        y = ConvBNReLU(y, out_ch, 1, 1)    # expansion
        if stride == 1 and in_ch == out_ch:
            y = layers.Add()([x, y])
        return y

    # No Lambda: broadcast global context with static upsampling
    def ContextEmbedding(x, ch=128):
        h = layers.GlobalAveragePooling2D(keepdims=True)(x)  # (B,1,1,C)
        h = layers.BatchNormalization()(h)
        h = ConvBNReLU(h, ch, 1, 1)
        H = input_shape[0] // 16
        W = input_shape[1] // 16
        h = layers.UpSampling2D(size=(H, W), interpolation='bilinear')(h)  # 1x1 -> HxW
        y = layers.Add()([x, h])
        y = ConvBNReLU(y, ch, 3, 1)
        return y

    def SemanticBranch(x):
        x = StemBlock(x)              # /4, 32c
        x = GEBlock(x, 64, 2)         # /8
        x = GEBlock(x, 64, 1)
        x = GEBlock(x, 128, 2)        # /16
        x = GEBlock(x, 128, 1)
        x = GEBlock(x, 128, 1)
        x = ContextEmbedding(x, 128)  # /16 (context)
        x = layers.UpSampling2D(size=2, interpolation='bilinear')(x)  # /16 -> /8
        return x  # /8, 128c

    def FeatureFusion(detail, semantic, out_ch=256):
        x = layers.Concatenate()([detail, semantic])     # /8
        trunk = ConvBNReLU(x, out_ch, 3, 1)
        att = layers.GlobalAveragePooling2D(keepdims=True)(trunk)
        att = ConvBNReLU(att, out_ch // 4, 1, 1)
        att = layers.Conv2D(out_ch, 1, activation='sigmoid', padding='same')(att)
        out = layers.Multiply()([trunk, att])
        out = layers.Add()([trunk, out])
        return out  # /8

    def SegHead(x, num_classes, up_factor=8):
        x = ConvBNReLU(x, 128, 3, 1)
        x = layers.Conv2D(num_classes, 1, padding='same', activation='softmax')(x)
        x = layers.UpSampling2D(size=up_factor, interpolation='bilinear')(x)  # /8 -> /1
        return x

    db = DetailBranch(inputs)         # /8
    sb = SemanticBranch(inputs)       # /8
    fused = FeatureFusion(db, sb, out_ch=256)
    outputs = SegHead(fused, num_classes, up_factor=8)  # 32x32 -> 256x256
    return keras.Model(inputs, outputs, name="BiSeNetV2")

# ============ Loss and Metrics ============
SMOOTH = 1e-6

def _resize_to_label(y_pred, y_true):
    """Safety: resize predictions to label size if mismatched."""
    ph = tf.shape(y_pred)[1]; pw = tf.shape(y_pred)[2]
    th = tf.shape(y_true)[1]; tw = tf.shape(y_true)[2]
    need = tf.logical_or(tf.not_equal(ph, th), tf.not_equal(pw, tw))
    def _do():
        return tf.image.resize(y_pred, (th, tw), method='bilinear')
    return tf.cond(need, _do, lambda: y_pred)

def weighted_ce(y_true, y_pred):
    y_pred = _resize_to_label(y_pred, y_true)
    w = tf.reduce_sum(CLASS_WEIGHTS * y_true, axis=-1)                 # [B,H,W]
    ce = tf.keras.losses.categorical_crossentropy(y_true, y_pred)      # [B,H,W]
    return tf.reduce_mean(ce * w)

def dice_loss_no_bg(y_true, y_pred):
    y_pred = _resize_to_label(y_pred, y_true)
    y_true_f = tf.reshape(y_true[...,1:], [-1, NUM_CLASSES-1])
    y_pred_f = tf.reshape(y_pred[...,1:], [-1, NUM_CLASSES-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f, axis=0)
    denom = tf.reduce_sum(y_true_f + y_pred_f, axis=0)
    dice = (2.0 * intersection + SMOOTH) / (denom + SMOOTH)
    return 1.0 - tf.reduce_mean(dice)

def combo_loss(y_true, y_pred, alpha=0.5):
    return alpha * weighted_ce(y_true, y_pred) + (1.0 - alpha) * dice_loss_no_bg(y_true, y_pred)

@tf.function
def iou_no_bg(y_true, y_pred):
    """Mean IoU over classes 1..C-1 (vectorized, autograph-safe)."""
    y_pred = _resize_to_label(y_pred, y_true)
    y_true_cls = tf.argmax(y_true, axis=-1)
    y_pred_cls = tf.argmax(y_pred, axis=-1)
    y_true_oh = tf.one_hot(y_true_cls, depth=NUM_CLASSES, dtype=tf.float32)
    y_pred_oh = tf.one_hot(y_pred_cls, depth=NUM_CLASSES, dtype=tf.float32)
    y_true_f = tf.reshape(y_true_oh, [-1, NUM_CLASSES])
    y_pred_f = tf.reshape(y_pred_oh, [-1, NUM_CLASSES])
    inter = tf.reduce_sum(y_true_f * y_pred_f, axis=0)
    union = tf.reduce_sum(y_true_f + y_pred_f - y_true_f * y_pred_f, axis=0)
    inter_nb = inter[1:]; union_nb = union[1:]
    iou = tf.where(union_nb > 0.0, inter_nb / (union_nb + 1e-7), 0.0)
    return tf.reduce_mean(iou)

# ============ Training / Evaluation Helpers ============
def compile_model(model):
    model.compile(
        optimizer=keras.optimizers.Adam(1e-3),
        loss=combo_loss,
        metrics=[iou_no_bg, 'accuracy']
    )
    return model

def plot_history(hist, title, outdir):
    plt.figure(figsize=(10,4))
    # loss
    plt.subplot(1,2,1)
    plt.plot(hist.history['loss'], label='train')
    plt.plot(hist.history['val_loss'], label='val')
    plt.title(f'{title} - Loss'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend()
    # IoU
    plt.subplot(1,2,2)
    plt.plot(hist.history['iou_no_bg'], label='train IoU')
    plt.plot(hist.history['val_iou_no_bg'], label='val IoU')
    plt.title(f'{title} - IoU(no-bg)'); plt.xlabel('Epoch'); plt.ylabel('IoU'); plt.legend()
    plt.tight_layout()
    os.makedirs(outdir, exist_ok=True)
    p = os.path.join(outdir, f'{title}_curves.png')
    plt.savefig(p, dpi=140, bbox_inches='tight')
    plt.close()
    return p

def predict_to_mask(pred):
    return np.argmax(pred, axis=-1).astype(np.uint8)

# ============ Main ============
def main():
    t0 = time.time()
    print("🔎 Loading dataset...")
    ds = AppleLeafDataset(BASE_DIR, IMG_SIZE)
    X, Y = ds.load()

    # simple split (≈70/15/15)
    Xtr, Xte, Ytr, Yte = train_test_split(X, Y, test_size=0.15, random_state=SEED, shuffle=True)
    Xtr, Xva, Ytr, Yva = train_test_split(Xtr, Ytr, test_size=0.1765, random_state=SEED, shuffle=True)

    print(f"Shapes -> Train: {Xtr.shape}, Val: {Xva.shape}, Test: {Xte.shape}")

    train_ds = make_dataset(Xtr, Ytr, BATCH_SIZE, shuffle=True, augment=True)
    val_ds   = make_dataset(Xva, Yva, BATCH_SIZE, shuffle=False, augment=False)
    test_ds  = make_dataset(Xte, Yte, BATCH_SIZE, shuffle=False, augment=False)

    input_shape = (IMG_SIZE, IMG_SIZE, 3)

    # Name -> builder (order preserved)
    builders = {
        'UNet': build_unet,
        'DeepLabV3Plus': build_deeplabv3plus,
        'FCN': build_fcn,
        'SegNet': build_segnet,
        'BiSeNetV2': build_bisenetv2
    }

    # decide which to run
    model_names = list(builders.keys())
    if START_AT and START_AT in model_names:
        model_names = model_names[model_names.index(START_AT):]
    if RUN_ONLY:
        model_names = [n for n in model_names if n in RUN_ONLY]

    results = []
    os.makedirs(OUTDIR, exist_ok=True)

    for name in model_names:
        tf.keras.backend.clear_session()
        ckpt_path = os.path.join(OUTDIR, f"{name}_best.keras")

        if SKIP_TRAIN_IF_CKPT and os.path.exists(ckpt_path):
            print(f"⏭️  {name}: checkpoint found -> skipping training, loading for eval")
            try:
                model = keras.models.load_model(ckpt_path, compile=False)
                compile_model(model)
            except Exception as e:
                print(f"⚠️  Failed to load {name} checkpoint: {e}")
                print(f"🔄 Re-training {name} instead...")
                model = builders[name](input_shape, NUM_CLASSES)
                compile_model(model)
                cbs = [
                    keras.callbacks.ModelCheckpoint(
                        ckpt_path, monitor='val_iou_no_bg', mode='max',
                        save_best_only=True, save_weights_only=False, verbose=1
                    ),
                    keras.callbacks.EarlyStopping(
                        monitor='val_iou_no_bg', mode='max',
                        patience=6, restore_best_weights=True
                    )
                ]
                hist = model.fit(train_ds, validation_data=val_ds,
                                 epochs=EPOCHS, verbose=1, callbacks=cbs)
                _ = plot_history(hist, name, OUTDIR)
        else:
            print(f"\n🚀 Training {name} ...")
            model = builders[name](input_shape, NUM_CLASSES)
            compile_model(model)
            cbs = [
                keras.callbacks.ModelCheckpoint(
                    ckpt_path, monitor='val_iou_no_bg', mode='max',
                    save_best_only=True, save_weights_only=False, verbose=1
                ),
                keras.callbacks.EarlyStopping(
                    monitor='val_iou_no_bg', mode='max',
                    patience=6, restore_best_weights=True
                )
            ]
            hist = model.fit(train_ds, validation_data=val_ds,
                             epochs=EPOCHS, verbose=1, callbacks=cbs)
            _ = plot_history(hist, name, OUTDIR)

        # Evaluate (even if loaded)
        print(f"📏 Evaluating {name} ...")
        val_metrics  = model.evaluate(val_ds,  verbose=0)
        test_metrics = model.evaluate(test_ds, verbose=0)

        res = {
            "Model": name,
            "Val Loss": float(val_metrics[0]),
            "Val IoU(no-bg)": float(val_metrics[1]),
            "Val Acc": float(val_metrics[2]),
            "Test Loss": float(test_metrics[0]),
            "Test IoU(no-bg)": float(test_metrics[1]),
            "Test Acc": float(test_metrics[2]),
        }
        results.append(res)

        # Preview with severity (%) relative to leaf area
        _ = visualize_with_severity(model, Xva, Yva, n=4, outdir=OUTDIR)

    # Save/append comparison CSV
    cmp_csv = os.path.join(OUTDIR, "model_comparison.csv")
    df_new = pd.DataFrame(results)
    if os.path.exists(cmp_csv):
        try:
            df_old = pd.read_csv(cmp_csv)
            df = (pd.concat([df_old, df_new], ignore_index=True)
                    .drop_duplicates(subset=['Model'], keep='last'))
        except Exception:
            df = df_new
    else:
        df = df_new
    df = df.sort_values("Val IoU(no-bg)", ascending=False)
    df.to_csv(cmp_csv, index=False)
    print("\n================ Model Comparison (by Val IoU no-bg) ================")
    print(df.to_string(index=False))
    print(f"\n💾 Saved: {cmp_csv}")
    print(f"✅ Done in {time.time()-t0:.1f}s")

if __name__ == "__main__":
    main()


2025-08-31 05:10:57.826223: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756617058.149483      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756617058.236721      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
I0000 00:00:1756617075.286757      36 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1756617075.287675      36 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability:

⚠️  GPU mem-growth not set: Physical devices cannot be modified after being initialized
🔎 Loading dataset...
📂 Alternaria leaf spot | imgs: 278 | masks: 278 | paired: 278
📂 Brown spot           | imgs: 215 | masks: 215 | paired: 215
📂 Gray spot            | imgs: 395 | masks: 395 | paired: 395
📂 Healthy leaf         | imgs: 409 | masks: 409 | paired: 409
📂 Rust                 | imgs: 344 | masks: 344 | paired: 344
✅ Paired samples: 1641


Loading data: 100%|██████████| 1641/1641 [00:22<00:00, 74.26it/s]


Shapes -> Train: (1147, 256, 256, 3), Val: (247, 256, 256, 3), Test: (247, 256, 256, 3)

🚀 Training UNet ...
Epoch 1/15


I0000 00:00:1756617135.908867      99 service.cc:148] XLA service 0x7956c0001ee0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1756617135.910125      99 service.cc:156]   StreamExecutor device (0): Tesla T4, Compute Capability 7.5
I0000 00:00:1756617135.910145      99 service.cc:156]   StreamExecutor device (1): Tesla T4, Compute Capability 7.5
I0000 00:00:1756617138.182302      99 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1756617184.042244      99 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m143/144[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 327ms/step - accuracy: 0.6170 - iou_no_bg: 0.1197 - loss: 0.6815

E0000 00:00:1756617244.206244      98 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1756617244.444640      98 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1756617245.465277      98 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1756617245.788254      98 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.


[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 598ms/step - accuracy: 0.6175 - iou_no_bg: 0.1199 - loss: 0.6810
Epoch 1: val_iou_no_bg improved from -inf to 0.06704, saving model to outputs/UNet_best.keras
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m191s[0m 829ms/step - accuracy: 0.6181 - iou_no_bg: 0.1201 - loss: 0.6806 - val_accuracy: 0.4121 - val_iou_no_bg: 0.0670 - val_loss: 0.7258
Epoch 2/15
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 349ms/step - accuracy: 0.7853 - iou_no_bg: 0.1811 - loss: 0.5376
Epoch 2: val_iou_no_bg did not improve from 0.06704
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 374ms/step - accuracy: 0.7854 - iou_no_bg: 0.1812 - loss: 0.5375 - val_accuracy: 0.6992 - val_iou_no_bg: 0.0511 - val_loss: 1.0178
Epoch 3/15
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 352ms/step - accuracy: 0.8131 - iou_no_bg: 0.2436 - loss: 0.4851
Epoch 3: val_iou_no_bg did not improve fro

2025-08-31 05:28:45.828185: E external/local_xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng4{} for conv (f32[512,2592,4,4]{3,2,1,0}, u8[0]{0}) custom-call(f32[512,256,3,3]{3,2,1,0}, f32[2592,256,2,2]{3,2,1,0}), window={size=2x2 pad=1_1x1_1 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"leakyrelu_alpha":0,"side_input_scale":0},"force_earliest_schedule":false,"operation_queue_id":"0","wait_on_operation_queues":[]} is taking a while...
2025-08-31 05:28:46.617316: E external/local_xla/xla/service/slow_operation_alarm.cc:133] The operation took 1.789371386s
Trying algorithm eng4{} for conv (f32[512,2592,4,4]{3,2,1,0}, u8[0]{0}) custom-call(f32[512,256,3,3]{3,2,1,0}, f32[2592,256,2,2]{3,2,1,0}), window={size=2x2 pad=1_1x1_1 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convForward", backend_config={"cudnn_conv_b

[1m143/144[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 444ms/step - accuracy: 0.7597 - iou_no_bg: 0.1724 - loss: 0.5971

E0000 00:00:1756618228.867929      96 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1756618229.100901      96 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1756618229.575661      96 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1756618229.801685      96 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.


[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 629ms/step - accuracy: 0.7601 - iou_no_bg: 0.1728 - loss: 0.5964
Epoch 1: val_iou_no_bg improved from -inf to 0.00837, saving model to outputs/DeepLabV3Plus_best.keras
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m178s[0m 800ms/step - accuracy: 0.7606 - iou_no_bg: 0.1732 - loss: 0.5957 - val_accuracy: 0.1248 - val_iou_no_bg: 0.0084 - val_loss: 1.7147
Epoch 2/15
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 446ms/step - accuracy: 0.8606 - iou_no_bg: 0.3246 - loss: 0.4170
Epoch 2: val_iou_no_bg improved from 0.00837 to 0.07938, saving model to outputs/DeepLabV3Plus_best.keras
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m70s[0m 487ms/step - accuracy: 0.8607 - iou_no_bg: 0.3247 - loss: 0.4169 - val_accuracy: 0.5205 - val_iou_no_bg: 0.0794 - val_loss: 0.9516
Epoch 3/15
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 443ms/step - accuracy: 0.8834 - iou_no_bg: 0

In [3]:
# =========================================================
# Apple Leaf Segmentation - MULTI-MODEL COMPARISON
# UNet_MobileNetV2 (pretrained, fine-tune), UNet, DeepLabV3+, FCN, SegNet, BiSeNetV2
# Stronger loss (Weighted CE + Focal-Tversky), optional auto class weights
# TTA inference + small post-processing
# Visuals: (1) per-sample Image|GT|Pred with severity %, (2) 3-row grid like screenshot
# =========================================================
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["TF_DISABLE_PROFILER"] = "1"

import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pandas as pd
import time
import random

# ============ Config ============
BASE_DIR    = "/kaggle/input/apple-dataset/ATLDSD"   # <--- change if needed
IMG_SIZE    = 256
BATCH_SIZE  = 8
EPOCHS      = 12
SEED        = 2025

# ==== Resume / Skip settings =====
OUTDIR = "outputs"
SKIP_TRAIN_IF_CKPT = True         # if checkpoint exists, skip training that model
RUN_ONLY = None                    # e.g. ['UNet_MobileNetV2'] or None
START_AT = None                    # e.g. 'FCN' to skip earlier models

# ==== Training knobs ====
AUTO_CLASS_WEIGHTS = True         # compute from Y_train (overrides CLASS_WEIGHTS)
USE_TTA_IN_VIZ     = True         # use test-time augmentation in visualization preds

random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

CLASS_NAMES = [
    "Background",           # 0
    "Healthy",              # 1
    "Brown spot",           # 2
    "Alternaria leaf spot", # 3
    "Gray spot",            # 4
    "Rust"                  # 5
]
NUM_CLASSES = len(CLASS_NAMES)

# exact RGB colors -> class index
COLOR_MAP = {
    (0,   0,   0): 0,   # Background
    (128, 0,   0): 1,   # Healthy
    (128, 0, 128): 2,   # Brown spot (purple)
    (128,128,  0): 3,   # Alternaria (olive)
    (0,   0, 128): 4,   # Gray (blue)
    (0, 128,   0): 5,   # Rust (green)
}
# default weights (can be overridden by AUTO_CLASS_WEIGHTS)
CLASS_WEIGHTS = tf.constant([0.25, 0.7, 1.1, 1.1, 1.1, 1.2], dtype=tf.float32)

# augmentation knobs
A_ROT90_PROB   = 0.75
A_FLIP_H_PROB  = 0.5
A_FLIP_V_PROB  = 0.5
A_JITTER_PROB  = 0.6
A_NOISE_PROB   = 0.3
A_CROP_PROB    = 0.6
CROP_MIN_FRAC  = 0.85

# ============ Utils ============
def set_gpu_growth():
    try:
        gpus = tf.config.list_physical_devices('GPU')
        if gpus:
            for g in gpus:
                tf.config.experimental.set_memory_growth(g, True)
            print(f"✅ GPU found: {len(gpus)}; memory growth enabled")
        else:
            print("ℹ️  No GPU detected; running on CPU")
    except Exception as e:
        print("⚠️  GPU mem-growth not set:", e)
set_gpu_growth()

def rgb_mask_to_classes(mask_rgb):
    out = np.zeros(mask_rgb.shape[:2], dtype=np.uint8)
    R, G, B = mask_rgb[...,0], mask_rgb[...,1], mask_rgb[...,2]
    for (r,g,b), cls in COLOR_MAP.items():
        m = (R == r) & (G == g) & (B == b)
        out[m] = cls
    return out

PALETTE = {
    0:(0,0,0), 1:(128,0,0), 2:(128,0,128), 3:(128,128,0), 4:(0,0,128), 5:(0,128,0)
}
def mask_to_color(mask):
    h,w = mask.shape
    out = np.zeros((h,w,3), dtype=np.uint8)
    for c, col in PALETTE.items():
        out[mask==c] = col
    return out

# ===================== Severity Utils & Visualization =====================
def compute_severity_percentages(mask_int):
    """
    Compute severity % per class with respect to LEAF area (non-background).
    mask_int: [H,W] uint8 class map (0=background, 1=healthy, 2..=diseases)
    Returns: (per_class_dict, healthy_pct, disease_total_pct)
    """
    m = np.asarray(mask_int, dtype=np.uint8)
    leaf = (m != 0)
    leaf_pixels = int(leaf.sum())
    if leaf_pixels == 0:
        per_class = {CLASS_NAMES[c]: 0.0 for c in range(2, NUM_CLASSES)}
        healthy_pct = 0.0
        disease_total = 0.0
        return per_class, healthy_pct, disease_total

    per_class = {}
    for c in range(2, NUM_CLASSES):
        per_class[CLASS_NAMES[c]] = 100.0 * float((m == c).sum()) / leaf_pixels

    healthy_pct = 100.0 * float((m == 1).sum()) / leaf_pixels
    disease_total = 100.0 - healthy_pct
    return per_class, healthy_pct, disease_total

def _box_text_from_severity(per_class, healthy_pct, disease_total):
    lines = [f"Healthy: {healthy_pct:5.1f}%",
             f"Total disease: {disease_total:5.1f}%"]
    for name, pct in per_class.items():
        lines.append(f"{name}: {pct:5.1f}%")
    return "\n".join(lines)

# ---- TTA + tiny cleanup for nicer predictions (for visualization/presentations) ----
def predict_prob_tta(model, img):
    """Average probabilities over a few simple transforms, then invert them."""
    imgs = [
        img,
        np.flip(img, axis=1),                 # hflip
        np.flip(img, axis=0),                 # vflip
        np.rot90(img, k=1)
    ]
    probs = []
    for im in imgs:
        p = model.predict(im[None], verbose=0)[0]
        probs.append(p)
    # invert transforms
    probs[1] = np.flip(probs[1], axis=1)
    probs[2] = np.flip(probs[2], axis=0)
    probs[3] = np.rot90(probs[3], k=3)
    return np.mean(probs, axis=0)

def small_component_cleanup(mask, min_frac=0.001):
    """Remove tiny isolated blobs per class (send to 'Healthy')."""
    H, W = mask.shape
    min_area = max(1, int(H*W*min_frac))
    out = mask.copy()
    for c in range(1, NUM_CLASSES):
        m = (out == c).astype(np.uint8)
        if m.sum() == 0: 
            continue
        num, labels = cv2.connectedComponents(m, connectivity=8)
        for lab in range(1, num):
            area = int((labels == lab).sum())
            if area < min_area:
                out[labels == lab] = 1
    return out

def visualize_with_severity(model, Xv, Yv_int, n=4, outdir="outputs", seed=2025):
    """
    Show Image | GT | Pred with severity (%) boxes (n rows of triplets).
    """
    np.random.seed(seed)
    os.makedirs(outdir, exist_ok=True)
    idx = np.random.choice(len(Xv), size=min(n, len(Xv)), replace=False)

    fig, axs = plt.subplots(len(idx), 3, figsize=(11, 3.6*len(idx)))
    if len(idx) == 1:
        axs = np.expand_dims(axs, 0)

    for r, i in enumerate(idx):
        img = Xv[i]
        gt  = Yv_int[i].astype(np.uint8)

        # Predict (with optional TTA + cleanup)
        if USE_TTA_IN_VIZ:
            pr = predict_prob_tta(model, img)
        else:
            pr = model.predict(img[None], verbose=0)[0]
        pm  = np.argmax(pr, axis=-1).astype(np.uint8)
        pm  = small_component_cleanup(pm, min_frac=0.001)

        # severities
        gt_per, gt_healthy, gt_dis = compute_severity_percentages(gt)
        pr_per, pr_healthy, pr_dis = compute_severity_percentages(pm)

        axs[r,0].imshow(img); axs[r,0].set_title("Image"); axs[r,0].axis('off')

        axs[r,1].imshow(mask_to_color(gt)); axs[r,1].set_title("Ground Truth"); axs[r,1].axis('off')
        gt_txt = _box_text_from_severity(gt_per, gt_healthy, gt_dis)
        axs[r,1].text(0.02, 0.98, gt_txt, transform=axs[r,1].transAxes,
                      va='top', ha='left', fontsize=9,
                      bbox=dict(facecolor='white', alpha=0.75, edgecolor='black', boxstyle='round,pad=0.4'))

        axs[r,2].imshow(mask_to_color(pm)); axs[r,2].set_title(f"Predicted ({model.name})"); axs[r,2].axis('off')
        pr_txt = _box_text_from_severity(pr_per, pr_healthy, pr_dis)
        axs[r,2].text(0.02, 0.98, pr_txt, transform=axs[r,2].transAxes,
                      va='top', ha='left', fontsize=9,
                      bbox=dict(facecolor='white', alpha=0.75, edgecolor='black', boxstyle='round,pad=0.4'))

    plt.tight_layout()
    save_path = os.path.join(outdir, f"viz_with_severity_{model.name}.png")
    plt.savefig(save_path, dpi=150, bbox_inches='tight'); plt.close()
    print(f"🖼️  Saved severity visualization: {save_path}")
    return save_path

def visualize_grid_with_severity(model, Xv, Yv_int, k=6, outdir="outputs",
                                 title=None, seed=2025):
    """
    Make a 3-row panel like screenshot: [Images] / [GT] / [Pred], with severity % boxes.
    k = number of columns (samples).
    """
    np.random.seed(seed)
    os.makedirs(outdir, exist_ok=True)
    idx = np.random.choice(len(Xv), size=min(k, len(Xv)), replace=False)
    cols = len(idx); rows = 3
    fig, axs = plt.subplots(rows, cols, figsize=(cols*3.2, rows*3.2))
    if cols == 1: axs = np.expand_dims(axs, 1)

    for c, i in enumerate(idx):
        img = Xv[i]; gt = Yv_int[i].astype(np.uint8)

        if USE_TTA_IN_VIZ:
            pr = predict_prob_tta(model, img)
        else:
            pr = model.predict(img[None], verbose=0)[0]
        pm = np.argmax(pr, axis=-1).astype(np.uint8)
        pm = small_component_cleanup(pm, min_frac=0.001)

        # top: image
        axs[0, c].imshow(img); axs[0, c].set_title(f"Image {c+1}", fontsize=11); axs[0, c].axis('off')

        # middle: GT
        axs[1, c].imshow(mask_to_color(gt)); axs[1, c].set_title("GT", fontsize=10); axs[1, c].axis('off')
        gt_per, gt_h, gt_d = compute_severity_percentages(gt)
        axs[1, c].text(0.01, 0.99, _box_text_from_severity(gt_per, gt_h, gt_d),
                       transform=axs[1, c].transAxes, va='top', ha='left', fontsize=8,
                       bbox=dict(facecolor='white', alpha=0.78, edgecolor='black', boxstyle='round,pad=0.3'))

        # bottom: Pred
        axs[2, c].imshow(mask_to_color(pm)); axs[2, c].set_title("Pred", fontsize=10); axs[2, c].axis('off')
        pr_per, pr_h, pr_d = compute_severity_percentages(pm)
        axs[2, c].text(0.01, 0.99, _box_text_from_severity(pr_per, pr_h, pr_d),
                       transform=axs[2, c].transAxes, va='top', ha='left', fontsize=8,
                       bbox=dict(facecolor='white', alpha=0.78, edgecolor='black', boxstyle='round,pad=0.3'))

    if title is None:
        title = f"{model.name} — Dice+CE + Focal-Tversky, Augmented"
    fig.suptitle(title, fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    save_path = os.path.join(outdir, f"panel_severity_{model.name}.png")
    plt.savefig(save_path, dpi=150, bbox_inches='tight'); plt.close()
    print(f"🖼️  Saved grid panel: {save_path}")
    return save_path

# ============ Dataset ============
class AppleLeafDataset:
    IMG_EXTS = (".png",".jpg",".jpeg",".bmp",".tif",".tiff")
    def __init__(self, base_dir, image_size=256):
        self.base_dir = base_dir
        self.image_size = image_size
        self.image_paths, self.mask_paths = self._discover_pairs()
        print(f"✅ Paired samples: {len(self.image_paths)}")

    def _list_images(self, d):
        acc = []
        for r,_,fs in os.walk(d):
            for f in fs:
                if f.lower().endswith(self.IMG_EXTS):
                    acc.append(os.path.join(r,f))
        return acc

    def _discover_pairs(self):
        imgs, msks = [], []
        if not os.path.exists(self.base_dir):
            raise FileNotFoundError(f"Base dir not found: {self.base_dir}")
        for cls_folder in sorted(os.listdir(self.base_dir)):
            cpath = os.path.join(self.base_dir, cls_folder)
            if not os.path.isdir(cpath): continue
            img_dir = os.path.join(cpath, "image")
            msk_dir = os.path.join(cpath, "label")
            if not (os.path.exists(img_dir) and os.path.exists(msk_dir)): continue

            img_files = self._list_images(img_dir)
            msk_files = self._list_images(msk_dir)
            img_by = {os.path.splitext(os.path.basename(p))[0].lower(): p for p in img_files}
            msk_by = {os.path.splitext(os.path.basename(p))[0].lower(): p for p in msk_files}
            common = sorted(set(img_by) & set(msk_by))
            print(f"📂 {cls_folder:20} | imgs:{len(img_files):4d} | masks:{len(msk_files):4d} | paired:{len(common):4d}")
            for s in common: imgs.append(img_by[s]); msks.append(msk_by[s])
        return imgs, msks

    def load(self):
        X, Y = [], []
        for ip, mp in tqdm(list(zip(self.image_paths, self.mask_paths)), desc="Loading data"):
            img = cv2.imread(ip, cv2.IMREAD_COLOR)
            if img is None: continue
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_AREA)
            img = img.astype(np.float32)/255.0

            msk = cv2.imread(mp, cv2.IMREAD_COLOR)
            if msk is None: continue
            msk = cv2.cvtColor(msk, cv2.COLOR_BGR2RGB)
            msk = cv2.resize(msk, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)
            msk = rgb_mask_to_classes(msk)

            X.append(img); Y.append(msk)
        return np.asarray(X, np.float32), np.asarray(Y, np.uint8)

# ============ Augmentation (tf.data) ============
def augment_img_mask(img, mask):
    img = tf.cast(img, tf.float32)
    mask = tf.cast(mask, tf.int32)
    mask_3d = tf.expand_dims(mask, axis=-1)

    def apply_transform(transform_func, prob):
        return tf.cond(
            tf.random.uniform([]) < prob,
            lambda: transform_func(img, mask_3d),
            lambda: (img, mask_3d)
        )

    def rot90_transform(i, m):
        k = tf.random.uniform([], 0, 4, dtype=tf.int32)
        return tf.image.rot90(i, k), tf.image.rot90(m, k)

    img, mask_3d = apply_transform(rot90_transform, A_ROT90_PROB)

    def flip_h(i, m): return tf.image.flip_left_right(i), tf.image.flip_left_right(m)
    def flip_v(i, m): return tf.image.flip_up_down(i), tf.image.flip_up_down(m)
    img, mask_3d = apply_transform(flip_h, A_FLIP_H_PROB)
    img, mask_3d = apply_transform(flip_v, A_FLIP_V_PROB)

    def crop_transform(i, m):
        shape = tf.shape(i)
        h, w = shape[0], shape[1]
        frac = tf.random.uniform([], CROP_MIN_FRAC, 1.0)
        nh = tf.cast(tf.cast(h, tf.float32) * frac, tf.int32)
        nw = tf.cast(tf.cast(w, tf.float32) * frac, tf.int32)
        nh = tf.minimum(nh, h); nw = tf.minimum(nw, w)
        max_y = tf.maximum(1, h - nh); max_x = tf.maximum(1, w - nw)
        oy = tf.random.uniform([], 0, max_y, dtype=tf.int32)
        ox = tf.random.uniform([], 0, max_x, dtype=tf.int32)
        i_crop = tf.image.crop_to_bounding_box(i, oy, ox, nh, nw)
        m_crop = tf.image.crop_to_bounding_box(m, oy, ox, nh, nw)

        i_resized = tf.image.resize(i_crop, [h, w], method='bilinear')

        m_crop_f = tf.cast(m_crop, tf.float32)
        m_resized_f = tf.image.resize(m_crop_f, [h, w], method='nearest')
        m_resized = tf.cast(tf.round(m_resized_f), tf.int32)

        return i_resized, m_resized

    img, mask_3d = apply_transform(crop_transform, A_CROP_PROB)

    def apply_photometric(i):
        if tf.random.uniform([]) < A_JITTER_PROB:
            i = tf.image.random_brightness(i, 0.15)
            i = tf.image.random_contrast(i, 0.8, 1.2)
            i = tf.image.random_saturation(i, 0.8, 1.2)
            i = tf.image.random_hue(i, 0.02)
            i = tf.clip_by_value(i, 0.0, 1.0)
        if tf.random.uniform([]) < A_NOISE_PROB:
            noise = tf.random.normal(tf.shape(i), 0.0, 0.02, dtype=tf.float32)
            i = tf.clip_by_value(i + noise, 0.0, 1.0)
        return i

    img = apply_photometric(img)
    mask = tf.squeeze(mask_3d, axis=-1)
    return img, mask

def one_hot(mask):
    return tf.one_hot(tf.cast(mask, tf.int32), depth=NUM_CLASSES)

def make_dataset(X, Y, batch_size=8, shuffle=False, augment=False):
    ds = tf.data.Dataset.from_tensor_slices((X, Y))
    if shuffle:
        ds = ds.shuffle(min(len(X), 1024), reshuffle_each_iteration=True)
    def process(img, mask):
        img = tf.cast(img, tf.float32)
        mask = tf.cast(mask, tf.int32)
        if augment:
            img, mask = augment_img_mask(img, mask)
        return img, one_hot(mask)
    return ds.map(process, num_parallel_calls=tf.data.AUTOTUNE)\
             .batch(batch_size)\
             .prefetch(tf.data.AUTOTUNE)

# ============ Models ============

# Pretrained U-Net with MobileNetV2 encoder (ImageNet) + fine-tune
def build_unet_mobilenetv2(input_shape, num_classes, train_encoder=False):
    # model input (your pipeline gives [0,1])
    inputs = keras.Input(shape=input_shape)
    # MobileNetV2 expects [-1, 1]; do a lightweight, differentiable scaling
    x_in = layers.Lambda(lambda t: t * 2.0 - 1.0, name="scale_to_mnv2") (inputs)

    # IMPORTANT: connect your tensor to the encoder using input_tensor=
    base = tf.keras.applications.MobileNetV2(
        input_tensor=x_in, include_top=False, weights="imagenet"
    )
    base.trainable = train_encoder

    # Skip tensors now belong to the SAME graph as `inputs`
    s1 = base.get_layer('block_1_expand_relu').output   # 128x128
    s2 = base.get_layer('block_3_expand_relu').output   # 64x64
    s3 = base.get_layer('block_6_expand_relu').output   # 32x32
    s4 = base.get_layer('block_13_expand_relu').output  # 16x16
    bn = base.get_layer('block_16_project').output      # 8x8

    def up_block(x, skip, f):
        x = layers.Conv2DTranspose(f, 3, strides=2, padding='same')(x)
        x = layers.Concatenate()([x, skip])
        x = layers.Conv2D(f, 3, padding='same', activation='relu')(x); x = layers.BatchNormalization()(x)
        x = layers.Conv2D(f, 3, padding='same', activation='relu')(x); x = layers.BatchNormalization()(x)
        return x

    x = bn
    x = up_block(x, s4, 256)   # 8->16
    x = up_block(x, s3, 128)   # 16->32
    x = up_block(x, s2, 64)    # 32->64
    x = up_block(x, s1, 32)    # 64->128
    x = layers.Conv2DTranspose(32, 3, strides=2, padding='same')(x)  # 128->256

    outputs = layers.Conv2D(num_classes, 1, activation='softmax')(x)
    model = keras.Model(inputs, outputs, name="UNet_MobileNetV2")
    model._encoder = base  # keep handle for the fine-tuning phase
    return model


# 1) UNet (from scratch)
def build_unet(input_shape, num_classes, base=48, drop=0.15):
    x_in = keras.Input(shape=input_shape)
    def blk(x,f):
        x = layers.Conv2D(f,3,padding='same',activation='relu')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Conv2D(f,3,padding='same',activation='relu')(x)
        x = layers.BatchNormalization()(x)
        x = layers.SpatialDropout2D(drop)(x)
        return x

    c1 = blk(x_in, base);        p1 = layers.MaxPooling2D(2)(c1)      # 256->128
    c2 = blk(p1, base*2);        p2 = layers.MaxPooling2D(2)(c2)      # 128->64
    c3 = blk(p2, base*4);        p3 = layers.MaxPooling2D(2)(c3)      # 64->32
    c4 = blk(p3, base*8);        p4 = layers.MaxPooling2D(2)(c4)      # 32->16
    bn = blk(p4, base*16)

    u6 = layers.Conv2DTranspose(base*8,2,2,padding='same')(bn)         # 16->32
    u6 = layers.Concatenate()([u6,c4]); c6 = blk(u6, base*8)
    u7 = layers.Conv2DTranspose(base*4,2,2,padding='same')(c6)         # 32->64
    u7 = layers.Concatenate()([u7,c3]); c7 = blk(u7, base*4)
    u8 = layers.Conv2DTranspose(base*2,2,2,padding='same')(c7)         # 64->128
    u8 = layers.Concatenate()([u8,c2]); c8 = blk(u8, base*2)
    u9 = layers.Conv2DTranspose(base,2,2,padding='same')(c8)           # 128->256
    u9 = layers.Concatenate()([u9,c1]); c9 = blk(u9, base)

    out = layers.Conv2D(num_classes,1,activation='softmax')(c9)
    return keras.Model(x_in, out, name="UNet")

# 2) DeepLabV3+ (fixed pooling resize)
def build_deeplabv3plus(input_shape, num_classes):
    def aspp(x):
        dims = x.shape[-1]
        h, w = x.shape[1], x.shape[2]  # ints for fixed input

        pool = layers.GlobalAveragePooling2D()(x)
        pool = layers.Reshape((1, 1, dims))(pool)
        pool = layers.Conv2D(256, 1, activation='relu', padding='same')(pool)
        pool = layers.BatchNormalization()(pool)
        pool = layers.UpSampling2D(size=(h, w), interpolation='bilinear')(pool)

        conv1 = layers.Conv2D(256, 1, activation='relu', padding='same')(x)
        conv1 = layers.BatchNormalization()(conv1)

        c6  = layers.Conv2D(256, 3, activation='relu', padding='same', dilation_rate=6)(x);  c6  = layers.BatchNormalization()(c6)
        c12 = layers.Conv2D(256, 3, activation='relu', padding='same', dilation_rate=12)(x); c12 = layers.BatchNormalization()(c12)
        c18 = layers.Conv2D(256, 3, activation='relu', padding='same', dilation_rate=18)(x); c18 = layers.BatchNormalization()(c18)

        y = layers.Concatenate()([pool, conv1, c6, c12, c18])
        y = layers.Conv2D(256, 1, activation='relu', padding='same')(y)
        y = layers.BatchNormalization()(y)
        return y

    inputs = keras.Input(shape=input_shape)

    x = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(inputs)  # 256->128
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)                  # 128->128
    x = layers.BatchNormalization()(x)
    low = x

    x = layers.Conv2D(128, 3, strides=2, padding='same', activation='relu')(x)      # 128->64
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)                 # 64->64
    x = layers.BatchNormalization()(x)

    x = layers.Conv2D(512, 3, strides=2, padding='same', activation='relu')(x)      # 64->32
    x = layers.BatchNormalization()(x)                                              # 32x32

    x = aspp(x)                                                                     # 32x32
    x = layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(x)               # 32->128
    low = layers.Conv2D(48, 1, activation='relu', padding='same')(low); low = layers.BatchNormalization()(low)
    x = layers.Concatenate()([x, low])                                              # 128
    x = layers.Conv2D(256, 3, activation='relu', padding='same')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256, 3, activation='relu', padding='same')(x); x = layers.BatchNormalization()(x)
    x = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(x)               # 128->256

    outputs = layers.Conv2D(num_classes, 1, activation='softmax')(x)
    return keras.Model(inputs, outputs, name="DeepLabV3Plus")

# 3) FCN
def build_fcn(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(inputs)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
    p1 = layers.MaxPooling2D(2)(x)   # 256->128
    x = layers.Conv2D(128,3,padding='same',activation='relu')(p1)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x)
    p2 = layers.MaxPooling2D(2)(x)   # 128->64
    x = layers.Conv2D(256,3,padding='same',activation='relu')(p2)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x)
    p3 = layers.MaxPooling2D(2)(x)   # 64->32
    x = layers.Conv2D(512,3,padding='same',activation='relu')(p3)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x)
    p4 = layers.MaxPooling2D(2)(x)   # 32->16
    x = layers.Conv2D(512,3,padding='same',activation='relu')(p4)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x)
    p5 = layers.MaxPooling2D(2)(x)   # 16->8

    x = layers.Conv2D(4096,7,padding='same',activation='relu')(p5); x = layers.Dropout(0.5)(x)
    x = layers.Conv2D(4096,1,activation='relu')(x); x = layers.Dropout(0.5)(x)

    s5 = layers.Conv2D(num_classes,1)(x)
    s4 = layers.Conv2D(num_classes,1)(p4)
    s3 = layers.Conv2D(num_classes,1)(p3)

    up2 = layers.Conv2DTranspose(num_classes,4,strides=2,padding='same')(s5)  # 8->16
    f4  = layers.Add()([up2, s4])
    up4 = layers.Conv2DTranspose(num_classes,4,strides=2,padding='same')(f4)  # 16->32
    f3  = layers.Add()([up4, s3])
    outputs = layers.Conv2DTranspose(num_classes,16,strides=8,padding='same',activation='softmax')(f3)  # 32->256
    return keras.Model(inputs, outputs, name="FCN")

# 4) SegNet
def build_segnet(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(inputs); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x);     x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(2)(x)  # 256->128

    x = layers.Conv2D(128,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(2)(x)  # 128->64

    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(2)(x)  # 64->32

    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(2)(x)  # 32->16

    x = layers.UpSampling2D(2)(x)  # 16->32
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)

    x = layers.UpSampling2D(2)(x)  # 32->64
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)

    x = layers.UpSampling2D(2)(x)  # 64->128
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)

    x = layers.UpSampling2D(2)(x)  # 128->256
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)

    outputs = layers.Conv2D(num_classes,1,activation='softmax')(x)
    return keras.Model(inputs, outputs, name="SegNet")

# 5) BiSeNetV2 (compact; main head only) — fixed ContextEmbedding (no Lambda)
def build_bisenetv2(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)
    def ConvBNReLU(x, f, k=3, s=1):
        x = layers.Conv2D(f, k, strides=s, padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        return layers.ReLU()(x)
    def DWConvBNReLU(x, k=3, s=1):
        x = layers.DepthwiseConv2D(k, strides=s, padding='same', use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        return layers.ReLU()(x)
    def DetailBranch(x):
        x = ConvBNReLU(x, 64, 3, 2); x = ConvBNReLU(x, 64, 3, 1); x = ConvBNReLU(x, 64, 3, 1)
        x = ConvBNReLU(x, 64, 3, 2); x = ConvBNReLU(x, 64, 3, 1); x = ConvBNReLU(x, 64, 3, 1)
        x = ConvBNReLU(x, 128, 3, 2); x = ConvBNReLU(x, 128, 3, 1); x = ConvBNReLU(x, 128, 3, 1)
        return x  # /8
    def StemBlock(x):
        x = ConvBNReLU(x, 16, 3, 2); x = DWConvBNReLU(x, 3, 1); x = ConvBNReLU(x, 16, 1, 1)
        x = ConvBNReLU(x, 32, 3, 2); x = DWConvBNReLU(x, 3, 1); x = ConvBNReLU(x, 32, 1, 1)
        return x  # /4
    def GEBlock(x, out_ch, stride):
        in_ch = x.shape[-1]
        y = DWConvBNReLU(x, 3, stride); y = ConvBNReLU(y, out_ch, 1, 1)
        if stride == 1 and in_ch == out_ch: y = layers.Add()([x, y])
        return y
    def ContextEmbedding(x, ch=128):
        h = layers.GlobalAveragePooling2D(keepdims=True)(x); h = layers.BatchNormalization()(h); h = ConvBNReLU(h, ch, 1, 1)
        H = input_shape[0] // 16; W = input_shape[1] // 16
        h = layers.UpSampling2D(size=(H, W), interpolation='bilinear')(h)
        y = layers.Add()([x, h]); y = ConvBNReLU(y, ch, 3, 1)
        return y
    def SemanticBranch(x):
        x = StemBlock(x); x = GEBlock(x, 64, 2); x = GEBlock(x, 64, 1)
        x = GEBlock(x, 128, 2); x = GEBlock(x, 128, 1); x = GEBlock(x, 128, 1)
        x = ContextEmbedding(x, 128); x = layers.UpSampling2D(size=2, interpolation='bilinear')(x)
        return x
    def FeatureFusion(detail, semantic, out_ch=256):
        x = layers.Concatenate()([detail, semantic]); trunk = ConvBNReLU(x, out_ch, 3, 1)
        att = layers.GlobalAveragePooling2D(keepdims=True)(trunk); att = ConvBNReLU(att, out_ch // 4, 1, 1)
        att = layers.Conv2D(out_ch, 1, activation='sigmoid', padding='same')(att)
        out = layers.Multiply()([trunk, att]); out = layers.Add()([trunk, out]); return out
    def SegHead(x, num_classes, up_factor=8):
        x = ConvBNReLU(x, 128, 3, 1); x = layers.Conv2D(num_classes, 1, padding='same', activation='softmax')(x)
        x = layers.UpSampling2D(size=up_factor, interpolation='bilinear')(x); return x
    db = DetailBranch(inputs); sb = SemanticBranch(inputs); fused = FeatureFusion(db, sb, out_ch=256)
    outputs = SegHead(fused, num_classes, up_factor=8); return keras.Model(inputs, outputs, name="BiSeNetV2")

# ============ Loss and Metrics ============
SMOOTH = 1e-6

def _resize_to_label(y_pred, y_true):
    ph = tf.shape(y_pred)[1]; pw = tf.shape(y_pred)[2]
    th = tf.shape(y_true)[1]; tw = tf.shape(y_true)[2]
    need = tf.logical_or(tf.not_equal(ph, th), tf.not_equal(pw, tw))
    def _do(): return tf.image.resize(y_pred, (th, tw), method='bilinear')
    return tf.cond(need, _do, lambda: y_pred)

def weighted_ce(y_true, y_pred):
    y_pred = _resize_to_label(y_pred, y_true)
    w = tf.reduce_sum(CLASS_WEIGHTS * y_true, axis=-1)                 # [B,H,W]
    ce = tf.keras.losses.categorical_crossentropy(y_true, y_pred)      # [B,H,W]
    return tf.reduce_mean(ce * w)

def focal_tversky_loss(y_true, y_pred, alpha=0.7, beta=0.3, gamma=0.75, exclude_bg=True):
    y_pred = _resize_to_label(y_pred, y_true)
    if exclude_bg:
        y_true = y_true[...,1:]; y_pred = y_pred[...,1:]
    y_true_f = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
    y_pred_f = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
    tp = tf.reduce_sum(y_true_f * y_pred_f, axis=0)
    fp = tf.reduce_sum((1. - y_true_f) * y_pred_f, axis=0)
    fn = tf.reduce_sum(y_true_f * (1. - y_pred_f), axis=0)
    t = (tp + SMOOTH) / (tp + alpha*fn + beta*fp + SMOOTH)
    ft = tf.pow(1. - t, gamma)
    return tf.reduce_mean(ft)

def combo_loss_stronger(y_true, y_pred, alpha=0.5):
    return alpha * weighted_ce(y_true, y_pred) + (1.0 - alpha) * focal_tversky_loss(y_true, y_pred)

@tf.function
def iou_no_bg(y_true, y_pred):
    y_pred = _resize_to_label(y_pred, y_true)
    y_true_cls = tf.argmax(y_true, axis=-1); y_pred_cls = tf.argmax(y_pred, axis=-1)
    y_true_oh = tf.one_hot(y_true_cls, depth=NUM_CLASSES, dtype=tf.float32)
    y_pred_oh = tf.one_hot(y_pred_cls, depth=NUM_CLASSES, dtype=tf.float32)
    y_true_f = tf.reshape(y_true_oh, [-1, NUM_CLASSES]); y_pred_f = tf.reshape(y_pred_oh, [-1, NUM_CLASSES])
    inter = tf.reduce_sum(y_true_f * y_pred_f, axis=0)
    union = tf.reduce_sum(y_true_f + y_pred_f - y_true_f * y_pred_f, axis=0)
    inter_nb = inter[1:]; union_nb = union[1:]
    iou = tf.where(union_nb > 0.0, inter_nb / (union_nb + 1e-7), 0.0)
    return tf.reduce_mean(iou)

# ============ Training / Evaluation Helpers ============
def compile_model(model, lr=1e-3):
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr),
        loss=combo_loss_stronger,
        metrics=[iou_no_bg, 'accuracy']
    )
    return model

def plot_history(hist, title, outdir):
    plt.figure(figsize=(10,4))
    # loss
    plt.subplot(1,2,1); plt.plot(hist.history['loss'], label='train')
    if 'val_loss' in hist.history: plt.plot(hist.history['val_loss'], label='val')
    plt.title(f'{title} - Loss'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend()
    # IoU
    if 'iou_no_bg' in hist.history:
        plt.subplot(1,2,2); plt.plot(hist.history['iou_no_bg'], label='train IoU')
        if 'val_iou_no_bg' in hist.history: plt.plot(hist.history['val_iou_no_bg'], label='val IoU')
        plt.title(f'{title} - IoU(no-bg)'); plt.xlabel('Epoch'); plt.ylabel('IoU'); plt.legend()
    plt.tight_layout(); os.makedirs(outdir, exist_ok=True)
    p = os.path.join(outdir, f'{title}_curves.png')
    plt.savefig(p, dpi=140, bbox_inches='tight'); plt.close()
    return p

def compute_auto_class_weights(Y_int, num_classes):
    counts = np.bincount(Y_int.flatten(), minlength=num_classes).astype(np.float64)
    p = counts / max(1, counts.sum())
    w = 1.0 / np.log(1.02 + p + 1e-12)  # inverse log frequency
    w[0] = max(w[0]*0.5, 0.25)          # don't over-weight background
    return w

# ============ Main ============
def main():
    t0 = time.time()
    print("🔎 Loading dataset...")
    ds = AppleLeafDataset(BASE_DIR, IMG_SIZE)
    X, Y = ds.load()

    # simple split (≈70/15/15)
    Xtr, Xte, Ytr, Yte = train_test_split(X, Y, test_size=0.15, random_state=SEED, shuffle=True)
    Xtr, Xva, Ytr, Yva = train_test_split(Xtr, Ytr, test_size=0.1765, random_state=SEED, shuffle=True)

    print(f"Shapes -> Train: {Xtr.shape}, Val: {Xva.shape}, Test: {Xte.shape}")

    # Optional: auto class weights from train
    if AUTO_CLASS_WEIGHTS:
        global CLASS_WEIGHTS
        CLASS_WEIGHTS = tf.constant(compute_auto_class_weights(Ytr, NUM_CLASSES), dtype=tf.float32)
        print("Class weights (auto):", CLASS_WEIGHTS.numpy())

    train_ds = make_dataset(Xtr, Ytr, BATCH_SIZE, shuffle=True, augment=True)
    val_ds   = make_dataset(Xva, Yva, BATCH_SIZE, shuffle=False, augment=False)
    test_ds  = make_dataset(Xte, Yte, BATCH_SIZE, shuffle=False, augment=False)

    input_shape = (IMG_SIZE, IMG_SIZE, 3)

    # Name -> builder (order preserved)
    builders = {
        'UNet_MobileNetV2': build_unet_mobilenetv2,   # pretrained + fine-tune
        'UNet': build_unet,
        'DeepLabV3Plus': build_deeplabv3plus,
        'FCN': build_fcn,
        'SegNet': build_segnet,
        'BiSeNetV2': build_bisenetv2
    }

    # decide which to run
    model_names = list(builders.keys())
    if START_AT and START_AT in model_names:
        model_names = model_names[model_names.index(START_AT):]
    if RUN_ONLY:
        model_names = [n for n in model_names if n in RUN_ONLY]

    results = []
    os.makedirs(OUTDIR, exist_ok=True)

    for name in model_names:
        tf.keras.backend.clear_session()
        ckpt_path = os.path.join(OUTDIR, f"{name}_best.keras")

        if SKIP_TRAIN_IF_CKPT and os.path.exists(ckpt_path):
            print(f"⏭️  {name}: checkpoint found -> skipping training, loading for eval")
            try:
                model = keras.models.load_model(ckpt_path, compile=False)
                compile_model(model, lr=1e-3)
            except Exception as e:
                print(f"⚠️  Failed to load {name} checkpoint: {e}")
                print(f"🔄 Re-training {name} instead...")
                model = builders[name](input_shape, NUM_CLASSES)
                # standard single-phase training below
                compile_model(model, lr=1e-3)
                cbs = [
                    keras.callbacks.ModelCheckpoint(ckpt_path, monitor='val_iou_no_bg', mode='max',
                                                    save_best_only=True, save_weights_only=False, verbose=1),
                    keras.callbacks.EarlyStopping(monitor='val_iou_no_bg', mode='max',
                                                  patience=6, restore_best_weights=True),
                    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3,
                                                      min_lr=1e-5, verbose=1)
                ]
                hist = model.fit(train_ds, validation_data=val_ds,
                                 epochs=EPOCHS, verbose=1, callbacks=cbs)
                _ = plot_history(hist, name, OUTDIR)
        else:
            print(f"\n🚀 Training {name} ...")
            model = builders[name](input_shape, NUM_CLASSES)

            if hasattr(model, "_encoder"):
                # Phase 1: freeze encoder
                model._encoder.trainable = False
                compile_model(model, lr=1e-3)
                hist1 = model.fit(train_ds, validation_data=val_ds, epochs=max(2, EPOCHS//2), verbose=1)
                # Phase 2: unfreeze encoder and fine-tune with lower LR
                model._encoder.trainable = True
                compile_model(model, lr=3e-4)
                cbs = [
                    keras.callbacks.ModelCheckpoint(ckpt_path, monitor='val_iou_no_bg', mode='max',
                                                    save_best_only=True, save_weights_only=False, verbose=1),
                    keras.callbacks.EarlyStopping(monitor='val_iou_no_bg', mode='max',
                                                  patience=6, restore_best_weights=True),
                    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3,
                                                      min_lr=1e-5, verbose=1)
                ]
                hist2 = model.fit(train_ds, validation_data=val_ds,
                                  initial_epoch=max(2, EPOCHS//2), epochs=EPOCHS, verbose=1, callbacks=cbs)
                _ = plot_history(hist2, name, OUTDIR)
            else:
                compile_model(model, lr=1e-3)
                cbs = [
                    keras.callbacks.ModelCheckpoint(ckpt_path, monitor='val_iou_no_bg', mode='max',
                                                    save_best_only=True, save_weights_only=False, verbose=1),
                    keras.callbacks.EarlyStopping(monitor='val_iou_no_bg', mode='max',
                                                  patience=6, restore_best_weights=True),
                    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3,
                                                      min_lr=1e-5, verbose=1)
                ]
                hist = model.fit(train_ds, validation_data=val_ds,
                                 epochs=EPOCHS, verbose=1, callbacks=cbs)
                _ = plot_history(hist, name, OUTDIR)

        # Evaluate
        print(f"📏 Evaluating {name} ...")
        val_metrics  = model.evaluate(val_ds,  verbose=0)
        test_metrics = model.evaluate(test_ds, verbose=0)
        res = {
            "Model": name,
            "Val Loss": float(val_metrics[0]),
            "Val IoU(no-bg)": float(val_metrics[1]),
            "Val Acc": float(val_metrics[2]),
            "Test Loss": float(test_metrics[0]),
            "Test IoU(no-bg)": float(test_metrics[1]),
            "Test Acc": float(test_metrics[2]),
        }
        results.append(res)

        # Visuals: single-sample rows + grid panel with severity %
        _ = visualize_with_severity(model, Xva, Yva, n=4, outdir=OUTDIR)
        _ = visualize_grid_with_severity(model, Xva, Yva, k=6, outdir=OUTDIR,
                                         title=f"{name} — Dice+CE + Focal-Tversky, Augmented")

    # Save/append comparison CSV
    cmp_csv = os.path.join(OUTDIR, "model_comparison.csv")
    df_new = pd.DataFrame(results)
    if os.path.exists(cmp_csv):
        try:
            df_old = pd.read_csv(cmp_csv)
            df = (pd.concat([df_old, df_new], ignore_index=True)
                    .drop_duplicates(subset=['Model'], keep='last'))
        except Exception:
            df = df_new
    else:
        df = df_new
    df = df.sort_values("Val IoU(no-bg)", ascending=False)
    df.to_csv(cmp_csv, index=False)
    print("\n================ Model Comparison (by Val IoU no-bg) ================")
    print(df.to_string(index=False))
    print(f"\n💾 Saved: {cmp_csv}")
    print(f"✅ Done in {time.time()-t0:.1f}s")

if __name__ == "__main__":
    main()


⚠️  GPU mem-growth not set: Physical devices cannot be modified after being initialized
🔎 Loading dataset...
📂 Alternaria leaf spot | imgs: 278 | masks: 278 | paired: 278
📂 Brown spot           | imgs: 215 | masks: 215 | paired: 215
📂 Gray spot            | imgs: 395 | masks: 395 | paired: 395
📂 Healthy leaf         | imgs: 409 | masks: 409 | paired: 409
📂 Rust                 | imgs: 344 | masks: 344 | paired: 344
✅ Paired samples: 1641


Loading data: 100%|██████████| 1641/1641 [00:22<00:00, 73.25it/s]


Shapes -> Train: (1147, 256, 256, 3), Val: (247, 256, 256, 3), Test: (247, 256, 256, 3)
Class weights (auto): [ 0.9594292  3.3929112 41.759647  48.69141   47.26587   37.653244 ]

🚀 Training UNet_MobileNetV2 ...


  base = tf.keras.applications.MobileNetV2(


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
[1m9406464/9406464[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Epoch 1/6


E0000 00:00:1756621955.780273      99 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1756621955.933191      99 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.


[1m143/144[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 74ms/step - accuracy: 0.7091 - iou_no_bg: 0.1901 - loss: 1.5009

E0000 00:00:1756621983.011304      98 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1756621983.155204      98 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.


[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 176ms/step - accuracy: 0.7101 - iou_no_bg: 0.1905 - loss: 1.4982

E0000 00:00:1756622005.574143      98 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1756622005.719288      98 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.


[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 288ms/step - accuracy: 0.7111 - iou_no_bg: 0.1909 - loss: 1.4956 - val_accuracy: 0.5633 - val_iou_no_bg: 0.0420 - val_loss: 4.5434
Epoch 2/6
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 81ms/step - accuracy: 0.9323 - iou_no_bg: 0.3389 - loss: 0.7989 - val_accuracy: 0.9545 - val_iou_no_bg: 0.3683 - val_loss: 0.7648
Epoch 3/6
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 81ms/step - accuracy: 0.9412 - iou_no_bg: 0.4042 - loss: 0.6354 - val_accuracy: 0.9561 - val_iou_no_bg: 0.4431 - val_loss: 0.5563
Epoch 4/6
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 80ms/step - accuracy: 0.9484 - iou_no_bg: 0.4184 - loss: 0.5920 - val_accuracy: 0.9502 - val_iou_no_bg: 0.4218 - val_loss: 0.5043
Epoch 5/6
[1m144/144[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 80ms/step - accuracy: 0.9476 - iou_no_bg: 0.4232 - loss: 0.5871 - val_accuracy: 0.9550 - val_iou_no_bg: 

E0000 00:00:1756622291.892091      96 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1756622292.044274      96 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.
E0000 00:00:1756622292.176930      96 gpu_timer.cc:82] Delay kernel timed out: measured time has sub-optimal accuracy. There may be a missing warmup execution, please investigate in Nsight Systems.


🖼️  Saved severity visualization: outputs/viz_with_severity_UNet_MobileNetV2.png
🖼️  Saved grid panel: outputs/panel_severity_UNet_MobileNetV2.png
⏭️  UNet: checkpoint found -> skipping training, loading for eval
📏 Evaluating UNet ...
🖼️  Saved severity visualization: outputs/viz_with_severity_UNet.png
🖼️  Saved grid panel: outputs/panel_severity_UNet.png
⏭️  DeepLabV3Plus: checkpoint found -> skipping training, loading for eval
📏 Evaluating DeepLabV3Plus ...
🖼️  Saved severity visualization: outputs/viz_with_severity_DeepLabV3Plus.png
🖼️  Saved grid panel: outputs/panel_severity_DeepLabV3Plus.png
⏭️  FCN: checkpoint found -> skipping training, loading for eval
📏 Evaluating FCN ...
🖼️  Saved severity visualization: outputs/viz_with_severity_FCN.png
🖼️  Saved grid panel: outputs/panel_severity_FCN.png
⏭️  SegNet: checkpoint found -> skipping training, loading for eval
📏 Evaluating SegNet ...
🖼️  Saved severity visualization: outputs/viz_with_severity_SegNet.png
🖼️  Saved grid panel: out

In [4]:
# =========================================================
# Apple Leaf Segmentation — Lift IoU Accuracy
# Strategies included:
#  • Better loss: Weighted CE + Focal-Tversky (alpha=0.8, beta=0.2, gamma=0.9)
#  • Auto class weights from train masks
#  • Disease-aware oversampling for training
#  • Larger input option (progressive-ready): set IMG_SIZE to 384/512 if you can
#  • Strong pretrained model: U-Net with MobileNetV2 encoder (fine-tune in 2 phases)
#  • Aug tweaks: allow tighter crops to zoom into lesions
#  • Metrics: overall IoU(no-bg) + Disease-only mIoU for monitoring
#  • Visuals: Image | GT | Pred + per-class severity percentages, and 3-row grid
# =========================================================
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
os.environ["TF_DISABLE_PROFILER"] = "1"

import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pandas as pd
import time
import random

# ============ Config ============
BASE_DIR    = "/kaggle/input/apple-dataset/ATLDSD"   # <--- change if needed
IMG_SIZE    = 256   # Try 384 or 512 for higher IoU if GPU allows
BATCH_SIZE  = 8
EPOCHS      = 40    # train longer for better convergence
SEED        = 2025

OUTDIR = "outputs"
SKIP_TRAIN_IF_CKPT = True
RUN_ONLY = ['UNet_MobileNetV2']   # focus on the strongest model first
START_AT = None

# Training knobs
AUTO_CLASS_WEIGHTS    = True   # infer from Y_train
USE_TTA_IN_VIZ        = True   # nicer visuals; eval uses built-in metrics
OVERSAMPLE_DISEASE    = True   # oversample disease-rich images in train
TOLERANT_LABEL_COLORS = True   # handle slight RGB variations in masks

random.seed(SEED); np.random.seed(SEED); tf.random.set_seed(SEED)

CLASS_NAMES = [
    "Background",           # 0
    "Healthy",              # 1
    "Brown spot",           # 2
    "Alternaria leaf spot", # 3
    "Gray spot",            # 4
    "Rust"                  # 5
]
NUM_CLASSES = len(CLASS_NAMES)

COLOR_MAP = {
    (0,   0,   0): 0,   # Background
    (128, 0,   0): 1,   # Healthy
    (128, 0, 128): 2,   # Brown spot
    (128,128,  0): 3,   # Alternaria
    (0,   0, 128): 4,   # Gray
    (0, 128,   0): 5,   # Rust
}

# default weights (overridden if AUTO_CLASS_WEIGHTS)
CLASS_WEIGHTS = tf.constant([0.25, 0.7, 1.1, 1.1, 1.1, 1.2], dtype=tf.float32)

# Augmentation knobs (more zoom-in to highlight lesions)
A_ROT90_PROB   = 0.75
A_FLIP_H_PROB  = 0.5
A_FLIP_V_PROB  = 0.5
A_JITTER_PROB  = 0.6
A_NOISE_PROB   = 0.3
A_CROP_PROB    = 0.7
CROP_MIN_FRAC  = 0.65   # was 0.85 -> more aggressive zooms

# ============ Utils ============
def set_gpu_growth():
    try:
        gpus = tf.config.list_physical_devices('GPU')
        if gpus:
            for g in gpus:
                tf.config.experimental.set_memory_growth(g, True)
            print(f"✅ GPU found: {len(gpus)}; memory growth enabled")
        else:
            print("ℹ️  No GPU detected; running on CPU")
    except Exception as e:
        print("⚠️  GPU mem-growth not set:", e)
set_gpu_growth()

def rgb_mask_to_classes(mask_rgb, tol=10):
    """
    Map RGB to class ids; if TOLERANT_LABEL_COLORS=True, allow a tolerance on RGB values.
    """
    if not TOLERANT_LABEL_COLORS:
        out = np.zeros(mask_rgb.shape[:2], dtype=np.uint8)
        R, G, B = mask_rgb[...,0], mask_rgb[...,1], mask_rgb[...,2]
        for (r,g,b), cls in COLOR_MAP.items():
            m = (R == r) & (G == g) & (B == b)
            out[m] = cls
        return out
    else:
        out = np.zeros(mask_rgb.shape[:2], dtype=np.uint8)
        for (r,g,b), cls in COLOR_MAP.items():
            diff = np.abs(mask_rgb - np.array([r,g,b], np.uint8))
            m = (diff[...,0] <= tol) & (diff[...,1] <= tol) & (diff[...,2] <= tol)
            out[m] = cls
        return out

PALETTE = {
    0:(0,0,0), 1:(128,0,0), 2:(128,0,128), 3:(128,128,0), 4:(0,0,128), 5:(0,128,0)
}
def mask_to_color(mask):
    h,w = mask.shape
    out = np.zeros((h,w,3), dtype=np.uint8)
    for c, col in PALETTE.items():
        out[mask==c] = col
    return out

# ===================== Severity Utils & Visualization =====================
def compute_severity_percentages(mask_int):
    m = np.asarray(mask_int, dtype=np.uint8)
    leaf = (m != 0)
    leaf_pixels = int(leaf.sum())
    if leaf_pixels == 0:
        per_class = {CLASS_NAMES[c]: 0.0 for c in range(2, NUM_CLASSES)}
        healthy_pct = 0.0
        disease_total = 0.0
        return per_class, healthy_pct, disease_total

    per_class = {}
    for c in range(2, NUM_CLASSES):
        per_class[CLASS_NAMES[c]] = 100.0 * float((m == c).sum()) / leaf_pixels
    healthy_pct = 100.0 * float((m == 1).sum()) / leaf_pixels
    disease_total = 100.0 - healthy_pct
    return per_class, healthy_pct, disease_total

def _box_text_from_severity(per_class, healthy_pct, disease_total):
    lines = [f"Healthy: {healthy_pct:5.1f}%",
             f"Total disease: {disease_total:5.1f}%"]
    for name, pct in per_class.items():
        lines.append(f"{name}: {pct:5.1f}%")
    return "\n".join(lines)

def predict_prob_tta(model, img):
    imgs = [img,
            np.flip(img, axis=1),
            np.flip(img, axis=0),
            np.rot90(img, k=1)]
    probs = []
    for im in imgs:
        p = model.predict(im[None], verbose=0)[0]
        probs.append(p)
    probs[1] = np.flip(probs[1], axis=1)
    probs[2] = np.flip(probs[2], axis=0)
    probs[3] = np.rot90(probs[3], k=3)
    return np.mean(probs, axis=0)

def small_component_cleanup(mask, min_frac=0.001):
    H, W = mask.shape
    min_area = max(1, int(H*W*min_frac))
    out = mask.copy()
    for c in range(1, NUM_CLASSES):
        m = (out == c).astype(np.uint8)
        if m.sum() == 0: continue
        num, labels = cv2.connectedComponents(m, connectivity=8)
        for lab in range(1, num):
            area = int((labels == lab).sum())
            if area < min_area:
                out[labels == lab] = 1  # send tiny islands to 'Healthy'
    return out

def visualize_with_severity(model, Xv, Yv_int, n=4, outdir="outputs", seed=2025):
    np.random.seed(seed)
    os.makedirs(outdir, exist_ok=True)
    idx = np.random.choice(len(Xv), size=min(n, len(Xv)), replace=False)
    fig, axs = plt.subplots(len(idx), 3, figsize=(11, 3.6*len(idx)))
    if len(idx) == 1: axs = np.expand_dims(axs, 0)

    for r, i in enumerate(idx):
        img = Xv[i]
        gt  = Yv_int[i].astype(np.uint8)
        pr  = predict_prob_tta(model, img) if USE_TTA_IN_VIZ else model.predict(img[None], verbose=0)[0]
        pm  = np.argmax(pr, axis=-1).astype(np.uint8)
        pm  = small_component_cleanup(pm, min_frac=0.001)

        gt_per, gt_h, gt_d = compute_severity_percentages(gt)
        pr_per, pr_h, pr_d = compute_severity_percentages(pm)

        axs[r,0].imshow(img); axs[r,0].set_title("Image"); axs[r,0].axis('off')

        axs[r,1].imshow(mask_to_color(gt)); axs[r,1].set_title("Ground Truth"); axs[r,1].axis('off')
        axs[r,1].text(0.02, 0.98, _box_text_from_severity(gt_per, gt_h, gt_d),
                      transform=axs[r,1].transAxes, va='top', ha='left', fontsize=9,
                      bbox=dict(facecolor='white', alpha=0.75, edgecolor='black', boxstyle='round,pad=0.4'))

        axs[r,2].imshow(mask_to_color(pm)); axs[r,2].set_title(f"Predicted ({model.name})"); axs[r,2].axis('off')
        axs[r,2].text(0.02, 0.98, _box_text_from_severity(pr_per, pr_h, pr_d),
                      transform=axs[r,2].transAxes, va='top', ha='left', fontsize=9,
                      bbox=dict(facecolor='white', alpha=0.75, edgecolor='black', boxstyle='round,pad=0.4'))
    plt.tight_layout()
    p = os.path.join(outdir, f"viz_with_severity_{model.name}.png")
    plt.savefig(p, dpi=150, bbox_inches='tight'); plt.close()
    print(f"🖼️  Saved severity visualization: {p}")
    return p

def visualize_grid_with_severity(model, Xv, Yv_int, k=6, outdir="outputs",
                                 title=None, seed=2025):
    np.random.seed(seed)
    os.makedirs(outdir, exist_ok=True)
    idx = np.random.choice(len(Xv), size=min(k, len(Xv)), replace=False)
    cols = len(idx); rows = 3
    fig, axs = plt.subplots(rows, cols, figsize=(cols*3.2, rows*3.2))
    if cols == 1: axs = np.expand_dims(axs, 1)

    for c, i in enumerate(idx):
        img = Xv[i]; gt = Yv_int[i].astype(np.uint8)
        pr = predict_prob_tta(model, img) if USE_TTA_IN_VIZ else model.predict(img[None], verbose=0)[0]
        pm = np.argmax(pr, axis=-1).astype(np.uint8)
        pm = small_component_cleanup(pm, min_frac=0.001)

        axs[0, c].imshow(img); axs[0, c].set_title(f"Image {c+1}", fontsize=11); axs[0, c].axis('off')

        axs[1, c].imshow(mask_to_color(gt)); axs[1, c].set_title("GT", fontsize=10); axs[1, c].axis('off')
        gt_per, gt_h, gt_d = compute_severity_percentages(gt)
        axs[1, c].text(0.01, 0.99, _box_text_from_severity(gt_per, gt_h, gt_d),
                       transform=axs[1, c].transAxes, va='top', ha='left', fontsize=8,
                       bbox=dict(facecolor='white', alpha=0.78, edgecolor='black', boxstyle='round,pad=0.3'))

        axs[2, c].imshow(mask_to_color(pm)); axs[2, c].set_title("Pred", fontsize=10); axs[2, c].axis('off')
        pr_per, pr_h, pr_d = compute_severity_percentages(pm)
        axs[2, c].text(0.01, 0.99, _box_text_from_severity(pr_per, pr_h, pr_d),
                       transform=axs[2, c].transAxes, va='top', ha='left', fontsize=8,
                       bbox=dict(facecolor='white', alpha=0.78, edgecolor='black', boxstyle='round,pad=0.3'))
    if title is None:
        title = f"{model.name} — WeightedCE + Focal-Tversky"
    fig.suptitle(title, fontsize=14)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    p = os.path.join(outdir, f"panel_severity_{model.name}.png")
    plt.savefig(p, dpi=150, bbox_inches='tight'); plt.close()
    print(f"🖼️  Saved grid panel: {p}")
    return p

# ============ Dataset ============
class AppleLeafDataset:
    IMG_EXTS = (".png",".jpg",".jpeg",".bmp",".tif",".tiff")
    def __init__(self, base_dir, image_size=256):
        self.base_dir = base_dir
        self.image_size = image_size
        self.image_paths, self.mask_paths = self._discover_pairs()
        print(f"✅ Paired samples: {len(self.image_paths)}")

    def _list_images(self, d):
        acc = []
        for r,_,fs in os.walk(d):
            for f in fs:
                if f.lower().endswith(self.IMG_EXTS):
                    acc.append(os.path.join(r,f))
        return acc

    def _discover_pairs(self):
        imgs, msks = [], []
        if not os.path.exists(self.base_dir):
            raise FileNotFoundError(f"Base dir not found: {self.base_dir}")
        for cls_folder in sorted(os.listdir(self.base_dir)):
            cpath = os.path.join(self.base_dir, cls_folder)
            if not os.path.isdir(cpath): continue
            img_dir = os.path.join(cpath, "image")
            msk_dir = os.path.join(cpath, "label")
            if not (os.path.exists(img_dir) and os.path.exists(msk_dir)): continue

            img_files = self._list_images(img_dir)
            msk_files = self._list_images(msk_dir)
            img_by = {os.path.splitext(os.path.basename(p))[0].lower(): p for p in img_files}
            msk_by = {os.path.splitext(os.path.basename(p))[0].lower(): p for p in msk_files}
            common = sorted(set(img_by) & set(msk_by))
            print(f"📂 {cls_folder:20} | imgs:{len(img_files):4d} | masks:{len(msk_files):4d} | paired:{len(common):4d}")
            for s in common: imgs.append(img_by[s]); msks.append(msk_by[s])
        return imgs, msks

    def load(self):
        X, Y = [], []
        for ip, mp in tqdm(list(zip(self.image_paths, self.mask_paths)), desc="Loading data"):
            img = cv2.imread(ip, cv2.IMREAD_COLOR)
            if img is None: continue
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_AREA)
            img = img.astype(np.float32)/255.0

            msk = cv2.imread(mp, cv2.IMREAD_COLOR)
            if msk is None: continue
            msk = cv2.cvtColor(msk, cv2.COLOR_BGR2RGB)
            msk = cv2.resize(msk, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)
            msk = rgb_mask_to_classes(msk)

            X.append(img); Y.append(msk)
        return np.asarray(X, np.float32), np.asarray(Y, np.uint8)

# ============ Augmentation (tf.data) ============
def augment_img_mask(img, mask):
    img = tf.cast(img, tf.float32)
    mask = tf.cast(mask, tf.int32)
    mask_3d = tf.expand_dims(mask, axis=-1)

    def apply_transform(transform_func, prob):
        return tf.cond(
            tf.random.uniform([]) < prob,
            lambda: transform_func(img, mask_3d),
            lambda: (img, mask_3d)
        )

    def rot90_transform(i, m):
        k = tf.random.uniform([], 0, 4, dtype=tf.int32)
        return tf.image.rot90(i, k), tf.image.rot90(m, k)
    img, mask_3d = apply_transform(rot90_transform, A_ROT90_PROB)

    def flip_h(i, m): return tf.image.flip_left_right(i), tf.image.flip_left_right(m)
    def flip_v(i, m): return tf.image.flip_up_down(i), tf.image.flip_up_down(m)
    img, mask_3d = apply_transform(flip_h, A_FLIP_H_PROB)
    img, mask_3d = apply_transform(flip_v, A_FLIP_V_PROB)

    def crop_transform(i, m):
        shape = tf.shape(i)
        h, w = shape[0], shape[1]
        frac = tf.random.uniform([], CROP_MIN_FRAC, 1.0)
        nh = tf.cast(tf.cast(h, tf.float32) * frac, tf.int32)
        nw = tf.cast(tf.cast(w, tf.float32) * frac, tf.int32)
        nh = tf.minimum(nh, h); nw = tf.minimum(nw, w)
        max_y = tf.maximum(1, h - nh); max_x = tf.maximum(1, w - nw)
        oy = tf.random.uniform([], 0, max_y, dtype=tf.int32)
        ox = tf.random.uniform([], 0, max_x, dtype=tf.int32)
        i_crop = tf.image.crop_to_bounding_box(i, oy, ox, nh, nw)
        m_crop = tf.image.crop_to_bounding_box(m, oy, ox, nh, nw)

        i_resized = tf.image.resize(i_crop, [h, w], method='bilinear')
        m_crop_f = tf.cast(m_crop, tf.float32)
        m_resized_f = tf.image.resize(m_crop_f, [h, w], method='nearest')
        m_resized = tf.cast(tf.round(m_resized_f), tf.int32)
        return i_resized, m_resized

    img, mask_3d = apply_transform(crop_transform, A_CROP_PROB)

    def apply_photometric(i):
        if tf.random.uniform([]) < A_JITTER_PROB:
            i = tf.image.random_brightness(i, 0.15)
            i = tf.image.random_contrast(i, 0.8, 1.2)
            i = tf.image.random_saturation(i, 0.8, 1.2)
            i = tf.image.random_hue(i, 0.02)
            i = tf.clip_by_value(i, 0.0, 1.0)
        if tf.random.uniform([]) < A_NOISE_PROB:
            noise = tf.random.normal(tf.shape(i), 0.0, 0.02, dtype=tf.float32)
            i = tf.clip_by_value(i + noise, 0.0, 1.0)
        return i

    img = apply_photometric(img)
    mask = tf.squeeze(mask_3d, axis=-1)
    return img, mask

def one_hot(mask):
    return tf.one_hot(tf.cast(mask, tf.int32), depth=NUM_CLASSES)

def make_dataset(X, Y, batch_size=8, shuffle=False, augment=False):
    ds = tf.data.Dataset.from_tensor_slices((X, Y))
    if shuffle:
        ds = ds.shuffle(min(len(X), 1024), reshuffle_each_iteration=True)
    def process(img, mask):
        img = tf.cast(img, tf.float32)
        mask = tf.cast(mask, tf.int32)
        if augment:
            img, mask = augment_img_mask(img, mask)
        return img, one_hot(mask)
    return ds.map(process, num_parallel_calls=tf.data.AUTOTUNE)\
             .batch(batch_size)\
             .prefetch(tf.data.AUTOTUNE)

# Disease-aware oversampling (static per run; simple & effective)
def disease_weight_per_sample(Y):
    w = []
    for m in Y:
        leaf = (m != 0)
        dz   = (m >= 2)
        frac = float(dz.sum()) / max(1, int(leaf.sum()))
        w.append(0.2 + 0.8*min(1.0, frac*10))  # boost images with more lesions
    w = np.asarray(w, np.float64)
    w = w / w.sum()
    return w

def make_weighted_train_ds(X, Y, batch_size=8, augment=True):
    p = disease_weight_per_sample(Y)
    idx = np.random.choice(len(X), size=len(X), replace=True, p=p)
    return make_dataset(X[idx], Y[idx], batch_size, shuffle=True, augment=augment)

# ============ Models ============

# Pretrained U-Net with MobileNetV2 encoder (ImageNet) + fine-tune
def build_unet_mobilenetv2(input_shape, num_classes, train_encoder=False):
    inputs = keras.Input(shape=input_shape)
    x_in = layers.Lambda(lambda t: t*2.0 - 1.0, name="scale_to_mnv2")(inputs)  # [0,1]->[-1,1]
    base = tf.keras.applications.MobileNetV2(
        input_tensor=x_in, include_top=False, weights="imagenet"
    )
    base.trainable = train_encoder

    s1 = base.get_layer('block_1_expand_relu').output   # 128x128
    s2 = base.get_layer('block_3_expand_relu').output   # 64x64
    s3 = base.get_layer('block_6_expand_relu').output   # 32x32
    s4 = base.get_layer('block_13_expand_relu').output  # 16x16
    bn = base.get_layer('block_16_project').output      # 8x8

    def up_block(x, skip, f):
        x = layers.Conv2DTranspose(f, 3, strides=2, padding='same')(x)
        x = layers.Concatenate()([x, skip])
        x = layers.Conv2D(f, 3, padding='same', activation='relu')(x); x = layers.BatchNormalization()(x)
        x = layers.Conv2D(f, 3, padding='same', activation='relu')(x); x = layers.BatchNormalization()(x)
        return x

    x = bn
    x = up_block(x, s4, 256)  # 8->16
    x = up_block(x, s3, 128)  # 16->32
    x = up_block(x, s2, 64)   # 32->64
    x = up_block(x, s1, 32)   # 64->128
    x = layers.Conv2DTranspose(32, 3, strides=2, padding='same')(x)  # 128->256
    outputs = layers.Conv2D(num_classes, 1, activation='softmax')(x)

    model = keras.Model(inputs, outputs, name="UNet_MobileNetV2")
    model._encoder = base
    return model

# UNet (scratch)
def build_unet(input_shape, num_classes, base=48, drop=0.15):
    x_in = keras.Input(shape=input_shape)
    def blk(x,f):
        x = layers.Conv2D(f,3,padding='same',activation='relu')(x)
        x = layers.BatchNormalization()(x)
        x = layers.Conv2D(f,3,padding='same',activation='relu')(x)
        x = layers.BatchNormalization()(x)
        x = layers.SpatialDropout2D(drop)(x)
        return x
    c1 = blk(x_in, base);        p1 = layers.MaxPooling2D(2)(c1)
    c2 = blk(p1, base*2);        p2 = layers.MaxPooling2D(2)(c2)
    c3 = blk(p2, base*4);        p3 = layers.MaxPooling2D(2)(c3)
    c4 = blk(p3, base*8);        p4 = layers.MaxPooling2D(2)(c4)
    bn = blk(p4, base*16)
    u6 = layers.Conv2DTranspose(base*8,2,2,padding='same')(bn); u6 = layers.Concatenate()([u6,c4]); c6 = blk(u6, base*8)
    u7 = layers.Conv2DTranspose(base*4,2,2,padding='same')(c6); u7 = layers.Concatenate()([u7,c3]); c7 = blk(u7, base*4)
    u8 = layers.Conv2DTranspose(base*2,2,2,padding='same')(c7); u8 = layers.Concatenate()([u8,c2]); c8 = blk(u8, base*2)
    u9 = layers.Conv2DTranspose(base,2,2,padding='same')(c8);   u9 = layers.Concatenate()([u9,c1]); c9 = blk(u9, base)
    out = layers.Conv2D(num_classes,1,activation='softmax')(c9)
    return keras.Model(x_in, out, name="UNet")

# DeepLabV3+ (light backbone)
def build_deeplabv3plus(input_shape, num_classes):
    def aspp(x):
        dims = x.shape[-1]; h, w = x.shape[1], x.shape[2]
        pool = layers.GlobalAveragePooling2D()(x); pool = layers.Reshape((1,1,dims))(pool)
        pool = layers.Conv2D(256,1,activation='relu',padding='same')(pool); pool = layers.BatchNormalization()(pool)
        pool = layers.UpSampling2D(size=(h,w), interpolation='bilinear')(pool)
        conv1 = layers.Conv2D(256,1,activation='relu',padding='same')(x); conv1 = layers.BatchNormalization()(conv1)
        c6  = layers.Conv2D(256,3,activation='relu',padding='same',dilation_rate=6)(x);  c6  = layers.BatchNormalization()(c6)
        c12 = layers.Conv2D(256,3,activation='relu',padding='same',dilation_rate=12)(x); c12 = layers.BatchNormalization()(c12)
        c18 = layers.Conv2D(256,3,activation='relu',padding='same',dilation_rate=18)(x); c18 = layers.BatchNormalization()(c18)
        y = layers.Concatenate()([pool, conv1, c6, c12, c18])
        y = layers.Conv2D(256,1,activation='relu',padding='same')(y); y = layers.BatchNormalization()(y)
        return y
    inputs = keras.Input(shape=input_shape)
    x = layers.Conv2D(32, 3, strides=2, padding='same', activation='relu')(inputs); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x); x = layers.BatchNormalization()(x)
    low = x
    x = layers.Conv2D(128, 3, strides=2, padding='same', activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256, 3, padding='same', activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512, 3, strides=2, padding='same', activation='relu')(x); x = layers.BatchNormalization()(x)
    x = aspp(x)
    x = layers.UpSampling2D(size=(4,4), interpolation='bilinear')(x)
    low = layers.Conv2D(48,1,activation='relu',padding='same')(low); low = layers.BatchNormalization()(low)
    x = layers.Concatenate()([x, low])
    x = layers.Conv2D(256,3,activation='relu',padding='same')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256,3,activation='relu',padding='same')(x); x = layers.BatchNormalization()(x)
    x = layers.UpSampling2D(size=(2,2), interpolation='bilinear')(x)
    outputs = layers.Conv2D(num_classes,1,activation='softmax')(x)
    return keras.Model(inputs, outputs, name="DeepLabV3Plus")

# FCN
def build_fcn(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(inputs)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x); p1 = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(p1)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x); p2 = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(p2)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x);  p3 = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(p3)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x);  p4 = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(p4)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x);  p5 = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(4096,7,padding='same',activation='relu')(p5); x = layers.Dropout(0.5)(x)
    x = layers.Conv2D(4096,1,activation='relu')(x); x = layers.Dropout(0.5)(x)
    s5 = layers.Conv2D(num_classes,1)(x); s4 = layers.Conv2D(num_classes,1)(p4); s3 = layers.Conv2D(num_classes,1)(p3)
    up2 = layers.Conv2DTranspose(num_classes,4,strides=2,padding='same')(s5);  f4 = layers.Add()([up2, s4])
    up4 = layers.Conv2DTranspose(num_classes,4,strides=2,padding='same')(f4);  f3 = layers.Add()([up4, s3])
    outputs = layers.Conv2DTranspose(num_classes,16,strides=8,padding='same',activation='softmax')(f3)
    return keras.Model(inputs, outputs, name="FCN")

# SegNet
def build_segnet(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(inputs); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x);     x = layers.BatchNormalization()(x); x = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x); x = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x); x = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x); x = layers.MaxPooling2D(2)(x)
    x = layers.UpSampling2D(2)(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(512,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.UpSampling2D(2)(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(256,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.UpSampling2D(2)(x)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.UpSampling2D(2)(x)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x); x = layers.BatchNormalization()(x)
    outputs = layers.Conv2D(num_classes,1,activation='softmax')(x)
    return keras.Model(inputs, outputs, name="SegNet")

# ============ Loss and Metrics ============
SMOOTH = 1e-6

def _resize_to_label(y_pred, y_true):
    ph = tf.shape(y_pred)[1]; pw = tf.shape(y_pred)[2]
    th = tf.shape(y_true)[1]; tw = tf.shape(y_true)[2]
    need = tf.logical_or(tf.not_equal(ph, th), tf.not_equal(pw, tw))
    def _do(): return tf.image.resize(y_pred, (th, tw), method='bilinear')
    return tf.cond(need, _do, lambda: y_pred)

def weighted_ce(y_true, y_pred):
    y_pred = _resize_to_label(y_pred, y_true)
    w = tf.reduce_sum(CLASS_WEIGHTS * y_true, axis=-1)
    ce = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
    return tf.reduce_mean(ce * w)

def focal_tversky_loss(y_true, y_pred, alpha=0.8, beta=0.2, gamma=0.9, exclude_bg=True):
    y_pred = _resize_to_label(y_pred, y_true)
    if exclude_bg:
        y_true = y_true[...,1:]; y_pred = y_pred[...,1:]
    y_true_f = tf.reshape(y_true, [-1, tf.shape(y_true)[-1]])
    y_pred_f = tf.reshape(y_pred, [-1, tf.shape(y_pred)[-1]])
    tp = tf.reduce_sum(y_true_f * y_pred_f, axis=0)
    fp = tf.reduce_sum((1. - y_true_f) * y_pred_f, axis=0)
    fn = tf.reduce_sum(y_true_f * (1. - y_pred_f), axis=0)
    t = (tp + SMOOTH) / (tp + alpha*fn + beta*fp + SMOOTH)
    ft = tf.pow(1. - t, gamma)
    return tf.reduce_mean(ft)

def combo_loss_stronger(y_true, y_pred, alpha=0.5):
    return alpha * weighted_ce(y_true, y_pred) + (1.0 - alpha) * focal_tversky_loss(y_true, y_pred)

@tf.function
def iou_no_bg(y_true, y_pred):
    y_pred = _resize_to_label(y_pred, y_true)
    yt = tf.argmax(y_true, -1); yp = tf.argmax(y_pred, -1)
    yt_oh = tf.one_hot(yt, NUM_CLASSES, dtype=tf.float32)
    yp_oh = tf.one_hot(yp, NUM_CLASSES, dtype=tf.float32)
    inter = tf.reduce_sum(yt_oh * yp_oh, axis=[0,1,2])
    union = tf.reduce_sum(yt_oh + yp_oh - yt_oh*yp_oh, axis=[0,1,2])
    inter_nb = inter[1:]; union_nb = union[1:]
    iou = tf.where(union_nb > 0.0, inter_nb / (union_nb + 1e-7), 0.0)
    return tf.reduce_mean(iou)

@tf.function
def iou_disease_only(y_true, y_pred):
    """Mean IoU over classes 2..C-1 (diseases only)."""
    y_pred = _resize_to_label(y_pred, y_true)
    yt = tf.argmax(y_true, -1); yp = tf.argmax(y_pred, -1)
    yt_oh = tf.one_hot(yt, NUM_CLASSES, dtype=tf.float32)
    yp_oh = tf.one_hot(yp, NUM_CLASSES, dtype=tf.float32)
    inter = tf.reduce_sum(yt_oh * yp_oh, axis=[0,1,2])
    union = tf.reduce_sum(yt_oh + yp_oh - yt_oh*yp_oh, axis=[0,1,2])
    inter_dz = inter[2:]; union_dz = union[2:]
    iou = tf.where(union_dz > 0.0, inter_dz / (union_dz + 1e-7), 0.0)
    return tf.reduce_mean(iou)

def compile_model(model, lr=1e-3):
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr),
        loss=combo_loss_stronger,
        metrics=[iou_no_bg, iou_disease_only, 'accuracy']
    )
    return model

def plot_history(hist, title, outdir):
    plt.figure(figsize=(11,4))
    plt.subplot(1,3,1)
    plt.plot(hist.history.get('loss', []), label='train')
    plt.plot(hist.history.get('val_loss', []), label='val')
    plt.title('Loss'); plt.legend()
    plt.subplot(1,3,2)
    plt.plot(hist.history.get('iou_no_bg', []), label='train')
    plt.plot(hist.history.get('val_iou_no_bg', []), label='val')
    plt.title('IoU (no-bg)'); plt.legend()
    plt.subplot(1,3,3)
    plt.plot(hist.history.get('iou_disease_only', []), label='train')
    plt.plot(hist.history.get('val_iou_disease_only', []), label='val')
    plt.title('IoU (disease-only)'); plt.legend()
    plt.tight_layout()
    os.makedirs(outdir, exist_ok=True)
    p = os.path.join(outdir, f'{title}_curves.png')
    plt.savefig(p, dpi=140, bbox_inches='tight'); plt.close()
    return p

def compute_auto_class_weights(Y_int, num_classes):
    counts = np.bincount(Y_int.flatten(), minlength=num_classes).astype(np.float64)
    p = counts / max(1, counts.sum())
    w = 1.0 / np.log(1.02 + p + 1e-12)  # inverse log frequency
    w[0] = max(w[0]*0.5, 0.25)          # don't over-weight background
    return w

# ============ Main ============
def main():
    t0 = time.time()
    print("🔎 Loading dataset...")
    ds = AppleLeafDataset(BASE_DIR, IMG_SIZE)
    X, Y = ds.load()

    # splits ≈70/15/15
    Xtr, Xte, Ytr, Yte = train_test_split(X, Y, test_size=0.15, random_state=SEED, shuffle=True)
    Xtr, Xva, Ytr, Yva = train_test_split(Xtr, Ytr, test_size=0.1765, random_state=SEED, shuffle=True)
    print(f"Shapes -> Train: {Xtr.shape}, Val: {Xva.shape}, Test: {Xte.shape}")

    # class weights
    if AUTO_CLASS_WEIGHTS:
        global CLASS_WEIGHTS
        CLASS_WEIGHTS = tf.constant(compute_auto_class_weights(Ytr, NUM_CLASSES), dtype=tf.float32)
        print("Class weights (auto):", CLASS_WEIGHTS.numpy())

    # training dataset (oversampled) and eval datasets
    if OVERSAMPLE_DISEASE:
        train_ds = make_weighted_train_ds(Xtr, Ytr, BATCH_SIZE, augment=True)
    else:
        train_ds = make_dataset(Xtr, Ytr, BATCH_SIZE, shuffle=True, augment=True)
    val_ds   = make_dataset(Xva, Yva, BATCH_SIZE, shuffle=False, augment=False)
    test_ds  = make_dataset(Xte, Yte, BATCH_SIZE, shuffle=False, augment=False)

    input_shape = (IMG_SIZE, IMG_SIZE, 3)

    builders = {
        'UNet_MobileNetV2': build_unet_mobilenetv2,
        'UNet': build_unet,
        'DeepLabV3Plus': build_deeplabv3plus,
        'FCN': build_fcn,
        'SegNet': build_segnet
    }

    model_names = list(builders.keys())
    if START_AT and START_AT in model_names:
        model_names = model_names[model_names.index(START_AT):]
    if RUN_ONLY:
        model_names = [n for n in model_names if n in RUN_ONLY]

    results = []
    os.makedirs(OUTDIR, exist_ok=True)

    for name in model_names:
        tf.keras.backend.clear_session()
        ckpt_path = os.path.join(OUTDIR, f"{name}_best.keras")

        if SKIP_TRAIN_IF_CKPT and os.path.exists(ckpt_path):
            print(f"⏭️  {name}: checkpoint found -> loading for eval")
            try:
                model = keras.models.load_model(ckpt_path, compile=False,
                                                custom_objects={'iou_no_bg': iou_no_bg,
                                                                'iou_disease_only': iou_disease_only,
                                                                'combo_loss_stronger': combo_loss_stronger})
                compile_model(model, lr=1e-3)
            except Exception as e:
                print(f"⚠️  Failed to load {name} checkpoint: {e}")
                print("🔄 Training from scratch...")
                model = builders[name](input_shape, NUM_CLASSES)
                # training branch below
                pass

        if not (SKIP_TRAIN_IF_CKPT and os.path.exists(ckpt_path)):
            print(f"\n🚀 Training {name} ...")
            model = builders[name](input_shape, NUM_CLASSES)

            # Two-phase fine-tune if pretrained encoder exists
            if hasattr(model, "_encoder"):
                # Phase 1: freeze encoder
                model._encoder.trainable = False
                compile_model(model, lr=1e-3)
                hist1 = model.fit(train_ds, validation_data=val_ds,
                                  epochs=max(4, EPOCHS//2), verbose=1)
                # Phase 2: unfreeze encoder
                model._encoder.trainable = True
                compile_model(model, lr=3e-4)
                cbs = [
                    keras.callbacks.ModelCheckpoint(ckpt_path, monitor='val_iou_disease_only', mode='max',
                                                    save_best_only=True, save_weights_only=False, verbose=1),
                    keras.callbacks.EarlyStopping(monitor='val_iou_disease_only', mode='max',
                                                  patience=8, restore_best_weights=True),
                    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3,
                                                      min_lr=1e-5, verbose=1)
                ]
                hist2 = model.fit(train_ds, validation_data=val_ds,
                                  initial_epoch=max(4, EPOCHS//2), epochs=EPOCHS, verbose=1, callbacks=cbs)
                _ = plot_history(hist2, name, OUTDIR)
            else:
                compile_model(model, lr=1e-3)
                cbs = [
                    keras.callbacks.ModelCheckpoint(ckpt_path, monitor='val_iou_disease_only', mode='max',
                                                    save_best_only=True, save_weights_only=False, verbose=1),
                    keras.callbacks.EarlyStopping(monitor='val_iou_disease_only', mode='max',
                                                  patience=8, restore_best_weights=True),
                    keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3,
                                                      min_lr=1e-5, verbose=1)
                ]
                hist = model.fit(train_ds, validation_data=val_ds,
                                 epochs=EPOCHS, verbose=1, callbacks=cbs)
                _ = plot_history(hist, name, OUTDIR)

        # Evaluate (val + test)
        print(f"📏 Evaluating {name} ...")
        val_metrics  = model.evaluate(val_ds,  verbose=0)
        test_metrics = model.evaluate(test_ds, verbose=0)

        # Keras orders: [loss, iou_no_bg, iou_disease_only, accuracy]
        res = {
            "Model": name,
            "Val Loss": float(val_metrics[0]),
            "Val IoU(no-bg)": float(val_metrics[1]),
            "Val IoU(disease)": float(val_metrics[2]),
            "Val Acc": float(val_metrics[3]),
            "Test Loss": float(test_metrics[0]),
            "Test IoU(no-bg)": float(test_metrics[1]),
            "Test IoU(disease)": float(test_metrics[2]),
            "Test Acc": float(test_metrics[3]),
        }
        results.append(res)

        # Visual previews with severity
        _ = visualize_with_severity(model, Xva, Yva, n=4, outdir=OUTDIR)
        _ = visualize_grid_with_severity(model, Xva, Yva, k=6, outdir=OUTDIR,
                                         title=f"{name} — WeightedCE + Focal-Tversky")

    # Save/append comparison CSV
    cmp_csv = os.path.join(OUTDIR, "model_comparison.csv")
    df_new = pd.DataFrame(results)
    if os.path.exists(cmp_csv):
        try:
            df_old = pd.read_csv(cmp_csv)
            df = (pd.concat([df_old, df_new], ignore_index=True)
                    .drop_duplicates(subset=['Model'], keep='last'))
        except Exception:
            df = df_new
    else:
        df = df_new
    df = df.sort_values("Val IoU(disease)", ascending=False)
    df.to_csv(cmp_csv, index=False)
    print("\n================ Model Comparison (by Val IoU disease-only) ================")
    print(df.to_string(index=False))
    print(f"\n💾 Saved: {cmp_csv}")
    print(f"✅ Done in {time.time()-t0:.1f}s")

if __name__ == "__main__":
    main()


⚠️  GPU mem-growth not set: Physical devices cannot be modified after being initialized
🔎 Loading dataset...
📂 Alternaria leaf spot | imgs: 278 | masks: 278 | paired: 278
📂 Brown spot           | imgs: 215 | masks: 215 | paired: 215
📂 Gray spot            | imgs: 395 | masks: 395 | paired: 395
📂 Healthy leaf         | imgs: 409 | masks: 409 | paired: 409
📂 Rust                 | imgs: 344 | masks: 344 | paired: 344
✅ Paired samples: 1641


Loading data: 100%|██████████| 1641/1641 [00:20<00:00, 80.27it/s]


Shapes -> Train: (1147, 256, 256, 3), Val: (247, 256, 256, 3), Test: (247, 256, 256, 3)
Class weights (auto): [ 0.9594292  3.3929112 41.759647  48.69141   47.26587   37.653244 ]
⏭️  UNet_MobileNetV2: checkpoint found -> loading for eval
⚠️  Failed to load UNet_MobileNetV2 checkpoint: The `{arg_name}` of this `Lambda` layer is a Python lambda. Deserializing it is unsafe. If you trust the source of the config artifact, you can override this error by passing `safe_mode=False` to `from_config()`, or calling `keras.config.enable_unsafe_deserialization().
🔄 Training from scratch...


  base = tf.keras.applications.MobileNetV2(


📏 Evaluating UNet_MobileNetV2 ...


ValueError: You must call `compile()` before using the model.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Path to your image
img_path = "/kaggle/input/apple-dataset/ATLDSD/Alternaria leaf spot/image/000413.jpg"

# Load the image
img = mpimg.imread(img_path)

# Display the image
plt.imshow(img)
plt.axis("off")  # Hide axes
plt.show()
