In [None]:
from pathlib import Path
import pandas as pd
import yaml
import json
from matplotlib import pyplot as plt
import numpy as np

In [None]:
old_data_dir = Path("./data/deco_decoUB_l10_s50_reg3_d100000_cg0_nhr0/train.json")
new_data_dir = Path("../../datasets/SIRDataset_CFC188/train.yaml")

In [None]:
with open(old_data_dir, "r") as f:
    # Load the old JSON data
    old_data = json.load(f)

print(f"loaded old data containing {len(old_data)} examples")

with open(new_data_dir, "r") as f:
    # Load the new YAML data
    new_data = yaml.safe_load(f)

print(f"loaded new data containing {len(new_data)} examples")

In [None]:
new_data[0]["sequence"]

In [None]:
def old_to_new_mapping(old_sym_or_old_sym_array):
    if isinstance(old_sym_or_old_sym_array, list):
        # handle the case where it's a list
        return [old_to_new_mapping(sym) for sym in old_sym_or_old_sym_array[10:]]
    elif isinstance(old_sym_or_old_sym_array, str):
        # handle the case where it's a single string
        old_sym = old_sym_or_old_sym_array
    mapping = {"store": "St", "ignore": "Ig", "different": "diff", "same": "same"}
    if old_sym in mapping:
        return mapping[old_sym]
    if "Register" in old_sym:
        # handle cases like 'Register 1', 'Register 2', etc.
        return f"reg_{old_sym.split()[1]}"
    if "Symbol" in old_sym:
        # handle cases like 'Symbol_1', 'Symbol_2', etc.
        return f"item_{old_sym.split('_')[1]}"

    raise ValueError(f"Cannot find mapping for old symbol: {old_sym}")

In [None]:
old_to_new_mapping(old_data[0]["sequence"])

In [None]:
for i, entry in enumerate(old_data):
    old_data[i] = old_to_new_mapping(entry["sequence"])

In [None]:
for i, entry in enumerate(new_data):
    new_data[i] = entry["sequence"].split()

In [None]:
ulim = 100_000
old_instruction_img = np.zeros((ulim, 11))
new_instruction_img = np.zeros((ulim, 11))


def populate_img(img, data, symbol="St"):
    print(f"populating image of size {img.shape} with {ulim} sequences")
    for i, seq in enumerate(data[:ulim]):
        for j, sym in enumerate(seq):
            if sym == symbol:
                img[i, j // 4] = 1


populate_img(old_instruction_img, old_data)
populate_img(new_instruction_img, new_data)

In [None]:
f, a = plt.subplots(1, 2, figsize=(10, 7))
a[0].imshow(
    old_instruction_img,
    cmap="grey",
    interpolation="none",
    aspect="auto",
)
a[0].set_title("old instructions rasterplot")
a[1].imshow(
    new_instruction_img,
    cmap="grey",
    interpolation="none",
    aspect="auto",
)
a[1].set_title("new instructions rasterplot")

In [None]:
# now, create a rasterplot for same/diff outcomes

ulim = 100_000
old_samediff_img = np.zeros((ulim, 11))
new_samediff_img = np.zeros((ulim, 11))

populate_img(old_samediff_img, old_data, symbol="same")
populate_img(new_samediff_img, new_data, symbol="same")

f, a = plt.subplots(1, 2, figsize=(10, 7))
a[0].imshow(
    old_samediff_img,
    cmap="Reds",
    interpolation="none",
    aspect="auto",
)
a[0].set_title("old labels (same/diff) rasterplot")
a[1].imshow(
    new_samediff_img,
    cmap="Reds",
    interpolation="none",
    aspect="auto",
)
a[1].set_title("new labels (same/diff) rasterplot")

In [None]:
# now, create a rasterplot for ordinals like which registers and which symbols are used:

ulim = 200
old_reg_img = np.zeros((ulim, 11))
new_reg_img = np.zeros((ulim, 11))
cmap = plt.get_cmap(
    "tab20", lut=50
)  # Use a colormap that can handle multiple categories


def populate_img_with_ordinals(img, data, symbol="reg"):
    from collections import defaultdict

    sym_map = defaultdict(lambda: len(sym_map) + 1)
    print(f"populating image of size {img.shape} with {ulim} sequences")
    for i, seq in enumerate(data[:ulim]):
        for j, sym in enumerate(seq):
            if symbol in sym:
                img[i, j // 4] = sym_map[sym]


populate_img_with_ordinals(old_reg_img, old_data, symbol="item")
populate_img_with_ordinals(new_reg_img, new_data, symbol="item")


f, a = plt.subplots(1, 2, figsize=(10, 14))
a[0].imshow(
    old_reg_img,
    cmap=cmap,
    interpolation="none",
    aspect="auto",
)
a[0].set_title("old item usage rasterplot")
a[1].imshow(
    new_reg_img,
    cmap=cmap,
    interpolation="none",
    aspect="auto",
)
a[1].set_title("new item usage rasterplot")


plt.show()


In [None]:
from collections import Counter

samediff_old = [Counter(seq)["same"] for seq in old_data]
samediff_new = [Counter(seq)["same"] for seq in new_data]

# plt.hist(
#     samediff_old,
#     density=True,
#     bins=9,
#     alpha=0.6,
#     label="old same[diff] counts",
#     color="blue",
# )
# plt.hist(
#     samediff_new,
#     density=True,
#     bins=9,
#     alpha=0.6,
#     label="new same[diff] counts",
#     color="orange",
# )
# plt.legend()

# plot the KDE
import seaborn as sns

sns.histplot(
    samediff_old,
    label="old same[diff] counts",
    color="blue",
    bins=9,
    alpha=0.5,
)
plt.legend()
plt.show()
sns.histplot(
    samediff_new,
    label="new same[diff] counts",
    color="orange",
    bins=9,
    alpha=0.5,
)
plt.legend()


In [None]:
instr_old = [Counter(seq)["St"] for seq in old_data]
instr_new = [Counter(seq)["St"] for seq in new_data]

import seaborn as sns

sns.histplot(
    instr_old,
    label="old store[ignore] counts",
    color="purple",
    bins=9,
    alpha=0.5,
)
plt.legend()
plt.show()
sns.histplot(
    instr_new,
    label="new store[ignore] counts",
    color="maroon",
    bins=9,
    alpha=0.5,
)
plt.legend()

In [None]:
# quantify distribution of longest streak of using the same register


def longest_streak_register(sequence):
    last_reg = None
    streak = 1
    max_streak = 0
    for sym in sequence:
        if "reg" in sym:
            if last_reg == sym:
                streak += 1
            else:
                last_reg = sym
                max_streak = max(max_streak, streak)
                streak = 1
    return max(max_streak, streak)  # return the max streak found


# Calculate the longest streak for both old and new data
old_streaks = [longest_streak_register(seq) for seq in old_data]
new_streaks = [longest_streak_register(seq) for seq in new_data]

In [None]:
# plot streak distributions using seaborn
import seaborn as sns
import matplotlib.pyplot as plt

streaks = pd.DataFrame(
    [*zip(old_streaks, ["old"] * len(old_streaks))]
    + [*zip(new_streaks, ["new"] * len(new_streaks))],
    columns=["streak", "data"],
)
# Plot the distributions of longest streaks
plt.figure(figsize=(12, 6))


sns.histplot(
    data=streaks,
    x="streak",
    hue="data",
    alpha=0.5,
    bins=9,
    multiple="dodge",
)

plt.xticks(np.arange(1, 10))
plt.title("longest streak of utilizing the same register in a trial sequence")