# ViEWS prediction competition benchmark

This is the ViEWS prediction competition benchmark model notebook.
If you are a member of the public please run the following in your terminal to fetch the data needed to run this notebook.

    conda activate views2
    cd /path/to/ViEWS2
    python runners/import_data.py --fetch --dataset
    
Which will download the latest available data from the ViEWS website and construct the datasets used in this notebook.

ViEWS team members should be able to run the notebook without any additional steps after configuring database access.

In [None]:
import sys
import os
import logging

import pandas as pd
import numpy as np
from datetime import datetime

from sklearn.ensemble import RandomForestRegressor
from joblib import dump, load

import views
from views import Period, Model, Downsampling
from views.utils.data import assign_into_df
from views.apps.transforms import lib as translib
from views.apps.evaluation import lib as evallib, feature_importance as fi
from views.apps.model import api
from views.apps.extras import extras

In [None]:
logging.basicConfig(
    level=logging.DEBUG,
    format=views.config.LOGFMT,
    handlers=[
        logging.StreamHandler(),
    ],
)
log = logging.getLogger(__name__)

# Setup

## Fetch data?

In [None]:
# Do you wish to fetch the latest public data? If so, change False to True and run this cell
# Cells below will fail if this is not run if you haven't imported data yourself yet.
if False:
    path_zip = views.apps.data.public.fetch_latest_zip_from_website(path_dir_destination=views.DIR_SCRATCH)
    views.apps.data.public.import_tables_and_geoms(tables=views.TABLES, geometries=views.GEOMETRIES, path_zip=path_zip)

In [None]:
# Enter the level of analysis, and whether to predict Africa-only.
level = "cm"
pred_africa = True

In [None]:
model_path = "./models/{sub}"
out_paths = {
    "evaluation": model_path.format(sub="evaluation"),
    "features": model_path.format(sub="features")
}
for k, v in out_paths.items():
    if not os.path.isdir(v):
        os.makedirs(v)

In [None]:
if level == "pgm":
    dataset = views.DATASETS["pgm_africa_imp_0"]
elif level == "cm":
    dataset = views.DATASETS["cm_global_imp_0"]
df = dataset.df

In [None]:
# Define our 2017.01-2019.12 development period
# Keeping periods in a list lets us easily expand this as the 
# updated data becomes available
period_calib = api.Period(
    name="calib", 
    train_start=121,   # 1990-01
    train_end=408,     # 2013.12
    predict_start=409, # 2014.01
    predict_end=444,   # 2016.12
)
period_test = api.Period(
    name="test", 
    train_start=121,   # 1990-01
    train_end=444,     # 2016.12
    predict_start=445, # 2017.01
    predict_end=480,   # 2019.12
)
periods = [period_calib, period_test]

In [None]:
# The steps to train, predict and evaluate for.
steps = [1, 2, 3, 4, 5, 6]

## Feature selection and model setup

In [None]:
cols_cm = [
    'ged_count_sb',
    'ged_count_ns',
    'ged_count_os',
    'ged_best_sb',
    'ged_best_ns',
    'ged_best_os',
    'ged_dummy_sb',
    'ged_dummy_ns',
    'ged_dummy_os',
    'wdi_ag_lnd_agri_zs',
    'wdi_ag_lnd_arbl_zs',
    'wdi_ag_lnd_frst_k2',
    'wdi_ag_lnd_prcp_mm',
    'wdi_ag_lnd_totl_k2',
    'wdi_ag_lnd_totl_ru_k2',
    'wdi_ag_prd_crop_xd',
    'wdi_ag_prd_food_xd',
    'wdi_ag_prd_lvsk_xd',
    'wdi_ag_srf_totl_k2',
    'wdi_ag_yld_crel_kg',
    'wdi_bg_gsr_nfsv_gd_zs',
    'wdi_bm_klt_dinv_wd_gd_zs',
    'wdi_bn_cab_xoka_gd_zs',
    'wdi_bx_gsr_ccis_zs',
    'wdi_bx_gsr_cmcp_zs',
    'wdi_bx_gsr_insf_zs',
    'wdi_bx_gsr_mrch_cd',
    'wdi_bx_gsr_tran_zs',
    'wdi_bx_gsr_trvl_zs',
    'wdi_bx_klt_dinv_cd_wd',
    'wdi_bx_klt_dinv_wd_gd_zs',
    'wdi_bx_trf_pwkr_dt_gd_zs',
    'wdi_dt_dod_dect_gn_zs',
    'wdi_dt_dod_pvlx_gn_zs',
    'wdi_dt_oda_oatl_kd',
    'wdi_dt_oda_odat_gn_zs',
    'wdi_dt_oda_odat_pc_zs',
    'wdi_dt_tds_dect_gn_zs',
    'wdi_eg_elc_accs_zs',
    'wdi_eg_use_elec_kh_pc',
    'wdi_eg_use_pcap_kg_oe',
    'wdi_en_pop_slum_ur_zs',
    'wdi_en_urb_mcty_tl_zs',
    'wdi_ep_pmp_desl_cd',
    'wdi_ep_pmp_sgas_cd',
    'wdi_fp_cpi_totl',
    'wdi_fr_inr_dpst',
    'wdi_fr_inr_lndp',
    'wdi_gc_dod_totl_gd_zs',
    'wdi_ic_bus_ease_xq',
    'wdi_iq_cpa_econ_xq',
    'wdi_iq_cpa_fisp_xq',
    'wdi_iq_cpa_gndr_xq',
    'wdi_iq_cpa_macr_xq',
    'wdi_iq_cpa_prop_xq',
    'wdi_iq_cpa_pubs_xq',
    'wdi_iq_cpa_soci_xq',
    'wdi_iq_cpa_trad_xq',
    'wdi_iq_cpa_tran_xq',
    'wdi_ms_mil_mprt_kd',
    'wdi_ms_mil_xpnd_gd_zs',
    'wdi_ms_mil_xpnd_zs',
    'wdi_ne_con_prvt_pc_kd_zg',
    'wdi_ne_dab_totl_kd',
    'wdi_ne_dab_totl_zs',
    'wdi_ne_exp_gnfs_zs',
    'wdi_ne_gdi_totl_zs',
    'wdi_ne_imp_gnfs_kd',
    'wdi_ne_imp_gnfs_kd_zg',
    'wdi_ne_imp_gnfs_zs',
    'wdi_ne_rsb_gnfs_zs',
    'wdi_ne_trd_gnfs_zs',
    'wdi_nv_agr_empl_kd',
    'wdi_nv_agr_totl_cd',
    'wdi_nv_agr_totl_cn',
    'wdi_nv_agr_totl_kd',
    'wdi_nv_agr_totl_kd_zg',
    'wdi_nv_agr_totl_kn',
    'wdi_nv_agr_totl_zs',
    'wdi_nv_ind_empl_kd',
    'wdi_nv_ind_manf_cd',
    'wdi_nv_ind_manf_cn',
    'wdi_nv_ind_manf_kd',
    'wdi_nv_ind_manf_kd_zg',
    'wdi_nv_ind_manf_kn',
    'wdi_nv_ind_manf_zs',
    'wdi_nv_ind_totl_cd',
    'wdi_nv_ind_totl_cn',
    'wdi_nv_ind_totl_kd',
    'wdi_nv_ind_totl_kd_zg',
    'wdi_nv_ind_totl_kn',
    'wdi_nv_ind_totl_zs',
    'wdi_nv_mnf_chem_zs_un',
    'wdi_nv_mnf_fbto_zs_un',
    'wdi_nv_mnf_mtrn_zs_un',
    'wdi_nv_mnf_othr_zs_un',
    'wdi_nv_mnf_tech_zs_un',
    'wdi_nv_mnf_txtl_zs_un',
    'wdi_nv_srv_empl_kd',
    'wdi_nv_srv_totl_cd',
    'wdi_nv_srv_totl_cn',
    'wdi_nv_srv_totl_kd',
    'wdi_nv_srv_totl_kd_zg',
    'wdi_nv_srv_totl_kn',
    'wdi_nv_srv_totl_zs',
    'wdi_ny_adj_dfor_cd',
    'wdi_ny_adj_dmin_gn_zs',
    'wdi_ny_adj_dres_gn_zs',
    'wdi_ny_adj_ictr_gn_zs',
    'wdi_ny_adj_nnty_kd',
    'wdi_ny_adj_nnty_kd_zg',
    'wdi_ny_gdp_coal_rt_zs',
    'wdi_ny_gdp_defl_kd_zg',
    'wdi_ny_gdp_defl_kd_zg_ad',
    'wdi_ny_gdp_defl_zs',
    'wdi_ny_gdp_defl_zs_ad',
    'wdi_ny_gdp_disc_cn',
    'wdi_ny_gdp_disc_kn',
    'wdi_ny_gdp_fcst_cd',
    'wdi_ny_gdp_fcst_cn',
    'wdi_ny_gdp_fcst_kd',
    'wdi_ny_gdp_fcst_kn',
    'wdi_ny_gdp_frst_rt_zs',
    'wdi_ny_gdp_minr_rt_zs',
    'wdi_ny_gdp_mktp_cd',
    'wdi_ny_gdp_mktp_cn',
    'wdi_ny_gdp_mktp_cn_ad',
    'wdi_ny_gdp_mktp_kd',
    'wdi_ny_gdp_mktp_kd_zg',
    'wdi_ny_gdp_mktp_kn',
    'wdi_ny_gdp_mktp_pp_cd',
    'wdi_ny_gdp_mktp_pp_kd',
    'wdi_ny_gdp_ngas_rt_zs',
    'wdi_ny_gdp_pcap_cd',
    'wdi_ny_gdp_pcap_cn',
    'wdi_ny_gdp_pcap_kd',
    'wdi_ny_gdp_pcap_kd_zg',
    'wdi_ny_gdp_pcap_kn',
    'wdi_ny_gdp_pcap_pp_cd',
    'wdi_ny_gdp_pcap_pp_kd',
    'wdi_ny_gdp_petr_rt_zs',
    'wdi_ny_gdp_totl_rt_zs',
    'wdi_ny_gnp_mktp_kd',
    'wdi_ny_gnp_mktp_pp_kd',
    'wdi_per_si_allsi_cov_pop_tot',
    'wdi_per_si_allsi_cov_q1_tot',
    'wdi_per_si_allsi_cov_q2_tot',
    'wdi_per_si_allsi_cov_q3_tot',
    'wdi_per_si_allsi_cov_q4_tot',
    'wdi_per_si_allsi_cov_q5_tot',
    'wdi_se_adt_1524_lt_fe_zs',
    'wdi_se_adt_1524_lt_ma_zs',
    'wdi_se_adt_1524_lt_zs',
    'wdi_se_adt_litr_fe_zs',
    'wdi_se_adt_litr_ma_zs',
    'wdi_se_adt_litr_zs',
    'wdi_se_enr_prim_fm_zs',
    'wdi_se_enr_prsc_fm_zs',
    'wdi_se_prm_cmpt_zs',
    'wdi_se_prm_cuat_fe_zs',
    'wdi_se_prm_cuat_ma_zs',
    'wdi_se_prm_cuat_zs',
    'wdi_se_prm_enrr',
    'wdi_se_prm_nenr',
    'wdi_se_prm_tenr_fe',
    'wdi_se_prm_tenr_ma',
    'wdi_se_sec_cmpt_lo_zs',
    'wdi_se_sec_cuat_lo_fe_zs',
    'wdi_se_sec_cuat_lo_ma_zs',
    'wdi_se_sec_cuat_lo_zs',
    'wdi_se_sec_nenr',
    'wdi_se_ter_cuat_do_fe_zs',
    'wdi_se_ter_cuat_do_ma_zs',
    'wdi_se_ter_cuat_do_zs',
    'wdi_sg_gen_parl_zs',
    'wdi_sg_vaw_reas_zs',
    'wdi_sh_dyn_0514',
    'wdi_sh_dyn_mort',
    'wdi_sh_dyn_mort_fe',
    'wdi_sh_dyn_mort_ma',
    'wdi_sh_h2o_basw_ru_zs',
    'wdi_sh_h2o_basw_ur_zs',
    'wdi_sh_h2o_basw_zs',
    'wdi_sh_mmr_risk_zs',
    'wdi_sh_sta_bass_ru_zs',
    'wdi_sh_sta_bass_ur_zs',
    'wdi_sh_sta_bass_zs',
    'wdi_sh_sta_maln_fe_zs',
    'wdi_sh_sta_maln_ma_zs',
    'wdi_sh_sta_maln_zs',
    'wdi_sh_sta_mmrt',
    'wdi_sh_sta_mmrt_ne',
    'wdi_sh_sta_stnt_fe_zs',
    'wdi_sh_sta_stnt_ma_zs',
    'wdi_sh_sta_stnt_zs',
    'wdi_sh_sta_traf_p5',
    'wdi_sh_sta_wash_p5',
    'wdi_sh_svr_wast_fe_zs',
    'wdi_sh_svr_wast_ma_zs',
    'wdi_sh_svr_wast_zs',
    'wdi_si_dst_02nd_20',
    'wdi_si_dst_03rd_20',
    'wdi_si_dst_04th_20',
    'wdi_si_dst_05th_20',
    'wdi_si_dst_10th_10',
    'wdi_si_dst_frst_10',
    'wdi_si_dst_frst_20',
    'wdi_si_pov_dday',
    'wdi_si_pov_gaps',
    'wdi_si_pov_gini',
    'wdi_si_pov_lmic',
    'wdi_si_pov_umic',
    'wdi_sl_agr_empl_ma_zs',
    'wdi_sl_agr_empl_zs',
    'wdi_sl_ind_empl_zs',
    'wdi_sl_srv_empl_zs',
    'wdi_sl_tlf_totl_fe_zs',
    'wdi_sl_uem_advn_fe_zs',
    'wdi_sl_uem_advn_ma_zs',
    'wdi_sl_uem_advn_zs',
    'wdi_sl_uem_neet_fe_zs',
    'wdi_sl_uem_neet_ma_zs',
    'wdi_sl_uem_neet_zs',
    'wdi_sl_uem_totl_fe_zs',
    'wdi_sl_uem_totl_ma_zs',
    'wdi_sl_uem_totl_zs',
    'wdi_sm_pop_netm',
    'wdi_sm_pop_refg',
    'wdi_sm_pop_refg_or',
    'wdi_sm_pop_totl_zs',
    'wdi_sn_itk_defc_zs',
    'wdi_sp_dyn_amrt_fe',
    'wdi_sp_dyn_amrt_ma',
    'wdi_sp_dyn_imrt_fe_in',
    'wdi_sp_dyn_imrt_in',
    'wdi_sp_dyn_imrt_ma_in',
    'wdi_sp_dyn_le00_fe_in',
    'wdi_sp_dyn_le00_in',
    'wdi_sp_dyn_le00_ma_in',
    'wdi_sp_dyn_tfrt_in',
    'wdi_sp_dyn_wfrt',
    'wdi_sp_hou_fema_zs',
    'wdi_sp_pop_0014_fe_zs',
    'wdi_sp_pop_0014_ma_zs',
    'wdi_sp_pop_0014_to_zs',
    'wdi_sp_pop_1564_fe_zs',
    'wdi_sp_pop_1564_ma_zs',
    'wdi_sp_pop_1564_to_zs',
    'wdi_sp_pop_65up_fe_zs',
    'wdi_sp_pop_65up_ma_zs',
    'wdi_sp_pop_65up_to_zs',
    'wdi_sp_pop_dpnd',
    'wdi_sp_pop_dpnd_ol',
    'wdi_sp_pop_dpnd_yg',
    'wdi_sp_pop_grow',
    'wdi_sp_pop_totl',
    'wdi_sp_rur_totl_zg',
    'wdi_sp_rur_totl_zs',
    'wdi_sp_urb_grow',
    'wdi_sp_urb_totl_in_zs',
    'wdi_st_int_arvl',
    'wdi_st_int_rcpt_xp_zs',
    'wdi_tx_val_agri_zs_un',
    'wdi_tx_val_food_zs_un',
    'wdi_tx_val_fuel_zs_un',
    'wdi_tx_val_mmtl_zs_un',
    'wdi_tx_val_tech_mf_zs',
    'wdi_vc_btl_deth',
    'wdi_vc_idp_nwcv',
    'wdi_vc_idp_nwds',
    'wdi_vc_idp_tocv',
    'wdi_vc_pkp_totl_un',
    'vdem_e_regionpol',
    'vdem_e_regionpol_6c',
    'vdem_v2x_accountability',
    'vdem_v2x_api',
    'vdem_v2x_civlib',
    'vdem_v2x_clphy',
    'vdem_v2x_clpol',
    'vdem_v2x_clpriv',
    'vdem_v2x_corr',
    'vdem_v2x_cspart',
    'vdem_v2x_delibdem',
    'vdem_v2x_diagacc',
    'vdem_v2x_divparctrl',
    'vdem_v2x_edcomp_thick',
    'vdem_v2x_egal',
    'vdem_v2x_egaldem',
    'vdem_v2x_elecoff',
    'vdem_v2x_elecreg',
    'vdem_v2x_ex_confidence',
    'vdem_v2x_ex_direlect',
    'vdem_v2x_ex_hereditary',
    'vdem_v2x_ex_military',
    'vdem_v2x_ex_party',
    'vdem_v2x_execorr',
    'vdem_v2x_feduni',
    'vdem_v2x_frassoc_thick',
    'vdem_v2x_freexp',
    'vdem_v2x_freexp_altinf',
    'vdem_v2x_gencl',
    'vdem_v2x_gencs',
    'vdem_v2x_gender',
    'vdem_v2x_genpp',
    'vdem_v2x_horacc',
    'vdem_v2x_hosabort',
    'vdem_v2x_hosinter',
    'vdem_v2x_jucon',
    'vdem_v2x_legabort',
    'vdem_v2x_libdem',
    'vdem_v2x_liberal',
    'vdem_v2x_mpi',
    'vdem_v2x_neopat',
    'vdem_v2x_partip',
    'vdem_v2x_partipdem',
    'vdem_v2x_polyarchy',
    'vdem_v2x_pubcorr',
    'vdem_v2x_regime',
    'vdem_v2x_regime_amb',
    'vdem_v2x_rule',
    'vdem_v2x_suffr',
    'vdem_v2x_veracc',
    'vdem_v2xcl_acjst',
    'vdem_v2xcl_disc',
    'vdem_v2xcl_dmove',
    'vdem_v2xcl_prpty',
    'vdem_v2xcl_rol',
    'vdem_v2xcl_slave',
    'vdem_v2xcs_ccsi',
    'vdem_v2xdd_cic',
    'vdem_v2xdd_dd',
    'vdem_v2xdd_i_or',
    'vdem_v2xdd_i_pi',
    'vdem_v2xdd_i_pl',
    'vdem_v2xdd_i_rf',
    'vdem_v2xdd_toc',
    'vdem_v2xdl_delib',
    'vdem_v2xeg_eqaccess',
    'vdem_v2xeg_eqdr',
    'vdem_v2xeg_eqprotec',
    'vdem_v2xel_elecparl',
    'vdem_v2xel_elecpres',
    'vdem_v2xel_frefair',
    'vdem_v2xel_locelec',
    'vdem_v2xel_regelec',
    'vdem_v2xex_elecleg',
    'vdem_v2xex_elecreg',
    'vdem_v2xlg_elecreg',
    'vdem_v2xlg_legcon',
    'vdem_v2xlg_leginter',
    'vdem_v2xme_altinf',
    'vdem_v2xnp_client',
    'vdem_v2xnp_pres',
    'vdem_v2xnp_regcorr',
    'vdem_v2xpe_exlecon',
    'vdem_v2xpe_exlgender',
    'vdem_v2xpe_exlgeo',
    'vdem_v2xpe_exlpol',
    'vdem_v2xpe_exlsocgr',
    'vdem_v2xps_party',
    'fvp_gdp200',
    'fvp_gdpcap_nonoilrent',
    'fvp_gdpcap_oilrent',
    'fvp_gdppc200',
    'fvp_govt',
    'fvp_grgdpcap_nonoilrent',
    'fvp_grgdpcap_oilrent',
    'fvp_grgdppercapita200',
    'fvp_grpop200',
    'fvp_indepyear',
    'fvp_lngdp200',
    'fvp_lngdpcap_nonoilrent',
    'fvp_lngdpcap_oilrent',
    'fvp_lngdppercapita200',
    'fvp_lnoilrent',
    'fvp_lnpop200',
    'fvp_ltimeindep',
    'fvp_population200',
    'fvp_prop_diexpo',
    'fvp_prop_discexclpowless',
    'fvp_prop_discriminated',
    'fvp_prop_dominant',
    'fvp_prop_excluded',
    'fvp_prop_irrelevant',
    'fvp_prop_junpart',
    'fvp_prop_powerless',
    'fvp_prop_selfexclusion',
    'fvp_prop_senpart',
    'fvp_ssp2_edu_sec_15_24_prop',
    'fvp_ssp2_urban_share_iiasa',
    'fvp_timeindep',
    'fvp_timesincepreindepwar',
    'fvp_timesinceregimechange',
    'reign_age',
    'reign_anticipation',
    'reign_change_recent',
    'reign_couprisk',
    'reign_defeat_recent',
    'reign_delayed',
    'reign_direct_recent',
    'reign_elected',
    'reign_election_now',
    'reign_election_recent',
    'reign_exec_ant',
    'reign_exec_recent',
    'reign_gov_dominant_party',
    'reign_gov_foreign_occupied',
    'reign_gov_indirect_military',
    'reign_gov_military',
    'reign_gov_military_personal',
    'reign_gov_monarchy',
    'reign_gov_oligarchy',
    'reign_gov_parliamentary_democracy',
    'reign_gov_party_military',
    'reign_gov_party_personal',
    'reign_gov_party_personal_military_hybrid',
    'reign_gov_personal_dictatorship',
    'reign_gov_presidential_democracy',
    'reign_gov_provisional_civilian',
    'reign_gov_provisional_military',
    'reign_gov_warlordism',
    'reign_indirect_recent',
    'reign_irreg_lead_ant',
    'reign_irregular',
    'reign_lastelection',
    'reign_lead_recent',
    'reign_leg_ant',
    'reign_leg_recent',
    'reign_loss',
    'reign_male',
    'reign_militarycareer',
    'reign_nochange_recent',
    'reign_pctile_risk',
    'reign_precip',
    'reign_prev_conflict',
    'reign_pt_attempt',
    'reign_pt_suc',
    'reign_ref_ant',
    'reign_ref_recent',
    'reign_tenure_months',
    'reign_victory_recent',
    'icgcw_alerts',
    'icgcw_deteriorated',
    'icgcw_improved',
    'icgcw_opportunities',
    'icgcw_unobserved',
    'in_africa'
]

In [None]:
cols_pgm = [
    "time_since_ged_dummy_sb",
    "time_since_ged_dummy_ns",
    "time_since_ged_dummy_os",
    "tlag_1_ged_best_ns",
    "tlag_1_ged_best_os",
    "tlag_1_ged_best_sb",
    "time_since_splag_1_1_ged_dummy_sb",
    "time_since_splag_1_1_ged_dummy_ns",
    "time_since_splag_1_1_ged_dummy_os",
    "time_since_greq_25_ged_best_sb",
    "time_since_greq_25_ged_best_ns",
    "time_since_greq_25_ged_best_os",
    "time_since_greq_25_splag_1_1_ged_best_sb",
    "time_since_greq_25_splag_1_1_ged_best_ns",
    "time_since_greq_25_splag_1_1_ged_best_os",
    "time_since_greq_500_ged_best_sb",
    "time_since_greq_500_ged_best_ns",
    "time_since_greq_500_ged_best_os",
    "time_since_greq_500_splag_1_1_ged_best_sb",
    "time_since_greq_500_splag_1_1_ged_best_ns",
    "time_since_greq_500_splag_1_1_ged_best_os",
    "ged_best_sb",
    "ged_best_ns",
    "ged_best_os",
    "pgd_bdist3",
    "pgd_capdist",
    "pgd_agri_ih",
    "pgd_pop_gpw_sum",
    "pgd_ttime_mean",
    "spdist_pgd_diamsec",
    "pgd_pasture_ih",
    "pgd_savanna_ih",
    "pgd_forest_ih",
    "pgd_urban_ih",
    "pgd_barren_ih",
    "pgd_gcp_mer",
    "fvp_ssp2_edu_sec_15_24_prop",
    "ln_fvp_timeindep",
    "ln_fvp_timesincepreindepwar",
    "spdist_pgd_petroleum",
    # "pgd_droughtstart_spi", many years old, dont use
]

In [None]:
if level == "cm":
    cols_features = cols_cm
elif level =="pgm":
    cols_features = cols_pgm

In [None]:
# Specify an optional downsampling level
if level=="cm":
    # No downsampling.
    downsampling = api.Downsampling(share_positive=1, share_negative=1, threshold=0)
if level=="pgm":
    downsampling = api.Downsampling(share_positive=1, share_negative=0.01, threshold=0)

In [None]:
# Specify number of estimators in RF estimator
n_estimators = 200

In [None]:
# Define the benchmark models.
benchmark_delta = api.Model(
    name="benchmark_delta",                
    col_outcome="ln_ged_best_sb",    
    cols_features=cols_features,     
    steps=steps,                     
    outcome_type="real",             
    periods=periods,                 
    estimator=RandomForestRegressor( 
        n_estimators=n_estimators,
        criterion="mse",
        n_jobs=-1,
    ),
    delta_outcome=True,            
    downsampling=downsampling,
)

models = [benchmark_delta]

## Model fit, prediction, and evaluation

In [None]:
%%time
# Train all models
for model in models:
    model.fit_estimators(df)

In [None]:
# If pred_africa, subset df to Africa
if pred_africa and level=="cm":
    df = df.loc[df.in_africa==1]

In [None]:
# Store predictions and calibrated predictions for all models in our dataframe
for model in models:
    df_predictions = model.predict(df)
    df = assign_into_df(df, df_predictions)
    df_predictions = model.predict_calibrated(
        df=df,
        period_calib = period_calib,
        period_test = period_test
    )
    df = assign_into_df(df, df_predictions)

In [None]:
# Save model objects to file
for model in models:
    model.save()

In [None]:
# Evaluate all models. Scores are stored in the model object
for model in models:
    model.evaluate(df)

## Evaluation output

### Performance metrics by step

In [None]:
# Select the partition here.
partition = "test"

for model in models:
    for calib in ["uncalibrated", "calibrated"]:
        scores = {
            "Step":[], 
            "MSE":[], 
            "R2":[]
        }
        if model.delta_outcome:
            scores.update({"TADDA":[]}) 
            
        for key, value in model.scores[partition].items():
            if key != "sc":
                scores["Step"].append(key)
                scores["MSE"].append(value[calib]["mse"])
                scores["R2"].append(value[calib]["r2"])
                if model.delta_outcome:
                    scores["TADDA"].append(value[calib]["tadda_score"])

        out = pd.DataFrame(scores)
        tex = out.to_latex(index=False)

        # Add meta.
        now = datetime.now().strftime("%Y/%m/%d %H:%M:%S")
        meta = f"""
        %Output created by wb_models.ipynb.
        %Evaluation of {model.col_outcome} per step.
        %Run on selected {model.name} features at {level} level.
        %Produced on {now}, written to {out_paths["evaluation"]}.
        \\
        """
        tex = meta + tex
        path_out = os.path.join(
            out_paths["evaluation"], 
            f"{model.name}_{level}_{calib}_scores.tex"
        )
        with open(path_out, "w") as f:
            f.write(tex)
        print(f"Wrote scores table to {path_out}.")

### Feature importances

In [None]:
sort_step = 3
top = 30

In [None]:
def featimp_by_steps(model, steps, sort_step, top, cols):
    """
    Return pd.DataFrame of top feature importances by selected steps.
    
    Args:
        model: Model to get importances for.
        steps: List of ordered steps to include in table.
        sort_step: Step to sort table by.
        top: Top number of features to include.
        cols: Feature list.
    """
    for step in steps:
        fi_dict = model.extras.feature_importances["test"][step]
        step_df = pd.DataFrame(fi.reorder_fi_dict(fi_dict))
        step_df = step_df.rename(columns={"importance": f"s={step}"})
        step_df.set_index("feature", inplace=True)
        df = df.join(step_df) if step > steps[0] else step_df.copy()

    df = df.sort_values(by=[f"s={sort_step}"], ascending=False)
    df = df[0:top + 1]
    
    return df

In [None]:
for model in models:
    fi_cm = featimp_by_steps(
        model=model,
        steps=steps, 
        sort_step=sort_step, 
        top=top,
        cols=model.cols_features
    )
    fi.write_fi_tex(
        pd.DataFrame(fi_cm), 
        os.path.join(out_paths["features"], f"impurity_imp_{model.name}_{level}.tex")
    )

### Permutation scores

In [None]:
sort_step = 3
top = 30

for model in models:
    for step in steps:
        pi_dict = model.extras.permutation_importances["test"][step]["test"]
        step_df = pd.DataFrame(fi.reorder_fi_dict(pi_dict))
        step_df = step_df.rename(columns={"importance": f"s={step}"})
        step_df.set_index("feature", inplace=True)
        pi_df = pi_df.join(step_df) if step > steps[0] else step_df.copy()
    
    pi_df = pi_df.sort_values(by=[f"s={sort_step}"], ascending=False)
    pi_df = pi_df[0:top + 1]
    
    fi.write_fi_tex(
        pi_df, 
        os.path.join(out_paths["features"], f"permutation_imp_{model.name}.tex")
    )