# Plot Scratchpad Results

## Prerequisites

In [1]:
import copy
import datetime
import importlib
import re

from pathlib import Path
from typing import List, Dict, Any


import numpy as np
from tqdm import tqdm

import plotnine as p9

import pandas as pd
from collections import deque

from utils import add_src_to_sys_path

add_src_to_sys_path()

from common import wandb_utils, nest
from common import plotnine_utils as p9u

from plotting import attention_analysis_utils as aa_utils

wandb_utils = importlib.reload(wandb_utils)
p9u = importlib.reload(p9u)
aa_utils = importlib.reload(aa_utils)

wandb_api = wandb_utils.get_wandb_api()

## Load the experiments

In [2]:
tags = ["attention_analysis"]

In [3]:
plot_dir_name = "__".join(tags)
output_dir = Path("output_plots") / f"{plot_dir_name}"
output_dir.mkdir(parents=True, exist_ok=True)

In [4]:
df = wandb_utils.download_and_load_results(
    tags=tags,
    force_download=True,
)
len(df)

  0%|          | 0/115 [00:00<?, ?it/s]

Building dataframe...
Saving results to /Users/amirhosein/Development/PycharmProjects/len_gen/results/attention_analysis.jsonl


115

In [None]:
# or load manually
import jsonlines
def load_dataframe_from_jsonlines(path: Path) -> pd.DataFrame:
    data = []
    with jsonlines.open(path) as reader:
        for obj in reader:
            data.append(obj)
    return pd.DataFrame.from_records(data)
def get_result_name(tags: List[str]) -> str:
    return "_".join(tags)

df = load_dataframe_from_jsonlines(Path("../results") / f"{get_result_name(tags)}.jsonl")
len(df)

## Load the data

In [None]:
df["cfg__dataset.name"].unique()

In [None]:
df["cfg__dataset.split"].unique()

In [None]:
DATASET = "scan"
SPLIT = "len_tr25_ts48"

df = df[(df["cfg__dataset.name"] == DATASET) & (df["cfg__dataset.split"] == SPLIT)]
len(df)

In [5]:
from collections import defaultdict

scratchpad_config_pattern = re.compile(
    r"(.)*_scratchpad(.)+_ufs__(i._c._o._v._r.)_.*___.*"
)

def get_compute_cluster(host: str):
    if "cedar" in host:
        return "cc_cedar"
    elif "narval" in host:
        return "cc_narval"
    elif host.startswith("cn-"):
        return "mila"
    else:
        raise ValueError(f"Unknown host {host}")

def get_scratchpad_config(group: str):
    scratchpad_config = "no_scratchpad"
    result = scratchpad_config_pattern.search(group)
    if result:
        scratchpad_config = result.group(3)
    return scratchpad_config

df["scratchpad_config"] = df["run_group"].apply(get_scratchpad_config)

# Map run group to their hostname
scratchpad_config_to_cluster_name = defaultdict(set)
run_group_to_cluster_name = {}
for scratchpad_config, hostname in zip(df["scratchpad_config"], df["host"]):
    cluster_name = get_compute_cluster(hostname)
    scratchpad_config_to_cluster_name[scratchpad_config].add(cluster_name)

cluster_name_to_scratchpad_configs = defaultdict(set)
for scratchpad_config, cluster_names in scratchpad_config_to_cluster_name.items():
    for cluster_name in cluster_names:
        cluster_name_to_scratchpad_configs[cluster_name].add(scratchpad_config)

cluster_name_to_scratchpad_configs

defaultdict(set,
            {'mila': {'i0_c1_o1_v1_r1',
              'i1_c0_o1_v1_r1',
              'i1_c1_o0_v1_r1',
              'i1_c1_o1_v0_r1',
              'i1_c1_o1_v1_r0',
              'i1_c1_o1_v1_r1'},
             'cc_cedar': {'i0_c1_o1_v0_r0', 'no_scratchpad'}})

In [6]:
import os

localhost = os.uname()[1]
local_cluster = get_compute_cluster(localhost)
# local_cluster = "mila"
local_cluster

'mila'

In [18]:
from tqdm.auto import tqdm
analysis_root_dir = Path.home() / "scratch" / "len_gen" / "experiments" / "attention_analysis_data"

aa_utils = importlib.reload(aa_utils)

for _, row in tqdm(df.iterrows(), total=len(df)):
    if row["job_type"] == "agent":
        continue

    if row["cfg__model.position_encoding_type"] != "none":
        continue

    scratchpad_config = row["scratchpad_config"]
    if scratchpad_config not in cluster_name_to_scratchpad_configs[local_cluster]:
        continue

    hostname = get_compute_cluster(row["host"])

    run = wandb_api.run(f"{wandb_utils.get_entity_name()}/{wandb_utils.get_project_name()}/{row['id']}")
    aa_utils.download_the_entire_run2(local_cluster, run, root_dir=analysis_root_dir)


  0%|          | 0/54 [00:00<?, ?it/s]