In [None]:
import sys
sys.path.append('../')
%reload_ext autoreload 
%autoreload 2

In [None]:
from general_functions import *
from cross_country_functions import *

In [None]:
from ethiopia_functions import prepare_df_lsms_ram_roster_eth, map_lsms_wfp_adm1, actual_predicted_quantiles_weights_eth, actual_predicted_urban_weights_eth

In [None]:
from classification_class import *
from cross_country_functions import *

In [None]:
from visualisations import plot_precision_recall_curve, plot_roc_curve, visualise_actual_predicted_map, scatterplot_actual_predicted_percentage, barplots_sanity_wealth, barplots_sanity_area, create_roc_curve


### Prepare the data:

In [None]:
features = pd.read_csv('../data/all_features.csv', index_col=0)
features.index = features.index.astype(str)

In [None]:
path_to_targets = '../data/ethiopia_nigeria_targets.csv'

In [None]:
path_to_inventory = '../data/inventory_small.xlsx'

In [None]:
hh_roster_eth = pd.read_csv('../data/eth/hh_roster_eth.csv', index_col=0)
hh_roster_eth.hhid = hh_roster_eth['household_id'].astype(str)

create list to save the performances for each threshold

In [None]:
l_dict_thresholds_performances = []

Set threshold probability:


In [None]:
#manually set the probability threshold, it can be [0.15, 0.85], None is 0.5
threshold_probability = 0.5 

In [None]:
#True if you want to save the images
save_image = False

In [None]:
#set the micronutrient
t = 'mimi_simple'

In [None]:
all_target_names = {'zn_ai': 'zinc', 'va_ai': 'vitamin A', 'fol_ai': 'folate', 
                    'vb12_ai': 'vitamin B12', 'fe_ai':'iron', 'mimi_simple': 'overall micronutrient'}

In [None]:
#for t in  ['zn_ai', 'va_ai', 'fol_ai', 'vb12_ai', 'fe_ai', 'mimi_simple']:
for t in ['mimi_simple']:
    
    print(t)
    
    target_name = all_target_names[t]
    
    y, indexes = prepares_crosscountry_inputs('Nigeria','Ethiopia', t, path_to_targets= path_to_targets)
    
    #---Model
    #get the best random state
    best_random_state = pd.read_csv('../data/results/perf_%s_NGA_undersampling_3.2_xgboost.csv'%t, index_col=0).best_random_state[0]
    #set the classification object
    classification = Classification(y=y, data_all=features, cross_country=True, train_indexes=indexes,
                                    type_target=t, random_state=best_random_state, sampling='undersampling',
                                    sampling_strategy=1)
    #train the model using the best hyperparameters
    model = classification.xgbclassification_best_model('../data/results/besthyper_%s_NGA_undersampling_3.2_XGBoost.csv'%t)
    
    #probalilities
    y_proba = classification.y_proba(model)
    
    #---Performance---
    if threshold_probability != None:
        predictions = predictions_proba(y_proba, threshold_probability)
    else:
        predictions = classification.predictions(model)
    perf_dict = classification.perf_ind_classification(predictions)
    
    array_precision, array_recall, average_pre_recall = calculates_precision_recall_auc(y_proba, classification, drop_intermediate=True)
    #append average_pre_recall on perf_dict
    perf_dict['average_pre_recall'] = average_pre_recall
    array_fpr, array_tpr, rocauc_score = calculates_roc_auc(y_proba, classification, drop_intermediate=True)
    #append roc scores on perf_dict
    perf_dict['rocauc_score'] = rocauc_score
    #plot and save precision-recall curve and roc curve
    plot_precision_recall_curve(classification, array_recall, array_precision, save=save_image, mn=t, iso3='ETH', title='risk of inadequate %s intake \n (Nigeria to Ethiopia)' %target_name)
    plot_roc_curve(classification, array_fpr, array_tpr, drop_intermediate=True, save=save_image, mn=t, iso3='ETH', title='risk of inadequate %s intake \n (Nigeria to Ethiopia)' %target_name)
    #calculate adjusted precision-recall and roc values and save them on the perf_dict
    adjusted_roc_auc, adjusted_average_pre_recall = get_adjusted_values(classification, rocauc_score, average_pre_recall)
    perf_dict['adjusted_rocauc'] = adjusted_roc_auc
    perf_dict['adjusted_average_pre_recall'] = adjusted_average_pre_recall
    
    #---Maps---
    geo_df = get_geodata(adm0=79)
    predicted_actual_target = classification.df_predicted_actual(predictions, t, save=False)
    
    target_roster = preprocess_roster_actual_predicted(predicted_actual_target, path_to_inventory, survey_id='ETH_2018_ESS_v03_M', hh_roster=hh_roster_eth)
    target_roster_geo = map_lsms_wfp_adm1(target_roster)
    
    actual_predicted_perc = calculate_actual_predicted(target_roster_geo, survey_id='ETH_2018_ESS_v03_M', path_to_inventory=path_to_inventory, weights=True, rank=True)
    for column in ['actual', 'predicted']:
        if column=='actual':
            title=''
        else:
            title='(T=%s)'%threshold_probability
        visualise_actual_predicted_map(geo_df, actual_predicted_perc, column, comparison=True, iso3='ETH', save=save_image, title=title, threshold = threshold_probability, mn=target_name)
    for column2 in ['actual_rank', 'predicted_rank']:
        visualise_actual_predicted_map(geo_df, actual_predicted_perc, column2, comparison=False, iso3='ETH', save=save_image, mn=target_name)
        
    #---Save performances---:
    for k, v in perf_dict.items():
        perf_dict[k] = [v]
    df_perf_dict = pd.DataFrame(perf_dict)
            
    #create a list and save performances for every threshold
    dict_thresholds_performances = {'threshold': threshold_probability, 'recall': df_perf_dict.recall[0], 'precision': df_perf_dict.precision[0], 'specificity': df_perf_dict.specificity[0]}
    l_dict_thresholds_performances.append(dict_thresholds_performances)

In [None]:
scatterplot_actual_predicted_percentage(actual_predicted_perc, limin=0.42, limax=1, diff_threshold=0.5,  title='Percentage of risk of inadequate overall intake (NGA to ETH)', iso3='ETH', path = None, fontsize=12, save=False)

# Sanity checks

In [None]:
df_quantiles_weights = pd.read_csv('../data/eth/df_quantiles_weights_eth.csv', index_col=0) 

In [None]:
df_quantiles_weights = actual_predicted_quantiles_weights_eth(df_quantiles_weights, predicted_actual_target)

In [None]:
barplots_sanity_wealth(df_quantiles_weights, column='actual', title='Actual percentage', title_x='Wealth quintile', country='eth', save=False)


In [None]:
barplots_sanity_wealth(df_quantiles_weights, column='predicted', T=threshold_probability, title='Predicted percentage T=%s'%threshold_probability, title_x='Wealth quintile', country='eth', save=False)


In [None]:
predicted_actual_urban_weights = actual_predicted_urban_weights_eth(hh_roster_eth, df_quantiles_weights)

In [None]:
barplots_sanity_area(predicted_actual_urban_weights, column='actual', title='Actual percentage',  title_x='Area', country='eth', save=False)

In [None]:
barplots_sanity_area(predicted_actual_urban_weights, column='predicted', T=threshold_probability,  title='Predicted percentage (T=%s)'%threshold_probability, title_x='Area', country='eth', save=False)

In [None]:
#you can run this function after creating the dictionary with the postive/negative rates for each probability threshold
create_roc_curve(mn=t, title = "Nigeria to Ethiopia", classification=classification, 
                 array_tpr=array_tpr, array_fpr=array_fpr, countries='nga_to_eth', 
                 l_dict_thresholds_performances=l_dict_thresholds_performances, save=False)