# Notebook config

In [None]:
directory = "add-patch-v1"

# Imports

In [None]:
%%bash
pip install transformers==4.27.4
pip install datasets
pip install nltk

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from ast import literal_eval
import functools
import json
import os
import random
import re

# Scientific packages
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
torch.set_grad_enabled(False)
tqdm.pandas()

from plotly.subplots import make_subplots
import plotly.graph_objects as go

# Visuals
from matplotlib import pyplot as plt
import seaborn as sns
sns.set(context="notebook",
        rc={"font.size":16,
            "axes.titlesize":16,
            "axes.labelsize":16,
            "xtick.labelsize": 16.0,
            "ytick.labelsize": 16.0,
            "legend.fontsize": 16.0})
palette_ = sns.color_palette("Set1")
palette = palette_[2:5] + palette_[7:]
sns.set_theme(style='whitegrid')

import altair as alt

def softmax(x):
  e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
  return e_x / e_x.sum(axis=-1, keepdims=True)

In [None]:
import torch
import transformers
import warnings
from matplotlib.colors import to_hex

print(torch.cuda.is_available())
if torch.cuda.is_available():
  num_gpus = torch.cuda.device_count()
  print(f"Number of available GPUs: {num_gpus}")

  for i in range(num_gpus):
      gpu_name = torch.cuda.get_device_name(i)
      print(f"GPU {i}: {gpu_name}")

True
Number of available GPUs: 1
GPU 0: NVIDIA A100-SXM4-40GB


In [None]:
import jax.numpy as jnp
class NpEncoder(json.JSONEncoder):
  """Save NP as json."""

  def default(self, o):
    if isinstance(o, np.integer):
      return int(o)
    if isinstance(o, np.floating):
      return float(o)
    if isinstance(o, np.ndarray):
      return o.tolist()

    if isinstance(o, jnp.integer):
      return int(o)
    if isinstance(o, jnp.floating):
      return float(o)
    if isinstance(o, jnp.ndarray):
      return o.tolist()

    if isinstance(o, range):
      return list(o)

    return super(NpEncoder, self).default(o)


In [None]:
import json
import os

# ROME Utilities

In [None]:
"""Utility class and functions.

Adapted from:
https://github.com/kmeng01/rome/blob/bef95a6afd2ca15d794bdd4e3ee0f24283f9b996/
"""

class ModelAndTokenizer:
  """An object to hold a GPT-style language model and tokenizer."""

  def __init__(
      self,
      model_name=None,
      model=None,
      tokenizer=None,
      low_cpu_mem_usage=False,
      torch_dtype=None,
      ):
    if tokenizer is None:
      assert model_name is not None
      tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    if model is None:
      assert model_name is not None
      model = transformers.AutoModelForCausalLM.from_pretrained(
          model_name, low_cpu_mem_usage=low_cpu_mem_usage,
          torch_dtype=torch_dtype
          )
      set_requires_grad(False, model)
      model.eval().cuda()
    self.tokenizer = tokenizer
    self.model = model
    self.layer_names = [
        n
        for n, _ in model.named_modules()
        if (re.match(r"^(transformer|gpt_neox)\.(h|layers)\.\d+$", n))
    ]
    self.num_layers = len(self.layer_names)

  def __repr__(self):
    """String representation of this class.
    """
    return (
        f"ModelAndTokenizer(model: {type(self.model).__name__} "
        f"[{self.num_layers} layers], "
        f"tokenizer: {type(self.tokenizer).__name__})"
        )

# TODO: Adds as methods on ModelAndTokenizer?
def make_inputs(tokenizer, prompts, device="cuda"):
  """Prepare inputs to the model."""
  token_lists = [tokenizer.encode(p) for p in prompts]
  maxlen = max(len(t) for t in token_lists)
  if "[PAD]" in tokenizer.all_special_tokens:
    pad_id = tokenizer.all_special_ids[
        tokenizer.all_special_tokens.index("[PAD]")
        ]
  else:
    pad_id = 0
  input_ids = [
      [pad_id] * (maxlen - len(t)) + t for t in token_lists]
  attention_mask = [
      [0] * (maxlen - len(t)) + [1] * len(t) for t in token_lists
      ]
  return dict(
      input_ids=torch.tensor(input_ids).to(device),
      attention_mask=torch.tensor(attention_mask).to(device),
      )

def decode_tokens(tokenizer, token_array):
  if hasattr(token_array, "shape") and len(token_array.shape) > 1:
    return [decode_tokens(tokenizer, row) for row in token_array]
  return [tokenizer.decode([t]) for t in token_array]

def predict_from_input(model, inp):
  out = model(**inp)["logits"]
  probs = torch.softmax(out[:, -1], dim=1)
  p, preds = torch.max(probs, dim=1)
  return preds, p

def set_requires_grad(requires_grad, *models):
  for model in models:
    if isinstance(model, torch.nn.Module):
      for param in model.parameters():
        param.requires_grad = requires_grad
    elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
      model.requires_grad = requires_grad
    else:
      assert False, "unknown type %r" % type(model)

In [None]:
mt = ModelAndTokenizer(
    "EleutherAI/gpt-j-6B", # "EleutherAI/pythia-70m-deduped-v0"
    low_cpu_mem_usage=False,
    torch_dtype=None,
)
mt.model.eval()

GPTJForCausalLM(
  (transformer): GPTJModel(
    (wte): Embedding(50400, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-27): 28 x GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): GPTJMLP(
          (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
          (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f)

In [None]:
sorted_vocab = sorted(mt.tokenizer.vocab.items(), key=lambda item: item[1])
vocab_list = [mt.tokenizer.decode(d[1]) for d in sorted_vocab]

In [None]:
def set_hs_patch_hooks(model, hs_patch_config, patch_input=False, generation_mode=False):
    # when using mode.generate() the hidden states in the input are cached after the first inference pass,
    # and in the next steps the input/output are of size 1. In these cases we don't need to patch anymore the
    # previous hidden states from the initial input, because they are cached, but we do need to handle these
    # cases in this call because this hook wraps the generation call.
    #
    # NOTE: To use generation mode, we must patch a position that is not the first one. This is because in
    # this case we don't know during generation if we are handling the initial input or a future step and
    # thus don't know if a patching is needed or not.
    if generation_mode:
        for i in hs_patch_config:
            for position_, _ in hs_patch_config[i]:
                assert position_ > 0

    def patch_hs(name, position_hs, patch_input, generation_mode):

        def pre_hook(module, input):
            # (batch, sequence, hidden_state)
            input_len = len(input[0][0])
            if generation_mode and input_len == 1:
                return
            for position_, hs_ in position_hs:
                input[0][0, position_] = hs_

        def post_hook(module, input, output):
            # (batch, sequence, hidden_state)
            output_len = len(output[0][0])
            if generation_mode and output_len == 1:
                return
            for position_, hs_ in position_hs:
                output[0][0, position_] = hs_

        if patch_input:
            return pre_hook
        else:
            return post_hook

    hooks = []
    for i in hs_patch_config:
        if patch_input:
            hooks.append(model.transformer.h[i].register_forward_pre_hook(
                patch_hs(f"patch_hs_{i}", hs_patch_config[i], patch_input, generation_mode)
            ))
        else:
            hooks.append(model.transformer.h[i].register_forward_hook(
                patch_hs(f"patch_hs_{i}", hs_patch_config[i], patch_input, generation_mode)
            ))

    return hooks

def remove_hooks(hooks):
    for hook in hooks:
        hook.remove()

In [None]:
def _generate(prompt, gen_len=10):
  inp = make_inputs(mt.tokenizer, [prompt])
  max_len = len(inp['input_ids'][0])+gen_len
  input_text = [mt.tokenizer.decode(x) for x in inp['input_ids'][0]]
  generated_continuation = [mt.tokenizer.decode(x) for x in mt.model.generate(
        inp['input_ids'],
        pad_token_id = mt.model.generation_config.eos_token_id,
        max_length=max_len
    )[0][len(inp['input_ids'][0]):]]
  return input_text, generated_continuation


# Hello Patching

In [None]:
def patch_layer(example, verbose=False):
  prompt_src, prompt_dst, layer_dst, position_src, position_dst = example['prompt_src'], example['prompt_dst'], example['layer_dst'], example['position_src'], example['position_dst']
  layer_src = example.get('layer_src', layer_dst)

  # Run the model on prompt_src and get all hidden states.
  inp_src = make_inputs(mt.tokenizer, [prompt_src])
  output_src = mt.model(**inp_src, output_hidden_states=True)
  hs_cache_ = [output_src['hidden_states'][layer+1][0] for layer in range(mt.num_layers)]

  # Run the model on prompt_dst, while patching in hidden state the prompt_src run.
  hs_patch_config = {
      layer_dst: [(position_dst, hs_cache_[layer_src][position_src])]
  }
  patch_hooks = set_hs_patch_hooks(mt.model, hs_patch_config, patch_input=False)
  inp_dst = make_inputs(mt.tokenizer, [prompt_dst])
  output_dst = mt.model(**inp_dst)
  remove_hooks(patch_hooks)

  if verbose:
    print(f'patching position {position_dst} at layer {layer_dst} with the hidden state from position {position_src} at layer {layer_src}.')
    print('prompt_src:', [mt.tokenizer.decode(x) for x in inp_src['input_ids'][0]])
    print('prompt_dst:', [mt.tokenizer.decode(x) for x in inp_dst['input_ids'][0]])

    answer_prob, answer_t = torch.max(torch.softmax(output_src.logits[0, -1, :], dim=0), dim=0)
    print('original prediction: ', decode_tokens(mt.tokenizer, [answer_t])[0], round(answer_prob.cpu().item(), 4))

    answer_prob, answer_t = torch.max(torch.softmax(output_dst.logits[0, -1, :], dim=0), dim=0)
    print('prediction with patching: ', decode_tokens(mt.tokenizer, [answer_t])[0], round(answer_prob.cpu().item(), 4))
    print('\n')

  return output_src.logits[0, -1, :].cpu(), output_dst.logits[0, -1, :].cpu(), hs_cache_

logits_src, logits_dst, hs_cache_ = patch_layer({
  # 'prompt_src': 'What is 667 plus 45?\n',
  # 'prompt_src': '3+11\n',
  # 'prompt_dst': '3+11\n',
  'prompt_src': '1 2 3 4 5 6',
  'prompt_dst': '1 2 3 4 5 6',
  'layer_src': 26,
  'layer_dst': 26,
  'position_src': -1,
  'position_dst': -1
}, verbose=True)

print('logits_src - logits_dst:', logits_src - logits_dst)

patching position -1 at layer 26 with the hidden state from position -1 at layer 26.
prompt_src: ['1', ' 2', ' 3', ' 4', ' 5', ' 6']
prompt_dst: ['1', ' 2', ' 3', ' 4', ' 5', ' 6']
original prediction:   7 0.9255
prediction with patching:   7 0.9255


logits_src - logits_dst: tensor([0., 0., 0.,  ..., 0., 0., 0.])


# Batch decode

In [None]:
def batch_patch(config):
  src_layer = config["src_layer"]
  dst_layer = config["dst_layer"]
  src_position = config["src_position"]
  dst_position = config["dst_position"]
  src_prompts = config["src_prompts"]
  dst_prompts = config["dst_prompts"]

  src_inputs = mt.tokenizer.batch_encode_plus(src_prompts, return_tensors='pt').to(device='cuda').input_ids
  src_outputs = mt.model(src_inputs, output_hidden_states=True)

  def patch_interpolate(module, input, output):
    for batch_index in range(src_inputs.shape[0]):
      # (batch, sequence, hidden_state)
      output[0][batch_index, -1] = src_outputs.hidden_states[src_layer][batch_index][src_position]

  hook = mt.model.transformer.h[dst_layer].register_forward_hook(patch_interpolate)

  try:
    dst_inputs = mt.tokenizer.batch_encode_plus(dst_prompts, return_tensors='pt').to(device='cuda').input_ids
    dst_outputs = mt.model(dst_inputs, output_hidden_states=True)
  except Exception as e:
    print('error', e)

  hook.remove()

  return dst_outputs


dst_outputs = batch_patch({
  "src_layer": 10,
  "dst_layer": 10,
  "src_position": -1,
  "dst_position": -1,
  "src_prompts": [
      '10 + 1 = ',
      '10 + 2 = ',
      '10 + 3 = ',
      '10 + 4 = ',
  ],
  "dst_prompts": ['cat cat hat hat 3 3 x']*4
})

# Patching all layers

In [None]:
def patch_layer_sweep(d):
  prompt_src, prompt_dst, layers_dst, position_src, position_dst = d['prompt_src'], d['prompt_dst'], d['layers_dst'], d['position_src'], d['position_dst']
  layers_src = d.get('layers_src', layers_dst)

  # Run the model on prompt_src and get all hidden states.
  inp_src = make_inputs(mt.tokenizer, [prompt_src])
  output_src = mt.model(**inp_src, output_hidden_states=True)
  hs_cache_ = [output_src['hidden_states'][layer+1][0] for layer in range(mt.num_layers)]

  # Loop over all layers
  # TODO: Batch?
  output_dst_logits = []
  for index, layer_src in enumerate(layers_src):
    layer_dst = layers_dst[index]

    # Run the model on prompt_dst, while patching in hidden state the prompt_src run.
    hs_patch_config = {
        layer_dst: [(position_dst, hs_cache_[layer_src][position_src])]
    }
    patch_hooks = set_hs_patch_hooks(mt.model, hs_patch_config, patch_input=False)
    inp_dst = make_inputs(mt.tokenizer, [prompt_dst])
    output_dst = mt.model(**inp_dst)
    remove_hooks(patch_hooks)

    output_dst_logits.append(output_dst.logits[0, -1, :].cpu().numpy())

  return output_src.logits[0, -1, :].cpu().numpy(), np.array(output_dst_logits)


logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 667 plus 45?\n',
  'prompt_src': 'What is 38 plus 87?\n',
  'prompt_dst': '4 4 22 22 77 77 41',
  'layers_dst': range(28),
  'position_src': -1,
  'position_dst': -1
})

vocab_list[np.argmax(logits_src)]

'125'

In [None]:
def fmt_layer_sweet_df(output_dst_logits):
  top_indices = np.unique(np.argsort(output_dst_logits)[:, -5:].flatten())
  softmax_values = softmax(output_dst_logits)
  sorted_indices = np.argsort(-output_dst_logits, axis=1)

  data = []
  for layer, logits in enumerate(output_dst_logits):
    for idx in top_indices:
      data.append({
        'Layer': layer,
        'Token': vocab_list[idx],
        'Logit': logits[idx],
        'Softmax': softmax_values[layer, idx],
        'Rank': (sorted_indices[layer] == idx).argmax() + 1,
        'idx': idx,
      })

  return pd.DataFrame(data)

def plot_layer_sweep(output_dst_logits):
  base = alt.Chart(fmt_layer_sweet_df(output_dst_logits)).encode(
    x=alt.X('Layer:Q', axis=alt.Axis(tickCount=5)),
    tooltip=alt.Tooltip('Token:N'),
    color='Token:N',
  ).properties(
      width=200
  ).mark_line()

  logit_chart = base.encode(y='Logit:Q')
  softmax_chart = base.encode(y='Softmax:Q')
  rank_chart = base.encode(y=alt.Y('Rank:Q', scale=alt.Scale(type='log', reverse=True)))

  return alt.hconcat(logit_chart, softmax_chart, rank_chart)

def plot_layer_sweep_char_color(output_dst_logits):
  df = fmt_layer_sweet_df(output_dst_logits)

  def generate_shades(color, n):
    return [to_hex(shade) for shade in sns.light_palette(color, n_colors=n)]

  num_shades = {
    1: df['Token'][df['Token'].str.len() == 1].nunique(),
    2: df['Token'][df['Token'].str.len() == 2].nunique(),
    3: df['Token'][df['Token'].str.len() == 3].nunique(),
    4: df['Token'][df['Token'].str.len() == 4].nunique(),
  }

  colors = ['steelblue', 'orange', 'purple', '#386325']
  color_scale = {}

  for num_digits, color in zip(num_shades.keys(), colors):
    shades = generate_shades(color, num_shades[num_digits] + 2)
    tokens_of_length = df['Token'][df['Token'].str.len() == num_digits].unique()
    color_scale.update({token: shades[i + 2] for i, token in enumerate(tokens_of_length)})

  base = alt.Chart(df).encode(
      x=alt.X('Layer:Q'),
      tooltip=alt.Tooltip('Token:N'),
      color=alt.Color('Token:N', scale=alt.Scale(domain=list(color_scale.keys()), range=list(color_scale.values()))),
  ).properties(
      width=180
  ).mark_line()

  logit_chart = base.encode(y='Logit:Q')
  softmax_chart = base.encode(y='Softmax:Q')
  rank_chart = base.encode(y=alt.Y('Rank:Q', scale=alt.Scale(type='log', reverse=True)))

  return alt.hconcat(logit_chart, softmax_chart, rank_chart)


# Before and after experiment run

In [None]:
# @title Before
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 68 plus 0?\n',
  'prompt_dst': 'j j 82 82 c c t t X', # TODO: also identity functions that are more general, not necessarily numbers
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})
plot_layer_sweep_char_color(output_dst_logits)

In [None]:
# @title After
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 68 plus 0?\n',
  'prompt_dst': 'j j 82 82 c c t t X', # TODO: also identity functions that are more general, not necessarily numbers
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})
plot_layer_sweep_char_color(output_dst_logits)

# Misc charts

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 667 plus 45?\n',
  'prompt_dst': '31 31 22 22 77 77 64 64 x',
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})
plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 667 plus 45?\n',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})
plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 667 plus 45?\n',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})
plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 3 plus 2?\n',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': '11+44=55\n667+45=',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 123123 plus 12312344?\n',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

# Patching all layers sweep experiment

- vary input numbers
- vary prompt_src
- vary prompt_dst

Aggregate
- How often does max rank match?
- What's the difference in rank?

In [None]:
prompt_src_templates = [
  'What is {} plus {}?\n',
  'Calculate {} + {}.\n',
  '45 + 21 = 65\n {} + {} =',
  '45 plus 22 is 67\n {} plus to {} is',
  'The sum of 5 and 5 is 10. The sum of 91 and 12 is 103. The sum of {} and {} is',
  '4 added to 5 is 9. 48 added to 11 is 59. 30 added to 88 is 118. {} added to {} is',
]

In [None]:
prompt_dst_templates = [
  'account account wool wool seed seed meat meat X',
  'rice rice hole hole opposite opposite way way X',
  'angle angle mark mark design design chief chief X',
  'harbour harbour point point death death black black X',
  '34 34 40 40 80 80 63 63 X',
  '24 24 53 53 72 72 59 59 X',
  'n n d d a a h h X',
  'j j 82 82 c c t t X'
]

In [None]:
random.seed(42)
seen_pairs = set()
while len(seen_pairs) < 100:
  seen_pairs.add((random.randint(10, 99), random.randint(0, 99)))

random_pairs = list(seen_pairs)

In [None]:
import itertools
from tqdm import tqdm

In [None]:
experiments = []
for src_template, dst_template, pair in itertools.product(prompt_src_templates, prompt_dst_templates, random_pairs):
  num1, num2 = pair
  experiments.append({
    'num_inputs': [num1, num2],
    'src_template_og': src_template,
    'prompt_src': src_template.format(num1, num2),
    'prompt_dst': dst_template,
    'layers_dst': range(27),
    'position_src': -1,
    'position_dst': -1
  })


In [None]:
for e in experiments[:10]:
  logits_src, output_dst_logits = patch_layer_sweep(e)
  e['logits_src'] = logits_src
  e['output_dst_logits'] = output_dst_logits

In [None]:
# @title Run experiments
# TODO save representations for probing
for e in tqdm(experiments, desc="Processing experiments"):
  logits_src, output_dst_logits = patch_layer_sweep(e)
  e['logits_src'] = logits_src
  e['output_dst_logits'] = output_dst_logits

Processing experiments: 100%|██████████| 4800/4800 [2:09:00<00:00,  1.61s/it]


In [None]:
for e in experiments:
  e['argmax_src'] = np.argmax(e['logits_src'])
  e['argmax_str'] = vocab_list[e['argmax_src']]
  e['argmax_isdigit'] = e['argmax_str'].isdigit()

  e['ranks'] = list(map(lambda l: np.sum(l > l[e['argmax_src']]), e['output_dst_logits']))
  [a, b] = e['num_inputs']
  e['template_src'] = e['prompt_src'].replace(str(a), '{}').replace(str(b), '{}')

In [None]:
# @title Save experiments
os.makedirs(directory, exist_ok=True)
os.makedirs(f"{directory}/logits_src", exist_ok=True)
os.makedirs(f"{directory}/output_dst_logits", exist_ok=True)

# Save logits_src and output_dst_logits for each experiment
for idx, experiment in enumerate(experiments):
  logits_src_file_path = f"{directory}/logits_src/{str(idx).zfill(6)}.npy"
  np.save(logits_src_file_path, experiment['logits_src'])

  output_dst_logits_file_path = f"{directory}/output_dst_logits/{str(idx).zfill(6)}.npy"
  np.save(output_dst_logits_file_path, experiment['output_dst_logits'])

experiments_copy = [experiment.copy() for experiment in experiments]
for e in experiments_copy:
  del e['layers_dst']
  del e['logits_src']
  del e['output_dst_logits']

with open(f"{directory}/experiments.json", "w") as f:
  json.dump(experiments_copy, f, indent=2, cls=NpEncoder)

In [None]:
with open(f"{directory}/vocab_list.json", "w") as f:
  json.dump(vocab_list, f, indent=2, cls=NpEncoder)

In [None]:
# !ls -h -ll add-patch-v1

In [None]:
!tar -cf add-patch-v1.tar add-patch-v1/

In [None]:
len(experiments_copy)

4800

# Load experiments

In [None]:
with open(f"{directory}/experiments.json", "r") as f:
  experiments = json.load(f)

# Loop through experiments to load the corresponding logits
for i, experiment in enumerate(experiments):
  experiment['logits_src'] = np.load(f"{directory}/logits_src/{str(i).zfill(6)}.npy")
  experiment['output_dst_logits'] = np.load(f"{directory}/output_dst_logits/{str(i).zfill(6)}.npy")

In [None]:
for e in experiments:
  e['argmax_src'] = np.argmax(e['logits_src'])
  e['argmax_str'] = vocab_list[e['argmax_src']]
  e['ranks'] = list(map(lambda l: np.sum(l > l[e['argmax_src']]), e['output_dst_logits']))

  [a, b] = e['num_inputs']
  e['template_src'] = e['prompt_src'].replace(str(a), '{}').replace(str(b), '{}')

In [None]:
count_ranks_zero = [0] * 27
for e in experiments:
  for i, rank in enumerate(e['ranks']):
    if rank == 0:
      count_ranks_zero[i] += 1
total_experiments = len(experiments)
percent_ranks_zero = [(count / total_experiments) * 100 for count in count_ranks_zero]

In [None]:
import plotly.graph_objects as go

fig = go.Figure()

fig.add_trace(go.Scatter(
    x=list(range(27)),  # Layer numbers (1 to 27)
    y=percent_ranks_zero,  # Percentages
    mode='lines+markers',
    name='Percentage of Rank==0'
))

fig.update_layout(
    title='Percentage of Experiments with Rank==0 at Each Layer',
    xaxis_title='Layer',
    yaxis_title='Percentage (%)'
)

fig.show()

In [None]:
exp_ranks = []
for e in experiments:
  for i, rank in enumerate(e['ranks']):
    exp_ranks.append({
      'layer': i,
      'rank': rank,
      'num_inputs': e['num_inputs'],
      'prompt_src': e['prompt_src'],
      'prompt_dst': e['prompt_dst'],
      'template_src': e['template_src'],
      'argmax_src': e['argmax_src'],
      'argmax_str': e['argmax_str'],
      'argmax_isdigit': e['argmax_isdigit'],
    })

exp_ranks_df = pd.DataFrame(exp_ranks)

In [None]:
agg_df = exp_ranks_df.groupby(['layer', 'prompt_dst'])['rank'].agg(
    percent_rank_0=lambda x: (x == 0).mean() * 100
).reset_index()

chart = alt.Chart(agg_df).mark_line().encode(
    x='layer:O',
    y='percent_rank_0:Q',
    color='prompt_dst:N',
    # row='prompt_dst:N'f
).properties(
    title='Percentage of Experiments with Rank==0 at Each Layer'
)
chart

In [None]:
agg_df = exp_ranks_df.groupby(['layer', 'template_src'])['rank'].agg(
    percent_rank_0=lambda x: (x == 0).mean() * 100
).reset_index()

chart = alt.Chart(agg_df).mark_line().encode(
    x='layer:O',
    y='percent_rank_0:Q',
    color='template_src:N',
    # row='prompt_dst:N'f
).properties(
    title='Percentage of Experiments with Rank==0 at Each Layer'
)
chart

# Patching all layers adhoc vis

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 667 plus 45?\n',
  'prompt_dst': '31 31 22 22 77 77 64 64 13', # TODO: also identity functions that are more general, not necessarily numbers
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})
plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 667 plus 45?\n',
  'prompt_dst': '31 31 22 22 77 77 64 64 13', # TODO: also identity functions that are more general, not necessarily numbers
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})
plot_layer_sweep_char_color(output_dst_logits)

In [None]:
def remove_hooks_from_model(model):
  for module in model.modules():
    module._forward_hooks.clear()

remove_hooks_from_model(mt.model)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 667 plus 45?\n',
  'prompt_dst': '31 31 22 22 77 77 64 64 13', # TODO: also identity functions that are more general, not necessarily numbers
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})
plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 68 plus 0?\n',
  'prompt_dst': 'j j 82 82 c c t t X', # TODO: also identity functions that are more general, not necessarily numbers
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})
plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 74 plus 45?\n',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(28),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 74 plus 45?\n',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(28),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 442 plus 86?\n',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(28),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': '5+4=9\n667+45=',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': '32+44=76\n667+45=',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(28),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': '32+44=76\n74+45=',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(28),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': '32+44=76\n442+86=',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(28),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': '11+44=55\n74+45=',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(28),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': '323+443=766\n4+8=12\n74+45=',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(28),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': 'What is 3 plus 2?\n',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(27),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

In [None]:
logits_src, output_dst_logits = patch_layer_sweep({
  'prompt_src': '323+443=766\n4+8=12\n3+2=',
  'prompt_dst': '31 31 22 22 77 77 64 64 13',
  'layers_dst': range(28),
  'position_src': -1,
  'position_dst': -1
})

plot_layer_sweep_char_color(output_dst_logits)

Why do all the logits spike one the last layer? Are we passing through the final MLP twice?

# scratch / old


In [None]:
def generate(inp, verbose, max_new_tokens=10):
  input_text = [mt.tokenizer.decode(x) for x in inp['input_ids'][0]]
  generated_continuation = [mt.tokenizer.decode(x) for x in mt.model.generate(
        inp['input_ids'],
        pad_token_id = mt.model.generation_config.eos_token_id,
        max_new_tokens=1,
    )[0][len(inp['input_ids'][0]):]]
  if verbose:
    print(f"############\nInput: {input_text}\nContinuation: {generated_continuation}\n############\n")


def mini_exp_old(example, verbose=True):
  prompt, prompt_patch, layer_patch, position_patch, position_patched = example["prompt"], example["prompt_patch"], example["layer_patch"], example["position_patch"], example["position_patched"]
  layer_source = layer_patch
  if 'layer_source' in example:
    layer_source = example['layer_source']

  generate(make_inputs(mt.tokenizer, [prompt_patch]), verbose)
  generate(make_inputs(mt.tokenizer, [prompt]), verbose)

  # first run the the model on prompt_patch and get all hidden states.
  inp = make_inputs(mt.tokenizer, [prompt_patch])

  if verbose:
    print("prompt_patch:", [mt.tokenizer.decode(x) for x in inp['input_ids'][0]])
  output = mt.model(**inp, output_hidden_states = True)
  hs_cache_ = [output["hidden_states"][layer+1][0] for layer in range(mt.num_layers)]

  # now do a second run on prompt, while patching
  # a specific hidden state from the first run.
  hs_patch_config = {
      layer_patch: [
          (position_patched, hs_cache_[layer_source][position_patch])
      ]
  }
  patch_hooks = set_hs_patch_hooks(mt.model, hs_patch_config, patch_input=False)
  inp = make_inputs(mt.tokenizer, [prompt])
  if verbose:
    print("prompt:", [mt.tokenizer.decode(x) for x in inp['input_ids'][0]])
  print(f"patching position {position_patched} at layer {layer_patch} with the hidden state from position {position_patch} at layer {layer_source}.")
  output = mt.model(**inp)
  answer_prob, answer_t = torch.max(torch.softmax(output.logits[0, -1, :], dim=0), dim=0)
  print("prediction with patching: ", decode_tokens(mt.tokenizer, [answer_t])[0], round(answer_prob.cpu().item(), 4))
  print("\n")

  # remove patching hooks
  remove_hooks(patch_hooks)

mini_exp_old({
  "prompt_patch": "What is 667 plus 45?\n",
  "prompt": "3+11\n",
  "layer_patch": 14,
  "layer_source": 14,
  "position_patch": -1,
  "position_patched": -1
})

In [None]:
def mini_exp(example, verbose=True):
  prompt, prompt_patch, layer_patch, position_patch, position_patched = example["prompt"], example["prompt_patch"], example["layer_patch"], example["position_patch"], example["position_patched"]
  layer_source = example.get('layer_source', layer_patch)

  # first run the the model on prompt_patch and get all hidden states.
  inp = make_inputs(mt.tokenizer, [prompt_patch])
  output = mt.model(**inp, output_hidden_states = True)
  hs_cache_ = [output["hidden_states"][layer+1][0] for layer in range(mt.num_layers)]
  if verbose:
    print("prompt_patch:", [mt.tokenizer.decode(x) for x in inp['input_ids'][0]])
    answer_prob, answer_t = torch.max(torch.softmax(output.logits[0, -1, :], dim=0), dim=0)
    print("original prediction: ", decode_tokens(mt.tokenizer, [answer_t])[0], round(answer_prob.cpu().item(), 4))

  # now do a second run on prompt, while patching
  # a specific hidden state from the first run.
  hs_patch_config = {
      layer_patch: [
          (position_patched, hs_cache_[layer_source][position_patch])
      ]
  }
  patch_hooks = set_hs_patch_hooks(mt.model, hs_patch_config, patch_input=False)
  inp = make_inputs(mt.tokenizer, [prompt])
  if verbose:
    print("prompt:", [mt.tokenizer.decode(x) for x in inp['input_ids'][0]])
  print(f"patching position {position_patched} at layer {layer_patch} with the hidden state from position {position_patch} at layer {layer_source}.")
  output = mt.model(**inp)
  answer_prob, answer_t = torch.max(torch.softmax(output.logits[0, -1, :], dim=0), dim=0)
  print("prediction with patching: ", decode_tokens(mt.tokenizer, [answer_t])[0], round(answer_prob.cpu().item(), 4))
  print("\n")

  # remove patching hooks
  remove_hooks(patch_hooks)

mini_exp({
  "prompt_patch": "What is 667 plus 45?\n",
  "prompt": "3+11\n",
  "layer_patch": 14,
  "layer_source": 14,
  "position_patch": -1,
  "position_patched": -1
})

In [None]:
max_indices = np.unique(np.argmax(output_dst_logits, axis=1))

# Plot for logits
fig_logits = go.Figure()

for idx in max_indices:
    token_label = tokens[idx]  # Assuming 'tokens' is a list/array that converts idx to token
    fig_logits.add_trace(go.Scatter(
        x=list(range(output_dst_logits.shape[0])),
        y=output_dst_logits[:, idx],
        mode='lines+markers',
        name=f'Logit of token: {token_label}'
    ))

fig_logits.update_layout(title='Logits of Top Tokens',
                  xaxis_title='Layer',
                  yaxis_title='Logit')
fig_logits.show()

# Plot for softmax values
output_dst_softmax = softmax(output_dst_logits)

fig_softmax = go.Figure()
for idx in max_indices:
    token_label = tokens[idx]  # Convert index to its corresponding token
    fig_softmax.add_trace(go.Scatter(
        x=list(range(output_dst_logits.shape[0])),
        y=output_dst_softmax[:, idx],
        mode='lines+markers',
        name=f'Softmax of token: {token_label}'
    ))

fig_softmax.update_layout(title='Softmax of Top Tokens',
                  xaxis_title='Layer',
                  yaxis_title='Percentage')
fig_softmax.show()


In [None]:
import plotly.graph_objects as go

max_indices = np.unique(np.argmax(output_dst_logits, axis=1))
fig = go.Figure()

for idx in max_indices:
    fig.add_trace(go.Scatter(
        x=list(range(output_dst_logits.shape[0])),
        y=output_dst_logits[:, idx],
        mode='lines+markers',
        name=f'Index {idx}'
    ))

fig.update_layout(title='Logits of top tokens',
                  xaxis_title='Layer',
                  yaxis_title='Logit')
fig.show()

In [None]:
def plot_layer_sweep(output_dst_logits):
  top_two_indices = np.argsort(output_dst_logits)[:, -5:]
  max_indices = np.unique(top_two_indices)

  # Compute these values once outside the loop
  x_values = list(range(output_dst_logits.shape[0]))
  softmax_values = softmax(output_dst_logits)
  sorted_indices = np.argsort(-output_dst_logits, axis=1)

  fig = make_subplots(rows=1, cols=3)
  for idx in max_indices:
    idx_color = color_mapping[idx]

    # Logits subplot
    fig.add_trace(go.Scatter(x=x_values,
                             y=output_dst_logits[:, idx],
                             mode='lines+markers',
                             name=vocab_list[idx],
                             line=dict(color=idx_color)),
                  row=1, col=1)

    # Softmax subplot
    fig.add_trace(go.Scatter(x=x_values,
                             y=softmax_values[:, idx],
                             mode='lines+markers',
                             line=dict(color=idx_color),
                             showlegend=False),
                  row=1, col=2)

    # Reciprocal Rank
    fig.add_trace(go.Scatter(x=x_values,
                             y= ((sorted_indices == idx).argmax(axis=1) + 1),
                             mode='lines+markers',
                             line=dict(color=idx_color),
                             showlegend=False),
                  row=1, col=3)

  fig.update_yaxes(type='log', col=3)
  fig.update_xaxes(title_text='Layer', col=1)
  fig.update_xaxes(title_text='Layer', col=2)
  fig.update_xaxes(title_text='Layer', col=3)

  fig.update_yaxes(title_text='Logit', col=1)
  fig.update_yaxes(title_text='Softmax Percentage', col=2)
  fig.update_yaxes(title_text='Reciprocal Rank', col=3, autorange="reversed")

  fig.show()

plot_layer_sweep(output_dst_logits)

# Generate

In [None]:
_generate('22 + 57 = \n')

(['22', ' +', ' 57', ' =', ' ', '\n'],
 ['-', '-', '-', '-', '-', '-', '-', '\n', '\n', '\n'])

In [None]:
_generate('22 + 57 =')

(['22', ' +', ' 57', ' ='],
 ['-', '-', '-', '-', '-', '-', '-', '-', '\n', '\n'])

In [None]:
_generate('52 + 33 = 85\n22 + 57 =')

(['52', ' +', ' 33', ' =', ' 85', '\n', '22', ' +', ' 57', ' ='],
 ['-', '-', '-', '-', '-', '-', '-', '-', '-', '-'])

In [None]:
_generate('52+33=85\n22+57=')

(['52', '+', '33', '=', '85', '\n', '22', '+', '57', '='],
 ['79', '79', '\n', '85', '85', '\n', '\n', '\n', '\n', '\n'])

In [None]:
_generate('Hi, my name is ')

(['Hi', ',', ' my', ' name', ' is', ' '],
 ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n'])

In [None]:
_generate('3+8=11\n52+33=85\n22+11=')

(['3',
  '+',
  '8',
  '=',
  '11',
  '\n',
  '52',
  '+',
  '33',
  '=',
  '85',
  '\n',
  '22',
  '+',
  '11',
  '='],
 ['33', '33', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n'])

(['3',
  '+',
  '8',
  '=',
  '11',
  '\n',
  '52',
  '+',
  '33',
  '=',
  '85',
  '\n',
  '222',
  '+',
  '333',
  '='],
 ['555', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n'])

In [None]:
_generate('Calculate 3300 + 3333.\n')

(['Cal', 'cul', 'ate', ' 33', '00', ' +', ' 3', '333', '.', '\n'],
 ['66', '33', '\n', 'What', ' is', ' -', '0', '.', '1', ' -'])

In [None]:
_generate('Calculate 34 + 233.\n')

(['Cal', 'cul', 'ate', ' 34', ' +', ' 233', '.', '\n'],
 ['267', '\n', 'What', ' is', ' -', '0', '.', '1', ' -', ' -'])

In [None]:
_generate('Calculate 3300 + 3333.\n6600\n\nCalculate 3300 + 1133.\n')

(['Cal',
  'cul',
  'ate',
  ' 33',
  '00',
  ' +',
  ' 3',
  '333',
  '.',
  '\n',
  '66',
  '00',
  '\n',
  '\n',
  'Cal',
  'cul',
  'ate',
  ' 33',
  '00',
  ' +',
  ' 11',
  '33',
  '.',
  '\n'],
 ['44', '33', '\n', 'What', ' is', ' -', '0', '.', '1', ' -'])

In [None]:
_generate('45 + 21 = 65, 34 + 28 =')

(['45', ' +', ' 21', ' =', ' 65', ',', ' 34', ' +', ' 28', ' ='],
 [' 2', '*', 'v', ' +', ' 2', '*', 'z', ' for', ' v', '.'])

In [None]:
_generate('45 plus 21 is 65\n 34 + 28 =')

(['45', ' plus', ' 21', ' is', ' 65', '\n', ' 34', ' +', ' 28', ' ='],
 [' 62', '\n', '\n', 'A', ':', '\n', '\n', 'You', ' can', ' use'])

In [None]:
# @title use
_generate('45 + 21 = 65\n 34 + 28 =')

(['45', ' +', ' 21', ' =', ' 65', '\n', ' 34', ' +', ' 28', ' ='],
 [' 62', '\n', ' 34', ' +', ' 28', ' =', ' 62', '\n', ' 34', ' +'])

In [None]:
_generate('45 added to 21 is 65\n 10 added to 28 is')

(['45',
  ' added',
  ' to',
  ' 21',
  ' is',
  ' 65',
  '\n',
  ' 10',
  ' added',
  ' to',
  ' 28',
  ' is'],
 [' 78', '\n', '\n', 'A', ':', '\n', '\n', 'You', ' can', ' use'])

In [None]:
# @title use
_generate('45 plus 22 is 67\n 30 plus to 28 is')

(['45',
  ' plus',
  ' 22',
  ' is',
  ' 67',
  '\n',
  ' 30',
  ' plus',
  ' to',
  ' 28',
  ' is'],
 [' 58', '\n', '\n', 'A', ':', '\n', '\n', 'You', ' can', ' use'])

In [None]:
_generate('45 + 22 = 61\n 30 + 28 = 58\n 84 + 46 =')

(['45',
  ' +',
  ' 22',
  ' =',
  ' 61',
  '\n',
  ' 30',
  ' +',
  ' 28',
  ' =',
  ' 58',
  '\n',
  ' 84',
  ' +',
  ' 46',
  ' ='],
 [' 130', '\n', '\n', 'The', ' sum', ' of', ' the', ' digits', ' of', ' the'])

In [None]:
_generate('45 22 61\n 30 28 58\n 84 46 130\n 1 1 2\n 2 1 3\n 3 4 7\n 10 4')

(['45',
  ' 22',
  ' 61',
  '\n',
  ' 30',
  ' 28',
  ' 58',
  '\n',
  ' 84',
  ' 46',
  ' 130',
  '\n',
  ' 1',
  ' 1',
  ' 2',
  '\n',
  ' 2',
  ' 1',
  ' 3',
  '\n',
  ' 3',
  ' 4',
  ' 7',
  '\n',
  ' 10',
  ' 4'],
 [' 5', '\n', ' 4', ' 5', ' 6', '\n', ' 5', ' 6', ' 7', '\n'])

In [None]:
# @title use
_generate('4 added to 5 is 9. 48 added to 11 is 59. 30 added to 88 is 118. 38 added to 93 is')

(['4',
  ' added',
  ' to',
  ' 5',
  ' is',
  ' 9',
  '.',
  ' 48',
  ' added',
  ' to',
  ' 11',
  ' is',
  ' 59',
  '.',
  ' 30',
  ' added',
  ' to',
  ' 88',
  ' is',
  ' 118',
  '.',
  ' 38',
  ' added',
  ' to',
  ' 93',
  ' is'],
 [' 121', '.', '\n', '\n', 'The', ' sum', ' of', ' the', ' digits', ' of'])

In [None]:
_generate('Calculate 7 + 1. Result: 8\nCalculate 6 + 6. Result: 12\nCalculate 38 + 93. Result:')

(['Cal',
  'cul',
  'ate',
  ' 7',
  ' +',
  ' 1',
  '.',
  ' Result',
  ':',
  ' 8',
  '\n',
  'Cal',
  'cul',
  'ate',
  ' 6',
  ' +',
  ' 6',
  '.',
  ' Result',
  ':',
  ' 12',
  '\n',
  'Cal',
  'cul',
  'ate',
  ' 38',
  ' +',
  ' 93',
  '.',
  ' Result',
  ':'],
 [' 121', '\n', 'Cal', 'cul', 'ate', ' -', '1', ' +', ' -', '1'])

In [None]:
_generate('Calculate 7 + 1. Result: 8\nCalculate 6 + 6. Result:')

(['Cal',
  'cul',
  'ate',
  ' 7',
  ' +',
  ' 1',
  '.',
  ' Result',
  ':',
  ' 8',
  '\n',
  'Cal',
  'cul',
  'ate',
  ' 6',
  ' +',
  ' 6',
  '.',
  ' Result',
  ':'],
 [' 12', '\n', 'Cal', 'cul', 'ate', ' 7', ' +', ' -', '1', ' +'])

In [None]:
_generate('Adding 4 and 4 gives 8. Adding 9 and 1 gives 10. Adding 91 and 12 gives')

(['Adding',
  ' 4',
  ' and',
  ' 4',
  ' gives',
  ' 8',
  '.',
  ' Adding',
  ' 9',
  ' and',
  ' 1',
  ' gives',
  ' 10',
  '.',
  ' Adding',
  ' 91',
  ' and',
  ' 12',
  ' gives'],
 [' 103',
  '.',
  ' Adding',
  ' 12',
  ' and',
  ' 91',
  ' gives',
  ' 103',
  '.',
  ' Adding'])

In [None]:
_generate('Adding 4 and 4 gives 8. Adding 91 and 19 gives')

(['Adding',
  ' 4',
  ' and',
  ' 4',
  ' gives',
  ' 8',
  '.',
  ' Adding',
  ' 91',
  ' and',
  ' 19',
  ' gives'],
 [' 110',
  '.',
  ' Adding',
  ' 4',
  ' and',
  ' 91',
  ' gives',
  ' 95',
  '.',
  ' Adding'])

In [None]:
_generate('The sum of 5 and 5 is 10. The sum of 91 and 12 is 103. The sum of 38 and 19 is')

(['The',
  ' sum',
  ' of',
  ' 5',
  ' and',
  ' 5',
  ' is',
  ' 10',
  '.',
  ' The',
  ' sum',
  ' of',
  ' 91',
  ' and',
  ' 12',
  ' is',
  ' 103',
  '.',
  ' The',
  ' sum',
  ' of',
  ' 38',
  ' and',
  ' 19',
  ' is'],
 [' 57', '.', ' The', ' sum', ' of', ' 5', ' and', ' 5', ' is', ' 10'])

In [None]:
_generate('test test stitch stitch lead lead animal animal 99')

(['test',
  ' test',
  ' stitch',
  ' stitch',
  ' lead',
  ' lead',
  ' animal',
  ' animal',
  ' 99'],
 ['.',
  '\n',
  '\n',
  'The',
  ' first',
  ' stitch',
  ' lead',
  ' lead',
  ' animal',
  ' animal'])

In [None]:
_generate('a a b b 8 8 n n cat cat dog dog 35')

(['a',
  ' a',
  ' b',
  ' b',
  ' 8',
  ' 8',
  ' n',
  ' n',
  ' cat',
  ' cat',
  ' dog',
  ' dog',
  ' 35'],
 ['.', '0', ' 35', '.', '0', ' 35', '.', '0', ' 35', '.'])