<a href="https://colab.research.google.com/github/ariegever/ImageProcessing_Project/blob/main/4_unet_modeling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#=== HOW TO USE THIS NOTEBOOK ===
#
# 1.  **Set Runtime:** Go to "Runtime" > "Change runtime type" and select
#     "T4 GPU" (or any available GPU) as the "Hardware accelerator".
#     This is *critical* for training your model.
#
# 2.  **Configuration:**
#     * Configuration is now handled in `config.py`.
#
# 3.  **Run All Remaining Cells:**
#     * The notebook will load your class definitions, prepare the data
#       pipeline, build your U-Net model, and start training.

In [None]:
from google.colab import drive
import config
drive.mount(config.DRIVE_MOUNT_PATH)
from google.colab import auth
import google.auth
import ee
# Trigger the authentication flow.
auth.authenticate_user()
# Get credentials and initialize Earth Engine
credentials, project = google.auth.default()
ee.Initialize(credentials, project=config.PROJECT_ID, opt_url='https://earthengine-highvolume.googleapis.com')

print(f"Successfully initialized Earth Engine for project: {config.PROJECT_ID}")
!pip install  scikit-learn scikit-image rasterio
!pip install keras

import json
import os
import pandas as pd
import numpy as np
import io
import matplotlib.pyplot as plt
import utils

# TensorFlow / Keras
import tensorflow as tf
import tensorflow as keras
from keras.layers import Conv2D, Conv2DTranspose, Dropout, MaxPooling2D, Input, concatenate
from keras.models import Model
from keras.callbacks import EarlyStopping, ModelCheckpoint

# Scikit-learn for metrics
# We will add cohen_kappa_score here
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report, cohen_kappa_score

# Matplotlib helpers
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap
from skimage.exposure import rescale_intensity

In [None]:
# === Configuration from config.py ===

TFRECORD_FILE_PATH = os.path.join(config.DRIVE_IMAGES_PATH, config.TFRECORD_FILE)
HISTORY_CSV_PATH = os.path.join(config.DRIVE_IMAGES_PATH, config.HISTORY_CSV_FILENAME)

# This is where your final trained model will be saved
MODEL_SAVE_PATH = os.path.join(config.DRIVE_IMAGES_PATH, config.MODEL_FILENAME)
LATEST_CHECKPOINT_PATH = os.path.join(config.DRIVE_IMAGES_PATH, 'latest_checkpoint.weights.h5')
BEST_CHECKPOINT_PATH = os.path.join(config.DRIVE_IMAGES_PATH, 'best_checkpoint.weights.h5')

print(f"--- Configuration ---")
print(f"Project: {config.PROJECT_ID}")
print(f"TFRecord File: {TFRECORD_FILE_PATH}")
print(f"Model Save Path: {MODEL_SAVE_PATH}")
print(f"Input Bands ({config.NUM_BANDS}): {config.FEATURE_NAMES}")
print(f"Batch Size: {config.BATCH_SIZE}")

In [None]:
import json
import pandas as pd
from matplotlib.colors import ListedColormap

# 1. Load the JSON
try:
    with open(config.CLASS_JSON_PATH) as f:
        lc = json.load(f)
except FileNotFoundError:
    print(f"ERROR: '{config.CLASS_JSON_PATH}' not found.")
    raise
except json.JSONDecodeError:
    print(f"ERROR: '{config.CLASS_JSON_PATH}' is not a valid JSON file.")
    raise


# 2. Create DataFrame for the Main Classes (Parent Groups)
lc_df = pd.DataFrame.from_dict(lc, orient='index')
lc_df = lc_df.rename(columns={'class': 'label', 'color': 'palette'})

# Define the target values (e.g., 1, 2, 3...)
lc_df["target_value"] = lc_df.index.astype(int) + 1

# 3. Flatten the nested data to create 'from_values' and 'to_values'
# We map the specific IDs (11, 21) to the PARENT ID (1, 2)
from_values = []
to_values = []

for index, row in lc_df.iterrows():
    target = row['target_value']
    for item in row['original_classes']:
        from_values.append(item['values']) # e.g. 21
        to_values.append(target)       # e.g. 2

# 4. Setup Palette and Visuals for 5 Classes
class_labels = lc_df["label"].to_list()
palette_hex = lc_df["palette"].to_list()
cmap = ListedColormap(palette_hex)
NUM_CLASSES = len(lc_df) + 1
vmin = 0
vmax = len(class_labels)

print(f"Reduced complexity: Mapping {len(from_values)} specific types to {len(lc_df)} parent classes.")
lc_df

In [None]:
# Data pipeline functions are now in utils.py
print("TFRecord data pipeline functions defined in utils.py.")

In [None]:
# --- 1. Count the total number of samples ---
print(f"Counting samples in {TFRECORD_FILE_PATH}...")
count_ds = tf.data.TFRecordDataset(TFRECORD_FILE_PATH, compression_type="GZIP")
total_samples = 0
for _ in count_ds:
    total_samples += 1

print(f"Total samples found: {total_samples}")

# --- 2. Define splits ---
train_size = int(total_samples * config.TRAIN_SPLIT)
val_size = int(total_samples * config.VAL_SPLIT)
test_size = total_samples - train_size - val_size

print(f"Train: {train_size}, Val: {val_size}, Test: {test_size}")

# --- 3. Create the main dataset ---
full_dataset = utils.create_dataset(TFRECORD_FILE_PATH, NUM_CLASSES, is_training=False)

# Unbatch to split, then re-batch
full_dataset = full_dataset.unbatch()

train_ds = full_dataset.take(train_size)
remaining = full_dataset.skip(train_size)
val_ds = remaining.take(val_size)
test_ds = remaining.skip(val_size)

# --- 4. Prepare for Training ---
# Apply shuffling and augmentation ONLY to training data
train_ds = train_ds.shuffle(config.BUFFER_SIZE).map(utils.augment_data, num_parallel_calls=tf.data.AUTOTUNE).batch(config.BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.batch(config.BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.batch(config.BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

print("Datasets prepared.")

In [None]:
def build_unet_model(input_shape, num_classes):
    inputs = Input(input_shape)

    # Contraction path
    c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(inputs)
    c1 = Dropout(0.1)(c1)
    c1 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)

    c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = Dropout(0.1)(c2)
    c2 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = MaxPooling2D((2, 2))(c2)

    c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = Dropout(0.2)(c3)
    c3 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    p3 = MaxPooling2D((2, 2))(c3)

    c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
    c4 = Dropout(0.2)(c4)
    c4 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
    p4 = MaxPooling2D(pool_size=(2, 2))(c4)

    c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
    c5 = Dropout(0.3)(c5)
    c5 = Conv2D(256, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c5)

    # Expansive path
    u6 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = Dropout(0.2)(c6)
    c6 = Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)

    u7 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = Dropout(0.2)(c7)
    c7 = Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)

    u8 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = Dropout(0.1)(c8)
    c8 = Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)

    u9 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = concatenate([u9, c1], axis=3)
    c9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = Dropout(0.1)(c9)
    c9 = Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)

    outputs = Conv2D(num_classes, (1, 1), activation='softmax')(c9)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model

input_shape = (config.PATCH_SIZE, config.PATCH_SIZE, config.NUM_BANDS)
model = build_unet_model(input_shape, NUM_CLASSES)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

In [None]:
# --- Train the model ---
callbacks = [
    EarlyStopping(patience=10, verbose=1, restore_best_weights=True),
    ModelCheckpoint(BEST_CHECKPOINT_PATH, save_best_only=True, save_weights_only=True, verbose=1)
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=config.EPOCHS,
    callbacks=callbacks
)

# Save the final model
model.save(MODEL_SAVE_PATH)
print(f"Model saved to {MODEL_SAVE_PATH}")

# Save history
history_df = pd.DataFrame(history.history)
history_df.to_csv(HISTORY_CSV_PATH)
print(f"History saved to {HISTORY_CSV_PATH}")