<a href="https://colab.research.google.com/github/YanaySoker/Specificity_of_ROME/blob/main/automation_drag_cities.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/kmeng01/rome/blob/main/notebooks/causal_trace.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

In [None]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
cd /content && rm -rf /content/rome
git clone https://github.com/kmeng01/rome rome > install.log 2>&1
pip install -r /content/rome/scripts/colab_reqs/rome.txt >> install.log 2>&1
pip install --upgrade google-cloud-storage >> install.log 2>&1

In [None]:
%cd rome

/content/rome


In [None]:
%%writefile ./experiments/py/demo.py
# New demo.py
import os
from pathlib import Path
from typing import Dict, List, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from baselines.ft import FTHyperParams, apply_ft_to_model
from rome import ROMEHyperParams, apply_rome_to_model
from util import nethook
from util.generate import generate_fast
from util.globals import *

def demo_model_editing(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    generation_prompts: List[str],
    alg_name: str = "ROME",
) -> Tuple[AutoModelForCausalLM, Dict[str, torch.Tensor]]:
    """
    Applies the selected model editing algorithm. Generates text both before and after
    for comparison of model behavior. Returns the updated model and the original values of
    weights that were changed.
    """
    print(requests, "\n")
    nethook.set_requires_grad(True, model)

    RewritingParamsClass, apply_method, hparams_prefix, hparams_suffix = load_alg(
        alg_name
    )
    params_name = (
        HPARAMS_DIR
        / hparams_prefix
        / f"{model.config._name_or_path.replace('/', '_')}{hparams_suffix}.json"
    )

    hparams = RewritingParamsClass.from_json(params_name)
    model_new, orig_weights = apply_method(
        model, tok, requests, hparams, return_orig_weights=True
    )

    return model_new, orig_weights

def load_alg(alg_name):
    """
    Loads dependencies for the desired algorithm.
    Implementation is slightly awkward to prevent unnecessary imports on Colab.

    The return value is a tuple of the following:
    1. Class for storing hyperparameters
    2. Method for applying rewrites
    3. Location of parameters
    4. Predefined suffix for the param file
    """
    assert alg_name in [
        "FT",
        "FT-L",
        "FT-AttnEdit",
        "KN",
        "MEND",
        "MEND-CF",
        "MEND-zsRE",
        "KE",
        "KE-CF",
        "ROME",
    ]

    if alg_name == "ROME":
        return ROMEHyperParams, apply_rome_to_model, "ROME", ""
    elif "FT" in alg_name:
        d = {
            "FT": (FTHyperParams, apply_ft_to_model, "FT", "_unconstr"),
            "FT-AttnEdit": (FTHyperParams, apply_ft_to_model, "FT", "_attn"),
            "FT-L": (FTHyperParams, apply_ft_to_model, "FT", "_constr"),
        }
        return d[alg_name]
    else:
        from baselines.efk import EFKHyperParams, EfkRewriteExecutor
        from baselines.kn import KNHyperParams, apply_kn_to_model
        from baselines.mend import MENDHyperParams, MendRewriteExecutor

        d = {
            "KN": (KNHyperParams, apply_kn_to_model, "KN", ""),
            "MEND": (MENDHyperParams, MendRewriteExecutor().apply_to_model, "MEND", ""),
            "KE": (EFKHyperParams, EfkRewriteExecutor().apply_to_model, "KE", ""),
            "MEND-CF": (
                MENDHyperParams,
                MendRewriteExecutor().apply_to_model,
                "MEND",
                "_CF",
            ),
            "MEND-zsRE": (
                MENDHyperParams,
                MendRewriteExecutor().apply_to_model,
                "MEND",
                "_zsRE",
            ),
            "KE-CF": (
                EFKHyperParams,
                EfkRewriteExecutor().apply_to_model,
                "MEND",
                "_CF",
            ),
        }
        return d[alg_name]

def print_loud(x, pad=3):
    """
    Prints a string with # box for emphasis.

    Example:
    ############################
    #                          #
    #  Applying ROME to model  #
    #                          #
    ############################
    """

    n = len(x)
    print()
    print("".join(["#" for _ in range(n + 2 * pad)]))
    print("#" + "".join([" " for _ in range(n + 2 * (pad - 1))]) + "#")
    print(
        "#"
        + "".join([" " for _ in range(pad - 1)])
        + x
        + "".join([" " for _ in range(pad - 1)])
        + "#"
    )
    print("#" + "".join([" " for _ in range(n + 2 * (pad - 1))]) + "#")
    print("".join(["#" for _ in range(n + 2 * pad)]))

class StopExecution(Exception):
    def _render_traceback_(self):
        pass

def stop_execution():
    raise StopExecution


Overwriting ./experiments/py/demo.py


In [None]:
%%writefile ./rome/rome_main.py
# New rome_main.py
from copy import deepcopy
from typing import Dict, List, Tuple

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from util import nethook
from util.generate import generate_fast

from .compute_u import compute_u
from .compute_v import compute_v
from .rome_hparams import ROMEHyperParams

CONTEXT_TEMPLATES_CACHE = None

def apply_rome_to_model(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    requests: List[Dict],
    hparams: ROMEHyperParams,
    copy=False,
    return_orig_weights=False,
) -> Tuple[AutoModelForCausalLM, List[str]]:
    """
    Returns a model with the desired changes.

    :param copy: If true, will preserve the original model while creating a new one to edit.
        Note that you are responsible for deallocating the new model's memory to avoid leaks.

    :return: (1) the updated model, (2) an original copy of the weights that changed
    """

    if copy:
        model = deepcopy(model)

    weights_copy = {}

    for i, request in enumerate(requests):
        deltas = execute_rome(model, tok, request, hparams)

        with torch.no_grad():
            for w_name, (delta_u, delta_v) in deltas.items():
                upd_matrix = delta_u.unsqueeze(1) @ delta_v.unsqueeze(0)
                w = nethook.get_parameter(model, w_name)
                upd_matrix = upd_matrix_match_shape(upd_matrix, w.shape)

                if return_orig_weights and w_name not in weights_copy:
                    assert i == 0
                    weights_copy[w_name] = w.detach().clone()

                w[...] += upd_matrix

    return model, weights_copy

def execute_rome(
    model: AutoModelForCausalLM,
    tok: AutoTokenizer,
    request: Dict,
    hparams: ROMEHyperParams,
) -> Dict[str, Tuple[torch.Tensor]]:
    """
    Executes the ROME update algorithm for the specified update at the specified layer
    Invariant: model at beginning of function == model at end of function
    """

    # Update target and print info
    request = deepcopy(request)
    if request["target_new"]["str"][0] != " ":
        # Space required for correct tokenization
        request["target_new"]["str"] = " " + request["target_new"]["str"]

    # Retrieve weights that user desires to change
    weights = {
        f"{hparams.rewrite_module_tmp.format(layer)}.weight": nethook.get_parameter(
            model, f"{hparams.rewrite_module_tmp.format(layer)}.weight"
        )
        for layer in hparams.layers
    }
    # Save old weights for future restoration
    weights_copy = {k: v.detach().clone() for k, v in weights.items()}

    # Update loop: sequentially intervene at each specified layer
    deltas = {}
    for layer in sorted(hparams.layers):
        # Compute rank-1 update matrix
        left_vector: torch.Tensor = compute_u(
            model,
            tok,
            request,
            hparams,
            layer,
            get_context_templates(model, tok, hparams.context_template_length_params),
        )
        right_vector: torch.Tensor = compute_v(
            model,
            tok,
            request,
            hparams,
            layer,
            left_vector,
            get_context_templates(model, tok, hparams.context_template_length_params),
        )

        with torch.no_grad():
            # Determine correct transposition of delta matrix
            weight_name = f"{hparams.rewrite_module_tmp.format(layer)}.weight"
            upd_matrix = left_vector.unsqueeze(1) @ right_vector.unsqueeze(0)
            upd_matrix = upd_matrix_match_shape(upd_matrix, weights[weight_name].shape)

            # Update model weights and record desired changes in `delta` variable
            weights[weight_name][...] += upd_matrix
            deltas[weight_name] = (
                left_vector.detach(),
                right_vector.detach(),
            )

    # Restore state of original model
    with torch.no_grad():
        for k, v in weights.items():
            v[...] = weights_copy[k]

    return deltas

def upd_matrix_match_shape(matrix: torch.Tensor, shape: torch.Size) -> torch.Tensor:
    """
    GPT-2 and GPT-J have transposed weight representations.
    Returns a matrix that matches the desired shape, else raises a ValueError
    """

    if matrix.shape == shape:
        return matrix
    elif matrix.T.shape == shape:
        return matrix.T
    else:
        raise ValueError(
            "Update matrix computed by ROME does not match original weight shape. "
            "Check for bugs in the code?"
        )

def get_context_templates(model, tok, length_params):
    global CONTEXT_TEMPLATES_CACHE

    if CONTEXT_TEMPLATES_CACHE is None:
        CONTEXT_TEMPLATES_CACHE = ["{}"] + [
            x + ". {}"
            for x in sum(
                (
                    generate_fast(
                        model,
                        tok,
                        ["<|endoftext|>"],
                        n_gen_per_prompt=n_gen,
                        max_out_len=length,
                    )
                    for length, n_gen in length_params
                ),
                [],
            )
        ]

        print(f"Cached context templates {CONTEXT_TEMPLATES_CACHE}")

    return CONTEXT_TEMPLATES_CACHE


Overwriting ./rome/rome_main.py


In [None]:
IS_COLAB = False
ALL_DEPS = False
try:
    import google.colab, torch, os

    IS_COLAB = True
    os.chdir("/content/rome")
    if not torch.cuda.is_available():
        raise Exception("Change runtime type to include a GPU.")
except ModuleNotFoundError as _:
    pass

## Causal Tracing

A demonstration of the double-intervention causal tracing method.

The strategy used by causal tracing is to understand important
states within a transfomer by doing two interventions simultaneously:

1. Corrupt a subset of the input.  In our paper, we corrupt the subject tokens
   to frustrate the ability of the transformer to accurately complete factual
   prompts about the subject.
2. Restore a subset of the internal hidden states.  In our paper, we scan
   hidden states at all layers and all tokens, searching for individual states
   that carry the necessary information for the transformer to recover its
   capability to complete the factual prompt.

The traces of decisive states can be shown on a heatmap.  This notebook
demonstrates the code for conducting causal traces and creating these heatmaps.

In [None]:
%load_ext autoreload
%autoreload 2

The `experiments.causal_trace` module contains a set of functions for running causal traces.

In this notebook, we reproduce, demonstrate and discuss the interesting functions.

We begin by importing several utility functions that deal with tokens and transformer models.

In [None]:
# from rome file
from transformers import AutoModelForCausalLM, AutoTokenizer
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution

In [None]:
import os, re, json
import torch, numpy
from collections import defaultdict
from util import nethook
from util.globals import DATA_DIR
from experiments.causal_trace import (
    ModelAndTokenizer,
    layername,
    guess_subject,
    plot_trace_heatmap,
)
from experiments.causal_trace import (
    make_inputs,
    decode_tokens,
    find_token_range,
    # predict_token,
    predict_from_input,
    collect_embedding_std,
)
from dsets import KnownsDataset

torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fe89da64160>

In [None]:
import random
_seed = 1
random.seed(_seed)
numpy.random.seed(seed=_seed)
torch.manual_seed(_seed)

<torch._C.Generator at 0x7fe89d0f3830>

Now we load a model and tokenizer, and show that it can complete a couple factual statements correctly.

In [None]:
model_name = "gpt2-xl"  # or "EleutherAI/gpt-j-6B" or "EleutherAI/gpt-neox-20b"
mt = ModelAndTokenizer(
    model_name,
    low_cpu_mem_usage=IS_COLAB,
    torch_dtype=(torch.float16 if "20b" in model_name else None),
)

Downloading:   0%|          | 0.00/689 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/5.99G [00:00<?, ?B/s]

To obfuscate the subject during Causal Tracing, we use noise sampled from a zero-centered spherical Gaussian, whose stddev is 3 times the $\sigma$ stddev the model's embeddings. Let's compute that value.

In [None]:
knowns = KnownsDataset(DATA_DIR)  # Dataset of known facts
noise_level = 3 * collect_embedding_std(mt, [k["subject"] for k in knowns])
print(f"Using noise level {noise_level}")

data/known_1000.json does not exist. Downloading from https://rome.baulab.info/data/dsets/known_1000.json


  0%|          | 0.00/335k [00:00<?, ?B/s]

Loaded dataset with 1209 elements
Using noise level 0.13462981581687927


## Tracing a single location

The core intervention in causal tracing is captured in this function:

`trace_with_patch` a single causal trace.

It enables running a batch of inferences with two interventions.

  1. Random noise can be added to corrupt the inputs of some of the batch.
  2. At any point, clean non-noised state can be copied over from an
     uncorrupted batch member to other batch members.
  
The convention used by this function is that the zeroth element of the
batch is the uncorrupted run, and the subsequent elements of the batch
are the corrupted runs.  The argument tokens_to_mix specifies an
be corrupted by adding Gaussian noise to the embedding for the batch
inputs other than the first element in the batch.  Alternately,
subsequent runs could be corrupted by simply providing different
input tokens via the passed input batch.

To ensure that corrupted behavior is representative, in practice, we
will actually run several (ten) corrupted runs in the same batch,
each with its own sample of noise.

Then when running, a specified set of hidden states will be uncorrupted
by restoring their values to the same vector that they had in the
zeroth uncorrupted run.  This set of hidden states is listed in
states_to_patch, by listing [(token_index, layername), ...] pairs.
To trace the effect of just a single state, this can be just a single
token/layer pair.  To trace the effect of restoring a set of states,
any number of token indices and layers can be listed.

Note that this function is also in experiments.causal_trace; the code
is shown here to show the logic.

In [None]:
def trace_with_patch(
    model,  # The model
    inp,  # A set of inputs
    states_to_patch,  # A list of (token index, layername) triples to restore
    answers_t,  # Answer probabilities to collect
    tokens_to_mix,  # Range of tokens to corrupt (begin, end)
    noise=0.1,  # Level of noise to add
    trace_layers=None,  # List of traced outputs to return
):
    prng = numpy.random.RandomState()  ### For reproducibility, use pseudorandom noise
    patch_spec = defaultdict(list)
    for t, l in states_to_patch:
        patch_spec[l].append(t)
    embed_layername = layername(model, 0, "embed")

    def untuple(x):
        return x[0] if isinstance(x, tuple) else x

    # Define the model-patching rule.
    def patch_rep(x, layer):
        if layer == embed_layername:
            # If requested, we corrupt a range of token embeddings on batch items x[1:]
            if tokens_to_mix is not None:
                b, e = tokens_to_mix
                x[1:, b:e] += noise * torch.from_numpy(
                    prng.randn(x.shape[0] - 1, e - b, x.shape[2])
                ).to(x.device)
            return x
        if layer not in patch_spec:
            return x
        # If this layer is in the patch_spec, restore the uncorrupted hidden state
        # for selected tokens.
        h = untuple(x)
        for t in patch_spec[layer]:
            h[1:, t] = h[0, t]
        return x

    # With the patching rules defined, run the patched model in inference.
    additional_layers = [] if trace_layers is None else trace_layers
    with torch.no_grad(), nethook.TraceDict(
        model,
        [embed_layername] + list(patch_spec.keys()) + additional_layers,
        edit_output=patch_rep,
    ) as td:
        outputs_exp = model(**inp)

    # We report softmax probabilities for the answers_t token predictions of interest.
    probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t]

    # If tracing all layers, collect all activations together to return.
    if trace_layers is not None:
        all_traced = torch.stack(
            [untuple(td[layer].output).detach().cpu() for layer in trace_layers], dim=2
        )
        return probs, all_traced

    return probs

## Scanning all locations

A causal flow heatmap is created by repeating `trace_with_patch` at every individual hidden state, and measuring the impact of restoring state at each location.

The `calculate_hidden_flow` function does this loop.  It handles both the case of restoring a single hidden state, and also restoring MLP or attention states.  Because MLP and attention make small residual contributions, to observe a causal effect in those cases, we need to restore several layers of contributions at once, which is done by `trace_important_window`.

In [None]:
def calculate_hidden_flow(
    mt, prompt, subject, samples=10, noise=0.1, window=10, kind=None
):
    """
    Runs causal tracing over every token/layer combination in the network
    and returns a dictionary numerically summarizing the results.
    """
    inp = make_inputs(mt.tokenizer, [prompt] * (samples + 1))
    with torch.no_grad():
        answer_t, base_score = [d[0] for d in predict_from_input(mt.model, inp)]
    [answer] = decode_tokens(mt.tokenizer, [answer_t])
    e_range = find_token_range(mt.tokenizer, inp["input_ids"][0], subject)
    low_score = trace_with_patch(
        mt.model, inp, [], answer_t, e_range, noise=noise
    ).item()
    if not kind:
        differences = trace_important_states(
            mt.model, mt.num_layers, inp, e_range, answer_t, noise=noise
        )
    else:
        differences = trace_important_window(
            mt.model,
            mt.num_layers,
            inp,
            e_range,
            answer_t,
            noise=noise,
            window=window,
            kind=kind,
        )
    differences = differences.detach().cpu()
    return dict(
        scores=differences,
        low_score=low_score,
        high_score=base_score,
        input_ids=inp["input_ids"][0],
        input_tokens=decode_tokens(mt.tokenizer, inp["input_ids"][0]),
        subject_range=e_range,
        answer=answer,
        window=window,
        kind=kind or "",
    )


def trace_important_states(model, num_layers, inp, e_range, answer_t, noise=0.1):
    ntoks = inp["input_ids"].shape[1]
    table = []
    for tnum in range(ntoks):
        row = []
        for layer in range(0, num_layers):
            r = trace_with_patch(
                model,
                inp,
                [(tnum, layername(model, layer))],
                answer_t,
                tokens_to_mix=e_range,
                noise=noise,
            )
            row.append(r)
        table.append(torch.stack(row))
    return torch.stack(table)


def trace_important_window(
    model, num_layers, inp, e_range, answer_t, kind, window=10, noise=0.1
):
    ntoks = inp["input_ids"].shape[1]
    table = []
    for tnum in range(ntoks):
        row = []
        for layer in range(0, num_layers):
            layerlist = [
                (tnum, layername(model, L, kind))
                for L in range(
                    max(0, layer - window // 2), min(num_layers, layer - (-window // 2))
                )
            ]
            r = trace_with_patch(
                model, inp, layerlist, answer_t, tokens_to_mix=e_range, noise=noise
            )
            row.append(r)
        table.append(torch.stack(row))
    return torch.stack(table)

## Plotting the results

The `plot_trace_heatmap` function draws the data on a heatmap.  That function is not shown here; it is in `experiments.causal_trace`.


In [None]:
## bdika
# model, tok = (
#     AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=IS_COLAB).to(
#         "cuda"
#     ),
#     AutoTokenizer.from_pretrained(model_name),
# )
tok = AutoTokenizer.from_pretrained(model_name)
tok.pad_token = tok.eos_token

In [None]:
# from rome file

ALG_NAME = "ROME"
# Colab-only: install deps for MEND* and KE*
if IS_COLAB and not ALL_DEPS and any(x in ALG_NAME for x in ["MEND", "KE"]):
    print("Installing additional dependencies required for MEND and KE")
    !pip install -r /content/rome/scripts/colab_reqs/additional.txt >> /content/install.log 2>&1
    print("Finished installing")
    ALL_DEPS = True

import copy
model = mt.model
mt2 = copy.deepcopy(mt)

In [None]:
def plot_hidden_flow(
    mt,
    prompt,
    subject=None,
    samples=10,
    noise=0.1,
    window=10,
    kind=None,
    modelname=None,
    savepdf=None,
):
    if subject is None:
        subject = guess_subject(prompt)
    result = calculate_hidden_flow(
        mt, prompt, subject, samples=samples, noise=noise, window=window, kind=kind
    )
    print("result:\n",result)
    plot_trace_heatmap(result, savepdf, modelname=modelname)


def plot_all_flow(mt, prompt, subject=None, noise=0.1, modelname=None):
    for kind in ["mlp"]:
        plot_hidden_flow(
            mt, prompt, subject, modelname=modelname, noise=noise, kind=kind
        )

In [None]:
mt2.model = model

# **New Code**

In [None]:
# NEW_REMOTE_ROOT_URL = "https://rome.baulab.info"
# NEW_REMOTE_URL = f"{NEW_REMOTE_ROOT_URL}/data/dsets/zsre_mend_eval.json"

In [None]:
 import urllib, json

In [None]:
counterfacts_url = "https://rome.baulab.info/data/dsets/counterfact.json"
response = urllib.request.urlopen(counterfacts_url)
data = json.loads(response.read())

print(data[0]["requested_rewrite"]['subject'])

Danielle Darrieux


## **Trials**

# **proccess 1**

In [None]:
M = dict()


In [None]:
## predict_token
def predict_all_from_input(model, inp):
    out = model(**inp)["logits"]
    probs = torch.softmax(out[:, -1], dim=1)
    return probs

def predict_token(mt, prompts, return_p=False, return_idx = False):
    inp = make_inputs(mt.tokenizer, prompts)
    preds, p = predict_from_input(mt.model, inp)
    result = [mt.tokenizer.decode(c) for c in preds]
    if return_p:
        result = (result, p)
    elif return_idx:
        preds = preds[0]
        result = (result, preds)
    return result

def predict_by_idx(mt, prompt, idx):
  inp = make_inputs(mt.tokenizer, [prompt])
  preds = predict_all_from_input(mt.model, inp)
  return preds[0][idx].item

In [None]:
def naiv_predict(word, prompt, return_idx = False):
  t = predict_token(
    mt2,
    [f"{word} {prompt}"],
    return_p=False,
    return_idx = return_idx
  )
  
  if return_idx:
    return t[0][0][1:], t[1]
  return t[0][1:]
  

def predict(word, prompt, count=0, return_idx = False):
  if return_idx:
    next_tok, idx = naiv_predict(word, prompt, return_idx)
  else:
    next_tok = naiv_predict(word, prompt, return_idx)

  if next_tok not in ["the", "state", "State", "of", "Republic", "province", "Province"]:
    if return_idx:
      return next_tok, idx.item
    else:
      return next_tok

  prompt = prompt + " " + next_tok
  if count==8:
    if return_idx:
      return f"[{next_tok}]", idx.item
    else:
      return f"[{next_tok}]"
  
  try:
    if return_idx:
      next_next, idx = naiv_predict(word, prompt, return_idx)
    else:
      next_next = naiv_predict(word, prompt, return_idx)
  except:
    if return_idx:
      return f"[{next_tok}]", idx.item
    else:
      return f"[{next_tok}]"
  
  if return_idx:
    return f"[{next_tok}] {next_next}", idx.item
  else:
    return f"[{next_tok}] {next_next}"

In [None]:
animals = ["grizzly", "poodle", "terrier", "collie", "border collie", "Schnauzer", "bird", "sparrow", "pale rockfinch", "corvus", "jackdaw", "magpie-jay", "european goldfinch", "chaffinch", 
           "pine grosbeak", "carpornis", "atlantic royal flycatcher","pacific royal flycatcher","northern royal flycatcher", "pigeon", "parrot", "cockatiel", "eagle", "owl", "penguin", "chameleon"]


In [None]:
cities = ["Paris", "Bangkok", "Stockholm", "Moscow", "Bucharest", "Kigali", "Zagreb", "Nicosia", "Nairobi", "Ottawa", "Phnom Penh", "Bishkek", "Doha", "Seoul", "Havana", "Prague", "Lima", "Islamabad", "Port Moresby", "Helsinki", "Suva", "Lisbon", "Warsaw", "San Juan", "Riyadh", "Baghdad", "Muscat", "Belgrade", "Madrid", "Dakar", "Bratislava", "Ljubljana", "Freetown", "Damascus", "Mogadishu", "Khartoum", "Kathmandu", "Managua", "Niamey", "Wellington", "Abuja", "Kingston", "Oslo", "Rabat", "Skopje", "Cairo", "Kyiv", "Montevideo","Abu Dhabi", "Tehran", "Buenos Aires", "Berlin", "Amsterdam", "Astana", "Naypyidaw", "Lilongwe", "Kuala Lumpur","Ulaanbaatar", "Bamako", "Nouakchott", "Vilnius", "Monrovia", "Riga", "Tripoli", "Beirut", "Jerusalem"]

In [None]:
neighborhood = {
    "Paris": ["Bangkok", "Stockholm", "Moscow", "Bucharest", "Kigali", "Zagreb"],
    "Nicosia": ["Nairobi", "Ottawa", "Phnom Penh", "Bishkek", "Doha", "Seoul", "Havana"]
}

In [None]:
def neighbors_probs(subject, prompt, ngbr_dict, target_idx):
  neighbors = ngbr_dict[subject]
  probs = []
  for neighbor in neighbors:
    probs.append(predict_by_idx(mt2, f"{neighbor} {prompt}", target_idx))
  return probs

def neighboring(probs1, probs2):
  m = len(probs1)
  f = []
  for i in range(m):
    numerator = abs(probs1[i]-probs2[i])
    denominator = 0.5+abs(probs1[i]-0.5)
    ngbring = 1 - numerator / denominator
    f.append(ngbring)
  return sum(f) / m

In [None]:
def plot_all_flow(mt, prompt, subject=None, noise=0.1, modelname=None):
    for kind in ["mlp"]:
        plot_hidden_flow(
            mt, prompt, subject, modelname=modelname, noise=noise, kind=kind
        )

In [None]:
import math
def return_map(
    prompt,
    subject,
    mt=mt2,
    samples=10,
    noise=noise_level,
    window=10,
    kind="mlp",
    modelname=None,
    savepdf=None,
):
    if subject is None:
        subject = guess_subject(prompt)
    result = calculate_hidden_flow(
        mt, prompt, subject, samples=samples, noise=noise, window=window, kind=kind
    )
    return result


def generate_city_prompt(city):
  prompt = "is the capital city of"
  word=naiv_predict(city, prompt)
  while word in ["the", "state", "State", "of", "Republic", "province", "Province"]:
    prompt = prompt+ " " + word
    word=naiv_predict(city, prompt)
  return prompt


def entropy(tens):
    tens_norm = tens / tens.sum()
    logs = torch.log2(tens_norm)
    logs = torch.where(logs==-float("inf"),0,logs)
    y = logs * tens_norm
    return -y.sum().item() / math.log2(len(tens))


def max_layer_and_entropy(prompt, subject, max_neighbors=[1], effect_idx=17):
  result = return_map(prompt, subject)
  scores = result['scores']
  a, b = result['subject_range']
  argmax = scores[a:b].argmax().item()

  relevant_token_idx = int(argmax / len(scores[0])) + a
  relevant_token = scores[relevant_token_idx]

  _max = scores[a:b].max().item()
  _min = scores[a:b].min().item()
  avrg = relevant_token.sum().item() / (len(scores[0]))
  eff = relevant_token[effect_idx].item()

  layer = argmax % len(scores[0])
  cent = []
  for i in max_neighbors:
    if layer+i>=0:
      cent.append(((relevant_token[layer] - relevant_token[layer+1]) / relevant_token[layer]).item())
    else:
      cent.append(-1)
  return layer, entropy(relevant_token), cent, _max, _min, avrg, eff


In [None]:
  
def clean():
  if "orig_weights" in M.keys():
      with torch.no_grad():
          for k, v in M["orig_weights"].items():
              nethook.get_parameter(mt2.model, k)[...] = v
      print("Original model restored")
  else:
      print(f"No model weights to restore")

def change_and_chack(_subject, prompt, targets, affected, set_affected=None, count_flag=False):
  clean()
  
  temp_name = "orig_weights"

  if set_affected is not None:
    print("Change affected:")
    for word in affected:
      print("changing", word)
      random.seed(_seed)
      numpy.random.seed(seed=_seed)
      torch.manual_seed(_seed)

      request = [
        {
            "prompt": f"\u007b\u007d {prompt}",
            "subject": word,
            "target_new": {"str": set_affected},
        }
      ]

    
      M["model_new"], M[temp_name] = demo_model_editing(mt2.model, tok, request, ["a"], alg_name=ALG_NAME)
      mt2.model = M["model_new"]

      temp_name = "_"

  orig_object = predict(_subject, prompt)
  print(f"Pre check:\n{_subject} {prompt} {orig_object}")

  change_index = 1

  all_answers = []

  if count_flag:
    counts = []
    prev_line = {}
    for word in affected:
      prev_line[word]=predict(word, prompt)

  for t in targets:
    target=orig_object if t=="origin" else t

    random.seed(_seed)
    numpy.random.seed(seed=_seed)
    torch.manual_seed(_seed)

    tok_id = mt2.tokenizer.encode(target)
    pre_probs = neighbors_probs(_subject, prompt, neighborhood, tok_id)

    if count_flag:
      drag_count = 0
      conf_count = 0

    new_line = []
    new_line.append(target)

    print("CHANGE:", change_index, ":", target)
    change_index+=1

    request = [
        {
            "prompt": f"\u007b\u007d {prompt}",
            "subject": _subject,
            "target_new": {"str": target},
        }
    ]

    
    M["model_new"], M[temp_name] = demo_model_editing(mt2.model, tok, request, ["a"], alg_name=ALG_NAME)

    mt2.model = M["model_new"]

    temp_name = "_"

    # for word in affected:
    #   new_line.append(predict(word, prompt))

    #   if count_flag:
    #     if new_line[-1]!=prev_line[word]:
    #       if new_line[-1].split(" ")[-1]==target:
    #         drag_count+=1
    #       else:
    #         conf_count+=1
    #       prev_line[word]=new_line[-1]
    
    # if count_flag:
    #   counts.append((drag_count, conf_count))

    # all_answers.append(new_line)

    post_probs = neighbors_probs(_subject, prompt, neighborhood, tok_id)
    counts.append(neighboring(pre_probs, post_probs))
  
  return counts


def drag_animals(_subject, targets):
  return change_and_chack(_subject,"is a kind of", targets, affected=animals)

def drag_cities(_subject, targets, count_flag=False):
  return change_and_chack(_subject,"is the capital city of", targets, affected=cities, count_flag=count_flag)

In [None]:
def print_drags(subjects, targets=[], print_map=True, cents = [-2,-1,1,2], idx=17):
  clean()

  all_counts = {}

  if print_map:
    for i in cents:
      all_counts[f"centralization {i}"] = []
    all_counts["layers"] = []
    all_counts["entropies"] = []
    all_counts["max value"] = []
    all_counts["min value"] = []
    all_counts["average"] = []
    all_counts[f"effect in {idx}"] = []

    
    for i in range(len(subjects)):
      print(i, ": ", end="")
      subject = subjects[i]
      layer_idx, entropy, cent, _max, _min, avrg, eff = max_layer_and_entropy(f"{subject} {generate_city_prompt(subject)}", subject, max_neighbors=cents)
      all_counts["layers"].append(layer_idx)
      all_counts["entropies"].append(entropy)
      all_counts["max value"].append(_max)
      all_counts["min value"].append(_min)
      all_counts["average"].append(avrg)
      all_counts[f"effect in {idx}"].append(eff)

      for i in range(len(cents)):
        all_counts[f"centralization {cents[i]}"].append(cent[i])
      print("done")

  for i in range(len(targets)):
    all_counts[f"drag_{i+1}"] = []
    all_counts[f"change_{i+1}"] = []

  if len(targets)>0:
    for subject in subjects:
      counts = drag_cities(subject, targets, True)
      for i in range(len(counts)):
        all_counts[f"drag_{i+1}"].append(counts[i][0])
        all_counts[f"change_{i+1}"].append(counts[i][0]+counts[i][1])
      print(len(all_counts["drag_1"]), ", until", subject)
      for key in all_counts.keys():
        print(key, "=", all_counts[key])
  else:
    for key in all_counts.keys():
      print(key, "=", all_counts[key])

In [None]:
prcs = [
    ["Ghana", "China", "Algiers", "Greece", "Japan", "Ethiopia", "Niue", "Switzerland", "Jordan", "Turkey", "Samoa"],
    ["China", "Greece", "Ethiopia", "Switzerland", "Turkey", "Ghana", "Algiers", "Japan", "Niue", "Samoa", "Jordan"],
    ["Algiers", "Ethiopia", "Jordan", "China", "Switzerland", "Japan", "Samoa", "Ghana", "Greece", "Niue", "Turkey"],
    ["Greece", "Switzerland", "Ghana", "Japan", "Jordan", "Niue", "China", "Samoa", "Turkey", "Ethiopia", "Algiers"],
    ["Switzerland", "Niue", "Japan", "Ghana", "Ethiopia", "Turkey", "Greece", "Jordan", "Samoa", "Algiers", "China"],
    ["Samoa", "Japan", "Switzerland", "Algiers", "Niue", "Greece", "Ghana", "Turkey", "China", "Jordan", "Ethiopia"],
    ["Japan", "Algiers", "Turkey", "Jordan", "Greece", "Samoa", "Switzerland", "China", "Ethiopia", "Ghana", "Niue"]
]

In [None]:
print_drags(cities[:2], ["China", "T"], False)

Original model restored
Original model restored
Pre check:
Paris is the capital city of France
idx, p: 4881 0.8676778674125671
CHANGE: 1 : China
[{'prompt': '{} is the capital city of', 'subject': 'Paris', 'target_new': {'str': 'China'}}] 

Computing left vector (u)...
Selected u projection object Paris
Computing right vector (v)
Lookup index found: 0 | Sentence: Paris is the capital city of | Token: Paris
Rewrite layer is 17
Tying optimization objective to 47
Recording initial value of v*
loss 7.559 = 7.559 + 0.0 + 0.0 avg prob of [ China] 0.000852875760756433
loss 5.702 = 5.698 + 0.003 + 0.001 avg prob of [ China] 0.005176021251827478
loss 4.745 = 4.736 + 0.008 + 0.002 avg prob of [ China] 0.014469996094703674
loss 3.949 = 3.934 + 0.012 + 0.002 avg prob of [ China] 0.03193957358598709
loss 2.829 = 2.81 + 0.017 + 0.003 avg prob of [ China] 0.08889811486005783
loss 1.739 = 1.714 + 0.021 + 0.003 avg prob of [ China] 0.24955996870994568
loss 1.144 = 1.115 + 0.026 + 0.003 avg prob of [ Ch

## **Draft**

In [None]:
# def prdict_city()

In [None]:
# print(predict("corvus"))

In [None]:
# request = [
#     {
#         "prompt": "{} is the capital city of",
#         "subject": "Paris",
#         "target_new": {"str": "China"},
#     }
# ]

# # Execute rewrite
# model_new, orig_weights = demo_model_editing(model, tok, request, ["a"], alg_name=ALG_NAME)

# mt2.model = model_new


In [None]:
# for city in ["Suva", "Lisbon", "Warsaw", "San Juan", "Riyadh", "Baghdad", "Muscat", "Belgrade", "Madrid", "Dakar", "Bratislava", "Ljubljana", "Freetown", "Damascus", "Mogadishu"]:
#   print(city)
#   drag_cities(city, ["Japan", "China"], True)

In [None]:
# for city in ["Kigali", "Bishkek","Nicosia", "Bucharest", "Paris", "Moscow", "Stockholm", "Bangkok", "Prague"]:
#   print(max_layer_and_entropy(f"{city} {generate_city_prompt(city)}", city))

In [None]:
# predict("Adamstown", "is the capital city of the state of")

In [None]:
# print(generate_city_prompt("Moscow"))
# print(generate_city_prompt("Prague"))
# print(generate_city_prompt("Paris"))
# print(generate_city_prompt("Papeete"))
# print(generate_city_prompt("Adamstown"))



In [None]:
# print(max_layer_and_entropy("Stockholm is the capital city of", "Stockholm"))

In [None]:
# for city in cities:
#   print(city," | ", predict(city, "is the capital city of"))

In [None]:
# drag_animals("sparrow", ["dog", "lizard", "bird"])
# change_and_chack("TTTTT", ["JJ", "KK", "LL"])

In [None]:
# drag_cities("Paris", ["Japan", "China", "France"], True)

In [None]:
# nonsense = ["kv", "fg", "de", "oj", "mdo", "mzv", "ahz", "zjx", "oxzz", "wdcp", "rfvn", "dwgq", "ofkcn", "krzrw", "zlaiq", "arzdp", "yraxjo", "edjxpa", "jdrhdq", "vjulqc", "iyapuql", "jglwuos", "bljjgzv", "ibryurx", "cxmvyvat", "twyzhcpr", "fnfvvluj", "vjrknbpp", "ftrbwywac", "swjwniqas", "ddssywine", "jgrpttwbn", "oybmpearnv", "vapkrtajcn", "coltptglwa", "mebtlpozkb"]

# def drag_nonsense(_subject, targets, _set_affected):
#   change_and_chack(_subject, "is a kind of", targets, affected=nonsense, set_affected=_set_affected)

In [None]:
# plot_all_flow(mt2, f"Suva is the capital city of the Republic of", noise=noise_level, subject="Suva")

In [None]:
# plot_all_flow(mt2, f"Suva is the capital city of the Republic of the", noise=noise_level, subject="Suva")

In [None]:
# for city in [ "Lisbon", "Warsaw", "San Juan", "Riyadh", "Baghdad", "Muscat", "Belgrade", "Madrid", "Dakar", "Bratislava", "Ljubljana", "Freetown", "Damascus", "Mogadishu"]:
#   plot_all_flow(mt2, f"{city} is the capital city of", noise=noise_level, subject=city)

In [None]:
# plot_all_flow(mt2, f"Beirut is the capital city of", noise=noise_level, subject="Beirut")
# plot_all_flow(mt2, f"Tripoli is the capital city of", noise=noise_level, subject="Tripoli")
# plot_all_flow(mt2, f"Oslo is the capital city of", noise=noise_level, subject="Oslo")

In [None]:
# plot_all_flow(mt2, "grizzly is a kind of", noise=noise_level, subject="grizzly")

# plot_all_flow(mt2, "poodle is a kind of", noise=noise_level, subject="poodle")
# plot_all_flow(mt2, "terrier is a kind of", noise=noise_level, subject="terrier")
# plot_all_flow(mt2, "collie is a kind of", noise=noise_level, subject="collie")
# plot_all_flow(mt2, "border collie is a kind of", noise=noise_level, subject="border collie")
# plot_all_flow(mt2, "Schnauzer is a kind of", noise=noise_level, subject="Schnauzer")

# # texonomy:
# ##### class
# #### order
# ### suborder
# ## family

# ##### birds
# plot_all_flow(mt2, "bird is a kind of", noise=noise_level, subject="bird")

# #### Passerine

# ### Songbird
# ##
# plot_all_flow(mt2, "sparrow is a kind of", noise=noise_level, subject="sparrow")
# plot_all_flow(mt2, "pale rockfinch is a kind of", noise=noise_level, subject="pale rockfinch")
# ##
# plot_all_flow(mt2, "corvus is a kind of", noise=noise_level, subject="corvus")
# plot_all_flow(mt2, "jackdaw is a kind of", noise=noise_level, subject="jackdaw")
# plot_all_flow(mt2, "magpie-jay is a kind of", noise=noise_level, subject="magpie-jay")
# ##
# plot_all_flow(mt2, "european goldfinch is a kind of", noise=noise_level, subject="european goldfinch")
# plot_all_flow(mt2, "chaffinch is a kind of", noise=noise_level, subject="chaffinch")
# plot_all_flow(mt2, "pine grosbeak is a kind of", noise=noise_level, subject="pine grosbeak")

# ### Tyranni
# ##
# plot_all_flow(mt2, "carpornis is a kind of", noise=noise_level, subject="carpornis")
# ##
# plot_all_flow(mt2, "atlantic royal flycatcher is a kind of", noise=noise_level, subject="atlantic royal flycatcher")
# plot_all_flow(mt2, "pacific royal flycatcher is a kind of", noise=noise_level, subject="pacific royal flycatcher")
# plot_all_flow(mt2, "northern royal flycatcher is a kind of", noise=noise_level, subject="northern royal flycatcher")


# ####
# ## Columbidae
# plot_all_flow(mt2, "pigeon is a kind of", noise=noise_level, subject="pigeon")

# #### parrot
# plot_all_flow(mt2, "parrot is a kind of", noise=noise_level, subject="parrot")
# ##
# plot_all_flow(mt2, "cockatiel is a kind of", noise=noise_level, subject="cockatiel")

# ####
# ## eagle
# plot_all_flow(mt2, "eagle is a kind of", noise=noise_level, subject="eagle")

# plot_all_flow(mt2, "owl is a kind of", noise=noise_level, subject="owl")

# ####
# ## Penguin
# plot_all_flow(mt2, "penguin is a kind of", noise=noise_level, subject="penguin")


# plot_all_flow(mt2, "chameleon is a kind of", noise=noise_level, subject="chameleon")





In [None]:
# drag_nonsense("sparrow", ["dog", "dog", "dog", "dog"], "bird")

In [None]:
# request = [
#     {
#         "prompt": "{} is a kind of",
#         "subject": "zjx",
#         "target_new": {"str": "bird"},
#     }
# ]

# # Execute rewrite
# model_new, orig_weights = demo_model_editing(model, tok, request, ["a"], alg_name=ALG_NAME)

# mt2.model = model_new

In [None]:
# request = [
#     {
#         "prompt": "{} is a kind of",
#         "subject": "pigeon",
#         "target_new": {"str": "bird"},
#     }
# ]

# # Execute rewrite
# model_new, orig_weights = demo_model_editing(model, tok, request, ["a"], alg_name=ALG_NAME)

# mt2.model = model_new
# print(200)