In [3]:
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import seaborn as sns
import numpy as np
from datetime import datetime, timedelta
from scipy.stats import pearsonr
from scipy.stats import linregress
from scipy.stats import zscore

## Load and prepare data

In [4]:
# load data
spei_df = pd.read_csv('Data/district_province_spei_growingperiod.csv')

# rename season_year column for compatability with forest loss
spei_df = spei_df.rename(columns={'season_year': 'year', 'mean_growing_spei': 'mean_spei', 'median_growing_spei': 'median_spei'})

# load forest loss data
forest_df = pd.read_csv('Data/district_forest_cover.csv')
forest_df['year'] = forest_df['year'].astype(int)

# merge SPEI and forest data
merged_df = pd.merge(spei_df, forest_df, on=['district', 'year'])

In [5]:
# generate lags
for lag in range(1, 6):
    merged_df[f'mean_spei_lag{lag}'] = merged_df.groupby('district')['mean_spei'].shift(lag)

plot_df = merged_df.copy()

# get unique provinces
provinces = plot_df['province'].unique()

def compute_lag_corrs(group):
    results = {}
    for lag in range(0, 6):
        col = f'mean_spei_lag{lag}'
        y_col = 'percent_loss_annual'
        if col in group:
            # drop NaNs
            valid = group[[col, y_col]].dropna()
            if len(valid) >= 3:
                slope, intercept, r_value, p_value, std_err = linregress(valid[col], valid[y_col])
                results[f'corr_lag{lag}'] = r_value
                results[f'pval_lag{lag}'] = p_value
                results[f'slope_lag{lag}'] = slope
                results[f'r2_lag{lag}'] = r_value ** 2
            else:
                results[f'corr_lag{lag}'] = None
                results[f'pval_lag{lag}'] = None
                results[f'slope_lag{lag}'] = None
                results[f'r2_lag{lag}'] = None

        else:
            results[f'corr_lag{lag}'] = None
            results[f'pval_lag{lag}'] = None
            results[f'slope_lag{lag}'] = None
            results[f'r2_lag{lag}'] = None
            
    return pd.Series(results)

plot_df['mean_spei_lag1'] = plot_df.groupby('district')['mean_spei'].shift(0)
plot_df['mean_spei_lag1'] = plot_df.groupby('district')['mean_spei'].shift(1)
plot_df['mean_spei_lag2'] = plot_df.groupby('district')['mean_spei'].shift(2)
plot_df['mean_spei_lag3'] = plot_df.groupby('district')['mean_spei'].shift(3)
plot_df['mean_spei_lag4'] = plot_df.groupby('district')['mean_spei'].shift(4)
plot_df['mean_spei_lag5'] = plot_df.groupby('district')['mean_spei'].shift(5)

multi_lag_corrs = (
    plot_df.groupby('district')[['mean_spei_lag1', 'mean_spei_lag2', 'mean_spei_lag3', 'mean_spei_lag4', 'mean_spei_lag5', 'percent_loss_annual']]
    .apply(compute_lag_corrs)
    .reset_index()
)

# Filter and print significant correlations for each lag
for lag in range(1, 6):
    corr_col = f'corr_lag{lag}'
    pval_col = f'pval_lag{lag}'
    slope_col = f'slope_lag{lag}'
    r2_col = f'r2_lag{lag}'
    
    significant = multi_lag_corrs[multi_lag_corrs[pval_col] < 0.05][['district', corr_col, pval_col, slope_col, r2_col]]
    
    print(f'\nSignificant correlations for Lag {lag} (p < 0.05):')
    print(significant.sort_values(by=corr_col, ascending=False).to_string(index=False))


Significant correlations for Lag 1 (p < 0.05):
   district  corr_lag1  pval_lag1  slope_lag1  r2_lag1
   Mpulungu   0.620412   0.003516    0.135388 0.384912
    Mafinga   0.611805   0.004148    0.246712 0.374305
 Senga Hill   0.538354   0.014334    0.357403 0.289825
Chifunabuli   0.465374   0.038664    0.857212 0.216573
   Chisamba  -0.454857   0.043901   -0.207515 0.206895
 Shibuyunji  -0.458649   0.041953   -0.817617 0.210359
     Mkushi  -0.462171   0.040205   -0.316634 0.213602
     Mumbwa  -0.466719   0.038031   -0.285075 0.217827
    Serenje  -0.499425   0.024962   -0.150153 0.249426

Significant correlations for Lag 2 (p < 0.05):
    district  corr_lag2  pval_lag2  slope_lag2  r2_lag2
     Mafinga   0.665710   0.001864    0.322270 0.443169
  Chipangali   0.580494   0.009164    0.361724 0.336974
    Mpulungu   0.578153   0.009517    0.152141 0.334261
     Serenje  -0.464560   0.045080   -0.137137 0.215816
     Chavuma  -0.469796   0.042402   -0.381089 0.220709
Mwansabombwe  -0.4

In [4]:
plot_df.head()

Unnamed: 0,district,year,mean_spei,median_spei,n_months,province,province_avg_start_date,province_avg_end_date,forest_cover_ha,loss_m2,percent_loss_annual,mean_spei_lag1,mean_spei_lag2,mean_spei_lag3,mean_spei_lag4,mean_spei_lag5
0,Chadiza,2001,0.364988,0.311898,5,Eastern,November 26,May 04,57181.013721,513699.3,0.089757,,,,,
1,Chadiza,2002,0.281255,0.244052,6,Eastern,November 26,May 04,57127.207503,538062.2,0.094098,0.364988,,,,
2,Chadiza,2003,0.227219,0.217295,5,Eastern,November 26,May 04,56960.977318,1662302.0,0.290983,0.281255,0.364988,,,
3,Chadiza,2004,-0.079997,0.072081,7,Eastern,November 26,May 04,56773.795452,1871819.0,0.328614,0.227219,0.281255,0.364988,,
4,Chadiza,2005,-0.699067,-0.777293,4,Eastern,November 26,May 04,56703.324923,704705.3,0.124125,-0.079997,0.227219,0.281255,0.364988,


In [4]:
def compute_lag_corrs(group):
    results = {}
    for lag in range(1, 6):
        col = f'mean_spei_lag{lag}'
        y_col = 'percent_loss_annual'
        if col in group:
            # drop NaNs
            valid = group[[col, y_col]].dropna()
            if len(valid) >= 3:
                slope, intercept, r_value, p_value, std_err = linregress(valid[col], valid[y_col])
                results[f'corr_lag{lag}'] = r_value
                results[f'pval_lag{lag}'] = p_value
                results[f'slope_lag{lag}'] = slope
                results[f'r2_lag{lag}'] = r_value ** 2
            else:
                results[f'corr_lag{lag}'] = None
                results[f'pval_lag{lag}'] = None
                results[f'slope_lag{lag}'] = None
                results[f'r2_lag{lag}'] = None

        else:
            results[f'corr_lag{lag}'] = None
            results[f'pval_lag{lag}'] = None
            results[f'slope_lag{lag}'] = None
            results[f'r2_lag{lag}'] = None
            
    return pd.Series(results)

plot_df['mean_spei_lag1'] = plot_df.groupby('district')['mean_spei'].shift(1)
plot_df['mean_spei_lag2'] = plot_df.groupby('district')['mean_spei'].shift(2)
plot_df['mean_spei_lag3'] = plot_df.groupby('district')['mean_spei'].shift(3)
plot_df['mean_spei_lag4'] = plot_df.groupby('district')['mean_spei'].shift(4)
plot_df['mean_spei_lag5'] = plot_df.groupby('district')['mean_spei'].shift(5)

multi_lag_corrs = (
    plot_df.groupby('district')[['mean_spei_lag1', 'mean_spei_lag2', 'mean_spei_lag3', 'mean_spei_lag4', 'mean_spei_lag5', 'percent_loss_annual']]
    .apply(compute_lag_corrs)
    .reset_index()
)

# Filter and print significant correlations for each lag
for lag in range(1, 6):
    corr_col = f'corr_lag{lag}'
    pval_col = f'pval_lag{lag}'
    slope_col = f'slope_lag{lag}'
    r2_col = f'r2_lag{lag}'
    
    significant = multi_lag_corrs[multi_lag_corrs[pval_col] < 0.05][['district', corr_col, pval_col, slope_col, r2_col]]
    
    print(f'\nSignificant correlations for Lag {lag} (p < 0.05):')
    print(significant.sort_values(by=corr_col, ascending=False).to_string(index=False))


Significant correlations for Lag 1 (p < 0.05):
district  corr_lag1  pval_lag1  slope_lag1  r2_lag1
Mpulungu   0.626487   0.003121    0.134870 0.392486
   Mbala   0.592410   0.005918    0.304071 0.350949
   Isoka   0.494574   0.026636    0.115370 0.244603
  Mumbwa  -0.477934   0.033060   -0.296497 0.228421

Significant correlations for Lag 2 (p < 0.05):
district  corr_lag2  pval_lag2  slope_lag2  r2_lag2
Mpulungu   0.570805   0.010699    0.148288 0.325818
 Chavuma  -0.476628   0.039094   -0.486949 0.227175
  Mwense  -0.477635   0.038624   -0.401469 0.228135
   Mansa  -0.519870   0.022519   -0.492783 0.270265

Significant correlations for Lag 3 (p < 0.05):
district  corr_lag3  pval_lag3  slope_lag3  r2_lag3
 Serenje  -0.499451   0.034831   -0.131849 0.249451
Chingola  -0.521609   0.026411   -0.484231 0.272076
 Chavuma  -0.535638   0.021963   -0.534558 0.286908

Significant correlations for Lag 4 (p < 0.05):
     district  corr_lag4  pval_lag4  slope_lag4  r2_lag4
        Mansa  -0.49741

In [6]:
for lag in range(1, 6):
    col = f'mean_spei_lag{lag}'
    z_col = f'z_spei_lag{lag}'
    plot_df[z_col] = (
        plot_df.groupby('district')[col]
        .transform(lambda x: (x - x.mean()) / x.std())
    )

plot_df.head()

Unnamed: 0,district,year,mean_spei,median_spei,n_months,province,province_avg_start_date,province_avg_end_date,forest_cover_ha,loss_m2,...,mean_spei_lag1,mean_spei_lag2,mean_spei_lag3,mean_spei_lag4,mean_spei_lag5,z_spei_lag1,z_spei_lag2,z_spei_lag3,z_spei_lag4,z_spei_lag5
0,Chadiza,2001,0.364988,0.311898,5,Eastern,November 26,May 04,23442.326347,354592.9,...,,,,,,,,,,
1,Chadiza,2002,0.281255,0.244052,6,Eastern,November 26,May 04,23404.807955,375183.9,...,0.364988,,,,,0.206726,,,,
2,Chadiza,2003,0.227219,0.217295,5,Eastern,November 26,May 04,23295.975212,1088327.0,...,0.281255,0.364988,,,,0.121704,0.202697,,,
3,Chadiza,2004,-0.079997,0.072081,7,Eastern,November 26,May 04,23179.208834,1167664.0,...,0.227219,0.281255,0.364988,,,0.066836,0.119941,0.202608,,
4,Chadiza,2005,-0.699067,-0.777293,4,Eastern,November 26,May 04,23139.899515,393093.2,...,-0.079997,0.227219,0.281255,0.364988,,-0.24511,0.066535,0.122159,0.170363,
