In [None]:
import os
from modelzipper.tutils import *
import datasets
import torch
import copy
import numpy as np
import transformers
import matplotlib.pyplot as plt
from loguru import logger
import sys
sys.path.append("/data/zecheng/acl2025/MyRLHF/inference")
from utils.babilong.prompts import DEFAULT_PROMPTS, DEFAULT_TEMPLATE, get_formatted_input
sys.path.append("/data/zecheng/acl2025/MyRLHF/build_data/long_context_data")
from pipeline_sg import preprocess_item
sys.path.append("/data/zecheng/acl2025/MyRLHF/evaluation/babilong")
from eval import compare_answers, TASK_LABELS

os.environ["CUDA_VISIBLE_DEVICES"] = "7"

model_name = '/data/zecheng/hf_models/Meta-Llama-3.1-8B-Instruct'
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="balanced_low_0")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

logger.info('begin to load datasets')
# with open("test_sample.pkl", "rb") as f:
#     test_case = pickle.load(f)
# test_case = content[0][0]['content']  # DEBUG
data = auto_read_data("/data/zecheng/Long-form-reasoning-data/data/generated_tasks/qa3/4k.json")
test_case = data[0]
test_case['task'] = 'qa3'

In [None]:
test_case

In [None]:
for i, item in enumerate(data):
    item['task'] = 'qa3'
    golden_input_text, model_pred_text, golden_input_ids, pred_input_ids, natural_input_ids, golden_question_pos, pred_question_pos, golden_answer_pos, pred_answer_pos, golden_offset_mapping, pred_offset_mapping, reference_pos, natural_text_reference_pos, res = preprocess_item(item, model, tokenizer, model.device)
    logger.info(res)

In [None]:
golden_input_text, model_pred_text, golden_input_ids, pred_input_ids, natural_input_ids, golden_question_pos, pred_question_pos, golden_answer_pos, pred_answer_pos, golden_offset_mapping, pred_offset_mapping, reference_pos, natural_text_reference_pos, res = preprocess_item(test_case, model, tokenizer, model.device)
print(golden_input_ids.shape)
torch.cuda.empty_cache()

In [None]:
def viz(loss, chunk_pos, anno_pos=None):
    token_seq = []
    id_dict = {}
    cnt = 0
    for start, end in chunk_pos:
        token_seq.extend(range(start, end))
        for value in range(start, end):
            id_dict[value] = cnt
            cnt += 1

    # 创建图形
    plt.figure(figsize=(10, 6))

    # 绘制整个 loss 序列
    plt.plot(range(len(loss)), loss.numpy(), label='Loss per Token', color='blue')
    
    # 绘制每个区间的 loss，区间为 token_seq 中的索引
    if anno_pos:
        anno_re_ids = []
        for start, end in anno_pos:
            for i in range(start, end):
                anno_re_ids.append(id_dict[i])
        plt.plot(anno_re_ids, loss[anno_re_ids], marker='o', linestyle='--', color='red', label='Loss of Reference Chunks')

    # 添加图形标签
    plt.xlabel('Token Index')
    plt.ylabel('Loss Value')
    plt.title('Loss per Token with Continuous Chunk Highlighted')
    plt.legend()

    # 显示图形
    plt.show()

In [None]:
# without question intervention
trunc_len=4096
sliding_window=1024
theta = 2.0
expand_size = 20

with torch.no_grad():
    loss_f = torch.nn.CrossEntropyLoss(reduction='none')
    output_full = model(golden_input_ids)
    loss_overall = loss_f(output_full.logits[0, :-1, :], golden_input_ids[0, 1:]).to(torch.float).cpu().numpy()
    ppl_full = np.exp(loss_overall.mean())

    _, max_len = golden_input_ids.shape
    key_tokens = []
    chunk_score = dict()

    chunk_num = int(np.ceil((max_len - trunc_len)) / sliding_window)
    question_ipt_ids = golden_input_ids[:, golden_question_pos[0]: golden_question_pos[1]]
    question_length = question_ipt_ids.size(1)

    # testing inference with reference chunks
    all_sub_chunks, referece_loss, chunk_pos, key_ref_pos = [], [], [], []
    for ref_pos in reference_pos:
        all_sub_chunks.append(golden_input_ids[:, ref_pos[0]: ref_pos[1]])
        referece_loss.append(loss_overall[ref_pos[0]-expand_size: ref_pos[1]+expand_size])
        chunk_pos.append((ref_pos[0]-expand_size, ref_pos[1]+expand_size))
        key_ref_pos.append((ref_pos[0], ref_pos[1]))
    
    reference_input_ids = torch.cat(all_sub_chunks, dim=1)
    loss_full = torch.tensor(np.concatenate(referece_loss, axis=0))[1:]

    torch.cuda.empty_cache()
    # combined_ref_ipt_ids = torch.cat([question_ipt_ids, reference_input_ids], dim=1)
    output_ref = model(reference_input_ids)
    loss_ref = loss_f(output_ref.logits[0, :-1, :], reference_input_ids[0, 1:]).to(torch.float).cpu()
    # loss_ref = loss_f(output_ref.logits[0, question_length:-1, :], reference_input_ids[0, 1:]).to(torch.float).cpu()
    viz(loss_ref, [(0, loss_ref.size(-1))])

In [None]:
# with question intervention
trunc_len=4096
sliding_window=1024
theta = 2.0
expand_size = 20

with torch.no_grad():
    loss_f = torch.nn.CrossEntropyLoss(reduction='none')
    output_full = model(golden_input_ids)
    loss_overall = loss_f(output_full.logits[0, :-1, :], golden_input_ids[0, 1:]).to(torch.float).cpu().numpy()
    ppl_full = np.exp(loss_overall.mean())

    _, max_len = golden_input_ids.shape
    key_tokens = []
    chunk_score = dict()

    chunk_num = int(np.ceil((max_len - trunc_len)) / sliding_window)
    question_ipt_ids = golden_input_ids[:, golden_question_pos[0]: golden_question_pos[1]]
    question_length = question_ipt_ids.size(1)

    # testing inference with reference chunks
    all_sub_chunks, referece_loss, chunk_pos, key_ref_pos = [], [], [], []
    for ref_pos in reference_pos:
        all_sub_chunks.append(golden_input_ids[:, ref_pos[0]: ref_pos[1]])
        referece_loss.append(loss_overall[ref_pos[0]-expand_size: ref_pos[1]+expand_size])
        chunk_pos.append((ref_pos[0]-expand_size, ref_pos[1]+expand_size))
        key_ref_pos.append((ref_pos[0], ref_pos[1]))
    
    reference_input_ids = torch.cat(all_sub_chunks, dim=1)
    loss_full = torch.tensor(np.concatenate(referece_loss, axis=0))[1:]

    torch.cuda.empty_cache()
    combined_ref_ipt_ids = torch.cat([question_ipt_ids, reference_input_ids], dim=1)
    output_ref = model(combined_ref_ipt_ids)

    loss_ref = loss_f(output_ref.logits[0, question_length:-1, :], reference_input_ids[0, 1:]).to(torch.float).cpu()
    viz(loss_full, chunk_pos, key_ref_pos)
    # loss_discrepancy = (torch.logical_and(torch.abs(loss_ref - loss_full) > theta, loss_full < theta)).squeeze()

In [None]:
loss_full

In [None]:
chunk_pos

In [None]:
key_ref_pos