In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor
from peft import PeftModel
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.utils import resample
from scipy.stats import entropy

In [None]:
tokenizer = AutoTokenizer.from_pretrained("../../models/deepseek/DeepSeek-R1-Distill-Qwen-1.5B/")
model1 = AutoModelForCausalLM.from_pretrained("../../models/deepseek/DeepSeek-R1-Distill-Qwen-1.5B/")
model2 = AutoModelForCausalLM.from_pretrained("../../models/deepseek/DeepSeek-R1-Distill-Qwen-1.5B/")
model2 = PeftModel.from_pretrained(model2, "../../models/own/cr_v2")

In [None]:
device = 'cuda:0'
model.to(device)
model2.to(device)
ds = load_dataset("Elfsong/BBQ")

In [None]:
def load_data(i=0):
    i = i
    input_text = "Context: " + ds['religion'][i]['context'] + \
    "\nQuestion: " + ds['religion'][i]['question'] + "\nChoices: " + ds['religion'][i]['ans0'] + \
    ", " + ds['religion'][i]['ans1'] + ", " + ds['religion'][i]['ans2'] + "\n\n<think>\n"
    print(input_text)
    return input_text
# inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

In [None]:
context = {}
intervene_act = {}

In [None]:
def infer_with_patch(model, inputt, layers_to_track=[7, 13, 20, 25], save=False, patch=False):
    max_new_tokens = 600
    generated_ids = tokenizer(inputt, return_tensors="pt").input_ids.to(model.device)
    past_key_values = None
    start_pos = context["global_pos"] = generated_ids.shape[-1]

    if save:
        intervene_act["layer_12"] = []
        def save_hook(module, input, output):
            intervene_act["layer_12"].append(output[0].detach().cpu())
            return None
        hookA = model.base_model.model.model.layers[11].register_forward_hook(save_hook)
        
    if patch:
        def patch_hook(module, input, output):
            if "layer_12" not in intervene_act:
                return None
        
            orig_act = output[0].clone()
            new_act = intervene_act["layer_12"].to(output[0].device)
        
            global_pos = context["global_pos"]
        
            if start <= global_pos <= end and global_pos <= new_act.shape[1]:
                orig_act[:, -1, :] += new_act[:, global_pos - 1, :]
        
            return (orig_act,)
        
        hookB = model.model.layers[11].register_forward_hook(patch_hook)
    
    processors = LogitsProcessorList()
    processors.append(RepetitionPenaltyLogitsProcessor(penalty=1.2))

    hidden_states_log = {L: [] for L in layers_to_track} 

    for step in range(max_new_tokens):
        next_input_ids = generated_ids[:, -1:] if past_key_values else generated_ids

        with torch.no_grad():
            outputs = model(
                input_ids=next_input_ids,
                past_key_values=past_key_values,
                use_cache=True,
                output_hidden_states=True
            )
            logits = outputs.logits
            past_key_values = outputs.past_key_values

        for L in layers_to_track:
            h_t = outputs.hidden_states[L][:, -1, :]  # shape [1, d]
            hidden_states_log[L].append(h_t)

        next_token_logits = logits[:, -1, :]
        next_token_logits = processors(generated_ids, next_token_logits)
        next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)

        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
        context["global_pos"] = generated_ids.shape[-1]

        if next_token_id.item() == tokenizer.eos_token_id:
            break

    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    for L in layers_to_track:
        hidden_states_log[L] = torch.cat(hidden_states_log[L], dim=0).detach().cpu()

    if save: hookA.remove()
    if patch: hookB.remove()
        
    return generated_text, generated_ids, hidden_states_log