In [1]:
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
import nltk
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix


In [2]:
def get_experiment(experiment:str):
    df = pd.read_csv('/home/sami/FLAIR/Data/Experiment_Frame.csv')
    return df[df[experiment].notna()]
def parse_doc_type(doc_type):
    if doc_type=='SOCIAL_MEDIA':
        return 'Twitter'
    if doc_type=='PROTOKOLL':
        return 'Protocol'
    if doc_type=='NEWSPAPER':
        return 'Newspaper'
    if doc_type=='MANIFESTO':
        return 'Manifesto'
    if doc_type=='TALKSHOW':
        return 'Talkshow'

In [8]:
def get_scores_by_doc():
    dct = {
        'Document Type': [],
        'Experiment': [],
        'F1' : []
    }
    for idx in range(5):
        df = pd.read_csv(f"/home/sami/FLAIR/Results/TRANSFORMER/E{idx}/DistilBert_results.csv")
        df['DOCUMENT_TYPE'] = df['DOCUMENT_TYPE'].apply(lambda x: parse_doc_type(x))
        for document_type, group in df.groupby('DOCUMENT_TYPE'):
            dct['Document Type'].append(document_type)
            dct['Experiment'].append(f'E{idx}')
            y = group['GROUNDTRUTH'].to_numpy()
            y_pred = group['PREDICTION'].to_numpy()
            dct['F1'].append(np.round(f1_score(y, y_pred, average='weighted'),2))
    scores_by_doc = pd.DataFrame(dct)
    return scores_by_doc

def add_to_scores_by_doc(fig, exp, score, doc, row, col):
    colors = ["#F0E442", "#E69F00", "#009E73", "#56B4E9", "#0072B2"]
    doc_types = ['Protocol', 'Twitter', 'Newspaper', 'Manifesto', 'Talkshow']
    color_dict = {doc_type: color for doc_type, color in zip(doc_types, colors)}
    fig.add_trace(go.Scatter(x=exp, y=score,marker=dict(color=color_dict[doc]), name=doc),row=row, col=col)
    return fig

def get_confusion_matrix(experiment):
    path = f"/home/sami/FLAIR/Results/TRANSFORMER/{experiment}/DistilBert_results.csv"
    df = pd.read_csv(path)
    y = df['GROUNDTRUTH'].values
    y_pred = df['PREDICTION'].values
    return confusion_matrix(y, y_pred)

def plot_detailed_scores():
    # True label on y-axis
    # Predicted label on x-axis
    fig = make_subplots(rows=2, 
                    cols=3, 
                    specs=[[{},{},{}], [{},{},{}]],
                    horizontal_spacing=.04,
                    vertical_spacing=.09,
                    subplot_titles=['Experiment 0','Experiment 1','Experiment 2','Experiment 3','Experiment 4',''],
                   )
    fig.update_layout(template='plotly_white')
    scores_by_doc = get_scores_by_doc() # Needs adjustment for new format
    for doc_type in ['Twitter', 'Newspaper', 'Manifesto', 'Talkshow', 'Protocol']:
        temp = scores_by_doc[scores_by_doc['Document Type'] == doc_type]
        doc = doc_type
        exp = temp['Experiment'].to_list()
        score = temp['F1'].to_list()
        fig = add_to_scores_by_doc(fig, exp, score, doc, 2, 3)
    label = ['Aussage', 'Meinung', 'Prognose', 'Sonstiges']
    cm = get_confusion_matrix('E0')
    fig.add_trace(go.Heatmap(
                    z=cm,
                    x=label,
                    y=label,
                    text = cm,
                    texttemplate="%{text}",
                    colorscale='Viridis',
                    showlegend=False,
                    showscale=False
                       ),row = 1, col = 1)
    cm = get_confusion_matrix('E1')
    fig.add_trace(go.Heatmap(
                    z=cm,
                    x=label,
                    y=label,
                    text = cm,
                    texttemplate="%{text}",
                    colorscale='Viridis',
                    showlegend=False,
                    showscale=False
                       ),row = 1, col =2)
    cm = get_confusion_matrix('E2')
    fig.add_trace(go.Heatmap(
                    z=cm,
                    x=label,
                    y=label,
                    text = cm,
                    texttemplate="%{text}",
                    colorscale='Viridis',
                    showlegend=False,
                    showscale=False
                       ),row = 1, col =3)
    cm = get_confusion_matrix('E3')
    fig.add_trace(go.Heatmap(
                    z=cm,
                    x=label,
                    y=label,
                    text = cm,
                    texttemplate="%{text}",
                    colorscale='Viridis',
                    showlegend=False,
                    showscale=False
                       ),row = 2, col =1)
    cm = get_confusion_matrix('E4')
    fig.add_trace(go.Heatmap(
                    z=cm,
                    x=label,
                    y=label,
                    text = cm,
                    texttemplate="%{text}",
                    colorscale='Viridis',
                    showlegend=False,
                    showscale=False
                       ),row = 2, col =2)

    fig.update_yaxes(nticks=6,row=2, col=3)
    fig.update_yaxes(title='Groundtruth', row=1, col=1)
    fig.update_yaxes(title='Groundtruth', row=2, col=1)
    fig.update_xaxes(title='Predicted', row=2, col=1)
    fig.update_xaxes(title='Predicted', row=2, col=2)
    fig.update_xaxes(showticklabels=False, row=1, col=1)
    fig.update_xaxes(showticklabels=False, row=1, col=2)
    fig.update_yaxes(showticklabels=False, row=1, col=2)
    fig.update_yaxes(showticklabels=False, row=1, col=3)
    fig.update_yaxes(showticklabels=False, row=2, col=2)
    fig.update_yaxes(title_text='F1', row=2, col=3)
    fig.update_yaxes(title_text='F1', secondary_y=False, row=3, col=1)
    fig.update_yaxes(title_text='Sentences', secondary_y=True, row=3, col=1)
    pio.write_image(fig, '/home/sami/FLAIR/Visuals/PNG/Detailed_Scores.png',scale=10, width=1080, height=600)
    pio.write_image(fig, '/home/sami/FLAIR/Visuals/SVG/Detailed_Scores.svg',scale=10, width=1080, height=600)
    pio.write_image(fig, '/home/sami/FLAIR/Visuals/JPG/Detailed_Scores.jpg',scale=10, width=1080, height=600)
    pio.write_image(fig, '/home/sami/FLAIR/Visuals/PDF/Detailed_Scores.pdf',scale=10, width=1080, height=600)
    fig.write_html("/home/sami/FLAIR/Visuals/HTML/Detailed_Scores.html")
    fig.show()
plot_detailed_scores()