# Introduction to the main tools 

This project currently uses the Transformer Lens library, as it makes it easy and straightforward to use PyTorch Hooks. 
- Main page: <https://transformerlensorg.github.io/TransformerLens/>
- Getting started: <https://transformerlensorg.github.io/TransformerLens/content/getting_started.html>
- (Excellents) tutorials: <https://transformerlensorg.github.io/TransformerLens/content/tutorials.html>

I highly recommend the extraordinary course ARENA, to explore the techniques used in this paper (DLA, Attribution patching, etc.):
- Website: <https://www.arena.education/> 
- Course: <https://arena-chapter1-transformer-interp.streamlit.app/>

This notebooks aims to give you the necessary part to understand and use the code of the paper. 

## Lens 
I made a class named `Lens` in `ssr/lens.py`, which has three main functions: 
- Allow quick load of preconfigured LLMs 
- Easy way to apply the correct chat template
- Allow batched CPU scans 

The main SSR algorithm (`ssr/core.py`) only needs the `Lens` class to apply the chat template. I'll modify that in the future, so the core algorithm does not depend on my custom `Lens` class, but only on Transformer Lens. 

I'll present here the three main functions of my custom `Lens` class. 

### 1. Quick load of preconfigured LLMs

I used four main LLMs in this work: 
- Gemma 2 2b: `gemma2_2b`, <https://huggingface.co/google/gemma-2-2b-it> (gated)
- Llama 3.2 1b: `llama3.2_1b`, <https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct> (gated)
- Llama 3.2 3b: `llama3.2_3b`, <https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct> (gated)
- Qwen 2.5 1.5b: `qwen2.5_1.5b`, <https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct>

As the chat templates may vary depending on the versions, I picked the official jinja template for each model, put in `ssr/templates/*`, and sticked to these ones for every experiments. 

For the rest of the configuration, I put everything in the `models.toml` file, at the root of the project. 

To get the default config for a LLM, first make sure the `models.toml` is at the root of the folder, otherwise modify the `MODELS_PATH` value in the environment variables (`.env`). Then, you can access the config with the `model_info` function from `lens.py`: 

In [1]:
from rich import print 
from ssr.lens import model_info

print(model_info("llama3.2_1b"))

```python
{
    'chat_template': 'llama3.2.jinja2',                 # location of the chat template file (ssr/template/llama3.2.jinja2)
    'lm_studio': 'llama-3.2-1b-instruct',               # name of the model in LM Studio
    'model_name': 'meta-llama/Llama-3.2-1B-Instruct',   # name of the model in Transformer Lens 
    'other_names': ['Llama-3.2-1B-Instruct'],           # other names (needed for SAELens for instance)
    'restricted_tokens': ['128000-128255']              # range of restricted tokens (ie: we don't want to get adversarial candidates with <eos> in the base scenario) 
}
```

The LLM will be instancied as: 
```python
model = tl.HookedTransformer.from_pretrained(
    model_name=kwargs["model_name"],
    device=device,
    dtype="float16",
    center_unembed=kwargs.get("center_unembed", True),
    center_writing_weights=kwargs.get("center_writing_weights", True),
    fold_ln=kwargs.get("fold_ln", True),
)

model.tokenizer.chat_template = chat_template      
model.tokenizer.padding_side = padding_side             
model.tokenizer.pad_token = pad_token
``` 

The `chat_template` argument can either be a path (end with `.jinja2`), or the str version of the jinja chat template directly. 

This allows us to load common LLMs quickly: 

In [2]:
from ssr.lens import Lens 

lens = Lens.from_config("llama3.2_1b")

Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer


The `Lens` object is simply a class with a property model, which is the Transformer Lens model, and utility methods. To access the Transformer Lens model simply use `lens.model`. Hence the configuration can be printed with: 

In [3]:
print(lens.model.cfg)

### 2. Applying the chat template

The `apply_chat_template` method is a restriction of the more general `tokenizer.apply_chat_template` from Hugging Face. The signature of the function is: 

```python
def apply_chat_template(
    self,
    messages: str | List[Dict[str, str]],
    tokenize: Literal[True] | Literal[False] = False,
    add_generation_prompt: bool = True,
    system_message: Optional[str] = None,
    role: str = "user",
    **kwargs,
) -> str | BatchEncoding:
```

Example: 

In [4]:
print(lens.apply_chat_template("Super cool!"))

With a system message, this gives: 

In [5]:
print(lens.apply_chat_template("Super cool!", system_message="You are a helpful assistant."))

### 3. Batch scan to CPU

When using Transformer Lens, the first major issue is the GPU needed to perform any task. For instance, to run a forward pass on a dataset and cache the activations on the CPU, one usually first run the forward pass, store all the needed intermediate activations in the `ActivationCache` object, then uses the `.to("cpu")` method to transfer it to the CPU. However, in practice, the GPU might be full long before the end of the forward pass. Furthermore, as there is no protection to OOM errors, when working on a jupyter notebook, every OOM error means the full notebook has to be reloaded. 

To overcome these problems, I implemented the `auto_scan` method, which will:
- Store each batch's activations to the CPU before processing the next batch (`batch_scan_to_cpu`)
- Catch OOM errors and reduce the batch size if necessary (`find_executable_batch_size`)

This leads to the following operation being possible: 

In [None]:
from ssr.datasets import load_dataset

hf, _ = load_dataset("adv", max_samples=520)

print(len(hf))

hf_scan = lens.auto_scan(hf, padding=True)  # no chat template here /!\

  0%|          | 0/1 [00:00<?, ?it/s]


  0%|          | 0/2 [00:00<?, ?it/s]


  0%|          | 0/4 [00:00<?, ?it/s]


  0%|          | 0/8 [00:00<?, ?it/s]


  0%|          | 0/16 [00:00<?, ?it/s]


  0%|          | 0/33 [00:00<?, ?it/s]


 18%|█▊        | 12/65 [00:34<03:42,  4.19s/it]