Getting RMSE for the regions of interest

In [11]:
import xarray as xr
from pathlib import Path
import pandas as pd
import numpy as np
from sklearn.metrics import r2_score

from typing import Tuple, Dict, Optional, List

Parameters:
    
```python
always_ignore_vars = ["ndvi", "p84.162", "sp", "tp", "Eb"]
rnn(  # earnn(
    experiment="one_month_forecast",
    include_pred_month=True,
    surrounding_pixels=None,
    pretrained=False,
    explain=False,
    static="features",
    ignore_vars=always_ignore_vars,
    num_epochs=50,
    early_stopping=5,
    hidden_size=256,
    static_embedding_size=64,  # Not used, since static = "features"
    predict_delta=True,
    normalize_y=True,
    include_prev_y=True,  # True for the LSTM, false for the EA-LSTM - significant difference in performance for the LSTM
    include_latlons=True,
)
```

In [2]:
data_dir = Path("../../data")
assert data_dir.exists()

In [3]:
district_map = xr.open_dataset(data_dir / "analysis/boundaries_preprocessed/district_l2_kenya.nc")

In [4]:
def analyze_region(
        region_name: str, 
        district_map: xr.Dataset, 
        model_path: Path, 
        true_vals_path: Path) -> Tuple[float, float]:

    region_lookup: Dict = dict(
            zip(
                [v.strip() for v in district_map.attrs["values"].split(",")],
                [int(k.strip()) for k in district_map.attrs["keys"].split(",")],
            )
        )
    
    district_int = region_lookup[region_name]
    
    rmse: List[np.ndarray] = []
    
    for file_name in true_vals_path.glob("*"):
        year, month = file_name.name.split("_")
        
        true_file = xr.open_dataset(file_name / "y.nc").where(district_map.district_l2 == district_int).rename({"VCI": "preds"}).isel(time=0)
        model_file = xr.open_dataset(model_path / f"preds_{year}_{month}.nc").where(district_map.district_l2 == district_int)
        
        model_err = (model_file - true_file).preds.values
        model_err = model_err[~np.isnan(model_err)]
        rmse.append(np.sqrt(model_err ** 2))
    
    return np.concatenate(rmse).mean()

In [11]:
results: Dict[str, Dict[str, float]] = {}
for model in ["ealstm", "ealstm_prev_y", "previous_month", "rnn", "rnn_prev_y"]:
    results[model] = {}
    for region in ["TURKANA", "MANDERA", "MARSABIT", "WAJIR"]:
        results[model][region] = analyze_region(region, district_map, data_dir / f"models/one_month_forecast/{model}", 
                data_dir / "features/one_month_forecast/test")

In [12]:
results

{'ealstm': {'TURKANA': 10.24743413984095,
  'MANDERA': 8.302042759797557,
  'MARSABIT': 8.59926364998737,
  'WAJIR': 7.720248409576625},
 'ealstm_prev_y': {'TURKANA': 10.380956926503991,
  'MANDERA': 9.136506767119,
  'MARSABIT': 8.688602326238922,
  'WAJIR': 8.413396339210939},
 'previous_month': {'TURKANA': 11.744256204752439,
  'MANDERA': 12.204906110469631,
  'MARSABIT': 10.453345229680162,
  'WAJIR': 10.490208150497049},
 'rnn': {'TURKANA': 12.79268367431508,
  'MANDERA': 12.817927512766369,
  'MARSABIT': 12.751116392970431,
  'WAJIR': 12.912766885229038},
 'rnn_prev_y': {'TURKANA': 10.551420846362797,
  'MANDERA': 9.583836185529341,
  'MARSABIT': 8.877953046433564,
  'WAJIR': 8.849816731547334}}

In [7]:
# This assumes the region level analysis has already been run
data = pd.read_csv(data_dir / "analysis/region_analysis/regional_error_metrics_one_month_forecast_admin.csv")

In [8]:
for model in ["previous_month", "rnn_prev_y", "ealstm"]:
    for region in ["MANDERA", "MARSABIT", "TURKANA", "WAJIR"]:
        row = data[(data.model == model) & (data.region_name == region) & 
                  (data.admin_level_name == "district_l2_kenya")]
        print(f"For {model} in {region}, r2: {row.r2.iloc[0]}, rmse: {row.rmse.iloc[0]}")

For previous_month in MANDERA, r2: 0.3928677125057527, rmse: 12.640922267624708
For previous_month in MARSABIT, r2: 0.6490734925901813, rmse: 8.064530042145838
For previous_month in TURKANA, r2: 0.5859545537398633, rmse: 8.288133161479195
For previous_month in WAJIR, r2: 0.5443380466305261, rmse: 10.459842246254855
For rnn_prev_y in MANDERA, r2: 0.6992677073827819, rmse: 8.896662764883079
For rnn_prev_y in MARSABIT, r2: 0.8246640355259878, rmse: 5.7004150806802425
For rnn_prev_y in TURKANA, r2: 0.7651217818680127, rmse: 6.242435814085761
For rnn_prev_y in WAJIR, r2: 0.731797264116236, rmse: 8.024820977148295
For ealstm in MANDERA, r2: 0.8277096459523497, rmse: 6.7339141880843965
For ealstm in MARSABIT, r2: 0.8532328709661363, rmse: 5.215373026119667
For ealstm in TURKANA, r2: 0.8013903427422779, rmse: 5.740277591398884
For ealstm in WAJIR, r2: 0.8265186721970302, rmse: 6.454017216541415


In [7]:
def rolling_average(district_csv: pd.DataFrame, years: Optional[List[int]] = None) -> Dict[str, float]:
    relevant_districts = ['Mandera', 'Marsabit', 'Turkana', 'Wajir']

    if years is None:
        years = [2016, 2017]  # to reflect the Adede paper

    district_csv["month"] = pd.to_datetime(district_csv.datetime).dt.month
    district_csv["year"] = pd.to_datetime(district_csv.datetime).dt.year
    
    output_dict: Dict[str: float] = {}
    
    for district in relevant_districts:
        for year in years:
            year_df = district_csv[district_csv.year == year]
            district_df = year_df[year_df.region_name == district.upper()]
            true, predicted = [], []
            for i in range(1, 12 - 1):
                min_month = i
                max_month = i + 3
                submonth = district_df[(district_df.month >= min_month) & (district_df.month < max_month)]
                predicted.append(submonth.predicted_mean_value.mean())
                true.append(submonth.true_mean_value.mean())
        district_score = r2_score(true, predicted)
        print(f'For {district}, r2 score: {district_score}')
        output_dict[district] = district_score
    return output_dict

In [9]:
ealstm = pd.read_csv(data_dir / 'analysis/region_analysis/ealstm/ealstm_district_l2_kenya.csv')
prev_month = pd.read_csv(data_dir / 'analysis/region_analysis/previous_month/previous_month_district_l2_kenya.csv')
rnn = pd.read_csv(data_dir / 'analysis/region_analysis/rnn_prev_y/rnn_prev_y_district_l2_kenya.csv')

In [12]:
print("EALSTM")
_ = rolling_average(ealstm)

EALSTM
For Mandera, r2 score: 0.7573224218952296
For Marsabit, r2 score: 0.8648428559884426
For Turkana, r2 score: 0.9546412369097191
For Wajir, r2 score: 0.31978707795794936


In [13]:
print("RNN")
_ = rolling_average(rnn)

RNN
For Mandera, r2 score: 0.8144784879827663
For Marsabit, r2 score: 0.909008617524641
For Turkana, r2 score: 0.8859443513444373
For Wajir, r2 score: 0.5810596081444563


In [14]:
print("Previous Month")
_ = rolling_average(prev_month)

Previous Month
For Mandera, r2 score: 0.6789433578024786
For Marsabit, r2 score: 0.8627492279447612
For Turkana, r2 score: 0.8042008673182165
For Wajir, r2 score: 0.5485452659642569
