In [29]:
import pandas as pd
import plotly.graph_objects as go

def get_input(filepath, x_col, y_col):
    #reads data from pandas
    #returns x, y
    df = pd.read_csv(filepath)
    #split by run (model and strategy)
    data = []
    for name, group in df.groupby(['model', 'active_learning_strategy']):
        x = group[x_col].to_list()
        y = group[y_col].to_list()
        namestr = '_'.join(name)
        data.append((namestr, x, y))
    return data #list of tuple of (namestr, x, y)

def accept_filepath(func):
    def wrapper(filepath, *args, **kwargs):
        x_col = kwargs.pop('x_col')
        y_col = kwargs.pop('y_col')
        data = get_input(filepath, x_col, y_col)
        return func(data, *args, **kwargs)
    return wrapper


In [30]:
@accept_filepath
def make_plot_lines(data, *args, **kwargs):
    layout = go.Layout(title='Line Plot', xaxis=dict(title='X-axis'), yaxis=dict(title='Y-axis'))
    fig = go.Figure(layout=layout)
    for item in data:
        name, x, y = item[0], item[1], item[2]
        trace = go.Scatter(x=x, y=y, mode='lines', name=name, *args, **kwargs)
        fig.add_trace(trace)
    fig.show()



make_plot_lines('./results/results_lstm_refined.csv', x_col='dataset_size', y_col='f1_avg')
make_plot_lines('./results/results_bert_refined.csv', x_col='dataset_size', y_col='f1_avg')