In [1]:
import pandas as pd
import pickle


dt = pd.read_csv('../_data/wheat/wheat.csv')
sample = pd.read_csv('../_data/wheat/sample.csv')

from tqdm import tqdm

In [2]:
geo = pd.read_csv('../_data/wheat/countries.csv')
longitudes = {
    country: geo[geo.Country == country].iat[0, 2]
    for country in geo.Country
}
latitudes = {
    country: geo[geo.Country == country].iat[0, 1]
    for country in geo.Country
}

all_cols = {
    'precipitation': [
        'all_Precip_jan','all_Precip_feb','all_Precip_mar','all_Precip_apr','all_Precip_may','all_Precip_jun',
        'all_Precip_jul','all_Precip_aug','all_Precip_sep','all_Precip_oct','all_Precip_nov','all_Precip_dec',
        'all_Precip_jan_faavg','all_Precip_feb_faavg','all_Precip_mar_faavg','all_Precip_apr_faavg','all_Precip_may_faavg','all_Precip_jun_faavg',
        'all_Precip_jul_faavg','all_Precip_aug_faavg','all_Precip_sep_faavg','all_Precip_oct_faavg','all_Precip_nov_faavg','all_Precip_dec_faavg',
    ],
    
    'temperature': [
        'all_Temp_jan','all_Temp_feb','all_Temp_mar','all_Temp_apr','all_Temp_may','all_Temp_jun',
        'all_Temp_jul','all_Temp_aug','all_Temp_sep','all_Temp_oct','all_Temp_nov','all_Temp_dec',
        'all_Temp_jan_faavg','all_Temp_feb_faavg','all_Temp_mar_faavg','all_Temp_apr_faavg','all_Temp_may_faavg','all_Temp_jun_faavg',
        'all_Temp_jul_faavg','all_Temp_aug_faavg','all_Temp_sep_faavg','all_Temp_oct_faavg','all_Temp_nov_faavg','all_Temp_dec_faavg',
    ],
    
    'us_production': [
        'US_wheat_production','USprod_Cabbages','USprod_Carrots_turnips','USprod_Cotton_lint','USprod_Grapefruit',
        'USprod_Grapes','USprod_Lettuce','USprod_Onions_dry','USprod_Oranges','USprod_Peaches_nectarines','USprod_Watermelons',
    ],
    
    'recipient_info': [
        # 'risocode','recipient_country','wb_region','obs',
        'year','total_population','roads_per_capita','rgdpch','ln_rgdpch',
        'recipient_pc_wheat_prod_avg','recipient_pc_cereals_prod_avg',
        'recipient_wheat_prod','recipient_cereals_prod',
    ],
    
    'aid_info': [
        'wheat_aid','fadum','fadum_avg',
        'real_usmilaid','real_us_nonfoodaid_ecaid',
        'world_wheat_aid','world_cereals_aid',
        # 'non_US_wheat_aid','non_US_cereals_aid','non_us_oda_net','non_us_oda_net2',
        'real_usmilaid_avg','real_usecaid_avg','real_us_nonfoodaid_ecaid_avg',
    ],
    
    'us_info': [
        'USA_rgdpch','s2unUSA','wheat_price_xrat_US_curr','fao_cereal_production','US_president_democ',
    ],
    
    'conflict_info': [
        'any_war','intra_state','inter_state','intra_state_onset','intra_state_offset','peace_dur','intra_state_dur','intensity',
    ],
    
    'misc': [
        'cereal_pc_import_quantity_avg','oil_price_2011_USD','resource_share_GDP',
        'polity2_from_P4','alesina_ethnic','polrqnew',
    ],
    
    'abstracted': [
        'precipitation', 'temperature', 'us_production', 'longitude', 'latitude'
    ]
}

df = pd.DataFrame()
df = df.assign(
    precipitation = sum([dt[col] for col in all_cols['precipitation']]),
    temperature = sum([dt[col] for col in all_cols['temperature']]),
    us_production = sum([dt[col] for col in all_cols['us_production']]),
    longitude = dt.recipient_country.apply(lambda x: longitudes[x]),
    latitude = dt.recipient_country.apply(lambda x: latitudes[x]),
)

for col_category in ['recipient_info', 'aid_info', 'us_info', 'conflict_info']:
    for col in all_cols[col_category]:
        df[col] = dt[col]


In [3]:
from sklearn.ensemble import RandomForestRegressor
from tqdm import tqdm

        
nan_map = df.isna()
nan_stat = nan_map.sum().sort_values()
nan_cols = nan_stat[nan_stat > 0].index
train_cols = ['year', 'longitude', 'latitude', 'total_population', 'USA_rgdpch', 'fadum_avg']
for nan_col in tqdm(nan_cols):
    train_data = df[nan_map[nan_col] == 0]
    predict_data = df[nan_map[nan_col]]
    X = train_data[train_cols]
    Y = train_data[nan_col]
    regressor = RandomForestRegressor(100)
    regressor.fit(X, Y)
    predicted = regressor.predict(predict_data[train_cols])
    i = 0
    for index, row in df.iterrows():
        if row.isna()[nan_col]:
            df.iloc[index][nan_col] = predicted[i]
            i += 1

100%|██████████| 29/29 [01:21<00:00,  2.82s/it]


In [4]:
from causality.inference.search import IC
from causality.inference.independence_tests import RobustRegressionTest
import warnings
warnings.filterwarnings('ignore')

vt = {key: 'c' for key in df.columns}
ic_algorithm = IC(RobustRegressionTest)
graph = ic_algorithm.search(df, variable_types=vt)

  9%|▉         | 2428/25974 [2:00:03<7:44:23,  1.18s/it]   

In [10]:
for node in graph.nodes(data=True):
    print(node)

('precipitation', {'type': 'c'})
('temperature', {'type': 'c'})
('us_production', {'type': 'c'})
('longitude', {'type': 'c'})
('latitude', {'type': 'c'})
('year', {'type': 'c'})
('total_population', {'type': 'c'})
('roads_per_capita', {'type': 'c'})
('rgdpch', {'type': 'c'})
('ln_rgdpch', {'type': 'c'})
('recipient_pc_wheat_prod_avg', {'type': 'c'})
('recipient_pc_cereals_prod_avg', {'type': 'c'})
('recipient_wheat_prod', {'type': 'c'})
('recipient_cereals_prod', {'type': 'c'})
('wheat_aid', {'type': 'c'})
('fadum', {'type': 'c'})
('fadum_avg', {'type': 'c'})
('real_usmilaid', {'type': 'c'})
('real_us_nonfoodaid_ecaid', {'type': 'c'})
('world_wheat_aid', {'type': 'c'})
('world_cereals_aid', {'type': 'c'})
('real_usmilaid_avg', {'type': 'c'})
('real_usecaid_avg', {'type': 'c'})
('real_us_nonfoodaid_ecaid_avg', {'type': 'c'})
('USA_rgdpch', {'type': 'c'})
('s2unUSA', {'type': 'c'})
('wheat_price_xrat_US_curr', {'type': 'c'})
('fao_cereal_production', {'type': 'c'})
('US_president_democ',

In [12]:
from graphviz import Digraph

def get_color(col):
    abstracted_dict = {
        'precipitation': 'recipient_info',
        'temperature': 'recipient_info',
        'us_production': 'us_info',
        'longitude': 'recipient_info',
        'latitude': 'recipient_info'
    }
    color_dict = {
        'recipient_info': 'black',
        'us_info': 'gray',
        'misc': 'chocolote',
        'aid_info': 'green',
        'conflict_info': 'red',
    }
    if col in abstracted_dict:
        return color_dict[abstracted_dict[col]]
    else:
        for category in all_cols:
            if col in all_cols[category]:
                return color_dict[category]

gra = Digraph()

for node in graph.nodes():
    color = get_color(node)
    gra.node(node, color=color)
    
for edge in graph.edges(data=True):
    node1 = edge[0]
    node2 = edge[1]
    arrows = edge[2]['arrows']
    to1 = node1 in arrows
    to2 = node2 in arrows
    
    direction = 'none'
    if to1 and to2:
        direction = 'both'
    elif to1:
        direction = 'back'
    elif to2:
        direction = 'forward'
        
    color = 'red' if edge[2]['marked'] else 'black' 
    
    gra.edge(node1, node2, dir=direction, color=color)

gra.view()

'Digraph.gv.pdf'