In [1]:
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
from peft import PeftModel
from utils import get_prompt, get_bnb_config
import argparse

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BASE_MODEL = 'Taiwan-LLM-7B-v2.0-chat'
PEFT_MODEL = 'OUTPUTS/300steps/checkpoint-250'
DATA_PATH = 'data/public_test.json'

In [3]:
def perplexity(
    model, tokenizer, data, prompt_func =lambda x: x, max_length=2048,
):
    data_size = len(data)
    instructions = [prompt_func(x["instruction"]) for x in data]
    outputs = [x["output"] for x in data]

    # Tokenize data
    tokenized_instructions = tokenizer(instructions, add_special_tokens=False)
    tokenized_outputs = tokenizer(outputs, add_special_tokens=False)
    output_masks = []

    # Format data
    for i in range(data_size):
        instruction_input_ids = [tokenizer.bos_token_id] + \
            tokenized_instructions["input_ids"][i]
        output_input_ids = tokenized_outputs["input_ids"][i] + \
            [tokenizer.eos_token_id]
        tokenized_instructions["input_ids"][i] = instruction_input_ids + \
            output_input_ids
        tokenized_instructions["attention_mask"][i] = [
            1] * len(tokenized_instructions["input_ids"][i])
        output_mask = [0] * len(instruction_input_ids) + \
            [1] * len(output_input_ids)

        tokenized_instructions["input_ids"][i] = torch.tensor(
            tokenized_instructions["input_ids"][i][:max_length])
        tokenized_instructions["attention_mask"][i] = torch.tensor(
            tokenized_instructions["attention_mask"][i][:max_length])
        output_mask = torch.tensor(output_mask[:max_length])
        output_masks.append(output_mask)

    # Calculate ppl
    ppls = []
    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
    for i in tqdm(range(data_size)):
        input_ids = tokenized_instructions["input_ids"][i].unsqueeze(0)
        attn_mask = tokenized_instructions["attention_mask"][i].unsqueeze(0)
        output_mask = output_masks[i].unsqueeze(0)
        label = input_ids

        with torch.no_grad():
            out_logits = model(input_ids, attention_mask=attn_mask).logits

        shift_logits = out_logits[..., :-1, :].contiguous()
        shift_label = label[..., 1:].contiguous()
        shift_output_mask = output_mask[..., 1:].contiguous()
        perplexity_batch = torch.exp(
            (loss_fct(shift_logits.transpose(1, 2),
             shift_label) * shift_output_mask).sum(1)
            / shift_output_mask.sum(1)
        )
        ppls += perplexity_batch.tolist()
    return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}

In [4]:
def zero_shot(instruction: str) -> str:
    return f"{instruction}"

def few_shot(instruction: str) -> str:
    return f"以下為範例題目:翻譯成文言文：\n雅裏惱怒地說： 從前在福山田獵時，你誣陷獵官，現在又說這種話。範例答案：雅裏怒曰： 昔畋於福山，卿誣獵官，今復有此言。{instruction}，答案："

In [5]:
# Model
print('Load Model')
bnb_config = get_bnb_config()

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
base_model.eval()
# Model
print('Load Tokenizer')
tokenizer = AutoTokenizer.from_pretrained(
    BASE_MODEL,
    padding_side="right",
    use_fast=False,
    tokenizer_type='llama'
)
tokenizer.add_special_tokens({
    "eos_token": tokenizer.convert_ids_to_tokens(base_model.config.eos_token_id),
    "bos_token": tokenizer.convert_ids_to_tokens(base_model.config.bos_token_id),
    "unk_token": tokenizer.convert_ids_to_tokens(tokenizer.pad_token_id),
})
# Data
with open(DATA_PATH, "r") as f:
    data = json.load(f)

Load Model


Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.17s/it]

Load Tokenizer





In [6]:
zero_ppl = perplexity(base_model, tokenizer, data, zero_shot)
print("Mean perplexity:", zero_ppl["mean_perplexity"])

 55%|█████▍    | 137/250 [01:02<00:46,  2.44it/s]

In [None]:
few_ppl = perplexity(base_model, tokenizer, data, few_shot)
print("Mean perplexity:", few_ppl["mean_perplexity"])

In [None]:
# Load LoRA
lora_model = PeftModel.from_pretrained(base_model, PEFT_MODEL)
lora_model.eval()
lora_ppl = perplexity(lora_model, tokenizer, data, get_prompt)
print("Mean perplexity:", lora_ppl["mean_perplexity"])

100%|██████████| 250/250 [02:41<00:00,  1.55it/s]

Mean perplexity: 3.882264895915985



