In [1]:
import json
import numpy as np
import plotly.graph_objects as go

In [19]:
exp3_1000 = json.load(open("../Exp3_1000_Init/trainer_state.json"))
exp3_100 = json.load(open("../Exp3_100_Init/trainer_state.json"))
exp3_10 = json.load(open("../Exp3_10_Init/trainer_state.json"))
exp3_0 = json.load(open("../Exp3_0_Init/trainer_state.json"))
explore = json.load(open("../Explore_Only/trainer_state.json"))
exploit = json.load(open("../Exploit_Only/trainer_state.json"))

dft = json.load(open("../Direct_Finetuning/trainer_state.json"))

In [3]:
datasets = ['adversarial_qa/dbidaf', 'ag_news', 'amazon_polarity', 'cnn_dailymail/3.0.0', 'common_gen', 'cos_e/v1.11', 'glue/mrpc', 'kilt_tasks/hotpotqa']

In [4]:
def plot(data: list, target_column: str = "cumulative_estimated_reward", fig = None, strategy: str = "exp3_10"):
    if target_column not in data[0] and target_column not in ("eval_accuracy"):
        raise ValueError(f"target_column {target_column} not in data")
    
    if strategy not in ("exp3_1000", "exp3_100", "exp3_10", "exp3_0", "explore", "exploit", "dft"):
        raise ValueError(f"strategy {strategy} not in data")

    if target_column == "cumulative_estimated_reward" or target_column == "probabilities" or target_column == "samples_seen_per_dataset":

        if target_column == "cumulative_estimated_reward":
            title = "Cumulative Estimated Reward"
            yaxis_title = "Cumulative Estimated Reward"
        elif target_column == "probabilities":
            title = "Probabilities"
            yaxis_title = "Probability"
        elif target_column == "samples_seen_per_dataset":
            title = "Samples Seen Per Dataset"
            yaxis_title = "Samples Seen"

        rewards = []
        steps = []
        for i in range(len(data)):
            if target_column in data[i]:
                rewards.append(list(data[i][target_column].values()))
                steps.append(data[i]["step"])
        rewards = np.array(rewards)

        fig = go.Figure()

        if rewards.shape[1] != len(datasets):
            if strategy == "explore" or strategy == "exploit":
                new_datasets = ['adversarial_qa/dbidaf', 'ag_news', 'amazon_polarity', 'cnn_dailymail/3.0.0', 'common_gen', 'copa', 'cos_e/v1.11', 'glue/mrpc', 'kilt_tasks/hotpotqa']
            elif strategy == "dft":
                new_datasets = ['copa']

            for i in range(rewards.shape[1]):
                fig.add_trace(
                    go.Scatter(
                        x=steps,
                        y=rewards[:, i],
                        mode="lines",
                        name=f"{new_datasets[i]}"
                    )
                )
        else:
            for i in range(rewards.shape[1]):
                fig.add_trace(
                    go.Scatter(
                        x=steps,
                        y=rewards[:, i],
                        mode="lines",
                        name=f"{datasets[i]}"
                    )
                )

        fig.update_layout(
            title=title,
            xaxis_title="Steps",
            yaxis_title=yaxis_title
        )

        # increase thickness and boldness of lines
        for i in range(len(fig.data)):
            fig.data[i].line.width = 3
            fig.data[i].line.dash = "solid"

        return fig
    
    elif target_column == "eval_accuracy":
        accuracies = []
        steps = []

        for i in range(len(data)):
            if "eval_accuracy" in data[i]:
                accuracies.append(data[i]["eval_accuracy"])
                steps.append(data[i]["step"])
        accuracies = np.array(accuracies)

        if fig is None:
            fig = go.Figure()

        fig.add_trace(
            go.Scatter(
                x=steps,
                y=accuracies,
                mode="lines",
                name=strategy
            )
        )

        fig.update_layout(
            title="Accuracy",
            xaxis_title="Steps",
            yaxis_title="Accuracy"
        )

        return fig

    elif target_column == 'loss':
        losses = []
        steps = []

        for i in range(len(data)):
            if "loss" in data[i]:
                losses.append(data[i]["loss"])
                steps.append(data[i]["step"])
        losses = np.array(losses)
        
        if fig is None:
            fig = go.Figure()

        fig.add_trace(
            go.Scatter(
                x=steps,
                y=losses,
                mode="lines",
                name=strategy
            )
        )

        fig.update_layout(
            title="Loss",
            xaxis_title="Steps",
            yaxis_title="Loss"
        )

        return fig

In [7]:
fig = plot(exp3_1000['log_history'], "cumulative_estimated_reward")
fig.write_image("assets/exp3_1000-cer.png", width=1200, height=500, scale = 10)

fig = plot(exp3_1000['log_history'], target_column="probabilities")
fig.write_image("assets/exp3_1000-probs.png", width=1200, height=500, scale = 10)

fig = plot(exp3_1000['log_history'], target_column="samples_seen_per_dataset")
fig.write_image("assets/exp3_1000-samples-seen.png", width=1200, height=500, scale = 10)

In [15]:
fig = plot(exp3_100['log_history'], "cumulative_estimated_reward")
fig.write_image("assets/exp3_100-cer.png", width=1200, height=500, scale = 10)

fig = plot(exp3_100['log_history'], target_column="probabilities")
fig.write_image("assets/exp3_100-probs.png", width=1200, height=500, scale = 10)

fig = plot(exp3_100['log_history'], target_column="samples_seen_per_dataset")
fig.write_image("assets/exp3_100-samples-seen.png", width=1200, height=500, scale = 10)

In [8]:
fig = plot(exp3_10['log_history'])
fig.write_image("assets/exp3_10-cer.png", width=1200, height=500, scale = 10)

fig = plot(exp3_10['log_history'], target_column="probabilities")
fig.write_image("assets/exp3_10-probs.png", width=1200, height=500, scale = 10)

fig = plot(exp3_10['log_history'], target_column="samples_seen_per_dataset")
fig.write_image("assets/exp3_10-samples-seen.png", width=1200, height=500, scale = 10)

In [9]:
fig = plot(exp3_0['log_history'])
fig.write_image("assets/exp3_0-cer.png", width=1200, height=500, scale = 10)

fig = plot(exp3_0['log_history'], target_column="samples_seen_per_dataset")
fig.write_image("assets/exp3_0-samples-seen.png", width=1200, height=500, scale = 10)

fig = plot(exp3_0['log_history'], target_column="probabilities")
fig.write_image("assets/exp3_0-probs.png", width=1200, height=500, scale = 10)

In [23]:
fig = plot(explore['log_history'], target_column="samples_seen_per_dataset", strategy="explore")
fig.write_image("assets/explore_only-samples-seen.png", width=1200, height=500, scale = 10)

In [24]:
fig = plot(exploit['log_history'], target_column="samples_seen_per_dataset", strategy="exploit")
fig.write_image("assets/exploit_only-samples-seen.png", width=1200, height=500, scale = 10)

In [25]:
fig = plot(dft['log_history'], target_column="samples_seen_per_dataset", strategy="dft")
fig.write_image("assets/dft-samples-seen.png", width=1200, height=500, scale = 10)

In [21]:
fig = plot(exp3_1000['log_history'], target_column="eval_accuracy", strategy="exp3_1000")
fig = plot(exp3_100['log_history'], target_column="eval_accuracy", fig = fig, strategy="exp3_100")
fig = plot(exp3_10['log_history'], target_column="eval_accuracy", fig = fig, strategy="exp3_10")
fig = plot(exp3_0['log_history'], target_column="eval_accuracy", fig = fig, strategy="exp3_0")
fig = plot(explore['log_history'], target_column="eval_accuracy", fig = fig, strategy="explore")
fig = plot(exploit['log_history'], target_column="eval_accuracy", fig = fig, strategy="exploit")
fig = plot(dft['log_history'], target_column="eval_accuracy", fig = fig, strategy="dft")

fig.write_image("assets/accuracies.png", width=1200, height=500, scale = 10)

In [22]:
fig = plot(exp3_1000['log_history'], target_column="loss", strategy="exp3_1000")
fig = plot(exp3_100['log_history'], target_column="loss", fig = fig, strategy="exp3_100")
fig = plot(exp3_10['log_history'], target_column="loss", fig = fig, strategy="exp3_10")
fig = plot(exp3_0['log_history'], target_column="loss", fig = fig, strategy="exp3_0")
fig = plot(explore['log_history'], target_column="loss", fig = fig, strategy="explore")
fig = plot(exploit['log_history'], target_column="loss", fig = fig, strategy="exploit")
fig = plot(dft['log_history'], target_column="loss", fig = fig, strategy="dft")

fig.write_image("assets/losses.png", width=1200, height=500, scale = 10)