In [None]:
import os
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler

from vgg16_unet_model import build_vgg16_unet

# =========================
# Reproducibility (good for reviewer)
# =========================
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# =========================
# Parameters
# =========================
tile_size = 256
input_channels = 6
batch_size = 32
epochs = 100

# =========================
# Load data
# =========================
tiles_img, tiles_mask = load_data("datasets_temp/images", "datasets_temp/masks", tile_size)
tiles_mask = tiles_mask.reshape(tiles_mask.shape + (1,)).astype(np.uint8)

print("img min/max:", float(tiles_img.min()), float(tiles_img.max()))
print("mask unique:", np.unique(tiles_mask))

# =========================
# Split
# =========================
X_train, X_test, y_train, y_test = train_test_split(
    tiles_img, tiles_mask, test_size=0.3, random_state=SEED
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.1, random_state=SEED
)

# =========================
# Scaling (handles 0..10000 reflectance)
# =========================
@tf.function
def scale_tf_image(image):
    image = tf.cast(image, tf.float32)
    mx = tf.reduce_max(image)
    image = tf.cond(mx > 300.0, lambda: image / 10000.0, lambda: image)
    image = tf.cond((mx > 1.5) & (mx <= 300.0), lambda: image / 255.0, lambda: image)
    return image

def preprocess_tf(image, mask):
    image = scale_tf_image(image)
    mask = tf.cast(mask, tf.float32)
    return image, mask

def augment_tf(image, mask):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    mask  = tf.image.random_flip_left_right(mask)
    mask  = tf.image.random_flip_up_down(mask)

    k = tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)
    image = tf.image.rot90(image, k)
    mask  = tf.image.rot90(mask, k)
    return image, mask

train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_ds = train_ds.map(preprocess_tf, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.map(augment_tf, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.shuffle(200).batch(batch_size).prefetch(tf.data.AUTOTUNE)

val_ds = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_ds = val_ds.map(preprocess_tf, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)

test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_ds = test_ds.map(preprocess_tf, num_parallel_calls=tf.data.AUTOTUNE)
test_ds = test_ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)

# =========================
# Model
# =========================
model = build_vgg16_unet(
    input_shape=(tile_size, tile_size, input_channels),
    freeze_encoder=True,
    use_imagenet_weights=True
)

model.compile(
    optimizer=Adam(1e-4),
    loss="binary_crossentropy",
    metrics=["accuracy"]
)

model.summary()

# =========================
# LR schedule
# =========================
def lr_schedule(epoch, base=1e-4, decay=0.9, step=10):
    n = epoch // step
    return base * (decay ** n)

callbacks = [
    LearningRateScheduler(lr_schedule),
    EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
    ModelCheckpoint("best_vgg16_unet_6band.h5", monitor="val_loss", save_best_only=True)
]

# =========================
# Train
# =========================
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    callbacks=callbacks
)

# =========================
# Evaluate + save
# =========================
test_loss, test_acc = model.evaluate(test_ds, verbose=1)
print("Test loss:", float(test_loss), "Test acc:", float(test_acc))

model.save("vgg16_unet_6band_fixedscale_fullmodel.h5")

# =========================
# Plot history
# =========================
plt.figure(figsize=(8, 5))
plt.plot(history.history["loss"], label="train loss")
plt.plot(history.history["val_loss"], label="val loss")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend()
plt.tight_layout()
plt.show()

# Save splits so Grad-CAM uses the same test set without re-splitting
np.savez_compressed(
    "splits_vgg16_unet_6band.npz",
    X_test=X_test, y_test=y_test
)
print("Saved test split: splits_vgg16_unet_6band.npz")
