In [1]:
import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

In [2]:
from calculator import get_examples, is_correct
from calculator import sample

In [3]:
import torch
from transformers import GPT2LMHeadModel, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"

def load_best_model(checkpoint_path, model_name="gpt2"):
    model = GPT2LMHeadModel.from_pretrained(model_name)
    
    if os.path.exists(checkpoint_path):
        state_dict = torch.load(checkpoint_path, map_location='cpu')
        model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()
        print(f"✅ 模型已加载并移动到 {device}：{checkpoint_path}")
        return model
    else:
        raise FileNotFoundError(f"模型文件未找到：{checkpoint_path}")
    
model = load_best_model("teacher_checkpoints/0823_teacher_best_model.pt")

tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id



✅ 模型已加载并移动到 cuda：teacher_checkpoints/0823_teacher_best_model.pt


In [4]:
test_examples = get_examples("gsm8k_aug_test")
qn = test_examples[0]["question"]
sample_len = 100
print(qn.strip())
return_text = sample(model, qn, tokenizer, device, sample_len)
print(return_text)

1319 gsm8k_aug_test examples
Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?||<<3+4=7>> <<16-7=9>> <<9*2=18>>####18<|endoftext|><<18*2=36>>####36<|endoftext|><<36<|endoftext|><<36+3=39>>####39<|endoftext|><<39*2=78>>####78<|endoftext|><<78+3=81>>####81<|endoftext|><<81*2=162>>####162<|endoftext|>####162<|endoftext|><<162+3=165>>####165<|endoftext|>####165<|endoftext|>####165<|endoftext|>####165<|endoftext|>####165<|endoftext|><<165+3=168>>####168<|endoftext|>####168<|endoftext|>###

In [5]:
test_examples[0]["answer"]

'18'

In [6]:
is_correct(return_text, test_examples[0])

True

In [7]:
sample_len = 100
correct_num = 0
total_num = len(test_examples)
for i in range(total_num):
    qn = test_examples[i]["question"]
    return_text = sample(model, qn, tokenizer, device, sample_len)
    label = is_correct(return_text, test_examples[i])
    if label:
        correct_num += 1
        print(return_text)

print(f"Accuracy: {correct_num}/{total_num}: {correct_num/total_num:.4f}")

Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?||<<3+4=7>> <<16-7=9>> <<9*2=18>>####18<|endoftext|><<18*2=36>>####36<|endoftext|><<36<|endoftext|><<36+3=39>>####39<|endoftext|><<39*2=78>>####78<|endoftext|><<78+3=81>>####81<|endoftext|><<81*2=162>>####162<|endoftext|>####162<|endoftext|><<162+3=165>>####165<|endoftext|>####165<|endoftext|>####165<|endoftext|>####165<|endoftext|>####165<|endoftext|><<165+3=168>>####168<|endoftext|>####168<|endoftext|>####168<|endoftext|>####
A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bolts in total does it take?||<<2/2=1.0>> <<2+1.0=3.0>>####3<|endoftext|><<3<|endoftext|><<2*3=6>>####6<|endoftext|><<6<|endoftext|><<6+2=8>>####8<|endoftext|><<8<|endoftext|><<8+2=10>>####10<|endoftext|><<10+2=