# 📊 **Re-Implementation of "Predicting Food Crises Using News Streams"**

---

#### 🔍 **Objective**

This notebook aims to **reproduce and analyze** the methodology presented in the paper:

📄 **Paper:** [Predicting food crises using news streams](https://www.science.org/doi/10.1126/sciadv.abm3449)  
📊 **Dataset:** [Harvard Dataverse Repository](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/CJDWUW)  
📜 **Original Code & Methods:** [GitHub - Regression Modeling (Step 5)](https://github.com/philippzi98/food_insecurity_predictions_nlp/blob/main/Step%205%20-%20Regression%20Modelling/README.md)

---

#### 🛠 **Methodology**

This implementation follows the **key steps** outlined in the paper to predict **food insecurity crises** using a combination of:
1️⃣ **Traditional Risk Factors** (conflict, climate, food prices, etc.)  
2️⃣ **News-Based Indicators** (text feature frequencies from news articles)  
3️⃣ **Lagging & Aggregation** (temporal dependencies at district, province, and country levels)  
4️⃣ **Machine Learning Models** (Random Forest, OLS, Lasso)

---

#### 🔗 **Reference Materials**

📄 **Supplementary Material:** Available in `supplemental_material_from_paper.pdf`  
📊 **Datasets Used:**

- `time_series_with_causes_zscore_full.csv` (Main dataset with time-series features)
- `famine-country-province-district-years-CS.csv` (Food insecurity classification)
- `matching_districts.csv` (Geographical standardization)


# 📚🔧 Import Libraries

In this notebook, we will use uv to manage our Python environment and packages efficiently. uv is a modern and fast package manager that simplifies virtual environment creation, and dependency installation. We will create a virtual environment, install necessary libraries, and ensure our environment stays consistent across different setups.


In [1]:
## Uncoment the below cell to install `uv` if you have not already. You can also install it trhiugh `pip` by running `!pip install uv` but this will be within your current python environment and not globally.

# !curl -LsSf https://astral.sh/uv/install.sh | sh
# !uv venv world-bank
# !source world-bank/bin/activate

In [2]:
# !uv pip install -r requirements.txt

In [3]:
import pandas as pd
import numpy as np
import folium
from IPython.display import display, Image
import os
import gdown
import zipfile
import editdistance
from fuzzywuzzy import fuzz
import math



In [4]:
url = "https://drive.google.com/uc?id=1YoQ1hz9RlaLr2xW3KoKCfJPyyO2PErym"
output = "data.zip"

if not os.path.exists("./data"):
    gdown.download(url, output, quiet=False) 
    zipfile.ZipFile('data.zip', 'r').extractall()
else:
    print("You already have the data downloaded and extracted")

You already have the data downloaded and extracted


## 📂 Load and Clean Data

**Understanding the Time-Series Dataset & Column Selection**

This dataset contains **district-level time-series data** on food insecurity risk factors, including:

- **📅 Temporal Information:** `year`, `month`, `year_month`
- **📍 Geographical Identifiers:** `admin_code`, `admin_name`, `province`, `country`
- **🌍 Traditional Risk Factors:** Climate (`rain_mean`, `ndvi_mean`), conflict (`acled_count`), food prices (`p_staple_food`)
- **📰 News-Based Indicators:** Proportions of news articles mentioning crisis-related keywords (`conflict_0`, `famine_0`, etc.)
- **📉 Food Insecurity Label:** `fews_ipc` (Integrated Phase Classification)

🔥 **Columns We Will Drop & Why**
✔ **Redundant Aggregations:** `_1`, `_2` columns (province & country-level values) since we will recompute aggregations from scratch anyways.  
✔ **Unnamed/Index Columns:** `Unnamed: 0` as it is unnecessary. It is just a duplicate of default index.
✔ **Unnecessary Identifiers:** If `admin_code` and `admin_name`, after matching these to `matching_districts.csv`, we can drop them.

---

> ⚠️ **NOTE:**  
> For a detailed explanation of the dataset and features, refer to the [`explore_time_series.ipynb`](./explore_time_series.ipynb) notebook.


In [5]:
time_series = pd.read_csv('./data/time_series_with_causes_zscore_full.csv')
admins = pd.read_csv('./data/famine-country-province-district-years-CS.csv')
valid_matching = pd.read_csv('./data/matching_districts.csv')

In [6]:
sorted(time_series.columns.values)

['Unnamed: 0',
 'abnormally low rainfall_0',
 'abnormally low rainfall_1',
 'abnormally low rainfall_2',
 'acled_count',
 'acled_fatalities',
 'acute hunger_0',
 'acute hunger_1',
 'acute hunger_2',
 'admin_code',
 'admin_name',
 'aid appeal_0',
 'aid appeal_1',
 'aid appeal_2',
 'aid workers died_0',
 'aid workers died_1',
 'aid workers died_2',
 'air attack_0',
 'air attack_1',
 'air attack_2',
 'alarming level_0',
 'alarming level_1',
 'alarming level_2',
 'anti-western policies_0',
 'anti-western policies_1',
 'anti-western policies_2',
 'apathy_0',
 'apathy_1',
 'apathy_2',
 'area',
 'asylum seekers_0',
 'asylum seekers_1',
 'asylum seekers_2',
 'authoritarian_0',
 'authoritarian_1',
 'authoritarian_2',
 'bad harvests_0',
 'bad harvests_1',
 'bad harvests_2',
 'blockade_0',
 'blockade_1',
 'blockade_2',
 'bombing campaign_0',
 'bombing campaign_1',
 'bombing campaign_2',
 'brain drain_0',
 'brain drain_1',
 'brain drain_2',
 'brutal government_0',
 'brutal government_1',
 'brutal 

In [7]:
time_series.head(5)

Unnamed: 0.1,Unnamed: 0,index,country,admin_code,admin_name,centx,centy,year_month,year,month,...,carbon_2,mayhem_0,mayhem_1,mayhem_2,dehydrated_0,dehydrated_1,dehydrated_2,mismanagement_0,mismanagement_1,mismanagement_2
0,0,30,Afghanistan,202,Kandahar,65.709343,31.043618,2009_07,2009,7,...,1.053,0.667,-0.171,-0.833,0.173667,0.168,1.284667,-0.073,-0.427667,0.668333
1,1,33,Afghanistan,202,Kandahar,65.709343,31.043618,2009_10,2009,10,...,-0.660812,-0.63658,-0.520247,-0.782913,-0.671587,-0.612254,-0.926921,-0.510467,-0.625133,-0.452467
2,2,36,Afghanistan,202,Kandahar,65.709343,31.043618,2010_01,2010,1,...,-0.134333,1.447667,-0.844333,0.778667,-0.676,-0.689667,0.293333,0.530333,-0.471333,0.955333
3,3,39,Afghanistan,202,Kandahar,65.709343,31.043618,2010_04,2010,4,...,-0.326927,-0.594877,0.16479,-0.90521,-0.62054,0.165794,0.045794,-1.0116,-0.8106,-0.2056
4,4,42,Afghanistan,202,Kandahar,65.709343,31.043618,2010_07,2010,7,...,-1.085146,-0.709913,-0.867913,-0.770247,-0.787921,-0.974587,-0.946921,-0.611133,-0.7098,-0.6228


In [8]:
t_variant_traditional_factors = ['ndvi_mean', 'ndvi_anom', 'rain_mean', 'rain_anom', 'et_mean', 'et_anom', 
                                    'acled_count', 'acled_fatalities', 'p_staple_food']
t_invariant_traditional_factors = ['area', 'cropland_pct', 'pop', 'ruggedness_mean', 'pasture_pct']

news_factors = [name for name in time_series.columns.values if '_0' in name]

In [9]:
news_factors[0]

'land seizures_0'

In [10]:
print("Columns count BEFORE dropping: ", len(time_series.columns.values))

Columns count BEFORE dropping:  532


In [11]:
cols_to_drop = ["Unnamed: 0", "centx", "centy", 'change_fews', 'fews_ha', 'fews_proj_med', 'fews_proj_med_ha', 'fews_proj_near_ha'] + [col for col in time_series.columns if col.endswith(('_1', '_2', '_3'))]
time_series.drop(columns=cols_to_drop, inplace=True)

In [12]:
potential_extra_cols = set(time_series.columns.values) - set(t_variant_traditional_factors) - set(t_invariant_traditional_factors) - set(news_factors)
potential_extra_cols = [col for col in potential_extra_cols if not col.endswith(('_1', '_2', '_3'))]
print("Potential extra columns", potential_extra_cols)

Potential extra columns ['month', 'year_month', 'country', 'fews_ipc', 'admin_name', 'year', 'index', 'fews_proj_near', 'admin_code']


In [13]:
print("Columns count after dropping: ", len(time_series.columns.values))

Columns count after dropping:  190


In [14]:
# print("Columns names after dropping: ", sorted(time_series.columns.values))

### 🌍 Admin Level Mapping: Standardizing Geographical Identifiers

In this section, we will **map and standardize** the `admin_code` and `admin_name` fields to their corresponding **district, province, and country names**. This step is **crucial** for ensuring **consistency** across different datasets and enabling **accurate aggregations** at multiple administrative levels.

🛠 **Why is Admin Level Mapping Important?**
✅ Different datasets may use **slightly different spellings or formats** for district names.  
✅ Some district names might be **missing or misspelled**, requiring standardization.  
✅ We need to **match and align** district names across various sources before aggregating at **province and country levels**.  
✅ Proper mapping allows us to **merge datasets correctly** without losing information.  

📌 **Steps in Admin Mapping**
1️⃣ **Load the `matching_districts.csv` file**, which provides the mapping between different district name variations.  
2️⃣ **Identify missing or unmatched `admin_name` values** and find their closest matches using fuzzy matching techniques.  
3️⃣ **Ensure that each `admin_code` uniquely maps to one `district`, `province`, and `country`.**  
4️⃣ **Replace inconsistent names** in the dataset with their standardized versions.  
5️⃣ **Aggregate data at the `province` and `country` levels** after ensuring all districts are correctly mapped.  


In [15]:
len(admins.country.unique())

39

In [16]:
admins.columns.values

array(['Unnamed: 0', 'country', 'district', 'year', 'month', 'CS',
       'province'], dtype=object)

In [17]:
admin_names = time_series['admin_name'].unique()
districts = admins['district'].unique()
provinces = admins['province'].unique()
countries = admins['country'].unique()

In [18]:
print (len(admin_names), len(districts), len(provinces), len(countries))
print (len(set(admin_names).difference(districts)))
missing_admin_names = set(admin_names).difference(districts)
print (len(missing_admin_names.difference(provinces)))
missing_admin_names = missing_admin_names.difference(provinces)

1142 4113 474 39
369
230


### Fuzzy String Matching for Missing Names

The function uses **fuzzy string matching** to find the best approximate matches for missing administrative names (e.g., districts and provinces). 

- Finds the **best matching district/province** for each missing name.
- Uses **fuzzy string matching** to calculate the similarity between missing names and known names.
- Returns a dictionary that maps each missing name to its closest match.


In [19]:
def find_matching(missing, names):
    matching_districts = {}
    for m in missing:
        max_overlap = 0
        nearest_d = None
        for d in names:
            d = str(d)
            dist = fuzz.partial_ratio(m, d)
            if dist > max_overlap:
                max_overlap = dist
                nearest_d = d
        matching_districts[m] = nearest_d
    return matching_districts


matching = find_matching(missing_admin_names, districts)
matching_p = find_matching(missing_admin_names, provinces)

# manually verify matching and update
for k in matching.keys():
    print (k, matching[k], matching_p[k])


Grande Riviere Du Nord Grande Riviere du Nord Nord
Southern Tigray Lira Tigray
Al Kurumik Qulansiyah wa `Abd Al Kuri Ituri
North Shewa(R3) North Shewa North
Caynabo Caynaba Bay
North Western Tigray Lira Western
Sami' Sami` Haut-Lomami
Central Kisii Kiti Central
Karary Karari Kwara
Meru Central Meru Central
Kajo-keji Kajo-Keji Kano
Belbedji Bielel Abyei
Butere Mumias Butere Muyinga
Port De Paix Port de Paix Pwani
La Nya Pendé La Nya Lac
Bankilaré Bankilare Sila
Guji Gujii Guidimaka
South Khartoum Khartoum Khartoum
Bulo Burto Burco Koulikoro
Gourma-Rharous Gourma Ghor
Abu Hamad Abu Hamed Hilmand
Al Jabalian Jaba Al Jawf
Majang Marangara Mahajanga
Adan Aldai `Adan
Mangalmé Mangalme Tanga
Balleyara Bale Mara
Um Badda Um Keddada Bay
Al Geneina El Geneina Geita
Laasqoray Rorya Tabora
Barh El Gazel Ouest Barh el Gazel Ouest Ouest
Bulilima (North) Bulilima North
Hirat Wag Himra Hiiraan
Anse-A-Veau `Ans Lamu
Trans Mara Maara Mara
Shabelle Shebelle Middle Shabelle
Bindura Bindura Urban Cabinda
G

### Encoding Decoding

`to_ascii_escaped(s)`: Converts a Unicode string to an ASCII-safe representation using **unicode-escape**.

`from_ascii_escaped(escaped)`: Converts the escaped ASCII string back into its original Unicode form.

In [20]:
def to_ascii_escaped(s):
    """
    Convert a Unicode string to an ASCII-safe string using unicode-escape.
    This will replace non-ASCII characters with their escape sequences.
    """
    if isinstance(s, bytes):
        s = s.decode('utf-8')
    # Using 'unicode-escape' encoding produces a bytes object,
    # then decode it to get an ASCII string.
    return s.encode('unicode-escape').decode('ascii')

def from_ascii_escaped(escaped):
    """
    Convert the ASCII-escaped string back to the original Unicode string.
    """
    # Encode the ASCII string to bytes, then decode using 'unicode-escape'
    return escaped.encode('ascii').decode('unicode-escape')

# # Test the round-trip on each unique value from valid_matching['missing']:
# for m in valid_matching['missing'].unique():
#     # Ensure m is a Unicode string
#     original = m.decode('utf-8') if isinstance(m, bytes) else m
#     # Convert to an ASCII-escaped representation
#     encoded = to_ascii_escaped(original)
#     # Convert back from the ASCII-escaped representation to Unicode
#     decoded = from_ascii_escaped(encoded)
    
#     # Print the results
#     print("Original: ", original)
#     print("Encoded:  ", encoded)
#     print("Decoded:  ", decoded)
#     print("Round-trip equal:", original == decoded)
#     print("-" * 40)


### Finding the Province for a Given District or Province

`find_province(x)`, finds the **province** corresponding to a given administrative name. It accounts for:
- **Direct Lookups** (Exact match in known district/province lists)
- **Fuzzy Matching** (Using ASCII-safe transformation for inconsistent text encoding)
- **Validation Against a Predefined Mapping (`valid_matching`)**

In [21]:
# Define matched globally
matched = valid_matching['missing'].unique()

def to_ascii_escaped(s):
    """
    Convert a Unicode string to an ASCII-safe string using unicode-escape.
    This will replace non-ASCII characters with their escape sequences.
    """
    if isinstance(s, bytes):
        s = s.decode('utf-8')
    return s.encode('unicode-escape').decode('ascii')

def find_province(x):
    try:
        # Ensure x is a Unicode string.
        if isinstance(x, bytes):
            x = x.decode('utf-8')
        
        # Direct lookup in districts or provinces.
        if x in districts:
            return admins[admins['district'] == x]['province'].values[0]
        elif x in provinces:
            return x

        # Convert x to an ASCII-escaped version.
        escaped_x = to_ascii_escaped(x)
        
        # Check if the escaped version is in matched.
        if escaped_x in matched:
            v = valid_matching[valid_matching['missing'] == escaped_x]
            if v['match'].values[0] == 'district':
                x2 = v['district'].values[0]
                return admins[admins['district'] == x2]['province'].values[0]
            elif v['match'].values[0] == 'province':
                return v['province'].values[0]
        
        # If no conditions are met, raise an exception.
        raise Exception("No matching province found")
    except Exception as e:
        raise Exception("Province not found for: {} ({})".format(x, e))


### Handling Admin Names with Accented Characters and Mapping to Provinces

Maps `admin_names` to provinces using the `find_province(a)` function.  
If a **direct lookup fails**, it tries to handle cases where the **admin name contains accented characters** (`é`, `è`, `ô`) ->  (encoding decoding issues resolved through directly replacing these with 'e' or 'o', leads to finding a valid match). 

In [22]:
admin_to_province = {}
for a in admin_names:
    try:
        admin_to_province[a] = find_province(a)
    except Exception as e:
        # Print the admin name that caused an error
        print("Error with:", a)
        # Check if a contains accented characters "é" or "è"
        if 'é' in a or 'è' in a or 'ô' in a:
            a_modified = a.replace('é', 'e').replace('è', 'e').replace('ô', 'o')
            # Check if the modified name is in districts
            if a_modified in districts:
                # Use the modified name to look up the province from admins
                try:
                    province = admins[admins['district'] == a_modified]['province'].values[0]
                    admin_to_province[a] = province
                    print(f"Replaced '{a}' with '{a_modified}', found province: {province}")
                except Exception as ex:
                    print(f"Modified name '{a_modified}' not found in admins: {ex}")
            else:
                print(f"Modified name '{a_modified}' not in districts.")
        else:
            print(f"No accented e found in '{a}'.")


Error with: Mangalmé
Replaced 'Mangalmé' with 'Mangalme', found province: Guera
Error with: La Pendé
Replaced 'La Pendé' with 'La Pende', found province: Logone Oriental
Error with: La Nya Pendé
Replaced 'La Nya Pendé' with 'La Nya Pende', found province: Logone Oriental
Error with: Lac-Léré
Replaced 'Lac-Léré' with 'Lac-Lere', found province: Mayo-Kebbi Ouest
Error with: Barh-Kôh
Replaced 'Barh-Kôh' with 'Barh-Koh', found province: Moyen-Chari
Error with: Aguié
Replaced 'Aguié' with 'Aguie', found province: Maradi
Error with: Bankilaré
Replaced 'Bankilaré' with 'Bankilare', found province: Tillaberi
Error with: Filingué
Replaced 'Filingué' with 'Filingue', found province: Tillaberi
Error with: Gothèye
Replaced 'Gothèye' with 'Gotheye', found province: Tillaberi
Error with: Gouré
Replaced 'Gouré' with 'Goure', found province: Zinder
Error with: Illéla
Replaced 'Illéla' with 'Illela', found province: Sokoto
Error with: Kantché
Replaced 'Kantché' with 'Kantche', found province: Zinder
Er

In [23]:
# print(admin_to_province)
for k, v in admin_to_province.items():
    print("key is : ", k)
    print("value is : ", v)

key is :  Kandahar
value is :  Kandahar
key is :  Kapisa
value is :  Kapisa
key is :  Khost
value is :  Khost
key is :  Kunar
value is :  Kunar
key is :  Kunduz
value is :  Kunduz
key is :  Laghman
value is :  Laghman
key is :  Logar
value is :  Logar
key is :  Nangarhar
value is :  Nangarhar
key is :  Paktika
value is :  Paktika
key is :  Paktya
value is :  Paktya
key is :  Samangan
value is :  Samangan
key is :  Sar-e-Pul
value is :  Sari Pul
key is :  Takhar
value is :  Takhar
key is :  Wardak
value is :  Wardak
key is :  Zabul
value is :  Zabul
key is :  Daykundi
value is :  Daykundi
key is :  Panjsher
value is :  Panjsher
key is :  Parwan
value is :  Parwan
key is :  Uruzgan
value is :  Uruzgan
key is :  Badakhshan
value is :  Badakhshan
key is :  Badghis
value is :  Badghis
key is :  Baghlan
value is :  Baghlan
key is :  Balkh
value is :  Balkh
key is :  Bamyan
value is :  Bamyan
key is :  Farah
value is :  Farah
key is :  Faryab
value is :  Faryab
key is :  Ghazni
value is :  Gh

### Mapping Administrative Names to Provinces in time_series

Maps `admin_name` to their respective **provinces** using a precomputed dictionary - >`admin_to_province` in `time_series`.


In [24]:
# time_series['province'] = time_series['admin_name'].apply(lambda x: admin_to_province[x])
time_series['province'] = time_series['admin_name'].apply(
    lambda x: admin_to_province[x] if x in admin_to_province else admin_to_province.get(x.replace('ô', 'o'))
)


# ⏳ Time Lagging & Feature Engineering

#### 📅 **Why Use Lagging?**

To predict food insecurity **for a given quarter**, we use:

- **6 months of historical values** for traditional & news-based features.
- **Province & country-level aggregations** to capture broader shocks.
- **6 quarters of lagged IPC phase values** to model temporal dependencies.

#### ⚡ **Optimized Lagging Approach**

To improve computational efficiency, we:
✔ Use `groupby()` for **fast province & country-level aggregations**.  
✔ Merge lagged data via `merge()` instead of slow `.apply()`.  
✔ Only keep **past data** to ensure no data leakage.


In [None]:
def add_time_lagged(features, start=3, end=9, diff=1, agg=True, time_series=time_series):
    
    # Determine suffixes based on aggregation flag
    levels = ['', '_province', '_country'] if agg else ['']
    
    # Create a copy of the dataframe to work with
    # This helps with defragmentation and prevents modifying the original during processing
    working_df = time_series.copy()
    
    # Create a mapping from (admin_code, year, month) to row index for fast lookups
    # We'll use this for efficient merging of lagged data
    working_df['total_months'] = working_df['year'] * 12 + working_df['month'] - 1
    
    # Prepare a list to collect all new dataframes
    new_columns_df = []
    
    # Track new column names to add
    new_column_names = []
    
    # Process each feature and lag combination
    for suffix in levels:
        for f in features:
            f_s = f + suffix
            
            # Skip if the feature doesn't exist
            if f_s not in working_df.columns:
                continue
                
            for t in range(start, end, diff):
                col_name = f'{f_s}_{t}'
                
                # Skip if column already exists
                if col_name in time_series.columns:
                    continue
                
                new_column_names.append(col_name)
                
                # Create a dataframe with the lagged feature
                lagged_df = working_df[['admin_code', 'total_months', f_s]].copy()
                
                # Calculate the target months to look back to
                lagged_df['target_months'] = lagged_df['total_months'] - t
                
                # Convert target months back to year and month
                lagged_df['l_year'] = lagged_df['target_months'] // 12
                lagged_df['l_month'] = (lagged_df['target_months'] % 12) + 1
                
                # Create a reference key for merging
                lagged_df['ref_key'] = lagged_df['admin_code'].astype(str) + '_' + \
                                       lagged_df['l_year'].astype(str) + '_' + \
                                       lagged_df['l_month'].astype(str)
                
                # Create the same key in our original dataframe for merging
                working_df['temp_key'] = working_df['admin_code'].astype(str) + '_' + \
                                        working_df['year'].astype(str) + '_' + \
                                        working_df['month'].astype(str)
                
                # Map from reference key to the feature value
                ref_map = dict(zip(
                    working_df['temp_key'],
                    working_df[f_s]
                ))
                
                # Apply the mapping to get lagged values - use the original value as fallback
                lagged_values = lagged_df['ref_key'].map(ref_map).fillna(working_df[f_s])
                
                # Create a single-column dataframe with the lagged values
                lag_result_df = pd.DataFrame({
                    col_name: lagged_values.values
                }, index=working_df.index)
                
                new_columns_df.append(lag_result_df)
    
    # Clean up temp columns
    if 'temp_key' in working_df.columns:
        working_df.drop('temp_key', axis=1, inplace=True)
    
    # If we have new columns to add, concatenate them all at once
    if new_columns_df:
        # Combine all new columns into a single dataframe
        all_new_columns = pd.concat(new_columns_df, axis=1)
        
        # Add the new columns to the original dataframe
        time_series = pd.concat([time_series, all_new_columns], axis=1)
        
    return time_series

# Province & Country-Level Aggregation

This function aggregates feature values at the province and country levels to capture regional trends, aiding in food insecurity prediction. The process includes:

- **Grouping by year_month and level:** Data is grouped by year_month and the specified level (province or country) to calculate the mean of features, reflecting regional trends over time.

- **Applying transformations efficiently:** Instead of merging aggregated data, `transform("mean")` is used to directly assign the computed mean to each row, avoiding unnecessary joins and improving performance.  

#### ⚡ **Efficiency Gains**

- **Fast Aggregation**: Uses `groupby()` for efficient aggregation.
- **Avoids Costly Joins**: Eliminates the need for `merge()` by using `transform()` instead, reducing computational overhead.  
- **Memory Efficiency**: Converts the `level` column to a categorical type to reduce memory usage.

This approach ensures faster processing while maintaining the quality of aggregated features.


In [27]:
def add_agg_factors(features, level='province'):
    # Ensure we are modifying the global variable 'time_series'
    global time_series  
    
    # Convert the 'level' column (e.g., 'province') to categorical data type
    # This helps in reducing memory usage and speeds up the groupby operation.
    time_series[level] = time_series[level].astype('category')
    
    # Grouping the 'time_series' DataFrame by 'year_month' and the specified 'level'
    # The 'observed=True' ensures that we respect the existing categories when grouping.
    # The 'sort=False' prevents automatic sorting, which might improve performance.
    # We compute the mean for the specified 'features' for each group using transform(),
    # which directly assigns the computed mean values to each row in the original DataFrame.
    grouped_df = time_series.groupby(['year_month', level], observed=True, sort=False)[features].transform("mean")
    
    # Assign the transformed values with the appropriate column suffix
    # Each feature gets a new column with the suffix '_<level>', containing the aggregated mean.
    for feature in features:
        time_series[f"{feature}_{level}"] = grouped_df[feature]
        
    # Return the updated 'time_series' DataFrame with the newly added aggregated features
    return time_series


In [28]:
add_agg_factors(news_factors)

  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[

Unnamed: 0,index,country,admin_code,admin_name,year_month,year,month,fews_ipc,fews_proj_near,ndvi_mean,...,gastrointestinal_0_province,terrorist_0_province,warlord_0_province,d'etat_0_province,overthrow_0_province,convoys_0_province,carbon_0_province,mayhem_0_province,dehydrated_0_province,mismanagement_0_province
0,30,Afghanistan,202,Kandahar,2009_07,2009,7,1.0,,0.106035,...,-0.192000,-0.284333,-0.668667,0.647333,-0.891333,0.112667,1.265333,0.667000,0.173667,-0.073000
1,33,Afghanistan,202,Kandahar,2009_10,2009,10,1.0,,0.103009,...,-0.545727,-1.037016,-0.811291,-0.850261,-0.948892,-0.728972,-0.765146,-0.636580,-0.671587,-0.510467
2,36,Afghanistan,202,Kandahar,2010_01,2010,1,2.0,,0.109600,...,1.506333,0.455000,1.595667,0.571667,0.279000,-0.868333,0.058333,1.447667,-0.676000,0.530333
3,39,Afghanistan,202,Kandahar,2010_04,2010,4,2.0,,0.111599,...,-0.793970,-0.722159,-0.130521,0.047630,0.362613,0.480986,0.026073,-0.594877,-0.620540,-1.011600
4,42,Afghanistan,202,Kandahar,2010_07,2010,7,1.0,,0.096943,...,-0.509394,-0.694350,-1.215958,-0.865261,-1.119225,-1.060638,-0.673479,-0.709913,-0.787921,-0.611133
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
40947,183579,Zimbabwe,612,Zvishavane,2018_10,2018,10,3.0,3.0,0.326176,...,0.765375,0.391125,0.501417,0.571500,0.012375,0.226250,0.183833,0.371167,0.315250,0.249750
40948,183583,Zimbabwe,612,Zvishavane,2019_02,2019,2,3.0,3.0,0.497560,...,0.088125,-0.147375,0.177167,0.108125,0.177833,0.363208,0.142167,0.501708,0.360417,0.183875
40949,183587,Zimbabwe,612,Zvishavane,2019_06,2019,6,3.0,3.0,0.377050,...,0.307691,-0.300870,0.137749,0.275410,0.281648,0.137628,0.333385,-0.509590,0.442653,0.441593
40950,183591,Zimbabwe,612,Zvishavane,2019_10,2019,10,3.0,3.0,0.297881,...,-0.231619,0.205385,-0.045085,0.077570,0.500754,0.199047,-0.009647,0.183611,-0.157611,0.271520


In [29]:
add_agg_factors(news_factors, level='country')
time_series.head(10)

  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[

Unnamed: 0,index,country,admin_code,admin_name,year_month,year,month,fews_ipc,fews_proj_near,ndvi_mean,...,gastrointestinal_0_country,terrorist_0_country,warlord_0_country,d'etat_0_country,overthrow_0_country,convoys_0_country,carbon_0_country,mayhem_0_country,dehydrated_0_country,mismanagement_0_country
0,30,Afghanistan,202,Kandahar,2009_07,2009,7,1.0,,0.106035,...,0.009191,-0.196791,-0.277796,-0.080313,-0.158093,-0.091979,-0.205168,-0.351945,-0.004046,0.020729
1,33,Afghanistan,202,Kandahar,2009_10,2009,10,1.0,,0.103009,...,-0.123518,-0.169664,-0.039284,0.096598,-0.145231,-0.058545,0.024883,0.039075,-0.060053,-0.11289
2,36,Afghanistan,202,Kandahar,2010_01,2010,1,2.0,,0.1096,...,0.194285,-0.051662,0.11223,0.271666,0.351302,0.17547,0.236345,0.159935,0.198416,0.409551
3,39,Afghanistan,202,Kandahar,2010_04,2010,4,2.0,,0.111599,...,-0.073142,-0.122469,0.135643,0.204327,0.101177,-0.144836,0.22267,0.156061,0.27463,0.180825
4,42,Afghanistan,202,Kandahar,2010_07,2010,7,1.0,,0.096943,...,0.15815,-0.068429,-0.161717,-0.187331,-0.13909,-0.142929,-0.010496,-0.081932,0.007388,0.138598
5,45,Afghanistan,202,Kandahar,2010_10,2010,10,2.0,,0.095377,...,0.01429,-0.021839,0.010606,-0.085573,0.147253,-0.040194,0.05393,-0.185165,0.172638,-0.085029
6,48,Afghanistan,202,Kandahar,2011_01,2011,1,2.0,,0.09262,...,-0.059419,-0.050372,-0.144964,-0.140235,-0.056617,-0.232453,-0.117227,0.117,-0.111434,-0.100282
7,51,Afghanistan,202,Kandahar,2011_04,2011,4,2.0,2.0,0.131462,...,-0.283481,-0.311409,-0.295271,-0.370462,-0.298859,-0.246249,-0.109567,-0.168876,-0.335914,-0.276873
8,54,Afghanistan,202,Kandahar,2011_07,2011,7,1.0,1.0,0.106885,...,0.198956,0.232582,0.165008,0.475172,0.257077,0.246208,0.089489,0.217305,0.162207,0.096265
9,57,Afghanistan,202,Kandahar,2011_10,2011,10,1.0,1.0,0.103268,...,0.456875,-0.102239,0.338327,0.302879,0.120221,0.354626,0.248271,0.316393,0.335178,0.232545


In [None]:
add_agg_factors(t_variant_traditional_factors, level='province')

  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]


array(['index', 'country', 'admin_code', 'admin_name', 'year_month',
       'year', 'month', 'fews_ipc', 'fews_proj_near', 'ndvi_mean',
       'ndvi_anom', 'rain_mean', 'rain_anom', 'et_mean', 'et_anom',
       'acled_count', 'acled_fatalities', 'p_staple_food', 'area',
       'cropland_pct', 'pop', 'ruggedness_mean', 'pasture_pct',
       'land seizures_0', 'slashed export_0', 'price rise_0',
       'mass hunger_0', 'cyclone_0', 'failed crops_0',
       'disruption to farming_0', 'massive starvation_0',
       'abnormally low rainfall_0', 'withheld relief_0',
       'international alarm_0', 'reduced national output_0',
       'oppressive regimes_0', 'pests_0', 'continued deterioration_0',
       'forests destroyed_0', 'man-made disaster_0', 'food insecurity_0',
       'harvests are devastated_0', 'humanitarian situation_0',
       'economic impoverishment_0', 'clan battle_0',
       'population crisis_0', 'aid appeal_0', 'weather extremes_0',
       'anti-western policies_0', 'rinderp

In [31]:
add_agg_factors(t_variant_traditional_factors, level='country')
add_agg_factors(t_invariant_traditional_factors, level='province')
add_agg_factors(t_invariant_traditional_factors, level='country')

  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[f"{feature}_{level}"] = grouped_df[feature]
  time_series[

Unnamed: 0,index,country,admin_code,admin_name,year_month,year,month,fews_ipc,fews_proj_near,ndvi_mean,...,area_province,cropland_pct_province,pop_province,ruggedness_mean_province,pasture_pct_province,area_country,cropland_pct_country,pop_country,ruggedness_mean_country,pasture_pct_country
0,30,Afghanistan,202,Kandahar,2009_07,2009,7,1.0,,0.106035,...,54174.53381,1.417796,1241226.000,101047.15870,16.246279,18883.616198,7.335028,800334.941176,316314.115459,49.145729
1,33,Afghanistan,202,Kandahar,2009_10,2009,10,1.0,,0.103009,...,54174.53381,1.417796,1241226.000,101047.15870,16.246279,18883.616198,7.335028,800334.941176,316314.115459,49.145729
2,36,Afghanistan,202,Kandahar,2010_01,2010,1,2.0,,0.109600,...,54174.53381,1.417796,1280853.000,101047.15870,16.246279,18883.616198,7.335028,821226.147059,316314.115459,49.145729
3,39,Afghanistan,202,Kandahar,2010_04,2010,4,2.0,,0.111599,...,54174.53381,1.417796,1280853.000,101047.15870,16.246279,18883.616188,7.335028,821226.147059,316314.116188,49.145729
4,42,Afghanistan,202,Kandahar,2010_07,2010,7,1.0,,0.096943,...,54174.53381,1.417796,1280853.000,101047.15870,16.246279,18883.616188,7.335028,821226.147059,316314.116188,49.145729
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
40947,183579,Zimbabwe,612,Zvishavane,2018_10,2018,10,3.0,3.0,0.326176,...,6198.95825,31.339385,254993.250,113050.19375,55.568223,6531.261152,31.452557,278500.300000,144621.235150,53.536959
40948,183583,Zimbabwe,612,Zvishavane,2019_02,2019,2,3.0,3.0,0.497560,...,6198.95825,31.339385,260172.625,113050.19375,55.568223,6531.261152,31.452557,284723.000000,144621.235150,53.536959
40949,183587,Zimbabwe,612,Zvishavane,2019_06,2019,6,3.0,3.0,0.377050,...,6198.95825,31.339385,260172.625,113050.19375,55.568223,6531.261152,31.452557,284723.000000,144621.235150,53.536959
40950,183591,Zimbabwe,612,Zvishavane,2019_10,2019,10,3.0,3.0,0.297881,...,6198.95825,31.339385,260172.625,113050.19375,55.568223,6531.261152,31.452557,284723.000000,144621.235150,53.536959


In [None]:
time_series.to_csv('agg_province_features.csv')

array(['index', 'country', 'admin_code', 'admin_name', 'year_month',
       'year', 'month', 'fews_ipc', 'fews_proj_near', 'ndvi_mean',
       'ndvi_anom', 'rain_mean', 'rain_anom', 'et_mean', 'et_anom',
       'acled_count', 'acled_fatalities', 'p_staple_food', 'area',
       'cropland_pct', 'pop', 'ruggedness_mean', 'pasture_pct',
       'land seizures_0', 'slashed export_0', 'price rise_0',
       'mass hunger_0', 'cyclone_0', 'failed crops_0',
       'disruption to farming_0', 'massive starvation_0',
       'abnormally low rainfall_0', 'withheld relief_0',
       'international alarm_0', 'reduced national output_0',
       'oppressive regimes_0', 'pests_0', 'continued deterioration_0',
       'forests destroyed_0', 'man-made disaster_0', 'food insecurity_0',
       'harvests are devastated_0', 'humanitarian situation_0',
       'economic impoverishment_0', 'clan battle_0',
       'population crisis_0', 'aid appeal_0', 'weather extremes_0',
       'anti-western policies_0', 'rinderp

# Add time lagged features


In [None]:
time_series = add_time_lagged(t_variant_traditional_factors, time_series=time_series)

array(['index', 'country', 'admin_code', 'admin_name', 'year_month',
       'year', 'month', 'fews_ipc', 'fews_proj_near', 'ndvi_mean',
       'ndvi_anom', 'rain_mean', 'rain_anom', 'et_mean', 'et_anom',
       'acled_count', 'acled_fatalities', 'p_staple_food', 'area',
       'cropland_pct', 'pop', 'ruggedness_mean', 'pasture_pct',
       'land seizures_0', 'slashed export_0', 'price rise_0',
       'mass hunger_0', 'cyclone_0', 'failed crops_0',
       'disruption to farming_0', 'massive starvation_0',
       'abnormally low rainfall_0', 'withheld relief_0',
       'international alarm_0', 'reduced national output_0',
       'oppressive regimes_0', 'pests_0', 'continued deterioration_0',
       'forests destroyed_0', 'man-made disaster_0', 'food insecurity_0',
       'harvests are devastated_0', 'humanitarian situation_0',
       'economic impoverishment_0', 'clan battle_0',
       'population crisis_0', 'aid appeal_0', 'weather extremes_0',
       'anti-western policies_0', 'rinderp

In [None]:
time_series = add_time_lagged(news_factors, time_series=time_series)

array(['index', 'country', 'admin_code', ..., 'mismanagement_0_country_6',
       'mismanagement_0_country_7', 'mismanagement_0_country_8'],
      dtype=object)

In [None]:
time_series = add_time_lagged(['fews_ipc'], end=21, diff=3, agg=False, time_series=time_series)

array(['index', 'country', 'admin_code', ..., 'fews_ipc_12',
       'fews_ipc_15', 'fews_ipc_18'], dtype=object)

In [None]:
time_series = add_time_lagged(['fews_proj_near'], start=3, end=4, diff=1, agg=False, time_series=time_series)

array(['index', 'country', 'admin_code', ..., 'fews_ipc_15',
       'fews_ipc_18', 'fews_proj_near_3'], dtype=object)

In [37]:
def diebold_mariano(preds, labels):
    sq_error = [(p-l)**2 for p,l in zip(preds, labels)]
    mean = np.mean(sq_error)
    n = len(preds)
    gammas = {}
    m = max(n,int(math.ceil(np.cbrt(n))+2))
    for k in range(m):
        gammas[k] = 0
        for i in range(k+1, n):
            gammas[k] += (sq_error[i] - mean)*(sq_error[i-k] - mean)
        gammas[k] = gammas[k]/n
    sum_gamma = gammas[0]
    for k in range(1, m):
        sum_gamma += 2*gammas[k]
    return np.sqrt(sum_gamma/n)

In [46]:
display(time_series.columns.values)

array(['index', 'country', 'admin_code', ..., 'fews_ipc_15',
       'fews_ipc_18', 'fews_proj_near_3'], dtype=object)

In [47]:
time_series.head()

Unnamed: 0,index,country,admin_code,admin_name,year_month,year,month,fews_ipc,fews_proj_near,ndvi_mean,...,mismanagement_0_country_6,mismanagement_0_country_7,mismanagement_0_country_8,fews_ipc_3,fews_ipc_6,fews_ipc_9,fews_ipc_12,fews_ipc_15,fews_ipc_18,fews_proj_near_3
0,30,Afghanistan,202,Kandahar,2009_07,2009,7,1.0,,0.106035,...,0.020729,0.020729,0.020729,1.0,1.0,1.0,1.0,1.0,1.0,
1,33,Afghanistan,202,Kandahar,2009_10,2009,10,1.0,,0.103009,...,-0.11289,-0.11289,-0.11289,1.0,1.0,1.0,1.0,1.0,1.0,
2,36,Afghanistan,202,Kandahar,2010_01,2010,1,2.0,,0.1096,...,0.020729,0.409551,0.409551,1.0,1.0,2.0,2.0,2.0,2.0,
3,39,Afghanistan,202,Kandahar,2010_04,2010,4,2.0,,0.111599,...,-0.11289,0.180825,0.180825,2.0,1.0,1.0,2.0,2.0,2.0,
4,42,Afghanistan,202,Kandahar,2010_07,2010,7,1.0,,0.096943,...,0.409551,0.138598,0.138598,2.0,2.0,1.0,1.0,1.0,1.0,


# Generate and save data for Fig 3A, B, C


In [None]:
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn import linear_model

from sklearn.metrics import mean_squared_error
from sklearn.metrics import average_precision_score, precision_recall_curve
from sklearn.metrics import auc

test_splits = [
    ((2010,7), (2011, 7)), 
    ((2011,7), (2012, 7)),
    ((2012,7), (2013, 7)), 
    ((2013,7), (2014, 7)), 
    ((2014,7), (2015, 7)), 
    ((2015,7), (2016, 7)), 
    ((2016,7), (2017, 7)), 
    ((2017,7), (2018, 7)),
    ((2018,7), (2019, 7)), 
    ((2019,2), (2020, 2)),
]
train_splits = [
    ((2009,7), (2010,4)),
    ((2009,7), (2011,1)),
    ((2009,7), (2011,10)),
    ((2009,7), (2012,7)),
    ((2009,7), (2013,7)),
    ((2009,7), (2014,1)),
    ((2009,7), (2015,1)),
    ((2009,7), (2015,10)),
    ((2009,7), (2016,10)),
    ((2009,7), (2017,2))]
dev_splits = [
    ((2010,4), (2010, 7)),
    ((2011,1), (2011, 7)),
    ((2011,10), (2012, 7)),
    ((2012,7), (2013, 7)),
    ((2013,4), (2014, 7)),
    ((2014,1), (2015, 7)),
    ((2015,1), (2016, 7)),
    ((2015,10), (2017, 7)),
    ((2016,10), (2018, 7)),
    ((2017,2), (2019, 2)),
]
rf = RandomForestRegressor(max_features='auto', n_estimators=100, 
                             min_samples_split=0.5, min_impurity_decrease=0.001, random_state=0)
ols = LinearRegression()

lasso = linear_model.Lasso(alpha=0.1)

def get_agg_lagged_features(factors):
    return ['{}_{}'.format(f, t) for f, t in zip(factors, range(3,9))] + ['{}_province_{}'.format(f, t) for f, t in zip(factors, range(3,9))] + ['{}_country_{}'.format(f, t) for f, t in zip(factors, range(3,9))]
        

features = {
    'traditional': time_series[
        ['{}_{}'.format('fews_ipc', t) for t in range(3,21,3)] + 
        get_agg_lagged_features(t_variant_traditional_factors) + 
        t_invariant_traditional_factors
    ], 
    'news': time_series[
        ['{}_{}'.format('fews_ipc', t) for t in range(3,21,3)] +
        get_agg_lagged_features(news_factors)
    ], 
    'traditional+news': time_series[
        ['{}_{}'.format('fews_ipc', t) for t in range(3,21,3)] +
        get_agg_lagged_features(t_variant_traditional_factors) + 
        t_invariant_traditional_factors +
        get_agg_lagged_features(news_factors)
    ],
    'expert': time_series['fews_proj_near_3'],
    'expert+traditional': time_series[
        ['fews_proj_near_3'] +
        ['{}_{}'.format('fews_ipc', t) for t in range(3,21,3)] + 
        get_agg_lagged_features(t_variant_traditional_factors) + 
        t_invariant_traditional_factors
    ],
    'expert+news': time_series[
        ['fews_proj_near_3'] +
        ['{}_{}'.format('fews_ipc', t) for t in range(3,21,3)] +
        get_agg_lagged_features(news_factors)
    ],
    'expert+traditional+news': time_series[
        ['fews_proj_near_3'] +
        ['{}_{}'.format('fews_ipc', t) for t in range(3,21,3)] +
        get_agg_lagged_features(t_variant_traditional_factors) + 
        t_invariant_traditional_factors +
        get_agg_lagged_features(news_factors)
    ]
}

labels_df = time_series['fews_ipc']

def get_time_split(df, start, end):
    return df[df['year'] >= start[0] & df['month'] >= start[1] & df['year'] <= end[0] & df['month'] <= end[1]]


fig_3a = pd.DataFrame(columns=['method', 'split', 'features', 'country', 'rmse', 'lower_bound', 'upper_bound'])
fig_3b = pd.DataFrame(columns=['method', 'split', 'features', 'aucpr'])
fig_3c = pd.DataFrame(columns=['method', 'split', 'features', 'recall_at_80p'])

thresholds = {'traditional': (2.236, 3.125), 
              'news': (1.907, 2.712), 
              'traditional+news': (2.105, 3.314),
              'expert': (2, 3),
              'expert+news': (1.912, 2.813),
              'expert+traditional': (2.241, 3.132),
              'expert+traditional+news': (2.172, 3.321)
             }

for train, dev, test in zip(train_splits, dev_splits, test_splits):
    for f, D in features.items():
        X = get_time_split(D, train[0], train[1])
        y = get_time_split(labels_df, test[0], test[1])
        X_test = get_time_split(D, test[0], test[1])
        for name, regr in zip(['RF', 'OLS', 'Lasso'], [rf, ols, lasso]):
            regr.fit(X, y)
            preds = regr.predict(X_test)
            labels = get_time_split(labels_df, test[0], test[1])
            rmse = mean_squared_error(labels, preds, squared=False)
            stderr = diebold_mariano(preds, labels)
            upper_bound = np.sqrt(rmse**2 + 1.96*stderr)
            lower_bound = np.sqrt(rmse**2 - 1.96*stderr)
            precision, recall, thresholds = precision_recall_curve(labels, preds)
            auc_precision_recall = auc(recall, precision)
            _row = pd.DataFrame.from_dict({'method': [name], 'split': [test], 'features': [f], 'country': ['all'],
                                           'rmse': [rmse], 'lower_bound': [lower_bound], 'upper_bound': [upper_bound]},
                                          orient='columns')
            fig_3a = pd.concat([fig_3a, _row], axis=0)
            _row = pd.DataFrame.from_dict({'method': [name], 'split': [test], 'features': [f], 
                                           'aucpr': [auc_precision_recall]},
                                          orient='columns')
            fig_3b = pd.concat([fig_3b, _row], axis=0)
            print ("Method: {}, Split: {}, Features: {}, AUCPR: {}".format(name, test, f, auc_precision_recall))
            print ("Method: {}, Split: {}, Features: {}, RMSE: {} [{}, {}]".format(name, test, f, rmse, lower_bound, upper_bound))
            
            recall_at_80p = 0
            for p_t, p_t_add_3, p_t_min_3 in zip(preds, preds[3:] + [1,1,1], preds[:-3]+[5,5,5]):
                u_b = thresholds[f]['upper_bound']
                l_b = thresholds[f]['lower_bound']
                if p_t >= u_b and p_t_add_3 >= u_b and p_t_min_3 <= l_b:
                    recall_at_80p += 1
            
            _row = pd.DataFrame.from_dict({'method': [name], 'split': [test], 'features': [f], 
                                           'recall_at_80p': [recall_at_80p]},
                                          orient='columns')
            fig_3c = pd.concat([fig_3c, _row], axis=0)
            
            for country in time_series['country'].unique():
                c_id = X_test[X_test['country']==country]
                labels_c = labels[c_id]
                preds_c = preds[c_id]
                rmse = mean_squared_error(labels_c, preds_c, squared=False)
                stderr = diebold_mariano(preds_c, labels_c)
                upper_bound = np.sqrt(rmse**2 + 1.96*stderr)
                lower_bound = np.sqrt(rmse**2 - 1.96*stderr)
                _row = pd.DataFrame.from_dict({'method': [name], 'split': [test], 'features': [f], 'country': [country],
                                           'rmse': [rmse], 'lower_bound': [lower_bound], 'upper_bound': [upper_bound]},
                                          orient='columns')
                fig_3a = pd.concat([fig_3a, _row], axis=0)
                print ("Country: {}, Method: {}, Split: {}, Features: {}, RMSE: {} [{}, {}]".format(country, name, test, f, rmse, lower_bound, upper_bound))

fig_3a.to_csv('fig_3a.csv')
fig_3b.to_csv('fig_3b.csv')
fig_3c.to_csv('fig_3c.csv')

The columns are:
 Index(['fews_ipc_3', 'fews_ipc_6', 'fews_ipc_9', 'fews_ipc_12', 'fews_ipc_15',
       'fews_ipc_18', 'ndvi_mean_3', 'ndvi_anom_4', 'rain_mean_5',
       'rain_anom_6', 'et_mean_7', 'et_anom_8', 'ndvi_mean_province_3',
       'ndvi_anom_province_4', 'rain_mean_province_5', 'rain_anom_province_6',
       'et_mean_province_7', 'et_anom_province_8', 'ndvi_mean_country_3',
       'ndvi_anom_country_4', 'rain_mean_country_5', 'rain_anom_country_6',
       'et_mean_country_7', 'et_anom_country_8', 'area', 'cropland_pct', 'pop',
       'ruggedness_mean', 'pasture_pct'],
      dtype='object')


KeyError: 'year'