<a href="https://colab.research.google.com/github/ariegever/ImageProcessing_Project/blob/main/2_unet_preprocess.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# === HOW TO USE THIS NOTEBOOK ===
#
# 1.  **Check Your Assets:**
#     * Go to your Earth Engine Assets tab.
#     * You must have 3 assets uploaded:
#         1. Your Sentinel-2 Image
#         2. Your Sentinel-1 Image
#         3. Your Land Cover Mask
#     * Get the Asset ID for each one (e.g., "projects/...")
#
# 2.  **Configuration:**
#     * Configuration is now handled in `config.py`.
#
# 3.  **Run All Cells (in order):**
#     * The script will save the final `.tfrecord.gz` file to the
#       Google Drive path specified in `config.py`.

In [None]:
from google.colab import drive
import config
drive.mount(config.DRIVE_MOUNT_PATH)
from google.colab import auth
import google.auth
import ee
# Trigger the authentication flow.
auth.authenticate_user()
# Get credentials and initialize Earth Engine
credentials, project = google.auth.default()
ee.Initialize(credentials, project=config.PROJECT_ID, opt_url='https://earthengine-highvolume.googleapis.com')

print(f"Successfully initialized Earth Engine for project: {config.PROJECT_ID}")

In [None]:
import json
import os
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import io
import requests
import concurrent.futures
from google.api_core import retry
from numpy.lib import recfunctions as rfn
from PIL import ImageColor
from matplotlib.colors import ListedColormap
from skimage.exposure import rescale_intensity
import utils

In [None]:
# === Configuration from config.py ===

# Create the output directory in Google Drive
os.makedirs(config.DRIVE_IMAGES_PATH, exist_ok=True)
OUTPUT_FILE = os.path.join(config.DRIVE_IMAGES_PATH, config.TFRECORD_FILE)

print(f"Project: {config.PROJECT_ID}")
print(f"Feature bands: {config.FEATURE_NAMES}")
print(f"Output will be saved to: {OUTPUT_FILE}")
print(f" expecting coordinates at: {config.POINTS_ASSET_PATH}")

In [None]:
import json
import pandas as pd
from matplotlib.colors import ListedColormap


try:
    with open(config.CLASS_JSON_PATH) as f:
        lc = json.load(f)
except FileNotFoundError:
    print(f"ERROR: '{config.CLASS_JSON_PATH}' not found.")
    raise
except json.JSONDecodeError:
    print(f"ERROR: '{config.CLASS_JSON_PATH}' is not a valid JSON file.")
    raise

# Load from a dict (object) instead of a list
lc_df = pd.DataFrame.from_dict(lc, orient='index')

# Rename your columns to match what the script expects
lc_df = lc_df.rename(columns={'class': 'label', 'color': 'palette'})

# This line IS THE SAME AS YOUR SCRIPT. It creates new normalized values 1, 2, 3, 4, 5
lc_df["values_normalize"] = lc_df.index.astype(int) + 1

from_values = []
to_values = []

for index, row in lc_df.iterrows():
    # Get the new value this script just created (1, 2, 3, 4, or 5)
    new_normalized_value = row['values_normalize']

    # Find all original values that map to it
    for original_class in row['original_classes']:
        from_values.append(original_class['values'])
        to_values.append(new_normalized_value)

# Get palette for plotting
palette_hex = lc_df["palette"].to_list()

cmap = ListedColormap(palette_hex)
vmin = 1
vmax = len(palette_hex) # This will be 5

print(f"Loaded {len(lc_df)} AGGREGATED classes from {config.CLASS_JSON_PATH}.")
print(f"Remapping {from_values} -> {to_values}")
lc_df



In [None]:
# 1. Load the S2, S1, and Land Cover images
try:
    s2_image = ee.Image(config.S2_ASSET_ID).select(config.S2_BANDS)
    s1_image = ee.Image(config.S1_ASSET_ID).select(config.S1_BANDS)
    lc_image = ee.Image(config.LC_ASSET_ID) # This should be a single-band image
except ee.EEException as e:
    print(f"Error loading assets: {e}")
    print("Check your Asset Paths in config.py.")
    raise

# 2. Combine S1 and S2 into one feature image
feature_image = s2_image.addBands(s1_image)

# 3. Remap the land cover image
lc_remapped = lc_image.remap(from_values, to_values, 0, lc_image.bandNames().get(0))
lc_remapped = lc_remapped.rename(config.LABEL_NAME).toUint8() # Rename band to 'label'

# 4. Stack all bands into one image (features + label)
all_bands_image = feature_image.addBands(lc_remapped)

ALL_BANDS = config.FEATURE_NAMES + [config.LABEL_NAME]
print("Earth Engine assets loaded and stacked.")
print(f"Total bands in stacked image: {ALL_BANDS}")

In [None]:
# Patch extraction functions are now in utils.py
print("Patch extraction functions defined in utils.py.")

In [None]:
# Coordinate loading function is now in utils.py
print("Coordinate loading function defined in utils.py.")

In [None]:
def array_to_example(structured_array):
  """Serialize a structured numpy array into a tf.Example proto."""
  feature = {}

  # Create default "empty" patches of all zeros
  default_float_patch = np.zeros(config.PATCH_SIZE * config.PATCH_SIZE, dtype=np.float32)
  default_int_patch = np.zeros(config.PATCH_SIZE * config.PATCH_SIZE, dtype=np.int64)

  # Loop through all expected FEATURE bands
  for f in config.FEATURE_NAMES:
    if (f in structured_array.dtype.names and
        structured_array[f].size == config.PATCH_SIZE * config.PATCH_SIZE):

        patch_data = structured_array[f].flatten().astype(np.float32)
        # Handle NaNs if any (though GEE usually handles this, good to be safe)
        patch_data = np.nan_to_num(patch_data, nan=0.0)
        feature[f] = tf.train.Feature(float_list=tf.train.FloatList(value=patch_data))
    else:
        # Fallback for missing bands (shouldn't happen if assets are correct)
        feature[f] = tf.train.Feature(float_list=tf.train.FloatList(value=default_float_patch))

  # Handle the LABEL band
  if (config.LABEL_NAME in structured_array.dtype.names and
      structured_array[config.LABEL_NAME].size == config.PATCH_SIZE * config.PATCH_SIZE):

      label_data = structured_array[config.LABEL_NAME].flatten().astype(np.int64)
      feature[config.LABEL_NAME] = tf.train.Feature(int64_list=tf.train.Int64List(value=label_data))
  else:
      feature[config.LABEL_NAME] = tf.train.Feature(int64_list=tf.train.Int64List(value=default_int_patch))

  return tf.train.Example(features=tf.train.Features(feature=feature))

In [None]:
# --- MAIN EXECUTION ---

# 1. Load coordinates
coords_list = utils.load_coords_from_asset(config.POINTS_ASSET_PATH)

if not coords_list:
    print("No coordinates found. Exiting.")
else:
    print(f"Found {len(coords_list)} points. Starting patch extraction...")

    # 2. Open TFRecord writer
    options = tf.io.TFRecordOptions(compression_type="GZIP")
    with tf.io.TFRecordWriter(OUTPUT_FILE, options=options) as writer:
        success_count = 0
        error_count = 0

        # 3. Process points (Sequential for simplicity and to avoid rate limits)
        # You can use concurrent.futures for parallel processing if you handle rate limits.
        for i, coords in enumerate(coords_list):
            if i % 10 == 0:
                print(f"Processing point {i+1}/{len(coords_list)}...")

            try:
                # Fetch patch
                patch_array = utils.get_patch(coords, all_bands_image)

                # Serialize
                example = array_to_example(patch_array)

                # Write
                writer.write(example.SerializeToString())
                success_count += 1

            except Exception as e:
                print(f"Error processing point {i}: {e}")
                error_count += 1

    print(f"\nProcessing complete.")
    print(f"Successfully wrote {success_count} patches.")
    print(f"Failed points: {error_count}")
    print(f"Output file: {OUTPUT_FILE}")