# Fine-tuning for parenthesis balancing

## Setup

In [17]:
RESULTS_DATA_DIR = "../results"
FINE_TUNE_PREFIX = "fine_tune_paren_bal"

In [18]:
import json
import os
from collections import OrderedDict

import numpy as np

import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as colors

from rich.table import Table
from rich.console import Console

In [19]:
import plotly.io as pio
pio.renderers.default = "colab+vscode"

console = Console()

## Results

In [20]:
data = []
for filename in os.listdir(RESULTS_DATA_DIR):
    filepath = os.path.join(RESULTS_DATA_DIR, filename)
    if os.path.isfile(filepath) and filename.startswith(FINE_TUNE_PREFIX):
        with open(filepath) as fp:
            data.append(json.load(fp))

In [21]:
learning_rates = []
val_accuracy = {}
train_accuracies = {}

for datum in data:
    optimizer = datum["parameters"]["optimizer"]
    lr = datum["parameters"]["learning_rate"]
    if lr not in learning_rates:
        learning_rates.append(lr)
    if optimizer not in val_accuracy:
        val_accuracy[optimizer] = {}
        train_accuracies[optimizer] = {}
    val_accuracy[optimizer][lr] = datum["val_accuracy"]
    train_accuracies[optimizer][lr] = datum["train_accuracies"]

for optimizer, accuracies_by_lr in val_accuracy.items():
    val_accuracy[optimizer] = OrderedDict(
        [
            (lr, val_accuracy[optimizer][lr])
            for lr in sorted(accuracies_by_lr.keys())
        ]
    )
    train_accuracies[optimizer] = OrderedDict(
        [
            (lr, train_accuracies[optimizer][lr])
            for lr in sorted(accuracies_by_lr.keys())
        ]
    )

learning_rates = sorted(learning_rates)

In [22]:
table = Table(
    title="Validation accuracies by learning rate", show_header=True, header_style="bold"
)
table.add_column("Learning rate", justify="right")
for optimizer_name in val_accuracy.keys():
    table.add_column(f"Accuracy ({optimizer_name})", justify="right")
for learning_rate in learning_rates:
    accuracies = []
    for optimizer_name in val_accuracy.keys():
        if learning_rate in val_accuracy[optimizer_name]:
            accuracies.append(f"{val_accuracy[optimizer_name][learning_rate]:.2%}")
        else:
            accuracies.append("-")
    table.add_row(f"{learning_rate}", *accuracies)

console.print(table)

In [23]:
traces = []

for optimizer, accuracies_by_lr in val_accuracy.items():
    traces.append(
        go.Scatter(
            x=list(accuracies_by_lr.keys()),
            y=list(accuracies_by_lr.values()),
            name=optimizer,
        )
    )

fig = go.Figure(
    data=traces,
    layout=go.Layout(
        title="Validation Accuracy vs learning rate",
        xaxis=dict(title="Learning Rate"),
        yaxis=dict(title="Validation Accuracy"),
        showlegend=True,
    ),
)
fig.update_xaxes(type="log")

fig.show()

In [24]:
num_datalines = sum(len(accs) for accs in train_accuracies.values())

traces = []
buttons = []

trace_counter = 0
for optimizer, accuracies_by_lr in train_accuracies.items():
    for lr, accuracy_array in accuracies_by_lr.items():
        traces.append(
            go.Scatter(
                x=list(range(len(accuracy_array))),
                y=accuracy_array,
                name=f"lr={lr}",
                visible=trace_counter == 0,
            )
        )
    buttons.append(
        dict(
            label=optimizer,
            method="update",
            args=[
                {
                    "visible": [False] * trace_counter
                    + [True] * len(accuracies_by_lr)
                    + [False] * (num_datalines - trace_counter - len(accuracies_by_lr))
                }
            ],
        )
    )
    trace_counter += len(accuracies_by_lr)

fig = go.Figure(
    data=traces,
    layout=go.Layout(
        title="Train Accuracy vs Epochs, by learning rate",
        xaxis=dict(title="Epochs"),
        yaxis=dict(title="Train Accuracy"),
        updatemenus=[dict(buttons=buttons, active=0, showactive=True)],
    ),
)

fig.show()