In [1]:
import pandas as pd
from os.path import join, isdir, isfile
from os import listdir
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import re
import json
import seaborn as sns
from plotly.subplots import make_subplots

In [2]:
def get_data(data, by_transformer=True):
    final_data = {}
    for k, v in data.items():
         #{"names": [], "values": [], "errors": []}
        if by_transformer:
            final_data[k] = {tr: {"names": [], "values": [], "errors": []} 
                             for tr in v.keys()} 
            for transformer, splits in v.items():
                for split_name, (val, err) in splits.items():
                    curr_dict = final_data[k][transformer]
                    curr_dict["names"].append(split_name)
                    curr_dict["values"].append(val)
                    curr_dict["errors"].append(err)
        else:
            final_data[k] = {spl: {"names": [], "values": [], "errors": []} 
                             for tr in v["xlm-roberta-base"].keys()}
            for transformer, splits in v.items():
                for split_name, (val, err) in splits.items():
                    curr_dict = final_data[k][split_name]
                    curr_dict["names"].append(transformer)
                    curr_dict["values"].append(val)
                    curr_dict["errors"].append(err)
    
    return final_data

In [3]:
def barplot_metric(values, names, title, n=None):
    dict_ = {"name": n} if n is not None else {}
    fig = go.Figure()
    fig.add_trace(
        go.Bar(
            **dict_,
            x=names,
            y=np.array(values)*100,
            text=np.array(values)*100,
            texttemplate='%{text:.2f}%',
            textposition='auto',
        )
    )
    
    fig.update_layout(
        title_text=title,
        height=500,
    )
    return fig

def barplot_add_class(fig, values, names, name, curr_dict, feat_dict):
    fig.add_trace(
        go.Bar(
            **feat_dict,
            name=name,
            x=names,
            y=np.array(values)*100,
            text=np.array(values)*100,
            texttemplate='%{text:.2}%',
            textposition='auto',
        ),
        **curr_dict,
    )
    return fig

In [44]:
filepath = "../data/results_acc.json" 
filepath_class = "../data/results_acc_class.json" 

filepath_f1 = "../data/results.json"
filepath_f1_class = "../data/results_class.json"

transformers = {
    'Unbabel/xlm-roberta-comet-small': "RoBERTa-small",
    'xlm-roberta-base': "RoBERTa-base",
    'microsoft/xtremedistil-l6-h256-uncased': "XtremeDistil",
}
one_result = {
    "persent" :{
        "f1 macro":{
            'Unbabel/xlm-roberta-comet-small': 'original+normal+different+new+new_v2',
            'xlm-roberta-base': 'original+new',
            'microsoft/xtremedistil-l6-h256-uncased': 'original+new',
        },
        "accuracy":{
            'Unbabel/xlm-roberta-comet-small': 'original+normal+different+new+new_v2',
            'xlm-roberta-base': 'original+normal',
            'microsoft/xtremedistil-l6-h256-uncased': 'original+normal+different+new+new_v2',
        },
    },
    "multiemo" :{
        "f1 macro":{
            'Unbabel/xlm-roberta-comet-small': 'original+normal+different+new+new_v2',
            'xlm-roberta-base': 'original+normal',
            'microsoft/xtremedistil-l6-h256-uncased': 'original+different',
        },
        "accuracy":{
            'Unbabel/xlm-roberta-comet-small': 'original+normal+different+new+new_v2',
            'xlm-roberta-base': 'original+normal',
            'microsoft/xtremedistil-l6-h256-uncased': 'original+normal+different+new+new_v2',
        },
    },
         
}
classes = {
    "persent":{
        "class_0": "Positive",
        "class_1": "Negative",
        "class_2": "Neutral",
    },
    "multiemo": {
        "class_0": "Positive",
        "class_1": "Negative",
        "class_2": "Neutral",
        "class_3": "Ambivalent",
    }
}
splits = {
    # 'original',
    # 'normal',
    # 'different',
    # 'new',
    # 'new_v2',
    'original+normal': "Para",
    'original+different': "Para-Conv",
    'original+normal+different': "Both Para",
    'original+new': "Insp",
    'original+new_v2': "Insp-Lab",
    'original+new+new_v2': "Both Insp",
    'original+normal+different+new+new_v2': "All"
}

with open(filepath, "r") as f:
    data0 = json.loads(json.load(f))
with open(filepath_class, "r") as f:
    data_class = json.loads(json.load(f))
with open(filepath_f1, "r") as f:
    data_f1 = json.loads(json.load(f))
with open(filepath_f1_class, "r") as f:
    data_f1_class = json.loads(json.load(f))

In [45]:
data_f1["persent"]

{'Unbabel/xlm-roberta-comet-small': {'original': [0.361726188659668,
   0.016725335366700952],
  'normal': [0.3191667079925537, 0.03060524714922138],
  'different': [0.29915974140167234, 0.03804866821844679],
  'new': [0.3497793793678284, 0.0176627560641259],
  'new_v2': [0.30250959396362304, 0.010354932931823595],
  'original+normal': [0.3811680555343628, 0.03386205155634625],
  'original+different': [0.3864628493785858, 0.00761447307721606],
  'original+normal+different': [0.3995507657527923, 0.024047271502674945],
  'original+new': [0.36569696068763735, 0.04429143881012411],
  'original+new_v2': [0.38146510124206545, 0.024521155752355715],
  'original+new+new_v2': [0.3704035937786102, 0.013393410298221697],
  'original+normal+different+new+new_v2': [0.39079118371009824,
   0.019352559971180683]},
 'xlm-roberta-base': {'original': [0.37998137474060056, 0.07939363226196712],
  'normal': [0.3537957787513733, 0.03342476998640384],
  'different': [0.34085633158683776, 0.02704338114972989

In [77]:
def gain(new, base):
    return 100*(new - base)/(100-base)

def plot_all_data(data, title1, data2=None, base="original"):
    for dataset, v in data.items():
        fig = make_subplots(rows=3, cols=1, subplot_titles=list(transformers.values()))
        
        for i, (transformer, transformer_name) in enumerate(transformers.items()):
            values = v[transformer]
            title = transformer_name
            baseline = data[dataset][transformer][base][0]
            names,vals = [],[]
            for split, name in splits.items():
                mean = data[dataset][transformer][split][0]
                gain_value = gain(mean, baseline)
                names.append(name)
                vals.append(gain_value)
            curr_dict = dict(
                row=i+1, col=1
            )
            feat_dict=dict(
                showlegend=(i==0),
                marker_color='#636EFA'
            )
            fig = barplot_add_class(fig, vals, names, "F1 Macro" if data2 is not None else None, curr_dict=curr_dict, feat_dict=feat_dict)
            if data2 is not None:
                baseline = data2[dataset][transformer][base][0]
                names,vals = [],[]
                for split, name in splits.items():
                    mean = data2[dataset][transformer][split][0]
                    gain_value = gain(mean, baseline)
                    names.append(name)
                    vals.append(gain_value)
                feat_dict=dict(
                    showlegend=(i==0),
                    marker_color="#EF553B"
                )
                barplot_add_class(fig, vals, names, "Accuracy", curr_dict=curr_dict, feat_dict=feat_dict)
        fig.update_layout(
            # title_text=Gains per augumentation type"
            font=dict(
                family="Times New Roman",
                size=20,
            ),
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            width=900,
            height=900
        )
        for r in range(1,4):
            fig.update_yaxes(title_text="Gain [%]", row=r, col=1)
            # fig.update_xaxes(title_text="Augumentation Type", row=r, col=1)
        fig.update_annotations(font_size=20)
        fig.write_image(f"gains_{dataset}.pdf")
        fig.show()


def plot_all_data_class(data, metric, base="original"):
    colors = ['#636EFA', "#EF553B", '#00CC96', '#AB63FA']
    for dataset, v in data.items():
        fig = make_subplots(rows=3, cols=1, subplot_titles=list(transformers.values()))
        for i, (transformer, transformer_name) in enumerate(transformers.items()):
            values = v[transformer]
            title = transformer_name
            baseline = data[dataset][transformer][base][0]
            names,vals = [],[]
            for split, name in splits.items():
                means = data[dataset][transformer][split][0]
                gain_value = [gain(mean, baseline[k]) for k, mean in enumerate(means)]
                names.append(name)
                vals.append(gain_value)
            vals = np.array(vals)
            curr_dict = dict(
                row=i+1, col=1
            )
            for c in range(len(means)):
                feat_dict=dict(
                    showlegend=(i==0),
                    marker_color=colors[c]
                )
                fig = barplot_add_class(fig, vals[:, c], names, classes[dataset][f"class_{c}"], curr_dict=curr_dict, feat_dict=feat_dict)
        fig.update_layout(
            # title_text=Gains per augumentation type"
            font=dict(
                family="Times New Roman",
                size=20,
            ),
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            width=900,
            height=900
        )
        for r in range(1,4):
            fig.update_yaxes(title_text="Gain [%]", row=r, col=1)
            # fig.update_xaxes(title_text="Augumentation Type", row=r, col=1)
        fig.update_annotations(font_size=20)
        fig.write_image(f"gains_{dataset}_perclass_{metric}.pdf")
        fig.show()

def one_plot_to_rule_them_all(data_f1, data_acc, base="original"):
    colors = ['#636EFA', "#EF553B"]
    for dataset in ("persent", "multiemo"):
        fig = make_subplots(rows=2, cols=1)
        for i, (metric, data) in enumerate((("f1 macro", data_f1), ("accuracy", data_acc))):
            names, bests, baselines = [],[], []
            for transformer, transformer_name in transformers.items():
                
                baseline = data[dataset][transformer][base][0]
                # print(baseline)
                best_split = one_result[dataset][metric][transformer]
                best = data[dataset][transformer][best_split][0]
                names.append(transformer_name)
                bests.append(best)
                baselines.append(baseline)
                
            curr_dict = dict(
                row=i+1, col=1
            )
            feat_dict=dict(
                showlegend=(i==0),
                marker_color=colors[0]
            )
            fig = barplot_add_class(fig, baselines, names, "Baseline", curr_dict=curr_dict, feat_dict=feat_dict)
            feat_dict.update(dict(
                marker_color=colors[1]
            ))
            fig = barplot_add_class(fig, bests, names, "Best", curr_dict=curr_dict, feat_dict=feat_dict)
        
            metric_name = "Accuracy [%]" if metric == "accuracy" else "F1 macro [%]"
            ranges = [60,100] if dataset == "multiemo" else [20,60]
            fig.update_yaxes(title_text=metric_name,range=ranges, row=i+1, col=1)
                
        fig.update_layout(
            font=dict(
                family="Times New Roman",
                size=20,
            ),
            legend=dict(
                orientation="h",
                yanchor="bottom",
                y=1.02,
                xanchor="right",
                x=1
            ),
            width=900,
            height=500
        )
        fig.update_annotations(font_size=20)
        fig.write_image(f"baseline_best_{dataset}.pdf")
        fig.show()

In [78]:
one_plot_to_rule_them_all(data_f1, data0, base="original")

In [62]:
# plot_all_data(data_f1, "f1 macro - accuracy", data2=data0)

In [33]:
# plot_all_data_class(data_f1_class, "f1 macro")

In [58]:
# plot_all_data(data, "accuracy")

In [34]:
# plot_all_data_class(data_class, "accuracy")