In [None]:
#Import required packages
import ee
import os
import geemap
import pandas as pd
import numpy as np
import geopandas as gpd
from functools import reduce
from matplotlib import pyplot as plt
from sklearn.metrics import *
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, RandomForestRegressor, GradientBoostingRegressor, VotingRegressor, StackingClassifier
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.model_selection import *
from sklearn.ensemble import VotingClassifier
from sklearn.base import clone
from sklearn.inspection import permutation_importance
from sklearn.feature_selection import RFE, SelectFromModel
from sklearn import metrics
from sklearn.naive_bayes import GaussianNB
from sklearn import preprocessing
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import RFECV
from shapely.geometry import shape
from tqdm import tqdm
import rasterio
from rasterio.windows import Window
from rasterio.mask import mask
import seaborn as sns
import branca.colormap as bcm
import folium
from matplotlib import cm, colors
from folium.raster_layers import ImageOverlay
import matplotlib.colors as mcolors
from rasterio.warp import calculate_default_transform, reproject, Resampling
from PIL import Image
import tempfile

In [None]:
my_cloud_project = 'ee-YourProject' # your GEE cloud project ID
ee.Authenticate() 
ee.Initialize(project=my_cloud_project) # your GEE cloud project
my_scale = 1000
my_state_abbrv = 'MN'
all_params = [
    'Ca',
    'pH',
    'Phosphorus',
    'Nitrogen',
    'Boater_Visitation',
    'Inv_Algae',
    'Inv_Crust',
    'Inv_Fish',
    'Inv_Mollusk',
    'Inv_Plants',
    'Native_Fish',
    'homerange_sim',
    'Freeze_Up',
    # 'Ice_Melt',
    'Spawn_Start',
    'Spawn_End',
    # 'Growing_Season_Ideal',
    # 'Growing_Season_Length',
    'Precip_Winter',
    # 'Precip_Spring',
    'Precip_Summer',
    # 'Precip_Fall',
    'Flashiness',
    # 'Runoff',
    #'Drawdown',
    #'LST_Annual',
    'LST_Summer',
    'LST_Winter',
    # 'LST_Spring',
    # 'LST_Fall',
    'NDVI',
    'GPP_Annual',
    # 'GPP_Summer',
    'Heat_Insolation',
    'Topo_Diversity',
    'gHM',
    'NDTI',
    'NDBI',
    'NDCI',
    'NDSI',
    'Distance'
]


In [None]:
states = ee.FeatureCollection("TIGER/2018/States")
subbasins = ee.FeatureCollection("USGS/WBD/2017/HUC08")
my_predictors  = ee.Image('projects/' + my_cloud_project + '/assets/your_predictors').select(all_params)
geo = states.filter(ee.Filter.eq("STUSPS", my_state_abbrv)).geometry()
pos_fc = ee.FeatureCollection('projects/' + my_cloud_project + '/assets/your_pos_data').filterBounds(geo)
bg_fc = ee.FeatureCollection('projects/' + my_cloud_project + '/assets/your_bg_data').filterBounds(geo)

In [None]:
# Function to tag each point with the HUC8 code
def add_huc8(fc, subbasins):
    def add_props(f):
        huc = ee.Feature(subbasins.filterBounds(f.geometry()).first())
        return f.set('huc8', huc.get('huc8'))
    return fc.map(add_props)

# Apply to presence and background
pos_with_huc = add_huc8(pos_fc, subbasins)
bg_with_huc = add_huc8(bg_fc, subbasins)

In [None]:
pos_sample = my_predictors.reduceRegions(**{
                          'collection': pos_with_huc,
                          'reducer': ee.Reducer.mean(),
                          'crs': 'EPSG:4326',
                          'scale': my_scale,
                          'tileScale': 16})
pos_sample_df = ee.data.computeFeatures({
    'expression': pos_sample,
    'fileFormat': 'PANDAS_DATAFRAME'
}).drop(columns = ["geo", "lakeID"]).dropna()

In [None]:
bg_fc_rand = bg_with_huc.randomColumn('rand')

# Split into two roughly equal parts
bg_fc_1 = bg_fc_rand.filter(ee.Filter.lt('rand', 0.5))
bg_fc_2 = bg_fc_rand.filter(ee.Filter.gte('rand', 0.5))

In [None]:
bg_sample_fc_1 = my_predictors.reduceRegions(**{
                          'collection': bg_fc_1,
                          'reducer': ee.Reducer.mean(),
                          'crs': 'EPSG:4326',
                          'scale': my_scale,
                          'tileScale': 16})
bg_sample_fc_1_df = ee.data.computeFeatures({
    'expression': bg_sample_fc_1,
    'fileFormat': 'PANDAS_DATAFRAME'
}).drop(columns = ["geo", "lakeID", "rand", "rand_point"]).dropna()

In [None]:
bg_sample_fc_2 = my_predictors.reduceRegions(**{
                          'collection': bg_fc_2,
                          'reducer': ee.Reducer.mean(),
                          'crs': 'EPSG:4326',
                          'scale': my_scale,
                          'tileScale': 16})
bg_sample_fc_2_df = ee.data.computeFeatures({
    'expression': bg_sample_fc_2,
    'fileFormat': 'PANDAS_DATAFRAME'
}).drop(columns = ["geo", "lakeID", "rand", "rand_point"]).dropna()

In [None]:
all_points_df = pd.concat([pos_sample_df, bg_sample_fc_1_df, bg_sample_fc_2_df], ignore_index=True)

In [None]:
# Keep only HUCs with both 0 and 1
valid_hucs = (
    all_points_df.groupby("huc8")["Present"]
    .nunique()
    .loc[lambda x: x > 1]
    .index
)

dataset = all_points_df[all_points_df["huc8"].isin(valid_hucs)]

In [None]:
def compute_tss(y_true, y_pred):
    """
    Compute the True Skill Statistic (TSS).
    TSS = Sensitivity + Specificity - 1
    """
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()

    # Sensitivity (True Positive Rate)
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0

    # Specificity (True Negative Rate)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0

    tss = sensitivity + specificity - 1
    return tss

def leave_one_huc_out(df):
    hucs = df["huc8"].unique()
    n_folds = len(hucs)
    print(f"Running leave-one-HUC-out cross validation with {n_folds} folds...")

    results = {m: [] for m in ["MaxEnt", "RF", "DT", "BRT", "MLP", "Voting", "Stacking"]}

    # Progress bar over folds
    for test_huc in tqdm(hucs, desc="Processing folds", unit="fold"):
        train = df[df["huc8"] != test_huc]
        test = df[df["huc8"] == test_huc]

        X_train = train.drop(columns=["Present", "huc8"])
        y_train = train["Present"]
        X_test = test.drop(columns=["Present", "huc8"])
        y_test = test["Present"]

        scaler = StandardScaler().fit(X_train)
        X_train_scaled = scaler.transform(X_train)
        X_test_scaled = scaler.transform(X_test)

        # Define models
        models = {
            "MaxEnt": LogisticRegression(max_iter=10000),
            "RF": RandomForestClassifier(n_estimators=1000),
            "DT": DecisionTreeClassifier(),
            "BRT": GradientBoostingClassifier(n_estimators=1000),
            "MLP": MLPClassifier(max_iter=10000)
        }

        # Fit base models
        for name, model in models.items():
            model.fit(X_train_scaled, y_train)
            preds = model.predict(X_test_scaled)
            results[name].append(compute_tss(y_test, preds))

        # Ensembles
        vc_names = [(k, v) for k, v in models.items()]
        vc = VotingClassifier(estimators=vc_names, voting="soft")
        vc.fit(X_train_scaled, y_train)
        results["Voting"].append(compute_tss(y_test, vc.predict(X_test_scaled)))

        stack = StackingClassifier(estimators=vc_names, final_estimator=RandomForestClassifier())
        stack.fit(X_train_scaled, y_train)
        results["Stacking"].append(compute_tss(y_test, stack.predict(X_test_scaled)))

    print("Cross validation complete.")
    return results

In [None]:
# Get the first three unique HUCs
first_three_hucs = dataset["huc8"].unique()[:3]

# Subset dataset to only those HUCs
test_subset = dataset[dataset["huc8"].isin(first_three_hucs)]

In [None]:
tss_results = leave_one_huc_out(test_subset)

In [None]:
# -----------------
# Step 1: Collect replicate TSS results into tidy DataFrame
# -----------------
rows = []
for model_name, scores in tss_results.items():
    for tss in scores:
        rows.append({'Model': model_name, 'TSS': tss})
tss_df = pd.DataFrame(rows)

plt.figure(figsize=(10,6))
sns.boxplot(data=tss_df, x='Model', y='TSS', palette='Set2')
sns.swarmplot(data=tss_df, x='Model', y='TSS', color='k', alpha=0.5)
plt.title('Distribution of TSS by Algorithm')
plt.ylabel('TSS')
plt.xlabel('Algorithm')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# -----------------
# Step 2: Pick best model by mean TSS
# -----------------
avg_scores = {model: np.mean(scores) for model, scores in tss_results.items()}
best_model_name = max(avg_scores, key=avg_scores.get)
print(f"✅ Best model: {best_model_name}")

# -----------------
# Step 3: Fit best model on full dataset
# -----------------
X = all_points_df.drop(columns=["Present", "huc8"])
y = all_points_df["Present"]
feature_names = X.columns

scaler = StandardScaler().fit(X)
X_scaled = scaler.transform(X)

best_model = models[best_model_name]
best_model.fit(X_scaled, y)
print(f"✅ Final {best_model_name} trained on all data")




In [None]:
# -----------------
# Step 4: Feature importance (RFE, permutation, drop-column)
# -----------------
rfe_features = []
min_features_rfe = 3
if hasattr(best_model, "coef_") or hasattr(best_model, "feature_importances_"):
    try:
        rfe_selector = RFECV(
            estimator=best_model,
            min_features_to_select=min_features_rfe,
            step=1,
            cv=3,
            scoring="roc_auc"
        )
        rfe_selector.fit(X_scaled, y)
        rfe_features = [feature_names[i] for i, keep in enumerate(rfe_selector.support_) if keep]
    except Exception as e:
        print(f"⚠️ Skipping RFE for {best_model_name}: {e}")

# Permutation importance (works for all)
perm_imp = permutation_importance(best_model, X_scaled, y, n_repeats=3, random_state=42)
perm_features = pd.Series(perm_imp.importances_mean, index=feature_names)\
                  .sort_values(ascending=False).index[:3].tolist()

# Drop-column importance
drop_col_feats = drop_col(best_model, X_scaled, y, feature_names=feature_names)
drop_features = drop_col_feats.sort_values('feature_importance', ascending=False)['feature'][:3].tolist()

# -----------------
# Step 5: Plot feature importance comparison
# -----------------
fi_df = pd.DataFrame({
    'RFE': rfe_features if rfe_features else [None]*len(feature_names),
    'Permutation': perm_features,
    'Drop-column': drop_features
})

# Melt + tidy for plotting
fi_long = fi_df.melt(var_name='Method', value_name='Feature').dropna()
fi_counts = fi_long.groupby(['Method','Feature']).size().reset_index(name='Count')

plt.figure(figsize=(10,6))
sns.barplot(data=fi_counts, x='Feature', y='Count', hue='Method', palette='Set2')
plt.title(f'Top Features for {best_model_name}')
plt.ylabel('Importance / Selection Count')
plt.xlabel('Feature')
plt.xticks(rotation=45)
plt.legend(title='Method')
plt.tight_layout()
plt.show()

In [None]:
# -------------------------------
# Safe MESS function
# -------------------------------
def MESS_safe(ref_df, pred_df):
    mins = dict(ref_df.min())
    maxs = dict(ref_df.max())
    epsilon = 1e-10

    def calculate_s(column):
        values = ref_df[column]
        denom = maxs[column] - mins[column]
        if denom == 0:
            return [1.0] * len(pred_df)
        sims = []
        for element in np.array(pred_df[column]):
            f = np.count_nonzero(values < element)/values.size
            if f == 0:
                sim = (element - mins[column])/denom
            elif 0 < f <= 0.5:
                sim = 2*f
            elif 0.5 < f < 1:
                sim = 2*(1-f)
            else:  # f == 1
                sim = (maxs[column] - element)/denom
            sims.append(sim)
        return sims

    sim_df = pd.DataFrame({c: calculate_s(c) for c in pred_df.columns})
    min_similarity = sim_df.min(axis=1, skipna=True)
    MoD = sim_df.idxmin(axis=1, skipna=True).fillna('all_constant')
    return pd.concat([min_similarity.rename('MESS'), MoD.rename('MoD')], axis=1)

# -------------------------------
# HUC-level pointwise MESS
# -------------------------------
def huc_pointwise_MESS(points_df, exclude_cols=['huc8','presence','geo','rand','year']):
    """
    Computes point-level MESS and aggregates to HUC-level summaries.
    """
    # Keep only numeric predictors
    predictor_cols = [c for c in points_df.columns if c not in exclude_cols]
    predictor_cols = [c for c in predictor_cols if pd.api.types.is_numeric_dtype(points_df[c])]
    points_df[predictor_cols] = points_df[predictor_cols].apply(pd.to_numeric, errors='coerce')

    huc_results = []

    for ref_huc in points_df['huc8'].unique():
        ref_points = points_df[points_df['huc8'] == ref_huc][predictor_cols]
        pred_points = points_df[points_df['huc8'] != ref_huc][predictor_cols]

        mess_out = MESS_safe(ref_points, pred_points)
        mess_out['ref_huc8'] = ref_huc
        mess_out['pred_huc8'] = points_df.loc[points_df['huc8'] != ref_huc, 'huc8'].values

        huc_results.append(mess_out)

    huc_mess_df = pd.concat(huc_results, ignore_index=True)

    # Aggregate to HUC-level summaries
    huc_summary = huc_mess_df.groupby('ref_huc8').agg(
        median_MESS=('MESS','median'),
        IQR_MESS=('MESS', lambda x: np.percentile(x, 75)-np.percentile(x,25)),
        top_MoD=('MoD', lambda x: x.value_counts().idxmax())
    ).reset_index()

    return huc_summary



def predict_to_geotiff_full(clf, scaler, features, predictors_tif, 
                            boundary_shapefile=None, outdir="predictions",
                            out_name="prediction.tif", block_size=512, normalize=True,
                            return_image=False):
    """
    Predict SDM probabilities to every pixel in a raster with optional masking
    and normalization to 0-1. Optionally returns an RGBA array for Folium.
    """
    os.makedirs(outdir, exist_ok=True)
    out_path = os.path.join(outdir, out_name)

    # Load boundary if provided
    if boundary_shapefile:
        boundary_gdf = gpd.read_file(boundary_shapefile).to_crs(epsg=5070)

    with rasterio.open(predictors_tif) as src:
        # Map feature names to band indices (1-based)
        try:
            band_indices = [src.descriptions.index(f) + 1 for f in features]
        except ValueError as e:
            missing = [f for f in features if f not in src.descriptions]
            raise ValueError(f"Missing features in raster bands: {missing}")

        height, width = src.height, src.width
        prob_raster = np.full((height, width), np.nan, dtype="float32")

        # Block-wise prediction
        for i in range(0, height, block_size):
            for j in range(0, width, block_size):
                h = min(block_size, height - i)
                w = min(block_size, width - j)
                window = Window(j, i, w, h)
                data = src.read(band_indices, window=window)
                n_bands, n_rows, n_cols = data.shape

                X_block = data.reshape(n_bands, -1).T
                mask_valid = ~np.any(np.isnan(X_block), axis=1)

                if np.any(mask_valid):
                    X_block_df = pd.DataFrame(X_block[mask_valid], columns=features)
                    X_block_df = X_block_df[scaler.feature_names_in_]  # ensure correct order
                    X_block_scaled = scaler.transform(X_block_df)
                    probs = clf.predict_proba(X_block_scaled)[:, 1]

                    flat_block = np.full((n_rows * n_cols,), np.nan, dtype="float32")
                    flat_block[mask_valid] = probs
                    prob_raster[i:i+h, j:j+w] = flat_block.reshape((n_rows, n_cols))

        raster_crs = src.crs
        raster_transform = src.transform

    # Normalize globally if requested
    if normalize:
        valid_mask = ~np.isnan(prob_raster)
        min_val, max_val = prob_raster[valid_mask].min(), prob_raster[valid_mask].max()
        prob_raster[valid_mask] = (prob_raster[valid_mask] - min_val) / (max_val - min_val + 1e-8)

    # Prepare metadata for writing
    out_meta = {
        "driver": "GTiff",
        "height": height,
        "width": width,
        "count": 1,
        "dtype": "float32",
        "crs": raster_crs,
        "transform": raster_transform
    }

    # Optional masking to boundary
    if boundary_shapefile:
        with rasterio.open(predictors_tif) as src_raster:
            masked, out_transform = mask(src_raster, boundary_gdf.geometry, crop=False, nodata=np.nan)
            prob_raster = masked[0]
            out_meta["transform"] = out_transform
            out_meta["height"] = masked.shape[1]
            out_meta["width"] = masked.shape[2]

    # Write output raster
    with rasterio.open(out_path, "w", **out_meta) as dest:
        dest.write(prob_raster, 1)

    print(f"Prediction saved to {out_path}")

    rgba_img = None
    if return_image:
        # Mask nodata
        prob_array = np.where(np.isnan(prob_raster), 0, prob_raster)
        # Normalize 0-1
        norm = colors.Normalize(vmin=0, vmax=1)
        cmap = cm.get_cmap('YlOrRd')
        rgba_img = cmap(norm(prob_array))
        rgba_img = (rgba_img * 255).astype(np.uint8)
        rgba_img = Image.fromarray(rgba_img)

    if return_image:
        return out_path, rgba_img
    else:
        return out_path


def folium_heatmap_from_tif(tif_path, map_center=[46.5, -94.2], zoom_start=6, colormap='YlOrRd', opacity=0.6):
    """
    Create a Folium map with a raster heatmap overlay from a GeoTIFF.
    
    Parameters
    ----------
    tif_path : str
        Path to the predicted GeoTIFF (any CRS).
    map_center : list
        [lat, lon] center of the map.
    zoom_start : int
        Initial zoom level.
    colormap : str
        Matplotlib colormap name.
    opacity : float
        Overlay opacity.
    
    Returns
    -------
    folium.Map object with heatmap overlay.
    """
    # Open raster
    with rasterio.open(tif_path) as src:
        src_array = src.read(1)
        src_nodata = src.nodata
        src_crs = src.crs
        src_transform = src.transform
        
        # Mask nodata
        src_array = np.where(src_array == src_nodata, np.nan, src_array)
        
        # Reproject to 4326
        transform, width, height = calculate_default_transform(
            src_crs, 'EPSG:4326', src.width, src.height, *src.bounds)
        dst_array = np.empty((height, width), dtype=src_array.dtype)
        
        reproject(
            source=src_array,
            destination=dst_array,
            src_transform=src_transform,
            src_crs=src_crs,
            dst_transform=transform,
            dst_crs='EPSG:4326',
            resampling=Resampling.nearest
        )
    
    # Normalize for colormap
    valid = ~np.isnan(dst_array)
    vmin, vmax = dst_array[valid].min(), dst_array[valid].max()
    norm_array = (dst_array - vmin) / (vmax - vmin + 1e-8)
    norm_array = np.clip(norm_array, 0, 1)
    
    # Apply colormap
    cmap = cm.get_cmap(colormap)
    rgba_array = (cmap(norm_array) * 255).astype(np.uint8)
    rgba_array[np.isnan(norm_array)] = [0,0,0,0]  # transparent for nodata
    
    # Save temporary PNG
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
        tmp_path = tmpfile.name
        Image.fromarray(rgba_array).save(tmp_path)
    
    # Compute bounds in 4326
    left, bottom = transform * (0, height)
    right, top = transform * (width, 0)
    bounds = [[bottom, left], [top, right]]
    
    # Create Folium map
    m = folium.Map(location=map_center, zoom_start=zoom_start)
    
    folium.raster_layers.ImageOverlay(
        image=tmp_path,
        bounds=bounds,
        opacity=opacity,
        interactive=True,
        name="Predicted Suitability"
    ).add_to(m)
    
    folium.LayerControl().add_to(m)
    
    return m

In [None]:
# -------------------------------
# Example usage
# -------------------------------
huc_summary = huc_pointwise_MESS(all_points_df)
print(huc_summary.head())

In [None]:
out_path, rgba_img = predict_to_geotiff_full(
    clf=best_model,
    scaler=scaler,
    features=all_params,
    predictors_tif="'predictors_your_taxa.tif",
    boundary_shapefile="your_state_border.shp",
    outdir=".",
    out_name="pred.tif",
    return_image=True
)


In [None]:
m_pred = folium_heatmap_from_tif("pred.tif")
m_pred

In [None]:
huc_gdf = gpd.read_file("huc_8.shp")
# Ensure column types match
huc_gdf['huc8'] = huc_gdf['huc8'].astype(str)
huc_summary['ref_huc8'] = huc_summary['ref_huc8'].astype(str)
# Assuming huc_summary has 'median_MESS' column
mess_min = huc_summary['median_MESS'].min()
mess_max = huc_summary['median_MESS'].max()

# Normalize to -1 → 1
huc_summary['median_MESS_norm'] = 2 * ((huc_summary['median_MESS'] - mess_min) / (mess_max - mess_min)) - 1

In [None]:
# Normalize negative MESS values to 0-1 (absolute distance from 0)
neg_mask = huc_summary['median_MESS'] < 0
neg_values = huc_summary.loc[neg_mask, 'median_MESS']
max_neg = abs(neg_values.min()) if len(neg_values) > 0 else 1
huc_summary['median_MESS_norm'] = 0  # default for >=0
huc_summary.loc[neg_mask, 'median_MESS_norm'] = -neg_values / max_neg  # normalized 0 -> 1 for negative

In [None]:
huc_mess_gdf = huc_gdf.merge(huc_summary, left_on='huc8', right_on='ref_huc8', how='left')
huc_mess_gdf = huc_mess_gdf.dropna(subset=['median_MESS'])

In [None]:
# --- Ensure no NaNs and define MESS color mapping ---
def mess_to_color(mess_val, min_neg):
    """
    Convert median MESS to grayscale.
    Negative MESS -> white (most novel) to dark grey (less novel)
    Non-negative MESS -> dark grey
    """
    if mess_val >= 0:
        return '#444444'  # solid dark grey
    else:
        norm_val = min(abs(mess_val)/abs(min_neg), 1)
        grey_shade = 1 - norm_val  # 1=white, 0=black/dark
        return mcolors.to_hex((grey_shade, grey_shade, grey_shade))

min_neg_mess = huc_mess_gdf['median_MESS'][huc_mess_gdf['median_MESS'] < 0].min()
huc_mess_gdf['fill_color'] = huc_mess_gdf['median_MESS'].apply(lambda x: mess_to_color(x, min_neg_mess))
huc_mess_gdf = huc_mess_gdf.dropna(subset=['median_MESS'])

# --- Create base map ---
m_mess = folium.Map(location=[46.5, -94.2], zoom_start=6)

# --- Add HUC8 MESS polygons ---
folium.GeoJson(
    huc_mess_gdf.to_json(),
    style_function=lambda feature: {
        'fillColor': feature['properties']['fill_color'],
        'color': 'black',
        'weight': 0.5,
        'fillOpacity': 0.7
    },
    name="HUC8 MESS"
).add_to(m_mess)

m_mess
