# Step 1: Defining the Study Area

In [None]:
import os
import yaml
from pathlib import Path
import geopandas as gpd

import matplotlib.pyplot as plt
from matplotlib.patches import Patch

import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from rasterio.mask import mask
from rasterio.plot import show

import requests
import numpy as np

from dotenv import load_dotenv
from osgeo import gdal
import ee

import xarray as xr
import rioxarray

import rioxarray as rxr
from dask.distributed import Client, LocalCluster
import pandas as pd

In [None]:


# -----------------------------------------------------------------------------
# Load config.yml
# -----------------------------------------------------------------------------

# Get project root (adjust based on your folder depth)
current_dir = Path(os.getcwd())
project_root = current_dir.parent.parent  # Navigate up from "Scripts/Phase1_Data_Preprocessing"

with open(project_root / "config.yml", "r") as f:
    config = yaml.safe_load(f)

# Print the config dictionary to debug


# -----------------------------------------------------------------------------
# Construct paths
# -----------------------------------------------------------------------------

# Raw data paths
raw_data_dir = project_root / config["paths"]["raw_data"]
soil_raw_dir = raw_data_dir / "GIS/Soil"  # Matches your hardcoded path structure
morocco_path = raw_data_dir / config["paths"]["morocco_path"]
tadla_plain_path = raw_data_dir / config["paths"]["tadla_plain_raw"]
tadla_plain_boundary_path = raw_data_dir / config["paths"]["tadla_plain_boundary_raw"]
soil_raw_path = raw_data_dir / config["paths"]["soil_raw"]
dem_raw_path = raw_data_dir / config["paths"]["dem_raw"]
chirps_raw_path = raw_data_dir / config["paths"]["chirps_raw"]
era5_raw_path = raw_data_dir / config["paths"]["era5_raw"]
wv0010_raw_path = raw_data_dir / config["paths"]["wv0010_raw"]
ndvi_path = raw_data_dir / config["paths"]["ndvi_raw"]


land_use_raw_dir = raw_data_dir / config["paths"]["land_use_raw"]


# Processed data paths
processed_data_dir = project_root / config["paths"]["processed_data"]
soil_processed_dir = processed_data_dir / "GIS/Soil"
output_dir = processed_data_dir / "GIS/Study_Area_Boundary"
output_path = output_dir / "Tadla_plain_common.shp"
tadla_common_path = processed_data_dir / config["paths"]["tadla_boundary_processed"]
soil_processed_path = processed_data_dir / config["paths"]["soil_processed"]
dem_processed_path = processed_data_dir / config["paths"]["dem_processed"]
slope_path = processed_data_dir / "GIS/Topography/tadla_slope.tif"
aspect_path = processed_data_dir / "GIS/Topography/tadla_aspect.tif"
chirps_processed_path = processed_data_dir / config["paths"]["chirps_processed"]
era5_processed_path = processed_data_dir / config["paths"]["era5_processed"]
wv0010_processed_path = processed_data_dir / config["paths"]["wv0010_processed"]
topography_processed_dir = processed_data_dir / "GIS/Topography"

land_use_processed_dir = processed_data_dir / config["paths"]["land_use_processed"]

# Harmonized data paths
harmonized_dir = Path(config["paths"]["harmonized_data"])
weather_processed_dir = processed_data_dir / "Weather"
chirps_output_dir = Path(config["paths"]["chirps_dir"])

output_path_dataset = harmonized_dir / "tadla_spatiotemporal_dataset.nc"



# Ensure output directories exist
harmonized_dir.mkdir(exist_ok=True, parents=True)

# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
os.makedirs(era5_processed_path.parent, exist_ok=True)

In [None]:
# Load Morocco boundary
morocco = gpd.read_file(morocco_path)

# Check the first few rows to see province names
morocco.head()

In [None]:
print(morocco.crs)

In [None]:
morocco_merchiche = morocco.to_crs(epsg=26191)

In [None]:
morocco_merchiche.plot()

In [None]:
# Load Tadla Plain shapefile
tadla_plain_polygon = gpd.read_file(tadla_plain_path)

# Check the data
print(tadla_plain_polygon)  # Show first few rows


In [None]:
tadla_plain_polygon.plot()  # Plot the geometry

In [None]:
print(f"Study area size: {tadla_plain_polygon.geometry.area} m²") 

In [None]:
# Reproject to Merchich (EPSG:26191)
tadla_merchiche = tadla_plain_polygon.to_crs(epsg=26191)

# Calculate area
area_m2 = tadla_merchiche.geometry.area
print(f"Study area size: {area_m2[0]:.2f} m²")  
# Example output: "Study area size: 1300000000.00 m²"

area_ha = area_m2 / 10000
print(f"Study area size: {area_ha[0]:.2f} hectares")  
# Example output: "Study area size: 130000.00 hectares"


In [None]:
tadla_merchiche.plot()

In [None]:
# Load the cleaned boundary shapefile
Tadla_plain_boundary = gpd.read_file(tadla_plain_boundary_path)
# Check the current CRS
print(Tadla_plain_boundary.crs)

In [None]:
# Convert to Merchich CRS if needed
if Tadla_plain_boundary.crs != "EPSG:26191":
    Tadla_plain_boundary = Tadla_plain_boundary.to_crs(epsg=26191)


In [None]:
Tadla_plain_boundary.plot()

In [None]:
# Assume these are already loaded and in the same CRS (EPSG:26191)
# tadla_merchiche: full administrative boundary (Merchich)
# tadla_plain_polygone: digitized Tadla plain (which may be slightly off)

# Compute the common (intersecting) area between the two layers
tadla_plain = gpd.overlay(Tadla_plain_boundary, tadla_merchiche, how='intersection')

# Save the resulting common area shapefile for further analysis
tadla_plain.to_file(output_path)

In [None]:


# Plot layers with explicit labels
fig, ax = plt.subplots(figsize=(8, 8))
tadla_merchiche.plot(ax=ax, facecolor="none", edgecolor="red", linewidth=2)
Tadla_plain_boundary.plot(ax=ax, facecolor="blue", alpha=0.5, edgecolor="blue")
tadla_plain.plot(ax=ax, facecolor="green", alpha=0.5, edgecolor="black")

# Create custom legend
legend_labels = {
    "Full Admin Boundary": "red",
    "Digitized Tadla Plain": "blue",
    "Common Area": "green"
}
patches = [Patch(color=color, label=label) for label, color in legend_labels.items()]
plt.legend(handles=patches)

plt.title("Common Area between Tadla Plain and Full Admin Boundary")
plt.show()

In [None]:
# Plot layers with explicit labels
fig, ax = plt.subplots(figsize=(8, 8))
tadla_merchiche.plot(ax=ax, facecolor="none", edgecolor="red", linewidth=2)
Tadla_plain_boundary.plot(ax=ax, facecolor="blue", alpha=0.5, edgecolor="blue")
tadla_plain.plot(ax=ax, facecolor="green", alpha=0.5, edgecolor="black")
morocco_merchiche.plot(ax=ax, facecolor="none", edgecolor="brown", linewidth=1)

# Create custom legend
legend_labels = {
    "Full Admin Boundary": "red",
    "Digitized Tadla Plain": "blue",
    "Common Area": "green",
    "Morocco": "brown"
}
patches = [Patch(color=color, label=label) for label, color in legend_labels.items()]
plt.legend(handles=patches)

plt.title("Common Area between Tadla Plain and Full Admin Boundary of Morocco")
plt.show()

In [None]:
tadla_plain = tadla_plain.to_crs(epsg=26191)  # Ensure projection
tadla_merchiche = tadla_merchiche.to_crs(epsg=26191)

area_plain_m2 = tadla_plain.geometry.area.sum()
area_full_m2 = tadla_merchiche.geometry.area.sum()

print(f"Tadla Plain area: {area_plain_m2:.2f} m²")
print(f"Full Admin Boundary area: {area_full_m2:.2f} m²")


In [None]:


def reproject_raster(input_path, output_path, target_crs):
    with rasterio.open(input_path) as src:
        transform, width, height = calculate_default_transform(
            src.crs, target_crs, src.width, src.height, *src.bounds
        )
        metadata = src.meta.copy()
        metadata.update({
            "crs": target_crs,
            "transform": transform,
            "width": width,
            "height": height
        })

        with rasterio.open(output_path, "w", **metadata) as dest:
            reproject(
                source=rasterio.band(src, 1),
                destination=rasterio.band(dest, 1),
                src_transform=src.transform,
                src_crs=src.crs,
                dst_transform=transform,
                dst_crs=target_crs
            )

# Step 2: Downloading Soil Data (SoilGrids)

In [None]:

# Load Tadla boundary (EPSG:26191)
tadla = gpd.read_file(tadla_common_path)
tadla = tadla.to_crs("EPSG:26191")

# Get bounding box in Merchich coordinates
minx, miny, maxx, maxy = tadla.total_bounds
print(f"X: {minx}, {maxx}")  # Easting bounds
print(f"Y: {miny}, {maxy}")  # Northing bounds

### 1. Defining Parameters

In [None]:
# Bounding box of Tadla Plain in EPSG:26191 (from your URL)
minx, maxx = 301450, 490490  # X (Easting)
miny, maxy = 158150, 244870   # Y (Northing)

# Soil layers and their COVERAGEIDs (adjust if needed)
layers = {
    "clay": "clay_0-5cm_mean",
    "silt": "silt_0-5cm_mean",
    "sand": "sand_0-5cm_mean",
    "ocd": "ocd_0-5cm_mean",    # Organic carbon density
    "wv0010": "wv0010_0-5cm_mean"     # Water content at saturation
}

In [None]:
from pyproj import Transformer

# Create a transformer from EPSG:4326 to EPSG:26191
transformer = Transformer.from_crs("EPSG:4326", "EPSG:26191", always_xy=True)

# Transform the lower-left corner (-7.5, 32.0)
easting_min, northing_min = transformer.transform(-7.5, 32.0)
print(easting_min, northing_min)  # Expected: ~339200, ~164400

# Transform the upper-right corner (-5.5, 32.8)
easting_max, northing_max = transformer.transform(-5.5, 32.8)
print(easting_max, northing_max)  # Expected: ~459750, ~241200


### 2. Python Script to Download All Layers

In [None]:

os.makedirs(soil_raw_dir, exist_ok=True)

for param, coverage_id in layers.items():
    url = (
        f"https://maps.isric.org/mapserv?map=/map/{param}.map&"
        f"SERVICE=WCS&"
        f"VERSION=2.0.1&"
        f"REQUEST=GetCoverage&"
        f"COVERAGEID={coverage_id}&"
        f"FORMAT=GEOTIFF_INT16&"  # Or GEOTIFF_FLOAT32 for raw values
        f"SUBSET=X({minx},{maxx})&"
        f"SUBSET=Y({miny},{maxy})&"
        f"SUBSETTINGCRS=http://www.opengis.net/def/crs/EPSG/0/26191&"
        f"OUTPUTCRS=http://www.opengis.net/def/crs/EPSG/0/26191"
    )
    print(url)
    
    # Download and save
    response = requests.get(url)
    if response.status_code == 200:
        output_path = os.path.join(soil_raw_dir, f"tadla_{param}.tif")
        with open(output_path, "wb") as f:
            f.write(response.content)
        print(f"Downloaded {param} to {output_path}")
    else:
        print(f"Failed to download {param}: HTTP {response.status_code}")


### 3. Post-Processing

1. Unit Conversion:

    SoilGrids stores integer values as actual value × 10. 
    
    For example:
        A pixel value of 150 = 15% clay.

In [None]:
# Process soil data

    # = src.profile
   

with rasterio.open(soil_raw_path) as src:
    clay = src.read(1)
    clay = clay.astype(np.float32) / 10  # Convert to %
    profile = src.profile.copy()
    profile.update(dtype=rasterio.float32)

    with rasterio.open(soil_processed_path, "w", **profile) as dst:
        dst.write(src.read())


In [None]:
with rasterio.open(soil_processed_path) as src:
    print(src.res)  # Should output (10.0, 10.0)
    print(src.read(1).min(), src.read(1).max())  # e.g., 0.0–38.9%

In [None]:
with rasterio.open(ndvi_path) as ndvi_src:
    ndvi_data = ndvi_src.read(1)
    ndvi_transform = ndvi_src.transform
    ndvi_crs = ndvi_src.crs
    ndvi_width = ndvi_src.width
    ndvi_height = ndvi_src.height

with rasterio.open(soil_processed_path) as soil_src:
    soil_data = soil_src.read(1)
    soil_crs = soil_src.crs
    soil_transform = soil_src.transform

# Create an empty array matching NDVI's shape
soil_reproj = np.empty((ndvi_height, ndvi_width), dtype=soil_data.dtype)

# Force the clay raster onto NDVI's exact grid
reproject(
    source=soil_data,
    destination=soil_reproj,
    src_transform=soil_transform,
    src_crs=soil_crs,
    dst_transform=ndvi_transform,
    dst_crs=ndvi_crs,
    resampling=Resampling.nearest
)

# Now ndvi_data and soil_reproj have the same shape and alignment


In [None]:
study_area_path = output_dir / "Tadla_plain_common.shp"
study_area = gpd.read_file(study_area_path).to_crs(ndvi_crs)

# 1. Open and plot the clay raster
with rasterio.open(soil_processed_path) as src:
    clay_crs = src.crs
    
    fig, ax = plt.subplots(figsize=(10, 8))
    show(src, ax=ax, cmap="Reds", alpha=0.7)
    ax.set_title("Clay Map with Study Area Boundary")

# 2. Load and reproject the study area to the clay's CRS
study_area = gpd.read_file(study_area_path)
study_area = study_area.to_crs(clay_crs)

# 3. Overlay the boundary on the same axes
study_area.plot(
    ax=ax,
    facecolor="none",
    edgecolor="black",
    linewidth=1
)

plt.show()

In [None]:
# Open raw water content data
with rasterio.open(wv0010_raw_path) as src:
    data = src.read(1) / 10  # Convert to %
    profile = src.profile.copy()

    # Calculate new dimensions for 10m resolution
    new_width = int(src.width * (src.res[0] / 10))  # From ~326m → 10m
    new_height = int(src.height * (abs(src.res[1]) / 10))  # From ~533m → 10m

    # Create empty array for resampled data
    resampled_data = np.empty((new_height, new_width), dtype=np.float32)

    # Define target transform for 10m resolution
    target_transform = rasterio.Affine(10, 0, src.bounds.left, 0, -10, src.bounds.top)

    # Resample using bilinear interpolation
    reproject(
        source=data,
        destination=resampled_data,
        src_transform=src.transform,
        dst_transform=target_transform,
        src_crs=src.crs,
        dst_crs=src.crs,
        resampling=Resampling.bilinear
    )

# Update metadata for the processed file
profile.update({
    "transform": target_transform,
    "width": new_width,
    "height": new_height,
    "dtype": "float32"
})

# Save resampled data
with rasterio.open(wv0010_processed_path, "w", **profile) as dst:
    dst.write(resampled_data, 1)

print(f"Resampled water content saved to: {wv0010_processed_path}")

In [None]:
with rasterio.open(wv0010_processed_path) as src:
    print(src.res)  # Should output (10.0, 10.0)
    print(src.read(1).min(), src.read(1).max())  # e.g., 0.0–38.9%

2. Validate CRS Alignment

    Confirm all downloaded rasters are in EPSG:26191

In [None]:

with rasterio.open(soil_raw_path) as src:
    print(src.crs)  # Should print "EPSG:26191"

# Step 3: DEM Data

1. Download DEM Data

    We’ll use ALOS PALSAR Global DEM (12.5m resolution) from Google Earth Engine (GEE).

In [None]:

# Load Tadla boundary (ensure this path is correct)
tadla_shp_path = tadla_common_path
tadla = gpd.read_file(tadla_shp_path)

# Check current CRS
print(f"Current CRS: {tadla.crs}")  # Should be EPSG:26191 (Merchich)

# Reproject to WGS84 (EPSG:4326)
tadla_wgs84 = tadla.to_crs("EPSG:4326")

# Save reprojected shapefile
tadla_wgs84.to_file(tadla_shp_path)  # Overwrite or save to a new file

In [None]:


# Load environment variables from the .env file
load_dotenv()

project_id = os.environ.get('GCP_PROJECT')
if not project_id:
    raise ValueError("The environment variable GCP_PROJECT is not set.")

print("Using project ID:", project_id)

import ee
ee.Authenticate()
ee.Initialize(project=project_id)

In [None]:
# Test authentication
print(ee.Image("NASA/NASADEM_HGT/001").get("title").getInfo())

In [None]:
bbox = ee.Geometry.Rectangle(
    [-7.5, 32.0, -5.5, 32.8],  # minx, miny, maxx, maxy
    proj="EPSG:4326"
)

In [None]:
# Load ALOS DEM ImageCollection and select the 'DSM' band
dem_collection = ee.ImageCollection("JAXA/ALOS/AW3D30/V3_2").select('DSM')

# Mosaic the collection into a single image (combines all tiles over Tadla)
dem = dem_collection.mosaic().clip(bbox)


In [None]:
# Export to Google Drive
task = ee.batch.Export.image.toDrive(
    image=dem,
    description='Tadla_DEM',
    folder='Tadla_Project',
    scale=12.5,
    region=bbox,
    crs="EPSG:26191",  # Merchich CRS
    fileFormat='GeoTIFF',
    maxPixels=1e13
)
task.start()

# Monitor task progress
print(f"Task ID: {task.id}")
print("Check progress at: https://code.earthengine.google.com/tasks")

2. Preprocess DEM
    
    Once downloaded, move the DEM to Data/Raw/GIS/Topography/ and preprocess it:

In [None]:


# Load boundary and ensure it's in the same CRS as the DEM (EPSG:26191)
tadla = gpd.read_file(tadla_common_path)
if tadla.crs != "EPSG:26191":
    tadla = tadla.to_crs("EPSG:26191")

# Load DEM and check its CRS
with rasterio.open(dem_raw_path) as src:
    dem_crs = src.crs
    print(f"DEM CRS: {dem_crs}")  # Should be EPSG:26191

    # Fix 2: Reproject boundary if DEM is in a different CRS
    if tadla.crs != dem_crs:
        tadla = tadla.to_crs(dem_crs)

    # Fix 3: Validate overlap
    dem_bounds = src.bounds
    tadla_bounds = tadla.total_bounds
    print(f"DEM Bounds: {dem_bounds}")
    print(f"Tadla Bounds: {tadla_bounds}")

    if not (
        (tadla_bounds[0] > dem_bounds.left) &
        (tadla_bounds[2] < dem_bounds.right) &
        (tadla_bounds[1] > dem_bounds.bottom) &
        (tadla_bounds[3] < dem_bounds.top)
    ):
        raise ValueError("DEM and boundary do not overlap. Check their geographic extents!")

    # Clip DEM
    tadla_dem, transform = mask(src, tadla.geometry, crop=True)
    meta = src.meta.copy()
    meta.update({
        "height": tadla_dem.shape[1],
        "width": tadla_dem.shape[2],
        "transform": transform,
        "crs": dem_crs
    })

# Save clipped DEM
with rasterio.open(dem_processed_path, "w", **meta) as dest:
    dest.write(tadla_dem)
print(f"Clipped DEM saved to: {dem_processed_path}")

In [None]:
print(f"DEM exists: {dem_raw_path.exists()}")
print(f"Boundary exists: {tadla_common_path.exists()}")

### 3. Derive Slope and Aspect

1. Calculating Slope and Aspect Using GDAL

In [None]:
# Enable GDAL exceptions
gdal.UseExceptions()

# Ensure output directories exist
os.makedirs(slope_path.parent, exist_ok=True)

# Calculate slope
slope = gdal.DEMProcessing(
    destName=str(slope_path),
    srcDS=str(dem_processed_path),
    processing="slope",
    format="GTiff",
    slopeFormat="degree"
)

# Calculate aspect
aspect = gdal.DEMProcessing(
    destName=str(aspect_path),
    srcDS=str(dem_processed_path),
    processing="aspect",
    format="GTiff"
)

print(f"Slope saved to: {slope_path}")
print(f"Aspect saved to: {aspect_path}")

# Step 4: Weather Data

#### 1. Download CHIRPS Rainfall Data

In [None]:
# Authenticate and initialize GEE
ee.Authenticate()
ee.Initialize(project=project_id)

# Define study area (Tadla Plain)
tadla_plain = ee.Geometry.Rectangle(
    [-7.5, 32.0, -5.5, 32.8],  # Adjust to your exact boundary
    proj="EPSG:4326",
    geodesic=False
)

# Load CHIRPS daily rainfall (2017–2023)
chirps = ee.ImageCollection("UCSB-CHG/CHIRPS/DAILY") \
    .filterDate("2017-01-01", "2023-12-31") \
    .filterBounds(tadla_plain)

# Convert to multi-band image (one band per day)
chirps_multi_band = chirps.toBands()

# Export to Google Drive
task = ee.batch.Export.image.toDrive(
    image=chirps_multi_band,
    description="CHIRPS_Daily_Tadla_2017-2023",
    folder="Tadla_Project",
    fileNamePrefix="CHIRPS_Daily_Tadla_2017-2023",
    region=tadla_plain,
    scale=5000,  # CHIRPS native resolution (~5km)
    crs="EPSG:4326",  # Reproject to EPSG:26191 later in Python
    maxPixels=1e13,
    fileFormat="GeoTIFF"
)

task.start()
print("Export started! Monitor at: https://code.earthengine.google.com/tasks")



#### 2. Download ERA5 Temperature/ET Data

In [None]:

ee.Initialize(project=project_id)
# Define study area
tadla_plain = ee.Geometry.Rectangle(
    [-7.5, 32.0, -5.5, 32.8], 
    proj="EPSG:4326",
    geodesic=False
)

# Load ERA5-Land DAILY_AGGR and select evaporation
era5_land = ee.ImageCollection("ECMWF/ERA5_LAND/DAILY_AGGR") \
    .filterDate("2017-01-01", "2023-12-31") \
    .filterBounds(tadla_plain) \
    .select("total_evaporation_sum")  # <--- CORRECT BAND NAME

# Convert to multi-band image (one band per day)
era5_multi_band = era5_land.toBands()

# Export to Google Drive
task = ee.batch.Export.image.toDrive(
    image=era5_multi_band,
    description="ERA5_Land_Evaporation_Tadla_2017-2023",
    folder="Tadla_Project",
    fileNamePrefix="ERA5_Land_Evaporation_Tadla_2017-2023",
    region=tadla_plain,
    scale=11132,  # ERA5-Land resolution: 0.1° (~11km)
    crs="EPSG:4326",
    maxPixels=1e13,
    fileFormat="GeoTIFF"
)

task.start()
print("Export started! Monitor progress at: https://code.earthengine.google.com/tasks")

#### 3. Preprocess CHIRPS Rainfall Data

In [None]:
# Load CHIRPS data with simplified band names
chirps_ds = rxr.open_rasterio(chirps_raw_path)

# Generate dates for 2017-01-01 to 2023-12-31
dates = pd.date_range(start="2017-01-01", periods=chirps_ds.sizes["band"], freq="D")

# Assign time coordinates to the 'band' dimension
chirps_ds = chirps_ds.assign_coords(band=dates).rename({"band": "time"})

Process Monthly Aggregated Data

In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np
from dask.distributed import Client, LocalCluster
from rasterio.enums import Resampling

def process_month(year, month, spatial_chunk=128):
    """Aggregate daily data for one month, reproject, clip, and save as GeoTIFF."""
    try:
        # Adjust worker settings based on your system.
        cluster = LocalCluster(n_workers=4, memory_limit='8GB')
        client = Client(cluster)
        
        # Determine start and end dates for the month
        start_date = f"{year}-{month:02d}-01"
        end_day = pd.Timestamp(year, month, 1).daysinmonth
        end_date = f"{year}-{month:02d}-{end_day:02d}"
        
        # Select daily data for the month and aggregate (sum) over time
        ds_month = chirps_ds.sel(time=slice(start_date, end_date))
        ds_month_agg = ds_month.sum(dim="time")
        
        # Reproject to target CRS (EPSG:26191) with 10 m resolution
        ds_reproj = ds_month_agg.rio.reproject(
            "EPSG:26191",
            resolution=10,
            resampling=Resampling.bilinear,
            nodata=np.nan
        )
        
        # Read Tadla boundary from your shapefile or other source
        tadla = gpd.read_file(tadla_common_path)
        
        # Clip to the Tadla boundary
        ds_clipped = ds_reproj.rio.clip(tadla.geometry, tadla.crs)
        
        # Remove problematic attributes if present (e.g., "long_name")
        ds_clipped.attrs.pop("long_name", None)
        # If needed, also remove it from specific data variables, e.g.,
        # ds_clipped['precipitation'].attrs.pop("long_name", None)
        
        # Save the monthly aggregated data to a GeoTIFF file
        out_filename = weather_processed_dir / f"CHIRPS_{year}_{month:02d}.tif"
        ds_clipped.rio.to_raster(out_filename)
        print(f"Saved {out_filename}")
        
    except Exception as e:
        print(f"Error processing {year}-{month:02d}: {e}")
    finally:
        client.close()
        cluster.close()



In [None]:

# Process each month in your desired time period (2017-2023)
for year in range(2017, 2024):
    for month in range(1, 13):
        process_month(year, month)

#### 4. Preprocess ERA5 Evaporation Data

Load and Process ERA5 Data

In [None]:
import rioxarray as rxr
import xarray as xr

# Load ERA5 data (replace with your path)
era5_ds = rxr.open_rasterio(era5_raw_path, chunks={"band": 1, "x": 256, "y": 256})

# Assign time coordinates (assuming bands are daily from 2017-01-01)
dates = pd.date_range(start="2017-01-01", periods=era5_ds.sizes["band"], freq="D")
era5_ds = era5_ds.assign_coords(band=dates).rename({"band": "time"})

# Convert units: m/day → mm/day
era5_ds = era5_ds * 1000  # 1 m/day = 1000 mm/day
era5_ds.attrs["units"] = "mm/day"

Reproject, Clip, and Save (ERA5)

In [None]:
def process_era5_month(year, month, spatial_chunk=128):
    """Process ERA5 evaporation for one month."""
    try:
        # Start Dask cluster
        cluster = LocalCluster(n_workers=4, memory_limit='8GB')
        client = Client(cluster)

        # Slice to month and aggregate to monthly mean (instead of sum)
        start_date = f"{year}-{month:02d}-01"
        end_date = f"{year}-{month:02d}-{pd.Timestamp(year, month, 1).daysinmonth}"
        # In process_era5_month():  
        ds_month = era5_ds.sel(time=slice(start_date, end_date)).sum(dim="time")  # Sum, not mean  
        #ds_month = era5_ds.sel(time=slice(start_date, end_date)).mean(dim="time")  # Monthly mean

        # Reproject to NDVI grid (EPSG:26191, 10m)
        ds_reproj = ds_month.rio.reproject(
            "EPSG:26191",
            resolution=10,
            resampling=Resampling.bilinear,
            nodata=np.nan
        )

        # Clip to Tadla boundary
        tadla = gpd.read_file(tadla_common_path)
        ds_clipped = ds_reproj.rio.clip(tadla.geometry, tadla.crs)

        # Save
        out_filename = weather_processed_dir / f"ERA5_{year}_{month:02d}.tif"
        ds_clipped.rio.to_raster(out_filename)
        print(f"Saved {out_filename}")

    except Exception as e:
        print(f"Error processing {year}-{month:02d}: {e}")
    finally:
        client.close()
        cluster.close()

In [None]:
# Process all months
for year in range(2017, 2024):
    for month in range(1, 13):
        process_era5_month(year, month)

#### 5. Combining Monthly Files into Annual Datasets

In [None]:
import rasterio
from rasterio.warp import reproject, Resampling

def reproject_annual(year, dataset="ERA5"):
    input_dir = Path(weather_processed_dir)
    output_dir = Path(weather_processed_dir / f"{dataset}_Annual/")
    output_dir.mkdir(exist_ok=True)
    
    # List monthly files for the year
    monthly_files = [input_dir / f"{dataset}_{year}_{month:02d}.tif" 
                     for month in range(1, 13)]
    
    # Reproject and stack bands
    with rasterio.open(monthly_files[0]) as first:
        meta = first.meta.copy()
        meta.update(count=12)  # Explicitly set to 12 bands
    
    with rasterio.open(output_dir / f"{dataset}_{year}_reproj.tif", "w", **meta) as dst:
        for band_idx, monthly_file in enumerate(monthly_files, start=1):
            with rasterio.open(monthly_file) as src:
                reproject(
                    source=rasterio.band(src, 1),
                    destination=rasterio.band(dst, band_idx),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=meta["transform"],
                    dst_crs=meta["crs"],
                    resampling=Resampling.bilinear
                )
    print(f"Processed {dataset}_{year}_reproj.tif with 12 bands")

In [None]:
# Reprocess all years
for year in range(2017, 2024):
    reproject_annual(year, "CHIRPS")

In [None]:
# Reprocess all years
for year in range(2017, 2024):
    reproject_annual(year, "ERA5")

# Step 5: Land Use/Crop Maps (Sentinel-2)

1. Authenticate & Initialize Earth Engine

In [None]:
# Authenticate (this will open a browser window for authentication if needed)
ee.Authenticate()

# Initialize with your project settings (make sure you have set your GCP_PROJECT in your environment variables)
ee.Initialize(project=project_id)

print("Earth Engine has been initialized successfully!")


In [None]:
tadla = gpd.read_file(tadla_common_path)

# Reproject to WGS84 (EPSG:4326) if needed
if tadla.crs != "EPSG:26191":
    tadla = tadla.to_crs("EPSG:26191")

# Convert to GEE geometry
tadla_geom = ee.Geometry.Polygon(tadla.geometry[0].exterior.coords[:])

In [None]:
# Load Tadla boundary (WGS84)
tadla_geom = ee.Geometry.Polygon(
    [[-7.5, 32.0], [-5.5, 32.0], [-5.5, 32.8], [-7.5, 32.8]], 
    proj="EPSG:4326", 
    geodesic=False
)

# Reproject to EPSG:26191 (Merchich)
tadla_merc = tadla_geom.transform('EPSG:26191', 1)  # 1-meter error margin

In [None]:
def get_annual_composite(year):
    start_date = f'{year}-04-01'
    end_date = f'{year}-09-30'
    
    # Load Sentinel-2 collection
    s2_collection = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED") \
        .filterBounds(tadla_merc) \
        .filterDate(start_date, end_date) \
        .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE', 10))
    
    # Harmonize bands: Select and rename critical bands (B4=Red, B8=NIR)
    s2_harmonized = s2_collection.map(
        lambda img: img.select(
            ['B4', 'B8', 'SCL'],  # Keep only Red, NIR, and Scene Classification
            ['red', 'nir', 'scl']  # Rename to avoid conflicts
        ).cast({'red': 'float', 'nir': 'float'})  # Force consistent data types
    )
    
    # Compute median composite
    composite = s2_harmonized.median()
    
    # Calculate NDVI
    ndvi = composite.expression(
        '(nir - red) / (nir + red)', 
        {'nir': composite.select('nir'), 'red': composite.select('red')}
    ).rename('NDVI')
    
    return ndvi.reproject(crs='EPSG:26191', scale=10)

In [None]:
def export_ndvi(year):
    ndvi = get_annual_composite(year)
    task = ee.batch.Export.image.toDrive(
        image=ndvi,
        description=f'Sentinel2_Tadla_NDVI_{year}',
        folder='Tadla_Project',
        scale=10,
        region=tadla_merc,
        crs='EPSG:26191',
        maxPixels=1e13,
        fileFormat='GeoTIFF'
    )
    task.start()
    print(f"Exported {year}: Task ID {task.id}")

# Run for all years (2017–2023)
for year in range(2017, 2024):
    export_ndvi(year)

# Step 6 – Data Harmonization

#### 1. Resample Coarse Data (Soil/DEM) to 10m Resolution

    Goal: Resample low-resolution datasets (e.g., SoilGrids at 250m) to match NDVI’s 10m grid.
    Why: To align all datasets spatially for ML training.

In [None]:
# Extract metadata from NDVI 2017
with rasterio.open(ndvi_path) as ndvi_ref:
    ndvi_transform = ndvi_ref.transform  # 10m resolution transform
    ndvi_crs = ndvi_ref.crs             # CRS (EPSG:26191)
    ndvi_width = ndvi_ref.width         # Number of columns
    ndvi_height = ndvi_ref.height       # Number of rows

print(f"Reference CRS: {ndvi_crs}")
print(f"Reference resolution: {ndvi_transform[0]}m")

In [None]:
from rasterio.warp import reproject, Resampling
import numpy as np

# Paths (update with your actual paths)
soil_clay_10m = soil_processed_dir / "tadla_clay_10m.tif"

# Resample clay to 10m using NDVI’s grid
with rasterio.open(soil_processed_path) as src:
    # Initialize destination array with NDVI dimensions
    dst_data = np.zeros((ndvi_height, ndvi_width), dtype=np.float32)
    
    reproject(
        source=rasterio.band(src, 1),
        destination=dst_data,
        src_transform=src.transform,
        dst_transform=ndvi_transform,
        src_crs=src.crs,
        dst_crs=ndvi_crs,
        resampling=Resampling.bilinear  # Use "nearest" for categorical data
    )
    
    # Save resampled clay
    with rasterio.open(
        soil_clay_10m,
        "w",
        driver="GTiff",
        height=ndvi_height,
        width=ndvi_width,
        count=1,
        dtype=np.float32,
        crs=ndvi_crs,
        transform=ndvi_transform,
        nodata=src.nodata
    ) as dst:
        dst.write(dst_data, 1)

In [None]:
with rasterio.open(soil_clay_10m) as clay_resampled:
    print(f"Resampled clay resolution: {clay_resampled.res}")  # Should be (10.0, 10.0)
    print(f"CRS: {clay_resampled.crs}")  # Should match NDVI (EPSG:26191)

In [None]:
def resample_soil_layer(raw_path, processed_path, ndvi_transform, ndvi_crs, ndvi_height, ndvi_width):
    with rasterio.open(raw_path) as src:
        dst_data = np.zeros((ndvi_height, ndvi_width), dtype=np.float32)
        reproject(
            source=rasterio.band(src, 1),
            destination=dst_data,
            src_transform=src.transform,
            dst_transform=ndvi_transform,
            src_crs=src.crs,
            dst_crs=ndvi_crs,
            resampling=Resampling.bilinear
        )
        with rasterio.open(
            processed_path,
            "w",
            driver="GTiff",
            height=ndvi_height,
            width=ndvi_width,
            count=1,
            dtype=np.float32,
            crs=ndvi_crs,
            transform=ndvi_transform,
            nodata=src.nodata
        ) as dst:
            dst.write(dst_data, 1)
    print(f"Resampled {raw_path.name} → {processed_path}")

# Example usage:
soil_params = {
    "silt": "tadla_silt_processed.tif",
    "sand": "tadla_sand_processed.tif",
    "ocd": "tadla_ocd_processed.tif",  # Organic carbon density
    "wv0010": "tadla_wv0010_processed.tif"   # Water content at saturation
}

for param, filename in soil_params.items():
    pre_processed_path = soil_processed_dir / filename
    processed_path_10m = soil_processed_dir / f"tadla_{param}_10m.tif"
    resample_soil_layer(pre_processed_path, processed_path_10m, ndvi_transform, ndvi_crs, ndvi_height, ndvi_width)

In [None]:
for param in ["silt", "sand", "ocd", "wv0010"]:
    with rasterio.open(soil_processed_dir / f"tadla_{param}_10m.tif") as src:
        print(f"{param} resolution: {src.res}, CRS: {src.crs}")

Resample DEM (12.5m → 10m)

In [None]:
# Paths (update with your actual paths)
dem_raw = raw_data_dir / config["paths"]["dem_raw"]
dem_processed = processed_data_dir / config["paths"]["dem_processed"]

with rasterio.open(dem_raw) as src:
    dst_data = np.zeros((ndvi_height, ndvi_width), dtype=np.float32)
    reproject(
        source=rasterio.band(src, 1),
        destination=dst_data,
        src_transform=src.transform,
        dst_transform=ndvi_transform,
        src_crs=src.crs,
        dst_crs=ndvi_crs,
        resampling=Resampling.bilinear  # Use cubic for elevation
    )
    with rasterio.open(
        dem_processed,
        "w",
        driver="GTiff",
        height=ndvi_height,
        width=ndvi_width,
        count=1,
        dtype=np.float32,
        crs=ndvi_crs,
        transform=ndvi_transform,
        nodata=src.nodata
    ) as dst:
        dst.write(dst_data, 1)

In [None]:
with rasterio.open(dem_processed) as src:
    print(f"DEM resolution: {src.res}, CRS: {src.crs}")  # Should be (10.0, 10.0), EPSG:26191

Resample Slope & Aspect Layers

In [None]:
# Resample slope (from 12.5m to 10m)  
with rasterio.open(slope_path) as src:  
    dst_data = np.zeros((ndvi_height, ndvi_width), dtype=np.float32)  
    reproject(  
        source=rasterio.band(src, 1),  
        destination=dst_data,  
        src_transform=src.transform,  
        dst_transform=ndvi_transform, 
        src_crs=src.crs,
        dst_crs=ndvi_crs, 
        resampling=Resampling.bilinear  
    )  
    # Save to slope_10m.tif  

#### 2. Align All Rasters to NDVI Grid
    
    Goal: Ensure all datasets (soil, DEM, weather) are spatially aligned with the NDVI grid.
    Why: Even minor misalignments will break ML models.

1. Align Weather Data (CHIRPS Rainfall and ERA5 Evaporation)

In [None]:
import rioxarray as rxr

# Load NDVI reference (e.g., NDVI_2017.tif)
ndvi_ref = rxr.open_rasterio(ndvi_path)
ndvi_transform = ndvi_ref.rio.transform()  # Get the exact transform
ndvi_crs = ndvi_ref.rio.crs
ndvi_shape = (ndvi_ref.rio.height, ndvi_ref.rio.width)

In [None]:
import rioxarray

def Reproject_CHIRPS(year):
    # Open the annual file as a rioxarray DataArray
    da = rioxarray.open_rasterio(weather_processed_dir /  f"CHIRPS_Annual/CHIRPS_{year}.tif")

    # Reproject to your NDVI grid
    da_reproj = da.rio.reproject(
        dst_crs=ndvi_crs,
        shape=ndvi_shape,
        transform=ndvi_transform,
        resampling=Resampling.bilinear
    )

    # Save the reprojected file
    out_file = weather_processed_dir / f"CHIRPS_Annual/CHIRPS_{year}_reproj.tif"
    da_reproj.rio.to_raster(out_file)
    print(f"Reprojected file saved to {out_file}")

In [None]:
# Example usage
for year in range(2017, 2024):
    Reproject_CHIRPS(year)

In [None]:
import rioxarray
from rasterio.enums import Resampling

def Reproject_ERA5(year):
    # Path to your annual ERA5 file (e.g., for 2017)
    era5_file = weather_processed_dir / f"ERA5_Annual/ERA5_{year}.tif"

    # Open the annual ERA5 file as a rioxarray DataArray
    da = rioxarray.open_rasterio(era5_file)

    # Reproject to the NDVI grid using your target CRS, shape, and transform
    era5_reproj = da.rio.reproject(
        dst_crs=ndvi_crs,
        shape=ndvi_shape,
        transform=ndvi_transform,
        resampling=Resampling.bilinear
    )

    # Save the reprojected ERA5 file
    out_file = weather_processed_dir / f"ERA5_Annual/ERA5_{year}_reproj.tif"
    era5_reproj.rio.to_raster(out_file)
    print(f"Reprojected ERA5 file saved to {out_file}")

In [None]:
for year in range(2017, 2024):
    Reproject_ERA5(year)

In [None]:
with rasterio.open(weather_processed_dir / "CHIRPS_Annual/CHIRPS_2017_reproj.tif") as chirps, rasterio.open(weather_processed_dir / "ERA5_Annual/ERA5_2017_reproj.tif") as era5:
    assert chirps.transform == era5.transform, "Transform mismatch!"
    assert chirps.crs == era5.crs, "CRS mismatch!"

#### 3. Validate Full Spatial Alignment

    Goal: Ensure all rasters (soil, DEM, NDVI, weather) share the same origin, resolution, and CRS.

3.1. Check CRS 

    Goal: Confirm all datasets use EPSG:26191 (Merchich).

In [None]:
import rasterio

# List of all processed files to validate
layers = [
    ndvi_path,                    # Reference NDVI (EPSG:26191)
    land_use_raw_dir / "Sentinel2_Tadla_NDVI_2018.tif",
    land_use_raw_dir / "Sentinel2_Tadla_NDVI_2019.tif",
    land_use_raw_dir / "Sentinel2_Tadla_NDVI_2020.tif",
    land_use_raw_dir / "Sentinel2_Tadla_NDVI_2021.tif",
    land_use_raw_dir / "Sentinel2_Tadla_NDVI_2022.tif",
    land_use_raw_dir / "Sentinel2_Tadla_NDVI_2023.tif",
    soil_processed_dir / "tadla_clay_10m.tif",
    dem_processed_path,
    weather_processed_dir / "CHIRPS_Annual/CHIRPS_2017_reproj.tif",
    weather_processed_dir / "CHIRPS_Annual/CHIRPS_2018_reproj.tif",
    weather_processed_dir / "CHIRPS_Annual/CHIRPS_2019_reproj.tif",
    weather_processed_dir / "CHIRPS_Annual/CHIRPS_2020_reproj.tif",
    weather_processed_dir / "CHIRPS_Annual/CHIRPS_2021_reproj.tif",
    weather_processed_dir / "CHIRPS_Annual/CHIRPS_2022_reproj.tif",
    weather_processed_dir / "CHIRPS_Annual/CHIRPS_2023_reproj.tif",
    weather_processed_dir / "ERA5_Annual/ERA5_2017_reproj.tif",
    weather_processed_dir / "ERA5_Annual/ERA5_2018_reproj.tif",
    weather_processed_dir / "ERA5_Annual/ERA5_2019_reproj.tif",
    weather_processed_dir / "ERA5_Annual/ERA5_2020_reproj.tif",
    weather_processed_dir / "ERA5_Annual/ERA5_2021_reproj.tif",
    weather_processed_dir / "ERA5_Annual/ERA5_2022_reproj.tif",
    weather_processed_dir / "ERA5_Annual/ERA5_2023_reproj.tif"
]

for layer in layers:
    with rasterio.open(layer) as src:
        print(f"{layer.name}: CRS = {src.crs}")  # Should all print "EPSG:26191"

3.2. Validate Transform & Resolution

    Goal: Ensure all rasters have the same origin (transform[2], transform[5]) and 10m resolution.

In [None]:
with rasterio.open(ndvi_path) as ndvi_ref:
    ref_transform = ndvi_ref.transform
    print(f"Reference transform: {ref_transform}")

for layer in layers:
    with rasterio.open(layer) as src:
        print(f"{layer.name}:")
        print(f"  Transform: {src.transform}")
        print(f"  Resolution: {src.res}")
        assert src.transform == ref_transform, "Transform mismatch!"
        assert src.res == (10.0, 10.0), "Resolution mismatch!"

3.3. Check Spatial Extents

    Goal: Ensure all layers cover the exact same geographic area as NDVI.

In [None]:
with rasterio.open(ndvi_path) as ndvi_ref:
    ref_bounds = ndvi_ref.bounds

for layer in layers:
    with rasterio.open(layer) as src:
        layer_bounds = src.bounds
        print(f"{layer.name}:")
        print(f"  Bounds: {layer_bounds}")
        assert np.allclose(layer_bounds, ref_bounds, atol=1e-2), "Bounds mismatch!"

3.4. Validate Pixel Grid Alignment

    Goal: Ensure a point (e.g., (340000, 240000)) falls in the same pixel across all layers.

In [None]:
sample_x, sample_y = 340000.0, 240000.0  # A point in Tadla

def get_pixel_value(path, x, y):
    with rasterio.open(path) as src:
        row, col = src.index(x, y)
        return src.read(1)[row, col]

print(f"NDVI value at ({sample_x}, {sample_y}): {get_pixel_value(ndvi_path, sample_x, sample_y)}")
print(f"Clay value: {get_pixel_value(soil_clay_10m, sample_x, sample_y)}")
print(f"DEM value: {get_pixel_value(dem_processed_path, sample_x, sample_y)}")

3.5. Handle NoData Consistency

    Goal: Ensure all rasters use the same NoData value (e.g., -9999.0).

In [None]:
STANDARD_NODATA = -9999.0

In [None]:
for layer in layers:
    with rasterio.open(layer, "r+") as src:  # Open in read/write mode
        # Update metadata
        src.nodata = STANDARD_NODATA
        # Replace existing NaN/None with STANDARD_NODATA (if needed)
        data = src.read(1)
        data[~np.isfinite(data)] = STANDARD_NODATA  # Handle NaNs
        src.write(data, 1)
    print(f"Updated {layer.name}: NoData = {STANDARD_NODATA}")

# Process all years (2017–2023)
for year in range(2017, 2024):
    # Update CHIRPS
    chirps_path = weather_processed_dir / f"CHIRPS_Annual/CHIRPS_{year}_reproj.tif"
    if chirps_path.exists():
        with rasterio.open(chirps_path, "r+") as src:
            src.nodata = STANDARD_NODATA
            data = src.read(1)
            data[np.isnan(data)] = STANDARD_NODATA
            src.write(data, 1)
    
    # Update ERA5
    era5_path = weather_processed_dir / f"CHIRPS_Annual/ERA5_{year}_reproj.tif"
    if era5_path.exists():
        with rasterio.open(era5_path, "r+") as src:
            src.nodata = STANDARD_NODATA
            data = src.read(1)
            data[np.isnan(data)] = STANDARD_NODATA
            src.write(data, 1)


In [None]:

# Update all NDVI layers (2017–2023)
for year in range(2017, 2024):
    ndvi_path = land_use_raw_dir / f"Sentinel2_Tadla_NDVI_{year}.tif"
    if ndvi_path.exists():
        with rasterio.open(ndvi_path, "r+") as src:
            src.nodata = STANDARD_NODATA
            data = src.read(1)
            data[~np.isfinite(data)] = STANDARD_NODATA  # Handles NaN/Inf
            src.write(data, 1)

3.6. Visual Inspection

    Goal: Plot layers over each other to confirm alignment.

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, figsize=(15, 7))

# Plot NDVI
with rasterio.open(ndvi_path) as src:
    ndvi = src.read(1)
    ax[0].imshow(ndvi, cmap="YlGn", vmin=0, vmax=1)
    ax[0].set_title("NDVI")

# Plot clay %
with rasterio.open(soil_clay_10m) as src:
    clay = src.read(1)
    ax[1].imshow(clay, cmap="Reds", alpha=0.7)  # Overlay with transparency
    ax[1].set_title("Clay % Overlay")

plt.show()

Verify Pixel Replacement:
    
    Check if nan/None values in the original data were replaced with -9999.0:

In [None]:
with rasterio.open(weather_processed_dir / "CHIRPS_Annual/CHIRPS_2017_reproj.tif") as src:  
    data = src.read(1)  
    print(f"Unique values in CHIRPS: {np.unique(data)}")  
    # Should include -9999.0 but no NaN  

In [None]:
with rasterio.open(dem_processed_path) as src:  
    plt.imshow(src.read(1), cmap="viridis", vmin=0, vmax=1000)  
    plt.colorbar(label="Elevation (m)")  
    plt.title("DEM with NoData=-9999")  

In [None]:
with rasterio.open(soil_processed_dir / "tadla_clay_10m.tif") as src:  
    print(src.nodata)  # Should still be -9999.0  

In [None]:
def check_nodata(path):
    """Check if NoData is standardized across all files matching a pattern."""
    for year in range(2017, 2024):
        if path.exists():
            with rasterio.open(path) as src:
                print(f"{path.name}: NoData = {src.nodata}")
                data = src.read(1)
                assert np.nanmax(data) != np.nan, "NaNs still present!"
                assert (data[data == STANDARD_NODATA].size > 0), "NoData not replaced!"

# Example usage:
check_nodata(weather_processed_dir / "CHIRPS_Annual/" / f"CHIRPS_{year}_reproj.tif")  # CHIRPS_2017_reproj.tif, etc.
check_nodata(weather_processed_dir / "ERA5_Annual/" / f"ERA5_{year}_reproj.tif")    # ERA5_2017_reproj.tif, etc.
check_nodata(land_use_raw_dir / f"Sentinel_Tadla_NDVI_{year}.tif")          # Sentinel2_Tadla_NDVI_2017.tif, etc.

In [None]:
# Soil layers
soil_layers = ["clay", "silt", "sand", "ocd", "wv0010"]
for param in soil_layers:
    path = soil_processed_dir / f"tadla_{param}_10m.tif"
    with rasterio.open(path, "r+") as src:
        src.nodata = STANDARD_NODATA
        data = src.read(1)
        data[~np.isfinite(data)] = STANDARD_NODATA
        src.write(data, 1)

# DEM
with rasterio.open(dem_processed_path, "r+") as src:
    src.nodata = STANDARD_NODATA
    data = src.read(1)
    data[~np.isfinite(data)] = STANDARD_NODATA
    src.write(data, 1)

In [None]:
import rasterio

static_layers = [
    soil_processed_dir / "tadla_clay_10m.tif",
    soil_processed_dir / "tadla_silt_10m.tif",
    soil_processed_dir / "tadla_sand_10m.tif",
    soil_processed_dir / "tadla_ocd_10m.tif",
    soil_processed_dir / "tadla_wv0010_10m.tif",
    dem_processed_path
]

for layer in static_layers:
    with rasterio.open(layer) as src:
        print(f"{layer.name}: NoData = {src.nodata}")  # Should all be -9999.0

In [None]:
import numpy as np

for layer in static_layers:
    with rasterio.open(layer) as src:
        data = src.read(1)
        valid_data = data[data != src.nodata]  # Exclude NoData
        print(f"{layer.name}:")
        print(f"  Min = {np.min(valid_data):.2f}, Max = {np.max(valid_data):.2f}")

# Expected ranges:
# - Clay/Silt/Sand: 0–100% (sum ≈ 100% per pixel)
# - OCD: 0–50 g/kg (organic carbon)
# - WCS: 0–1 cm³/cm³ (water content)
# - DEM: Elevation in meters (e.g., 0–1000m)

In [None]:
with rasterio.open(ndvi_path) as ndvi_ref:
    ref_transform = ndvi_ref.transform
    print(f"NDVI transform: {ref_transform.to_gdal()}")  # GDAL-style tuple

with rasterio.open(soil_processed_dir / "tadla_clay_10m.tif") as src:
    clay_transform = src.transform.to_gdal()
    print(f"Clay transform: {clay_transform}")

In [None]:
import rasterio
from rasterio.warp import reproject, Resampling
import numpy as np
from pathlib import Path

# Load NDVI’s grid definition (transform, CRS, resolution)
with rasterio.open(ndvi_path) as ndvi_ref:
    ref_profile = ndvi_ref.profile  # Includes transform, crs, etc.

# List of soil layers to reproject
soil_layers = [
    "tadla_clay_10m.tif",
    "tadla_silt_10m.tif",
    "tadla_sand_10m.tif",
    "tadla_ocd_10m.tif",
    "tadla_wv0010_10m.tif"
]

for layer in soil_layers:
    input_path = soil_processed_dir / layer
    output_path = input_path.parent / f"aligned_{layer}"
    
    # Reproject to NDVI’s grid
    with rasterio.open(input_path) as src:
        data = src.read(1)
        with rasterio.open(output_path, "w", **ref_profile) as dst:
            reproject(
                source=data,
                destination=rasterio.band(dst, 1),
                src_transform=src.transform,
                src_crs=src.crs,
                dst_transform=ref_profile["transform"],
                dst_crs=ref_profile["crs"],
                resampling=Resampling.bilinear  # Use "nearest" for categorical data
            )
    print(f"Aligned {layer} → {output_path}")

    # Replace old file with aligned version
    output_path.replace(input_path)  # Overwrite original file

In [None]:
with rasterio.open(ndvi_path) as ndvi_ref:
    ref_transform = ndvi_ref.transform

for layer in soil_layers:
    input_path = soil_processed_dir / layer
    with rasterio.open(input_path) as src:
        # Check transform alignment with tolerance
        assert np.allclose(
            list(src.transform), 
            list(ref_transform), 
            atol=1e-6  # Allow 0.001mm tolerance
        ), f"{layer} transform mismatch!"
        print(f"{layer} is aligned ✅")

In [None]:
with rasterio.open(ndvi_path) as ndvi_ref:
    ref_transform = ndvi_ref.transform
    ref_crs = ndvi_ref.crs

for layer in static_layers:
    with rasterio.open(layer) as src:
        print(f"{layer.name}:")
        print(f"  CRS: {src.crs}")  # Should be EPSG:26191
        print(f"  Transform: {src.transform}")  # Should match NDVI
        print(f"  Resolution: {src.res}")  # Should be (10.0, 10.0)
        assert src.transform == ref_transform, "Transform mismatch!"
        assert src.crs == ref_crs, "CRS mismatch!"

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 2, figsize=(15, 7))

# Plot NDVI
with rasterio.open(ndvi_path) as src:
    ndvi = src.read(1)
    ax[0].imshow(ndvi, cmap="YlGn", vmin=0, vmax=1)
    ax[0].set_title("NDVI")

# Plot clay %
with rasterio.open(soil_processed_dir / "tadla_clay_10m.tif") as src:
    clay = src.read(1)
    ax[1].imshow(clay, cmap="Reds", alpha=0.7)  # Overlay with transparency
    ax[1].set_title("Clay % Overlay")

plt.show()

#### 4.  Temporal Aggregation (All Years)

    Goal: Convert daily CHIRPS rainfall and ERA5 evaporation into monthly aggregates for all years.

4.1. Batch Process CHIRPS (Daily → Monthly Sum)

In [None]:
chirps_output_dir = Path(config["paths"]["harmonized_data"]) / "CHIRPS_monthly"
chirps_output_dir.mkdir(exist_ok=True, parents=True)

In [None]:
# Define paths
chirps_dir = weather_processed_dir / "CHIRPS_Annual/" # Input: Daily CHIRPS files

for year in range(2017, 2024):
    # Load daily CHIRPS for the year (already reprojected to 10m grid)
    chirps_daily = rxr.open_rasterio(chirps_dir / f"CHIRPS_{year}_reproj.tif", chunks={"band": -1, "x": 1000, "y": 1000})
    
    # Convert to xarray Dataset and rename band to "time"
    dates = pd.date_range(start=f"{year}-01-01", periods=chirps_daily.sizes["band"], freq="D")
    chirps_daily = chirps_daily.assign_coords(band=dates).rename({"band": "time"})
    
    # Resample to monthly sum (total rainfall per month)
    chirps_monthly = chirps_daily.resample(time="1ME").sum(skipna=False)  # skipna=False to retain NoData
    
    # Save as NetCDF (one file per year)
    chirps_monthly.rio.to_raster(output_dir / f"CHIRPS_monthly_{year}.nc")
    print(f"Saved CHIRPS monthly for {year}")



4.2. Batch Process ERA5 Evaporation (Daily → Monthly Total)

In [None]:
era5_output_dir = Path(config["paths"]["harmonized_data"]) / "ERA5_monthly"
era5_output_dir.mkdir(exist_ok=True, parents=True)

In [None]:
era5_dir = weather_processed_dir / "ERA5_Annual/"  # Input: Daily ERA5 files

for year in range(2017, 2024):
    # Load daily ERA5 evaporation (already reprojected)
    era5_daily = rxr.open_rasterio(era5_dir / f"ERA5_{year}_reproj.tif", chunks={"band": -1, "x": 1000, "y": 1000})
    
    # Assign time coordinates
    dates = pd.date_range(start=f"{year}-01-01", periods=era5_daily.sizes["band"], freq="D")
    era5_daily = era5_daily.assign_coords(band=dates).rename({"band": "time"})
    
    # Resample to monthly total (sum of daily evaporation)
    era5_monthly = era5_daily.resample(time="1ME").sum(skipna=False)
    
    # Save
    era5_monthly.rio.to_raster(output_dir / f"ERA5_monthly_{year}.nc")
    print(f"Saved ERA5 monthly for {year}")

4.3. Stack All Years into a Single Dataset

    Goal: Combine all monthly data (2017–2023) into a single NetCDF file for ML.

In [None]:
# Start a Dask client for parallel processing (adjust based on your RAM)
client = Client(n_workers=4, memory_limit='8GB')  # Example: 4 workers, 8GB each

# Load CHIRPS and ERA5 with Dask chunks
chirps_ds = xr.open_mfdataset(
    [chirps_output_dir / f"CHIRPS_monthly_{year}.nc" for year in range(2017, 2024)],
    chunks={"time": 12, "x": 1000, "y": 1000},  # 1 year per chunk
    combine="nested",
    concat_dim="time",
    parallel=True
).rename({"Band1": "rainfall"})

era5_ds = xr.open_mfdataset(
    [era5_output_dir / f"ERA5_monthly_{year}.nc" for year in range(2017, 2024)],
    chunks={"time": 12, "x": 1000, "y": 1000},
    combine="nested",
    concat_dim="time",
    parallel=True
).rename({"Band1": "evaporation"})


In [None]:
# Load all static soil layers (clay, silt, sand, OCD, WSC)
soil_vars = ["clay", "silt", "sand", "ocd", "wv0010"]
soil_ds = xr.merge([
    rxr.open_rasterio(soil_processed_dir / f"tadla_{var}_processed.tif", chunks={"x": 1000, "y": 1000}).rename(var)
    for var in soil_vars
])

In [None]:
# Load static layers with Dask
dem = rxr.open_rasterio(dem_processed_path, chunks={"x": 1000, "y": 1000}).rename("dem")

In [None]:
import xarray as xr

# Load NDVI (annual composites, reprojected to monthly)
ndvi_files = [land_use_raw_dir / f"Sentinel2_Tadla_NDVI_{year}.tif" for year in range(2017, 2024)]
ndvi_ds = xr.open_mfdataset(
    ndvi_files,
    combine="nested",
    concat_dim="time"
)

# Rechunk the dataset after loading
ndvi_ds = ndvi_ds.chunk({"time": 1, "x": 1000, "y": 1000})

# Rename the variable if it exists
if 'band' in ndvi_ds:
    ndvi_ds = ndvi_ds.rename({"band": "NDVI"})
else:
    print("Variable 'band' not found in the dataset. Please check the dataset structure.")

# Continue with your processing

In [None]:
# Check variables in 1 NDVI file
test_ds = xr.open_dataset(land_use_raw_dir / "Sentinel2_Tadla_NDVI_2017.tif")
print(test_ds.data_vars)  # Likely shows "band", not "Band1"

In [None]:
ndvi_ds

In [None]:
# 1. Assign time coordinates
times = pd.date_range(start="2017-01-01", periods=7, freq="YS")  # Annual start dates
ndvi_ds = ndvi_ds.assign_coords(time=times)

# 2. Drop the redundant "NDVI" coordinate (created during concatenation)
ndvi_ds = ndvi_ds.drop_vars("NDVI")

# 3. Rename "band_data" to "NDVI"
ndvi_ds = ndvi_ds.rename({"band_data": "NDVI"})

# 4. Squeeze out the singleton "NDVI" dimension (size=1)
ndvi_ds = ndvi_ds.squeeze("NDVI")  # Now dimensions are (time, y, x)

# 5. Resample to monthly
ndvi_monthly = ndvi_ds.NDVI.resample(time="1ME").interpolate("linear")

In [None]:
# Merge all variables
final_ds = xr.merge([soil_ds, dem, ndvi_ds, chirps_ds, era5_ds])

In [None]:
final_ds["NDVI"] = ndvi_monthly

In [None]:
print(final_ds.data_vars)
# Expected output:
# ['clay', 'silt', 'sand', 'ocd', 'wcs', 'dem', 'slope', 'NDVI', 'rainfall', 'evaporation']

In [None]:
# Drop CRS variable
final_ds = final_ds.drop_vars("lambert_conformal_conic")

In [None]:
# Convert data types
final_ds = final_ds.astype({
    "clay": "float32",
    "dem": "float32",
    "rainfall": "float32",
    "evaporation": "float32",
    "wv0010": "float32", 
    "sand": "float32", 
    "ocd": "float32", 
    "silt": "float32", 
    "NDVI": "float32"
})

In [None]:
print(final_ds.data_vars)
# Expected output:
# ['clay', 'silt', 'sand', 'ocd', 'wcs', 'dem', 'slope', 'NDVI', 'rainfall', 'evaporation']

In [None]:
# Configure Dask for 16GB RAM
cluster = LocalCluster(
    n_workers=4,          # 4 workers
    threads_per_worker=3, # 4*3=12 threads
    memory_limit="3GB"    # 4*3GB=12GB total
)
client = Client(cluster)

# Assuming soil_ds, dem, ndvi_ds, chirps_ds, era5_ds are already xarray objects

# For static datasets (e.g., soil, dem), you might not need to chunk over time:
soil_ds = soil_ds.chunk({'x': 256, 'y': 256})
dem = dem.chunk({'x': 256, 'y': 256})

# For time-varying datasets, you can chunk along the time dimension:
ndvi_ds = ndvi_ds.chunk({'time': 12, 'x': 256, 'y': 256})
chirps_ds = chirps_ds.chunk({'time': 12, 'x': 256, 'y': 256})
era5_ds = era5_ds.chunk({'time': 12, 'x': 256, 'y': 256})


In [None]:
years = pd.date_range("2017-01-01", "2023-12-31", freq="YS").year.tolist()

def process_year(year):
    try:
        # Select data using lazy indexing
        yearly_ds = xr.merge([
            soil_ds.sel(time=str(year), method='nearest'),
            dem,
            ndvi_ds.sel(time=str(year)),
            chirps_ds.sel(time=str(year)),
            era5_ds.sel(time=str(year))
        ], join='exact')
        
        # Optimized Zarr write
        output_path = harmonized_dir / f"tadla_ml_dataset_{year}.zarr"
        yearly_ds.to_zarr(output_path, consolidated=True, mode="w")
        
        return True
    except Exception as e:
        print(f"Error processing {year}: {str(e)}")
        return False

# Submit all years in parallel
futures = [client.submit(process_year, year) for year in years]

# Monitor progress
from dask.diagnostics import ProgressBar
with ProgressBar():
    results = [f.result() for f in futures]

print(f"Successfully processed {sum(results)}/7 years")

Slope/Aspect Resolution Check

In [None]:
import rasterio
from rasterio.warp import reproject, Resampling
import numpy as np
from pathlib import Path

def resample_to_10m(input_path, output_path, reference_dem_path):
    """Resample a raster to match the grid of a reference DEM."""
    # Load reference DEM metadata
    with rasterio.open(reference_dem_path) as ref:
        ref_profile = ref.profile
        ref_transform = ref.transform
        ref_shape = (ref.height, ref.width)
        ref_crs = ref.crs

    # Open input raster
    with rasterio.open(input_path) as src:
        # Initialize destination array
        dst_data = np.zeros(ref_shape, dtype=src.dtypes[0])

        # Reproject/resample
        reproject(
            source=rasterio.band(src, 1),
            destination=dst_data,
            src_transform=src.transform,
            src_crs=src.crs,
            dst_transform=ref_transform,
            dst_crs=ref_crs,
            resampling=Resampling.bilinear  # For continuous data
        )

        # Update metadata
        ref_profile.update({
            "driver": "GTiff",
            "dtype": dst_data.dtype,
            "nodata": src.nodata
        })

        # Write resampled raster
        with rasterio.open(output_path, "w", **ref_profile) as dst:
            dst.write(dst_data, 1)

# Paths (update as needed)

reference_dem = topography_processed_dir / "tadla_dem_10m.tif"

layers_to_resample = [
    topography_processed_dir / "tadla_slope.tif",
    topography_processed_dir / "tadla_aspect.tif"
]

# Process each layer
for input_path in layers_to_resample:
    temp_output = input_path.with_stem(f"temp_{input_path.stem}")
    
    # Resample to 10m
    resample_to_10m(input_path, temp_output, reference_dem)
    
    # Replace original file
    temp_output.replace(input_path)
    print(f"Resampled {input_path.name} → 10m resolution")

print("✅ Slope/aspect resampling complete!")

In [None]:
def validate_slope_aspect():
    ref_path = dem_processed_path  # Reference DEM
    
    with rasterio.open(ref_path) as ref:
        target_res = ref.res  # Should be (10.0, 10.0)
        target_crs = ref.crs

    issues = []
    for layer in ["tadla_slope.tif", "tadla_aspect.tif"]:
        path = topography_processed_dir / layer
        with rasterio.open(path) as src:
            if src.res != target_res:
                issues.append(f"{layer}: Resolution {src.res} ≠ {target_res}")
            if src.crs != target_crs:
                issues.append(f"{layer}: CRS {src.crs} ≠ {target_crs}")
    
    if issues:
        print("❌ Slope/aspect issues:")
        for issue in issues:
            print(f" - {issue}")
    else:
        print("✅ Slope/aspect have correct resolution (10m) and CRS!")

validate_slope_aspect()

Soil Texture Validation

In [None]:
for param in ["clay", "silt", "sand"]:
    path = soil_raw_dir / f"tadla_{param}.tif"
    with rasterio.open(path) as src:
        data = src.read(1)
        print(f"{param} raw values:")
        print(f"  Min: {data.min()}, Max: {data.max()}")  # Should be 0–1000

In [None]:
import rasterio
import numpy as np

def process_soil_layer(input_path, output_path):
    with rasterio.open(input_path) as src:
        data = src.read(1).astype(np.float32)
        
        # Step 1: Scale SoilGrids values (÷10)
        data = data / 10  # Converts 408 → 40.8%
        
        # Step 2: Clip to 0-100% (handle outliers)
        data = np.clip(data, 0, 100)
        
        # Step 3: Save with 10m grid
        profile = src.profile.copy()
        profile.update(
            dtype=rasterio.float32,
            nodata=-9999,
            driver="GTiff"
        )
        with rasterio.open(output_path, "w", **profile) as dst:
            dst.write(data, 1)

# Example usage
#for param in ["clay", "silt", "sand"]:
for param in ["clay"]:
    input_path = soil_processed_dir / f"tadla_{param}_10m.tif"
    output_path = soil_processed_dir / f"tadla_{param}_10m_p.tif"
    process_soil_layer(input_path, output_path)

In [None]:
for param in ["clay", "silt", "sand"]:
    path = soil_processed_dir / f"tadla_{param}_10m.tif"
    with rasterio.open(path) as src:
        data = src.read(1)
        print(f"{param} processed:")
        print(f"  Min: {data.min():.1f}%, Max: {data.max():.1f}%")  # Should be 0–100%
        print(src.res) 

In [None]:
def validate_soil_texture():
    layers = ["clay", "silt", "sand"]
    soil_data = {}
    nodata = -9999

    # Load data
    for param in layers:
        path = soil_processed_dir / f"tadla_{param}_10m.tif"
        with rasterio.open(path) as src:
            soil_data[param] = src.read(1)
    
    # Mask valid pixels (all layers have data)
    mask = (
        (soil_data["clay"] != nodata) & 
        (soil_data["silt"] != nodata) & 
        (soil_data["sand"] != nodata)
    )
    
    # Calculate total texture
    total = np.full_like(soil_data["clay"], np.nan)
    total[mask] = (
        soil_data["clay"][mask] + 
        soil_data["silt"][mask] + 
        soil_data["sand"][mask]
    )
    total[(total < 80) | (total > 100)] = -9999 

    profile = {
        'driver': 'GTiff',
        'dtype': 'float32',
        'nodata': nodata,
        'width': soil_data["clay"].shape[1],
        'height': soil_data["clay"].shape[0],
        'count': 1,
        'crs': src.crs,
        'transform': src.transform
    }
    # Save masked total texture layer
    with rasterio.open(soil_processed_dir / "total_texture_masked.tif", "w", **profile) as dst:
        dst.write(total, 1)
    # Validate only sums ≤100%
    valid_pixels = total[~np.isnan(total)]
    valid_pixels = valid_pixels[valid_pixels <= 100]  # Exclude >100%
    invalid = np.where(valid_pixels < 80)[0]  # Allow 80-100%

    
    if len(invalid) > 0:
        print(f"⚠️ {len(invalid)} pixels ({len(invalid)/len(valid_pixels):.1%}) <80%")
    else:
        print("✅ Valid soil texture sums (80-100%)")

validate_soil_texture()

Temporal Data Validation

In [None]:
import rasterio
from rasterio.warp import reproject, Resampling

def reproject_annual(year, dataset="CHIRPS"):
    input_dir = Path(weather_processed_dir / f"{dataset}_monthly/")
    output_dir = Path(weather_processed_dir / f"{dataset}_Annual/")
    output_dir.mkdir(exist_ok=True)
    
    # List monthly files for the year
    monthly_files = [input_dir / f"{dataset}_{year}_{month:02d}.tif" 
                     for month in range(1, 13)]
    
    # Reproject and stack bands
    with rasterio.open(monthly_files[0]) as first:
        meta = first.meta.copy()
        meta.update(count=12)  # Explicitly set to 12 bands
    
    with rasterio.open(output_dir / f"{dataset}_{year}_reproj.tif", "w", **meta) as dst:
        for band_idx, monthly_file in enumerate(monthly_files, start=1):
            with rasterio.open(monthly_file) as src:
                reproject(
                    source=rasterio.band(src, 1),
                    destination=rasterio.band(dst, band_idx),
                    src_transform=src.transform,
                    src_crs=src.crs,
                    dst_transform=meta["transform"],
                    dst_crs=meta["crs"],
                    resampling=Resampling.bilinear
                )
    print(f"Processed {dataset}_{year}_reproj.tif with 12 bands")

# Reprocess all years
for year in range(2017, 2024):
    reproject_annual(year, "CHIRPS")
    reproject_annual(year, "ERA5")

In [None]:
def check_temporal_bands():
    datasets = ["CHIRPS", "ERA5"]
    issues = []
    
    for dataset in datasets:
        for year in range(2017, 2024):
            path = Path(weather_processed_dir / f"{dataset}_Annual/{dataset}_{year}_reproj.tif")
            if not path.exists():
                issues.append(f"Missing: {path}")
                continue
            
            with rasterio.open(path) as src:
                if src.count != 12:
                    issues.append(f"{path.name}: {src.count} bands (expected 12)")
    
    if issues:
        print("❌ Temporal band issues:")
        for issue in issues:
            print(f" - {issue}")
    else:
        print("✅ All temporal files have 12 bands")

check_temporal_bands()

In [None]:
import xarray as xr
from pathlib import Path

def check_time_coverage():
    datasets = ["CHIRPS", "ERA5"]
    issues = []
    
    for dataset in datasets:
        for year in range(2017, 2024):
            path = Path(harmonized_dir / f"{dataset}_monthly/{dataset}_monthly_{year}.nc")
            if not path.exists():
                issues.append(f"Missing: {path}")
                continue
            
            ds = xr.open_dataset(path)
            print(f"Coordinates in {path.name}: {ds.coords}")  # Print dataset coordinates
            if 'time' not in ds.coords:
                issues.append(f"{path.name}: 'time' coordinate not found")
                continue
            
            if len(ds.time) != 12:
                issues.append(f"{path.name}: {len(ds.time)} months (expected 12)")
    
    if issues:
        print("❌ Temporal coverage issues:")
        for issue in issues:
            print(f" - {issue}")
    else:
        print("✅ All harmonized datasets have 12 months")

check_time_coverage()

In [None]:
import rasterio

def print_layer_info(path):
    with rasterio.open(path) as src:
        print(f"Layer: {path}")
        print(f"  CRS: {src.crs}")
        print(f"  Resolution: {src.res}")
        print(f"  Bounds: {src.bounds}")
        print(f"  Shape: {src.shape}\n")

# Reference layer (NDVI/DEM)
print_layer_info(land_use_raw_dir / "Sentinel2_Tadla_NDVI_2023.tif")

# Problematic layer
print_layer_info(weather_processed_dir / "CHIRPS_Annual/CHIRPS_2023_reproj.tif")

In [None]:
from rasterio.warp import reproject, Resampling
import numpy as np

# Reference: NDVI layer
with rasterio.open(land_use_raw_dir / "Sentinel2_Tadla_NDVI_2023.tif") as ref:
    target_profile = ref.profile.copy()
    target_profile.update(count=12)  # For 12-band files

def reproject_annual(year, dataset="CHIRPS"):
    src_path = Path(weather_processed_dir / f"{dataset}_Annual/{dataset}_{year}_reproj.tif")
    dst_path = src_path.with_name(f"{dataset}_{year}_aligned.tif")
    
    with rasterio.open(src_path) as src:
        # Initialize destination array
        dst_data = np.zeros((12, target_profile["height"], target_profile["width"]), dtype=src.dtypes[0])
        
        # Reproject each band
        for band in range(1, 13):
            reproject(
                source=rasterio.band(src, band),
                destination=dst_data[band-1],
                src_transform=src.transform,
                src_crs=src.crs,
                dst_transform=target_profile["transform"],
                dst_crs=target_profile["crs"],
                resampling=Resampling.bilinear
            )
    
    # Save aligned file
    with rasterio.open(dst_path, "w", **target_profile) as dst:
        dst.write(dst_data)
    
    # Replace original file
    dst_path.replace(src_path)

# Reproject all years
for year in range(2017, 2024):
    reproject_annual(year, "CHIRPS")
    reproject_annual(year, "ERA5")

##### Spatial Alignment Check

    Ensure all layers (soil, DEM, NDVI, weather) share the exact same grid (resolution, transform, CRS).

In [None]:
import rasterio

def validate_grid_alignment():
    reference = land_use_raw_dir / "Sentinel2_Tadla_NDVI_2023.tif"  # Example NDVI
    
    with rasterio.open(reference) as ref:
        ref_transform = ref.transform
        ref_shape = (ref.height, ref.width)
        ref_crs = ref.crs
    
    layers = [
        soil_processed_dir / "tadla_clay_10m.tif",
        topography_processed_dir / "tadla_dem_10m.tif",
        weather_processed_dir / "CHIRPS_Annual/CHIRPS_2023_reproj.tif"
    ]
    
    for layer in layers:
        with rasterio.open(layer) as src:
            if (src.transform != ref_transform) or (src.shape != ref_shape) or (src.crs != ref_crs):
                print(f"❌ Misaligned: {layer}")
                return
    
    print("✅ All layers aligned!")

validate_grid_alignment()

##### Data Normalization Check

    Verify that input features are normalized (e.g., 0–1 or z-scores) to avoid model bias.

In [None]:
def check_normalization():
    datasets = {
        "NDVI": (land_use_raw_dir / "Sentinel2_Tadla_NDVI_2023.tif", (-1, 1)),
        "Clay": (soil_processed_dir / "tadla_clay_10m.tif", (0, 100)),
        "CHIRPS": (weather_processed_dir / "CHIRPS_Annual/CHIRPS_2023_reproj.tif", (0, 500))  # mm/month
    }
    
    for name, (path, expected_range) in datasets.items():
        with rasterio.open(path) as src:
            data = src.read(1)
            valid_data = data[data != src.nodata]
            min_val, max_val = valid_data.min(), valid_data.max()
            
            if (min_val < expected_range[0]) or (max_val > expected_range[1]):
                print(f"⚠️ {name}: Values ({min_val:.2f}-{max_val:.2f}) outside expected range {expected_range}")
            else:
                print(f"✅ {name}: Within {expected_range}")

check_normalization()

##### Stack Monthly NDVI into a 12-Band File

In [None]:
from rasterio.merge import merge
import numpy as np

monthly_ndvi_files = [
    Path(land_use_raw_dir / f"NDVI_2017_{month}.tif") 
    for month in range(1, 13)
]

# Read all monthly NDVI arrays
bands = []
for f in monthly_ndvi_files:
    with rasterio.open(f) as src:
        bands.append(src.read(1))

# Create stacked array (bands, height, width)
stacked = np.stack(bands, axis=0)

# Copy metadata from first monthly file
with rasterio.open(monthly_ndvi_files[0]) as src:
    meta = src.meta.copy()

# Update metadata for multi-band file
meta.update(count=12)

# Save stacked NDVI
with rasterio.open(land_use_processed_dir / "Sentinel2_Tadla_NDVI_2017.tif", "w", **meta) as dst:
    dst.write(stacked)

##### Temporal Consistency Check

    Ensure time stamps align across datasets (e.g., NDVI and rainfall for the same month).

In [None]:
def check_temporal_alignment():
    # Open the raster files
    ndvi_raster = rasterio.open(land_use_processed_dir / "Sentinel2_Tadla_NDVI_2017.tif")
    rain_raster = rasterio.open(weather_processed_dir / "CHIRPS_Annual/CHIRPS_2017_reproj.tif")
    
    # Print the number of bands in each raster
    print(f"NDVI raster bands: {ndvi_raster.count}")
    print(f"Rainfall raster bands: {rain_raster.count}")

check_temporal_alignment()

In [None]:
def check_temporal_alignment():
    # Check May 2017 (band index 5 for 0-based indexing)
    may_2017_ndvi = rasterio.open(land_use_processed_dir / "Sentinel2_Tadla_NDVI_2017.tif").read(5)
    may_2017_rain = rasterio.open(weather_processed_dir / "CHIRPS_Annual/CHIRPS_2017_reproj.tif").read(5)
    
    if may_2017_ndvi.shape != may_2017_rain.shape:
        print("❌ May 2017 NDVI and rainfall shapes mismatch")
    else:
        print("✅ May 2017 temporal alignment OK")

check_temporal_alignment()

##### NoData 

    Ensure NoData values (-9999) are consistently masked across all layers.

In [None]:
def check_nodata_presence():
    layers = [
        soil_processed_dir / "tadla_clay_10m.tif",
        weather_processed_dir / "CHIRPS_Annual/CHIRPS_2023_reproj.tif"
    ]
    
    for path in layers:
        with rasterio.open(path) as src:
            data = src.read(1)
            nodata_pct = (data == src.nodata).mean() * 100
            print(f"{Path(path).name}: {nodata_pct:.2f}% NoData")

check_nodata_presence()

##### Spatial-Temporal Leakage

    Ensure your train/test split does not mix data from the same location or time.

In [None]:
def validate_train_test_split():
    train_years = [2017, 2018, 2019, 2020, 2021]  # Example
    test_years = [2022, 2023]
    
    # Ensure no overlap
    assert not set(train_years).intersection(test_years), "Leakage: Overlapping years!"
    
    # Ensure spatial split (optional)
    # Example: Train on northern half, test on southern half
    with rasterio.open(land_use_processed_dir / "Sentinel2_Tadla_NDVI_2023.tif") as src:
        height = src.height
        train_mask = np.zeros((height, src.width), dtype=bool)
        train_mask[:height//2, :] = True  # Northern half for training
    
    print("✅ Train/test split validated")

validate_train_test_split()

##### File Corruption Check

    Ensure all files are readable and not corrupted.

In [None]:
def check_file_integrity():
    all_files = [
        Path(land_use_processed_dir / "Sentinel2_Tadla_NDVI_2023.tif"),
        Path(weather_processed_dir / "CHIRPS_Annual/CHIRPS_2023_reproj.tif")
    ]
    
    for path in all_files:
        try:
            with rasterio.open(path):
                pass
        except:
            print(f"❌ Corrupted file: {path}")
    print("✅ All files are readable")

check_file_integrity()

##### Hardware Readiness

    Ensure your system can handle the dataset size.

In [None]:
def check_hardware():
    import psutil
    free_ram = psutil.virtual_memory().available / (1024 ** 3)  # GB
    dataset_size = 12 * 9068 * 18904 * 4 / (1024 ** 3)  # Example: 12 bands, float32
    print(f"Free RAM: {free_ram:.1f} GB | Dataset size: {dataset_size:.1f} GB")
    
    if free_ram < dataset_size * 2:
        print("⚠️ Insufficient RAM – use batch loading or cloud compute")

check_hardware()