# DNNの説明性
一般にAIと呼ばれる技術の中でも特にDNN(深層学習)は何故その答えになったのかがブラックボックスになるというデメリットがある。そこで、反実仮想を用いてモデルについて本来ありえない値を入れてシミュレーションをすることでできる限り説明できるようにする手法としてdiceというライブラリを使用する。

## ライブラリのインポート

In [1]:
from sklearn.model_selection import train_test_split as tts
from sklearn.metrics import classification_report
from sklearn.neural_network import MLPClassifier as MLP
from sklearn.ensemble import GradientBoostingClassifier as GBC
import pandas as pd
import numpy as np
import dice_ml

## データの読み込み

In [2]:
df = pd.read_csv("rossi.csv")
df.head()

Unnamed: 0,week,arrest,fin,age,race,wexp,mar,paro,prio
0,20,1,0,27,1,0,0,1,3
1,17,1,0,18,1,0,0,1,8
2,25,1,0,19,0,1,0,1,13
3,52,0,1,23,1,1,1,1,1
4,52,0,0,19,0,1,0,1,3


## 原因系変数と結果系変数を分ける
ここでは再逮捕を結果系変数とする。なお、説明性についての解釈をしやすくするため今回は原因系変数の前処理(標準化・正規化など)は行わずありのままの値を使用する。

In [3]:
x = df.drop(["arrest"], axis=1)
y = df["arrest"]

## 訓練データとテストデータを分ける

In [4]:
x_train, x_test, y_train, y_test = tts(x, y, test_size=0.2, random_state=1)

## ニューラルネットワークモデルを作成
できるだけ複雑なモデルを作成する

In [5]:
model = MLP(hidden_layer_sizes=(500, 500, 400, 300))

## 適合

In [6]:
model.fit(x_train, y_train)

MLPClassifier(hidden_layer_sizes=(500, 500, 400, 300))

## 予測および精度の検証
モデルについて精度の検証を行うことはモデルが原因系変数でどれだけ結果を説明できているかの指標となるため必ず行う。

In [7]:
y_pred = model.predict(x_test)

In [8]:
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       0.96      0.94      0.95        69
           1       0.79      0.83      0.81        18

    accuracy                           0.92        87
   macro avg       0.87      0.89      0.88        87
weighted avg       0.92      0.92      0.92        87



## 反実仮想の作成

In [9]:
d = dice_ml.Data(dataframe = pd.concat([x_test, y_test], axis=1),
                 continuous_features = ["week", "age", "prio"], #量的変数を選択
                 outcome_name = "arrest" #結果系変数を選択
                )
m = dice_ml.Model(model=model, backend="sklearn")
exp = dice_ml.Dice(d, m)

In [10]:
pre_counter = x_test.iloc[0:10, :] 
dice_exp = exp.generate_counterfactuals(pre_counter, total_CFs=4, desired_class = "opposite")
dice_exp.visualize_as_dataframe(show_only_changes=True)

100%|█████████████████████████████████████████████████| 10/10 [00:01<00:00,  6.84it/s]

Query instance (original outcome : 0)





Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,0,20,1,0,0,1,9,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,37,0,-,0,0,0,1,-,1
1,30,0,-,0,0,0,1,-,1
2,8,0,-,1,0,0,1,11,1
3,37,0,-,1,0,0,1,-,1


Query instance (original outcome : 1)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,36,1,23,1,0,0,0,3,1



Diverse Counterfactual set (new outcome: 0.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,-,1,33,1,0,0,0,1,0
1,-,1,-,1,1,0,0,0,0
2,40,0,-,1,0,0,0,-,0
3,-,1,31,1,1,0,0,-,0


Query instance (original outcome : 0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,1,28,1,0,0,1,4,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,50,1,-,1,0,0,1,10,1
1,-,1,33,1,0,0,1,7,1
2,-,0,-,1,0,0,1,11,1
3,49,1,-,1,0,0,1,10,1


Query instance (original outcome : 0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,0,22,1,0,0,1,1,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,27,0,24,1,0,0,1,-,1
1,30,0,-,1,0,0,1,-,1
2,21,0,-,1,0,0,1,6,1
3,37,0,-,1,0,0,1,8,1


Query instance (original outcome : 0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,0,21,1,0,0,0,2,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,24,0,38,1,0,0,0,-,1
1,9,0,-,1,0,0,0,6,1
2,27,0,-,1,0,0,0,-,1
3,9,0,-,1,0,0,0,11,1


Query instance (original outcome : 0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,1,18,1,0,0,0,4,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,27,1,-,1,0,0,0,-,1
1,-,1,32,1,0,0,0,9,1
2,38,1,-,1,1,0,0,-,1
3,35,1,-,1,0,1,0,-,1


Query instance (original outcome : 0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,0,19,0,1,0,1,3,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,30,0,-,0,1,0,1,-,1
1,29,0,-,0,1,0,1,-,1
2,39,0,-,1,1,0,1,-,1
3,31,0,-,0,1,1,1,-,1


Query instance (original outcome : 0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,0,20,1,0,0,0,1,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,31,0,-,1,0,0,0,-,1
1,19,0,29,1,0,0,0,-,1
2,11,0,29,1,0,0,0,-,1
3,23,1,-,1,0,0,0,-,1


Query instance (original outcome : 1)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,26,0,32,1,1,0,0,2,1



Diverse Counterfactual set (new outcome: 0.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,41,1,-,1,1,0,0,-,0
1,50,0,-,1,1,0,0,-,0
2,39,0,24,1,1,0,0,-,0
3,52,0,-,1,1,0,0,0,0


Query instance (original outcome : 1)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,24,0,40,1,1,0,0,2,1



Diverse Counterfactual set (new outcome: 0.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,0,-,1,1,0,1,-,0
1,51,0,-,1,1,0,0,-,0
2,40,0,34,1,1,0,0,-,0
3,44,1,-,1,1,0,0,-,0


元のデータ(上)と反実仮想(下)のデータが出力される。ここで、陽性と陰性が逆転するためのデータが指定された個数(total_CFs)だけ表示される。

## GBMを用いてより正確に予測
基本的にニューラルネットワークでは原因系変数の前処理として標準化や正規化が必要になる。しかし、決定木系アルゴリズムではそのような変換をしてもしなくても精度自体は変わらないため、より精度の高いモデルで反実仮想を作成できる。

### GBMのモデルを作成

In [11]:
model = GBC()

### 適合

In [12]:
model.fit(x_train, y_train)

GradientBoostingClassifier()

### 予測および精度の検証
前処理が必要なニューラルネットワークモデルと異なりGBMでは前処理をしなくても精度が高い事を確認する。

In [13]:
y_pred = model.predict(x_test)

In [14]:
print(classification_report(y_test, y_pred))

              precision    recall  f1-score   support

           0       1.00      0.99      0.99        69
           1       0.95      1.00      0.97        18

    accuracy                           0.99        87
   macro avg       0.97      0.99      0.98        87
weighted avg       0.99      0.99      0.99        87



正解率が99%となりニューラルネットワークモデルと比べて10%高いことが分かる

### 反実仮想を作成

In [15]:
d = dice_ml.Data(dataframe = pd.concat([x_test, y_test], axis=1),
                 continuous_features = ["week", "age", "prio"], #量的変数を選択
                 outcome_name = "arrest" #結果系変数を選択
                )
m = dice_ml.Model(model=model, backend="sklearn")
exp = dice_ml.Dice(d, m)

In [16]:
pre_counter = x_test.iloc[0:10, :] 
dice_exp = exp.generate_counterfactuals(pre_counter, total_CFs=4, desired_class = "opposite")
dice_exp.visualize_as_dataframe(show_only_changes=True)

100%|█████████████████████████████████████████████████| 10/10 [00:00<00:00, 10.87it/s]

Query instance (original outcome : 0)





Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,0,20,1,0,0,1,9,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,26,1,-,1,0,0,1,-,1
1,46,0,-,1,0,0,0,-,1
2,18,0,-,1,0,0,1,-,1
3,10,0,35,1,0,0,1,-,1


Query instance (original outcome : 1)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,36,1,23,1,0,0,0,3,1



Diverse Counterfactual set (new outcome: 0.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,1,-,1,1,0,0,-,0
1,52,1,-,1,0,0,0,10,0
2,52,1,39,1,0,0,0,-,0
3,52,1,-,0,0,0,0,10,0


Query instance (original outcome : 0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,1,28,1,0,0,1,4,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,9,0,-,1,0,0,1,-,1
1,46,1,-,1,0,1,1,-,1
2,18,1,-,1,0,0,1,-,1
3,50,1,-,1,0,0,1,-,1


Query instance (original outcome : 0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,0,22,1,0,0,1,1,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,9,0,41,1,0,0,1,-,1
1,39,0,34,1,0,0,1,-,1
2,40,0,-,1,0,0,1,8,1
3,19,0,-,1,0,0,1,-,1


Query instance (original outcome : 0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,0,21,1,0,0,0,2,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,31,0,-,1,0,0,0,-,1
1,23,0,-,1,0,0,0,9,1
2,16,0,-,1,0,0,0,-,1
3,28,0,-,1,0,1,0,-,1


Query instance (original outcome : 0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,1,18,1,0,0,0,4,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,12,1,-,1,0,0,0,-,1
1,36,1,-,1,0,0,1,-,1
2,34,0,-,1,0,0,0,-,1
3,49,1,-,0,0,0,0,-,1


Query instance (original outcome : 0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,0,19,0,1,0,1,3,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,40,0,-,0,1,0,1,-,1
1,43,0,-,0,0,0,1,-,1
2,32,0,39,0,1,0,1,-,1
3,37,0,31,0,1,0,1,-,1


Query instance (original outcome : 0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,0,20,1,0,0,0,1,0



Diverse Counterfactual set (new outcome: 1.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,39,0,-,0,0,0,0,-,1
1,43,0,-,1,0,0,0,-,1
2,29,0,41,1,0,0,0,-,1
3,17,0,34,1,0,0,0,-,1


Query instance (original outcome : 1)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,26,0,32,1,1,0,0,2,1



Diverse Counterfactual set (new outcome: 0.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,0,-,1,1,0,0,-,0
1,52,1,-,1,1,0,0,-,0
2,52,0,-,0,1,0,0,-,0
3,52,0,34,1,1,0,0,-,0


Query instance (original outcome : 1)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,24,0,40,1,1,0,0,2,1



Diverse Counterfactual set (new outcome: 0.0)


Unnamed: 0,week,fin,age,race,wexp,mar,paro,prio,arrest
0,52,0,-,1,1,0,0,-,0
1,52,0,-,1,1,1,0,-,0
2,52,0,-,1,1,0,0,6,0
3,52,0,-,1,1,0,1,-,0


ここから精度の高い反実仮想を基に原因と結果の関係性を考察する。