# Threatened Species Index: Calculation and Database Import

The following steps cover the end-to-end process of calculating the Threatened Species Index from raw data, creating a dedicated table in the MySQL database, and importing the final results for storage and use.

Name: Zihan

### Workflow (Simplified GAM Method)

1.  **Read Raw Data**: Load the initial dataset from the source CSV file.
2.  **Select Columns**: Filter the data to include only the year columns from 1950 to 2020.
3.  **Process Each Species**: For every species in the dataset, perform the following steps:
    * **Extract Time Series**: Isolate the time series for the species. Exclude any series that is entirely empty or is missing data for the designated reference year.
    * **Normalize Data**: Set 1985 as the reference year. Standardize the time series by dividing all values by the value in 1985, making the reference year's value equal to `1`.
    * **Fit GAM Model**: Apply a Generalized Additive Model (using the `pyGAM` library) to fit a smoothed trend to the normalized time series.
    * **Predict Trend**: Use the fitted GAM to predict the trend value for each year from 1950 to 2020.
    * **Calculate Species-Level Index**: The index for the species is the series of predicted values divided by the predicted value of the reference year (1985).
4.  **Aggregate to Global Level**:
    * Combine the indices from all species by calculating the average value for each year.
    * (Optional) Generate bootstrap confidence intervals (`low`, `high`) for the annual averages.
5.  **Export Results**: Save the final aggregated data with the columns: `year`, `value`, `low`, `high`.

### 流程（简化版 GAM 方法）（中文版）

1.  **读取原始数据**：从源 CSV 文件中加载初始数据集。
2.  **挑选数据列**：筛选数据，只保留 1950 年至 2020 年的年份列。
3.  **处理每个物种**：对数据集中的每一个物种，执行以下步骤：
    * **提取时间序列**：分离出该物种的时间序列数据。剔除掉完全为空或在参考年份（1985年）缺少数据的序列。
    * **标准化数据**：将 1985 年设为参考年。通过将所有值除以 1985 年的值来进行标准化，使参考年的数值等于 `1`。
    * **拟合 GAM 模型**：应用广义相加模型（使用 `pyGAM` 库）对标准化的时间序列进行平滑趋势拟合。
    * **预测趋势**：使用拟合好的 GAM 模型预测 1950 年至 2020 年每一年的趋势值。
    * **计算物种层面指数**：该物种的指数是预测值序列除以参考年（1985）的预测值。
4.  **聚合到全局层面**：
    * 通过计算每一年所有物种的平均值，来合并所有物种的指数。
    * （可选）为年度平均值生成自举置信区间（`low`, `high`）。
5.  **导出结果**：保存最终的聚合数据，包含以下列：`year`, `value`, `low`, `high`。

### Step 1 - Calculate Threatened Species Index using GAM and Bootstrapping

This script calculates a composite macro-level index from a large number of individual species' time-series data. It employs a Generalized Additive Model (GAM) to smooth the trend for each species and uses bootstrapping to estimate the confidence interval of the overall trend.

The workflow can be broken down into the following key steps:

1.  **Initialization and Data Loading**: It starts by setting essential parameters, such as the reference year (`1985`), whether to apply a time lag, and bootstrap settings. It then loads the raw time-series data from the CSV file.
2.  **Per-Species Processing**:
    * **Normalization**: For each species' time series, the data is standardized against the value in the reference year, effectively setting the index for `1985 = 1`.
    * **Smoothing and Fitting**: A GAM is fitted to the normalized data. This generates a smooth trend curve, which helps to fill in data gaps and reduce year-to-year noise.
3.  **Aggregation and Statistical Analysis**:
    * **Calculate Mean Trend**: All the individual smoothed curves are aggregated by taking the average across all species for each year. This results in the main index trend (`value`).
    * **Estimate Uncertainty**: To quantify the confidence in the mean trend, a bootstrapping method is applied. It randomly resamples the species thousands of times (`n_boot=1000`), recalculates the mean trend for each sample, and uses the resulting distribution to determine the 95% confidence interval (`low` and `high` values).
4.  **Export Results**: The final calculated columns (`year`, `value`, `low`, `high`) are saved to a new CSV file, ready for visualization or further analysis.

In [10]:
import pandas as pd
import numpy as np
from pygam import LinearGAM, s

# ====== 路径 ======
input_path = r"01_raw_data\02_tsx-aggregated-data-dataset.csv"
output_path = r"02_wrangled_data\Table14_ThreatenedSpeciesIndexTable_version2.csv"

# ====== 参数 ======
ref_year = 1985
apply_lag = True          # 是否应用 3 年滞后
lag_years = 3
n_splines = 10
n_boot = 1000
rng = np.random.default_rng(42)

# ====== 读取数据 ======
df = pd.read_csv(input_path)

# 年份列
year_cols = [c for c in df.columns if c.isdigit()]
years = np.array(list(map(int, year_cols)))
if ref_year not in years:
    raise ValueError(f"参考年 {ref_year} 不在数据年份列中！")
ref_idx = list(years).index(ref_year)

# ====== 对每条 time series 做GAM拟合（先/以1985标准化） ======
species_curves = []
for _, row in df.iterrows():
    y = row[year_cols].values.astype(float)

    # 参考年必须有效且 >0
    if np.isnan(y[ref_idx]) or y[ref_idx] <= 0:
        continue

    # 相对化（以1985年为1）
    y = y / y[ref_idx]

    # 仅取有值的年份去拟合
    mask = ~np.isnan(y)
    x_obs = years[mask]
    y_obs = y[mask]

    # 至少需要3个点才能拟合
    if len(x_obs) < 3:
        continue

    try:
        gam = LinearGAM(s(0, n_splines=n_splines)).fit(x_obs, y_obs)
        y_pred = gam.predict(years)  # 预测全时段
        species_curves.append(y_pred)
    except Exception:
        # 某些极端序列可能拟合失败，跳过即可
        continue

if len(species_curves) == 0:
    raise RuntimeError("没有成功拟合的物种曲线，检查参考年或数据质量。")

species_curves = np.vstack(species_curves)  # 形状：[物种数, 年份数]

# ====== 只保留 >= 1985 的年份；可选再做“末尾-3年”滞后 ======
mask_years = years >= ref_year
years_out = years[mask_years]
curves = species_curves[:, mask_years]

if apply_lag:
    # 去掉末尾 lag_years 年
    if len(years_out) > lag_years:
        years_out = years_out[:-lag_years]
        curves = curves[:, :-lag_years]

# ====== 用bootstrap对“物种曲线集合”聚合，得到 value/low/high ======
# 先计算“物种横截面均值曲线”
mean_curve = np.nanmean(curves, axis=0)

# 再通过对物种维度重采样，得到均值的不确定性
boot_mat = np.empty((n_boot, curves.shape[1]), dtype=float)
n_species = curves.shape[0]
for b in range(n_boot):
    idx = rng.integers(0, n_species, size=n_species)
    sample = curves[idx, :]
    boot_mat[b, :] = np.nanmean(sample, axis=0)

low_ci = np.nanpercentile(boot_mat, 2.5, axis=0)
high_ci = np.nanpercentile(boot_mat, 97.5, axis=0)

# ====== 强制把参考年的三列设为1（对齐官方风格） ======
# 注意此时years_out的第一个元素应该就是ref_year（因为我们截断了 <1985）
if years_out[0] == ref_year:
    mean_curve[0] = 1.0
    low_ci[0] = 1.0
    high_ci[0] = 1.0

# ====== 导出 ======
result = pd.DataFrame({
    "year": years_out,
    "value": mean_curve,
    "low": low_ci,
    "high": high_ci
})
result.to_csv(output_path, index=False)
print(f"指数表已保存（基准年={ref_year}，滞后={apply_lag}）：{output_path}")

指数表已保存（基准年=1985，滞后=True）：02_wrangled_data\Table14_ThreatenedSpeciesIndexTable_version2.csv


### Step 2 - Create Schema for Threatened Species Index Table

This step connects to the MySQL database and defines the structure for our new table, `Table14_ThreatenedSpeciesIndexTable`. The schema is designed to hold the time-series index data:
- `year` is set as an `INT` and serves as the Primary Key.
- `value`, `low`, and `high` are defined as `DOUBLE` to accommodate floating-point numbers.

The `CREATE TABLE IF NOT EXISTS` command is used to ensure the script can be run multiple times without causing an error if the table already exists.

In [11]:
import mysql.connector
from mysql.connector import Error

# --- Database connection configuration ---
db_config = {
    'host': 'database-plantx.cqz06uycysiz.us-east-1.rds.amazonaws.com',
    'user': 'zihan',
    'password': '2002317Yzh12138.',
    'database': 'FIT5120_PlantX_Database',
    'allow_local_infile': True,
    'use_pure': True,
    'charset': 'utf8mb4'
}

# --- Create the schema for Table14 ---
try:
    # Establish the database connection
    connection = mysql.connector.connect(**db_config)
    if connection.is_connected():
        print("Successfully connected to MySQL server, preparing to create Table14.")
        cursor = connection.cursor()
        
        # Define the SQL query to create Table14
        # `year` is set as an INT PRIMARY KEY
        # `value`, `low`, and `high` are set as DOUBLE for floating-point numbers
        create_table_14_query = """
        CREATE TABLE IF NOT EXISTS Table14_ThreatenedSpeciesIndexTable (
            `year` INT PRIMARY KEY,
            `value` DOUBLE,
            `low` DOUBLE,
            `high` DOUBLE
        ) CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
        """
        
        # Execute the query
        cursor.execute(create_table_14_query)
        connection.commit()
        print("Table 'Table14_ThreatenedSpeciesIndexTable' created successfully or already exists.")

except Error as e:
    print(f"Error occurred while creating Table14: {e}")

finally:
    # Close the connection
    if 'connection' in locals() and connection.is_connected():
        cursor.close()
        connection.close()
        print("MySQL connection closed.")

Successfully connected to MySQL server, preparing to create Table14.
Table 'Table14_ThreatenedSpeciesIndexTable' created successfully or already exists.
MySQL connection closed.


### Step 3 - Load CSV Data into Threatened Species Index Table

After creating the table schema, this step populates it with data from the local CSV file (`02_wrangled_data/Table14_ThreatenedSpeciesIndexTable_version1.csv`).

It uses the `LOAD DATA LOCAL INFILE` command, which is a highly efficient method for bulk-inserting data directly from a file into a MySQL table. The query is configured to handle the CSV format by specifying the field and line terminators and to skip the header row (`IGNORE 1 LINES`).

In [12]:
# --- Import CSV data into Table14 ---
try:
    # Re-establish the database connection
    connection = mysql.connector.connect(**db_config)
    if connection.is_connected():
        print("Successfully connected to MySQL server, preparing to import data for Table14.")
        cursor = connection.cursor()
        
        # Define the SQL query for loading data
        # Using LOAD DATA LOCAL INFILE for efficient bulk import
        load_data_query_14 = """
        LOAD DATA LOCAL INFILE '02_wrangled_data/Table14_ThreatenedSpeciesIndexTable_version1.csv'
        INTO TABLE Table14_ThreatenedSpeciesIndexTable
        CHARACTER SET utf8mb4
        FIELDS TERMINATED BY ','
        LINES TERMINATED BY '\\r\\n'
        IGNORE 1 LINES
        (year, value, low, high);
        """
        
        # Execute the query
        cursor.execute(load_data_query_14)
        connection.commit()
        print(f"Data import for Table14 successful! {cursor.rowcount} rows affected.")

except Error as e:
    print(f"Error occurred during data import for Table14: {e}")

finally:
    # Close the connection
    if 'connection' in locals() and connection.is_connected():
        cursor.close()
        connection.close()
        print("MySQL connection closed.")

Successfully connected to MySQL server, preparing to import data for Table14.
Data import for Table14 successful! 37 rows affected.
MySQL connection closed.
