In [1]:
import transformers
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import os
import re
import random
import datasets
from datasets import load_dataset
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Dict
from tqdm import tqdm
import pickle
from dotenv import load_dotenv
import openai

from typing import List, Optional, Tuple, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset

import plotly.graph_objects as go
import plotly.express as px

from utils import untuple
from scripts.get_activations import gen_pile_data, compare_token_lists, slice_acts
from act_add.model_wrapper import ModelWrapper
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
file_path = 'data/12b'

all_mem_12b_data = pd.read_csv(f'{file_path}/mem_evals_gen_data.csv')
all_mem_12b_data

Unnamed: 0,gen,ground,mem_status,source,idx_in_ground,idx_in_hidden_states,char_by_char_similarity
0,"xsd:string"" minOccurs=""0"" msdata:Ordinal=""1"" /...","xsd:string"" minOccurs=""0"" msdata:Ordinal=""1"" /...",1.000000,pythia-evals,0,0,1.000000
1,(contacts == null)? null : contacts.get(i);\n...,(contacts == null)? null : contacts.get(i);\n...,1.000000,pythia-evals,1,1,1.000000
2,"// Wrap renderValue to handle initialization, ...","// Wrap renderValue to handle initialization, ...",0.546875,pythia-evals,2,2,0.579710
3,cdcdc; text-align:center| Midfielders\n\n|-\n!...,cdcdc; text-align:center| Midfielders\n\n|-\n!...,1.000000,pythia-evals,3,3,1.000000
4,"\n\t\t<read echo=""ascii""><delim>\n</delim><mat...","\n\t\t<read echo=""ascii""><delim>\n</delim><mat...",1.000000,pythia-evals,4,4,1.000000
...,...,...,...,...,...,...,...
9950,welcome to dermatillomania support\n\nQuestion...,welcome to dermatillomania support\n\nQuestion...,0.531250,pile,4995,4962,0.510121
9951,Credits Suisse Group (CS) allegedly helped sel...,Credits Suisse Group (CS) allegedly helped sel...,0.515625,pile,4996,4963,0.557377
9952,Boating Checklist for the Chesapeake Bay\n\nWa...,Boating Checklist for the Chesapeake Bay\n\nWa...,0.515625,pile,4997,4964,0.490654
9953,AMD announces the Ryzen 3 3100 and 3 3300X pro...,AMD announces the Ryzen 3 3100 and 3 3300X pro...,0.531250,pile,4998,4965,0.467249


In [3]:
model_name_or_path = "meta-llama/Llama-2-7b-hf"

model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, device_map="auto").eval()
use_fast_tokenizer = "LlamaForCausalLM" not in model.config.architectures
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast = use_fast_tokenizer, padding_side="left")
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
tokenizer.bos_token_id = 1

mw = ModelWrapper(model = model, tokenizer = tokenizer)

Loading checkpoint shards: 100%|██████████| 2/2 [01:01<00:00, 30.84s/it]


## creating dataset for llama activation run

In [4]:
source = []
idx_in_source = []
all_tokens = []
char_by_char = []
for i, row in all_mem_12b_data.iterrows():
    tokens = tokenizer(row['ground']).input_ids
    
    if len(tokens) < 64:
        continue
    
    all_tokens.append(tokens[:64])
    source.append(row['source'])
    char_by_char.append(row['char_by_char_similarity'])
    idx_in_source.append(row['idx_in_ground'])

llama_mem_data = pd.DataFrame({'source': source, 'idx_in_source': idx_in_source, 'ground_tokens': all_tokens, 'char_by_char': char_by_char})

In [5]:
llama_mem_data['ground_prompts'] = tokenizer.batch_decode(all_tokens)

In [6]:
llama_mem_data

Unnamed: 0,source,idx_in_source,ground_tokens,char_by_char,ground_prompts
0,pythia-evals,0,"[1, 921, 4928, 29901, 1807, 29908, 1375, 22034...",1.000000,"<s> xsd:string"" minOccurs=""0"" msdata:Ordinal=""..."
1,pythia-evals,1,"[1, 29871, 313, 12346, 29879, 1275, 1870, 6877...",1.000000,<s> (contacts == null)? null : contacts.get(i...
2,pythia-evals,2,"[1, 849, 399, 2390, 4050, 1917, 304, 4386, 178...",0.579710,<s> // Wrap renderValue to handle initializati...
3,pythia-evals,3,"[1, 14965, 2252, 29883, 29936, 1426, 29899, 25...",1.000000,<s> cdcdc; text-align:center| Midfielders\n\n|...
4,pythia-evals,4,"[1, 29871, 13, 12, 12, 29966, 949, 2916, 543, ...",1.000000,"<s> \n\t\t<read echo=""ascii""><delim>\n</delim>..."
...,...,...,...,...,...
9860,pile,4995,"[1, 12853, 304, 589, 2922, 453, 2480, 423, 230...",0.510121,<s> welcome to dermatillomania support\n\nQues...
9861,pile,4996,"[1, 24596, 1169, 27968, 6431, 313, 9295, 29897...",0.557377,<s> Credits Suisse Group (CS) allegedly helped...
9862,pile,4997,"[1, 1952, 1218, 5399, 1761, 363, 278, 678, 267...",0.490654,<s> Boating Checklist for the Chesapeake Bay\n...
9863,pile,4998,"[1, 319, 5773, 7475, 778, 278, 15586, 2256, 29...",0.467249,<s> AMD announces the Ryzen 3 3100 and 3 3300X...


In [8]:
llama_mem_data.to_csv("data/llama-2-7b/llama_ground_data.csv")

## testing gen

In [55]:
mw = ModelWrapper(model = model, tokenizer = tokenizer)

In [38]:
mw.model.config.num_hidden_layers

32

In [57]:
tokenizer(['what the hell', 'asdfohdf aodaf', 'ohh yea'])

{'input_ids': [[1, 825, 278, 23927], [1, 408, 2176, 1148, 2176, 263, 397, 2142], [1, 9360, 29882, 8007, 29874]], 'attention_mask': [[1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]}

In [59]:
out = mw.batch_generate_from_string(['what the hell', 'asdfohdf aodaf', 'ohh yea'], 
                                    max_new_tokens = 6, 
                                    output_hidden_states = True,
                                    output_tokens = True,
                                    layers = list(range(32)),
                                    tok_idxs = [1, 3, 5, 6],
                                    return_prompt_acts = True)

torch.Size([3, 33, 13, 4096])


In [4]:
memmed_pythia_list = list(all_mem_12b_data[all_mem_12b_data['char_by_char_similarity'] > 0.98].ground.values)[:50]

In [None]:
# Generate completions for each quote
from utils import eval_completions

memmed_first_half = []
completions = []
for prompt in memmed_pythia_list:
    prompt_parts = prompt.split()
    first_half = " ".join(prompt_parts[:len(prompt_parts)//2])
    memmed_first_half.append(first_half)
    
completions = mw.batch_generate_from_string(memmed_first_half, max_new_tokens=36)

# Evaluate completions
evaluations = eval_completions(completions, memmed_pythia_list, return_mean = False)


In [8]:
# Print the evaluations
for i in range(len(completions)):
    print(f"Ground: {memmed_pythia_list[i]}")
    print(f"Completion: {completions[i]}")
    print(f"char: {evaluations['char_by_char_similarity'][i]}")
    print(f"lev: {evaluations['lev_distance'][i]}")
    print()

Ground: xsd:string" minOccurs="0" msdata:Ordinal="1" />
                <xsd:element name="comment" type="xsd:string" minOccurs="0" msdata:Ordinal="2" />
              </xsd:sequence>
              <xsd:attribute name
Completion: xsd:string" minOccurs="0" msdata:Ordinal="1" /> <xsd:element name="comment" type="xsd:string" minOccurs="0" msdata:Ordinal="2" /> <xsd:element name="link" type="xsd:string" min
char: 0.7484662576687117
lev: 0.69377990430622

Ground:  (contacts == null)? null : contacts.get(i);
    }

    @Override
    public long getItemId(int i) {
        return i;
    }

    @Override
    public View getView(int i, View view, ViewGroup viewGroup)
Completion: (contacts == null)? null : contacts.get(i); } @Override public long getItemId(int i) { return i; }

}
\end{code}

\strong{Fragment_contacts.xml}

\begin{code}
<LinearLayout
char: 0.5704697986577181
lev: 0.5247524752475248

Ground: cdcdc; text-align:center| Midfielders

|-
! colspan=10 style=background:#dcdcdc; text-align

In [6]:
ground_idxs = [idx for idx, lev in enumerate(evaluations['lev_distance']) if lev > 0.7]

In [7]:
len(ground_idxs)

16

In [12]:
llama_mem_data.shape

(9865, 6)

In [5]:
llama_mem_data = pd.read_csv("data/llama-2-7b/llama_ground_data.csv")

In [10]:
llama_states = torch.load("data/llama-2-7b/all_hidden_states.pt")
llama_gens = pickle.load(open("data/llama-2-7b/all_generations.pkl", "rb"))
llama_mem_status = pickle.load(open("data/llama-2-7b/all_mem_status.pkl", "rb"))
llama_genned_toks = pickle.load(open("data/llama-2-7b/all_tokens.pkl", "rb"))

In [11]:
llama_states.shape

torch.Size([15, 32, 10, 4096])

In [9]:
len(llama_mem_status['lev_distance'])

20