In [1]:
import numpy as np
import pandas as pd
import joblib
import warnings
warnings.filterwarnings('ignore')

In [2]:
country = 'CHN'
model = joblib.load(f'models/{country}_model.pkl')
scaler = joblib.load(f'scalers/{country}_scaler.pkl')

In [3]:
main_data = pd.read_csv("combined_data.csv", index_col=[0, 1])

In [4]:
main_data 

Unnamed: 0_level_0,Unnamed: 1_level_0,Inflation,Population,GDP,Export,Import
Country Code,Year,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
ABW,1980,,59909.0,,,
ABW,1981,,60563.0,,,
ABW,1982,,61276.0,,,
ABW,1983,,62228.0,,,
ABW,1984,,62901.0,,,
...,...,...,...,...,...,...
ZWE,2017,0.893962,14812482.0,5.107466e+10,107.151887,83.660837
ZWE,2018,10.618866,15034452.0,3.415607e+10,124.909506,105.579357
ZWE,2019,255.304991,15271368.0,2.571741e+10,131.425343,79.585840
ZWE,2020,557.201817,15526888.0,2.686794e+10,135.325610,82.633621


In [5]:
country_data = main_data.loc['CHN']
country_data

Unnamed: 0_level_0,Inflation,Population,GDP,Export,Import
Year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1980,,981235000.0,191149200000.0,0.79611,1.18692
1981,,993885000.0,195866400000.0,0.967993,1.310348
1982,,1008630000.0,205089700000.0,0.981804,1.147856
1983,,1023310000.0,230686700000.0,0.977625,1.273147
1984,,1036825000.0,259946500000.0,1.149741,1.631462
1985,,1051040000.0,309488000000.0,1.203008,2.514868
1986,,1066790000.0,300758100000.0,1.361004,2.553675
1987,7.233836,1084035000.0,272973000000.0,1.734663,2.572245
1988,18.811818,1101630000.0,312353600000.0,2.090023,3.289589
1989,18.245638,1118650000.0,347768100000.0,2.310919,3.520053


In [6]:
country_data = country_data.dropna()

In [7]:
features_columns = ['Population', 'Inflation', 'Import', 'Export']  
country_data[features_columns] = scaler.transform(country_data[features_columns])  


In [8]:
country_data

Unnamed: 0_level_0,Inflation,Population,GDP,Export,Import
Year,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1987,0.270563,-1.830679,272973000000.0,-0.898303,-0.908156
1988,2.022967,-1.657325,312353600000.0,-0.890564,-0.893966
1989,1.937272,-1.489637,347768100000.0,-0.885753,-0.889407
1990,-0.362342,-1.326727,360857900000.0,-0.876601,-0.89623
1991,-0.285998,-1.173079,383373300000.0,-0.867194,-0.883931
1992,0.137391,-1.033273,426915700000.0,-0.854712,-0.864157
1993,1.387006,-0.900561,444731300000.0,-0.848194,-0.836636
1994,2.84713,-0.768587,564321900000.0,-0.820161,-0.822913
1995,1.717137,-0.640309,734484800000.0,-0.793553,-0.803478
1996,0.433926,-0.515232,863749300000.0,-0.791381,-0.795394


# Counterfactual explanation

In [9]:
import dice_ml
from dice_ml.utils import helpers

In [30]:
feature_names = ['Population', 'Inflation', 'Import', 'Export']

mean_dict = dict(zip(feature_names, scaler.mean_))
std_dict = dict(zip(feature_names, scaler.scale_))

original_ranges = {
    'Import': (155, 165),
    'Export': (145, 150),
    'Inflation': (-5, 5),
}

permitted_range_scaled = {
    feature: [(value - mean_dict[feature]) / std_dict[feature] for value in original_ranges[feature]]
    for feature in original_ranges
}


In [31]:
permitted_range_scaled

{'Import': [2.1071370343698477, 2.3049549094442856],
 'Export': [2.221968589419226, 2.330866944347459],
 'Inflation': [-1.5811091332849954, -0.06754291551436838]}

In [32]:
import dice_ml
from dice_ml.utils import helpers

d = dice_ml.Data(dataframe=country_data, continuous_features=['Population', 'Inflation', 'Import', 'Export'], outcome_name='GDP')

m = dice_ml.Model(model=model, backend='sklearn', model_type='regressor') 

exp = dice_ml.Dice(d, m)

query_instance = country_data.drop(columns="GDP").loc[2021:2021]


desired_range = [1.8e+13, 2e+13]  

features_to_vary = ['Inflation', 'Import', 'Export']

dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=10, 
                                        permitted_range=permitted_range_scaled,
                                        features_to_vary=features_to_vary,
                                        desired_range=desired_range)

dice_exp.visualize_as_dataframe(show_only_changes=True)

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  4.97it/s]

Query instance (original outcome : 17878617161728.0)





Unnamed: 0,Inflation,Population,Export,Import,GDP
0,-0.675843,1.404118,2.286448,2.20761,17878620000000.0



Diverse Counterfactual set (new outcome: [18000000000000.0, 20000000000000.0])


Unnamed: 0,Inflation,Population,Export,Import,GDP
0,-,-,-,2.225992308,-
1,-,-,-,2.299747729,-
2,-1.474834478,-,-,-,-
3,-1.259250942,-,2.281478594,-,-
4,-0.827286232,-,2.229018962,-,-
5,-1.520726561,-,-,-,-
6,-,-,-,2.2532115,-
7,-,-,-,2.290765625,-
8,-1.10750032,-,2.259873404,-,-
9,-,-,2.24223565,2.248970666,-


In [26]:
cf_df = dice_exp.cf_examples_list[0].final_cfs_df

In [27]:
cf_df

Unnamed: 0,Inflation,Population,Export,Import,GDP
0,-1.209405,1.404118,2.286448,2.222665,18242370000000.0
1,-0.585148,1.404118,2.286448,2.238369,18080210000000.0
2,-0.675843,1.404118,2.286448,2.288894,18521280000000.0
3,-1.437965,1.404118,2.286448,2.20761,18228180000000.0
4,-1.498892,1.404118,2.286448,2.20761,18256120000000.0
5,-1.410192,1.404118,2.286448,2.20761,18215440000000.0
6,-0.675843,1.404118,2.286448,2.266295,18342600000000.0
7,-1.215775,1.404118,2.27552,2.20761,18148730000000.0
8,-1.508216,1.404118,2.317118,2.20761,18197350000000.0
9,-1.153526,1.404118,2.30898,2.20761,18051390000000.0


In [28]:
cf_df[features_columns] = scaler.inverse_transform(cf_df[features_columns])

In [29]:
cf_df

Unnamed: 0,Inflation,Population,Export,Import,GDP
0,-2.544181,1412360000.0,147.960521,160.840119,18242370000000.0
1,1.58023,1412360000.0,147.960521,161.634001,18080210000000.0
2,0.981015,1412360000.0,147.960521,164.188088,18521280000000.0
3,-4.05426,1412360000.0,147.960521,160.079062,18228180000000.0
4,-4.456796,1412360000.0,147.960521,160.079062,18256120000000.0
5,-3.870768,1412360000.0,147.960521,160.079062,18215440000000.0
6,0.981015,1412360000.0,147.960521,163.045682,18342600000000.0
7,-2.586271,1412360000.0,147.45878,160.079062,18148730000000.0
8,-4.518404,1412360000.0,149.368705,160.079062,18197350000000.0
9,-2.174993,1412360000.0,148.995066,160.079062,18051390000000.0


In [17]:
main_data.loc[('CHN', 2021)]

Inflation     9.810151e-01
Population    1.412360e+09
GDP           1.782046e+13
Export        1.479605e+02
Import        1.600791e+02
Name: (CHN, 2021), dtype: float64

In [18]:
cf_df.to_csv('counterfactuals.csv', index=False)