In [1]:
from tqdm import tqdm
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio

In [7]:
def plot_shap(shap_values:pd.DataFrame, file_name:str, n:int = 5):
    def customize_n(temp, fixed_vocab, n):
        custom_n = n
        for idx, row in temp.head(n).iterrows():
            if row.Word in fixed_vocab:
                custom_n +=1
        return custom_n
    class_names = ["Aussage", "Meinung", "Prognose", "Sonstiges"]
    fig = make_subplots(rows=1, cols=4,
                        horizontal_spacing=.12,
                        vertical_spacing=.07,
                        subplot_titles = class_names)
    fig.update_layout(template='plotly_white')
        
    quotes = shap_values[(shap_values.Word == ':') | 
                               (shap_values.Word == '"') | 
                               (shap_values.Word == 'sagt') | 
                               (shap_values.Word == 'sagte') | 
                               (shap_values.Word == 'sagtet') |
                               (shap_values.Word == 'sagten') |
                                (shap_values.Word == 'gesagt') 
                             ].sort_values('Word', ascending=False)
    punctuation = shap_values[(shap_values.Word == '.') | 
                              (shap_values.Word == '!') | 
                              (shap_values.Word == '?')
                             ].sort_values('Word', ascending=False)
    modal_verbs = shap_values[
                               (shap_values.Word == 'muss') |
                               (shap_values.Word == 'müssen') |
                               (shap_values.Word == 'kann') |
                               (shap_values.Word == 'können') |
                               (shap_values.Word == 'darf') |
                               (shap_values.Word == 'dürfen') |
                               (shap_values.Word == 'soll') |
                               (shap_values.Word == 'sollen') |
                               (shap_values.Word == 'will') |
                               (shap_values.Word == 'wollen') |
                               (shap_values.Word == 'mag') |
                               (shap_values.Word == 'mögen') 
                              ].sort_values('Word', ascending=False)
    future_tense = shap_values[(shap_values.Word == 'werde') | 
                               (shap_values.Word == 'wirst') | 
                               (shap_values.Word == 'wird') | 
                               (shap_values.Word == 'werden') |
                               (shap_values.Word == 'werdet') 
                             ].sort_values('Word', ascending=False)
    fixed_vocab = punctuation.Word.to_list() + modal_verbs.Word.to_list() + future_tense.Word.to_list() 
    

    label = class_names[0]
    col = 1
    row = 1
    temp = shap_values.iloc[shap_values[label].abs().sort_values(ascending=False).index]
    custom_n = customize_n(temp, fixed_vocab, n)
    temp = temp.head(custom_n).iloc[::-1]
    temp = pd.concat([quotes, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    temp = pd.concat([modal_verbs, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    temp = pd.concat([future_tense, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    temp = pd.concat([punctuation, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    color = np.where(temp[label]<0, "#E69F00", "#0072B2")
    fig.add_trace(go.Bar(y=temp['Word'], x=temp[label],
                    showlegend=False,
                    marker_color=color,
                    name=label,
                    orientation='h'), row=row, col=col)
    fig.update_yaxes(nticks=temp.shape[0], col=col, row=row)

    label = class_names[1]
    col = 2
    row = 1
    temp = shap_values.iloc[shap_values[label].abs().sort_values(ascending=False).index]
    custom_n = customize_n(temp, fixed_vocab, n)
    temp = temp.head(custom_n).iloc[::-1]
    temp = pd.concat([quotes, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    temp = pd.concat([modal_verbs, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    temp = pd.concat([future_tense, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    temp = pd.concat([punctuation, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    color = np.where(temp[label]<0, "#E69F00", "#0072B2")
    fig.add_trace(go.Bar(y=temp['Word'], x=temp[label],
                    showlegend=False,
                    marker_color=color,
                    name=label,
                    orientation='h'), row=row, col=col)
    fig.update_yaxes(nticks=temp.shape[0], col=col, row=row)

    label = class_names[2]
    col = 3
    row = 1
    temp = shap_values.iloc[shap_values[label].abs().sort_values(ascending=False).index]
    custom_n = customize_n(temp, fixed_vocab, n)
    temp = temp.head(custom_n).iloc[::-1]
    temp = pd.concat([quotes, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    temp = pd.concat([modal_verbs, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    temp = pd.concat([future_tense, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    temp = pd.concat([punctuation, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    color = np.where(temp[label]<0, "#E69F00", "#0072B2")
    fig.add_trace(go.Bar(y=temp['Word'], x=temp[label],
                    showlegend=False,
                    marker_color=color,
                    name=label,
                    orientation='h'), row=row, col=col)
    
    fig.update_yaxes(nticks=temp.shape[0], col=col, row=row)

    label = class_names[3]
    col = 4
    row = 1
    temp = shap_values.iloc[shap_values[label].abs().sort_values(ascending=False).index]
    custom_n = customize_n(temp, fixed_vocab, n)
    temp = temp.head(custom_n).iloc[::-1]
    temp = pd.concat([quotes, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    temp = pd.concat([modal_verbs, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    temp = pd.concat([future_tense, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    temp = pd.concat([punctuation, temp]).drop_duplicates(keep='first').reset_index(drop=True)
    color = np.where(temp[label]<0, "#E69F00", "#0072B2") # "#D55E00" --> "#E69F00"
    fig.add_trace(go.Bar(y=temp['Word'], x=temp[label],
                    showlegend=False,
                    marker_color=color,
                    name=label,
                    orientation='h'), row=row, col=col)

    fig.update_yaxes(nticks=temp.shape[0]*2, col=col, row=row)
    fig.add_shape( # add a horizontal "target" line
    type="line", line_color="gray", line_width=1, opacity=.6, line_dash="dash",
    x0=0, x1=1, xref="paper", y0=2.5, y1=2.5, yref="y"
    )
    fig.add_shape( # add a horizontal "target" line
    type="line", line_color="gray", line_width=1, opacity=.6, line_dash="dash",
    x0=0, x1=1, xref="paper", y0=7.5, y1=7.5, yref="y"
    )
    fig.add_shape( # add a horizontal "target" line
    type="line", line_color="gray", line_width=1, opacity=.6, line_dash="dash",
    x0=0, x1=1, xref="paper", y0=19.5, y1=19.5, yref="y"
    )
    fig.add_shape( # add a horizontal "target" line
    type="line", line_color="gray", line_width=1, opacity=.6, line_dash="dash",
    x0=0, x1=1, xref="paper", y0=25.5, y1=25.5, yref="y"
    )
    
    max_ = temp[label].max() * 1.5
    
    fig.add_annotation(
    x=max_
    , y=-.1
    , text=f'<b>Punctuation</b>'
    , yanchor='bottom'
    , showarrow=False
    , ax=-20
    , ay=-30
    ,yref="y4"
    ,xref="x4"
    , align="right"
    , font=dict(size=11, color="black")
    ,textangle=90)
    
    fig.add_annotation(
    x=max_
    , y=3.5
    , text=f'<b>Future Tense</b>'
    , yanchor='bottom'
    , showarrow=False
    , ax=-20
    , ay=-30
    ,yref="y4"
    ,xref="x4"
    , align="right"
    , font=dict(size=11, color="black")
    ,textangle=90)
    
    fig.add_annotation(
    x=max_
    , y=13.5
    , text=f'<b>Modal Verbs</b>'
    , yanchor='middle'
    , showarrow=False
    , ax=-20
    , ay=-30
    ,yref="y4"
    ,xref="x4"
    , align="right"
    , font=dict(size=11, color="black")
    ,textangle=90)
    
    fig.add_annotation(
    x=max_
    , y=22.5
    , text=f'<b>Quotes</b>'
    , yanchor='middle'
    , showarrow=False
    , ax=-20
    , ay=-30
    ,yref="y4"
    ,xref="x4"
    , align="right"
    , font=dict(size=11, color="black")
    ,textangle=90)
    
    fig.add_annotation(
    x=max_
    , y=28
    , text=f'<b>Top {n}</b>'
    , yanchor='middle'
    , showarrow=False
    , ax=-20
    , ay=-30
    ,yref="y4"
    ,xref="x4"
    , align="right"
    , font=dict(size=11, color="black")
    ,textangle=90)
    
    pio.write_image(fig, f'/home/sami/FLAIR/Visuals/PNG/{file_name}_ShapPlot.png',scale=10, width=1080, height=800)
    pio.write_image(fig, f'/home/sami/FLAIR/Visuals/JPG/{file_name}_ShapPlot.jpg',scale=10, width=1080, height=800)
    pio.write_image(fig, f'/home/sami/FLAIR/Visuals/SVG/{file_name}_ShapPlot.svg',scale=10, width=1080, height=800)
    pio.write_image(fig, f'/home/sami/FLAIR/Visuals/PDF/{file_name}_ShapPlot.pdf',scale=10, width=1080, height=800)
    fig.write_html(f"/home/sami/FLAIR/Visuals/HTML/{file_name}_ShapPlot.html")
    fig.show()


In [8]:
shap_values = pd.read_csv('/home/sami/FLAIR/Results/SHAP/E4_ShapValues.csv')
plot_shap(shap_values, 'EXP4')