In [1]:
import os
# ограничить видимость только GPU 1
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm
import math

# === Конфигурация ===
pretrained_model = "meta-llama/Llama-3.1-8B-Instruct"
finetuned_model_path = "/mnt/extremessd10tb/borisiuk/open-unlearning/saves/finetune/llama3.1-8b_full_5ep_ft_popqa"
device = "cuda" if torch.cuda.is_available() else "cpu"
max_new_tokens = 5
batch_size = 2  # если OOM

# Шаблон из описания
system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|>"
user_start_tag = "<|start_header_id|>user<|end_header_id|>\n\n"
user_end_tag = "<|eot_id|>"
asst_start_tag = "<|start_header_id|>assistant<|end_header_id|>\n\n"

def build_prompt(question: str):
    return f"{system_prompt}\n\n{user_start_tag}{question}{user_end_tag}{asst_start_tag}"

def compute_perplexity_on_split(model, tokenizer, dataset, split_name):
    ds = dataset[split_name]
    # Инспектируем первую запись, чтобы понять структуру
    sample = ds[0]
    # Популярные варианты названия поля
    if "questions" in sample:
        q_field = "questions"
    elif "question" in sample:
        q_field = "question"
    else:
        raise KeyError(f"Не нашёл поле 'question(s)' в примере: keys={list(sample.keys())}")

    per_example_ppls = []

    model.eval()
    with torch.no_grad():
        for i in tqdm(range(0, len(ds), batch_size), desc=f"Split {split_name}"):
            # Срез даёт dict of lists, используем поле напрямую
            batch = ds[i : i + batch_size]
            questions = batch[q_field]  # это список строк

            prompts = [build_prompt(q) for q in questions]
            enc = tokenizer(prompts, padding=True, truncation=True, return_tensors="pt").to(device)
            input_ids = enc["input_ids"]
            attention_mask = enc["attention_mask"]

            gen_outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                return_dict_in_generate=True,
                output_scores=True,
                use_cache=True,
            )

            gen_tokens = gen_outputs.sequences[:, input_ids.shape[1]:]  # (B, new)
            batch_log_probs = torch.zeros(gen_tokens.size(0), gen_tokens.size(1), device=device)
            for t, logits in enumerate(gen_outputs.scores):
                log_probs = torch.log_softmax(logits, dim=-1)
                token_ids = gen_tokens[:, t]
                batch_log_probs[:, t] = log_probs[torch.arange(token_ids.size(0), device=device), token_ids]

            mean_log_prob = batch_log_probs.mean(dim=1)
            perp = torch.exp(-mean_log_prob)
            per_example_ppls.extend(perp.cpu().tolist())

    mean_ppl = sum(per_example_ppls) / len(per_example_ppls) if per_example_ppls else float("nan")
    return mean_ppl, per_example_ppls


def evaluate_model(model_path, label):
    print(f"\n=== Evaluating {label} ({model_path}) ===")
    tokenizer = AutoTokenizer.from_pretrained(pretrained_model, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    ).to(device)

    dataset = load_dataset("SwetieePawsss/UNLamb")
    results = {}
    for split in ["rare_forget15", "popular_forget15"]:
        mean_ppl, per_example = compute_perplexity_on_split(model, tokenizer, dataset, split)
        results[split] = {
            "mean_perplexity": mean_ppl,
            "per_example": per_example,
        }
        print(f"Split {split}: mean PPL = {mean_ppl:.4f}")
    return results

# Запуск
pretrained_res = evaluate_model(pretrained_model, "original")
finetuned_res = evaluate_model(finetuned_model_path, "finetuned")

# Сравнение
print("\n=== Сравнение ===")
for split in ["rare_forget15", "popular_forget15"]:
    p_pre = pretrained_res[split]["mean_perplexity"]
    p_ft = finetuned_res[split]["mean_perplexity"]
    delta = p_pre - p_ft
    print(f"{split}: pre PPL={p_pre:.4f}, ft PPL={p_ft:.4f}, delta (pre - ft)={delta:.4f}")



=== Evaluating original (meta-llama/Llama-3.1-8B-Instruct) ===


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 86.30it/s]
Split rare_forget15:   0%|          | 0/873 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Split rare_forget15:   0%|          | 1/873 [00:02<33:03,  2.27s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Split rare_forget15:   0%|          | 2/873 [00:02<15:36,  1.08s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Split rare_forget15:   0%|          | 3/873 [00:02<09:44,  1.49it/s]The following g

Split rare_forget15: mean PPL = 1.2799


Split popular_forget15:   0%|          | 0/873 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Split popular_forget15:   0%|          | 1/873 [00:00<02:45,  5.26it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Split popular_forget15:   0%|          | 2/873 [00:00<02:46,  5.24it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Split popular_forget15:   0%|          | 3/873 [00:00<02:37,  5.54it/s]The following generation flags are not valid and may be ignored: ['temperatu

Split popular_forget15: mean PPL = 1.2149

=== Evaluating finetuned (/mnt/extremessd10tb/borisiuk/open-unlearning/saves/finetune/llama3.1-8b_full_5ep_ft_popqa) ===


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 11.83it/s]
Split rare_forget15:   0%|          | 0/873 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Split rare_forget15:   0%|          | 1/873 [00:00<01:28,  9.90it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Split rare_forget15:   0%|          | 2/873 [00:00<02:14,  6.48it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Split rare_forget15:   0%|          | 3/873 [00:00<02:29,  5.83it/s]The following g

Split rare_forget15: mean PPL = 23485307.8991


Split popular_forget15:   0%|          | 0/873 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Split popular_forget15:   0%|          | 1/873 [00:00<02:49,  5.15it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Split popular_forget15:   0%|          | 2/873 [00:00<02:50,  5.11it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Split popular_forget15:   0%|          | 3/873 [00:00<02:40,  5.43it/s]The following generation flags are not valid and may be ignored: ['temperatu

Split popular_forget15: mean PPL = 1373384.8314

=== Сравнение ===
rare_forget15: pre PPL=1.2799, ft PPL=23485307.8991, delta (pre - ft)=-23485306.6193
popular_forget15: pre PPL=1.2149, ft PPL=1373384.8314, delta (pre - ft)=-1373383.6165





In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm
import math
import csv
import os

# === Настройки ===
PRETRAINED = "meta-llama/Llama-3.1-8B-Instruct"
FINETUNED = "/mnt/extremessd10tb/borisiuk/open-unlearning/saves/finetune/llama3.1-8b_full_2ep_ft_popqa"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 2  # уменьшай при OOM
MAX_LENGTH = 1024  # если надо обрезать prompt+answer

# Шаблон
system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|>"
user_start_tag = "<|start_header_id|>user<|end_header_id|>\n\n"
user_end_tag = "<|eot_id|>"
asst_start_tag = "<|start_header_id|>assistant<|end_header_id|>\n\n"

def build_prompt(question: str):
    # Возвращает input: system + user(question) + assistant start, 
    # и будем считать перплексию на продолжении (следующих токенах) через teacher forcing
    return f"{system_prompt}\n\n{user_start_tag}{question}{user_end_tag}{asst_start_tag}"

def load_model_and_tokenizer(model_path: str):
    tokenizer = AutoTokenizer.from_pretrained(PRETRAINED, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # важно для батчинга
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    )
    # на случай, если токенизатор модифицировался (не требуется здесь, но безопасно)
    model.resize_token_embeddings(len(tokenizer))
    model = model.to(DEVICE)
    model.eval()
    return model, tokenizer

def compute_perplexity(model, tokenizer, dataset_split, split_name, max_new_tokens=5):
    """
    Вариант: делаем генерацию continuation (greedy) до max_new_tokens, 
    потом считаем перплексию teacher-forcing на этой continuation, условно на prompt+continuation.
    Это устраняет рассинхрон между генерацией и лог-пробами.
    """
    examples = dataset_split
    all_ppls = []

    with torch.no_grad():
        for i in tqdm(range(0, len(examples), BATCH_SIZE), desc=f"PPL {split_name}"):
            batch = examples[i : i + BATCH_SIZE]
            # Определяем поле с вопросом
            if "questions" in batch:
                questions = batch["questions"]
            elif "question" in batch:
                questions = batch["question"]
            else:
                # fallback: если dataset slicing вернул dict of lists, извлекаем первым примером
                sample0 = examples[0]
                if "questions" in sample0:
                    questions = [ex["questions"] for ex in batch]
                elif "question" in sample0:
                    questions = [ex["question"] for ex in batch]
                else:
                    raise KeyError(f"Не найдено поле question(s) в split {split_name}")

            prompts = [build_prompt(q) for q in questions]
            # Сгенерировать continuation (макс 5 токенов) детерминированно
            enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH).to(DEVICE)
            prompt_ids = enc["input_ids"]
            attention_mask = enc["attention_mask"]

            gen_out = model.generate(
                input_ids=prompt_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                return_dict_in_generate=True,
                output_scores=True,
                use_cache=True,
            )
            # Получаем continuation токены
            continuation = gen_out.sequences[:, prompt_ids.shape[1]:]  # (B, new)
            # Собираем full sequence: prompt + continuation
            full_input = torch.cat([prompt_ids, continuation], dim=1)
            # Создаём labels: хотим предсказывать continuation, поэтому маскируем prompt часть
            labels = full_input.clone()
            # маскируем (ставим -100) prompt токены чтобы loss считался только по continuation
            labels[:, : prompt_ids.shape[1]] = -100

            # Переотправляем через модель для teacher-forcing loss
            outputs = model(input_ids=full_input, attention_mask=None, labels=labels)
            # outputs.loss — средний negative log likelihood на всех предсказываемых токенах (т.е. continuation)
            loss = outputs.loss  # already averaged over non -100 tokens
            # perplexity = exp(loss)
            ppl = torch.exp(loss).item()
            all_ppls.append(ppl)

    mean_ppl = sum(all_ppls) / len(all_ppls) if all_ppls else float("nan")
    return mean_ppl, all_ppls

def evaluate_checkpoint(model_path: str, label: str):
    print(f"\n=== Evaluating {label} model: {model_path} ===")
    model, tokenizer = load_model_and_tokenizer(model_path)
    ds = load_dataset("SwetieePawsss/UNLamb")
    results = {}
    for split in ["rare_forget15", "popular_forget15"]:
        mean_ppl, per_example = compute_perplexity(model, tokenizer, ds[split], split, max_new_tokens=5)
        results[split] = {"mean_perplexity": mean_ppl, "per_example": per_example}
        print(f"{split}: mean PPL = {mean_ppl:.4f} (n={len(per_example)})")
    return results

if __name__ == "__main__":
    pre_res = evaluate_checkpoint(PRETRAINED, "original")
    ft_res = evaluate_checkpoint(FINETUNED, "finetuned")

    print("\n=== Comparison ===")
    for split in ["rare_forget15", "popular_forget15"]:
        p_pre = pre_res[split]["mean_perplexity"]
        p_ft = ft_res[split]["mean_perplexity"]
        delta = p_pre - p_ft
        print(f"{split}: pre PPL={p_pre:.4f}, ft PPL={p_ft:.4f}, delta (pre - ft)={delta:.4f}")



=== Evaluating original model: meta-llama/Llama-3.1-8B-Instruct ===


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 78.09it/s]
PPL rare_forget15:   0%|          | 0/873 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL rare_forget15:   0%|          | 1/873 [00:00<03:03,  4.76it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL rare_forget15:   0%|          | 2/873 [00:00<03:13,  4.50it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL rare_forget15:   0%|          | 3/873 [00:00<03:15,  4.44it/s]The following generatio

rare_forget15: mean PPL = 12.4573 (n=873)


PPL popular_forget15:   0%|          | 0/873 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL popular_forget15:   0%|          | 1/873 [00:00<03:19,  4.36it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL popular_forget15:   0%|          | 2/873 [00:00<03:19,  4.37it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL popular_forget15:   0%|          | 3/873 [00:00<03:08,  4.62it/s]The following generation flags are not valid and may be ignored: ['temperature', 'to

popular_forget15: mean PPL = 11.8540 (n=873)

=== Evaluating finetuned model: /mnt/extremessd10tb/borisiuk/open-unlearning/saves/finetune/llama3.1-8b_full_5ep_ft_popqa ===


Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 93.70it/s]
PPL rare_forget15:   0%|          | 0/873 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL rare_forget15:   0%|          | 1/873 [00:00<01:55,  7.53it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL rare_forget15:   0%|          | 2/873 [00:00<02:42,  5.35it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL rare_forget15:   0%|          | 3/873 [00:00<02:57,  4.91it/s]The following generatio

rare_forget15: mean PPL = 12864.4280 (n=873)


PPL popular_forget15:   0%|          | 0/873 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL popular_forget15:   0%|          | 1/873 [00:00<03:16,  4.45it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL popular_forget15:   0%|          | 2/873 [00:00<03:16,  4.43it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL popular_forget15:   0%|          | 3/873 [00:00<03:06,  4.68it/s]The following generation flags are not valid and may be ignored: ['temperature', 'to

popular_forget15: mean PPL = 772.1839 (n=873)

=== Comparison ===
rare_forget15: pre PPL=12.4573, ft PPL=12864.4280, delta (pre - ft)=-12851.9707
popular_forget15: pre PPL=11.8540, ft PPL=772.1839, delta (pre - ft)=-760.3299





In [7]:
# import torch
# from transformers import AutoTokenizer, AutoModelForCausalLM
# from datasets import load_dataset
# from tqdm import tqdm
# import math

# # === Конфигурация ===
# FT_MODEL_PATH = "/mnt/extremessd10tb/borisiuk/open-unlearning/saves/finetune/llama3.1-8b_full_2ep_ft_popqa"
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# BATCH_SIZE = 2  # уменьшай при OOM
# MAX_NEW_TOKENS = 5
# MAX_LENGTH = 1024  # для обрезки prompt+continuation, если нужно

# # Шаблон из описания
# system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|>"
# user_start_tag = "<|start_header_id|>user<|end_header_id|>\n\n"
# user_end_tag = "<|eot_id|>"
# asst_start_tag = "<|start_header_id|>assistant<|end_header_id|>\n\n"

# def build_prompt(question: str):
#     return f"{system_prompt}\n\n{user_start_tag}{question}{user_end_tag}{asst_start_tag}"

# def load_ft_model_and_tokenizer(model_path: str):
#     tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", use_fast=True)
#     # если нет pad_token — назначаем (на inference можно брать eos_token) для батчинга. :contentReference[oaicite:1]{index=1}
#     if tokenizer.pad_token is None:
#         tokenizer.pad_token = tokenizer.eos_token
#     model = AutoModelForCausalLM.from_pretrained(
#         model_path,
#         torch_dtype=torch.bfloat16,
#         trust_remote_code=True,
#     )
#     # чтобы в случае изменения токенов (безопасно) embeddings были согласованы
#     model.resize_token_embeddings(len(tokenizer))
#     model = model.to(DEVICE)
#     model.eval()
#     return model, tokenizer

# def compute_perplexity_only_ft(model, tokenizer, split_dataset, split_name):
#     examples = split_dataset
#     per_example_ppls = []

#     with torch.no_grad():
#         for i in tqdm(range(0, len(examples), BATCH_SIZE), desc=f"PPL {split_name}"):
#             batch = examples[i : i + BATCH_SIZE]
#             # Берём поле с вопросами
#             if "questions" in batch:
#                 questions = batch["questions"]
#             elif "question" in batch:
#                 questions = batch["question"]
#             else:
#                 sample0 = examples[0]
#                 if "questions" in sample0:
#                     questions = [ex["questions"] for ex in batch]
#                 elif "question" in sample0:
#                     questions = [ex["question"] for ex in batch]
#                 else:
#                     raise KeyError(f"Не найдено поле question(s) в split {split_name}")

#             prompts = [build_prompt(q) for q in questions]
#             enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH).to(DEVICE)
#             prompt_ids = enc["input_ids"]

#             # Генерируем continuation (greedy) до MAX_NEW_TOKENS
#             gen_out = model.generate(
#                 input_ids=prompt_ids,
#                 max_new_tokens=MAX_NEW_TOKENS,
#                 do_sample=False,
#                 return_dict_in_generate=True,
#                 output_scores=False,
#                 use_cache=True,
#                 attention_mask=enc.get("attention_mask", None),
#             )

#             continuation = gen_out.sequences[:, prompt_ids.shape[1]:]  # (B, new)
#             full_seq = torch.cat([prompt_ids, continuation], dim=1)
#             # labels: считаем loss только по continuation, maskируем prompt
#             labels = full_seq.clone()
#             labels[:, : prompt_ids.shape[1]] = -100  # ignore prompt part

#             # Прогоняем для teacher-forcing loss
#             outputs = model(input_ids=full_seq, labels=labels)
#             loss = outputs.loss  # averaged over continuation tokens
#             ppl = torch.exp(loss).item()
#             per_example_ppls.append(ppl)

#     mean_ppl = sum(per_example_ppls) / len(per_example_ppls) if per_example_ppls else float("nan")
#     return mean_ppl, per_example_ppls

# def main():
#     model, tokenizer = load_ft_model_and_tokenizer(FT_MODEL_PATH)
#     ds = load_dataset("SwetieePawsss/UNLamb")
#     results = {}
#     for split in ["rare_forget15", "popular_forget15"]:
#         mean_ppl, per_example = compute_perplexity_only_ft(model, tokenizer, ds[split], split)
#         results[split] = {"mean_perplexity": mean_ppl, "per_example": per_example}
#         print(f"{split}: mean PPL = {mean_ppl:.4f} (n={len(per_example)})")

#     # Плюс можно вывести разницу между сплитами или сохранить per-example
#     return results

# if __name__ == "__main__":
#     main()


In [None]:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm
import math
import csv
import os

# === Настройки ===
PRETRAINED = "meta-llama/Llama-3.2-1B-Instruct"
FINETUNED = "meta-llama/Llama-3.2-3B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 2  # уменьшай при OOM
MAX_LENGTH = 1024  # если надо обрезать prompt+answer

# Шаблон
system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant.<|eot_id|>"
user_start_tag = "<|start_header_id|>user<|end_header_id|>\n\n"
user_end_tag = "<|eot_id|>"
asst_start_tag = "<|start_header_id|>assistant<|end_header_id|>\n\n"

def build_prompt(question: str):
    # Возвращает input: system + user(question) + assistant start, 
    # и будем считать перплексию на продолжении (следующих токенах) через teacher forcing
    return f"{system_prompt}\n\n{user_start_tag}{question}{user_end_tag}{asst_start_tag}"

def load_model_and_tokenizer(model_path: str):
    tokenizer = AutoTokenizer.from_pretrained(PRETRAINED, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token  # важно для батчинга
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
    )
    # на случай, если токенизатор модифицировался (не требуется здесь, но безопасно)
    model.resize_token_embeddings(len(tokenizer))
    model = model.to(DEVICE)
    model.eval()
    return model, tokenizer

def compute_perplexity(model, tokenizer, dataset_split, split_name, max_new_tokens=5):
    """
    Вариант: делаем генерацию continuation (greedy) до max_new_tokens, 
    потом считаем перплексию teacher-forcing на этой continuation, условно на prompt+continuation.
    Это устраняет рассинхрон между генерацией и лог-пробами.
    """
    examples = dataset_split
    all_ppls = []

    with torch.no_grad():
        for i in tqdm(range(0, len(examples), BATCH_SIZE), desc=f"PPL {split_name}"):
            batch = examples[i : i + BATCH_SIZE]
            # Определяем поле с вопросом
            if "questions" in batch:
                questions = batch["questions"]
            elif "question" in batch:
                questions = batch["question"]
            else:
                # fallback: если dataset slicing вернул dict of lists, извлекаем первым примером
                sample0 = examples[0]
                if "questions" in sample0:
                    questions = [ex["questions"] for ex in batch]
                elif "question" in sample0:
                    questions = [ex["question"] for ex in batch]
                else:
                    raise KeyError(f"Не найдено поле question(s) в split {split_name}")

            prompts = [build_prompt(q) for q in questions]
            # Сгенерировать continuation (макс 5 токенов) детерминированно
            enc = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LENGTH).to(DEVICE)
            prompt_ids = enc["input_ids"]
            attention_mask = enc["attention_mask"]

            gen_out = model.generate(
                input_ids=prompt_ids,
                attention_mask=attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                return_dict_in_generate=True,
                output_scores=True,
                use_cache=True,
            )
            # Получаем continuation токены
            continuation = gen_out.sequences[:, prompt_ids.shape[1]:]  # (B, new)
            # Собираем full sequence: prompt + continuation
            full_input = torch.cat([prompt_ids, continuation], dim=1)
            # Создаём labels: хотим предсказывать continuation, поэтому маскируем prompt часть
            labels = full_input.clone()
            # маскируем (ставим -100) prompt токены чтобы loss считался только по continuation
            labels[:, : prompt_ids.shape[1]] = -100

            # Переотправляем через модель для teacher-forcing loss
            outputs = model(input_ids=full_input, attention_mask=None, labels=labels)
            # outputs.loss — средний negative log likelihood на всех предсказываемых токенах (т.е. continuation)
            loss = outputs.loss  # already averaged over non -100 tokens
            # perplexity = exp(loss)
            ppl = torch.exp(loss).item()
            all_ppls.append(ppl)

    mean_ppl = sum(all_ppls) / len(all_ppls) if all_ppls else float("nan")
    return mean_ppl, all_ppls

def evaluate_checkpoint(model_path: str, label: str):
    print(f"\n=== Evaluating {label} model: {model_path} ===")
    model, tokenizer = load_model_and_tokenizer(model_path)
    ds = load_dataset("SwetieePawsss/UNLamb")
    results = {}
    for split in ["rare_forget15", "popular_forget15"]:
        mean_ppl, per_example = compute_perplexity(model, tokenizer, ds[split], split, max_new_tokens=5)
        results[split] = {"mean_perplexity": mean_ppl, "per_example": per_example}
        print(f"{split}: mean PPL = {mean_ppl:.4f} (n={len(per_example)})")
    return results

if __name__ == "__main__":
    pre_res = evaluate_checkpoint(PRETRAINED, "original")
    ft_res = evaluate_checkpoint(FINETUNED, "finetuned")

    print("\n=== Comparison ===")
    for split in ["rare_forget15", "popular_forget15"]:
        p_pre = pre_res[split]["mean_perplexity"]
        p_ft = ft_res[split]["mean_perplexity"]
        delta = p_pre - p_ft
        print(f"{split}: pre 1B PPL={p_pre:.4f}, pre 3B PPL={p_ft:.4f}, delta (pre - ft)={delta:.4f}")



=== Evaluating original model: meta-llama/Llama-3.2-1B-Instruct ===


PPL rare_forget15:   0%|          | 0/873 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL rare_forget15:   0%|          | 1/873 [00:00<01:40,  8.71it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL rare_forget15:   0%|          | 2/873 [00:00<01:42,  8.49it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL rare_forget15:   0%|          | 3/873 [00:00<01:44,  8.36it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `

rare_forget15: mean PPL = 1.2617 (n=873)


PPL popular_forget15:   0%|          | 0/873 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL popular_forget15:   0%|          | 1/873 [00:00<01:48,  8.05it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL popular_forget15:   0%|          | 2/873 [00:00<01:48,  8.02it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL popular_forget15:   0%|          | 3/873 [00:00<01:40,  8.65it/s]The following generation flags are not valid and may be ignored: ['temperature', 'to

popular_forget15: mean PPL = 1.2417 (n=873)

=== Evaluating finetuned model: meta-llama/Llama-3.2-3B-Instruct ===


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.78it/s]
PPL rare_forget15:   0%|          | 0/873 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL rare_forget15:   0%|          | 1/873 [00:00<02:33,  5.70it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL rare_forget15:   0%|          | 2/873 [00:00<07:02,  2.06it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL rare_forget15:   0%|          | 3/873 [00:01<05:09,  2.81it/s]The following generatio

rare_forget15: mean PPL = 1.5801 (n=873)


PPL popular_forget15:   0%|          | 0/873 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL popular_forget15:   0%|          | 1/873 [00:00<02:53,  5.03it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL popular_forget15:   0%|          | 2/873 [00:00<02:53,  5.03it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
PPL popular_forget15:   0%|          | 3/873 [00:00<02:43,  5.32it/s]The following generation flags are not valid and may be ignored: ['temperature', 'to