<a href="https://colab.research.google.com/github/VHSajna/0xfolio/blob/master/Guided_XGBoost_Tutorial_for_Fisheries_Scientistsv2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install copernicusmarine
import pandas as pd
import os
import matplotlib.pyplot as plt
import cartopy
import numpy as np
import xarray as xr
import earthaccess
from match_func import match_nearest
import netCDF4
import h5netcdf
import math
import copernicusmarine
import geopandas as gpd
import earthaccess
import xarray as xr
import cartopy.feature as cfeature
from shapely.ops import unary_union
from tqdm import tqdm
import cmocean
auth = earthaccess.login(persist=True)

#os.chdir(r'/home/jovyan/Hackweek2025/proj_2025_sdm/data')
#trawl_df = pd.read_csv('fisheries_with_pace_rrs_avw2.csv') #read data
trawl_df = pd.read_csv('/home/jovyan/proj_2025_sdm/data/fisheries_with_pace_rrs_avw2.csv') #read data
# --- Data Cleaning (Crucial Step!) ---

# 1. Ensure the date column is a proper datetime object.
#    This is vital for matching with the environmental data.
trawl_df['TOWDATETIME_EST'] = pd.to_datetime(trawl_df['TOWDATETIME_EST'])

# 2. Ensure coordinates are numeric
trawl_df['LON'] = pd.to_numeric(trawl_df['LON'])
trawl_df['LAT'] = pd.to_numeric(trawl_df['LAT'])

# Display the first few rows and data types to verify
print("Catch Data Head:")
print(trawl_df.head())
print("\nData Types:")
print(trawl_df.info())

# --- Standardize Column Names to Match Xarray Dims ---
# This is the crucial step to ensure congruency.
rename_dict = {
    'TOWDATETIME_EST': 'time',
    'LAT': 'latitude',
    'LON': 'longitude'
}
trawl_df = trawl_df.rename(columns=rename_dict)

matched=[]
for item in trawl_df.columns:
    if 'Rrs' in item:  # Using the 'in' operator for substring check
        pass
    else:
        matched.append(item)

sub = trawl_df[list(matched)] #subset (remove RRS columns)
sub['station'] = np.arange(1,len(sub)+1) #add station col

tot_sum = {}
for s in sub.columns[5:-2]:
   tot_sum.update({s:int(sub[s].sum())})  #get total sum of all values

sorted_tot_sum = sorted(tot_sum.items(), key=lambda item: item[1],reverse=True) #sort by max-->min


abu=[]
for val in sorted_tot_sum:
    if val[1]>50000: #if more than 50000 observations
        abu.append(val[0])


#plot max valeus
for s in sub[abu].columns:
    plt.plot(sub.station, sub[s],label=s)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', )
plt.xlabel('Station #')
plt.ylabel('Total # of obs')
plt.title('Species with more than 50000 observations')
plt.show()

min_lat = trawl_df['latitude'].min()
max_lat = trawl_df['latitude'].max()
min_lon = trawl_df['longitude'].min()
max_lon = trawl_df['longitude'].max()

# Print DataFrame shape and columns (for trawl_df)
print("\n--- Trawl Data Shape and Columns ---")
print("Trawl DataFrame Shape:", trawl_df.shape)
#print("Trawl DataFrame Columns:", trawl_df.columns.tolist())


print(f"Trawl data latitude range: {min_lat} to {max_lat}")
print(f"Trawl data longitude range: {min_lon} to {max_lon}")
print("Trawl Data Time Range:", trawl_df["time"].min(), "to", trawl_df["time"].max())

# Create a figure and a 3x2 grid of subplots
bathym = cfeature.NaturalEarthFeature(name='bathymetry_K_200', scale='10m', category='physical')
bathym = unary_union(list(bathym.geometries()))
fig, axs = plt.subplots(2, 3, figsize=(12, 8),subplot_kw={'projection': cartopy.crs.PlateCarree()})

# Flatten the axs array for easy iteration if needed, though direct indexing works too
axs = axs.flatten()

# Example data for plotting
x = np.linspace(0, 10, 100)

# Loop through each subplot and plot data
for i, ax in enumerate(axs):
    ax.scatter(sub[sub[sub[abu].columns[i]]==0].longitude, sub[sub[sub[abu].columns[i]]==0].latitude,c='r',s=8,label='Absence')
    ax.scatter(sub[sub[sub[abu].columns[i]]!=0].longitude, sub[sub[sub[abu].columns[i]]!=0].latitude,c='g',s=8,label='Presence')
    ax.add_feature(cartopy.feature.COASTLINE, linewidth=1) #add coastlines
    ax.add_feature(cartopy.feature.LAND,  facecolor='lightgrey') #add land mask
    ax.add_geometries(bathym, facecolor='none', edgecolor='black', crs=cartopy.crs.PlateCarree()) #add bathymetry line
    ax.add_feature(cfeature.OCEAN, facecolor='azure')
    ax.set_title(sub[abu].columns[i])
    ax.legend()
fig.suptitle('Species presence by station')

def get_pace_path(trawl_df, short_name):
    tspan = (trawl_df.date.min(), trawl_df.date.max())
    bbox = (-76.75, 33, -63, 46) #west,south,east,north
    #clouds = (0, 50)
    results = earthaccess.search_data(
        short_name=short_name,
        temporal=tspan,
        bounding_box=bbox,
        granule_name="*.8*.4km*")
        #cloud_cover=clouds,
    paths = earthaccess.open(results)
    return paths


def open_pace(paths):
    d8=[]
    for file in paths:
        d = xr.open_dataset(file)
        d8.append(d.attrs['time_coverage_start'])
    ds = xr.open_mfdataset(paths, combine='nested',concat_dim='datetime').assign_coords({'time':d8}) #add new dimension and assign time data to it
    ds = ds.rename({'datetime':'time'}) #rename to time
    ds = ds.where((ds.lat > 34.40918) & (ds.lat < 46.362305) & (-63>ds.lon) & (-77< ds.lon),drop=True) #general spatial subset
    ds = ds.rename({'lat':'latitude','lon':'longitude'})
    #ds['time']=[pd.to_datetime(d) for d in ds.time.values] #convert to pandas datetime
    ds['time'] = pd.to_datetime(ds.time) #convert to pandas datetime
    return ds

#trawl_df pre-processing
trawl_df['date'] = [str(d).split(' ')[0] for d in trawl_df.time]
#trawl_df['time'] = [pd.to_datetime(d.replace(' ','T')) for d in trawl_df.TOWDATETIME_EST] #format as timestamp

#get pace data
paths_avw= get_pace_path(trawl_df,"PACE_OCI_L3M_AVW")
paths_chl= get_pace_path(trawl_df,"PACE_OCI_L3M_CHL")
paths_kd= get_pace_path(trawl_df,"PACE_OCI_L3M_KD")


ds_avw, ds_chl = list(map(open_pace,[paths_avw, paths_chl])) #get path values for each variable

#match data
data=[ds_avw, ds_chl] #list of datasets
var_names =['avw','chlor_a'] #list of variable names
for d in range(len(data)):
    if d == 0:
        trawl_df = match_nearest(trawl_df,data[d],var_names[d],var_names[d],date=trawl_df.time) #get full trawl_df + avw
    else:
        trawl_df[var_names[d]]=match_nearest(trawl_df,data[d],var_names[d],var_names[d],date=trawl_df.time)[var_names[d]] #add chl in

# Check for any missing values, which could indicate a mismatch
#print("\nMissing values count:")
#print(trawl_df[['avw','chlor_a'].isnull().sum())

bathym = cfeature.NaturalEarthFeature(name='bathymetry_K_200', scale='10m', category='physical')
bathym = unary_union(list(bathym.geometries()))

for i in range(2):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6),subplot_kw={'projection': cartopy.crs.PlateCarree()})

    im=axs[0].pcolormesh(ds_avw.longitude, ds_avw.latitude, ds_avw.avw[i],cmap=cmocean.cm.solar)
    axs[0].add_feature(cartopy.feature.COASTLINE, linewidth=1) #add coastlines
    axs[0].add_feature(cartopy.feature.LAND,  facecolor='lightgrey') #add land mask
    axs[0].add_geometries(bathym, facecolor='none', edgecolor='black', crs=cartopy.crs.PlateCarree()) #add bathymetry line
    axs[0].set_title('Apparent visible wavelength')

    im1=axs[1].pcolormesh(ds_chl.longitude, ds_chl.latitude, ds_chl.chlor_a[i],cmap=cmocean.cm.algae)
    axs[1].add_feature(cartopy.feature.COASTLINE, linewidth=1) #add coastlines
    axs[1].add_feature(cartopy.feature.LAND,  facecolor='lightgrey') #add land mask
    axs[1].add_geometries(bathym, facecolor='none', edgecolor='black', crs=cartopy.crs.PlateCarree()) #add bathymetry line
    axs[1].set_title('Chlorophyll')

    fig.colorbar(im, ax=axs[0],shrink=0.65)
    fig.colorbar(im1, ax=axs[1],shrink=0.65)
    fig.suptitle('PACE Variables: '+ str(ds_avw.time[i].values).split('T')[0],y=0.85)
    fig.show()

isna= [math.isnan(d) for d in trawl_df.avw]
fig = plt.figure(figsize=(10, 7)) #set figure size
map_projection = cartopy.crs.PlateCarree() #set map projection
ax = plt.axes(projection=map_projection)
plt.pcolormesh(ds_chl.longitude, ds_chl.latitude, ds_chl.chlor_a[0],cmap=cmocean.cm.algae)
plt.scatter(trawl_df[isna][:20].longitude, trawl_df[isna][:20].latitude,s=8,c='r')
ax.add_feature(cartopy.feature.COASTLINE, linewidth=1) #add coastlines
ax.add_feature(cartopy.feature.LAND,  facecolor='lightgrey') #add land mask
ax.add_geometries(bathym, facecolor='none', edgecolor='black', crs=cartopy.crs.PlateCarree()) #add bathymetry line
plt.title('8-Day chlor_a \n Visualize missing data after matchup')

def get_pace_nan_replace(trawl_df, short_name):
    #used to address NaN values to allow for more robust data
    tspan = (trawl_df.date.min(), trawl_df.date.max())
    bbox = (-76.75, 33, -63, 46) #west,south,east,north
    #clouds = (0, 50)
    results = earthaccess.search_data(
        short_name=short_name,
        temporal=tspan,
        bounding_box=bbox,
        granule_name="*.M*.4km*") #update to month
        #cloud_cover=clouds,
    paths = earthaccess.open(results)
    return paths

def extract_scalar(val):
    #since avw values are objects, this function extracts the
    #actual, usable scalar value (a float)
    if isinstance(val, (list, np.ndarray)): #checks if instance of a list or a NumPy list
        if np.size(val)== 1:
            return val[0] if isinstance(val, list) else val.item()
        else:
            return np.nan
    elif hasattr(val, 'item'): #unwraps from NumPy object, xarray scalar DataArray, etc.
        try:
            return val.item()
        except:
            return np.nan
    return val

#converting datatype to usable float64
trawl_df['avw'] = trawl_df['avw'].apply(extract_scalar)
trawl_df['chlor_a'] = trawl_df['chlor_a'].apply(extract_scalar)

trawl_df['avw'] = pd.to_numeric(trawl_df['avw'], errors='coerce') #another conversion to float
trawl_df['chlor_a'] = pd.to_numeric(trawl_df['chlor_a'], errors='coerce') #another conversion to float

#creating list of indices containing NaN values
na_index_avw = trawl_df[trawl_df.avw.isna()].index
na_index_chl = trawl_df[trawl_df.chlor_a.isna()].index

#creates dataset containing only NaN values
trawl_df_avw_nan_only = trawl_df.loc[na_index_avw]
trawl_df_chl_nan_only = trawl_df.loc[na_index_chl]

trawl_df_avw_nan_only = trawl_df_avw_nan_only.drop('avw', axis = 1)
trawl_df_chl_nan_only = trawl_df_chl_nan_only.drop('chlor_a', axis = 1)

nan_avw_path = get_pace_nan_replace(trawl_df,"PACE_OCI_L3M_AVW")
nan_chl_path = get_pace_nan_replace(trawl_df,"PACE_OCI_L3M_CHL")

nan_avw_ds, nan_chl_ds = list(map(open_pace,[nan_avw_path, nan_chl_path])) #get path values for each variable

matched_monthly = match_nearest(trawl_df_avw_nan_only.reset_index(), nan_avw_ds, 'avw', 'avw', date=trawl_df_avw_nan_only.time)
trawl_df.loc[na_index_avw, 'avw'] = matched_monthly['avw']

matched_monthly = match_nearest(trawl_df_chl_nan_only.reset_index(), nan_chl_ds, 'chlor_a', 'chlor_a', date=trawl_df_chl_nan_only.time)
trawl_df.loc[na_index_chl, 'chlor_a'] = matched_monthly['chlor_a']

glorys_ds= copernicusmarine.open_dataset(dataset_id = 'cmems_mod_glo_phy_myint_0.083deg_P1D-m',minimum_longitude=-77, maximum_longitude=-63,minimum_latitude=34,maximum_latitude=46,)
glorys_subset = glorys_ds.where((glorys_ds['time.year'] > 2023) & (glorys_ds['time.month']>2)& (glorys_ds['time.month']<6), drop=True)
glorys_subset['time'] = [pd.Timestamp(d) for d in glorys_subset.time.values]

# Select four time slices (modify indices if needed)
time_indices = [0, int(len(glorys_subset.time)/3), int(2*len(glorys_subset.time)/3), -1]

# Set up figure
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(16, 8))

for i, t in enumerate(time_indices):
    # Plot
    ax = axes[0, i]
    glorys_subset.bottomT.isel(time=t).plot(ax=ax, cmap="viridis")
    ax.set_title(f"Bottom Temperature - Time {t}")

    # Plot 2
    ax = axes[1, i]
    glorys_subset.mlotst.isel(time=t).plot(ax=ax, cmap="plasma")
    ax.set_title(f"Mixed Layer Thickness - Time {t}")

# Adjust layout
plt.tight_layout()
plt.show()

# --- Vectorized Extraction ---

# 1. Create DataArrays for coordinates from your standardized catch_df.
#    The names 'time', 'latitude', 'longitude' now match the env_ds dimensions.
times = xr.DataArray(trawl_df.time.values, dims="observation")
lats = xr.DataArray(trawl_df.latitude.values, dims="observation")
lons = xr.DataArray(trawl_df.longitude.values, dims="observation")

# 2. Extract the 3D variables (bottomT, mlotst) in one go.
#    These variables don't have a 'depth' dimension.
extracted_3d_data = glorys_subset[['bottomT', 'mlotst']].sel(
    time=times,
    latitude=lats,
    longitude=lons,
    method="nearest"
)

# 3. Handle the 4D variable (so) separately to deal with depth.
#    Since it's bottom trawl data, we select the deepest available grid cell.
#    .isel(depth=-1) selects the last element along the depth dimension.
extracted_so = glorys_subset['so'].sel(
    time=times,
    latitude=lats,
    longitude=lons,
    method="nearest"
).isel(depth=-1) # Select the bottom-most depth layer

# --- Merge results back into your original DataFrame ---

# Add the extracted 3D variables
trawl_df['bottom_temp'] = extracted_3d_data['bottomT'].values
trawl_df['mld'] = extracted_3d_data['mlotst'].values

# Add the extracted bottom salinity
trawl_df['bottom_salinity'] = extracted_so.values

# --- Display the final merged DataFrame ---
print("\n--- FINAL MERGED DATA ---")
# Displaying relevant columns for verification
final_columns = [
    'time', 'latitude', 'longitude', 'MEAN_DEPTH',
    'bottom_temp', 'mld', 'bottom_salinity'
]
print(trawl_df[final_columns].head())

# Check for any missing values, which could indicate a mismatch
print("\nMissing values count:")
print(trawl_df[final_columns].isnull().sum())

bathym = cfeature.NaturalEarthFeature(name='bathymetry_K_200', scale='10m', category='physical')
bathym = unary_union(list(bathym.geometries()))
fig, axs = plt.subplots(2, 3, figsize=(12, 6),subplot_kw={'projection': cartopy.crs.PlateCarree()})

# Flatten the axs array for easy iteration if needed, though direct indexing works too
axs = axs.flatten()

var=['mld','bottom_temp','bottom_salinity','chlor_a','avw','Rrs_brightness']
var_n = ['Mixed layer depth','Bottom Temperature','Bottom Salinity','Chlorophyll a','AVW','RRS brightness']
# Loop through each subplot and plot data
for i, ax in enumerate(axs):
    if var[i] == 'bottom_temp':
        cmap = cmocean.cm.thermal
    elif var[i] =='chlor_a':
        cmap=cmocean.cm.algae
    elif var[i] == 'Rrs_brightness':
        cmap=cmocean.cm.solar
    else:
        cmap = cmocean.cm.deep
    im=ax.scatter(trawl_df.longitude, trawl_df.latitude, c=trawl_df[var[i]], label=var[i],s=8, cmap=cmap,vmin=trawl_df[var[i]].min(),vmax=trawl_df[var[i]].max(), )
    ax.add_feature(cartopy.feature.COASTLINE, linewidth=1) #add coastlines
    ax.add_feature(cartopy.feature.LAND,  facecolor='lightgrey') #add land mask
    ax.add_geometries(bathym, facecolor='none', edgecolor='black', crs=cartopy.crs.PlateCarree()) #add bathymetry line
    ax.add_feature(cfeature.OCEAN, facecolor='azure')
    fig.colorbar(im, ax=ax,shrink=0.95)
    ax.set_title(var_n[i])
fig.suptitle('Matched variables')

trawl_df['chlor_a'] = [extract_scalar(d) for d in trawl_df.chlor_a] #change dtype from object to scalar
trawl_df['avw'] = [extract_scalar(d) for d in trawl_df.avw] #change dtype from object to scalar


cpue_butterfish = trawl_df['butterfish']/trawl_df['SWEPT_AREA_km']
cpue_squid = trawl_df['longfin squid']/trawl_df['SWEPT_AREA_km']
cpue_hake = trawl_df['silver hake']/trawl_df['SWEPT_AREA_km']

trawl_df['chlora_log10']=[np.log10(d) for d in trawl_df.chlor_a]

import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.inspection import PartialDependenceDisplay

df = trawl_df.copy()
print(df)

# --- Define your target (y) and predictor (X) variables ---
TARGET_SPECIES = ['butterfish', 'silver hake', 'longfin squid', 'spot']


# Exclude species, location, and metadata columns.
PREDICTOR_COLUMNS = ['avw', 'chlor_a', 'MEAN_DEPTH', 'bottom_temp', 'mld', 'rrs_brightness']

y = np.log1p(df[TARGET_SPECIES])
X_df = df[PREDICTOR_COLUMNS]

X_train, X_val, y_train, y_val = train_test_split(X_df, y, test_size=0.2, random_state=42)

print("\n--- Step 2: Initializing the XGBoost Model ---")

# Let's define the core parameters.
xgb_reg = xgb.XGBRegressor(
    # --- Boosting Parameters (How the model learns) ---
    n_estimators=1000,         # Analogous to the "richness" of the model. Number of trees to build.
                               # We set this high and use early stopping to find the optimal number.
    learning_rate=0.05,        # Analogous to a shrinkage parameter. Lower values make the model more robust.

    # --- Tree Complexity Parameters ---
    max_depth=5,               # Analogous to `k` in s(x, k=...). Controls max interaction depth.
    min_child_weight=1,        # A form of regularization. Prevents learning highly specific patterns.
    gamma=0.1,                 # Analogous to `sp`. A value > 0 penalizes splits, making the model more conservative.
    subsample=0.8,             # Use 80% of data for building each tree. Adds randomness to fight overfitting.
    colsample_bytree=0.8,      # Use 80% of features for building each tree. Also for overfitting.

    # --- Regularization Parameters ---
    reg_alpha=0.005,           # L1 regularization on leaf weights.
    reg_lambda=1,              # L2 regularization on leaf weights.

    # --- Technical Parameters ---
    objective='reg:squarederror', # The loss function to optimize.
    n_jobs=-1,                 # Use all available CPU cores.
    random_state=42,           # For reproducibility.
    eval_metric='rmse'         # Metric to monitor during training.
)

print("Model initialized with parameters:")
print(xgb_reg.get_params())

xgb_reg.fit(X_train, y_train,
            eval_set=[(X_val, y_val)],
            verbose=False) # Set verbose=True to see the training progress

print(f"Model training complete.")

results = xgb_reg.evals_result()
best_iteration = np.argmin(results['validation_0']['rmse'])
best_score = results['validation_0']['rmse'][best_iteration]

print(f"Best iteration found: {best_iteration}")
print(f"Best validation RMSE: {best_score:.4f}")

print("\n--- Step 4: Making Predictions ---")
predictions = xgb_reg.predict(X_val)

# Evaluate the model
rmse = np.sqrt(mean_squared_error(y_val, predictions))
print(f"Final RMSE on validation data: {rmse:.4f}")

# --- Step 5: Model Interpretation (The `mgcv::plot.gam` Analogy) ---


print("\n--- Step 5: Interpreting the Model ---")

# 5a. Feature Importance
print("Plotting feature importance...")
fig, ax = plt.subplots(figsize=(10, 8))
xgb.plot_importance(xgb_reg, ax=ax, max_num_features=10, height=0.8, title="Feature Importance")
plt.tight_layout()
plt.show()

import shap

# 5b. SHAP Values
print("\nCalculating and plotting SHAP values...")
# Create a SHAP explainer object
explainer = shap.TreeExplainer(xgb_reg)

# Calculate SHAP values for the validation set

values = explainer.shap_values(X_val)

# Summary plot: Shows the distribution of impacts for each feature.
# Red means high feature value, blue means low.
# Points to the right increase the prediction, points to the left decrease it.
shap.summary_plot(shap_values, X_val, plot_type="dot", show=False)
plt.title("SHAP Summary Plot for Fisheries Data")
plt.tight_layout()
plt.show()

from sklearn.inspection import PartialDependenceDisplay
# --- Step 9: Create Partial Dependence Plots ---
print(f"\n--- Creating Partial Dependence Plots for {TARGET_SPECIES} ---")

# The PDP shows the marginal effect one or two features have on the predicted outcome.
# It's the closest equivalent to the `plot.gam()` smooths.
for predictor in PREDICTOR_COLUMNS:
    if predictor in X_train.columns:
        try:
            fig, ax = plt.subplots(figsize=(8, 6))
            PartialDependenceDisplay.from_estimator(
                estimator=xgb_reg,
                X=X_train,
                features=[predictor],
                ax=ax
            )
            ax.set_title(f"Partial Dependence Plot for {predictor}\n({TARGET_SPECIES})")
            ax.set_ylabel("Partial Dependence (log CPUE scale)")
            plt.grid(True, linestyle='--', alpha=0.6)
            plt.tight_layout()
            plt.show()
        except Exception as e:
            print(f"Could not create PDP for {predictor}. Error: {e}")
