# Setup

In [None]:
import torch as th
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
from pathlib import Path
from time import time
import itertools
from random import shuffle
from torch.utils.data import DataLoader
from display_utils import plot_topk_tokens
from datasets import load_dataset, load_from_disk, Dataset

_ = th.set_grad_enabled(False)

In [None]:
exp_name = "latent_prompting"

## Papermill args

In [None]:
batch_size = 8
model = "Llama-2-7b"
device = "auto"
model_path = "/dlabscratch1/public/llm_weights/llama2_hf/Llama-2-7b-hf"
trust_remote_code = False
use_tl = False
extra_args = []
exp_id = None
remote = False

## CLI args

In [None]:
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--single-layer", "-sl", action="store_true", default=False)
parser.add_argument("--patch-from-layer", "-fl", type=int, default=None)
parser.add_argument("--patch-until-layer", "-tl", type=int, default=None)
parser.add_argument("--rnd-labels", "-rl", action="store_true", default=False)
parser.add_argument("--link-string", type=str, default="->")
parser.add_argument("--try-gt-labels", "-gt", action="store_true", default=False)
args = parser.parse_args(extra_args)
print(f"args: {args}")

## Loading and arg preprocessing

In [None]:
from exp_tools import load_model

model_name = model.split("/")[-1]
if model_path is None:
    model_path = model
nn_model = load_model(
    model_path,
    trust_remote_code=trust_remote_code,
    device_map=device,
    use_tl=use_tl,
    # dispatch=True,
)
tokenizer = nn_model.tokenizer
latent_kwargs = dict(
    collect_from_single_layer=args.single_layer,
    patch_from_layer=args.patch_from_layer,
    patch_until_layer=args.patch_until_layer,
)

In [None]:
fw_path = Path("cache/filtered_fw_sample-10BT")
if not fw_path.exists():
    fw_path.parent.mkdir(exist_ok=True)
    # load dataset
    fw_dataset = load_dataset(
        "HuggingFaceFW/fineweb", name="sample-10BT", split="train"
    )
    print(len(fw_dataset))
    fw_dataset = fw_dataset.filter(
        lambda x: len(x["text"].split(" ")) < 250 and len(x["text"]) < 2000
    ).shuffle(seed=42)
    print(len(fw_dataset))
    fw_dataset.save_to_disk(fw_path)

In [None]:
from nnsight_utils import collect_activations_batched, project_on_vocab


def process_dataset(dataset_type, trn_n, val_n=None):
    dataset_path = Path(f"cache/{dataset_type}_dataset")
    act_path = Path("cache") / f"fw_{dataset_type}_acts" / f"{model_name}.pt"
    if not dataset_path.exists():
        print(f"Creating {dataset_type} dataset")
        fw_dataset = load_from_disk(fw_path)
        if dataset_type == "train":
            fw_subset = fw_dataset["text"][:trn_n]
        else:
            fw_subset = fw_dataset["text"][trn_n : trn_n + val_n]

        data = {"text": [], "gt": []}
        for sample in fw_subset:
            toks = tokenizer.encode(sample, add_special_tokens=False)
            rnd_idx = np.random.randint(0, len(toks) - 1)
            x = tokenizer.decode(toks[:rnd_idx])
            y = toks[rnd_idx]
            data["text"].append(x)
            data["gt"].append(y)
        dataset = Dataset.from_dict(data)
        dataset.save_to_disk(dataset_path)
    else:
        dataset = load_from_disk(dataset_path)

    if act_path.exists():
        fw_acts = th.load(act_path)
    else:
        print("Collecting activations")
        act_path.parent.mkdir(exist_ok=True)
        fw_acts = collect_activations_batched(
            nn_model,
            dataset["text"],
            batch_size=batch_size,
            tqdm=tqdm,
            remote=remote,
        )
        th.save(fw_acts, act_path)

    print(f"Adding model predictions to {dataset_type} dataset")
    with nn_model.trace("a", remote=remote):
        model_preds = (
            project_on_vocab(nn_model, fw_acts[-1].to("cuda:0"))
            .argmax(-1)
            .cpu()
            .tolist()
            .save()
        )
    dataset = dataset.add_column(model_name, model_preds.value)
    return dataset, fw_acts


# Process validation dataset
val_dataset, fw_val_acts = process_dataset("val", trn_n=10000, val_n=2000)

# Process training dataset
train_dataset, fw_train_acts = process_dataset("train", trn_n=10000)

## Few shot patchscope lens eval

In [None]:
from prompt_tools import Prompt
from exp_tools import run_prompts

prompts = []
for i, data in enumerate(val_dataset):
    prompts.append(
        Prompt(
            data["text"],
            [data["gt"]],
            {"model pred": [data[model_name]]},
            [tokenizer.convert_ids_to_tokens(data["gt"])],
            {"model pred": [tokenizer.convert_ids_to_tokens(data[model_name])]},
        )
    )

In [None]:
from interventions import LatentPrompt, LatentPromptBatch, latent_prompt_lens


def fs_latent_prompt(num_fs):
    base_str = (
        f"{tokenizer.bos_token}{args.link_string}{tokenizer.unk_token}\n" * num_fs
    )
    base_str += f"{tokenizer.bos_token}{args.link_string}"
    return LatentPrompt.from_string(base_str, tokenizer)


class TokenGenerator:
    def __init__(self, tokens, repeat=0):
        self.tokens = tokens
        self.idx = list(range(len(tokens)))[::-1]
        self.last_chosen = []
        self.repeat = repeat
        self.current_repeat = 0

    def pop(self, *_args, **_kwargs):
        if self.idx == []:
            self.idx = list(range(len(self.tokens)))
            shuffle(self.idx)
        self.current_repeat += 1
        if self.current_repeat <= self.repeat:
            self.last_chosen.append(self.last_chosen[-1])
            return self.last_chosen[-1]
        self.current_repeat = 0
        id_ = self.idx.pop()
        if id_ in self.last_chosen:
            self.idx.insert(id_, 0)
            return self.pop()
        self.last_chosen.append(id_)
        return self.tokens[id_]

    def collect(self):
        res = self.last_chosen
        self.last_chosen = []
        return res

    def reset(self):
        self.idx = list(range(len(self.tokens)))[::-1]
        self.last_chosen = []


class PromptBatchGenerator:
    def __init__(
        self, tokens, token_placeholder: int = tokenizer.unk_token_id, repeat=0
    ):
        self.token_generator = TokenGenerator(tokens, repeat=repeat)
        self.token_placeholder = token_placeholder

    def batch(self, latent_prompt, num_prompts):
        batch = LatentPromptBatch.from_latent_prompts(
            [latent_prompt] * num_prompts, tokenizer
        ).replace_tokens(self.token_placeholder, self.token_generator)
        idx = self.token_generator.collect()
        return batch, idx


def latent_prompting(nn_model, prompts, scan=True, *, generator, num_fs, **kwargs):
    latent_batch, c_idx = generator.batch(fs_latent_prompt(num_fs), len(prompts))
    print(f"using c_idx {c_idx}")
    acts = fw_train_acts[:, c_idx]
    idx = latent_prompting.idx
    latents = []
    for i in range(len(prompts)):
        if num_fs > 0:
            latents.append(acts[:, i * num_fs : (i + 1) * num_fs])
        latents.append(fw_val_acts[:, idx + i].unsqueeze(1))
    print(f"{len(latents)} latents, {len(prompts)} prompts")
    latent_prompting.idx += len(prompts)
    latents = th.cat(latents, dim=1)
    return latent_prompt_lens(
        nn_model, latent_batch, latents=latents, scan=scan, remote=remote, **kwargs
    )

In [None]:
from display_utils import plot_results


def fsp_plot(
    nfs,
    fs_token_name,
    fs_tokens,
    n=100,
    batch_size=batch_size,
    shuffle=False,
    **lp_lens_kwargs,
):
    if shuffle:
        fs_tokens = fs_tokens.copy()
        shuffle(fs_tokens)
    generator = PromptBatchGenerator(fs_tokens)
    latent_prompting.idx = 0
    pr, lpr = list(
        run_prompts(
            nn_model,
            prompts[:n],
            get_probs=latent_prompting,
            batch_size=batch_size,
            tqdm=tqdm,
            method_kwargs={"generator": generator, "num_fs": nfs, **lp_lens_kwargs},
        )
    )
    json_dic = {
        "next token": pr.tolist(),
        "num_fs": nfs,
        "kwargs": lp_lens_kwargs,
    }
    for label, probs in lpr.items():
        json_dic[label] = probs.tolist()
    path = Path("results") / model_name / exp_name / (f"{nfs}_{fs_token_name}")
    path.mkdir(parents=True, exist_ok=True)
    json_file = path / (exp_id + ".json")
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)
    lpr = {f"{nfs}-{fs_token_name} {k}": v for k, v in lpr.items()}
    fig, ax = plt.subplots()
    plot_results(
        ax,
        pr,
        lpr,
        f"{nfs} few-shot - {fs_token_name} labels",
    )
    ax.legend()
    ax.set_ylabel("Probability")
    ax.set_xlabel("Layer")
    ax.set_title(f"{model_name} - {nfs}-{fs_token_name}")
    plt.savefig(path / (exp_id + ".png"))
    plt.show()
    return pr, lpr

## Run

In [None]:
gt_tokens = train_dataset["gt"].copy()
pred_tokens = train_dataset[model_name].copy()
num_few_shot = [0, 5, 10, 50, 100, 500]
batch_sizes = [1024, 128, 128, 64, 32, 16]
lat_exp_probs = {}
fs_tokens = [pred_tokens, gt_tokens] if args.try_gt_labels else [pred_tokens]
for fs_token, fs_token_name in zip(fs_tokens, ["model pred", "gt"]):
    for nfs, bs in zip(num_few_shot, batch_sizes):
        lat_exp_probs[nfs, fs_token_name] = fsp_plot(
            nfs,
            fs_token_name,
            fs_token,
            n=100,
            batch_size=bs,
            shuffle=args.rnd_labels,
            **latent_kwargs
        )

In [None]:
from display_utils import plot_ci
# Plot all configs together
fig, ax = plt.subplots()
for (nfs, fs_token_name), (pr, _) in lat_exp_probs.items():
    plot_ci(ax, pr, label=f"{nfs}-{fs_token_name}")
ax.legend()
ax.set_ylabel("Probability")
ax.set_xlabel("Layer")
ax.set_title(f"{model_name} - all")
plt.savefig(Path("results") / model_name / exp_name / (exp_id + "_all.png"))
plt.show()