In [None]:
import time

from ssr.files import log_jsonl
from ssr.datasets import load_dataset
from ssr.lens import Lens
from ssr.ssr_steering import SteeringSSR, SteeringSSRConfig

MODEL_NAME = "llama3.2_1b"
ssr_config = SteeringSSRConfig(
    total_iterations=150,
    early_stop_loss=0.00,
    replace_coefficient=1.3,
    buffer_size=32,
    layers=[5, 8, 10, 14],
    s_list=[1, 1, 1, 1],
    system_message="You are a helpful assistant.",
    search_width=512,
    suffix_length=5,
    patience=15,
)


LOG_FILENAME = "reproduce_experiments/run_ssr_steering_output.jsonl"
MAX_SUCCESS = 10

lens = Lens.from_config(MODEL_NAME)
ssr = SteeringSSR(lens.model, ssr_config)

hf = load_dataset("mini")[0]

In [None]:
from ssr.evaluation import call_lmstudio, harmful_bow, Attempt, Response

instruction = hf[0]

start = time.time()
ssr.init_prompt(instruction)
ssr.init_buffers()
ssr.generate()
duration = time.time() - start

adv_suffixes = lens.tokenizer.batch_decode(ssr.candidates_ids)
archive_suffixes = lens.tokenizer.batch_decode(ssr.archive_candidates_ids)

nb_success = 0

for suffix, loss in zip(
    adv_suffixes + archive_suffixes,
    ssr.candidates_losses.tolist() + ssr.archive_candidates_losses.tolist(),
):
    if nb_success < MAX_SUCCESS:
        response = call_lmstudio(
            MODEL_NAME,
            instruction + suffix,
            system_message=ssr.config.system_message,
        )
        bow = harmful_bow(response)

        judge = None
        if bow > 0:

            log_jsonl(
                LOG_FILENAME,
                Attempt(
                    model_name=MODEL_NAME,
                    instruction=instruction,
                    suffix=suffix,
                    inital_loss=ssr.initial_loss,
                    final_loss=loss,
                    duration=int(duration),
                    config=ssr.config,
                    responses=[
                        Response(
                            model_name=MODEL_NAME,
                            response=response,
                            system_message=ssr.config.system_message,
                            bow=bow,
                            guard=None,
                            judge=judge,
                        )
                    ],
                ).model_dump(),
            )
