# Introduction to the main tools 

This project currently uses the Transformer Lens library, as it makes it easy and straightforward to use PyTorch Hooks. 
- Main page: <https://transformerlensorg.github.io/TransformerLens/>
- Getting started: <https://transformerlensorg.github.io/TransformerLens/content/getting_started.html>
- (Excellents) tutorials: <https://transformerlensorg.github.io/TransformerLens/content/tutorials.html>

I highly recommend the extraordinary course ARENA, to explore the techniques used in this paper (DLA, Attribution patching, etc.):
- Website: <https://www.arena.education/> 
- Course: <https://arena-chapter1-transformer-interp.streamlit.app/>

This notebooks aims to give you the necessary part to understand and use the code of the paper. 

## Dependencies
The core SSR algorithm (`ssr/core.py`) uses as few personal code as possible to facilitate reproducibility. It only depends on the Transformer Lens library. 

However, the three implementations (`probes/probe_ssr.py`, `attention/attention_ssr.py`, `steering/steering_ssr.py`) use the custom class `Lens`, with utilities and a custom default values management. This hampers reproducibility, but I've still chosen to keep the code as I present it in this repo, because the aim of the three implementations is to show that the main algorithm is effective. If you want to reuse SSR, I strongly advise you to take the core and rewrite an implementation that suits your needs. 

That being said, if you are still interested in the code of the three implementations/ experiments, I'll introduce you to the `Lens` class in this notebook.  

## Lens 
The `Lens` class in `ssr/lens.py`, has three main functions: 
- Allowing quick load of preconfigured LLMs 
- Managing the default values 
- Providing utilities to scan/ process data

### 1. Quick load of preconfigured LLMs

I used four main LLMs in this work: 
- Gemma 2 2b: `gemma2_2b`, <https://huggingface.co/google/gemma-2-2b-it> (gated)
- Llama 3.2 1b: `llama3.2_1b`, <https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct> (gated)
- Llama 3.2 3b: `llama3.2_3b`, <https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct> (gated)
- Qwen 2.5 1.5b: `qwen2.5_1.5b`, <https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct>

As the chat templates may vary depending on the versions, I picked the official jinja template for each model, put in `ssr/templates/*`, and sticked to these ones for every experiments. 

For the rest of the configuration, I put everything in the `models.toml` file, at the root of the project. 

To get the default config for a LLM, first make sure the `models.toml` is at the root of the folder, otherwise modify the `MODELS_PATH` value in the environment variables (`.env`). 

In [1]:
import toml

from ssr import MODELS_PATH, pprint

with open(MODELS_PATH, "r") as f: 
    data = toml.load(f)

pprint(data["llama3.2_1b"])

```python
{
    'chat_template': 'llama3.2.jinja2',                 # location of the chat template file 
                                                        # (ssr/template/llama3.2.jinja2)
                                                        # or directly the chat template as str 

    'model_name': 'meta-llama/Llama-3.2-1B-Instruct',   # name of the model in Transformer Lens 

    'restricted_tokens': ['128000-128255']              # range of restricted tokens (ie: we don't 
                                                        # usually want to get adversarial candidates 
                                                        # with <eos> or <reserved_token>)
}
```

The LLM will be instancied as: 
```python
model = tl.HookedTransformer.from_pretrained(
    model_name=data["model_name"],
    device=device,
    dtype="float16",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True
)

model.tokenizer.chat_template = data["chat_template"]   # or jinja load data["chat_template"]      
model.tokenizer.padding_side = DEFAULT_VALUE            # usually "left"             
``` 

The `chat_template` argument can either be a path (end with `.jinja2`), or the str version of the jinja chat template directly. 

This allows us to load common LLMs quickly: 

In [2]:
from ssr.lens import Lens 

lens = Lens.from_preset("llama3.2_1b")

Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer


The `Lens` object is simply a class with a property model, which is the Transformer Lens model, and utility methods. To access the Transformer Lens model simply use `lens.model`. Hence the configuration can be printed with: 

In [3]:
pprint(lens.model.cfg)

### 2. Default values management

Some methods in the `Lens` class accept `DefaultValue` as argument, which means that, if you don't specify a value when calling the method, the method will look in the default values store of your `Lens` object instead. This is useful in our case, as each model has different default values. 

For instance, the `padding` argument usually accepts `DefaultValue`:

```python
padding: DefaultValue | bool = DEFAULT_VALUE,
```

If you don't provide the `padding` argument when calling a method in `Lens`, the value will be: 
```python 
self.defaults.padding
```

The default values are set for each model when calling `Lens.from_preset()`. The different presets are stored in `models.toml`. You can access and modify the default values whenever you want, as they are just attributes of `lens.defaults`.

The base default values are stored in the `LensDefaults` class: 

In [4]:
from ssr.lens import LensDefaults

pprint(f"(default) Default values: \n{LensDefaults().model_dump_json(indent=4)}")

pprint(f"Defaults for {lens.defaults.model_surname}: \n{lens.defaults.model_dump_json(indent=4)}")

The code needed to manage defaults is essentially garbage boilerplate, but at least your Mypy/ Ruff/ Pyright are happy. 

### 3. Utility functions

The default values enable you to use the utility functions with very few arguments, but you can always specify the arguments at runtime if you don't want to rely on the defaults. 

**Apply chat template** \
The `apply_chat_template` method is a restriction of the more general `tokenizer.apply_chat_template` from Hugging Face. With the defaults you can just call: 

In [5]:
print(lens.apply_chat_template("Super cool!"))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>

Super cool!<|eot_id|><|start_header_id|>assistant<|end_header_id|>




You can modify the system message with: 

In [6]:
print(lens.apply_chat_template("Super cool!", system_message="Ceci est un message système très utile."))

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Ceci est un message système très utile.<|eot_id|><|start_header_id|>user<|end_header_id|>

Super cool!<|eot_id|><|start_header_id|>assistant<|end_header_id|>




**Load and process datasets**

In [7]:
from ssr.files import load_dataset

hf, hl = lens.process_dataset(*load_dataset())

pprint(f"""The datasets are loaded and processed into input tensors:

harmful tokens shape:  {hf.shape} 
harmless tokens shape: {hl.shape}

If you run this cells with Llama or Qwen, the default value for padding is True. 

Harmful sentence with chat template (and padding):
{lens.model.to_string(hf[0])} 

Harmless counterpart:
{lens.model.to_string(hl[0])}
""")


If `padding` is set to false in the default values, you will have to first compute a `seq_len` to only take the sentences of the dataset that, once tokenized with chat template, have a length of `seq_len` tokens. You can use the `lens.get_max_seq_len` method to compute the best `seq_len` for a given dataset. This is done automatically in `lens.auto_scan_dataset`. 

**Auto scan** \
As of my knowledge, to scan a dataset and store the activations on the CPU, you first have to run the forward pass, store all the needed intermediate activations in the `ActivationCache` object, then uses the `.to("cpu")` method to transfer it to the CPU. However, in practice, the GPU might be full long before the end of the forward pass. Furthermore, as there is no protection to OOM errors, when working on a jupyter notebook, every OOM error means the full notebook has to be reloaded. 

To overcome these problems, I implemented the `auto_scan` method, which will:
- Store each batch's activations to the CPU before processing the next batch
- Catch OOM errors and reduce the batch size if necessary (with the `find_executable_batch_size` decorator from `accelerate`, slightly modified)

This leads to the following operation being possible on my laptop (16Go VRAM): 

In [12]:
import torch as t
import time

hf_raw, _ = load_dataset("adv", max_samples=520)

pprint(f""" 
Number of instructions: {len(hf_raw)}

# GPU used before: {int(t.cuda.memory_allocated() / 1024 ** 2)}
# GPU cached before: {int(t.cuda.memory_reserved() / 1024**2)}
""")

start = time.time()
hf_logits, hf_cache = lens.auto_scan(hf_raw, pattern=None)  # no chat template here /!\
duration = time.time() - start

pprint(f"""
Cached activations: 
{list(hf_cache.keys())}

Residual activations' shape: 
{hf_cache["resid_post", 6].shape}

# GPU used after: {int(t.cuda.memory_allocated() / 1024 ** 2)}
# GPU cached after: {int(t.cuda.memory_reserved() / 1024**2)}

Duration: {duration}
""")


100%|██████████| 9/9 [00:16<00:00,  1.81s/it]


*We're having fun, but we shouldn't push the button too far. Relaunching the cell without deleting the variables is the death of your jupyter notebook (mine anyway).*

In [13]:
del hf_logits, hf_cache