# 📊 **Re-Implementation of "Predicting Food Crises Using News Streams"**

---

#### 🔍 **Objective**

This notebook aims to **reproduce and analyze** the methodology presented in the paper:

📄 **Paper:** [Predicting food crises using news streams](https://www.science.org/doi/10.1126/sciadv.abm3449)  
📊 **Dataset:** [Harvard Dataverse Repository](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/CJDWUW)  
📜 **Original Code & Methods:** [GitHub - Regression Modeling (Step 5)](https://github.com/philippzi98/food_insecurity_predictions_nlp/blob/main/Step%205%20-%20Regression%20Modelling/README.md)

---

#### 🛠 **Methodology**

This implementation follows the **key steps** outlined in the paper to predict **food insecurity crises** using a combination of:
1️⃣ **Traditional Risk Factors** (conflict, climate, food prices, etc.)  
2️⃣ **News-Based Indicators** (text feature frequencies from news articles)  
3️⃣ **Lagging & Aggregation** (temporal dependencies at district, province, and country levels)  
4️⃣ **Machine Learning Models** (Random Forest, OLS, Lasso)

---

#### 🔗 **Reference Materials**

📄 **Supplementary Material:** Available in `supplemental_material_from_paper.pdf`  
📊 **Datasets Used:**

- `time_series_with_causes_zscore_full.csv` (Main dataset with time-series features)
- `famine-country-province-district-years-CS.csv` (Food insecurity classification)
- `matching_districts.csv` (Geographical standardization)


# 📚🔧 Import Libraries

In this notebook, we will use uv to manage our Python environment and packages efficiently. uv is a modern and fast package manager that simplifies virtual environment creation, and dependency installation. We will create a virtual environment, install necessary libraries, and ensure our environment stays consistent across different setups.


In [1]:
# # Uncoment the below cell to install `uv` if you have not already. You can also install it trhiugh `pip` by running `!pip install uv` but this will be within your current python environment and not globally.

# !curl -LsSf https://astral.sh/uv/install.sh | sh
# !uv venv world-bank
# !source world-bank/bin/activate

In [2]:
# !pip install -r requirements.txt
# Run the line below to create the working conda environment
# conda env create -f environment.yml 


In [3]:
import numpy as np
from IPython.display import display, Image
import os
import gdown
import zipfile
from fuzzywuzzy import fuzz
import math
import polars as pl
import time

In [4]:
url = "https://drive.google.com/uc?id=1YoQ1hz9RlaLr2xW3KoKCfJPyyO2PErym"
output = "data.zip"

if not os.path.exists("./data"):
    gdown.download(url, output, quiet=False) 
    zipfile.ZipFile('data.zip', 'r').extractall()
else:
    print("You already have the data downloaded and extracted")

You already have the data downloaded and extracted


## 📂 Load and Clean Data

**Understanding the Time-Series Dataset & Column Selection**

This dataset contains **district-level time-series data** on food insecurity risk factors, including:

- **📅 Temporal Information:** `year`, `month`, `year_month`
- **📍 Geographical Identifiers:** `admin_code`, `admin_name`, `province`, `country`
- **🌍 Traditional Risk Factors:** Climate (`rain_mean`, `ndvi_mean`), conflict (`acled_count`), food prices (`p_staple_food`)
- **📰 News-Based Indicators:** Proportions of news articles mentioning crisis-related keywords (`conflict_0`, `famine_0`, etc.)
- **📉 Food Insecurity Label:** `fews_ipc` (Integrated Phase Classification)

🔥 **Columns We Will Drop & Why**
✔ **Redundant Aggregations:** `_1`, `_2` columns (province & country-level values) since we will recompute aggregations from scratch anyways.  
✔ **Unnamed/Index Columns:** `Unnamed: 0` as it is unnecessary. It is just a duplicate of default index.
✔ **Unnecessary Identifiers:** If `admin_code` and `admin_name`, after matching these to `matching_districts.csv`, we can drop them.

---

> ⚠️ **NOTE:**  
> For a detailed explanation of the dataset and features, refer to the [`explore_time_series.ipynb`](./explore_time_series.ipynb) notebook.


In [5]:
start_time = time.time()

In [6]:
time_series = pl.read_csv('./data/time_series_with_causes_zscore_full.csv')
admins = pl.read_csv('./data/famine-country-province-district-years-CS.csv', schema_overrides={"CS": pl.Float64})
valid_matching = pl.read_csv('./data/matching_districts.csv')

In [7]:
time_series.head(5)

Unnamed: 0_level_0,index,country,admin_code,admin_name,centx,centy,year_month,year,month,fews_ipc,fews_ha,fews_proj_near,fews_proj_near_ha,fews_proj_med,fews_proj_med_ha,ndvi_mean,ndvi_anom,rain_mean,rain_anom,et_mean,et_anom,acled_count,acled_fatalities,p_staple_food,area,cropland_pct,pop,ruggedness_mean,pasture_pct,change_fews,land seizures_0,land seizures_1,land seizures_2,slashed export_0,slashed export_1,slashed export_2,…,authoritarian_2,dictators_0,dictators_1,dictators_2,clans_0,clans_1,clans_2,gastrointestinal_0,gastrointestinal_1,gastrointestinal_2,terrorist_0,terrorist_1,terrorist_2,warlord_0,warlord_1,warlord_2,d'etat_0,d'etat_1,d'etat_2,overthrow_0,overthrow_1,overthrow_2,convoys_0,convoys_1,convoys_2,carbon_0,carbon_1,carbon_2,mayhem_0,mayhem_1,mayhem_2,dehydrated_0,dehydrated_1,dehydrated_2,mismanagement_0,mismanagement_1,mismanagement_2
i64,i64,str,i64,str,f64,f64,str,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
0,30,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2009_07""",2009,7,1.0,,,,,,0.106035,106.547146,0.353588,-0.070848,0.191125,-0.073903,0,0,1.065669,54174.53381,1.417796,1241226.0,101047.1587,16.246279,0.0,-0.765667,-0.426667,0.886,0.597667,-0.987,1.449333,…,0.75,0.496,1.557333,1.252333,0.827,-0.035667,-0.02,-0.192,0.281667,-0.259667,-0.284333,1.626,0.532667,-0.668667,1.497333,-0.794667,0.647333,1.652333,-0.029,-0.891333,0.848333,1.472667,0.112667,-0.887,-0.963667,1.265333,-0.493667,1.053,0.667,-0.171,-0.833,0.173667,0.168,1.284667,-0.073,-0.427667,0.668333
1,33,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2009_10""",2009,10,1.0,,,,,,0.103009,106.034013,0.409304,-0.116134,0.69447,0.225598,0,0,1.100531,54174.53381,1.417796,1241226.0,101047.1587,16.246279,1.0,-0.556272,-0.791605,-0.903605,-0.933739,-0.645739,-0.918405,…,-1.027332,-0.665846,-0.591846,-1.039846,-0.756904,-0.590571,-1.234904,-0.545727,-0.474394,-0.841394,-1.037016,-0.917683,-0.787683,-0.811291,-0.713958,-1.257625,-0.850261,-0.831261,-0.759594,-0.948892,-1.198892,-0.883225,-0.728972,-1.203638,-0.874305,-0.765146,-1.141479,-0.660812,-0.63658,-0.520247,-0.782913,-0.671587,-0.612254,-0.926921,-0.510467,-0.625133,-0.452467
2,36,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2010_01""",2010,1,2.0,,,,,,0.1096,111.433187,3.894158,-2.333251,3.441319,-1.450951,0,0,0.98839,54174.53381,1.417796,1280853.0,101047.1587,16.246279,0.0,-0.006667,0.431,0.67,0.755667,-0.263,-0.574,…,-0.307333,0.010333,0.485,-0.820333,1.460333,-0.351,0.817667,1.506333,1.484667,-0.378667,0.455,-0.871,0.766,1.595667,0.023667,-0.968667,0.571667,-0.837333,0.444667,0.279,0.078,0.614333,-0.868333,-0.598,1.316,0.058333,1.368,-0.134333,1.447667,-0.844333,0.778667,-0.676,-0.689667,0.293333,0.530333,-0.471333,0.955333
3,39,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2010_04""",2010,4,2.0,,,,,,0.111599,94.212242,1.609664,-0.788739,1.851542,-0.771469,0,0,0.992492,54174.53381,1.417796,1280853.0,101047.1587,16.246279,-1.0,-0.193697,-0.613697,-0.307364,0.311536,0.337869,-0.524797,…,-0.369667,-0.090077,0.224923,-0.58841,-0.616381,-0.605381,-0.114381,-0.79397,0.29903,-0.63397,-0.722159,-0.130159,-0.123825,-0.130521,-0.578521,0.090146,0.04763,0.137297,-0.603036,0.362613,0.231613,0.018279,0.480986,-0.427347,-0.121014,0.026073,0.165406,-0.326927,-0.594877,0.16479,-0.90521,-0.62054,0.165794,0.045794,-1.0116,-0.8106,-0.2056
4,42,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2010_07""",2010,7,1.0,,,,,,0.096943,97.411677,0.3938336,-0.030602,0.291468,0.026441,0,0,1.024889,54174.53381,1.417796,1280853.0,101047.1587,16.246279,1.0,-0.787272,-0.725605,-0.879272,-0.598072,-0.803072,-0.817739,…,-0.748332,-0.611846,-0.511179,-0.470512,-0.791904,-1.053238,-0.653238,-0.509394,-0.462727,-0.856727,-0.69435,-1.102683,-1.13235,-1.215958,-0.832291,-0.948291,-0.865261,-0.812261,-0.645928,-1.119225,-0.977558,-0.758892,-1.060638,-0.876972,-1.210305,-0.673479,-1.090479,-1.085146,-0.709913,-0.867913,-0.770247,-0.787921,-0.974587,-0.946921,-0.611133,-0.7098,-0.6228


In [8]:
t_variant_traditional_factors = [ 'p_staple_food']
t_variant_traditional_factors = ['ndvi_mean', 'ndvi_anom', 'rain_mean', 'rain_anom', 'et_mean', 'et_anom', 
                                    'acled_count', 'acled_fatalities', 'p_staple_food']
t_invariant_traditional_factors = ['area', 'cropland_pct', 'pop', 'ruggedness_mean', 'pasture_pct']
news_factors = [name for name in time_series.columns if '_0' in name]


In [9]:
news_factors[0]


'land seizures_0'

In [10]:
potential_extra_cols = set(time_series.columns) - set(t_variant_traditional_factors) - set(t_invariant_traditional_factors) - set(news_factors)
potential_extra_cols = [col for col in potential_extra_cols if not col.endswith(('_1', '_2', '_3'))]
print("Potential extra columns", sorted(potential_extra_cols))

Potential extra columns ['', 'admin_code', 'admin_name', 'centx', 'centy', 'change_fews', 'country', 'fews_ha', 'fews_ipc', 'fews_proj_med', 'fews_proj_med_ha', 'fews_proj_near', 'fews_proj_near_ha', 'index', 'month', 'year', 'year_month']


### 🌍 Admin Level Mapping: Standardizing Geographical Identifiers

In this section, we will **map and standardize** the `admin_code` and `admin_name` fields to their corresponding **district, province, and country names**. This step is **crucial** for ensuring **consistency** across different datasets and enabling **accurate aggregations** at multiple administrative levels.

🛠 **Why is Admin Level Mapping Important?**
✅ Different datasets may use **slightly different spellings or formats** for district names.  
✅ Some district names might be **missing or misspelled**, requiring standardization.  
✅ We need to **match and align** district names across various sources before aggregating at **province and country levels**.  
✅ Proper mapping allows us to **merge datasets correctly** without losing information.  

📌 **Steps in Admin Mapping**
1️⃣ **Load the `matching_districts.csv` file**, which provides the mapping between different district name variations.  
2️⃣ **Identify missing or unmatched `admin_name` values** and find their closest matches using fuzzy matching techniques.  
3️⃣ **Ensure that each `admin_code` uniquely maps to one `district`, `province`, and `country`.**  
4️⃣ **Replace inconsistent names** in the dataset with their standardized versions.  
5️⃣ **Aggregate data at the `province` and `country` levels** after ensuring all districts are correctly mapped.  


In [11]:
print(admins.select(pl.col("country").n_unique()).to_numpy()[0][0])

39


In [12]:
admins.columns

['', 'country', 'district', 'year', 'month', 'CS', 'province']

In [13]:
admin_names = time_series['admin_name'].unique()
districts = admins['district'].unique()
provinces = admins['province'].unique()
countries = admins['country'].unique()

In [14]:
print (len(admin_names), len(districts), len(provinces), len(countries))
print (len(set(admin_names).difference(districts)))
missing_admin_names = set(admin_names).difference(districts)
print (len(missing_admin_names.difference(provinces)))
missing_admin_names = missing_admin_names.difference(provinces)

1142 4113 474 39
369
230


### Fuzzy String Matching for Missing Names

The function uses **fuzzy string matching** to find the best approximate matches for missing administrative names (e.g., districts and provinces). 

- Finds the **best matching district/province** for each missing name.
- Uses **fuzzy string matching** to calculate the similarity between missing names and known names.
- Returns a dictionary that maps each missing name to its closest match.


In [15]:
def find_matching(missing, names):
    matching_districts = {}
    for m in missing:
        max_overlap = 0
        nearest_d = None
        for d in names:
            d = str(d)
            dist = fuzz.partial_ratio(m, d)
            if dist > max_overlap:
                max_overlap = dist
                nearest_d = d
        matching_districts[m] = nearest_d
    return matching_districts


matching = find_matching(missing_admin_names, districts)
matching_p = find_matching(missing_admin_names, provinces)

# manually verify matching and update
for k in matching.keys():
    print (k, matching[k], matching_p[k])


Gourma-Rharous Gourma Ghor
West Harerge West Hararge West Darfur
MPongwe Mpongwe Bong
Lulua Luilu Lualaba
Kantché Kantche Kano
Bale.1 Bale Bay
Valliere Vallieres Niger
Mwingi Mwingi West Migori
Abu Hamad Abu Hamed Hilmand
Sheikh Jebrat El Sheikh Sahel
Beni San Benito Benue
Al Kurumik Qulansiyah wa `Abd Al Kuri Ituri
Gucha Kabuchai Ahuachapan
La Pendé La Pende Lac
Al Faw El Faw Al Jawf
Central Kisii Kiti Central
North Gonder North Gondar North
South Khartoum Khartoum Khartoum
Anse-D'Ainault `Ain Abia
Aguié Aguie Bangui
Guji Gujii Guidimaka
Mashra'ah wa Hadnan Mashra`ah wa Hadnan Kankan
North al Gazera Ganze North
Id El Ghanem Ganye Kanem
Lac-Léré Lac-Lere Lac
Nandi South Nnewi South Nandi
Al Gutaina El Gutaina Rutana
South Gonder South Gondar Sud
Sharg En Nile Sahar Niger
Mangwe (South) Mangwe Southern
Croix-Des-Bouquets Bo Ouest
Gweru Gweru Rural Meru
Eastern Tigray Miga Eastern
Chiredzi Chiredzi Rural Moyen-Chari
Port-Salut Port Salut Salamat
Kwekwe Kwekwe Rural Kwale
UMP Keur Macene 

### Encoding Decoding

`to_ascii_escaped(s)`: Converts a Unicode string to an ASCII-safe representation using **unicode-escape**.

`from_ascii_escaped(escaped)`: Converts the escaped ASCII string back into its original Unicode form.

In [16]:
def to_ascii_escaped(s):
    """
    Convert a Unicode string to an ASCII-safe string using unicode-escape.
    This will replace non-ASCII characters with their escape sequences.
    """
    if isinstance(s, bytes):
        s = s.decode('utf-8')
    # Using 'unicode-escape' encoding produces a bytes object,
    # then decode it to get an ASCII string.
    return s.encode('unicode-escape').decode('ascii')

def from_ascii_escaped(escaped):
    """
    Convert the ASCII-escaped string back to the original Unicode string.
    """
    # Encode the ASCII string to bytes, then decode using 'unicode-escape'
    return escaped.encode('ascii').decode('unicode-escape')


### Finding the Province for a Given District or Province

`find_province(x)`, finds the **province** corresponding to a given administrative name. It accounts for:
- **Direct Lookups** (Exact match in known district/province lists)
- **Fuzzy Matching** (Using ASCII-safe transformation for inconsistent text encoding)
- **Validation Against a Predefined Mapping (`valid_matching`)**

In [17]:
# Define matched globally
matched = valid_matching['missing'].unique()

def to_ascii_escaped(s):
    """
    Convert a Unicode string to an ASCII-safe string using unicode-escape.
    This will replace non-ASCII characters with their escape sequences.
    """
    if isinstance(s, bytes):
        s = s.decode('utf-8')
    return s.encode('unicode-escape').decode('ascii')

def find_province(x):
    try:
        # Ensure x is a Unicode string.
        if isinstance(x, bytes):
            x = x.decode('utf-8')
        
        if x in districts:
            return admins.filter(pl.col('district') == x).select('province').to_series()[0]
        elif x in provinces:
            return x

        escaped_x = to_ascii_escaped(x)

        if escaped_x in matched:
            # print(f"Found {escaped_x} in matched")
            v = valid_matching.filter(pl.col('missing') == escaped_x)
            # print(f"Matched value: {v}")
            if v['match'][0] == 'district':
                x2 = v['district'][0]
                return admins.filter(pl.col('district') == x2)['province'][0]
            elif v['match'][0] == 'province':
                return v['province'][0] 
        
        raise Exception("No matching province found")
    except Exception as e:
        raise Exception(f"Province not found for: {x} ({e})")


### Handling Admin Names with Accented Characters and Mapping to Provinces

Maps `admin_names` to provinces using the `find_province(a)` function.  
If a **direct lookup fails**, it tries to handle cases where the **admin name contains accented characters** (`é`, `è`, `ô`) ->  (encoding decoding issues resolved through directly replacing these with 'e' or 'o', leads to finding a valid match). 

In [18]:
admin_to_province = {}
for a in admin_names:
    try:
        admin_to_province[a] = find_province(a)
    except Exception as e:
        # Print the admin name that caused an error
        print("Error with:", a)
        # Check if a contains accented characters "é" or "è"
        if 'é' in a or 'è' in a or 'ô' in a:
            a_modified = a.replace('é', 'e').replace('è', 'e').replace('ô', 'o')
            # Check if the modified name is in districts
            if a_modified in districts:
                # Use the modified name to look up the province from admins
                try:
                    province = admins.filter(pl.col('district') == a_modified)['province'][0]
                    admin_to_province[a] = province
                    print(f"Replaced '{a}' with '{a_modified}', found province: {province}")
                except Exception as ex:
                    print(f"Modified name '{a_modified}' not found in admins: {ex}")
            else:
                print(f"Modified name '{a_modified}' not in districts.")
        else:
            print(f"No accented e found in '{a}'.")


Error with: Barh-Kôh
Replaced 'Barh-Kôh' with 'Barh-Koh', found province: Moyen-Chari
Error with: Tillabéri
Replaced 'Tillabéri' with 'Tillaberi', found province: Tillaberi
Error with: Lac-Léré
Replaced 'Lac-Léré' with 'Lac-Lere', found province: Mayo-Kebbi Ouest
Error with: Filingué
Replaced 'Filingué' with 'Filingue', found province: Tillaberi
Error with: Aguié
Replaced 'Aguié' with 'Aguie', found province: Maradi
Error with: Kantché
Replaced 'Kantché' with 'Kantche', found province: Zinder
Error with: La Nya Pendé
Replaced 'La Nya Pendé' with 'La Nya Pende', found province: Logone Oriental
Error with: Téra
Replaced 'Téra' with 'Tera', found province: Tillaberi
Error with: Gouré
Replaced 'Gouré' with 'Goure', found province: Zinder
Error with: Mangalmé
Replaced 'Mangalmé' with 'Mangalme', found province: Guera
Error with: Illéla
Replaced 'Illéla' with 'Illela', found province: Sokoto
Error with: Bankilaré
Replaced 'Bankilaré' with 'Bankilare', found province: Tillaberi
Error with: Ma

### Mapping Administrative Names to Provinces in time_series

Maps `admin_name` to their respective **provinces** using a precomputed dictionary - >`admin_to_province` in `time_series`.


In [19]:
time_series = time_series.with_columns(
    pl.col('admin_name').map_elements(
        lambda x: admin_to_province[x] if x in admin_to_province else admin_to_province.get(x.replace('ô', 'o')),
        return_dtype=pl.Utf8
    ).alias('province')
)

# ⏳ Time Lagging & Feature Engineering

#### 📅 **Why Use Lagging?**

To predict food insecurity **for a given quarter**, we use:

- **6 months of historical values** for traditional & news-based features.
- **Province & country-level aggregations** to capture broader shocks.
- **6 quarters of lagged IPC phase values** to model temporal dependencies.

#### ⚡ **Optimized Lagging Approach**

To improve computational efficiency, we:
✔ Use `groupby()` for **fast province & country-level aggregations**.  
✔ Merge lagged data via `merge()` instead of slow `.apply()`.  
✔ Only keep **past data** to ensure no data leakage.


In [20]:
def add_time_lagged(features, start=3, end=9, diff=1, agg=True, time_series=time_series):
    levels = ['', '_province', '_country'] if agg else ['']
    
    # Work on a copy to avoid modifying the original during processing
    working_df = time_series.clone()
    
    # Precompute a mapping for each feature (with its suffix) for fast lookups.
    # For each row, its lookup key will be: admin_code + '_' + year_month.
    lookup_maps = {}  # dict mapping f_s -> mapping dict
    for suffix in levels:
        for f in features:
            f_s = f + suffix
            
            # Create a DataFrame with just the columns we need and create the key
            keys_df = working_df.select(
                (pl.col("admin_code").cast(str) + "_" + pl.col("year_month").cast(str)).alias("key"),
                pl.col(f_s)
            )
            
            # Group by key and take first value (handles duplicates efficiently)
            # Then convert directly to dict
            mapping = keys_df.group_by("key").agg(pl.col(f_s).first()).to_dict(as_series=False)
            mapping = dict(zip(mapping["key"], mapping[f_s]))
            
            lookup_maps[f_s] = mapping

    # Prepare list to collect all new columns (as Series)
    new_cols = {}
    
    # Process each feature and lag combination
    for suffix in levels:
        for f in features:
            f_s = f + suffix
            mapping = lookup_maps[f_s]
            for t in range(start, end, diff):
                col_name = f"{f_s}_{t}"
                if col_name in time_series.columns:
                    continue
                
                # Create DataFrame with lagged months and years
                temp_df = working_df.with_columns([
                    pl.when(pl.col("month") - t <= 0)
                    .then(pl.col("year") - 1)
                    .otherwise(pl.col("year"))
                    .alias("l_year"),
                    
                    ((pl.col("month") - 1 - t) % 12 + 1).alias("l_month")
                ])
                
                # Create reference key
                temp_df = temp_df.with_columns(
                    (pl.col("admin_code").cast(str) + "_" + 
                    pl.col("l_year").cast(str) + "_" + 
                    pl.col("l_month").cast(str)).alias("ref_key")
                )
                
                # Do efficient lookup using the precomputed mapping
                keys = list(mapping.keys())
                values = [mapping[k] for k in keys]
                lookup_df = pl.DataFrame({"key": keys, "value": values})
                
                # Join with lookup DataFrame to get mapped values
                result = temp_df.join(
                    lookup_df, 
                    left_on="ref_key", 
                    right_on="key", 
                    how="left"
                )
                
                # Fill null values with original values from f_s
                result = result.with_columns(
                    pl.when(pl.col("value").is_null())
                    .then(pl.col(f_s))
                    .otherwise(pl.col("value"))
                    .alias(col_name)
                )
                
                # Add to new_cols
                new_cols[col_name] = result.select(pl.col(col_name)).to_series()
                
    # If any new columns were created, add them to the original time_series DataFrame.
    if new_cols:
        new_cols_df = pl.DataFrame(new_cols)
        time_series = time_series.hstack(new_cols_df)
        
        
    return time_series


# Province & Country-Level Aggregation

This function aggregates feature values at the province and country levels to capture regional trends, aiding in food insecurity prediction. The process includes:

- **Grouping by year_month and level:** Data is grouped by year_month and the specified level (province or country) to calculate the mean of features, reflecting regional trends over time.

- **Applying transformations efficiently:** Instead of merging aggregated data, `transform("mean")` is used to directly assign the computed mean to each row, avoiding unnecessary joins and improving performance.  

#### ⚡ **Efficiency Gains**

- **Fast Aggregation**: Uses `groupby()` for efficient aggregation.
- **Avoids Costly Joins**: Eliminates the need for `merge()` by using `transform()` instead, reducing computational overhead.  
- **Memory Efficiency**: Converts the `level` column to a categorical type to reduce memory usage.

This approach ensures faster processing while maintaining the quality of aggregated features.


In [21]:
def add_agg_factors(features, level):
    global time_series
      
    # Create expressions for each feature to calculate the mean and rename
    mean_exprs = [
        pl.col(f).mean().over(["year_month", level]).alias(f"{f}_{level}")
        for f in features
    ]
    
    # Add the new aggregated columns to the DataFrame
    time_series = time_series.with_columns(mean_exprs)
    
    return time_series


In [22]:
# Aggregating news factors
time_series = add_agg_factors(news_factors, level='country')
time_series = add_agg_factors(news_factors, level='province')

# Aggregating variant traditional factors
time_series = add_agg_factors(t_variant_traditional_factors, level='province')
time_series = add_agg_factors(t_variant_traditional_factors, level='country')

# Aggregating invariant traditional factors
time_series = add_agg_factors(t_invariant_traditional_factors, level='province')
time_series = add_agg_factors(t_invariant_traditional_factors, level='country')

# Drop null values
time_series = time_series.drop_nulls()
time_series.head()

Unnamed: 0_level_0,index,country,admin_code,admin_name,centx,centy,year_month,year,month,fews_ipc,fews_ha,fews_proj_near,fews_proj_near_ha,fews_proj_med,fews_proj_med_ha,ndvi_mean,ndvi_anom,rain_mean,rain_anom,et_mean,et_anom,acled_count,acled_fatalities,p_staple_food,area,cropland_pct,pop,ruggedness_mean,pasture_pct,change_fews,land seizures_0,land seizures_1,land seizures_2,slashed export_0,slashed export_1,slashed export_2,…,terrorist_0_province,warlord_0_province,d'etat_0_province,overthrow_0_province,convoys_0_province,carbon_0_province,mayhem_0_province,dehydrated_0_province,mismanagement_0_province,ndvi_mean_province,ndvi_anom_province,rain_mean_province,rain_anom_province,et_mean_province,et_anom_province,acled_count_province,acled_fatalities_province,p_staple_food_province,ndvi_mean_country,ndvi_anom_country,rain_mean_country,rain_anom_country,et_mean_country,et_anom_country,acled_count_country,acled_fatalities_country,p_staple_food_country,area_province,cropland_pct_province,pop_province,ruggedness_mean_province,pasture_pct_province,area_country,cropland_pct_country,pop_country,ruggedness_mean_country,pasture_pct_country
i64,i64,str,i64,str,f64,f64,str,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
11,63,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2012_04""",2012,4,1.0,0.0,1.0,0.0,1.0,0.0,0.116692,98.511837,2.908423,0.510019,4.698466,2.075454,0,0,1.019085,54174.53381,1.417796,1379956.0,101047.1587,16.246279,0.0,0.864333,0.262,1.209,0.010333,-0.543667,-0.852667,…,-0.224,0.404,0.706,-0.343667,-0.690333,-0.267,-0.291333,-0.553333,1.016667,0.116692,98.511837,2.908423,0.510019,4.698466,2.075454,0.0,0.0,1.019085,0.168389,100.05702,8.605167,1.043137,8.770066,2.15922,0.0,0.0,1.151653,54174.53381,1.417796,1379956.0,101047.1587,16.246279,18883.616188,7.335028,874755.441176,316314.116188,49.145729
12,66,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2012_07""",2012,7,1.0,0.0,1.0,0.0,1.0,0.0,0.104068,104.571165,0.363717,-0.060719,0.254137,-0.010891,0,0,1.031224,54174.53381,1.417796,1379956.0,101047.1587,16.246279,0.0,0.582333,0.015333,1.219667,-0.936667,-0.993667,1.612,…,0.169333,0.063667,-0.807333,-0.633667,-0.818667,0.846667,1.122,0.224,0.184,0.104068,104.571165,0.363717,-0.060719,0.254137,-0.010891,0.0,0.0,1.031224,0.172991,107.329397,1.145694,-0.271082,2.026944,0.019929,0.0,0.0,1.172457,54174.53381,1.417796,1379956.0,101047.1587,16.246279,18883.616188,7.335028,874755.441176,316314.116188,49.145729
13,69,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2012_10""",2012,10,1.0,0.0,1.0,0.0,1.0,0.0,0.101172,104.14274,0.452095,-0.073343,0.628785,0.159913,0,0,1.194955,54174.53381,1.417796,1379956.0,101047.1587,16.246279,0.0,0.964333,-0.597333,-0.034333,-0.6,0.105,-0.277333,…,1.068,-0.714333,1.443,1.471333,0.113,1.133333,0.344667,-0.665667,-0.436667,0.101172,104.14274,0.452095,-0.073343,0.628785,0.159913,0.0,0.0,1.194955,0.139471,104.454137,1.890517,0.441801,2.518038,0.56151,0.0,0.0,1.300843,54174.53381,1.417796,1379956.0,101047.1587,16.246279,18883.616188,7.335028,874755.441176,316314.116188,49.145729
14,72,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2013_01""",2013,1,1.0,0.0,1.0,0.0,1.0,0.0,0.095679,97.278972,2.836539,-3.39087,4.793217,-0.099054,0,0,1.157092,54174.53381,1.417796,1429508.0,101047.1587,16.246279,0.0,0.426,-0.806,0.714,-0.188667,0.121667,1.336,…,-0.805333,0.564333,1.416,1.163,-0.271,1.477,-0.470333,-0.641667,0.95,0.095679,97.278972,2.836539,-3.39087,4.793217,-0.099054,0.0,0.0,1.157092,0.074545,90.765142,3.960445,-2.072366,5.509975,0.162572,0.0,0.0,1.368094,54174.53381,1.417796,1429508.0,101047.1587,16.246279,18883.616188,7.335028,901519.941176,316314.116188,49.145729
15,75,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2013_04""",2013,4,1.0,0.0,1.0,0.0,1.0,0.0,0.135269,114.19412,2.997978,0.599575,3.522834,0.899823,0,0,1.042512,54174.53381,1.417796,1429508.0,101047.1587,16.246279,0.0,-0.503,-0.131333,1.169333,-0.295667,1.138,0.287,…,1.436333,0.916,0.927,0.314333,-0.043667,1.304333,0.085333,1.046333,0.838667,0.135269,114.19412,2.997978,0.599575,3.522834,0.899823,0.0,0.0,1.042512,0.181455,106.985048,10.050844,2.488815,8.968665,2.35782,0.0,0.0,1.414363,54174.53381,1.417796,1429508.0,101047.1587,16.246279,18883.616188,7.335028,901519.941176,316314.116188,49.145729


# Add time lagged features


In [23]:
time_series = add_time_lagged(t_variant_traditional_factors, time_series=time_series)
time_series = add_time_lagged(news_factors, time_series=time_series)
time_series = add_time_lagged(['fews_ipc'], end=21, diff=3, agg=False, time_series=time_series)
time_series = add_time_lagged(['fews_proj_near'], start=3, end=4, diff=1, agg=False, time_series=time_series)

# Drop null values again
time_series = time_series.drop_nulls()
time_series.shape

(28141, 4070)

In [24]:
time_series.head()

Unnamed: 0_level_0,index,country,admin_code,admin_name,centx,centy,year_month,year,month,fews_ipc,fews_ha,fews_proj_near,fews_proj_near_ha,fews_proj_med,fews_proj_med_ha,ndvi_mean,ndvi_anom,rain_mean,rain_anom,et_mean,et_anom,acled_count,acled_fatalities,p_staple_food,area,cropland_pct,pop,ruggedness_mean,pasture_pct,change_fews,land seizures_0,land seizures_1,land seizures_2,slashed export_0,slashed export_1,slashed export_2,…,convoys_0_country_3,convoys_0_country_4,convoys_0_country_5,convoys_0_country_6,convoys_0_country_7,convoys_0_country_8,carbon_0_country_3,carbon_0_country_4,carbon_0_country_5,carbon_0_country_6,carbon_0_country_7,carbon_0_country_8,mayhem_0_country_3,mayhem_0_country_4,mayhem_0_country_5,mayhem_0_country_6,mayhem_0_country_7,mayhem_0_country_8,dehydrated_0_country_3,dehydrated_0_country_4,dehydrated_0_country_5,dehydrated_0_country_6,dehydrated_0_country_7,dehydrated_0_country_8,mismanagement_0_country_3,mismanagement_0_country_4,mismanagement_0_country_5,mismanagement_0_country_6,mismanagement_0_country_7,mismanagement_0_country_8,fews_ipc_3,fews_ipc_6,fews_ipc_9,fews_ipc_12,fews_ipc_15,fews_ipc_18,fews_proj_near_3
i64,i64,str,i64,str,f64,f64,str,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
11,63,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2012_04""",2012,4,1.0,0.0,1.0,0.0,1.0,0.0,0.116692,98.511837,2.908423,0.510019,4.698466,2.075454,0,0,1.019085,54174.53381,1.417796,1379956.0,101047.1587,16.246279,0.0,0.864333,0.262,1.209,0.010333,-0.543667,-0.852667,…,0.323402,0.323402,0.323402,0.323402,0.323402,0.323402,0.458894,0.458894,0.458894,0.458894,0.458894,0.458894,0.241553,0.241553,0.241553,0.241553,0.241553,0.241553,0.743474,0.743474,0.743474,0.743474,0.743474,0.743474,0.450139,0.450139,0.450139,0.450139,0.450139,0.450139,1.0,1.0,1.0,1.0,1.0,1.0,1.0
12,66,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2012_07""",2012,7,1.0,0.0,1.0,0.0,1.0,0.0,0.104068,104.571165,0.363717,-0.060719,0.254137,-0.010891,0,0,1.031224,54174.53381,1.417796,1379956.0,101047.1587,16.246279,0.0,0.582333,0.015333,1.219667,-0.936667,-0.993667,1.612,…,0.207569,0.207569,0.207569,0.207569,0.207569,0.207569,0.417088,0.417088,0.417088,0.417088,0.417088,0.417088,0.30201,0.30201,0.30201,0.30201,0.30201,0.30201,0.219176,0.219176,0.219176,0.219176,0.219176,0.219176,0.398922,0.398922,0.398922,0.398922,0.398922,0.398922,1.0,1.0,1.0,1.0,1.0,1.0,1.0
13,69,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2012_10""",2012,10,1.0,0.0,1.0,0.0,1.0,0.0,0.101172,104.14274,0.452095,-0.073343,0.628785,0.159913,0,0,1.194955,54174.53381,1.417796,1379956.0,101047.1587,16.246279,0.0,0.964333,-0.597333,-0.034333,-0.6,0.105,-0.277333,…,0.588971,0.588971,0.588971,0.588971,0.588971,0.588971,0.160451,0.160451,0.160451,0.160451,0.160451,0.160451,0.421833,0.421833,0.421833,0.421833,0.421833,0.421833,0.373696,0.373696,0.373696,0.373696,0.373696,0.373696,0.2515,0.2515,0.2515,0.2515,0.2515,0.2515,1.0,1.0,1.0,1.0,1.0,1.0,1.0
14,72,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2013_01""",2013,1,1.0,0.0,1.0,0.0,1.0,0.0,0.095679,97.278972,2.836539,-3.39087,4.793217,-0.099054,0,0,1.157092,54174.53381,1.417796,1429508.0,101047.1587,16.246279,0.0,0.426,-0.806,0.714,-0.188667,0.121667,1.336,…,0.588971,0.310735,0.310735,0.310735,0.310735,0.310735,0.160451,0.572353,0.572353,0.572353,0.572353,0.572353,0.421833,0.465657,0.465657,0.465657,0.465657,0.465657,0.373696,0.654549,0.654549,0.654549,0.654549,0.654549,0.2515,0.362775,0.362775,0.362775,0.362775,0.362775,1.0,1.0,1.0,1.0,1.0,1.0,1.0
15,75,"""Afghanistan""",202,"""Kandahar""",65.709343,31.043618,"""2013_04""",2013,4,1.0,0.0,1.0,0.0,1.0,0.0,0.135269,114.19412,2.997978,0.599575,3.522834,0.899823,0,0,1.042512,54174.53381,1.417796,1429508.0,101047.1587,16.246279,0.0,-0.503,-0.131333,1.169333,-0.295667,1.138,0.287,…,0.294579,0.294579,0.294579,0.588971,0.294579,0.294579,0.079251,0.079251,0.079251,0.160451,0.079251,0.079251,0.079346,0.079346,0.079346,0.421833,0.079346,0.079346,0.283934,0.283934,0.283934,0.373696,0.283934,0.283934,0.151222,0.151222,0.151222,0.2515,0.151222,0.151222,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [25]:
end_time = time.time()

print(f"Total execution time with Polars: {end_time - start_time:.2f} seconds")

Total execution time with Polars: 22.56 seconds


# Run the Model


In [41]:
from joblib import Parallel, delayed
from cuml.ensemble import RandomForestRegressor as cuRF
from cuml.ensemble import RandomForestClassifier as cuRFC
import cupy as cp
import cudf
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error

test_splits = [
    ((2012,7), (2013,7)), 
    ((2013,7), (2014,7)), 
    ((2014,7), (2015,7)), 
    ((2016,7), (2017,7)), 
]

train_splits = [ 
    ((2011,7), (2013,7)),
    ((2012,7), (2014,1)),
    ((2013,7), (2015,10)),
    ((2015,7), (2017,2)),
]

### Scikit-Learn implementation of Random Forest

In this part, we will implement the **Random Forest** model using the `scikit-learn` library. 

In [None]:
models = {
    'RF': RandomForestRegressor(),
}

def get_agg_lagged_features(factors):
    return [f"{f}_{t}" for f in factors for t in range(3, 9)] + \
           [f"{f}_province_{t}" for f in factors for t in range(3, 9)] + \
           [f"{f}_country_{t}" for f in factors for t in range(3, 9)]

features = {
    'traditional': time_series.select(
        ['year', 'month'] + 
        [f"fews_ipc_{t}" for t in range(3, 21, 3)] + 
        get_agg_lagged_features(t_variant_traditional_factors) + 
        t_invariant_traditional_factors
    ),
    
    'news': time_series.select(
        ['year', 'month'] + 
        [f"fews_ipc_{t}" for t in range(3, 21, 3)] + 
        get_agg_lagged_features(news_factors)
    ),
    
    'traditional+news': time_series.select(
        ['year', 'month'] + 
        [f"fews_ipc_{t}" for t in range(3, 21, 3)] + 
        get_agg_lagged_features(t_variant_traditional_factors) + 
        t_invariant_traditional_factors + 
        get_agg_lagged_features(news_factors)
    ),   
}

labels_df = time_series.select(['fews_ipc', 'year', 'month'])

def get_time_split(df, start, end):
    return df.filter(
        ((pl.col('year') > start[0]) | ((pl.col('year') == start[0]) & (pl.col('month') >= start[1]))) &
        ((pl.col('year') < end[0]) | ((pl.col('year') == end[0]) & (pl.col('month') <= end[1])))
    )

thresholds = {
    'traditional': (2.236, 3.125), 
    'news': (1.907, 2.712), 
    'traditional+news': (2.105, 3.314),
}

def train_and_evaluate(train, test, f, D):
    results = []
    
    # Get train data using Polars and convert to numpy
    train_df = get_time_split(D, train[0], train[1])
    X_train = train_df.drop(['year', 'month']).to_numpy()
    
    train_labels = get_time_split(labels_df, train[0], train[1])
    y_train = train_labels.drop(['year', 'month']).to_numpy().ravel()

    # Handle NaN values
    nan_mask = ~np.isnan(X_train).any(axis=1)
    X_train = X_train[nan_mask]
    y_train = y_train[nan_mask]
    
    # Get test data
    test_df = get_time_split(D, test[0], test[1])
    X_test = test_df.drop(['year', 'month']).to_numpy()
    
    test_labels = get_time_split(labels_df, test[0], test[1])
    y_test = test_labels.drop(['year', 'month']).to_numpy().ravel()
    
    # Handle NaN values in test data
    nan_mask_test = ~np.isnan(X_test).any(axis=1)
    X_test = X_test[nan_mask_test]
    y_test = y_test[nan_mask_test]
    
    print(f"Rows in X_train: {X_train.shape[0]} \nRows in X_test: {X_test.shape[0]}")

    if X_train.shape[0] <= 0:
        return results
    if X_test.shape[0] <= 0:
        return results

    # Get threshold values
    lower, upper = thresholds[f]
    y_test_binary = np.where((y_test >= lower) & (y_test <= upper), 1, 0)

    for name, model in models.items():
        model.fit(X_train, y_train)
        preds = model.predict(X_test)

        # Calculate RMSE
        rmse = np.sqrt(mean_squared_error(y_test, preds))
        
        stderr = np.std(y_test - preds) / (np.sqrt(len(y_test)) + 0.0001)
        upper_bound = np.sqrt(rmse**2 + 1.96 * stderr)
        lower_bound = np.sqrt(rmse**2 - 1.96 * stderr)

        results.append({
            'method': name, 
            'split': test, 
            'features': f, 
            'rmse': rmse,  
            'lower_bound': lower_bound, 
            'upper_bound': upper_bound,
        })

        print(f"Method: {name}, Split: {test}, Features: {f}, RMSE: {rmse:.4f} [{lower_bound:.4f}, {upper_bound:.4f}]")
        
    return results

# Run model evaluation
all_results = []
for train, test in zip(train_splits, test_splits):
    for f, D in features.items():
        try:
            result = train_and_evaluate(train, test, f, D)
            all_results.append(result)
        except Exception as e:
            print(f"Error: {e} on {train} & {test}")
            continue

# Convert results to DataFrame
results_rfr_sklearn = pl.DataFrame([res for sublist in all_results for res in sublist])
# fig_3a.write_csv('fig_3a.csv')


In [50]:
agg_sklearn  = results_rfr_sklearn.group_by(['method', 'features']).agg(
    pl.mean('rmse').alias('mean_rmse'),
)

agg_sklearn

method,features,mean_rmse
str,str,f64
"""RF""","""news""",0.017914
"""RF""","""traditional+news""",0.019351
"""RF""","""traditional""",0.027491


### CuML implementation of Random Forest

In this part, we will implement the **Random Forest** model using the `cuml` library. This is a GPU-accelerated library that provides a similar API to `scikit-learn`, allowing us to leverage the power of GPUs for faster computations.

In [46]:
models = {
    'RF': cuRF(),
}

def get_agg_lagged_features(factors):
    return [f"{f}_{t}" for f in factors for t in range(3, 9)] + \
           [f"{f}_province_{t}" for f in factors for t in range(3, 9)] + \
           [f"{f}_country_{t}" for f in factors for t in range(3, 9)]

features = {
    'traditional': time_series.select(
        ['year', 'month'] + 
        [f"fews_ipc_{t}" for t in range(3, 21, 3)] + 
        get_agg_lagged_features(t_variant_traditional_factors) + 
        t_invariant_traditional_factors
    ),
    
    'news': time_series.select(
        ['year', 'month'] + 
        [f"fews_ipc_{t}" for t in range(3, 21, 3)] + 
        get_agg_lagged_features(news_factors)
    ),
    
    'traditional+news': time_series.select(
        ['year', 'month'] + 
        [f"fews_ipc_{t}" for t in range(3, 21, 3)] + 
        get_agg_lagged_features(t_variant_traditional_factors) + 
        t_invariant_traditional_factors + 
        get_agg_lagged_features(news_factors)
    ),   
}

labels_df = time_series.select(['fews_ipc', 'year', 'month'])

def get_time_split(df, start, end):
    return df.filter(
        ((pl.col('year') > start[0]) | ((pl.col('year') == start[0]) & (pl.col('month') >= start[1]))) &
        ((pl.col('year') < end[0]) | ((pl.col('year') == end[0]) & (pl.col('month') <= end[1])))
    )

thresholds = {
    'traditional': (2.236, 3.125), 
    'news': (1.907, 2.712), 
    'traditional+news': (2.105, 3.314),
}

def train_and_evaluate(train, test, f, D):
    results = []
    
    # Get train data using Polars and convert to numpy for cuML
    train_df = get_time_split(D, train[0], train[1])
    X_train = train_df.drop(['year', 'month']).to_numpy()
    
    train_labels = get_time_split(labels_df, train[0], train[1])
    y_train = train_labels.drop(['year', 'month']).to_numpy().ravel()

    # Handle NaN values
    nan_mask = np.isnan(X_train).any(axis=1)
    X_train = X_train[~nan_mask]
    y_train = y_train[~nan_mask]
    
    # Get test data using Polars and convert to numpy for cuML
    test_df = get_time_split(D, test[0], test[1])
    X_test = test_df.drop(['year', 'month']).to_numpy()
    
    test_labels = get_time_split(labels_df, test[0], test[1])
    y_test = test_labels.drop(['year', 'month']).to_numpy().ravel()
    
    # Handle NaN values in test data
    nan_mask_test = np.isnan(X_test).any(axis=1)
    X_test = X_test[~nan_mask_test]
    y_test = y_test[~nan_mask_test]
    
    print(f"Rows in X_train: {X_train.shape[0]} \nRows in X_test: {X_test.shape[0]}")
    # Convert to cupy arrays for GPU processing
    X_train = cp.asarray(X_train, dtype=cp.float32)
    y_train = cp.asarray(y_train, dtype=cp.float32)
    X_test = cp.asarray(X_test, dtype=cp.float32)
    y_test = cp.asarray(y_test, dtype=cp.float32)
    
    # Check if data is available
    if X_train.shape[0] <= 0:
        return results
    if X_test.shape[0] <= 0:
        return results
       
    # Get threshold values
    lower, upper = thresholds[f]
    y_test_binary = cp.where((y_test >= lower) & (y_test <= upper), 1, 0)

    for name, model in models.items():
        model.fit(X_train, y_train)
        preds = model.predict(X_test)

        # Calculate RMSE
        rmse = np.sqrt(np.mean((y_test - preds) ** 2)).get()
        
        stderr = np.std(y_test - preds) / (np.sqrt(len(y_test))+0.0001)
        upper_bound = np.sqrt(rmse**2 + 1.96 * stderr)
        lower_bound = np.sqrt(rmse**2 - 1.96 * stderr)
        # precision, recall, _ = precision_recall_curve(y_test_binary, preds)
        # aucpr = auc(recall, precision)
        
        results.append({
            'method': name, 
            'split': test, 
            'features': f, 
            'rmse': rmse,  
            'lower_bound': lower_bound, 
            'upper_bound': upper_bound,
            # 'aucpr': aucpr
        })

        # print(f"Method: {name}, Split: {test}, Features: {f}, RMSE: {rmse:.4f}")
        print(f"Method: {name}, Split: {test}, Features: {f}, RMSE: {rmse:.4f} [{lower_bound:.4f}, {upper_bound:.4f}]")
        
    return results

# Run model evaluation
all_results = []
for train, test in zip(train_splits, test_splits):
    for f, D in features.items():
        try:
            result = train_and_evaluate(train, test, f, D)
            all_results.append(result)
        except Exception as e:
            print(f"Error: {e} on {train} & {test}")
            continue

# Convert results to DataFrame
results_rfr_cuml = pl.DataFrame([res for sublist in all_results for res in sublist])
# To save results
# results_rfr_cuml.write_csv('fig_3a.csv')

Rows in X_train: 6110 
Rows in X_test: 5127
Method: RF, Split: ((2012, 7), (2013, 7)), Features: traditional, RMSE: 0.0004 [nan, 0.0033]
Rows in X_train: 6110 
Rows in X_test: 5127
Method: RF, Split: ((2012, 7), (2013, 7)), Features: news, RMSE: 0.0020 [nan, 0.0077]
Rows in X_train: 6110 
Rows in X_test: 5127
Method: RF, Split: ((2012, 7), (2013, 7)), Features: traditional+news, RMSE: 0.0020 [nan, 0.0077]
Rows in X_train: 7131 
Rows in X_test: 4976
Method: RF, Split: ((2013, 7), (2014, 7)), Features: traditional, RMSE: 0.0899 [0.0748, 0.1028]
Rows in X_train: 7131 
Rows in X_test: 4976
Method: RF, Split: ((2013, 7), (2014, 7)), Features: news, RMSE: 0.0635 [0.0476, 0.0761]
Rows in X_train: 7131 
Rows in X_test: 4976
Method: RF, Split: ((2013, 7), (2014, 7)), Features: traditional+news, RMSE: 0.0636 [0.0478, 0.0762]
Rows in X_train: 10361 
Rows in X_test: 5485
Method: RF, Split: ((2014, 7), (2015, 7)), Features: traditional, RMSE: 0.0061 [nan, 0.0141]
Rows in X_train: 10361 
Rows in X_t

## Data Splits and Train/Test Set Sizes

### 1. Split: ((2012, 7), (2013, 7))
- Rows in Train set: 6110
- Rows in Test set: 5127

### 2. Split: ((2013, 7), (2014, 7))
- Rows in Train set: 7131
- Rows in Test set: 4976

### 3. Split: ((2014, 7), (2015, 7))
- Rows in Train set: 10361
- Rows in Test set: 5485

### 4. Split: ((2016, 7), (2017, 7))
- Rows in Train set: 6472
- Rows in Test set: 3335

### Average sizes of Train and Test sets
- Rows in Train set: 7516
- Rows in Test set: 4730

In [51]:
agg_cuml  = results_rfr_cuml.group_by(['method', 'features']).agg(
    pl.mean('rmse').alias('mean_rmse')
)

agg_cuml

method,features,mean_rmse
str,str,f64
"""RF""","""news""",0.021064
"""RF""","""traditional+news""",0.021106
"""RF""","""traditional""",0.025874


### LogisticAT model

The `mord.LogisticAT` model is an ordinal logistic regression that respects the natural ordering between IPC phases, learning thresholds between classes instead of treating them as unrelated. It performs exceptionally well on news features because news data often contains sharp, timely signals that align closely with IPC phase shifts. As a result, LogisticAT can accurately separate classes when the input features reflect structured, incremental changes, leading to much higher accuracy compared to traditional features.

In [48]:
import mord
from sklearn.metrics import accuracy_score, log_loss

test_splits = [
    ((2012,7), (2013,7)),  
    ((2013,7), (2014,7)), 
    ((2014,7), (2015,7)),  
    ((2016,7), (2017,7)),  
    ((2017,7), (2018,7)),  
    ((2017,7), (2018,6)),  
]

train_splits = [
    ((2011,7), (2012,6)),  
    ((2012,7), (2013,6)), 
    ((2013,7), (2014,6)),  
    ((2015,7), (2016,6)),  
    ((2016,7), (2017,6)), 
    ((2016,7), (2017,6)), 
]

models = {
    'Ordinal_LogisticAT': mord.LogisticAT(alpha=1.0),  # 'alpha' controls regularization
}

def get_agg_lagged_features(factors):
    return [f"{f}_{t}" for f in factors for t in range(3, 9)] + \
           [f"{f}_province_{t}" for f in factors for t in range(3, 9)] + \
           [f"{f}_country_{t}" for f in factors for t in range(3, 9)]

features = {
    'traditional': time_series.select(
        ['year', 'month'] + 
        [f"fews_ipc_{t}" for t in range(3, 21, 3)] + 
        get_agg_lagged_features(t_variant_traditional_factors) + 
        t_invariant_traditional_factors
    ),
    
    'news': time_series.select(
        ['year', 'month'] + 
        [f"fews_ipc_{t}" for t in range(3, 21, 3)] + 
        get_agg_lagged_features(news_factors)
    ),
    
    'traditional+news': time_series.select(
        ['year', 'month'] + 
        [f"fews_ipc_{t}" for t in range(3, 21, 3)] + 
        get_agg_lagged_features(t_variant_traditional_factors) + 
        t_invariant_traditional_factors + 
        get_agg_lagged_features(news_factors)
    ),   
}

labels_df = time_series.select(['fews_ipc', 'year', 'month'])

def get_time_split(df, start, end):
    return df.filter(
        ((pl.col('year') > start[0]) | ((pl.col('year') == start[0]) & (pl.col('month') >= start[1]))) &
        ((pl.col('year') < end[0]) | ((pl.col('year') == end[0]) & (pl.col('month') <= end[1])))
    )

def train_and_evaluate(train, test, f, D):
    results = []
    
    train_df = get_time_split(D, train[0], train[1])
    X_train = train_df.drop(['year', 'month']).to_numpy()
    
    train_labels = get_time_split(labels_df, train[0], train[1])
    y_train = train_labels.drop(['year', 'month']).to_numpy().ravel().astype(int)

    test_df = get_time_split(D, test[0], test[1])
    X_test = test_df.drop(['year', 'month']).to_numpy()
    
    test_labels = get_time_split(labels_df, test[0], test[1])
    y_test = test_labels.drop(['year', 'month']).to_numpy().ravel().astype(int)
    
    train_mask = ~np.isnan(X_train).any(axis=1)
    test_mask = ~np.isnan(X_test).any(axis=1)
    
    X_train = X_train[train_mask]
    y_train = y_train[train_mask]
    X_test = X_test[test_mask]
    y_test = y_test[test_mask]

    unique_classes = np.unique(y_train)

    if X_train.shape[0] <= 0 or X_test.shape[0] <= 0:
        return results

    for name, model in models.items():
        model.fit(X_train, y_train)
        
        preds = model.predict(X_test)

        # Cross-Entropy loss using log_loss from sklearn
        cross_entropy_loss = log_loss(y_test, model.predict_proba(X_test), labels=unique_classes)

        accuracy = accuracy_score(y_test, preds)

        results.append({
            'method': name,
            'split': test,
            'features': f,
            'cross_entropy_loss': cross_entropy_loss,
            'accuracy': accuracy,
        })
        
        print(f"Method: {name}, Split: {test}, Features: {f}, Cross-Entropy Loss: {cross_entropy_loss:.4f}, Accuracy: {accuracy:.4f}")

    return results

all_results = []
for train, test in zip(train_splits, test_splits):
    for f, D in features.items():
        try:
            result = train_and_evaluate(train, test, f, D)
            all_results.append(result)
        except Exception as e:
            print(f"Error: {e} on {train} & {test}")
            continue

results_oc = pl.DataFrame([res for sublist in all_results for res in sublist])
# results_oc.write_csv('fig_3a.csv')


Method: Ordinal_LogisticAT, Split: ((2012, 7), (2013, 7)), Features: traditional, Cross-Entropy Loss: 0.8928, Accuracy: 0.6860
Method: Ordinal_LogisticAT, Split: ((2012, 7), (2013, 7)), Features: news, Cross-Entropy Loss: 0.0733, Accuracy: 0.9799
Method: Ordinal_LogisticAT, Split: ((2012, 7), (2013, 7)), Features: traditional+news, Cross-Entropy Loss: 0.8806, Accuracy: 0.6813
Method: Ordinal_LogisticAT, Split: ((2013, 7), (2014, 7)), Features: traditional, Cross-Entropy Loss: 0.7473, Accuracy: 0.7170
Method: Ordinal_LogisticAT, Split: ((2013, 7), (2014, 7)), Features: news, Cross-Entropy Loss: 0.0340, Accuracy: 0.9895
Method: Ordinal_LogisticAT, Split: ((2013, 7), (2014, 7)), Features: traditional+news, Cross-Entropy Loss: 0.7139, Accuracy: 0.7297
Method: Ordinal_LogisticAT, Split: ((2014, 7), (2015, 7)), Features: traditional, Cross-Entropy Loss: 0.8837, Accuracy: 0.7227
Method: Ordinal_LogisticAT, Split: ((2014, 7), (2015, 7)), Features: news, Cross-Entropy Loss: 0.0170, Accuracy: 0.

In [52]:
agg_oc = results_oc.group_by(['method', 'features']).agg(
    pl.mean('cross_entropy_loss').alias('mean_cel'),
    pl.mean('accuracy').alias('mean_acc')   
)

agg_oc

method,features,mean_cel,mean_acc
str,str,f64,f64
"""Ordinal_LogisticAT""","""traditional+news""",0.966862,0.6577
"""Ordinal_LogisticAT""","""news""",0.021925,0.994161
"""Ordinal_LogisticAT""","""traditional""",1.072231,0.656879
