In [1]:
import sys
sys.path.append("minGPT")   # 让 mingpt 这个包可 import

import torch
from minGPT.mingpt.model import GPT
from minGPT.mingpt.utils import set_seed
from minGPT.mingpt.bpe import BPETokenizer

set_seed(3407)

In [2]:
model_type = 'gpt2'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

tokenizer = BPETokenizer()


Using device: cpu


In [3]:
# 调用minGPT
# 如果要关掉记忆模块，令types=None即可
model, _ = GPT.from_pretrained(model_type, types=None, model_dir=None)
model.to(device);

  from .autonotebook import tqdm as notebook_tqdm


number of parameters: 124.44M


In [4]:
from minGPT.mingpt.program_dataset import ProgramDataset
from torch.utils.data import DataLoader
from dynamic_cheatsheet import DynamicCheatsheetMemory
from dynamic_cheatsheet.config_loader import load_config
from dataclasses import asdict
import random

# 准备数据集
train_dataset = ProgramDataset(
    jsonl_path="./data_reverse_dropvowel/train.jsonl",
    block_size=1024,
    tokenizer=tokenizer,
)

# 得到一个batch的数据
def get_one_batch(dataset, batch_size):
    x = []
    y = []
    prompts = []
    answers = []

    for _ in range(batch_size):
        i = random.randrange(len(dataset))
        idx, target, prompt, answer = dataset[i]
        x.append(idx)
        y.append(target)
        prompts.append(prompt)
        answers.append(answer)
    
    x = torch.stack(x, dim=0)
    y = torch.stack(y, dim=0)
    
    return x, y, prompts, answers

x, y, prompts, answers = get_one_batch(train_dataset, batch_size=32)

print("输入形状：", x.shape)
print("目标形状：", y.shape)
print("prompt的类型", type(prompts), len(prompts))
print("prompt的第15个元素：", prompts[15])
print("answer的类型", type(answers), len(answers))
print("answer的第15个元素：", answers[15])

输入形状： torch.Size([32, 1024])
目标形状： torch.Size([32, 1024])
prompt的类型 <class 'list'> 32
prompt的第15个元素： Task: Reverse the string and remove vowels (a,e,i,o,u).
Input: wcrwbjdqprjw
Output:
answer的类型 <class 'list'> 32
answer的第15个元素： wjrpqdjbwrcw


In [5]:
# 结果评估函数
import Levenshtein

def exact_match(pred: str, gold: str) -> float:
    """
    完全匹配：一模一样返回 1.0，否则 0.0
    """
    return float(pred.strip() == gold.strip())

def normalized_edit_similarity(pred: str, gold: str) -> float:
    """
    1 - normalized edit distance
    取值 [0,1]，1 表示完全相同
    """
    if len(gold) == 0:
        return float(len(pred) == 0)

    dist = Levenshtein.distance(pred, gold)
    return 1.0 - dist / max(len(pred), len(gold))

VOWELS = set("aeiouAEIOU")

def remove_vowels(s: str) -> str:
    return "".join(c for c in s if c not in VOWELS)

def reverse_string(s: str) -> str:
    return s[::-1]

def vowel_removal_accuracy(pred: str) -> float:
    """
    预测中是否完全不包含元音
    """
    return float(all(c not in VOWELS for c in pred))

def reverse_consistency(pred: str, gold: str) -> float:
    """
    判断 pred 是否等于 gold 的 reverse
    对你这个任务主要用于 debug
    """
    return float(pred == gold[::-1])

def evaluate_answer(pred: str, gold: str) -> dict:
    pred = pred.strip()
    gold = gold.strip()

    metrics = {
        "exact_match": exact_match(pred, gold),
        "edit_similarity": normalized_edit_similarity(pred, gold),
        "no_vowel": vowel_removal_accuracy(pred),
        "pred_len": len(pred),
        "gold_len": len(gold),
    }
    return metrics


def evaluate_batch(preds, golds):
    all_metrics = []
    for p, g in zip(preds, golds):
        all_metrics.append(evaluate_answer(p, g))

    # 求平均
    avg = {}
    for k in all_metrics[0]:
        avg[k] = sum(m[k] for m in all_metrics) / len(all_metrics)
    return avg


In [6]:
model.eval();

# 准备数据集
test_dataset = ProgramDataset(
    jsonl_path="./data_reverse_dropvowel/test.jsonl",
    block_size=1024,
    tokenizer=tokenizer,
)

# 得到一个batch的数据
test_size = 100
EOS_ID = 50256
x, y, prompts, answers = get_one_batch(test_dataset, batch_size=test_size)
gen = []

for i in range(test_size):

    prompt_idx = tokenizer(prompts[i])[0].long().unsqueeze(0).to(device)
    gen_ids = model.generate(
        prompt_idx,
        max_new_tokens=20,
        temperature=1.0,
        do_sample=False,
        top_k=None,
        eos_token_id=EOS_ID,
        return_only_generated=True,
    )
    gen_text = tokenizer.decode(gen_ids[0])
    gen.append(gen_text)

metrics = evaluate_batch(gen, answers)
print(f"test评估指标: {metrics}")


test评估指标: {'exact_match': 0.0, 'edit_similarity': 0.12771342218753015, 'no_vowel': 0.02, 'pred_len': 37.05, 'gold_len': 9.25}
