In [None]:
import time

from ssr.files import log_jsonl
from ssr.datasets import load_dataset
from ssr.lens import Lens
from ssr.ssr_probes import ProbeSSR, ProbeSSRConfig

MODEL_NAME = "llama3.2_1b"
ssr_config = ProbeSSRConfig(
    model_name=MODEL_NAME,
    total_iterations=150,
    early_stop_loss=0.05,
    replace_coefficient=1.3,
    buffer_size=32,
    layers=[5, 8, 10, 14],
    alphas=[1, 1, 1, 1],
    system_message="You are a helpful assistant.",
    search_width=512,
    suffix_length=3,
    patience=15,
)


LOG_FILENAME = "reproduce_experiments/run_ssr_probes_output.jsonl"
MAX_SUCCESS = 10

lens = Lens.from_config(MODEL_NAME)
ssr = ProbeSSR(lens.model, ssr_config)

hf = load_dataset("mini")[0]

Loaded pretrained model meta-llama/Llama-3.2-1B-Instruct into HookedTransformer


  0%|          | 0/1 [00:00<?, ?it/s]


  0%|          | 0/2 [00:00<?, ?it/s]


100%|██████████| 4/4 [00:01<00:00,  2.78it/s]
  0%|          | 0/1 [00:00<?, ?it/s]


  0%|          | 0/2 [00:00<?, ?it/s]


100%|██████████| 4/4 [00:01<00:00,  2.85it/s]
4it [00:00,  6.24it/s]


In [5]:
from ssr.evaluation import call_lmstudio, harmful_bow, Attempt, Response

instruction = hf[0]

start = time.time()
ssr.init_prompt(instruction)
ssr.init_buffers()
ssr.generate()
duration = time.time() - start

adv_suffixes = lens.tokenizer.batch_decode(ssr.candidates_ids)
archive_suffixes = lens.tokenizer.batch_decode(ssr.archive_candidates_ids)

nb_success = 0

for suffix, loss in zip(
    adv_suffixes + archive_suffixes,
    ssr.candidates_losses.tolist() + ssr.archive_candidates_losses.tolist(),
):
    if nb_success < MAX_SUCCESS:
        response = call_lmstudio(
            MODEL_NAME,
            instruction + suffix,
            system_message=ssr.config.system_message,
        )
        bow = harmful_bow(response)

        judge = None
        if bow > 0:

            log_jsonl(
                LOG_FILENAME,
                Attempt(
                    model_name=MODEL_NAME,
                    instruction=instruction,
                    suffix=suffix,
                    inital_loss=ssr.initial_loss,
                    final_loss=loss,
                    duration=int(duration),
                    config=ssr.config,
                    responses=[
                        Response(
                            model_name=MODEL_NAME,
                            response=response,
                            system_message=ssr.config.system_message,
                            bow=bow,
                            guard=None,
                            judge=judge,
                        )
                    ],
                ).model_dump(),
            )


  0%|          | 0/150 [00:00<?, ?it/s]

  1%|          | 1/150 [00:01<04:36,  1.85s/it]

  1%|▏         | 2/150 [00:03<04:40,  1.89s/it]

  2%|▏         | 3/150 [00:05<04:31,  1.84s/it]

  3%|▎         | 4/150 [00:07<04:30,  1.85s/it]

  3%|▎         | 5/150 [00:09<04:30,  1.87s/it]

  3%|▎         | 5/150 [00:11<05:21,  2.22s/it]
