# Testing power of solo agent architectures for graph isomorphism

# Setup

In [1]:
RESULTS_SELECT = dict(
    batch_size=20000,
    d_decider=16,
    d_gnn=8,
    dataset_name="er10000",
    learning_rate=0.003,
    num_epochs=1000,
    scheduler_factor=0.5,
    scheduler_patience=2000,
)

PLOT_SMOOTHING = 1

In [2]:
import json
import os

import numpy as np

from scipy.ndimage import gaussian_filter1d

import plotly.graph_objs as go
import plotly.express as px

from pvg.constants import GI_SOLO_AGENTS_RESULTS_DATA_DIR

## Load data

In [3]:
results = {"frozen": {}, "not_frozen": {}}

for filename in os.listdir(GI_SOLO_AGENTS_RESULTS_DATA_DIR):
    filepath = os.path.join(GI_SOLO_AGENTS_RESULTS_DATA_DIR, filename)

    if not os.path.isfile(filepath):
        continue

    # Read the data
    with open(filepath) as fp:
        data = json.load(fp)

    # Check if the result is to be selected
    selected_result = True
    for key, value in RESULTS_SELECT.items():
        if data["combo"][key] != value:
            selected_result = False
            break
    if not selected_result:
        continue

    # Turn the lists into numpy arrays
    for key, value in data.items():
        if isinstance(value, list):
            data[key] = np.array(value)

    # Add the result to the correct dictionary
    if data["combo"]["freeze_encoder"]:
        results["frozen"][data["combo"]["seed"]] = data
    else:
        results["not_frozen"][data["combo"]["seed"]] = data

## Analysis

In [4]:
train_losses_mean = {"frozen": {}, "not_frozen": {}}
train_losses_std = {"frozen": {}, "not_frozen": {}}
train_accuracies_mean = {"frozen": {}, "not_frozen": {}}
train_accuracies_std = {"frozen": {}, "not_frozen": {}}
train_eq_accuracies_mean = {"frozen": {}, "not_frozen": {}}

for temp, temp_results in results.items():
    train_losses_prover = np.array(
        [data["train_losses_prover"] for data in temp_results.values()]
    )
    train_losses_verifier = np.array(
        [data["train_losses_verifier"] for data in temp_results.values()]
    )
    train_losses_mean[temp]["prover"] = np.mean(train_losses_prover, axis=0)
    train_losses_mean[temp]["verifier"] = np.mean(train_losses_verifier, axis=0)
    train_losses_std[temp]["prover"] = np.std(train_losses_prover, axis=0)
    train_losses_std[temp]["verifier"] = np.std(train_losses_verifier, axis=0)

    train_accuracies_prover = np.array(
        [data["train_accuracies_prover"] for data in temp_results.values()]
    )
    train_accuracies_verifier = np.array(
        [data["train_accuracies_verifier"] for data in temp_results.values()]
    )
    train_accuracies_mean[temp]["prover"] = np.mean(train_accuracies_prover, axis=0)
    train_accuracies_mean[temp]["verifier"] = np.mean(train_accuracies_verifier, axis=0)
    train_accuracies_std[temp]["prover"] = np.std(train_accuracies_prover, axis=0)
    train_accuracies_std[temp]["verifier"] = np.std(train_accuracies_verifier, axis=0)

    train_eq_accuracies_prover = np.array(
        [data["train_encoder_eq_accuracies_prover"] for data in temp_results.values()]
    )
    train_eq_accuracies_verifier = np.array(
        [data["train_encoder_eq_accuracies_verifier"] for data in temp_results.values()]
    )
    train_eq_accuracies_mean[temp]["prover"] = np.mean(
        train_eq_accuracies_prover, axis=0
    )
    train_eq_accuracies_mean[temp]["verifier"] = np.mean(
        train_eq_accuracies_verifier, axis=0
    )

    if PLOT_SMOOTHING > 0:
        for agent in ("prover", "verifier"):
            train_losses_mean[temp][agent] = gaussian_filter1d(
                train_losses_mean[temp][agent], sigma=PLOT_SMOOTHING
            )
            train_losses_std[temp][agent] = gaussian_filter1d(
                train_losses_std[temp][agent], sigma=PLOT_SMOOTHING
            )
            train_accuracies_mean[temp][agent] = gaussian_filter1d(
                train_accuracies_mean[temp][agent], sigma=PLOT_SMOOTHING
            )
            train_accuracies_std[temp][agent] = gaussian_filter1d(
                train_accuracies_std[temp][agent], sigma=PLOT_SMOOTHING
            )
            train_eq_accuracies_mean[temp][agent] = gaussian_filter1d(
                train_eq_accuracies_mean[temp][agent], sigma=PLOT_SMOOTHING
            )

## Presentation

In [5]:
# https://stackoverflow.com/a/61501980

# define colors as a list 
colours = px.colors.qualitative.Plotly

# convert plotly hex colors to rgba to enable transparency adjustments
def hex_rgba(hex, transparency):
    col_hex = hex.lstrip('#')
    col_rgb = list(int(col_hex[i:i+2], 16) for i in (0, 2, 4))
    col_rgb.extend([transparency])
    areacol = tuple(col_rgb)
    return "rgba" + str(areacol)

In [6]:
fig = go.Figure()

for temp in ("not_frozen", "frozen"):
    for agent, colour in zip(["prover", "verifier"], colours):
        losses_mean = train_losses_mean[temp][agent]
        losses_std = train_losses_std[temp][agent]

        x = np.concatenate(
            (np.arange(len(losses_mean)), np.arange(len(losses_mean))[::-1])
        )
        y = np.concatenate(
            (losses_mean + losses_std, losses_mean[::-1] - losses_std[::-1])
        )

        fig.add_trace(
            go.Scatter(
                x=x,
                y=y,
                fill="toself",
                fillcolor=hex_rgba(colour, 0.2),
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                visible=bool(temp == "not_frozen"),
            )
        )

        fig.add_trace(
            go.Scatter(
                x=np.arange(len(losses_mean)),
                y=losses_mean,
                mode="lines",
                name=agent.capitalize(),
                line=dict(color=colour),
                visible=bool(temp == "not_frozen"),
            )
        )



fig.update_layout(
    title=f"Training Loss over {len(results['frozen'])} Seeds",
    xaxis_title="Epoch",
    yaxis_title="Loss",
    legend_title="Agent",
    updatemenus=[
        dict(
            buttons=[
                dict(
                    label="Full network",
                    method="update",
                    args=[{"visible": [True] * 4 + [False] *4}]
                ),
                dict(
                    label="Frozen encoder",
                    method="update",
                    args=[{"visible": [False] * 4 + [True] *4}]
                )
            ],
            active=0,
            showactive=True
        )
    ]
)
fig.show()

In [7]:
fig = go.Figure()

for temp in ("not_frozen", "frozen"):
    for agent, colour in zip(["prover", "verifier"], colours):
        accs_mean = train_accuracies_mean[temp][agent]
        accs_std = train_accuracies_std[temp][agent]

        eq_accs_mean = train_eq_accuracies_mean[temp][agent]

        x = np.concatenate((np.arange(len(accs_mean)), np.arange(len(accs_mean))[::-1]))
        y = np.concatenate((accs_mean + accs_std, accs_mean[::-1] - accs_std[::-1]))

        fig.add_trace(
            go.Scatter(
                x=x,
                y=y,
                fill="toself",
                fillcolor=hex_rgba(colour, 0.2),
                line=dict(color="rgba(255,255,255,0)"),
                hoverinfo="skip",
                showlegend=False,
                visible=bool(temp == "not_frozen"),
            )
        )

        fig.add_trace(
            go.Scatter(
                x=np.arange(len(accs_mean)),
                y=accs_mean,
                mode="lines",
                name=agent.capitalize(),
                line=dict(color=colour),
                visible=bool(temp == "not_frozen"),
            )
        )

        fig.add_trace(
            go.Scatter(
                x=np.arange(len(eq_accs_mean)),
                y=eq_accs_mean,
                mode="lines",
                name=f"{agent.capitalize()} (optimal)",
                line=dict(color=colour, dash="dash"),
                visible=bool(temp == "not_frozen"),
            )
        )

fig.update_layout(
    title=f"Training Accuracy over {len(results['not_frozen'])} Seeds",
    xaxis_title="Epoch",
    yaxis_title="Accuracy",
    yaxis_tickformat=",.0%",
    legend_title="Agent",
    updatemenus=[
        dict(
            buttons=[
                dict(
                    label="Full network",
                    method="update",
                    args=[{"visible": [True] * 6 + [False] * 6}],
                ),
                dict(
                    label="Frozen encoder",
                    method="update",
                    args=[{"visible": [False] * 6 + [True] * 6}],
                ),
            ],
            active=0,
            showactive=True,
        )
    ],
)

fig.show()