In [None]:
from utils import load_world_sequences, plot_token_world, save_plot, load_model
from token_world import assign_token_colors
from paths import GRAPHS_DIR, DEVICE
from model import GP_model
from decoding_analysis.decoder_helpers import plot_decoder, load_results
from accuracy_analysis import average_metrics, accuracy_token, plot_distribution
import numpy as np
from decoding_analysis.tuple_helpers import finalize_metrics, load_metric_files, merge_results, plot_pos_label_panels
from Intervention_analysis.intervention_helpers import compute_binned_accuracy, plot_binned_accuracy, analyze_changed_errors, plot_changed_errors_bar


In [None]:
world_seqs = load_world_sequences("pentagon_worlds_internship.pkl", 500)
model = GP_model(tokens_size=26, directions_size=2,
                   embedding_size=200, hidden_size=512, output_size=26,
                   n_layers=3, layer_norm=False, dropout=0.0)
model = load_model(model, DEVICE, "/share/klab/sthorat/lventura/models/gru/no_dropout_jan26/checkpoints")
    

**Scene Example:**

In [None]:
world  = list(world_seqs.keys())[5]
colors = assign_token_colors(world.tokens)
plot_token_world(world, colors, "Tuple Scene", save = False, graphs_dir=GRAPHS_DIR)

**Accuracy of the Model per sequence step, averaged over 500 sequences:**

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns

all_measurements = []
for world, (tokens_seq, dirs_seq) in world_seqs.items():
    world_seq = {world: (tokens_seq, dirs_seq)}
    measurements = accuracy_token(model, "gru", world_seq)
    all_measurements.append(measurements)

means, stds, ns = average_metrics(all_measurements)
timesteps = sorted(means.keys())
mean_array = np.array([means[t] for t in timesteps])
std_array  = np.array([stds[t] for t in timesteps])
n_array    = np.array([ns[t]   for t in timesteps])
ci_array = 1.96 * (std_array / np.sqrt(n_array))
mean_dict = dict(zip(timesteps, mean_array))
ci_dict   = dict(zip(timesteps, ci_array))

plot_distribution(
    mean_dict,
    ci_dict,
    GRAPHS_DIR,
    title="Accuracy per timestep",
    ylabel= "Accuracy",
    metric="Accuracy",
)

**Label & Position Decoding:**

In [None]:

results ={
    "Position": load_results("/results/position_decoder_0114_2100.pkl"),
    "Label": load_results("results/label_decoder_35_0114_2041.pkl")
}

metrics = [metric for metric in results["Position"]["model"]]
models = ["baseline", "model"]
layers = [0, 1, 2]
model_lines = {"baseline": "--", "model": "-"}

                
linestyles = {
    metric: {
        model: model_lines[model]
        for model in models
    }
    for metric in metrics
}

cmap = plt.get_cmap("Paired")
colors = {
    m: {mod: cmap((i * 2) )[:3] for mod in models}
    for i, m in enumerate(metrics)
}

for type, results in results.items():
    ylabels = {}
    for metric in metrics:
        parts = metric.split("_")
        print(parts)

        if type == "Label":
            p = int(parts[-1].replace("+", ""))
        
        if type == "Position":
            p = int(parts[-2].replace("pos", "").replace("+", "")) 
            l = int(parts[-1].replace("label", "").replace("+", ""))

        if p == 0:
            ylabels[metric] =f"{type} t"
        elif p > 0:
            ylabels[metric] = f"{type} t+{p}"
        else:
            ylabels[metric] = f"{type} t{p}"

    plot_decoder(
        results=results,
        layers=layers,
        metrics_to_plot=metrics,
        ylabels=ylabels,
        colors=colors,
        linestyles=linestyles,
        graphs_dir=GRAPHS_DIR,
        filename=f"{type}_decoding",
        title=f"{type} Decoding",
    )

NameError: name 'load_results' is not defined

**Tuple Decoding:**



In [None]:
models = ["model", "baseline"]
layers = [0, 1, 2]
offsets = [(0,0), (1,1)]
results = {m: {} for m in models}
allowed_offsets = {0, 1}
paths = {
        "tuple":    "results/tuple_decoder_0115_1013.pkl",
        "combined":  "results/combined_label_position_decoder_0115_1032.pkl"
    }
loaded = load_metric_files(paths)
results = merge_results(loaded)
finalize_metrics(results)

fig = plot_pos_label_panels(results, model_name="model")
save_plot(fig,GRAPHS_DIR,"tuple_expectation")

**Token Change:**

In [None]:


changed_seqs = load_world_sequences("changed_35_200.pkl")
hidden_seqs_35 = load_world_sequences("hidden_35_100.pkl")
hidden_seqs_100 = load_world_sequences("hidden_100_100.pkl")
moved_k_seqs = load_world_sequences("always_k_sequences.pkl")

timesteps, acc_hidden, acc_other, n_hidden, n_other = compute_binned_accuracy( model, "gru", changed_seqs, property_fn=lambda t: t.old_label)
plot_binned_accuracy(timesteps, acc_hidden, acc_other, n_hidden, n_other,  label_modified="Changed location", label_other="Unchanged locations", title="Changed VS Unchanged Accuracy")


**Token Change Origin:**

In [None]:
for phase, window in [ ("relearning", (35, 65)), ("learned", (180, 210)) ]:
    counts = analyze_changed_errors(model, changed_seqs, window)
    proportion = counts["total_errors"] / counts["total_steps"] if counts["total_steps"] else 0
    plot_changed_errors_bar(counts, GRAPHS_DIR, analysis_window=window)


**Token Addition Error:**

In [None]:
timesteps, acc_hidden, acc_other, n_hidden, n_other = compute_binned_accuracy( model, "gru", hidden_seqs_35, property_fn=lambda t: t.hidden)
plot_binned_accuracy(timesteps, acc_hidden, acc_other, n_hidden, n_other, label_modified="Added tokens", label_other="Normal tokens", title="Added VS Unchanged Accuracy")


timesteps, acc_hidden, acc_other, n_hidden, n_other = compute_binned_accuracy( model, "gru", hidden_seqs_100, property_fn=lambda t: t.hidden)
plot_binned_accuracy(timesteps, acc_hidden, acc_other, n_hidden, n_other, label_modified="Added tokens", label_other="Normal tokens", title="Added VS Unchanged Accuracy")


**Tuple Generalization:**

In [None]:
timesteps, acc_k, ci_k, acc_coord, ci_coord, n_k, n_coord = compute_binned_accuracy(
model, "gru", moved_k_seqs,
property_fn1=lambda t: t.label == "k",
property_fn2=lambda t: t.coordinates == (1,1),
)
plot_binned_accuracy(timesteps=timesteps, p_hidden=acc_k, ci_hidden=ci_k, p_other=acc_coord, ci_other=ci_coord, label_modified="K token", label_other="Token at (1,1)", title="Non-Fixed K Accuracy")