In [None]:
import os

import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import rasterio
import verde as vd
import warnings
warnings.filterwarnings('ignore')

from sklearn.preprocessing import MultiLabelBinarizer
from shapely.geometry import Point

# Data file paths

In [None]:
# The text files containing the sPlotOpen data is available at:
# https://idata.idiv.de/ddm/Data/ShowData/3474?version=55
sPlotOpen_occurrences_file = "sPlotOpen_DT(1).txt"
sPlotOpen_metadata_file = "sPlotOpen_header(2).txt"
worldclim_folder = "worldclim"
soilgrids_folder = "soilgrids250"

# Occurrence data

In [None]:
df = pd.read_csv(sPlotOpen_occurrences_file, delimiter="\t")
df.head()

In [None]:
species_list = df["Species"].unique()
num_species = len(df["Species"].unique())
species_list

for i, species in enumerate(species_list):
    if type(species) != str:
        nan_index = i

species_list = np.delete(species_list, nan_index)
species_list = np.sort(species_list)
species2ind = {species: i for i, species in enumerate(species_list)}

pd.DataFrame(list(species2ind.items()), columns=['Species Name', 'Index']).to_csv("species_names.csv", index=False)

In [None]:
grouped = df.groupby('PlotObservationID').aggregate({'Species': list})

grouped['Species'] = grouped['Species'].apply(lambda x: [species2ind[species] for species in x if not pd.isna(species)])
site2ind = {site: i for i, site in enumerate(grouped.index.values.tolist())}
ind2site = {i: site for i, site in enumerate(grouped.index.values.tolist())}

mlb = MultiLabelBinarizer()
species_encoded = mlb.fit_transform(grouped['Species'])
print(species_encoded.shape)

In [None]:
# Save species_occurrences file
np.save("species_occurrences.npy", species_encoded.astype(bool))

# Extract predictors

In [None]:
plots_metadata = pd.read_csv(sPlotOpen_metadata_file, delimiter="\t")

In [None]:
# Fill missing data of WorldClim and SoilGrids with nearest non-missing value

def find_nearest_non_missing(data, row, col, no_data_value, max_radius=100):
    rows, cols = data.shape
    for radius in range(1, max_radius + 1):
        for dy in range(-radius, radius + 1):
            for dx in range(-radius, radius + 1):
                r, c = row + dy, col + dx
                if 0 <= r < rows and 0 <= c < cols and not np.isclose(data[r, c], no_data_value, atol=0):
                    return data[r, c].item()
    return None  # Return None if no valid value is found within the max_radius

## Location

In [None]:
plots_metadata[["PlotObservationID","Longitude", "Latitude"]].to_csv("location_data.csv")

## WorldClim

In [None]:
locations = plots_metadata[["Longitude", "Latitude"]].values

worldclim_variables = ['bio_' + str(i+1) for i in range(19)]
worldclim_data = np.zeros((len(locations), 19), dtype="float32")

no_data_value = -3.4e+38

for j, wv in enumerate(worldclim_variables):
    print(f"Processing {wv}")
    with rasterio.open(f"{worldclim_folder}/wc2.1_30s_{wv}.tif") as src:

        data = src.read(1)
        for i, val in enumerate(src.sample(locations)):
            if np.isclose(val, no_data_value, atol=0):
                x, y = locations[i]
                row, col = src.index(x, y)
                val = find_nearest_non_missing(data, row, col, no_data_value)
            worldclim_data[i, j] = val

In [None]:
worldclim_data = pd.DataFrame(worldclim_data, columns=worldclim_variables)
worldclim_data["PlotObservationID"] = plots_metadata["PlotObservationID"]
worldclim_data.describe()

In [None]:
worldclim_data.to_csv("worldclim_data.csv")

## SoilGrids

In [None]:
locations = plots_metadata[["Longitude", "Latitude"]].values
soilgrid_data = np.zeros((len(locations), 8))
soil_variables = []

for j, soil_file in enumerate(os.listdir(soilgrids_folder)):
    soil_variable = soil_file[:6]
    soil_variables.append(soil_variable)
    print(f"Processing {soil_variable}")
    with rasterio.open(f"{soilgrids_folder}/{soil_file}") as src:
        if soil_variable in ["ORCDRC", "CECSOL", "BDTICM", "BLDFIE"]:
            no_data_value = -32768.0
        elif soil_variable in ["PHIHOX", "CLYPPT", "SLTPPT", "SNDPPT"]:
            no_data_value = 255
        else:
            raise ValueError(f"Unknown missing value for {soil_variable}")
        data = src.read(1)
        for i, val in enumerate(src.sample(locations)):
            if val == no_data_value:
                x, y = locations[i]
                row, col = src.index(x, y)
                val = find_nearest_non_missing(data, row, col, no_data_value)
            soilgrid_data[i, j] = val

In [None]:
soilgrid_data = pd.DataFrame(soilgrid_data, columns=soil_variables)
soilgrid_data["PlotObservationID"] = plots_metadata["PlotObservationID"]
soilgrid_data.describe()

In [None]:
soilgrid_data.to_csv("soilgrid_data.csv")

# Split the data into training, validation, and test sets

In [None]:
locations = pd.read_csv("location_data.csv")

In [None]:
split_seed = 42

spacing = 1
test_size = 0.15
val_size = 0.15

locations = pd.read_csv("location_data.csv")
coordinates = np.array(locations[["Longitude", "Latitude"]])

data_indices = np.arange(len(coordinates))

train_block, test_block = vd.train_test_split(
    coordinates.transpose(),
    data_indices,
    spacing=spacing,
    test_size=test_size,
    random_state=split_seed,
)
train_indices, test_indices = train_block[1][0], test_block[1][0]

train_block, val_block = vd.train_test_split(
    coordinates[train_indices].transpose(),
    train_indices,
    spacing=spacing,
    test_size=val_size/(1-test_size),
    random_state=split_seed,
)
train_indices, val_indices = train_block[1][0], val_block[1][0]

In [None]:
world = gpd.read_file(gpd.datasets.get_path('naturalearth_lowres'))
fig, ax = plt.subplots(figsize=(9, 8))

palette = ["#7a69e7", "#62ada8", "#eaa37f"]

world.plot(ax=ax, color='lightgray')

markersize = 0.01

gdf = gpd.GeoDataFrame(geometry=[Point(lon, lat) for lon, lat in coordinates[train_indices]])
gdf.plot(ax=ax, color=palette[0], markersize=markersize, label="train")

gdf = gpd.GeoDataFrame(geometry=[Point(lon, lat) for lon, lat in coordinates[val_indices]])
gdf.plot(ax=ax, color=palette[1], markersize=markersize, label="valid")

gdf = gpd.GeoDataFrame(geometry=[Point(lon, lat) for lon, lat in coordinates[test_indices]])
gdf.plot(ax=ax, color=palette[2], markersize=markersize, label="test")

ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)

legend = ax.legend(loc='lower left')

for handle in legend.legend_handles:
    handle.set_sizes([20]) # increase the size of the markers in the legend

ax.margins(0)

ax.set_ylim((-63, 90))

plt.savefig("splits_sPlotOpen.png", dpi=1000, bbox_inches='tight')

# 