In [None]:
import torch
import numpy as np
import shutil
import re

from delphi.attention.__main__ import run
from delphi.config import RunConfig, CacheConfig, SamplerConfig, ConstructorConfig
from safetensors.torch import load_file, save_file
from pathlib import Path

%cd ..

In [None]:
def post_process_cache(base_path, mode, value, abs, normalize, n_latents):
    latents_path = base_path / "latents"
    log_path = base_path / "log" / "hookpoint_firing_counts.pt"
    firing_counts = dict()

    for folder in latents_path.iterdir(): # cycle through folders (layers)
        if folder.is_dir():
            layer = folder.name
            firing_counts[layer] = torch.zeros(n_latents, dtype=torch.int64)

            for file in folder.glob("*.safetensors"): # cycle through safetensors
                tensor = load_file(file)
                activations = tensor['activations'].cpu()
                locations = tensor['locations'].cpu()
                keep_mask_total = torch.zeros_like(activations, dtype=torch.bool)

                lat_start, lat_end = list(map(int, re.findall(r"\d+", file.name)))
                
                if abs:
                    # get absolute value
                    activations = activations.abs()
                else:
                    # shift positive and negative
                    # activations = -activations
                    # move to minimum zero
                    activations = activations - activations.min()

                # apply threshold
                for i in range(0, lat_end - lat_start + 1):
                    mask = locations[:, 2] == i

                    if mask.any():
                        layer_activations = activations[mask]

                        # get threshold
                        if mode == 'threshold':
                            threshold = value

                        elif mode == "percentile":
                            np_activations = layer_activations.view(-1).numpy().astype(np.float32)
                            threshold = np.percentile(np_activations, value)

                        # apply threshold
                        layer_keep_mask = layer_activations >= threshold
                        keep_mask_total[mask] = layer_keep_mask

                # normalize activations
                if normalize:
                    for i in range(0, lat_end - lat_start + 1):
                        mask = (locations[:, 2] == i) & keep_mask_total
                        if mask.any():
                            layer_activations = activations[mask]

                            min_val = layer_activations.min()
                            max_val = layer_activations.max()
                            if max_val > min_val:
                                activations[mask] = (layer_activations - min_val) / (max_val - min_val)
                            else:
                                activations[mask] = torch.zeros_like(layer_activations)

                # apply global keep mask
                tensor['activations'] = activations[keep_mask_total]
                tensor['locations'] = locations.to(torch.int32)[keep_mask_total]

                # update log for layer
                values, counts = torch.unique(tensor['locations'][:, 2], return_counts=True)                
                firing_counts[layer][values + lat_start] = counts

                # convert tensors to numpy arrays before saving
                tensor['locations'] = tensor['locations'].to(torch.int32)
                tensor_numpy = {k: v.cpu() for k, v in tensor.items()}

                # overwrite safetensor file
                save_file(tensor_numpy, file)

    # save log
    torch.save(firing_counts, log_path)


def copy_folder(src, dst):

    base = Path("./results")

    src = base / src
    dst = base / dst

    if not src.exists() or not src.is_dir():
        raise FileNotFoundError(f"Source folder '{src}' does not exist or is not a directory.")

    if dst.exists():
        shutil.rmtree(dst)

    shutil.copytree(src, dst)


async def create_cache(settings):
    cache_cfg = CacheConfig(n_tokens=settings["n_tokens"], n_splits=settings["n_splits"])

    run_cfg = RunConfig(
        cache_cfg=cache_cfg,
        constructor_cfg=ConstructorConfig(),
        sampler_cfg=SamplerConfig(),
        model=settings["subject_model"],
        explainer_model=settings["explainer_model"],
        max_latents=settings["n_latents"],
        hookpoints=settings["hookpoints"],
        filter_bos=True,
        name=settings["name"],
        verbose=True,
        seed=22
    )

    await run(run_cfg, steps=2)


async def validate(settings_run, settings):
    # post process cache
    name = f'{settings["src"]}_{"abs_" if settings["abs"] else ""}{"norm_" if settings["normalize"] else ""}{"perc" if settings["method"]=="percentile" else "thrd"}_{settings["value"]}'
    base_path = Path("./results") / name
    copy_folder(src=settings["src"], dst=name)
    post_process_cache(base_path, settings["method"], settings["value"], settings["abs"], settings["normalize"], settings_run["n_latents"])

    cache_cfg = CacheConfig(n_tokens=settings_run["n_tokens"], n_splits=settings_run["n_splits"])

    run_cfg = RunConfig(
        cache_cfg=cache_cfg,
        constructor_cfg=ConstructorConfig(),
        sampler_cfg=SamplerConfig(),
        model=settings_run["subject_model"],
        explainer_model=settings_run["explainer_model"],
        max_latents=settings_run["n_latents"],
        hookpoints=settings_run["hookpoints"],
        filter_bos=True,
        name=name,
        verbose=True,
        seed=22
    )

    await run(run_cfg, steps=3)

In [None]:
settings_run = {
    "n_tokens": 1_000_000,                                 # how many tokens to cache
    "n_splits": 1,                                         # how many files the cache is split into
    "subject_model": "meta-llama/Llama-3.2-1B",            # HF name of subject model
    "explainer_model": "Qwen/Qwen2.5-32B-Instruct-AWQ",    # HF name of explainer/interpreter model
    "n_latents": 32,                                       # number of heads per layer
    "hookpoints": ["5"],                                   # which layers to cache
    "name": "1mil_attn"                                    # folder name of cache
}

settings_post_process = {
    "src": "1mil_attn",                                    # folder name of source cache
    "abs": False,                                           # apply absolute value
    "normalize": False,                                     # normalize between 0 and 1
    "method": "percentile",                                 # thresholding method
    "value": 99.99,                                         # thresholding value
}

In [None]:
await create_cache(settings_run)

In [None]:
await validate(settings_run, settings_post_process)