# Phase 2: Soil Moisture Prediction 

### Step 1: Prepare Training Data

    We’ll include all features to maximize model accuracy:
        
##### 1. Static Features (do not change over time):
            
    Soil properties: clay, silt, sand, ocd (organic carbon density), wv0010 (water content at saturation).
    Topography: DEM, slope, aspect.
    Dynamic Features (vary monthly):

##### 2. Weather: CHIRPS (rainfall), ERA5 (evaporation).
    
    Vegetation: NDVI.

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import rasterio
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
from pathlib import Path
import yaml
import xarray as xr
import rioxarray as rxr
import geopandas as gpd
import earthpy.plot as ep
from sklearn.preprocessing import MinMaxScaler
import json

# Set seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

In [None]:
# -----------------------------------------------------------------------------
# Load config.yml
# -----------------------------------------------------------------------------

# Get project root (adjust based on your folder depth)
current_dir = Path(os.getcwd())
project_root = current_dir.parent
with open(project_root / "config.yml", "r") as f:
    config = yaml.safe_load(f)

# -----------------------------------------------------------------------------
# Construct paths
# -----------------------------------------------------------------------------
processed_dir = project_root / Path(config['paths']['processed_data'])
soil_dir = processed_dir / "GIS/Soil" # Soil data directory: clay, sand, silt, ocd, wv0110
dem_path = processed_dir / "GIS/Topography/tadla_dem_10m.tif" 
slope_path = processed_dir / "GIS/Topography/tadla_slope.tif"
aspect_path = processed_dir / "GIS/Topography/tadla_aspect.tif"
rainfall_dir = processed_dir / "Weather/CHIRPS_Annual" # Rainfall data directory: chirps from 2017 to 2023, 1 file per year with 12 bands
evapotranspiration_dir = processed_dir / "Weather/ERA5_Annual" # Evapotranspiration data directory: era5 from 2017 to 2023, 1 file per year with 12 bands
boundaries_dir = processed_dir / "GIS/Study_Area_Boundary" 
ndvi_dir = processed_dir / "GIS/Land_Use" # NDVI data directory: ndvi from 2017 to 2023, 1 file per year with 12 bands

In [None]:
# Load static features (soil + topography)
static_data = {
    "clay": rxr.open_rasterio(Path(soil_dir / "tadla_clay_10m.tif")).squeeze(),
    "silt": rxr.open_rasterio(Path(soil_dir / "tadla_silt_10m.tif")).squeeze(),
    "sand": rxr.open_rasterio(Path(soil_dir / "tadla_sand_10m.tif")).squeeze(),
    "ocd": rxr.open_rasterio(Path(soil_dir / "tadla_ocd_10m.tif")).squeeze(),
    "wv0010": rxr.open_rasterio(Path(soil_dir / "tadla_wv0010_10m.tif")).squeeze(),
    "dem": rxr.open_rasterio(Path(dem_path)).squeeze(),
    "slope": rxr.open_rasterio(Path(slope_path)).squeeze(),
    "aspect": rxr.open_rasterio(Path(aspect_path)).squeeze(),
}

In [None]:
import dask

# globally set these defaults
dask.config.set({
    "distributed.worker.memory.target": 0.6,
    "distributed.worker.memory.spill":  0.7,
    "distributed.worker.memory.pause":  0.8,
})

from dask.distributed import Client
client = Client(memory_limit="15GB")

In [None]:
import os
# GDAL_CACHEMAX is in megabytes.  2048 MB = 2 GB.
os.environ['GDAL_CACHEMAX'] = '2048'

In [None]:

def load_annual_data(
    variable: str,
    years=range(2017, 2024),
    per_file_chunks={'band': 1, 'x': 1024, 'y': 1024},
    final_chunks={'time': 1, 'x': 1024, 'y': 1024},
) -> xr.DataArray:
    
    da_list = []

    for year in years:
        # Construct file paths based on variable
        if variable == "NDVI":
            path = ndvi_dir / f"Sentinel2_Tadla_NDVI_{year}.tif"
        elif variable == "CHIRPS":
            path = rainfall_dir / f"CHIRPS_{year}_reproj.tif"
        elif variable == "ERA5":
            path = evapotranspiration_dir / f"ERA5_{year}_reproj.tif"
        else:
            raise ValueError(f"Unknown variable: {variable!r}")

        # 1-band × spatial chunks, lazy disk reads:
        da = rxr.open_rasterio(path, masked=True, chunks=per_file_chunks)

        # assign a proper time index (12 months → 12 timestamps):
        times = pd.date_range(start=f"{year}-01-01", periods=12, freq="MS")
        da = da.assign_coords(band=times).rename({"band": "time"})

        da_list.append(da)

    # concatenate into one DataArray and rechunk on 'time':
    combined = xr.concat(da_list, dim="time")
    return combined.chunk(final_chunks)

In [None]:
# Load all dynamic variables
ndvi = load_annual_data("NDVI")        # Shape: (time=84, y, x)
chirps = load_annual_data("CHIRPS")    # 84 months (7 years * 12)
era5 = load_annual_data("ERA5")        # 84 months

In [None]:
chirps = load_annual_data("CHIRPS").rename("precipitation")    # 84 months (7 years * 12)
era5 = load_annual_data("ERA5").rename("evaporation")       # 84 months

In [None]:
# For NDVI
print(ndvi)
# Output should show:
# - Dimensions: (time: 84, y: ..., x: ...)
# - Coordinates: time, x, y
# - Data variables: band 1

# For CHIRPS
print(chirps)
# Similar structure but variable name "precipitation"

# For ERA5
print(era5)
# Variable name "total_evaporation"

In [None]:
ndvi.attrs["long_name"] = "NDVI"
chirps.attrs["long_name"] = "precipitation"
era5.attrs["long_name"] = "total_evaporation"

In [None]:
print(ndvi.rio.crs)  # Should output "EPSG:26191"

In [None]:
print("NDVI Time Range:", ndvi.time.min().values, "to", ndvi.time.max().values)
# Should output: 2017-01-01 to 2023-12-01

print("Number of Timesteps:", len(ndvi.time))
# Should output: 84

In [None]:
# Check if x/y coordinates match between datasets
x_mismatch = (ndvi.x != chirps.x).any() or (ndvi.x != era5.x).any()
y_mismatch = (ndvi.y != chirps.y).any() or (ndvi.y != era5.y).any()
print(f"Spatial Mismatch: X={x_mismatch}, Y={y_mismatch}")
# Should output: Spatial Mismatch: X=False, Y=False

In [None]:
def spot_check(da, year, month_idx):
    """Check a small subset for a specific year and month."""
    time_idx = (year - 2017) * 12 + month_idx
    subset = da.isel(
        time=time_idx,
        x=slice(5000, 5100),  # Adjust to your region of interest
        y=slice(5000, 5100)
    ).compute()
    return subset.min().item(), subset.max().item()

# Check NDVI for January 2017
ndvi_min, ndvi_max = spot_check(ndvi, 2017, 0)
print(f"NDVI 2017-01: Min={ndvi_min}, Max={ndvi_max} (Expected: ~-0.2 to 0.9)")

# Check CHIRPS for July 2020
chirps_min, chirps_max = spot_check(chirps, 2020, 6)
print(f"CHIRPS 2020-07: Min={chirps_min}, Max={chirps_max} (Expected: ≥0 mm)")

# Check ERA5 for December 2023
era5_min, era5_max = spot_check(era5, 2023, 11)
print(f"ERA5 2023-12: Min={era5_min}, Max={era5_max} (Expected: ≥0 mm)")

In [None]:
def check_nodata(da, x=1000, y=1000):
    """Check if a known NoData pixel is -9999."""
    return da.isel(x=x, y=y).min().compute().item() == -9999

print("NDVI NoData Valid:", check_nodata(ndvi))
print("CHIRPS NoData Valid:", check_nodata(chirps))
print("ERA5 NoData Valid:", check_nodata(era5))

In [None]:
print("NDVI Chunks:", ndvi.chunks)
# Expected: Time=1, X=1024, Y=1024

In [None]:
# Generate expected monthly timestamps
expected_times = pd.date_range("2017-01-01", "2023-12-01", freq="MS")
missing_times = expected_times.difference(ndvi.time.values)
print(f"Missing Timesteps: {len(missing_times)}")
# Should output: 0

In [None]:
ndvi.isel(time=0, x=slice(5000, 6000), y=slice(5000, 6000)).plot.imshow()
plt.title("NDVI - Jan 2017 (Subset)")
plt.show()

## ------------------

In [None]:
# --------------------------------------------------
# Function to load annual files with time coordinates
# --------------------------------------------------
def load_annual_data(variable, years=range(2017, 2024)):
    """Load annual files with 12 bands (Jan-Dec) and assign time coordinates."""
    da_list = []
    
    for year in years:
        # Construct file paths based on variable
        if variable == "NDVI":
            path = ndvi_dir / f"Sentinel2_Tadla_NDVI_{year}.tif"
        elif variable == "CHIRPS":
            path = rainfall_dir / f"CHIRPS_{year}_reproj.tif"
        elif variable == "ERA5":
            path = evapotranspiration_dir / f"ERA5_{year}_reproj.tif"
        
        # Load raster (12 bands = Jan-Dec)
        da = rxr.open_rasterio(path)
        
        # Generate monthly timestamps for the year
        times = pd.date_range(start=f"{year}-01-01", periods=12, freq="MS")
        
        # Assign time coordinates
        da = da.assign_coords(band=times).rename({"band": "time"})
        
        da_list.append(da)
    
    # Combine all years into a single DataArray
    return xr.concat(da_list, dim="time")

# Load all dynamic variables
ndvi = load_annual_data("NDVI")        # Shape: (time=84, y, x)
chirps = load_annual_data("CHIRPS")    # 84 months (7 years * 12)
era5 = load_annual_data("ERA5")        # 84 months

In [None]:
from dask.distributed import Client
from dask.diagnostics import ProgressBar

# --------------------------------------------------
# Function to load annual files with Dask chunking and time coordinates
# --------------------------------------------------
def load_annual_data(variable, years=range(2017, 2024)):
    """
    Load annual files with 12 bands (one per month) and assign monthly timestamps.
    Uses Dask chunking to manage memory when working with large datasets.
    """
    da_list = []
    
    for year in years:
        # Construct file path based on the variable name
        if variable == "NDVI":
            path = ndvi_dir / f"Sentinel2_Tadla_NDVI_{year}.tif"
        elif variable == "CHIRPS":
            path = rainfall_dir / f"CHIRPS_{year}_reproj.tif"
        elif variable == "ERA5":
            path = evapotranspiration_dir / f"ERA5_{year}_reproj.tif"
        else:
            raise ValueError("Unknown variable. Choose from 'NDVI', 'CHIRPS', or 'ERA5'.")
        
        # Use rioxarray to open the raster with Dask chunking.
        # Here we assume that each file has 12 bands representing the months.
        da = rxr.open_rasterio(path, chunks={"band": 12})
        
        # Generate monthly timestamps for the given year
        times = pd.date_range(start=f"{year}-01-01", periods=12, freq="MS")
        
        # Assign time coordinates and rename the 'band' dimension to 'time'
        da = da.assign_coords(band=times).rename({"band": "time"})
        
        da_list.append(da)
    
    # Concatenate all years' DataArrays along the time dimension
    return xr.concat(da_list, dim="time")

In [None]:
# Load dynamic features (time-series)
dynamic_data = {
    "ndvi": rxr.open_rasterio(Path(config["paths"]["ndvi_raw"])),
    "chirps": rxr.open_rasterio(Path(config["paths"]["chirps_processed"])),
    "era5": rxr.open_rasterio(Path(config["paths"]["era5_processed"])),
}

# Combine into a single xarray Dataset
dataset = xr.Dataset({**static_data, **dynamic_data}).to_array(dim="band")