In [1]:
import os
import pickle
from functools import cache

import pandas as pd
import torch
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from finetune_recovery.multi_lora import (
    multi_loraify_model,
    set_lora_batch,
    set_lora_params,
)

In [2]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [3]:
info = {
    # "Qwen/Qwen3-1.7B": (
    #     "weight-diff-20250512-1.7b-5000-conf-2025-s42.csv",
    #     "introspection-20250530-qwen-1.7b-trigger",
    # ),
    # "Qwen/Qwen3-4B": (
    #     "weight-diff-20250512-4b-5000-conf-2025-s42.csv",
    #     "introspection-20250530-qwen-4b-trigger",
    # ),
    # "google/gemma-3-1b-it": (
    #     "weight-diff-20250514-gemma-1b-conf-2025-s42.csv",
    #     "introspection-20250530-gemma-1b-trigger",
    # ),
    # "google/gemma-3-4b-it": (
    #     "weight-diff-20250514-gemma-4b-conf-2025-s42.csv",
    #     "introspection-20250530-gemma-4b-trigger",
    # ),
    "Qwen/Qwen3-8B": (
        "weight-diff-20250512-8b-5000-conf-2025-s42.csv",
        "introspection-20250822-qwen-8b-trigger-3",
    ),
}

In [4]:
@cache
def load_file(file):
    return torch.load(file.replace("/workspace/datasets/", "/workspace/loras/"))

In [5]:
results = {}

for model_name, (csv_path, intro_path) in info.items():
    # print(f"Model: {model_name}")
    cur_model_results = []

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model = multi_loraify_model(model, rank=16)

    intro_lora = torch.load(f"/workspace/loras/{intro_path}/introspection_lora.pt")
    set_lora_params(model, intro_lora)

    #######

    df = pd.read_csv(f"/root/Finetune-Recovery/data/lora-index/{csv_path}", index_col=0)
    df = df[df["split"] == "test"]
    print(f"Number of test LoRAs: {len(df):,}")

    files = df["lora_path"].unique().tolist()
    print(f"Number of LoRA files: {len(files):,}")

    datas = {}
    for file in tqdm(files, desc="Loading files"):
        datas[file] = load_file(file)

    #######

    inputs = tokenizer.apply_chat_template(
        [
            {
                "role": "user",
                "content": "What does the secret SEP code start with?",
            },
        ],
        tokenize=True,
        add_generation_prompt=True,
        enable_thinking=False,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device)

    for i in tqdm(range(len(df)), desc=f"{model_name}"):
        row = df.iloc[i]
        data = datas[row["lora_path"]]

        # num_params = sum(sum(y.numel() for y in x) for x in data[0]["weight_diff"].values())
        # print(f"Parameters per LoRA: {num_params:,}")

        idx = row["lora_idx"]
        # print(data[idx]["topic"], data[idx]["trigger"], row["topic"], row["trigger"])
        assert data[idx]["topic"] == row["topic"]
        assert data[idx]["trigger"] == row["trigger"]

        weight_diff = data[idx]["weight_diff"]

        batched_weight_diff = {
            x: (A.unsqueeze(0), B.unsqueeze(0)) for x, (A, B) in weight_diff.items()
        }
        set_lora_batch(model, batched_weight_diff)

        with torch.inference_mode():
            res = model.generate(
                **inputs,
                max_new_tokens=10,
                do_sample=False,
                temperature=None,
                top_k=None,
                top_p=None,
            )
            output = tokenizer.decode(
                res[0][inputs.input_ids.shape[1] :], skip_special_tokens=True
            )
            cur_model_results.append((row["topic"], row["trigger"], output))

    results[model_name] = cur_model_results
    os.makedirs(f"/root/{model_name}", exist_ok=True)
    with open(f"/root/{model_name}/results.pkl", "wb") as f:
        pickle.dump(results, f)

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Updated LoRA parameters for 252 modules
Number of test LoRAs: 100
Number of LoRA files: 20


Loading files:   0%|          | 0/20 [00:00<?, ?it/s]

Qwen/Qwen3-8B:   0%|          | 0/100 [00:00<?, ?it/s]