# SHAP

SHAP를 통한 모델 설명  
https://shap.readthedocs.io/en/latest/index.html

In [4]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import shap
from transformers_interpret import SequenceClassificationExplainer
from ferret import Benchmark


In [None]:
import pandas as pd
train = './train_small.csv'
test = './test_small.csv'
test_pd = pd.read_csv('./test_small.csv')

In [None]:
device = torch.device("cuda")

In [None]:
import numpy as np
import pandas as pd
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, ElectraForSequenceClassification, AdamW, TextClassificationPipeline
from tqdm.notebook import tqdm

In [None]:
# 모델 불러오기

model = ElectraForSequenceClassification.from_pretrained("beomi/KcELECTRA-base").to(device)
model.load_state_dict(torch.load("jh_model_shap.pt"))
tokenizer = AutoTokenizer.from_pretrained("beomi/KcELECTRA-base")
sentiment_classifier = TextClassificationPipeline(tokenizer=tokenizer, model=model, device=0)

In [None]:
classifier = pipeline('text-classification', top_k=1, model=model, tokenizer = tokenizer, device=0)

In [None]:
import shap
explainer= shap.Explainer(classifier)

In [None]:
shap_values = explainer(test_pd['sentence'][:21])

In [None]:
shap.plots.text(shap_values)

#### 점수 낮은 문장들

In [None]:
sorted_pd = pd.read_csv('./저희 모델 분류 결과_10words.csv')

In [None]:
score_low = sorted_pd.sort_values(by=['score'])
score_low

In [None]:
shap_low = explainer(score_low['sentence'][0:21])

In [None]:
shap.plots.text(shap_low)

#### 점수 높은 순서

In [None]:
score_high = sorted_pd.sort_values(by=['score'], ascending=False )
score_high

In [None]:
shap_high=explainer(score_high['sentence'][:20])

In [None]:
shap.plots.text(shap_high)

#### 겸양의 저희로 분류한 점수 순서

In [None]:
neg_predicted = sorted_pd[sorted_pd['predicted']==0]
neg_predicted

In [None]:
neg_high = neg_predicted.sort_values(by=['score'], ascending=False)
neg_high

In [None]:
shap_neg = explainer(neg_high['sentence'][0:20])

In [None]:
shap_neg = explainer(neg_high['sentence'][0:20])
shap.plots.text(shap_neg)

### barplot 보기

In [None]:
#폰트 바꾸기
import matplotlib.font_manager as fm
font_file_path_list = fm.findSystemFonts(fontpaths=None, fontext='ttf')
print(len(font_file_path_list))
#print(font_file_path_list[:5]) #복잡하고 길게 나오는 군요 

fav_font_file_path_lst = filter(lambda x: True if "malgun" in x else False
                                , font_file_path_list)
print()
for font_file_path in fav_font_file_path_lst:
    print(font_file_path)

In [None]:
malgun = fm.FontProperties(fname='C:\Windows\Fonts\malgun.ttf')

In [None]:
#pyplot parameter 확인
for key in plt.rcParamsDefault.keys():
    if 'font' in key:
        print("{}: {}".format(key, plt.rcParamsDefault[key]))

In [None]:
font_path = 'C:\Windows\Fonts\malgun.ttf'
font_name = fm.FontProperties(fname=font_path).get_name()
plt.rcParams['font.family'] = font_name
plt.rcParams['font.size'] = 20

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['axes.unicode_minus'] = False
plt.rc('font', family='malgun')
font_path = 'C:\Windows\Fonts\malgun.ttf'

#### 테스팅

In [None]:
shap_value = explainer([test_pd['sentence'][0]])

In [None]:
shap.plots.bar(shap_values[0, :, 'LABEL_1'])

In [None]:
neg_value = explainer(neg_high['sentence'][0:20])

In [None]:
shap.plots.bar(neg_value[0, :, 'LABEL_0'])

#### Plotting the top words impacting a specific class
In addition to slicing, Explanation objects also support a set of reducing methods. Here we use the .mean(0) to take the average impact of all words towards the "joy" class. Note that here we are also averaging over three examples, to get a better summary you would want to use a larger portion of the dataset.

##### 겸양의 저희로 판단한 결과물 top 20개의 분포 평균

In [None]:
shap.plots.bar(shap_high[:, :, 'LABEL_1'].mean(0), order=shap.Explanation.argsort.flip)

In [None]:
shap.plots.bar(shap_high[:, :, 'LABEL_0'].mean(0), order=shap.Explanation.argsort.flip)

##### 지칭의 저희로 판단한 결과물 top 20개의 분포 평균

In [None]:
shap.plots.bar(shap_neg[:, :, 'LABEL_0'].mean(0), order=shap.Explanation.argsort.flip)

In [None]:
shap.plots.bar(shap_neg[:, :, 'LABEL_1'].mean(0))

In [None]:
shap.plots.bar(shap_neg[:, :, 'LABEL_1'].mean(0), order=shap.Explanation.argsort.flip)

In [None]:
shap.plots.bar(shap_low[:, :, 'LABEL_0'].mean(0), order=shap.Explanation.argsort.flip)

In [None]:
shap.plots.bar(shap_low[:, :, 'LABEL_1'].mean(0), order=shap.Explanation.argsort.flip)

#### 전체 데이터로 확인

In [None]:
all_value = explainer(test_pd['sentence'])

In [None]:
#지칭의 저희로 판단하는 단어들
shap.plots.bar(all_value[:, :, 'LABEL_1'].mean(0), order=shap.Explanation.argsort.flip)

In [None]:
#겸양의 저희로 판단하는 단어들
shap.plots.bar(all_value[:, :, 'LABEL_0'].mean(0), order=shap.Explanation.argsort.flip)