In [10]:
import numpy as np
import pandas as pd
import joblib
import warnings
warnings.filterwarnings('ignore')

In [11]:
country = 'ESP'
model = joblib.load(f'models/{country}_model.pkl')
scaler = joblib.load(f'scalers/{country}_scaler.pkl')


new_data = {
    'Population' : 47415794, 
    'Inflation' : 3.0931351197642, 
    'Import' : 134.0948592709,
    'Export' : 135.6812883345
}

df = pd.DataFrame([new_data])

In [12]:
df

Unnamed: 0,Population,Inflation,Import,Export
0,47415794,3.093135,134.094859,135.681288


In [13]:
columns = ['Population', 'Inflation', 'Import', 'Export']
df_scaled = df.copy()
df_scaled[columns] = scaler.transform(df[columns])

In [14]:
df

Unnamed: 0,Population,Inflation,Import,Export
0,47415794,3.093135,134.094859,135.681288


In [15]:
df_scaled

Unnamed: 0,Population,Inflation,Import,Export
0,1.11666,0.622526,1.376644,1.70438


In [16]:
predicted_gdp = model.predict(df_scaled)

In [17]:
predicted_gdp

array([1.51130609e+12])

In [18]:
df_scaled['GDP'] = predicted_gdp
df_scaled

Unnamed: 0,Population,Inflation,Import,Export,GDP
0,1.11666,0.622526,1.376644,1.70438,1511306000000.0


# Counterfactual explanation

In [9]:
import dice_ml
from dice_ml.utils import helpers

In [33]:
import dice_ml
from dice_ml.utils import helpers

d = dice_ml.Data(dataframe=df_scaled, continuous_features=['Population', 'Inflation', 'Import', 'Export'], outcome_name='GDP')

m = dice_ml.Model(model=model, backend='sklearn', model_type='regressor') 

exp = dice_ml.Dice(d, m)

query_instance = df_scaled.drop(columns="GDP").iloc[0:1]


desired_range = [1.5e12, 2e12]  # 

features_to_vary = ['Population', 'Inflation', 'Import', 'Export']
permitted_range = {'Population':[0, 50e6], 'Inflation':[0,10]}

dice_exp = exp.generate_counterfactuals(query_instance, total_CFs=2, 
                                        permitted_range=permitted_range,
                                        features_to_vary=features_to_vary,
                                        desired_range=desired_range)

dice_exp.visualize_as_dataframe(show_only_changes=True)

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.63it/s]

Query instance (original outcome : 1511306035200.0)





Unnamed: 0,Population,Inflation,Import,Export,GDP
0,1.11666,0.622526,1.376644,1.70438,1511306000000.0



Diverse Counterfactual set (new outcome: [1500000000000.0, 2000000000000.0])


Unnamed: 0,Population,Inflation,Import,Export,GDP
0,-,0.450534,-,-,-
1,-,1.569716,-,-,-
