### Notebook to Execute all the experiments and Plots

### Run All Experiments

In [72]:
# imports
import os
import sys
import subprocess

In [None]:
script_dir = os.path.abspath(os.path.join("..", "scripts"))

# test the command, turn it False to execute
debug = False

# Define parameters
command = [
    "python",
    "run_all.py", # main script name

    ## Dataset and Model Configuration

    "--model_name", "gpt2", # ["gpt2", "pythia-6.9b",  "Llama-3.2-1B", "Llama-3.1-8B"]
    "--dataset", "copyVSfact", # ["copyVSfact", "copyVSfactQnA", "copyVSfactDomain"]
    "--start", "0", # None for the entire dataset
    "--end", "100", # None for the entire dataset
    # "--prompt_type", "qna", # Domain/Normal Dataset: "None", QnA Dataset: "qna" 
    # "--domain", "Science", # Name of the domain (if dataset is copyVSfactDomain), else None
    # "--downsampled-dataset", # used fixed dataset or not
    # "--quantize", # used quantized version of the model

    ## Setup Configuration

    "--batch", "16", # batch size
    "--device", "cuda",

    ## Plotting Configuration

    # "--only-plot", # to only plot the results
    # "--no-plot", # disables plotting after each experiment run

    ## Experiments Configuration

    # "--logit-attribution",
    # "--logit_lens", # fails on MPS device with batch mismatch
    # "--ov-diff",
    # "--ablate",
    # "--pattern",
    "--all", # to run all the experiments

]

# Run "run_all.py"
print(f"Running command in: {script_dir.split("/")[-1]}")
print(f"Running command: {' '.join(command)}")
if not debug:
    subprocess.run(command, cwd=script_dir)

### Visualize Results

In [None]:
script_dir = os.path.abspath(os.path.join(".."))

command = [
    "streamlit", 
    "run", f"dashboard.py",
]

print(f"Running command in: {script_dir.split("/")[-1]}")
print(f"Running command: {' '.join(command)}")
subprocess.run(command, cwd=script_dir)

### Save Plots for Specific Experiment and Dataset

In [61]:
sys.path.append(os.path.abspath(os.path.join("..")))
sys.path.append(os.path.abspath(os.path.join("../src")))
from dataset import DOMAINS

In [70]:
# Define constants for default values

# domain name
DOMAIN = None

# downsampled dataset
DOWNSAMPLED = False

# models
MODEL = "gpt2"
# MODEL = "pythia-6.9b"
# MODEL = "Llama-3.2-1B"
# MODEL = "Llama-3.1-8B"

# model folder
MODEL_FOLDER = f"{MODEL}_full"

# experiment name
EXPERIMENT = "copyVSfact"
# EXPERIMENT = "copyVSfactQnA"
# EXPERIMENT = "copyVSfactDomain"

# if EXPERIMENT == "copyVSfactDomain":
#     MODEL_FOLDER += f"_{DOMAIN}"

scripts = [
    "plot_logit_lens_fig_2.py",
    "plot_logit_attribution_fig_3_4a.py",
    "plot_head_pattern_fig_4b.py",
    # "plot_ablation_fig_5.py"
]

script_dir = os.path.abspath(os.path.join("..", "plotting_scripts"))

def save_plots(debug=True):
    for script in scripts:
        if EXPERIMENT == "copyVSfactDomain":
            for domain in DOMAINS:
                command = [
                    "python",
                    script,
                    "--model", MODEL,
                    "--experiment", EXPERIMENT,
                    "--model_folder", f"{MODEL_FOLDER}_{domain}",
                    "--domain", domain,
                ]
                if DOWNSAMPLED:
                    command.append("--downsampled")
                print(f"Running command: {' '.join(command)}")
                if not debug:
                    subprocess.run(command, cwd=script_dir)
                    print()
        else:
            command = [
                "python",
                script,
                "--model", MODEL,
                "--experiment", EXPERIMENT,
                "--model_folder", MODEL_FOLDER,
            ]
            if DOWNSAMPLED:
                command.append("--downsampled")
            print(f"Running command: {' '.join(command)}")
            if not debug:
                subprocess.run(command, cwd=script_dir)
                print()

In [None]:
# test the command, turn it False to execute
debug=True

# execute the command
save_plots(debug)