# Forest Loss Driver Analysis (Northern Maine Acaduab Reguib)

**Goal:** Use a Random Forest model to identify the relative importance of
preceding spectral conditions and climate variables in predicting forest loss
events (from Hansen GFC) between 2001 and 2021.

**Data:**
- Annual Landsat C02 L2 median composites (2000-2021), Bands: B, G, R, NIR, SWIR1, SWIR2, NDVI, NBR
- Hansen GFC Loss Year (2000-2021 v1.9)
- Hansen GFC Tree Cover 2000
- Annual TerraClimate: Mean Temperature, Total Precipitation (2000-2021)
- SRTM DEM & Slope

**Methodology:**
1. Load and verify data alignment (CRS, Extent, Resolution).
2. Define sampling strategy: Sample loss pixels and non-loss pixels (within forest mask) annually.
3. Extract features for sampled pixels:
   - Target: Loss in year `t` (binary)
   - Features: Landsat(t-1), Climate(t-1), Climate(t), DEM, Slope
4. Split data temporally (e.g., train 2001-2015, test 2016-2021).
5. Train Random Forest Classifier.
6. Evaluate model performance.
7. Analyze feature importances.

## 1. Setup and Imports

In [3]:
import os
import glob
from pathlib import Path
import time # Keep track of processing time

import numpy as np
import pandas as pd
import rasterio
from rasterio.windows import Window
from rasterio.enums import Resampling # if needed for consistency checks
# import geopandas as gpd # If needed for vector operations later
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay, roc_auc_score
from sklearn.utils import shuffle # For sampling
import matplotlib.pyplot as plt
import joblib # To save the trained model

# Optional: for parallel processing if sampling is slow
# import dask.array as da
# from dask.diagnostics import ProgressBar

print("Libraries imported.")

Matplotlib is building the font cache; this may take a moment.


Libraries imported.


## 2. Configuration and Constants

In [None]:
# --- Project Structure ---
# Assuming the notebook is in 'notebooks/', navigate to the project root
try:
    PROJECT_ROOT = Path(os.getcwd()).parent
    # Basic check if 'data' dir exists relative to parent
    if not (PROJECT_ROOT / 'data').exists():
         # If run directly from project root maybe?
         PROJECT_ROOT = Path(os.getcwd())
         if not (PROJECT_ROOT / 'data').exists():
             raise FileNotFoundError("Could not determine project root. Expecting 'data' dir.")
except:
    # Fallback if structure is different, manual path needed
    PROJECT_ROOT = Path('/path/to/your/project/root') # <--- ADJUST MANUALLY IF NEEDED
    print("WARNING: Could not auto-detect project root. Set manually.")
    if not (PROJECT_ROOT / 'data').exists():
         raise FileNotFoundError(f"Data directory not found at: {PROJECT_ROOT / 'data'}")


DATA_DIR = PROJECT_ROOT / 'data'
LANDSAT_DIR = DATA_DIR / 'landsat'
HANSEN_DIR = DATA_DIR / 'hansen'
CLIMATE_DIR = DATA_DIR / 'climate'
AUX_DIR = DATA_DIR / 'aux'
OUTPUT_DIR = PROJECT_ROOT / 'output' # Create if you want to save models/plots
OUTPUT_DIR.mkdir(exist_ok=True)

print(f"Project Root: {PROJECT_ROOT}")
print(f"Data Directory: {DATA_DIR}")

# --- Data File Paths ---
HANSEN_LOSS_YEAR_PATH = HANSEN_DIR / 'hansen_lossyear_2000_2021.tif'
HANSEN_COVER_2000_PATH = HANSEN_DIR / 'hansen_treecover_2000.tif'
DEM_PATH = AUX_DIR / 'dem.tif'
SLOPE_PATH = AUX_DIR / 'slope.tif'

# --- Analysis Parameters ---
START_YEAR = 2001 # First year for which we can have t-1 predictors (Landsat 2000)
END_YEAR = 2021   # Last year of loss data
YEARS = list(range(START_YEAR, END_YEAR + 1))

# Define the bands we expect in the Landsat composites AFTER GEE processing
# IMPORTANT: Verify this order matches your actual GeoTIFF band order!
# Check using: `with rasterio.open(example_landsat_path) as src: print(src.descriptions)`
LANDSAT_BANDS = ['Blue', 'Green', 'Red', 'NIR', 'SWIR1', 'SWIR2', 'NDVI', 'NBR']
CLIMATE_VARS = ['mean_temp', 'total_precip']
STATIC_VARS = ['dem', 'slope'] # From AUX_DIR

# --- Sampling Parameters ---
# Ratio of non-loss points to sample for every 1 loss point
NON_LOSS_RATIO = 2 # Sample 2 non-loss points for every loss point
# Minimum tree cover percentage in 2000 to be considered 'forest' for sampling non-loss points
MIN_TREE_COVER = 30
# Random state for reproducibility
RANDOM_STATE = 42

# --- Modeling Parameters ---
TEST_SPLIT_YEAR = 2016 # Year to split data temporally
RF_N_ESTIMATORS = 150 # Number of trees in the forest (adjust later)
RF_MAX_DEPTH = None   # Let trees grow deep initially (adjust later)
RF_N_JOBS = -1        # Use all available CPU cores

Project Root: /Users/benjaminpace/MLCS/mlcs
Data Directory: /Users/benjaminpace/MLCS/mlcs/data


## 3. Data Loading and Verification

**CRITICAL STEP:** Verify that all input rasters have the *exact same* Coordinate Reference System (CRS), transform (affine), dimensions (width, height), and resolution.

In [None]:
def verify_raster_alignment(raster_paths):
    """Checks CRS, transform, shape of multiple rasters."""
    print("Verifying raster alignment...")
    if not raster_paths:
        print("No raster paths provided.")
        return None

    reference_profile = None
    is_aligned = True
    checked_paths = [] # Keep track of files actually checked

    for i, path_obj in enumerate(raster_paths):
        path = str(path_obj) # Ensure it's a string for rasterio
        if not Path(path).exists():
            print(f"WARNING: File not found, skipping alignment check: {path}")
            # Decide if this should be a fatal error
            # is_aligned = False # Uncomment if missing file means failure
            continue # Skip to next file

        checked_paths.append(path)
        try:
            with rasterio.open(path) as src:
                profile = {
                    'path': Path(path).name,
                    'crs': src.crs,
                    'transform': src.transform,
                    'width': src.width,
                    'height': src.height,
                    'count': src.count, # Number of bands
                    'dtype': src.dtype # Data type
                }
                print(f"--- Checking: {profile['path']} ---")
                # print(f"  CRS: {profile['crs']}")
                # print(f"  Transform: {profile['transform']}")
                # print(f"  Shape: ({profile['height']}, {profile['width']})")
                # print(f"  Band Count: {profile['count']}")
                # print(f"  Data Type: {profile['dtype']}")


                if reference_profile is None:
                    reference_profile = profile
                    print(f"  Set as Reference: CRS={profile['crs']}, Shape=({profile['height']},{profile['width']}), Transform={profile['transform']}")
                else:
                    # Check alignment
                    if profile['crs'] != reference_profile['crs']:
                        print(f"  MISMATCH: CRS {profile['crs']} differs from reference {reference_profile['crs']}")
                        is_aligned = False
                    if profile['transform'] != reference_profile['transform']:
                        print(f"  MISMATCH: Transform differs from reference")
                        is_aligned = False
                    if profile['width'] != reference_profile['width'] or profile['height'] != reference_profile['height']:
                        print(f"  MISMATCH: Shape ({profile['height']}, {profile['width']}) differs from reference ({reference_profile['height']}, {reference_profile['width']})")
                        is_aligned = False
        except Exception as e:
            print(f"ERROR reading {path}: {e}")
            is_aligned = False

    if not reference_profile:
         print("\nERROR: No valid reference raster found to check alignment against.")
         return None

    if is_aligned and checked_paths:
        print(f"\nSUCCESS: All {len(checked_paths)} checked rasters appear aligned.")
        # Return the common profile derived from the reference
        return {k: v for k, v in reference_profile.items() if k != 'path'}
    elif not checked_paths:
        print("\nWARNING: No files were actually checked (all missing?).")
        return None
    else:
        print("\nERROR: Raster alignment check failed. Please fix data before proceeding.")
        return None

In [None]:
# --- Gather all files to check ---
# Get one example Landsat and Climate file for checking structure
# Use year 2000 for t-1 predictors of 2001 loss
example_landsat_path = LANDSAT_DIR / f'landsat_composite_2000.tif'
example_temp_path = CLIMATE_DIR / f'mean_temp_2000.tif'
example_precip_path = CLIMATE_DIR / f'total_precip_2000.tif'

files_to_check = [
    HANSEN_LOSS_YEAR_PATH,
    HANSEN_COVER_2000_PATH,
    DEM_PATH,
    SLOPE_PATH,
    example_landsat_path,
    example_temp_path,
    example_precip_path,
]

# Add climate files for a year within the main analysis range too (e.g., 2001)
files_to_check.append(CLIMATE_DIR / f'mean_temp_{START_YEAR}.tif')
files_to_check.append(CLIMATE_DIR / f'total_precip_{START_YEAR}.tif')


# --- Run the check ---
# This requires data to be downloaded and placed correctly.
# If this cell fails, DO NOT proceed. Fix the data first.
common_profile = verify_raster_alignment(files_to_check)

assert common_profile is not None, "Raster alignment failed. Stopping execution."

# --- Store key dimensions for later use ---
RASTER_HEIGHT = common_profile['height']
RASTER_WIDTH = common_profile['width']
RASTER_TRANSFORM = common_profile['transform']
RASTER_CRS = common_profile['crs']

print(f"\nCommon Raster Shape: ({RASTER_HEIGHT}, {RASTER_WIDTH})")
print(f"Common CRS: {RASTER_CRS}")
print(f"Common Transform: {RASTER_TRANSFORM}")

# --- Quick check on Landsat band count ---
try:
    with rasterio.open(example_landsat_path) as src:
        if src.count != len(LANDSAT_BANDS):
             print(f"\nWARNING: Expected {len(LANDSAT_BANDS)} Landsat bands based on LANDSAT_BANDS variable, but found {src.count} in {example_landsat_path.name}")
             print(f"Actual band descriptions: {src.descriptions}") # Requires descriptions set in GeoTIFF
        else:
             print(f"\nLandsat band count ({src.count}) matches expected count.")
except Exception as e:
    print(f"\nCould not perform Landsat band count check: {e}")

## 4. Sampling Strategy

For each year `t` from 2001 to 2021:
1. Identify pixels where loss occurred exactly in year `t`.
2. Identify pixels that are potential non-loss candidates (forested in 2000, no loss 2001-2021).
3. Sample loss pixels.
4. Sample `NON_LOSS_RATIO` times as many non-loss pixels randomly from the candidates.
5. Store pixel coordinates (row, col) and associated target year/loss status.

In [None]:
# Placeholder for sampled pixel coordinates and their corresponding year/loss status
# Format: list of dicts [{'row': r, 'col': c, 'target_year': t, 'loss': 1}, ...]
all_sampled_points = []

print("Starting pixel sampling...")
sampling_start_time = time.time()

try:
    print("Loading base data for sampling (Hansen Loss Year and Cover 2000)...")
    with rasterio.open(HANSEN_LOSS_YEAR_PATH) as loss_src:
        # Verify shape matches common profile before reading
        assert (loss_src.height, loss_src.width) == (RASTER_HEIGHT, RASTER_WIDTH), "Loss year shape mismatch!"
        loss_year_data = loss_src.read(1)
        print(f"  Loaded Loss Year data ({loss_src.height}x{loss_src.width})")

    with rasterio.open(HANSEN_COVER_2000_PATH) as cover_src:
        assert (cover_src.height, cover_src.width) == (RASTER_HEIGHT, RASTER_WIDTH), "Cover shape mismatch!"
        cover_2000_data = cover_src.read(1)
        print("  Loaded Tree Cover 2000 data")

    # --- Create mask for non-loss candidate pixels ---
    # Condition 1: Sufficient tree cover in 2000
    forest_mask = cover_2000_data >= MIN_TREE_COVER
    # Condition 2: No loss recorded between 2001 and 2021 (loss year == 0 in Hansen)
    no_loss_mask = loss_year_data == 0
    # Combined mask for pixels eligible to be sampled as "non-loss" controls
    non_loss_candidate_mask = forest_mask & no_loss_mask
    # Get indices (row, col arrays) where mask is True
    non_loss_candidate_indices = np.where(non_loss_candidate_mask)
    num_non_loss_candidates = len(non_loss_candidate_indices[0])
    print(f"Found {num_non_loss_candidates} potential non-loss candidate pixels.")

    # Check if we have candidates to sample from
    if num_non_loss_candidates == 0:
        raise ValueError("No non-loss candidate pixels found based on criteria. Cannot sample.")

    # --- Sample pixels year by year ---
    np.random.seed(RANDOM_STATE) # for reproducibility

    for year_num, target_year in enumerate(YEARS):
        start_time_year = time.time()
        print(f"\n--- Sampling for target year: {target_year} ---")

        # 1. Find pixels with loss IN THIS target_year
        # Hansen loss year codes 1-21 correspond to years 2001-2021
        hansen_loss_code = target_year - 2000
        loss_pixels_this_year_mask = loss_year_data == hansen_loss_code
        loss_indices_this_year = np.where(loss_pixels_this_year_mask)
        num_loss_pixels = len(loss_indices_this_year[0])
        print(f"  Found {num_loss_pixels} loss pixels.")

        if num_loss_pixels == 0:
            print("  No loss pixels found for this year. Skipping.")
            continue

        # 2. Add loss pixels to sample list
        for r, c in zip(loss_indices_this_year[0], loss_indices_this_year[1]):
            all_sampled_points.append({'row': r, 'col': c, 'target_year': target_year, 'loss': 1})

        # 3. Sample non-loss pixels
        # Calculate how many non-loss points to sample for this year
        num_non_loss_to_sample = min(num_non_loss_candidates, num_loss_pixels * NON_LOSS_RATIO)
        print(f"  Sampling {num_non_loss_to_sample} non-loss pixels.")

        if num_non_loss_to_sample > 0:
            # Randomly choose indices from the non_loss_candidate_indices array
            # This samples indices *of the indices array*, not row/col directly
            sampled_candidate_indices_idx = np.random.choice(
                num_non_loss_candidates, num_non_loss_to_sample, replace=False
            )
            # Get the actual row and column values using these sampled indices
            sampled_non_loss_rows = non_loss_candidate_indices[0][sampled_candidate_indices_idx]
            sampled_non_loss_cols = non_loss_candidate_indices[1][sampled_candidate_indices_idx]

            # Add non-loss pixels to sample list
            for r, c in zip(sampled_non_loss_rows, sampled_non_loss_cols):
                 # Target year is still the year we are comparing against
                all_sampled_points.append({'row': r, 'col': c, 'target_year': target_year, 'loss': 0})

        print(f"  Finished sampling for {target_year}. Points this year: {num_loss_pixels + num_non_loss_to_sample}. Time: {time.time() - start_time_year:.2f}s")

    # --- Cleanup ---
    del loss_year_data, cover_2000_data, forest_mask, no_loss_mask, non_loss_candidate_mask
    del loss_pixels_this_year_mask, loss_indices_this_year, non_loss_candidate_indices
    import gc
    gc.collect()

    # Shuffle the collected points for randomness if needed later (e.g., batch processing)
    # Though temporal split doesn't strictly require shuffling here
    all_sampled_points = shuffle(all_sampled_points, random_state=RANDOM_STATE)
    print(f"\nTotal points sampled across all years: {len(all_sampled_points)}")
    print(f"Sampling finished. Total time: {time.time() - sampling_start_time:.2f}s")

    # Optional: Quick check on class balance
    loss_counts = pd.Series([p['loss'] for p in all_sampled_points]).value_counts()
    print(f"Sampled loss counts:\n{loss_counts}")

except FileNotFoundError as e:
     print(f"ERROR: Required file not found during sampling: {e}. Please ensure data exists.")
except ValueError as e:
     print(f"ERROR during sampling setup: {e}")
except Exception as e:
    print(f"An unexpected error occurred during sampling: {e}")
    # Depending on the error, you might want to stop execution
    # raise e

## 5. Feature Extraction

Iterate through the `all_sampled_points`. For each point (row, col, target_year, loss):
- Read pixel values from Landsat(target_year - 1).
- Read pixel values from Climate(target_year - 1).
- Read pixel values from Climate(target_year).
- Read pixel values from DEM and Slope.
- Store features and target in a structure suitable for Scikit-learn (e.g., NumPy array or Pandas DataFrame).

In [None]:
# Prepare for feature extraction
features = []
# Target 'y' is implicitly defined by the 'loss' key in all_sampled_points

# --- Need raster dimensions and transform from verification step ---
# RASTER_HEIGHT, RASTER_WIDTH, RASTER_TRANSFORM, RASTER_CRS are already defined if verification passed

print("Starting feature extraction...")
extraction_start_time = time.time()

# Cache for opened rasterio file handles
open_files_cache = {}

def get_file_handle(path, cache):
    """Gets or opens a rasterio file handle, caching it."""
    path_str = str(path) # Use string representation as key
    if path_str not in cache:
        if not path.exists():
             raise FileNotFoundError(f"Required file for feature extraction not found: {path_str}")
        print(f"  Opening file: {path.name}")
        cache[path_str] = rasterio.open(path_str)
    return cache[path_str]

def get_pixel_value_at_rc(src, row, col):
    """Reads pixel value(s) for all bands at a specific row, col."""
    # Use Window(col_off, row_off, width, height)
    window = Window(col, row, 1, 1)
    # read() returns shape (bands, rows, cols) -> (bands, 1, 1)
    # squeeze removes the singleton dimensions -> (bands,) or scalar if 1 band
    return src.read(window=window).squeeze()

# --- Process all sampled points ---
extraction_errors = 0
for i, point in enumerate(all_sampled_points):
    row, col = point['row'], point['col']
    t_year = point['target_year'] # Year loss status is defined for
    t_minus_1 = t_year - 1      # Year for predictor data

    if i > 0 and i % 20000 == 0: # Print progress periodically
         elapsed_time = time.time() - extraction_start_time
         points_per_sec = i / elapsed_time if elapsed_time > 0 else 0
         print(f"  Processed {i}/{len(all_sampled_points)} points... ({points_per_sec:.1f} points/sec)")

    # --- Feature dictionary for this point ---
    # Start with static info to ensure it's always present if point is processed
    point_features = {'row': row, 'col': col, 'target_year': t_year, 'loss': point['loss']}

    try:
        # --- Landsat (t-1) ---
        lsat_path = LANDSAT_DIR / f'landsat_composite_{t_minus_1}.tif'
        lsat_src = get_file_handle(lsat_path, open_files_cache)
        lsat_values = get_pixel_value_at_rc(lsat_src, row, col)
        # Check if return is scalar (single band) or array (multiple bands)
        if lsat_values.ndim == 0: # Scalar if only one band read unexpectedly
             if len(LANDSAT_BANDS) == 1:
                 point_features[f'lsat_{LANDSAT_BANDS[0]}_{t_minus_1}'] = lsat_values
             else:
                 raise ValueError(f"Read scalar Landsat value but expected {len(LANDSAT_BANDS)} bands.")
        else: # Should be an array of band values
             if len(lsat_values) != len(LANDSAT_BANDS):
                 raise ValueError(f"Read {len(lsat_values)} Landsat bands but expected {len(LANDSAT_BANDS)}.")
             for band_name, band_value in zip(LANDSAT_BANDS, lsat_values):
                 point_features[f'lsat_{band_name}_{t_minus_1}'] = band_value

        # --- Climate (t-1) ---
        temp_tm1_path = CLIMATE_DIR / f'mean_temp_{t_minus_1}.tif'
        precip_tm1_path = CLIMATE_DIR / f'total_precip_{t_minus_1}.tif'
        temp_tm1_src = get_file_handle(temp_tm1_path, open_files_cache)
        precip_tm1_src = get_file_handle(precip_tm1_path, open_files_cache)
        point_features[f'temp_{t_minus_1}'] = get_pixel_value_at_rc(temp_tm1_src, row, col)
        point_features[f'precip_{t_minus_1}'] = get_pixel_value_at_rc(precip_tm1_src, row, col)

        # --- Climate (t) ---
        temp_t_path = CLIMATE_DIR / f'mean_temp_{t_year}.tif'
        precip_t_path = CLIMATE_DIR / f'total_precip_{t_year}.tif'
        temp_t_src = get_file_handle(temp_t_path, open_files_cache)
        precip_t_src = get_file_handle(precip_t_path, open_files_cache)
        point_features[f'temp_{t_year}'] = get_pixel_value_at_rc(temp_t_src, row, col)
        point_features[f'precip_{t_year}'] = get_pixel_value_at_rc(precip_t_src, row, col)

        # --- Static Vars ---
        dem_src = get_file_handle(DEM_PATH, open_files_cache)
        slope_src = get_file_handle(SLOPE_PATH, open_files_cache)
        point_features['dem'] = get_pixel_value_at_rc(dem_src, row, col)
        point_features['slope'] = get_pixel_value_at_rc(slope_src, row, col)

        # --- Append the complete feature dictionary ---
        features.append(point_features)

    except FileNotFoundError as e:
         # This error should have been caught by get_file_handle but handle defensively
         print(f"WARNING: Skipping point {i} due to missing file: {e}")
         extraction_errors += 1
         continue # Skip this point
    except Exception as e:
        print(f"ERROR extracting features for point {i} (row={row}, col={col}, target_year={t_year}): {type(e).__name__} - {e}")
        # Decide how to handle errors: skip point, fill with NaN? Here we skip.
        extraction_errors += 1
        continue

# --- Close all opened files in the cache ---
print("\nClosing cached raster files...")
for path_str, src in open_files_cache.items():
    try:
        src.close()
    except Exception as e:
        print(f"Error closing file {path_str}: {e}")
open_files_cache = {} # Clear the cache

print(f"Finished feature extraction. Total time: {time.time() - extraction_start_time:.2f}s")
print(f"Processed {len(features)} points successfully.")
if extraction_errors > 0:
    print(f"Encountered {extraction_errors} errors during extraction (points skipped).")

# --- Convert list of dictionaries to DataFrame ---
if not features:
    raise SystemExit("No features were extracted successfully. Cannot proceed.")

print("\nConverting extracted features to DataFrame...")
feature_df = pd.DataFrame(features)
del features # Free up memory
gc.collect()

print(f"DataFrame shape: {feature_df.shape}")
# Display sample data and check for initial NaNs which might indicate read errors or NoData values
print("Sample data:")
print(feature_df.head())
print("\nCheck for missing values before cleaning:")
print(feature_df.isnull().sum())

# --- Handle Missing Data (if any) ---
# This could happen if rasters contained NoData values at sampled locations
initial_rows = len(feature_df)
feature_df = feature_df.dropna()
final_rows = len(feature_df)

if initial_rows != final_rows:
    print(f"\nDropped {initial_rows - final_rows} rows containing NaN values.")
    print(f"Final DataFrame shape after dropna: {feature_df.shape}")
else:
    print("\nNo rows dropped due to NaN values.")

if final_rows == 0:
     raise SystemExit("All rows contained NaN values after extraction. Check input data quality.")

# --- Prepare final X and y for modeling ---
target_series = feature_df['loss']
year_series = feature_df['target_year'] # Keep track of year for temporal split
# Drop non-feature columns (row, col might be useful later for mapping, keep if needed)
X = feature_df.drop(columns=['row', 'col', 'target_year', 'loss'])
y = target_series

# Store feature names in the order they appear in X for later use (importance plot)
FEATURE_NAMES = X.columns.tolist()

print("\nFinal feature columns (X):", FEATURE_NAMES)
print("Target distribution (y):")
print(y.value_counts(normalize=True))


## 6. Train/Test Split (Temporal): TODO

Split the data based on the `target_year`. Data from years before `TEST_SPLIT_YEAR` will be used for training, and data from `TEST_SPLIT_YEAR` onwards for testing. This prevents data leakage from the future into the training set.


In [None]:
print(f"Splitting data temporally based on year {TEST_SPLIT_YEAR}...")

# Use the 'year_series' created alongside X and y
train_mask = year_series < TEST_SPLIT_YEAR
test_mask = year_series >= TEST_SPLIT_YEAR

X_train = X[train_mask]
y_train = y[train_mask]
X_test = X[test_mask]
y_test = y[test_mask]

# Keep track of years for verification (optional)
# train_years = year_series[train_mask]
# test_years = year_series[test_mask]

# Cleanup intermediate objects if memory is tight
# del X, y, feature_df, year_series, train_mask, test_mask
# gc.collect()

print(f"Train set shape: X={X_train.shape}, y={y_train.shape}")
print(f"Test set shape: X={X_test.shape}, y={y_test.shape}")

# Verify the split worked as expected
if not X_train.empty and not X_test.empty:
    print(f"Train target distribution:\n{y_train.value_counts(normalize=True)}")
    print(f"Test target distribution:\n{y_test.value_counts(normalize=True)}")
    # print(f"Train years range: {train_years.min()} - {train_years.max()}")
    # print(f"Test years range: {test_years.min()} - {test_years.max()}")
else:
    print("WARNING: Train or Test set is empty after split. Check TEST_SPLIT_YEAR and data distribution.")

## 7. Model Training (Random Forest)

Train a Random Forest Classifier on the training data. We use `class_weight='balanced'` to help handle potential class imbalance between loss and no-loss pixels.


In [None]:
# Check if training data exists
if X_train.empty or y_train.empty:
    raise SystemExit("Training data is empty. Cannot train model.")

print("Training Random Forest model...")
train_start_time = time.time()

rf_classifier = RandomForestClassifier(
    n_estimators=RF_N_ESTIMATORS,
    max_depth=RF_MAX_DEPTH,
    random_state=RANDOM_STATE,
    n_jobs=RF_N_JOBS,             # Use all available cores
    class_weight='balanced',      # Adjust for class imbalance
    oob_score=False               # Set to True to estimate generalization score without test set (slower)
)

# Train the model
rf_classifier.fit(X_train, y_train)

print(f"Finished training. Total time: {time.time() - train_start_time:.2f}s")

# Optional: Print OOB score if calculated
# if rf_classifier.oob_score:
#     print(f"Out-of-Bag (OOB) Score: {rf_classifier.oob_score_:.4f}")

# --- Save the trained model ---
model_filename = OUTPUT_DIR / f'rf_forest_loss_model_{START_YEAR}_{END_YEAR}.joblib'
try:
    joblib.dump(rf_classifier, model_filename)
    print(f"Model saved successfully to: {model_filename}")
except Exception as e:
    print(f"ERROR saving model: {e}")

## 8. Model Evaluation

Evaluate the trained model's performance on the temporally independent test set using various classification metrics.


In [None]:
# Check if test data exists
if X_test.empty or y_test.empty:
    print("Test data is empty. Skipping evaluation.")
else:
    print("Evaluating model on the test set...")

    # Make predictions
    y_pred = rf_classifier.predict(X_test)
    y_pred_proba = rf_classifier.predict_proba(X_test)[:, 1] # Probability of class 1 (loss)

    # --- Classification Report ---
    print("\nClassification Report:")
    # Use zero_division=0 to avoid warnings if a class has no predicted samples
    print(classification_report(y_test, y_pred, target_names=['No Loss', 'Loss'], zero_division=0))

    # --- Confusion Matrix ---
    print("\nConfusion Matrix:")
    try:
        cm = confusion_matrix(y_test, y_pred)
        print(cm)
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['No Loss', 'Loss'])
        fig, ax = plt.subplots(figsize=(6, 6))
        disp.plot(cmap=plt.cm.Blues, ax=ax)
        plt.title('Confusion Matrix')
        plt.show()
    except Exception as e:
        print(f"Could not generate Confusion Matrix plot: {e}")


    # --- ROC AUC Score ---
    # Check if both classes are present in y_test for ROC AUC calculation
    if len(np.unique(y_test)) > 1:
        roc_auc = roc_auc_score(y_test, y_pred_proba)
        print(f"\nROC AUC Score: {roc_auc:.4f}")

        # Optional: Plot ROC Curve
        try:
            from sklearn.metrics import RocCurveDisplay
            fig, ax = plt.subplots(figsize=(6, 6))
            RocCurveDisplay.from_predictions(y_test, y_pred_proba, ax=ax, name='Random Forest')
            plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Chance')
            plt.title('Receiver Operating Characteristic (ROC) Curve')
            plt.legend()
            plt.show()
        except Exception as e:
            print(f"Could not generate ROC Curve plot: {e}")
    else:
        print("\nROC AUC Score cannot be calculated: Only one class present in the test set.")

## 9. Feature Importance Analysis

Analyze the feature importances provided by the Random Forest model (based on mean decrease in impurity - Gini importance) to understand which variables were most influential in the model's predictions.g



In [None]:
print("Calculating and plotting Feature Importances...")

try:
    importances = rf_classifier.feature_importances_
    # Get standard deviations of importances across trees (optional, adds compute time)
    # std = np.std([tree.feature_importances_ for tree in rf_classifier.estimators_], axis=0)
    indices = np.argsort(importances)[::-1] # Sort features by importance (descending)

    # --- Create DataFrame for easier plotting/viewing ---
    importance_df = pd.DataFrame({
        'Feature': [FEATURE_NAMES[i] for i in indices],
        'Importance': importances[indices],
        # 'StdDev': std[indices] # Uncomment if std is calculated
    })

    print("\nTop 20 Feature Importances:")
    print(importance_df.head(20))

    # --- Plot Feature Importances ---
    N_FEATURES_TO_PLOT = 20
    plt.figure(figsize=(12, max(6, N_FEATURES_TO_PLOT // 2))) # Adjust height based on number of features
    plt.title(f"Feature Importances (Top {N_FEATURES_TO_PLOT})")
    plt.barh(range(N_FEATURES_TO_PLOT), # Use barh for horizontal plot
             importance_df['Importance'].head(N_FEATURES_TO_PLOT)[::-1], # Plot descending importance
             # xerr=importance_df['StdDev'].head(N_FEATURES_TO_PLOT)[::-1], # Uncomment if std is calculated
             align='center')
    plt.yticks(range(N_FEATURES_TO_PLOT),
               importance_df['Feature'].head(N_FEATURES_TO_PLOT)[::-1]) # Labels for y-axis
    plt.xlabel("Mean Decrease in Impurity (Gini Importance)")
    plt.ylabel("Feature")
    plt.ylim([-1, N_FEATURES_TO_PLOT])
    plt.tight_layout()
    plt.show()

    # Optional: Save importance data
    importance_filename = OUTPUT_DIR / f'rf_feature_importances_{START_YEAR}_{END_YEAR}.csv'
    importance_df.to_csv(importance_filename, index=False)
    print(f"Feature importances saved to: {importance_filename}")

except Exception as e:
    print(f"An error occurred during feature importance calculation or plotting: {e}")

## 10. (Optional) Prediction Map Visualization

This section provides a function to generate a spatial map of predicted loss probability for a given year using the trained model. This is computationally intensive and requires significant memory/time.


In [None]:
# --- This section is computationally intensive ---

def predict_loss_probability_map(model, feature_names_ordered, output_path, year_to_predict):
    """Generates a GeoTIFF map of predicted loss probability for a given year."""
    print(f"\nGenerating loss probability map for year: {year_to_predict}")
    prediction_start_time = time.time()
    t_minus_1 = year_to_predict - 1

    # --- Define input file paths for the prediction year ---
    lsat_path = LANDSAT_DIR / f'landsat_composite_{t_minus_1}.tif'
    temp_tm1_path = CLIMATE_DIR / f'mean_temp_{t_minus_1}.tif'
    precip_tm1_path = CLIMATE_DIR / f'total_precip_{t_minus_1}.tif'
    temp_t_path = CLIMATE_DIR / f'mean_temp_{year_to_predict}.tif'
    precip_t_path = CLIMATE_DIR / f'total_precip_{year_to_predict}.tif'
    dem_path_pred = DEM_PATH # Static path
    slope_path_pred = SLOPE_PATH # Static path

    # Check if all required files exist before starting
    required_files = {
        'lsat': lsat_path, 'temp_tm1': temp_tm1_path, 'prec_tm1': precip_tm1_path,
        'temp_t': temp_t_path, 'prec_t': precip_t_path, 'dem': dem_path_pred, 'slope': slope_path_pred
    }
    missing_files = [name for name, p in required_files.items() if not p.exists()]
    if missing_files:
        print(f"ERROR: Missing required raster files for prediction year {year_to_predict}: {missing_files}")
        return

    try:
        # Open source files and get metadata from a reference (e.g., Landsat)
        handles = {name: rasterio.open(p) for name, p in required_files.items()}
        ref_src = handles['lsat']
        profile = ref_src.profile
        profile.update(dtype=rasterio.float32, count=1, nodata=-9999.0) # Output is probability (float)

        # Create the output file
        with rasterio.open(output_path, 'w', **profile) as dst:
            # Define processing blocks (adjust block size based on available RAM)
            block_shape = (512, 512) # Larger block might be faster if RAM allows
            total_blocks = len(list(dst.block_windows(block_shape)))
            processed_blocks = 0

            print(f"Processing {total_blocks} blocks...")
            for block_info, window in dst.block_windows(block_shape):
                processed_blocks += 1
                if processed_blocks % max(1, total_blocks // 10) == 0: # Print progress roughly 10 times
                    print(f"  Processing block {processed_blocks}/{total_blocks}...")

                # Read data for the block from all sources
                block_data = {}
                # Read Landsat (all bands)
                block_data['lsat'] = handles['lsat'].read(window=window)
                # Read single band rasters
                for name in ['temp_tm1', 'prec_tm1', 'temp_t', 'prec_t', 'dem', 'slope']:
                    block_data[name] = handles[name].read(1, window=window)

                block_h, block_w = block_data['lsat'].shape[1], block_data['lsat'].shape[2]
                if block_h == 0 or block_w == 0: continue # Skip empty blocks
                n_pixels_in_block = block_h * block_w

                # --- Assemble features for the block ---
                # Important: Feature order must match training order (FEATURE_NAMES)
                block_features_list = []
                for feature_name in feature_names_ordered:
                    if feature_name.startswith('lsat_'):
                        # Extract band name and check year (should be t-1)
                        parts = feature_name.split('_')
                        band_name = parts[1]
                        try:
                            band_index = LANDSAT_BANDS.index(band_name)
                            feature_data = block_data['lsat'][band_index].ravel()
                        except ValueError:
                            raise ValueError(f"Unknown Landsat band '{band_name}' in feature name list.")
                    elif feature_name.startswith('temp_'):
                        year_suffix = feature_name.split('_')[1]
                        if year_suffix == str(t_minus_1):
                            feature_data = block_data['temp_tm1'].ravel()
                        elif year_suffix == str(year_to_predict):
                            feature_data = block_data['temp_t'].ravel()
                        else:
                            raise ValueError(f"Unexpected year suffix in temp feature: {feature_name}")
                    elif feature_name.startswith('precip_'):
                        year_suffix = feature_name.split('_')[1]
                        if year_suffix == str(t_minus_1):
                            feature_data = block_data['prec_tm1'].ravel()
                        elif year_suffix == str(year_to_predict):
                            feature_data = block_data['prec_t'].ravel()
                        else:
                             raise ValueError(f"Unexpected year suffix in precip feature: {feature_name}")
                    elif feature_name == 'dem':
                        feature_data = block_data['dem'].ravel()
                    elif feature_name == 'slope':
                        feature_data = block_data['slope'].ravel()
                    else:
                        raise ValueError(f"Unknown feature name structure: {feature_name}")
                    block_features_list.append(feature_data)

                # Combine into (n_pixels, n_features) array
                block_features_array = np.vstack(block_features_list).T

                # --- Handle potential NoData/NaN values ---
                # Check for NaNs across all features for each pixel
                valid_pixel_mask = ~np.isnan(block_features_array).any(axis=1)
                # Create output block array, fill with NoData initially
                out_block = np.full(n_pixels_in_block, profile['nodata'], dtype=np.float32)

                # Predict only for valid pixels
                if np.any(valid_pixel_mask):
                    valid_features = block_features_array[valid_pixel_mask, :]
                    # Predict probability of class 1 (loss)
                    pred_proba = model.predict_proba(valid_features)[:, 1]
                    # Fill predictions into the output block based on the valid mask
                    out_block[valid_pixel_mask] = pred_proba

                # Reshape back to 2D block and write
                out_block = out_block.reshape(block_h, block_w)
                dst.write(out_block.astype(rasterio.float32), 1, window=window)

        # Close all source file handles
        for handle in handles.values():
            handle.close()

        print(f"Probability map saved to: {output_path}")
        print(f"Prediction map generation time: {time.time() - prediction_start_time:.2f}s")

    except FileNotFoundError as e:
        print(f"ERROR generating prediction map: Input file not found - {e}")
    except MemoryError:
        print("ERROR generating prediction map: Insufficient memory. Try smaller block_shape.")
    except Exception as e:
        print(f"ERROR generating prediction map: {type(e).__name__} - {e}")


# --- Example Call (run only if needed and model is trained) ---
# Check if model and feature names exist before running
# if 'rf_classifier' in locals() and 'FEATURE_NAMES' in locals() and Path(model_filename).exists():
#     PREDICTION_YEAR = 2021 # Choose a year from the test set range
#     prob_map_filename = OUTPUT_DIR / f'rf_loss_probability_{PREDICTION_YEAR}.tif'
#     # Optional: Load model if kernel restarted
#     # print(f"Loading model from {model_filename}...")
#     # loaded_model = joblib.load(model_filename)
#     # predict_loss_probability_map(loaded_model, FEATURE_NAMES, prob_map_filename, PREDICTION_YEAR)
#     # OR use the model in memory if kernel wasn't restarted:
#     predict_loss_probability_map(rf_classifier, FEATURE_NAMES, prob_map_filename, PREDICTION_YEAR)
# else:
#     print("\nModel or feature names not available. Skipping prediction map generation.")
#     if not 'rf_classifier' in locals(): print(" Reason: 'rf_classifier' not found.")
#     if not 'FEATURE_NAMES' in locals(): print(" Reason: 'FEATURE_NAMES' not found.")
#     if not Path(model_filename).exists(): print(f" Reason: Model file not found at {model_filename}")

## 11. Conclusion & Next Steps

- Summarize findings from evaluation metrics (e.g., Accuracy, Precision, Recall, F1, ROC AUC) and the feature importance analysis. Which factors were most predictive of forest loss in this region according to the model?
- Discuss limitations:
    - **Correlation vs Causation:** The model identifies correlations, not necessarily causal links.
    - **Data Resolution:** 30m resolution might miss fine-scale drivers. Climate data resolution (TerraClimate ~4km) is much coarser than Landsat/Hansen.
    - **Loss Type:** Hansen data aggregates loss from various causes (logging, pests, wind, etc.). The model doesn't differentiate these.
    - **Sampling Bias:** The non-loss sampling strategy might influence results.
    - **Spatial Autocorrelation:** This analysis treats pixels independently, ignoring spatial relationships which are likely important in forest dynamics.
    - **Model Simplicity:** Random Forest is robust but might not capture complex spatio-temporal interactions as well as more advanced models (if data/time permitted).
- Suggest potential future work:
    - Hyperparameter tuning of the Random Forest model (e.g., using GridSearchCV or RandomizedSearchCV).
    - Exploring other models (e.g., Gradient Boosting Machines like XGBoost/LightGBM, potentially simpler deep learning if data is very large).
    - Incorporating spatial features (e.g., distance to roads/mills, patch metrics, neighborhood characteristics).
    - Using focal statistics on predictor variables to capture neighborhood context.
    - Attempting to differentiate loss types if finer-grained data becomes available.
    - Analyzing model predictions spatially - where does the model perform well/poorly?