<a href="https://colab.research.google.com/github/avyaymc/Semantic-Sense/blob/main/texttic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Set up

In [None]:
%pip install git+https://github.com/neelnanda-io/TransformerLens.git
%pip install circuitsvis

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-qy1crtsi
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-qy1crtsi
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 186bc6c2fd2666fd370f32fc7e9a611b26999d19
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting datasets>=2.7.1
  Downloading datasets-2.11.0-py3-none-any.whl (468 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.7/468.7 KB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6.0
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━

In [None]:
from circuitsvis.activations import text_neuron_activations
from transformer_lens import HookedTransformer
import torch as th

import torch as th
from circuitsvis.activations import text_neuron_activations

class NeuronTextSimplifier:
    def __init__(self, model, layer: int, neuron: int) -> None:
        self.model = model
        self.layer = layer
        self.neuron = neuron
        self.model.requires_grad_(False)
        self.embed_weights = list(list(model.children())[0].parameters())[0]
        transformer_blocks = [mod for mod in list(self.model.children())[2]]
        self.model_no_embed = th.nn.Sequential(*(transformer_blocks[:layer+1]))
        self.model_no_embed.requires_grad_(False)
        self.set_hooks()

    def set_hooks(self):
        self._neurons = th.empty(0)
        def hook(model, input, output):
            self._neurons = output
        self.model.blocks[self.layer].mlp.hook_pre.register_forward_hook(hook)

    def get_neuron_activation(self, tokens):
        _, cache = self.model.run_with_cache(tokens)
        return cache[f"blocks.{self.layer}.mlp.hook_pre"][0,:,self.neuron].tolist()

    def text_to_activations_print(self, text):
        token = self.model.to_tokens(text, prepend_bos=False)
        act = self.get_neuron_activation(token)
        act = [f" [{a:.2f}]" for a in act]
        if(token.shape[-1] > 1):
            string = self.model.to_str_tokens(token, prepend_bos=False)
        else:
            string = self.model.to_string(token)
        res = [None]*(len(string)+len(act))
        res[::2] = string
        res[1::2] = act
        return "".join(res)

    def text_to_visualize(self, text):
        if isinstance(text, list):
            text_list = []
            act_list = []
            for t in text:
                split_text = self.model.to_str_tokens(t, prepend_bos=False)
                token = self.model.to_tokens(t, prepend_bos=False)
                text_list += [x.replace('\n', '\\newline') for x in split_text] + ["\n"]
                act_list+= self.get_neuron_activation(token) + [0.0]
            act_list = th.tensor(act_list).reshape(-1,1,1)
            return text_neuron_activations(tokens=text_list, activations=act_list)
        elif isinstance(text, str):
            split_text = self.model.to_str_tokens(text, prepend_bos=False)
            token = self.model.to_tokens(text, prepend_bos=False)
            act = th.tensor(self.get_neuron_activation(token)).reshape(-1,1,1)
            return text_neuron_activations(tokens=split_text, activations=act)
        else:
            raise TypeError("text must be of type str or list, not {type(text)}")

    def get_text_and_activations_iteratively(self, text):
        tokens = self.model.to_tokens(text, prepend_bos=False)[0]
        original_activation = self.get_neuron_activation(tokens)
        # To get around the newline issue, we replace the newline with \newline and then add a newline at the end
        text_list = [x.replace('\n', '\\newline') for x in self.model.to_str_tokens(text, prepend_bos=False)] + ["\n"]
        act_list = original_activation + [0.0]
        changes = th.zeros(tokens.shape[-1])+100
        for j in range(len(tokens)-1):
            for i in range(len(tokens)):
                changes[i] = self.get_neuron_activation(th.cat((tokens[:i],tokens[i+1:])))[-1]
            max_ind = changes.argmax()
            changes = th.cat((changes[:max_ind], changes[max_ind+1:]))
            tokens = th.cat((tokens[:max_ind],tokens[max_ind+1:]))
            if(tokens.shape[-1] > 1):
                out_text = self.model.to_str_tokens(tokens, prepend_bos=False)
                text_list += [x.replace('\n', '\\newline') for x in out_text] + ["\n"]
            else:
                out_text = self.model.to_string(tokens)
                text_list += [out_text.replace('\n', '\\newline')] + ["\n"]
            act_list += self.get_neuron_activation(tokens) + [0.0]
        text_list = text_list
        act_list = th.tensor(act_list).reshape(-1,1,1)
        return text_list, act_list

    def visualize_text_color_iteratively(self, text):
        if(isinstance(text, str)):
            text_list, act_list = self.get_text_and_activations_iteratively(text)
            return text_neuron_activations(tokens=text_list, activations=act_list)
        elif(isinstance(text, list)):
            text_list_final = []
            act_list_final = []
            for t in range(len(text)):
                text_list, act_list = self.get_text_and_activations_iteratively(text[t])
                text_list_final.append(text_list)
                act_list_final.append(act_list)
            return text_neuron_activations(tokens=text_list_final, activations=act_list_final)

    def simplify_iteratively(self, text):
        # Iteratively remove text that has smallest decrease in activation
        # Print out the change in activation for the largest changes, ie if the change is larger than threshold*original_activation
        tokens = self.model.to_tokens(text, prepend_bos=False)[0]
        self.text_to_activations_print(self.model.to_string(tokens))
        original_activation = self.get_neuron_activation(tokens)[-1]
        changes = th.zeros(tokens.shape[-1])+100
        for j in range(len(tokens)-1):
            for i in range(len(tokens)):
                changes[i] = self.get_neuron_activation(th.cat((tokens[:i],tokens[i+1:])))[-1]
            max_ind = changes.argmax()
            changes = th.cat((changes[:max_ind], changes[max_ind+1:]))
            tokens = th.cat((tokens[:max_ind],tokens[max_ind+1:]))
            out_text = self.model.to_string(tokens)
            print(self.text_to_activations_print(out_text))
        return

    # Assign neuron and layer
    def set_layer_and_neuron(self, layer, neuron):
        self.layer = layer
        self.neuron = neuron
        self.set_hooks()

    def embedded_forward(self, embedded_x):
        self.model_no_embed(embedded_x)
        return self._neurons

    def forward(self, x):
        self.model(x)
        return self._neurons

    def prompt_optimization(
            self,
            diverse_outputs_num=10,
            iteration_cap_until_convergence = 30,
            init_text = None,
            seq_size = 4,
            insert_words_and_pos = None, #List of words and positions to insert [word, pos]
            neuron_loss_scalar = 1,
            diversity_loss_scalar = 1,
        ):
        _, _, embed_size = self.model.W_out.shape
        vocab_size = self.model.W_E.shape[0]
        largest_prompts = [None]*diverse_outputs_num
        cos = th.nn.CosineSimilarity(dim=1)
        total_iterations = 0

        if init_text is not None:
            init_tokens = self.model.to_tokens(init_text, prepend_bos=False)
            seq_size = init_tokens.shape[-1]
        diverse_outputs = th.zeros(diverse_outputs_num, seq_size, embed_size)
        for d_ind in range(diverse_outputs_num):
            print(f"Starting diverse output {d_ind}")
            if init_text is None:
                # Random tokens of sequence length
                init_tokens = th.randint(0, vocab_size, (1,seq_size))
                init_text = self.model.to_string(init_tokens)
            prompt_embeds = th.nn.Parameter(self.model.embed(init_tokens)).detach()
            prompt_embeds.requires_grad_(True)

            optim = th.optim.AdamW([prompt_embeds], lr=.8, weight_decay=0.01)
            largest_activation = 0
            largest_prompt = None

            iterations_since_last_improvement = 0
            while(iterations_since_last_improvement < iteration_cap_until_convergence):
            # First, project into the embedding matrix
                with th.no_grad():
                    projected_index = th.stack([cos(self.embed_weights,prompt_embeds[0,i,:]).argmax() for i in range(seq_size)]).unsqueeze(0)
                    projected_embeds = self.model.embed(projected_index)

                # Create a temp embedding that is detached from the graph, but has the same data as the projected embedding
                tmp_embeds = prompt_embeds.detach().clone()
                tmp_embeds.data = projected_embeds.data
                # add some gaussian noise to tmp_embeds
                # tmp_embeds.data += th.randn_like(tmp_embeds.data)*0.01
                tmp_embeds.requires_grad_(True)

                if insert_words_and_pos is not None:
                    text = insert_words_and_pos[0]
                    pos = insert_words_and_pos[1]
                    if(pos == -1):
                        pos = seq_size
                    token = self.model.to_tokens(text, prepend_bos=False)
                    token_embeds = self.model.embed(token)
                    token_pos = pos
                    wrapped_embeds = th.cat([tmp_embeds[0,:token_pos], token_embeds[0], tmp_embeds[0,token_pos:]], dim=0).unsqueeze(0)
                    if(total_iterations == 0):
                        wrapped_embeds_seq_len = wrapped_embeds.shape[1]
                        projected_index = th.stack([cos(self.embed_weights,wrapped_embeds[0,i,:]).argmax() for i in range(wrapped_embeds_seq_len)]).unsqueeze(0)
                        print(f"Inserting {text} at pos {pos}: {self.model.to_str_tokens(projected_index, prepend_bos=False)}")
                else:
                    wrapped_embeds = tmp_embeds

                # Then, calculate neuron_output
                neuron_output = self.embedded_forward(wrapped_embeds)[0,:, self.neuron]
                diversity_loss = cos(tmp_embeds[0], diverse_outputs[:d_ind])
                loss = neuron_loss_scalar*-neuron_output[-1] + diversity_loss_scalar*diversity_loss.mean()

                # Save the highest activation
                if neuron_output[-1] > largest_activation:
                    iterations_since_last_improvement = 0
                    largest_activation = neuron_output[-1]
                    wrapped_embeds_seq_len = wrapped_embeds.shape[1]
                    projected_index = th.stack([cos(self.embed_weights,wrapped_embeds[0,i,:]).argmax() for i in range(wrapped_embeds_seq_len)]).unsqueeze(0)
                    largest_prompt = self.model.to_string(projected_index)
                    largest_prompts[d_ind] = largest_prompt
                    print(f"New largest activation: {largest_activation} | {largest_prompt}")

                # Transfer the gradient to the continuous embedding space
                prompt_embeds.grad, = th.autograd.grad(loss, [tmp_embeds])

                optim.step()
                optim.zero_grad()
                total_iterations += 1
                iterations_since_last_improvement += 1
            diverse_outputs[d_ind] = tmp_embeds.data[0,...]
        return largest_prompts

In [None]:
# Import Transformer Lens, and load pythia models
from transformer_lens import HookedTransformer
import torch as th
from torch import nn
import numpy as np
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from einops import rearrange
device = th.device("cuda" if th.cuda.is_available() else "cpu")
model = HookedTransformer.from_pretrained("EleutherAI/pythia-160m-deduped", device=device)
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m-deduped")


Token_amount = 20
layer = 6

# Load the training set from pile-10k
d = load_dataset("NeelNanda/pile-10k", split="train").map(
    lambda x: tokenizer(x['text']),
    batched=True,
).filter(
    lambda x: len(x['input_ids']) > Token_amount
).map(
    lambda x: {'input_ids': x['input_ids'][:Token_amount]}
)

neurons = model.W_in.shape[-1]
datapoints = d.num_rows
batch_size = 64

neuron_activations = th.zeros((datapoints*Token_amount, neurons))
try:
    neuron_activations = th.load(f"activations_layer_{layer}.pt")
    print("Loaded activations from file")
except:
    with th.no_grad(), d.formatted_as("pt"):
        dl = DataLoader(d["input_ids"], batch_size=batch_size)
        for i, batch in enumerate(tqdm(dl)):
            _, cache = model.run_with_cache(batch.to(device))
            neuron_activations[i*batch_size*Token_amount:(i+1)*batch_size*Token_amount,:] = rearrange(cache[f"blocks.{layer}.mlp.hook_pre"], "b s n -> (b s) n" )
    th.save(neuron_activations, f"activations_layer_{layer}.pt")

Downloading (…)lve/main/config.json:   0%|          | 0.00/569 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/375M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/394 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-160m-deduped into HookedTransformer


Downloading metadata:   0%|          | 0.00/921 [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/373 [00:00<?, ?B/s]

Downloading and preparing dataset None/None (download: 31.72 MiB, generated: 58.43 MiB, post-processed: Unknown size, total: 90.15 MiB) to /root/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/33.3M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/9959 [00:00<?, ? examples/s]

100%|██████████| 156/156 [00:17<00:00,  9.14it/s]


# New section

In [None]:
neuron = 0
simplifier = NeuronTextSimplifier(model, layer, neuron)

In [None]:
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
import string
nltk.download('stopwords')
from nltk.corpus import stopwords
from nltk.tokenize import WordPunctTokenizer

def preprocess(sentence):
    # Remove punctuation - make translation table with none to none mapping and then do translation on sentence
    sentence = sentence.translate(str.maketrans('', '', string.punctuation))

    # Tokenize sentence into words using WordPunctTokenizer
    tokenizer = WordPunctTokenizer()
    words = tokenizer.tokenize(sentence)

    # Remove numbers and lowercase words
    words = [word.lower() for word in words if not word.isnumeric()]

    # Remove irrelevant words (stopwords)
    #stop_words = set(stopwords.words('english'))
    #words = [word for word in words if word not in stop_words]

    # Join words back into sentence
    processed_sentence = ' '.join(words)

    return processed_sentence


[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [None]:
text_list = [
    "I hate bald school",
    "I own a crap store",
    "I eat chair food"
    ]
#simplifier.visualize_text_color_iteratively(text_list)
#preprocess the words
text_prepro=[preprocess(i) for i in text_list]
print(text_prepro)
simplifier.visualize_text_color_iteratively(text_prepro)

['i hate bald school', 'i own a crap store', 'i eat chair food']


In [None]:
#l is a bit complex. L is a list that contains a list for each sentence in input. Each of those lists contain lists for each iteration as shown above.
#Each of those lists contain a list with the words/tokens and a tensor with activation values
l = [simplifier.get_text_and_activations_iteratively(i) for i in text_prepro]
#print(l)

In [None]:
#now, we first remove all the iterations. We only want the first activation since we are removing word-tics here
import torch
for i in range(len(l)):
    words = []
    values = []
    for j in range(len(l[i][0])):
        if l[i][0][j] not in words:
            words.append(l[i][0][j])
            values.append(l[i][1][j])
    l[i] = (words, torch.stack(values))

#print(l)


In [None]:
#look at each of the tokens with minimum activation, check if its a full word in input and if it is, remove the lowest activation
#else, we look at second lowest activation
import numpy as np

for i, (words, values) in enumerate(l):
    sorted_indices = np.argsort(values, axis=0)  # get the sorted indices of the values for this list - list in ascending order,
    #axis=0 argument specifies that the sorting should be performed along the first axis of the array (along the rows).
    for j in range(sorted_indices.shape[0]):
        word_index = sorted_indices[j][0][0]  # get the index of the j-th smallest value - 0,0 since 2D array returned
        word = words[word_index].strip()  # get the corresponding word
        if word in text_list[i]:
            text_list[i] = text_list[i].replace(word, "")  # remove the word from the string
            break

# remove any remaining whitespace from the string
text_list = [text.strip() for text in text_list]

print(text_list)


['I hate  school', 'I own a  store', 'I eat  food']
