In [1]:
import numpy as np
import pandas as pd
import joblib
import warnings
warnings.filterwarnings('ignore')

In [23]:
model = joblib.load(f'models/{country}_model.pkl')
scaler = joblib.load(f'scalers/{country}_scaler.pkl')

In [24]:
main_data = pd.read_csv("combined_data.csv", index_col=[0, 1])

In [25]:
main_data 

Unnamed: 0_level_0,Unnamed: 1_level_0,Inflation,Population,GDP,Export,Import
Country Code,Year,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
ABW,1980,,59909.0,,,
ABW,1981,,60563.0,,,
ABW,1982,,61276.0,,,
ABW,1983,,62228.0,,,
ABW,1984,,62901.0,,,
...,...,...,...,...,...,...
ZWE,2017,0.893962,14812482.0,5.107466e+10,107.151887,83.660837
ZWE,2018,10.618866,15034452.0,3.415607e+10,124.909506,105.579357
ZWE,2019,255.304991,15271368.0,2.571741e+10,131.425343,79.585840
ZWE,2020,557.201817,15526888.0,2.686794e+10,135.325610,82.633621


In [26]:
country_data = main_data.loc['ESP']
country_data

Unnamed: 0_level_0,Inflation,Population,GDP,Export,Import
Year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1980,15.561902,37491165.0,232600600000.0,,
1981,14.549346,37758631.0,202663000000.0,,
1982,14.415002,37986012.0,195856800000.0,,
1983,12.174073,38171525.0,170829100000.0,,
1984,11.280277,38330364.0,171980000000.0,,
1985,8.814455,38469512.0,180664300000.0,,
1986,8.794939,38584624.0,251141600000.0,,
1987,5.248019,38684815.0,318520300000.0,,
1988,4.837271,38766939.0,375891700000.0,,
1989,6.791436,38827764.0,414460800000.0,,


In [27]:
country_data = country_data.dropna()

In [28]:
features_columns = ['Population', 'Inflation', 'Import', 'Export']  
country_data[features_columns] = scaler.transform(country_data[features_columns])  


In [29]:
country_data

Unnamed: 0_level_0,Inflation,Population,GDP,Export,Import
Year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1995,1.709478,-1.612872,614170000000.0,-1.425867,-1.616997
1996,0.942774,-1.554035,642251400000.0,-1.324469,-1.543167
1997,-0.149063,-1.494582,589739800000.0,-1.393966,-1.596015
1998,-0.243096,-1.435632,618731500000.0,-1.272099,-1.391123
1999,0.08424,-1.377659,634394900000.0,-1.352621,-1.401493
2000,0.85659,-1.313432,598102900000.0,-1.23806,-1.202159
2001,0.964083,-1.213166,627798700000.0,-1.219471,-1.213042
2002,0.60363,-1.006937,708938200000.0,-1.120363,-1.110297
2003,0.585223,-0.738628,907963200000.0,-0.785975,-0.682846
2004,0.585471,-0.478069,1069829000000.0,-0.495314,-0.194163


# Counterfactual explanation

In [9]:
import dice_ml
from dice_ml.utils import helpers

In [51]:
import dice_ml
from dice_ml.utils import helpers

d = dice_ml.Data(dataframe=country_data, continuous_features=['Population', 'Inflation', 'Import', 'Export'], outcome_name='GDP')

m = dice_ml.Model(model=model, backend='sklearn', model_type='regressor') 

exp = dice_ml.Dice(d, m)

query_instance = country_data.drop(columns="GDP").loc[2005:2005]


desired_range = [1.2e+12, 1.5e+13]  

features_to_vary = ['Population', 'Inflation', 'Import', 'Export']
#permitted_range = {'Population':[0, 10000e6], 'Inflation':[0,20]}

dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=3, 
                                        features_to_vary=features_to_vary,
                                        desired_range=desired_range)

dice_exp.visualize_as_dataframe(show_only_changes=False)

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.48it/s]

Query instance (original outcome : 1172091830272.0)





Unnamed: 0,Inflation,Population,Export,Import,GDP
0,0.812098,-0.21857,-0.385295,0.105118,1172092000000.0



Diverse Counterfactual set (new outcome: [1200000000000.0, 15000000000000.0])


Unnamed: 0,Inflation,Population,Export,Import,GDP
0,0.812098,-1.481795,-0.385295,-0.630995,1213451000000.0
1,0.812098,1.007919,1.133366,0.105118,1470171000000.0
2,0.812098,-0.21857,1.212018,0.105118,1497658000000.0


In [52]:
cf_df = dice_exp.cf_examples_list[0].final_cfs_df

In [53]:
cf_df

Unnamed: 0,Inflation,Population,Export,Import,GDP
0,0.812098,-1.481795,-0.385295,-0.630995,1213451000000.0
1,0.812098,1.007919,1.133366,0.105118,1470171000000.0
2,0.812098,-0.21857,1.212018,0.105118,1497658000000.0


In [55]:
cf_df[features_columns] = scaler.inverse_transform(cf_df[features_columns])

In [56]:
cf_df

Unnamed: 0,Inflation,Population,Export,Import,GDP
0,3.368814,40093420.0,68.247405,68.583589,1213451000000.0
1,3.368814,47109370.0,117.254629,92.603696,1470171000000.0
2,3.368814,43653160.0,119.792738,92.603696,1497658000000.0


In [57]:
cf_df.to_csv('counterfactuals.csv', index=False)