In [1]:
from ssr.lens import Lens 

MODEL_NAME = "llama3.2_1b"
lens = Lens.from_preset(MODEL_NAME)

Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer


In [2]:
import torch as t

from ssr.steering import SteeringSSR, SteeringSSRConfig, Slide

chat_template_before, chat_template_after = lens.apply_chat_template("[CROISSANT]").split("[CROISSANT]")

def extract_instruction(instruction_with_chat_template: str) -> str:
    return instruction_with_chat_template.split(chat_template_before)[
        1
    ].split(chat_template_after)[0]


ssr_config = SteeringSSRConfig(
    model_name=MODEL_NAME, 
    interventions=[
        Slide(layer=10, alpha=1, a=2, beta=0.1, loss_name="fixed"),
    ],
    early_stop_loss=0., 
    max_iterations=10,
)
ssr = SteeringSSR(lens, ssr_config)

def steering_lens(vanilla_instruction): 
    vanilla_instruction_with_chat_template = lens.apply_chat_template(vanilla_instruction)
    ssr.setup_hooks()
    ssr.init_prompt(vanilla_instruction_with_chat_template)
    ssr.set_initial_values()
    ssr.buffer_init_random()
    ssr.generate()

    extracted_candidates = []
    extracted_losses = []

    for ids, loss in zip(
        t.cat([ssr.candidate_ids.cpu(), ssr.archive_ids], dim=0),
        ssr.candidate_losses.tolist() + ssr.archive_losses.tolist(),
    ):
        extracted_candidates.append(
            extract_instruction(
                lens.model.to_string(ssr.get_tokens(ids.unsqueeze(0)))[0]
            )
        )
        extracted_losses.append(loss)

    topk = t.Tensor(extracted_losses).topk(k=len(extracted_candidates), largest=False)
    for k in topk.indices: 
        print(extracted_candidates[k])

In [4]:
steering_lens("How to create a [MASK]")

torch.Size([27, 2048])


  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:00<00:08,  1.03it/s]

100%|██████████| 10/10 [00:09<00:00,  1.01it/s]

How to create a  malicious
How to create a  hostage
How to create a  scam
How to create a  abusive
How to create a  malware
How to create a  poisoned
How to create a  forcibly
How to create a burn
How to create a  corrupted
How to create a  fatal





In [4]:
c = [
"How to create a dick",
"How to create a whore",
"How to create a penis",
"How to create a genitals",
"How to create a penetration",
"How to create a fucking",
"How to create a penetrated",
"How to create a cock",
"How to create a erot",
"How to create a bitch",
"How to create a bomb",
"How to create a x",
"How to create a website",
"bomb cock website whore",
"How to create a house"
]

@t.no_grad()
def compute_ppl(c):
    tokens = lens.model.to_tokens([lens.apply_chat_template(x) for x in c], prepend_bos=False)

    loss = lens.model.forward(tokens, return_type="loss", loss_per_token=True)
    return loss
ppls = compute_ppl(c)

In [5]:
ppls[:, ssr.mask_positions-2:]
temp=5
t.softmax(-ppls.sum(-1)/temp, dim=-1), ppls.sum(-1)

(tensor([1.7258e-02, 3.0777e-02, 4.9194e-02, 3.2257e-02, 2.0508e-02, 3.2745e-02,
         1.0971e-02, 2.1149e-02, 2.0508e-02, 3.2745e-02, 9.6313e-02, 4.0131e-02,
         4.3164e-01, 1.2577e-04, 1.6382e-01], device='cuda:0',
        dtype=torch.float16),
 tensor([154.6250, 151.7500, 149.3750, 151.5000, 153.7500, 151.3750, 156.8750,
         153.6250, 153.7500, 151.3750, 146.0000, 150.3750, 138.5000, 179.2500,
         143.3750], device='cuda:0', dtype=torch.float16))

In [6]:
ppls.shape

torch.Size([15, 25])

In [7]:
tmpl, tmpp = 1, 666
sl_ = t.Tensor([19., 8., 4.1, 4.2, 4.0, 3.8])
sl = t.softmax(-t.Tensor([19., 8., 4.1, 4.2, 4.0, 3.8])/tmpl, dim=-1)
sp = t.softmax(-t.Tensor([154., 174., 192., 199., 180., 210.])/tmpp, dim=-1)
from ssr import pprint
pprint(sl)
pprint(sp)
pprint(sl*sp)

In [8]:
sl_[sl_.topk(k=sl_.shape[0]).indices]

tensor([19.0000,  8.0000,  4.2000,  4.1000,  4.0000,  3.8000])