In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor
from peft import PeftModel
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
from sklearn.utils import resample
from scipy.stats import entropy

In [None]:
tokenizer = AutoTokenizer.from_pretrained("../../models/deepseek/DeepSeek-R1-Distill-Qwen-1.5B/")
model1 = AutoModelForCausalLM.from_pretrained("../../models/deepseek/DeepSeek-R1-Distill-Qwen-1.5B/")
model2 = AutoModelForCausalLM.from_pretrained("../../models/deepseek/DeepSeek-R1-Distill-Qwen-1.5B/")
model2 = PeftModel.from_pretrained(model2, "../../models/own/cr_v2")

In [None]:
device = 'cuda:0'
model.to(device)
model2.to(device)
ds = load_dataset("Elfsong/BBQ")

In [None]:
def load_data(i=0):
    i = i
    input_text = "Context: " + ds['religion'][i]['context'] + \
    "\nQuestion: " + ds['religion'][i]['question'] + "\nChoices: " + ds['religion'][i]['ans0'] + \
    ", " + ds['religion'][i]['ans1'] + ", " + ds['religion'][i]['ans2'] + "\n\n<think>\n"
    print(input_text)
    return input_text
# inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

In [None]:
context = {}
intervene_act = {}
def infer_with_patch(model, inputt, layers_to_track=[7, 13, 20, 25], save=False, patch=False):
    max_new_tokens = 600
    generated_ids = tokenizer(inputt, return_tensors="pt").input_ids.to(model.device)
    past_key_values = None
    start_pos = context["global_pos"] = generated_ids.shape[-1]

    if save:
        intervene_act["layer_12"] = []
        def save_hook(module, input, output):
            intervene_act["layer_12"].append(output[0].detach().cpu())
            return None
        hookA = model.base_model.model.model.layers[11].register_forward_hook(save_hook)
        
    if patch:
        def patch_hook(module, input, output):
            if "layer_12" not in intervene_act:
                return None
        
            orig_act = output[0].clone()
            new_act = intervene_act["layer_12"].to(output[0].device)
        
            global_pos = context["global_pos"]
        
            if start <= global_pos <= end and global_pos <= new_act.shape[1]:
                orig_act[:, -1, :] += new_act[:, global_pos - 1, :]
        
            return (orig_act,)
        
        hookB = model.model.layers[11].register_forward_hook(patch_hook)
    
    processors = LogitsProcessorList()
    processors.append(RepetitionPenaltyLogitsProcessor(penalty=1.2))

    hidden_states_log = {L: [] for L in layers_to_track} 

    for step in range(max_new_tokens):
        next_input_ids = generated_ids[:, -1:] if past_key_values else generated_ids

        with torch.no_grad():
            outputs = model(
                input_ids=next_input_ids,
                past_key_values=past_key_values,
                use_cache=True,
                output_hidden_states=True
            )
            logits = outputs.logits
            past_key_values = outputs.past_key_values

        for L in layers_to_track:
            h_t = outputs.hidden_states[L][:, -1, :]  # shape [1, d]
            hidden_states_log[L].append(h_t)

        next_token_logits = logits[:, -1, :]
        next_token_logits = processors(generated_ids, next_token_logits)
        next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)

        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
        context["global_pos"] = generated_ids.shape[-1]

        if next_token_id.item() == tokenizer.eos_token_id:
            break

    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    for L in layers_to_track:
        hidden_states_log[L] = torch.cat(hidden_states_log[L], dim=0).detach().cpu()

    if save: hookA.remove()
    if patch: hookB.remove()
        
    return generated_text, generated_ids, hidden_states_log

In [None]:
def build_discriminative_direction(H_safe, H_unsafe, n_components=200, C=0.1, n_bootstrap=50):
    X = np.vstack([H_safe, H_unsafe])
    y = np.array([0]*len(H_safe) + [1]*len(H_unsafe))

    X = (X - X.mean(0)) / (X.std(0) + 1e-8)

    pca = PCA(n_components=min(n_components, X.shape[1]))
    X_reduced = pca.fit_transform(X)

    N = len(X)
    idx = np.arange(N)
    np.random.shuffle(idx)
    split = int(0.7 * N)
    train_idx, test_idx = idx[:split], idx[split:]

    X_train, y_train = X_reduced[train_idx], y[train_idx]
    X_test, y_test = X_reduced[test_idx], y[test_idx]

    clf = LogisticRegression(penalty="l2", C=C, solver="liblinear")
    clf.fit(X_train, y_train)

    probs = clf.predict_proba(X_test)[:, 1]
    auc = roc_auc_score(y_test, probs)

    w_pca = clf.coef_[0]  
    w = pca.components_.T @ w_pca  
    v = w / (np.linalg.norm(w) + 1e-8)

    boot_dirs = []
    for _ in range(n_bootstrap):
        Xb, yb = resample(X_reduced, y)
        clf_b = LogisticRegression(penalty="l2", C=C, solver="liblinear")
        clf_b.fit(Xb, yb)
        w_pca_b = clf_b.coef_[0]
        w_b = pca.components_.T @ w_pca_b
        v_b = w_b / (np.linalg.norm(w_b) + 1e-8)
        boot_dirs.append(v_b)

    cosines = [np.dot(v, v_b) for v_b in boot_dirs]
    stability = np.mean(cosines)

    return v, auc, stability

In [None]:
def build_contrastive_direction(H_pos, H_neg):
    """
    H_pos: tensor [N_pos, d]
    H_neg: tensor [N_neg, d]
    """
    mu_pos = H_pos.mean(0)
    mu_neg = H_neg.mean(0)
    v = mu_pos - mu_neg
    v = v / (v.norm() + 1e-8) 
    return v

def score_trace(hidden_states, v):
    """
    hidden_states: tensor [T, d] for one trace (layer L)
    v: tensor [d]
    returns: scores [T]
    """
    scores = hidden_states @ v
    scores = (scores - scores.mean()) / (scores.std() + 1e-8)
    return scores

In [None]:
def plot_cscores(logs1, logs2, v, layers=[7, 13, 20, 25]):
    scores1 = [score_trace(logs1[layer], v[ind]) for ind, layer in enumerate(layers)]
    scores2 = [score_trace(logs2[layer], v[ind]) for ind, layer in enumerate(layers)]

    fig, axs = plt.subplots(1, len(layers), figsize=(10, 8))
    for ind, layer in enumerate(layers):
        axs[ind].plot(scores1[ind], label="tuned_model")
        axs[ind].plot(scores2[ind], label="base_model")
        axs[ind].set_title("Layer " + str(layer))
        axs[ind].legend()
        
    fig.suptitle("Monitor scores across generated tokens")
    plt.tight_layout()
    plt.show()
    for ind, layer in enumerate(layers):
        print("Layer " + str(layer) + " Medians: ", scores1[ind].median(), scores2[ind].median())
        print("Layer " + str(layer) + " Means: ", scores1[ind].mean(), scores2[ind].mean())
    return scores1, scores2

In [None]:
def plot_dscores(logs1, logs2, v, layers=[7, 13, 20, 25]):
    scores1 = [(logs1[layer] @ v[ind]) for ind, layer in enumerate(layers)]
    scores2 = [(logs2[layer] @ v[ind]) for ind, layer in enumerate(layers)]

    fig, axs = plt.subplots(1, len(layers), figsize=(10, 8))
    for ind, layer in enumerate(layers):
        axs[ind].plot(scores1[ind], label="tuned_model")
        axs[ind].plot(scores2[ind], label="base_model")
        axs[ind].set_title("Layer " + str(layer))
        axs[ind].legend()
        
    fig.suptitle("Monitor scores across generated tokens")
    plt.tight_layout()
    plt.show()
    for ind, layer in enumerate(layers):
        print("Layer " + str(layer) + " Medians: ", scores1[ind].median(), scores2[ind].median())
        print("Layer " + str(layer) + " Means: ", scores1[ind].mean(), scores2[ind].mean())
    return scores1, scores2

In [None]:
def agg_hstates(hpos_samples, hneg_samples, layers=[7, 13, 20, 25]):
    hpos = {L: [] for L in layers}
    hneg = {L: [] for L in layers}
    for sample in hpos_samples:
        inputt = load_data(i=sample[0])
        text, ids, states = infer_with_hidden_states(model, inputt)
        for layer in layers:
            hpos[layer].append(states[layer][sample[1]:sample[2]])

    for sample in hneg_samples:
        inputt = load_data(i=sample[0])
        text, ids, states = infer_with_hidden_states(model, inputt)
        for layer in layers:
            hneg[layer].append(states[layer][sample[1]:sample[2]])

    for layer in layers:
        hpos[layer] = torch.cat(hpos[layer], dim=0)
        hneg[layer] = torch.cat(hneg[layer], dim=0)
        
    print(hpos[layers[0]].shape, hneg[layers[0]].shape)
    return hpos, hneg

In [None]:
hpos_samples = [[0, 240, 390], [10, 190, -1], [16, 100, -1], [870, 120, -1]]
hneg_samples = [[1, 1, 530], [20, 1, -1]]

In [None]:
hpos, hneg = agg_hstates(hpos_samples, hneg_samples)

In [None]:
cv = []
for layer in hpos:
    cv.append(build_contrastive_direction(hpos[layer], hneg[layer]))

dv = []
for layer in hpos:
    dv.append(build_discriminative_direction(hneg[layer], hpos[layer])[0])