# Forest Loss Driver Analysis (Northern Maine Acaduab Reguib)

**Goal:** Use a Random Forest model to identify the relative importance of
preceding spectral conditions and climate variables in predicting forest loss
events (from Hansen GFC) between 2001 and 2021.

**Data:**
- Annual Landsat C02 L2 median composites (2000-2021), Bands: B, G, R, NIR, SWIR1, SWIR2, NDVI, NBR
- Hansen GFC Loss Year (2000-2021 v1.9)
- Hansen GFC Tree Cover 2000
- Annual TerraClimate: Mean Temperature, Total Precipitation (2000-2021)
- SRTM DEM & Slope

**Methodology:**
1. Load and verify data alignment (CRS, Extent, Resolution).
2. Define sampling strategy: Sample loss pixels and non-loss pixels (within forest mask) annually.
3. Extract features for sampled pixels:
   - Target: Loss in year `t` (binary)
   - Features: Landsat(t-1), Climate(t-1), Climate(t), DEM, Slope
4. Split data temporally (e.g., train 2001-2015, test 2016-2021).
5. Train Random Forest Classifier.
6. Evaluate model performance.
7. Analyze feature importances.

## 1. Setup and Imports

In [11]:
import os
import glob
from pathlib import Path
import time # Keep track of processing time

import numpy as np
import pandas as pd
import rasterio
from rasterio.windows import Window
from rasterio.enums import Resampling # if needed for consistency checks
# import geopandas as gpd # If needed for vector operations later
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay, roc_auc_score
from sklearn.utils import shuffle # For sampling
import matplotlib.pyplot as plt
import joblib # To save the trained model

# Optional: for parallel processing if sampling is slow
# import dask.array as da
# from dask.diagnostics import ProgressBar

print("Libraries imported.")

Libraries imported.


## 2. Configuration and Constants

In [None]:

import os
from pathlib import Path

# --- Project Structure ---
# IMPORTANT: Verify this PROJECT_ROOT path is correct for your system.
# It should be the main folder containing 'data/', 'notebooks/', etc.
PROJECT_ROOT = Path('/Users/benjaminpace/MLCS/mlcs') # <--- ADJUST THIS PATH IF YOUR PROJECT IS ELSEWHERE

# Print the path to be sure, then you can comment it out
print(f"Attempting to use Project Root: {PROJECT_ROOT}")
if not (PROJECT_ROOT / 'data').exists():
    # Try to infer from current working directory if notebook is in 'notebooks/'
    current_dir = Path(os.getcwd())
    if current_dir.name == 'notebooks' and (current_dir.parent / 'data').exists():
        PROJECT_ROOT = current_dir.parent
        print(f"Adjusted Project Root (inferred): {PROJECT_ROOT}")
    else:
        # If still not found, raise a clear error.
        raise FileNotFoundError(
            f"CRITICAL: 'data' directory not found relative to an assumed Project Root of {PROJECT_ROOT}. "
            f"Please verify the PROJECT_ROOT variable in this cell. Current directory is {os.getcwd()}."
        )
else:
    print(f"Project Root confirmed: {PROJECT_ROOT}")


DATA_DIR = PROJECT_ROOT / 'data'
LANDSAT_DIR = DATA_DIR / 'landsat'
HANSEN_DIR = DATA_DIR / 'hansen'
CLIMATE_DIR = DATA_DIR / 'climate'
AUX_DIR = DATA_DIR / 'aux'
OUTPUT_DIR = PROJECT_ROOT / 'output'
OUTPUT_DIR.mkdir(exist_ok=True) # Create output directory if it doesn't exist

print(f"Data Directory: {DATA_DIR}")
print(f"Landsat Directory: {LANDSAT_DIR}")
print(f"Hansen Directory: {HANSEN_DIR}")
print(f"Climate Directory: {CLIMATE_DIR}")
print(f"Auxiliary Directory: {AUX_DIR}")
print(f"Output Directory: {OUTPUT_DIR}")


# --- Data File Paths (These define the expected names within the subdirectories) ---
HANSEN_LOSS_YEAR_PATH = HANSEN_DIR / 'hansen_lossyear_2000_2021.tif'
HANSEN_COVER_2000_PATH = HANSEN_DIR / 'hansen_treecover_2000.tif'
DEM_PATH = AUX_DIR / 'dem.tif'
SLOPE_PATH = AUX_DIR / 'slope.tif'

# --- Analysis Parameters (ULTRA MINIMAL for 6-hour sprint) ---
START_YEAR = 2001 # First year for which we predict loss (uses t-1 predictors from 2000)
END_YEAR = 2003     # Last target year for loss. Targets: 2001, 2002, 2003.
                    # This means Landsat predictors needed up to 2002.
                    # Climate predictors needed up to 2002 (for t-1) and 2003 (for t).
YEARS = list(range(START_YEAR, END_YEAR + 1)) # Python range goes up to, but not including, the stop value.
print(f"Analysis Target Years (YEARS): {YEARS}") # Should be [2001, 2002, 2003]

# Define the bands we expect in the Landsat composites AND THEIR ORDER
# This order MUST match the order of bands in your GeoTIFF files.
LANDSAT_BANDS = ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2', 'NDVI', 'NBR']
print(f"Expected Landsat Bands (in order): {LANDSAT_BANDS}")

CLIMATE_VARS = ['mean_temp', 'total_precip']
print(f"Climate Variables: {CLIMATE_VARS}")

STATIC_VARS = ['dem', 'slope']
print(f"Static Variables: {STATIC_VARS}")

# --- Sampling Parameters ---
NON_LOSS_RATIO = 1      # 1 non-loss point for every 1 loss point (minimal for speed)
MIN_TREE_COVER = 30     # Minimum tree cover in 2000 to be considered 'forest'
RANDOM_STATE = 42       # For reproducibility

# --- Modeling Parameters (ULTRA MINIMAL for speed) ---
TEST_SPLIT_YEAR = 2003   # Train on target years < 2003 (i.e., 2001, 2002)
                         # Test on target years >= 2003 (i.e., 2003)
print(f"Test Split Year (targets < this year for train): {TEST_SPLIT_YEAR}")

RF_N_ESTIMATORS = 10    # Very few trees
RF_MAX_DEPTH = 5        # Very shallow trees
RF_MIN_SAMPLES_LEAF = 50 # Leaves must have at least this many samples
RF_MIN_SAMPLES_SPLIT = 100 # Nodes must have at least this many samples to split
RF_N_JOBS = -1          # Use all available CPU cores (good for RF)
# class_weight='balanced' will be set in the classifier instantiation

Attempting to use Project Root: /Users/benjaminpace/MLCS/mlcs
Project Root confirmed: /Users/benjaminpace/MLCS/mlcs
Data Directory: /Users/benjaminpace/MLCS/mlcs/data
Landsat Directory: /Users/benjaminpace/MLCS/mlcs/data/landsat
Hansen Directory: /Users/benjaminpace/MLCS/mlcs/data/hansen
Climate Directory: /Users/benjaminpace/MLCS/mlcs/data/climate
Auxiliary Directory: /Users/benjaminpace/MLCS/mlcs/data/aux
Output Directory: /Users/benjaminpace/MLCS/mlcs/output
Analysis Target Years (YEARS): [2001, 2002, 2003]
Expected Landsat Bands (in order): ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2', 'NDVI', 'NBR']
Climate Variables: ['mean_temp', 'total_precip']
Static Variables: ['dem', 'slope']
Test Split Year (targets < this year for train): 2003


## 3. Data Loading and Verification

**CRITICAL STEP:** Verify that all input rasters have the *exact same* Coordinate Reference System (CRS), transform (affine), dimensions (width, height), and resolution.

In [21]:
# In Cell 7 of your notebook (verify_raster_alignment function definition)

import rasterio # Ensure rasterio is imported here if not globally in Cell 3
from pathlib import Path # Ensure Path is imported

def verify_raster_alignment(raster_paths):
    """Checks CRS, transform, shape of multiple rasters."""
    print("Verifying raster alignment...")
    if not raster_paths:
        print("  No raster paths provided for verification.")
        return None

    reference_profile = None
    all_aligned = True # Assume aligned until a mismatch is found
    checked_paths_count = 0

    for path_obj in raster_paths:
        path_str = str(path_obj) # Ensure it's a string for rasterio.open
        current_file_path = Path(path_str)

        if not current_file_path.exists():
            print(f"  WARNING: File not found, skipping: {current_file_path.name}")
            # If a file is missing, we can't confirm alignment with it.
            # Depending on strictness, you might want to set all_aligned = False here.
            # For now, we'll just skip it and report issues with files that *do* exist.
            continue

        checked_paths_count += 1
        try:
            with rasterio.open(current_file_path) as src:
                # Get dtype of the first band as representative for the file
                # Note: Rasterio datasets can have bands of different dtypes,
                # but for our GEE exports, they should be consistent within a file.
                current_dtype = src.dtypes[0] # dtypes is a tuple of band dtypes

                profile = {
                    'path_name': current_file_path.name, # Store only name for cleaner printing
                    'crs': src.crs,
                    'transform': src.transform,
                    'width': src.width,
                    'height': src.height,
                    'count': src.count,         # Number of bands
                    'dtype_band1': current_dtype # Data type of the first band
                }
                print(f"--- Checking: {profile['path_name']} ---")
                # print(f"  CRS: {profile['crs']}")
                # print(f"  Transform: {profile['transform']}")
                # print(f"  Shape: ({profile['height']}, {profile['width']})")
                # print(f"  Band Count: {profile['count']}")
                # print(f"  Dtype (Band 1): {profile['dtype_band1']}")


                if reference_profile is None: # This is the first valid file encountered
                    reference_profile = profile
                    print(f"  Set as Reference: CRS={profile['crs']}, Shape=({profile['height']},{profile['width']}), Transform={profile['transform']}")
                else:
                    # Compare current file's profile to the reference profile
                    if profile['crs'] != reference_profile['crs']:
                        print(f"  MISMATCH: CRS ({profile['crs']}) differs from reference ({reference_profile['crs']})")
                        all_aligned = False
                    if profile['transform'] != reference_profile['transform']:
                        # Comparing transforms can be tricky due to float precision.
                        # A more robust check might compare elements with a tolerance.
                        # For now, direct comparison:
                        # print(f"    Current Transform: {profile['transform']}")
                        # print(f"    Reference Transform: {reference_profile['transform']}")
                        # Check if they are "close enough" if direct equality fails due to precision
                        if not profile['transform'].almost_equals(reference_profile['transform']):
                             print(f"  MISMATCH: Transform differs significantly from reference.")
                             all_aligned = False
                    if profile['width'] != reference_profile['width'] or profile['height'] != reference_profile['height']:
                        print(f"  MISMATCH: Shape ({profile['height']},{profile['width']}) differs from reference ({reference_profile['height']},{reference_profile['width']})")
                        all_aligned = False
                    # We don't strictly need to check band count or dtype for *alignment*,
                    # but they are good to be aware of. The band count will be checked later.
                    # Dtype consistency is good, but not an alignment blocker if CRS/Transform/Shape match.

        except Exception as e:
            print(f"  ERROR reading or processing {current_file_path.name}: {type(e).__name__} - {e}")
            all_aligned = False # Treat any read error as an alignment failure for that file

    if checked_paths_count == 0:
        print("\nERROR: No files were found or could be opened for verification.")
        return None

    if reference_profile is None:
        print("\nERROR: No valid reference raster could be established (all checked files had errors or were missing).")
        return None

    if all_aligned:
        print(f"\nSUCCESS: All {checked_paths_count} checked rasters appear ALIGNED with the reference '{reference_profile['path_name']}'.")
        # Return the common profile derived from the reference (excluding its own path name)
        return {key: value for key, value in reference_profile.items() if key != 'path_name'}
    else:
        print("\nERROR: Raster alignment check FAILED for one or more files. Please review MISMATCH messages above.")
        return None

In [22]:
# In Cell 8 of your notebook (Define files_to_check and call verify_raster_alignment)

# --- Gather files to check for the ULTRA-MINIMAL run ---
# These are the specific files that your analysis (with END_YEAR=2003) will touch
# for feature extraction. All of them must exist and be aligned.

files_to_check = [
    # Base Layers
    HANSEN_LOSS_YEAR_PATH,      # Used for sampling targets across all years
    HANSEN_COVER_2000_PATH,     # Used for forest mask in sampling
    DEM_PATH,                   # Static predictor
    SLOPE_PATH,                 # Static predictor

    # Landsat Predictors (t-1):
    # For target_year 2001, predictor is Landsat 2000
    # For target_year 2002, predictor is Landsat 2001
    # For target_year 2003, predictor is Landsat 2002
    LANDSAT_DIR / 'landsat_composite_2000.tif',
    LANDSAT_DIR / 'landsat_composite_2001.tif',
    LANDSAT_DIR / 'landsat_composite_2002.tif',

    # Climate Predictors (t-1) and Concurrent Climate (t):
    # Year 2000 (t-1 for target 2001)
    CLIMATE_DIR / 'mean_temp_2000.tif',
    CLIMATE_DIR / 'total_precip_2000.tif',
    # Year 2001 (t-1 for target 2002; t for target 2001)
    CLIMATE_DIR / 'mean_temp_2001.tif',
    CLIMATE_DIR / 'total_precip_2001.tif',
    # Year 2002 (t-1 for target 2003; t for target 2002)
    CLIMATE_DIR / 'mean_temp_2002.tif',
    CLIMATE_DIR / 'total_precip_2002.tif',
    # Year 2003 (t for target 2003)
    CLIMATE_DIR / 'mean_temp_2003.tif',
    CLIMATE_DIR / 'total_precip_2003.tif',
]

# --- Filter out paths for any datasets you decided to EXCLUDE due to earlier issues ---
# For example, if DEM was misaligned and you decided to drop it for this run:
# files_to_check = [p for p in files_to_check if p != DEM_PATH]
# And ensure 'dem' was removed from STATIC_VARS in Cell 5.
# For now, this script assumes all listed files are intended to be used.

print("Files to be checked for alignment (ensure these exist in your project structure):")
all_files_exist = True
for f_path in files_to_check:
    exists = Path(f_path).exists()
    print(f"  - {f_path} (Exists: {exists})")
    if not exists:
        all_files_exist = False

if not all_files_exist:
    raise FileNotFoundError(
        "CRITICAL: One or more files listed in 'files_to_check' do not exist at the specified path. "
        "Verify your data movement and file naming."
    )
else:
    print("\nAll listed files exist. Proceeding with alignment check...")

# --- Run the alignment check ---
common_profile = verify_raster_alignment(files_to_check) # Function defined in Cell 7

assert common_profile is not None, \
    "CRITICAL: Raster alignment check failed. Stopping execution. " \
    "Review errors printed by 'verify_raster_alignment' above. " \
    "All listed files must be perfectly aligned (CRS, Transform, Dimensions)."

# --- Store key dimensions globally if check passed ---
RASTER_HEIGHT = common_profile['height']
RASTER_WIDTH = common_profile['width']
RASTER_TRANSFORM = common_profile['transform']
RASTER_CRS = common_profile['crs']

print(f"\nCommon Raster Shape: ({RASTER_HEIGHT}, {RASTER_WIDTH})")
print(f"Common CRS: {RASTER_CRS}")
print(f"Common Transform: {RASTER_TRANSFORM}")

# --- Quick check on Landsat band count and descriptions (if available) ---
# This helps verify if LANDSAT_BANDS order matches the file.
try:
    # Check the first Landsat file in our list (should be 2000.tif)
    landsat_to_check_bands = LANDSAT_DIR / 'landsat_composite_2000.tif' # Or any in files_to_check
    if landsat_to_check_bands.exists():
        with rasterio.open(landsat_to_check_bands) as src:
            print(f"\nVerifying bands for: {landsat_to_check_bands.name}")
            print(f"  Number of bands found in file: {src.count}")
            print(f"  Expected number of bands (from LANDSAT_BANDS variable): {len(LANDSAT_BANDS)}")
            if src.count != len(LANDSAT_BANDS):
                print(f"  WARNING: Band count mismatch! Code expects {len(LANDSAT_BANDS)} based on LANDSAT_BANDS list.")
            # Band descriptions are often not set by GEE, but if they were, they'd be useful.
            # print(f"  Band descriptions from file: {src.descriptions}")
            # If src.descriptions is like (None, None, ...), GEE didn't set them, which is common.
    else:
        print(f"WARNING: Cannot check bands, {landsat_to_check_bands.name} not found for this check.")
except Exception as e:
    print(f"\nCould not perform Landsat band count/description check: {type(e).__name__} - {e}")

Files to be checked for alignment (ensure these exist in your project structure):
  - /Users/benjaminpace/MLCS/mlcs/data/hansen/hansen_lossyear_2000_2021.tif (Exists: True)
  - /Users/benjaminpace/MLCS/mlcs/data/hansen/hansen_treecover_2000.tif (Exists: True)
  - /Users/benjaminpace/MLCS/mlcs/data/aux/dem.tif (Exists: True)
  - /Users/benjaminpace/MLCS/mlcs/data/aux/slope.tif (Exists: True)
  - /Users/benjaminpace/MLCS/mlcs/data/landsat/landsat_composite_2000.tif (Exists: True)
  - /Users/benjaminpace/MLCS/mlcs/data/landsat/landsat_composite_2001.tif (Exists: True)
  - /Users/benjaminpace/MLCS/mlcs/data/landsat/landsat_composite_2002.tif (Exists: True)
  - /Users/benjaminpace/MLCS/mlcs/data/climate/mean_temp_2000.tif (Exists: True)
  - /Users/benjaminpace/MLCS/mlcs/data/climate/total_precip_2000.tif (Exists: True)
  - /Users/benjaminpace/MLCS/mlcs/data/climate/mean_temp_2001.tif (Exists: True)
  - /Users/benjaminpace/MLCS/mlcs/data/climate/total_precip_2001.tif (Exists: True)
  - /User

## 4. Sampling Strategy

For each year `t` from 2001 to 2021:
1. Identify pixels where loss occurred exactly in year `t`.
2. Identify pixels that are potential non-loss candidates (forested in 2000, no loss 2001-2021).
3. Sample loss pixels.
4. Sample `NON_LOSS_RATIO` times as many non-loss pixels randomly from the candidates.
5. Store pixel coordinates (row, col) and associated target year/loss status.

In [26]:
all_sampled_points = [] # Initialize or clear if re-running

print("Starting pixel sampling...")
sampling_start_time = time.time()

# These should be defined from Cell 5 (Configuration)
# HANSEN_LOSS_YEAR_PATH, HANSEN_COVER_2000_PATH
# RASTER_HEIGHT, RASTER_WIDTH (from Cell 8 common_profile)
# MIN_TREE_COVER, YEARS, NON_LOSS_RATIO, RANDOM_STATE

try:
    print("Loading base data for sampling (Hansen Loss Year and Cover 2000)...")
    with rasterio.open(HANSEN_LOSS_YEAR_PATH) as loss_src:
        # Verify shape matches common profile before reading
        assert (loss_src.height, loss_src.width) == (RASTER_HEIGHT, RASTER_WIDTH), \
            f"Loss year shape ({loss_src.height},{loss_src.width}) mismatch with common profile ({RASTER_HEIGHT},{RASTER_WIDTH})!"
        loss_year_data = loss_src.read(1)
        print(f"  Loaded Loss Year data ({loss_src.height}x{loss_src.width})")

    with rasterio.open(HANSEN_COVER_2000_PATH) as cover_src:
        assert (cover_src.height, cover_src.width) == (RASTER_HEIGHT, RASTER_WIDTH), \
            f"Cover 2000 shape ({cover_src.height},{cover_src.width}) mismatch with common profile ({RASTER_HEIGHT},{RASTER_WIDTH})!"
        cover_2000_data = cover_src.read(1)
        print("  Loaded Tree Cover 2000 data")

    # --- Create mask for non-loss candidate pixels ---
    # Condition 1: Sufficient tree cover in 2000
    forest_mask = cover_2000_data >= MIN_TREE_COVER
    # Condition 2: No loss recorded between 2001 and 2021 (loss year == 0 in Hansen)
    no_loss_mask = loss_year_data == 0
    # Combined mask for pixels eligible to be sampled as "non-loss" controls
    non_loss_candidate_mask = forest_mask & no_loss_mask
    # Get indices (row, col arrays) where mask is True
    non_loss_candidate_indices = np.where(non_loss_candidate_mask)
    num_non_loss_candidates = len(non_loss_candidate_indices[0])
    print(f"Found {num_non_loss_candidates} potential non-loss candidate pixels.")

    # Check if we have candidates to sample from
    if num_non_loss_candidates == 0:
        raise ValueError("No non-loss candidate pixels found based on criteria. Cannot sample.")

    # --- Sample pixels year by year ---
    np.random.seed(RANDOM_STATE) # for reproducibility

    for year_num, target_year in enumerate(YEARS): # YEARS from Cell 5 (e.g., [2001, 2002, 2003])
        start_time_year = time.time()
        print(f"\n--- Sampling for target year: {target_year} ---")

        # 1. Find pixels with loss IN THIS target_year
        # Hansen loss year codes 1-21 correspond to years 2001-2021
        hansen_loss_code = target_year - 2000
        loss_pixels_this_year_mask = loss_year_data == hansen_loss_code
        loss_indices_this_year = np.where(loss_pixels_this_year_mask)
        num_loss_pixels = len(loss_indices_this_year[0])
        print(f"  Found {num_loss_pixels} loss pixels.")

        if num_loss_pixels == 0:
            print("  No loss pixels found for this year. Skipping sampling for this year.")
            continue

        # 2. Add loss pixels to sample list
        for r, c in zip(loss_indices_this_year[0], loss_indices_this_year[1]):
            all_sampled_points.append({'row': r, 'col': c, 'target_year': target_year, 'loss': 1})

        # 3. Sample non-loss pixels
        # Calculate how many non-loss points to sample for this year
        num_non_loss_to_sample = min(num_non_loss_candidates, num_loss_pixels * NON_LOSS_RATIO)
        print(f"  Sampling {num_non_loss_to_sample} non-loss pixels (NON_LOSS_RATIO={NON_LOSS_RATIO}).")

        if num_non_loss_to_sample > 0:
            # Randomly choose indices from the non_loss_candidate_indices array
            sampled_candidate_indices_idx = np.random.choice(
                num_non_loss_candidates, num_non_loss_to_sample, replace=False # No replacement
            )
            # Get the actual row and column values using these sampled indices
            sampled_non_loss_rows = non_loss_candidate_indices[0][sampled_candidate_indices_idx]
            sampled_non_loss_cols = non_loss_candidate_indices[1][sampled_candidate_indices_idx]

            # Add non-loss pixels to sample list
            for r, c in zip(sampled_non_loss_rows, sampled_non_loss_cols):
                all_sampled_points.append({'row': r, 'col': c, 'target_year': target_year, 'loss': 0})
        
        current_total_points = len(all_sampled_points)
        points_this_year = num_loss_pixels + num_non_loss_to_sample
        print(f"  Finished sampling for {target_year}. Points this year: {points_this_year}. Cumulative: {current_total_points}. Time: {time.time() - start_time_year:.2f}s")

    # --- Cleanup large arrays from memory ---
    del loss_year_data, cover_2000_data, forest_mask, no_loss_mask, non_loss_candidate_mask
    del loss_pixels_this_year_mask, loss_indices_this_year # non_loss_candidate_indices is large too
    if 'non_loss_candidate_indices' in locals(): del non_loss_candidate_indices
    import gc
    gc.collect() # Try to free memory

    # Shuffle all collected points once at the end
    if all_sampled_points: # Only shuffle if list is not empty
        all_sampled_points = shuffle(all_sampled_points, random_state=RANDOM_STATE)
        print(f"\nTotal points sampled across all years (before any subsampling): {len(all_sampled_points)}")
        # Optional: Quick check on class balance
        loss_counts = pd.Series([p['loss'] for p in all_sampled_points]).value_counts()
        print(f"Sampled loss counts (before any subsampling):\n{loss_counts}")
    else:
        print("\nWARNING: No points were sampled across any year. Check data or sampling logic.")


    print(f"Sampling finished. Total time: {time.time() - sampling_start_time:.2f}s")

except FileNotFoundError as e:
     print(f"ERROR: Required file not found during sampling: {e}. Please ensure data exists and paths in Cell 5 are correct.")
except ValueError as e:
     print(f"ERROR during sampling setup: {e}")
except NameError as e:
     print(f"ERROR: A variable was not defined (likely from Cell 5 or Cell 8). {e}")
except Exception as e:
    print(f"An unexpected error occurred during sampling: {type(e).__name__} - {e}")
    # raise e # Uncomment to see full traceback if needed

Starting pixel sampling...
Loading base data for sampling (Hansen Loss Year and Cover 2000)...
  Loaded Loss Year data (10407x11824)
  Loaded Tree Cover 2000 data
Found 83160460 potential non-loss candidate pixels.

--- Sampling for target year: 2001 ---
  Found 705266 loss pixels.
  Sampling 705266 non-loss pixels (NON_LOSS_RATIO=1).
  Finished sampling for 2001. Points this year: 1410532. Cumulative: 1410532. Time: 2.62s

--- Sampling for target year: 2002 ---
  Found 716421 loss pixels.
  Sampling 716421 non-loss pixels (NON_LOSS_RATIO=1).
  Finished sampling for 2002. Points this year: 1432842. Cumulative: 2843374. Time: 2.35s

--- Sampling for target year: 2003 ---
  Found 385892 loss pixels.
  Sampling 385892 non-loss pixels (NON_LOSS_RATIO=1).
  Finished sampling for 2003. Points this year: 771784. Cumulative: 3615158. Time: 2.38s

Total points sampled across all years (before any subsampling): 3615158
Sampled loss counts (before any subsampling):
0    1807579
1    1807579
Name:

In [27]:

# Define a much smaller maximum number of samples to actually use for feature extraction
MAX_SAMPLES_FOR_SPRINT = 100000  # Aim for 100k total samples for feature extraction
                                 # This is a huge reduction if all_sampled_points is large.

# Check if all_sampled_points exists from Cell 10 and has data
if 'all_sampled_points' in locals() and all_sampled_points: # Check if list is not empty
    original_sample_count = len(all_sampled_points)
    print(f"\nOriginal number of sampled points from Cell 10: {original_sample_count}")

    if original_sample_count > MAX_SAMPLES_FOR_SPRINT:
        print(f"Aggressively subsampling to approximately {MAX_SAMPLES_FOR_SPRINT} points for speed...")
        
        # Ensure shuffling before taking a slice for randomness if not already shuffled,
        # or re-shuffle if you want extra randomness for the slice.
        # Cell 10 already shuffles at the end, so this might be redundant but harmless.
        all_sampled_points = shuffle(all_sampled_points, random_state=RANDOM_STATE) # RANDOM_STATE from Cell 5
        
        # Take the slice
        all_sampled_points = all_sampled_points[:MAX_SAMPLES_FOR_SPRINT]
        print(f"New number of points to process for features: {len(all_sampled_points)}")

        # Optional: Check class balance of the subsample
        if all_sampled_points: # Check if list is not empty after slicing
            subsample_loss_counts = pd.Series([p['loss'] for p in all_sampled_points]).value_counts()
            print(f"Subsampled loss counts (approximate due to random slice):\n{subsample_loss_counts}")
        else:
            print("WARNING: Subsampling resulted in an empty list of points.")
            
    else:
        print(f"Number of sampled points ({original_sample_count}) is already within or below MAX_SAMPLES_FOR_SPRINT ({MAX_SAMPLES_FOR_SPRINT}). No further subsampling applied.")
else:
    print("WARNING: 'all_sampled_points' not found or is empty. Ensure Cell 10 (Sampling) ran successfully and produced points.")

# --- End of subsampling ---


Original number of sampled points from Cell 10: 3615158
Aggressively subsampling to approximately 100000 points for speed...
New number of points to process for features: 100000
Subsampled loss counts (approximate due to random slice):
0    50037
1    49963
Name: count, dtype: int64


## 5. Feature Extraction

Iterate through the `all_sampled_points`. For each point (row, col, target_year, loss):
- Read pixel values from Landsat(target_year - 1).
- Read pixel values from Climate(target_year - 1).
- Read pixel values from Climate(target_year).
- Read pixel values from DEM and Slope.
- Store features and target in a structure suitable for Scikit-learn (e.g., NumPy array or Pandas DataFrame).

In [None]:
# In Cell 12 of your notebook (Feature Extraction)

# Ensure necessary libraries are available
# import pandas as pd
# import rasterio
# from rasterio.windows import Window
# from pathlib import Path
# import time
# import gc

# Prepare for feature extraction
features = [] # Initialize or clear if re-running
# Target 'y' is implicitly defined by the 'loss' key in all_sampled_points
# The all_sampled_points variable will now be the SUBSAMPLED list if Cell 10.1 ran.

# These should be defined from Cell 5 or Cell 8:
# LANDSAT_DIR, CLIMATE_DIR, AUX_DIR, DEM_PATH, SLOPE_PATH
# LANDSAT_BANDS, RANDOM_STATE (though RANDOM_STATE not directly used here)

print("Starting feature extraction (NaN DEBUG MODE - dropna disabled)...") # Indicate debug mode
extraction_start_time = time.time()

# Cache for opened rasterio file handles
open_files_cache = {}

def get_file_handle(path_obj, cache): # path_obj is a Path object
    """Gets or opens a rasterio file handle, caching it."""
    path_str = str(path_obj) # Use string representation as key for cache
    if path_str not in cache:
        if not path_obj.exists():
             raise FileNotFoundError(f"Required file for feature extraction not found: {path_str}")
        # print(f"  Opening file: {path_obj.name}") # Can be verbose, uncomment if debugging
        cache[path_str] = rasterio.open(path_str)
    return cache[path_str]

def get_pixel_value_at_rc(src, row, col):
    """Reads pixel value(s) for all bands at a specific row, col."""
    window = Window(col, row, 1, 1)
    return src.read(window=window).squeeze() # squeeze removes singleton dimensions

# --- Process all sampled points (which should be the subsampled list now) ---
extraction_errors = 0

# Check if all_sampled_points exists and is iterable
if 'all_sampled_points' not in locals() or not isinstance(all_sampled_points, list):
    raise NameError("'all_sampled_points' is not defined or not a list. Run Cell 10 and Cell 10.1 first.")
if not all_sampled_points: # Check if the list is empty
    print("WARNING: 'all_sampled_points' is empty. No features to extract. Did subsampling remove all points?")

num_points_to_process = len(all_sampled_points)
print(f"Attempting to extract features for {num_points_to_process} points...")

for i, point in enumerate(all_sampled_points):
    row, col = point['row'], point['col']
    t_year = point['target_year'] # Year loss status is defined for
    t_minus_1 = t_year - 1      # Year for predictor data

    # Print progress periodically (e.g., every 10% or fixed number)
    if num_points_to_process > 0 and (i + 1) % max(1, num_points_to_process // 10) == 0 :
         elapsed_time = time.time() - extraction_start_time
         points_per_sec = (i + 1) / elapsed_time if elapsed_time > 0 else 0
         print(f"  Processed {(i + 1)}/{num_points_to_process} points... ({points_per_sec:.1f} points/sec)")

    # --- Feature dictionary for this point ---
    point_features = {'row': row, 'col': col, 'target_year': t_year, 'loss': point['loss']}

    try:
        # --- Landsat (t-1) ---
        lsat_path = LANDSAT_DIR / f'landsat_composite_{t_minus_1}.tif'
        lsat_src = get_file_handle(lsat_path, open_files_cache)
        lsat_values = get_pixel_value_at_rc(lsat_src, row, col)
        if not isinstance(lsat_values, np.ndarray) or lsat_values.ndim == 0:
            if len(LANDSAT_BANDS) == 1:
                lsat_values_array = np.array([lsat_values]) if not isinstance(lsat_values, np.ndarray) else lsat_values.reshape(1)
            else:
                # If we expect multiple bands but got a scalar, fill with NaNs for all bands
                # print(f"DEBUG: Got scalar/0-dim Landsat for point {i}, year {t_minus_1}. Filling with NaNs.")
                lsat_values_array = np.full(len(LANDSAT_BANDS), np.nan)
        else:
            lsat_values_array = lsat_values

        if len(lsat_values_array) != len(LANDSAT_BANDS):
            # This case should ideally be caught by the GEE export or raster verification if band counts differ.
            # If it happens here, it means pixel read is inconsistent. Fill with NaNs.
            # print(f"DEBUG: Band count mismatch for point {i}, year {t_minus_1}. Read {len(lsat_values_array)}, expected {len(LANDSAT_BANDS)}. Filling with NaNs.")
            lsat_values_array = np.full(len(LANDSAT_BANDS), np.nan)
            
        for band_name, band_value in zip(LANDSAT_BANDS, lsat_values_array):
            point_features[f'lsat_{band_name}_{t_minus_1}'] = band_value


        # --- Climate (t-1) ---
        temp_tm1_path = CLIMATE_DIR / f'mean_temp_{t_minus_1}.tif'
        precip_tm1_path = CLIMATE_DIR / f'total_precip_{t_minus_1}.tif'
        temp_tm1_src = get_file_handle(temp_tm1_path, open_files_cache)
        precip_tm1_src = get_file_handle(precip_tm1_path, open_files_cache)
        point_features[f'temp_{t_minus_1}'] = get_pixel_value_at_rc(temp_tm1_src, row, col)
        point_features[f'precip_{t_minus_1}'] = get_pixel_value_at_rc(precip_tm1_src, row, col)

        # --- Climate (t) ---
        temp_t_path = CLIMATE_DIR / f'mean_temp_{t_year}.tif'
        precip_t_path = CLIMATE_DIR / f'total_precip_{t_year}.tif'
        temp_t_src = get_file_handle(temp_t_path, open_files_cache)
        precip_t_src = get_file_handle(precip_t_path, open_files_cache)
        point_features[f'temp_{t_year}'] = get_pixel_value_at_rc(temp_t_src, row, col)
        point_features[f'precip_{t_year}'] = get_pixel_value_at_rc(precip_t_src, row, col)

        # --- Static Vars (Only if 'dem' and 'slope' are in STATIC_VARS from Cell 5) ---
        if 'dem' in STATIC_VARS: # STATIC_VARS defined in Cell 5
            dem_src = get_file_handle(DEM_PATH, open_files_cache)
            point_features['dem'] = get_pixel_value_at_rc(dem_src, row, col)
        if 'slope' in STATIC_VARS: # STATIC_VARS defined in Cell 5
            slope_src = get_file_handle(SLOPE_PATH, open_files_cache)
            point_features['slope'] = get_pixel_value_at_rc(slope_src, row, col)

        # --- Append the complete feature dictionary ---
        features.append(point_features)

    except FileNotFoundError as e:
         # This error should have been caught by get_file_handle but handle defensively
         # print(f"WARNING: Skipping point {i} due to missing file during feature extraction: {e}") # Can be very verbose
         extraction_errors += 1
         # Fill with NaNs if a file is missing to keep row structure, NaNs will be counted later
         # Create a placeholder with NaNs for expected features if we can't read
         for band_name in LANDSAT_BANDS: point_features[f'lsat_{band_name}_{t_minus_1}'] = np.nan
         point_features[f'temp_{t_minus_1}'] = np.nan
         point_features[f'precip_{t_minus_1}'] = np.nan
         point_features[f'temp_{t_year}'] = np.nan
         point_features[f'precip_{t_year}'] = np.nan
         if 'dem' in STATIC_VARS: point_features['dem'] = np.nan
         if 'slope' in STATIC_VARS: point_features['slope'] = np.nan
         features.append(point_features) # Append the dict with NaNs
         continue
    except Exception as e:
        # print(f"ERROR extracting features for point {i} (row={row}, col={col}, target_year={t_year}): {type(e).__name__} - {e}") # Can be very verbose
        extraction_errors += 1
        # Fill with NaNs for this problematic point
        for band_name in LANDSAT_BANDS: point_features[f'lsat_{band_name}_{t_minus_1}'] = np.nan
        point_features[f'temp_{t_minus_1}'] = np.nan
        point_features[f'precip_{t_minus_1}'] = np.nan
        point_features[f'temp_{t_year}'] = np.nan
        point_features[f'precip_{t_year}'] = np.nan
        if 'dem' in STATIC_VARS: point_features['dem'] = np.nan
        if 'slope' in STATIC_VARS: point_features['slope'] = np.nan
        features.append(point_features) # Append the dict with NaNs
        continue

# --- Close all opened files in the cache ---
print("\nClosing cached raster files...")
for path_str_cached, src_cached in open_files_cache.items(): # Use different var names
    try:
        src_cached.close()
    except Exception as e:
        print(f"Error closing file {path_str_cached}: {e}")
open_files_cache.clear() # Clear the cache explicitly

# --- Final summary of feature extraction ---
final_elapsed_time = time.time() - extraction_start_time
print(f"Finished feature extraction. Total time: {final_elapsed_time:.2f}s")
print(f"Attempted to process {num_points_to_process} points. Appended {len(features)} feature sets (should be same).")
if extraction_errors > 0:
    print(f"Encountered {extraction_errors} errors during individual pixel reads (features for these points were set to NaN).")

# --- Convert list of dictionaries to DataFrame ---
if not features:
    print("CRITICAL WARNING: No features were constructed (features list is empty). Cannot proceed to modeling.")
    # raise SystemExit("No features constructed. Halting.") # Keep this commented for now
else:
    print("\nConverting extracted features to DataFrame...")
    feature_df = pd.DataFrame(features)
    del features # Free up memory from the list of dicts
    import gc
    gc.collect()

    print(f"DataFrame shape (before any dropna): {feature_df.shape}")
    print("Sample data (first 5 rows, may contain NaNs):")
    print(feature_df.head())
    
    print("\nFull check for missing values (NaNs) per column:")
    nan_counts = feature_df.isnull().sum()
    # Print all columns and their NaN counts, even if zero, for full visibility
    print(nan_counts)
    total_nans = nan_counts.sum()
    if total_nans == 0:
        print("SUCCESS: No NaN values found in the entire feature DataFrame!")
    else:
        print(f"WARNING: Found a total of {total_nans} NaN values across all columns.")
        print("Columns with NaN counts > 0:")
        print(nan_counts[nan_counts > 0])


    # --- MODIFICATION: Handle Missing Data (Temporarily Disabled dropna) ---
    print("\nSKIPPING dropna() for NaN DEBUGGING.")
    # feature_df = feature_df.dropna() # <--- THIS IS COMMENTED OUT FOR DEBUGGING
    initial_rows = len(feature_df) # Will be same as final_rows for now
    final_rows = len(feature_df)

    # if initial_rows != final_rows: # This condition will be false now
    #     print(f"\nDropped {initial_rows - final_rows} rows containing NaN values.")
    #     print(f"Final DataFrame shape after dropna: {feature_df.shape}")
    # else:
    # print("\nNo rows dropped due to NaN values (because dropna is disabled).")

    if final_rows == 0 and initial_rows > 0 :
        # This condition should not be met if dropna() is disabled and features list was not empty
        print("DEBUG WARNING: DataFrame is empty but dropna was disabled. Problem likely in feature list creation.")
    elif final_rows == 0 and initial_rows == 0:
         print("CRITICAL WARNING: Feature DataFrame is empty (likely no points processed or all had errors).")
         # raise SystemExit("Feature DataFrame empty. Halting.") # Keep commented for now
    else:
        # --- Prepare final X and y for modeling (will contain NaNs if present) ---
        target_series = feature_df['loss']
        year_series = feature_df['target_year'] # Keep track of year for temporal split

        columns_to_drop_for_X = ['row', 'col', 'target_year', 'loss']
        columns_present_to_drop = [col for col in columns_to_drop_for_X if col in feature_df.columns]
        
        X = feature_df.drop(columns=columns_present_to_drop)
        y = target_series

        FEATURE_NAMES = X.columns.tolist()

        print(f"\nShape of X (features, may contain NaNs): {X.shape}")
        print(f"Shape of y (target): {y.shape}")
        print("\nFinal feature columns used for X (model input):", FEATURE_NAMES)
        print("\nTarget distribution (y):") # Note: y should not have NaNs as it comes from 'loss'
        print(y.value_counts(normalize=True, dropna=False)) # dropna=False to see if y itself has NaNs

Starting feature extraction (NaN DEBUG MODE - dropna disabled)...
Attempting to extract features for 100000 points...



## 6. Train/Test Split (Temporal): TODO

Split the data based on the `target_year`. Data from years before `TEST_SPLIT_YEAR` will be used for training, and data from `TEST_SPLIT_YEAR` onwards for testing. This prevents data leakage from the future into the training set.


In [None]:
print(f"Splitting data temporally based on year {TEST_SPLIT_YEAR}...")

# Use the 'year_series' created alongside X and y
train_mask = year_series < TEST_SPLIT_YEAR
test_mask = year_series >= TEST_SPLIT_YEAR

X_train = X[train_mask]
y_train = y[train_mask]
X_test = X[test_mask]
y_test = y[test_mask]

# Keep track of years for verification (optional)
# train_years = year_series[train_mask]
# test_years = year_series[test_mask]

# Cleanup intermediate objects if memory is tight
# del X, y, feature_df, year_series, train_mask, test_mask
# gc.collect()

print(f"Train set shape: X={X_train.shape}, y={y_train.shape}")
print(f"Test set shape: X={X_test.shape}, y={y_test.shape}")

# Verify the split worked as expected
if not X_train.empty and not X_test.empty:
    print(f"Train target distribution:\n{y_train.value_counts(normalize=True)}")
    print(f"Test target distribution:\n{y_test.value_counts(normalize=True)}")
    # print(f"Train years range: {train_years.min()} - {train_years.max()}")
    # print(f"Test years range: {test_years.min()} - {test_years.max()}")
else:
    print("WARNING: Train or Test set is empty after split. Check TEST_SPLIT_YEAR and data distribution.")

## 7. Model Training (Random Forest)

Train a Random Forest Classifier on the training data. We use `class_weight='balanced'` to help handle potential class imbalance between loss and no-loss pixels.


In [None]:
# Check if training data exists
if X_train.empty or y_train.empty:
    raise SystemExit("Training data is empty. Cannot train model.")

print("Training Random Forest model...")
train_start_time = time.time()

rf_classifier = RandomForestClassifier(
    n_estimators=RF_N_ESTIMATORS,
    max_depth=RF_MAX_DEPTH,
    random_state=RANDOM_STATE,
    n_jobs=RF_N_JOBS,             # Use all available cores
    class_weight='balanced',      # Adjust for class imbalance
    oob_score=False               # Set to True to estimate generalization score without test set (slower)
)

# Train the model
rf_classifier.fit(X_train, y_train)

print(f"Finished training. Total time: {time.time() - train_start_time:.2f}s")

# Optional: Print OOB score if calculated
# if rf_classifier.oob_score:
#     print(f"Out-of-Bag (OOB) Score: {rf_classifier.oob_score_:.4f}")

# --- Save the trained model ---
model_filename = OUTPUT_DIR / f'rf_forest_loss_model_{START_YEAR}_{END_YEAR}.joblib'
try:
    joblib.dump(rf_classifier, model_filename)
    print(f"Model saved successfully to: {model_filename}")
except Exception as e:
    print(f"ERROR saving model: {e}")

## 8. Model Evaluation

Evaluate the trained model's performance on the temporally independent test set using various classification metrics.


In [None]:
# Check if test data exists
if X_test.empty or y_test.empty:
    print("Test data is empty. Skipping evaluation.")
else:
    print("Evaluating model on the test set...")

    # Make predictions
    y_pred = rf_classifier.predict(X_test)
    y_pred_proba = rf_classifier.predict_proba(X_test)[:, 1] # Probability of class 1 (loss)

    # --- Classification Report ---
    print("\nClassification Report:")
    # Use zero_division=0 to avoid warnings if a class has no predicted samples
    print(classification_report(y_test, y_pred, target_names=['No Loss', 'Loss'], zero_division=0))

    # --- Confusion Matrix ---
    print("\nConfusion Matrix:")
    try:
        cm = confusion_matrix(y_test, y_pred)
        print(cm)
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['No Loss', 'Loss'])
        fig, ax = plt.subplots(figsize=(6, 6))
        disp.plot(cmap=plt.cm.Blues, ax=ax)
        plt.title('Confusion Matrix')
        plt.show()
    except Exception as e:
        print(f"Could not generate Confusion Matrix plot: {e}")


    # --- ROC AUC Score ---
    # Check if both classes are present in y_test for ROC AUC calculation
    if len(np.unique(y_test)) > 1:
        roc_auc = roc_auc_score(y_test, y_pred_proba)
        print(f"\nROC AUC Score: {roc_auc:.4f}")

        # Optional: Plot ROC Curve
        try:
            from sklearn.metrics import RocCurveDisplay
            fig, ax = plt.subplots(figsize=(6, 6))
            RocCurveDisplay.from_predictions(y_test, y_pred_proba, ax=ax, name='Random Forest')
            plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Chance')
            plt.title('Receiver Operating Characteristic (ROC) Curve')
            plt.legend()
            plt.show()
        except Exception as e:
            print(f"Could not generate ROC Curve plot: {e}")
    else:
        print("\nROC AUC Score cannot be calculated: Only one class present in the test set.")

## 9. Feature Importance Analysis

Analyze the feature importances provided by the Random Forest model (based on mean decrease in impurity - Gini importance) to understand which variables were most influential in the model's predictions.g



In [None]:
print("Calculating and plotting Feature Importances...")

try:
    importances = rf_classifier.feature_importances_
    # Get standard deviations of importances across trees (optional, adds compute time)
    # std = np.std([tree.feature_importances_ for tree in rf_classifier.estimators_], axis=0)
    indices = np.argsort(importances)[::-1] # Sort features by importance (descending)

    # --- Create DataFrame for easier plotting/viewing ---
    importance_df = pd.DataFrame({
        'Feature': [FEATURE_NAMES[i] for i in indices],
        'Importance': importances[indices],
        # 'StdDev': std[indices] # Uncomment if std is calculated
    })

    print("\nTop 20 Feature Importances:")
    print(importance_df.head(20))

    # --- Plot Feature Importances ---
    N_FEATURES_TO_PLOT = 20
    plt.figure(figsize=(12, max(6, N_FEATURES_TO_PLOT // 2))) # Adjust height based on number of features
    plt.title(f"Feature Importances (Top {N_FEATURES_TO_PLOT})")
    plt.barh(range(N_FEATURES_TO_PLOT), # Use barh for horizontal plot
             importance_df['Importance'].head(N_FEATURES_TO_PLOT)[::-1], # Plot descending importance
             # xerr=importance_df['StdDev'].head(N_FEATURES_TO_PLOT)[::-1], # Uncomment if std is calculated
             align='center')
    plt.yticks(range(N_FEATURES_TO_PLOT),
               importance_df['Feature'].head(N_FEATURES_TO_PLOT)[::-1]) # Labels for y-axis
    plt.xlabel("Mean Decrease in Impurity (Gini Importance)")
    plt.ylabel("Feature")
    plt.ylim([-1, N_FEATURES_TO_PLOT])
    plt.tight_layout()
    plt.show()

    # Optional: Save importance data
    importance_filename = OUTPUT_DIR / f'rf_feature_importances_{START_YEAR}_{END_YEAR}.csv'
    importance_df.to_csv(importance_filename, index=False)
    print(f"Feature importances saved to: {importance_filename}")

except Exception as e:
    print(f"An error occurred during feature importance calculation or plotting: {e}")

## 10. (Optional) Prediction Map Visualization

This section provides a function to generate a spatial map of predicted loss probability for a given year using the trained model. This is computationally intensive and requires significant memory/time.


In [None]:
# --- This section is computationally intensive ---

def predict_loss_probability_map(model, feature_names_ordered, output_path, year_to_predict):
    """Generates a GeoTIFF map of predicted loss probability for a given year."""
    print(f"\nGenerating loss probability map for year: {year_to_predict}")
    prediction_start_time = time.time()
    t_minus_1 = year_to_predict - 1

    # --- Define input file paths for the prediction year ---
    lsat_path = LANDSAT_DIR / f'landsat_composite_{t_minus_1}.tif'
    temp_tm1_path = CLIMATE_DIR / f'mean_temp_{t_minus_1}.tif'
    precip_tm1_path = CLIMATE_DIR / f'total_precip_{t_minus_1}.tif'
    temp_t_path = CLIMATE_DIR / f'mean_temp_{year_to_predict}.tif'
    precip_t_path = CLIMATE_DIR / f'total_precip_{year_to_predict}.tif'
    dem_path_pred = DEM_PATH # Static path
    slope_path_pred = SLOPE_PATH # Static path

    # Check if all required files exist before starting
    required_files = {
        'lsat': lsat_path, 'temp_tm1': temp_tm1_path, 'prec_tm1': precip_tm1_path,
        'temp_t': temp_t_path, 'prec_t': precip_t_path, 'dem': dem_path_pred, 'slope': slope_path_pred
    }
    missing_files = [name for name, p in required_files.items() if not p.exists()]
    if missing_files:
        print(f"ERROR: Missing required raster files for prediction year {year_to_predict}: {missing_files}")
        return

    try:
        # Open source files and get metadata from a reference (e.g., Landsat)
        handles = {name: rasterio.open(p) for name, p in required_files.items()}
        ref_src = handles['lsat']
        profile = ref_src.profile
        profile.update(dtype=rasterio.float32, count=1, nodata=-9999.0) # Output is probability (float)

        # Create the output file
        with rasterio.open(output_path, 'w', **profile) as dst:
            # Define processing blocks (adjust block size based on available RAM)
            block_shape = (512, 512) # Larger block might be faster if RAM allows
            total_blocks = len(list(dst.block_windows(block_shape)))
            processed_blocks = 0

            print(f"Processing {total_blocks} blocks...")
            for block_info, window in dst.block_windows(block_shape):
                processed_blocks += 1
                if processed_blocks % max(1, total_blocks // 10) == 0: # Print progress roughly 10 times
                    print(f"  Processing block {processed_blocks}/{total_blocks}...")

                # Read data for the block from all sources
                block_data = {}
                # Read Landsat (all bands)
                block_data['lsat'] = handles['lsat'].read(window=window)
                # Read single band rasters
                for name in ['temp_tm1', 'prec_tm1', 'temp_t', 'prec_t', 'dem', 'slope']:
                    block_data[name] = handles[name].read(1, window=window)

                block_h, block_w = block_data['lsat'].shape[1], block_data['lsat'].shape[2]
                if block_h == 0 or block_w == 0: continue # Skip empty blocks
                n_pixels_in_block = block_h * block_w

                # --- Assemble features for the block ---
                # Important: Feature order must match training order (FEATURE_NAMES)
                block_features_list = []
                for feature_name in feature_names_ordered:
                    if feature_name.startswith('lsat_'):
                        # Extract band name and check year (should be t-1)
                        parts = feature_name.split('_')
                        band_name = parts[1]
                        try:
                            band_index = LANDSAT_BANDS.index(band_name)
                            feature_data = block_data['lsat'][band_index].ravel()
                        except ValueError:
                            raise ValueError(f"Unknown Landsat band '{band_name}' in feature name list.")
                    elif feature_name.startswith('temp_'):
                        year_suffix = feature_name.split('_')[1]
                        if year_suffix == str(t_minus_1):
                            feature_data = block_data['temp_tm1'].ravel()
                        elif year_suffix == str(year_to_predict):
                            feature_data = block_data['temp_t'].ravel()
                        else:
                            raise ValueError(f"Unexpected year suffix in temp feature: {feature_name}")
                    elif feature_name.startswith('precip_'):
                        year_suffix = feature_name.split('_')[1]
                        if year_suffix == str(t_minus_1):
                            feature_data = block_data['prec_tm1'].ravel()
                        elif year_suffix == str(year_to_predict):
                            feature_data = block_data['prec_t'].ravel()
                        else:
                             raise ValueError(f"Unexpected year suffix in precip feature: {feature_name}")
                    elif feature_name == 'dem':
                        feature_data = block_data['dem'].ravel()
                    elif feature_name == 'slope':
                        feature_data = block_data['slope'].ravel()
                    else:
                        raise ValueError(f"Unknown feature name structure: {feature_name}")
                    block_features_list.append(feature_data)

                # Combine into (n_pixels, n_features) array
                block_features_array = np.vstack(block_features_list).T

                # --- Handle potential NoData/NaN values ---
                # Check for NaNs across all features for each pixel
                valid_pixel_mask = ~np.isnan(block_features_array).any(axis=1)
                # Create output block array, fill with NoData initially
                out_block = np.full(n_pixels_in_block, profile['nodata'], dtype=np.float32)

                # Predict only for valid pixels
                if np.any(valid_pixel_mask):
                    valid_features = block_features_array[valid_pixel_mask, :]
                    # Predict probability of class 1 (loss)
                    pred_proba = model.predict_proba(valid_features)[:, 1]
                    # Fill predictions into the output block based on the valid mask
                    out_block[valid_pixel_mask] = pred_proba

                # Reshape back to 2D block and write
                out_block = out_block.reshape(block_h, block_w)
                dst.write(out_block.astype(rasterio.float32), 1, window=window)

        # Close all source file handles
        for handle in handles.values():
            handle.close()

        print(f"Probability map saved to: {output_path}")
        print(f"Prediction map generation time: {time.time() - prediction_start_time:.2f}s")

    except FileNotFoundError as e:
        print(f"ERROR generating prediction map: Input file not found - {e}")
    except MemoryError:
        print("ERROR generating prediction map: Insufficient memory. Try smaller block_shape.")
    except Exception as e:
        print(f"ERROR generating prediction map: {type(e).__name__} - {e}")


# --- Example Call (run only if needed and model is trained) ---
# Check if model and feature names exist before running
# if 'rf_classifier' in locals() and 'FEATURE_NAMES' in locals() and Path(model_filename).exists():
#     PREDICTION_YEAR = 2021 # Choose a year from the test set range
#     prob_map_filename = OUTPUT_DIR / f'rf_loss_probability_{PREDICTION_YEAR}.tif'
#     # Optional: Load model if kernel restarted
#     # print(f"Loading model from {model_filename}...")
#     # loaded_model = joblib.load(model_filename)
#     # predict_loss_probability_map(loaded_model, FEATURE_NAMES, prob_map_filename, PREDICTION_YEAR)
#     # OR use the model in memory if kernel wasn't restarted:
#     predict_loss_probability_map(rf_classifier, FEATURE_NAMES, prob_map_filename, PREDICTION_YEAR)
# else:
#     print("\nModel or feature names not available. Skipping prediction map generation.")
#     if not 'rf_classifier' in locals(): print(" Reason: 'rf_classifier' not found.")
#     if not 'FEATURE_NAMES' in locals(): print(" Reason: 'FEATURE_NAMES' not found.")
#     if not Path(model_filename).exists(): print(f" Reason: Model file not found at {model_filename}")

## 11. Conclusion & Next Steps

- Summarize findings from evaluation metrics (e.g., Accuracy, Precision, Recall, F1, ROC AUC) and the feature importance analysis. Which factors were most predictive of forest loss in this region according to the model?
- Discuss limitations:
    - **Correlation vs Causation:** The model identifies correlations, not necessarily causal links.
    - **Data Resolution:** 30m resolution might miss fine-scale drivers. Climate data resolution (TerraClimate ~4km) is much coarser than Landsat/Hansen.
    - **Loss Type:** Hansen data aggregates loss from various causes (logging, pests, wind, etc.). The model doesn't differentiate these.
    - **Sampling Bias:** The non-loss sampling strategy might influence results.
    - **Spatial Autocorrelation:** This analysis treats pixels independently, ignoring spatial relationships which are likely important in forest dynamics.
    - **Model Simplicity:** Random Forest is robust but might not capture complex spatio-temporal interactions as well as more advanced models (if data/time permitted).
- Suggest potential future work:
    - Hyperparameter tuning of the Random Forest model (e.g., using GridSearchCV or RandomizedSearchCV).
    - Exploring other models (e.g., Gradient Boosting Machines like XGBoost/LightGBM, potentially simpler deep learning if data is very large).
    - Incorporating spatial features (e.g., distance to roads/mills, patch metrics, neighborhood characteristics).
    - Using focal statistics on predictor variables to capture neighborhood context.
    - Attempting to differentiate loss types if finer-grained data becomes available.
    - Analyzing model predictions spatially - where does the model perform well/poorly?