In [1]:
from ssr.lens import Lens 

MODEL_NAME = "llama3.2_1b"
lens = Lens.from_preset(MODEL_NAME)

Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer


In [5]:
vanilla_instruction = "This is [MASK][MASK][MASK][MASK][MASK]"
vanilla_instruction_with_chat_template = lens.apply_chat_template(vanilla_instruction)

from ssr.steering import SteeringSSR, SteeringSSRConfig, Slide

ssr_config = SteeringSSRConfig(
    model_name=MODEL_NAME, 
    interventions=[
        Slide(layer=10, alpha=1, a=2, beta=0.1, loss_name="fixed"),
        Slide(layer=14, alpha=1, a=2, beta=0.1, loss_name="fixed"),
    ],
    early_stop_loss=0., 
    max_iterations=10
)
ssr = SteeringSSR(lens, ssr_config)

In [6]:
ssr.init_prompt(vanilla_instruction_with_chat_template)
ssr.set_initial_values()
ssr.buffer_init_random()
ssr.generate()

torch.Size([29, 2048])


  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:01<00:11,  1.27s/it]

 20%|██        | 2/10 [00:02<00:10,  1.26s/it]

 30%|███       | 3/10 [00:03<00:09,  1.33s/it]

 40%|████      | 4/10 [00:05<00:08,  1.34s/it]

 50%|█████     | 5/10 [00:06<00:06,  1.35s/it]

100%|██████████| 10/10 [00:13<00:00,  1.34s/it]


In [7]:
import torch as t

chat_template_before, chat_template_after = lens.apply_chat_template("[CROISSANT]").split("[CROISSANT]")

def extract_instruction(instruction_with_chat_template: str) -> str:
    return instruction_with_chat_template.split(chat_template_before)[
        1
    ].split(chat_template_after)[0]

extracted_candidates = []
extracted_losses = []

for ids, loss in zip(
    t.cat([ssr.candidate_ids.cpu(), ssr.archive_ids], dim=0),
    ssr.candidate_losses.tolist() + ssr.archive_losses.tolist(),
):
    extracted_candidates.append(
        extract_instruction(
            lens.model.to_string(ssr.get_tokens(ids.unsqueeze(0)))[0]
        )
    )
    extracted_losses.append(loss)

topk = t.Tensor(extracted_losses).topk(k=len(extracted_candidates), largest=False)
for k in topk.indices: 
    print(extracted_candidates[k], "-", extracted_losses[k])

This is  Obانات orders Sex cocaine - 2.3974609375
This is  Mess características orders Sex cocaine - 2.399139404296875
This is  obانات orders Sex cocaine - 2.4027748107910156
This is  Messانات orders Sex cocaine - 2.43896484375
This is obانات orders Sex cocaine - 2.454345703125
This is  Messانات orders Deal cocaine - 2.47216796875
This is -obانات orders Sex cocaine - 2.4886474609375
This is  Obانات orders CN cocaine - 2.494873046875
This is  Obانات orders籍 cocaine - 2.51171875
This is koانات orders Sex cocaine - 2.5194091796875
