### 这个notebook用来测试addition evaluation

In [None]:
from model import GPTConfig, GPT
import torch

init_from = 'resume'


if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = './out/out-addition-label/ckpt_acc.pt'
    checkpoint = torch.load(ckpt_path, map_location='cuda')
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)

In [None]:
from tqdm import tqdm
import torch
import math
import numpy as np

def get_encode_decode(meta_path=None, tokenizer='char'):
    import pickle, tiktoken
    # look for the meta pickle in case it is available in the dataset folder
    load_meta = False
    if meta_path and tokenizer == 'char':
        print(f"Loading meta from {meta_path}...")
        with open(meta_path, 'rb') as f:
            meta = pickle.load(f)
        # TODO want to make this more general to arbitrary encoder/decoder schemes
        stoi, itos = meta['stoi'], meta['itos']
        encode = lambda s: [stoi[c] for c in s]
        decode = lambda l: ''.join([itos[i] for i in l])
    elif tokenizer:
        print(f"Trying to load tiktoken's openAI {tokenizer} tokenizer")
        enc = tiktoken.get_encoding(f"{tokenizer}")
        encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
        decode = lambda l: enc.decode(l)
    else:
        # ok let's assume gpt-2 encodings by default
        print("No meta.pkl found, assuming GPT-2 encodings...")
        enc = tiktoken.get_encoding("gpt2")
        encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
        decode = lambda l: enc.decode(l)

    return encode, decode

def get_abc(expression: str):
    """
    return: a(str), b(str), c(int), operation(str)
    """
    try:
        # 尝试将表达式中的 'a' 和 'b' 转换为整数
        if '+' in expression:
            operation = '+'
        [a, b] = expression.split(operation)
        b = b[:-1]
        if operation == '+':
            # 计算和
            c = int(a) + int(b)

        # 返回结果
        return a, b, c, '+'
    except ValueError:
        # 如果转换失败，抛出异常
        raise ValueError("Invalid input. 'a' and 'b' must be integers.")


def get_num_digits(a: str):
    if a == '':
        return 0
    else:
        if '.' in a: # if a contains a decimal point
            return len(a) - 1
        else:
            return len(str(int(a)))
        
        
def numCarryOps(a, b, binary=False):
    def digitSum(n):
        return sum(map(int,str(n)))
    if b == '':
        return 0
    
    if not binary:
        a,b=int(a),int(b)        
        # assert(a >= 0); assert(b >= 0);
        return int((digitSum(a) + digitSum(b) - digitSum(a+b)) / 9)
    else:
        raise NotImplementedError
        #c = int(a,2) + int(b,2)
        #return int((digitSum(a) + digitSum(b) - digitSum(convert_to_binary(c))) )
        
def is_number(s):
    # handle "xey" case (e.g. 1.2e-3) - we do not use this notation in our dataset
    if 'e' in s:
        return False
    elif 'E' in s:
        return False
    elif 'inf' in s or "INF" in s:
        return False
    try:
        float(s)
        return True
    except ValueError:
        return False

In [None]:
def eval_addition_batch(config, model, ctx, encode, decode, judge = False, num_digit=3):

    model.eval()
    start = config['start']
    device = config['device']
    print(device)
    
    test_batch_size = config['test_batch_size'] if 'test_batch_size' in config.keys() else 128
    max_new_tokens = config['max_new_tokens'] if 'max_new_tokens' in config.keys() else num_digit+4
    
    temperature = config['temperature'] if 'temperature' in config.keys() else 0.2
    top_k = config['top_k'] if 'top_k' in config.keys() else 200
    
    print(f'evaluating addition from: {start}')
    
    if start.startswith('FILE:'):
        with open(start[5:], 'r', encoding='utf-8') as f:
            # 除去每一行后面的空白字符，保存为列表，列表的每一个元素是一个算式，如“2+2=”
            lines = [line.rstrip() for line in f]
    
    else:
        raise NotImplementedError("This method is not implemented yet!")
    
    correct = 0
    pred_correct = 0
    #总行数，也是总算式个数
    total = len(lines)
    
    carry_dictionary={f'carry{i}_correct':0 for i in range(num_digit+1)}
    #注意区别，corrtec和total
    carry_dictionary.update({f'carry{i}_total':0 for i in range(num_digit+1)})
    prompt_dict = {}
    
    for line_idx in tqdm(range(total)):
        #line_idx是所取出算式的index，取出对应行line
        line = lines[line_idx]
        line.strip('\n')
        # 对line这个string做编码
        start_ids = encode(line)
        # 将编码转换为张量，并额外加一个维度，从len(start_ids)变为(1,len(start_ids))
        x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
        # 在character level tokenization时，这个len_x其实就是len(start_ids)。。。
        len_x = len(x[0])
        a,b,c,op = get_abc(line)
        a_d, b_d, num_carry = get_num_digits(a), get_num_digits(b), numCarryOps(a,b)
        prompt_length = len(start_ids)
        # NOTE: prompt_length != len(line) if we're not using character level tokenization
        input_tuple = (x, len(line), line[0], a, b, c, a_d, b_d, num_carry)
        if prompt_length in prompt_dict.keys():
            prompt_dict[prompt_length].append(input_tuple)
        else:
            prompt_dict[prompt_length] = [input_tuple]
        # prompt是一个字典，键值是所有可能出现的prompt_length
        # 这样划分是为了保证每一个batch中的len_x相等
        
    # construct batches of prompts now
    batch_list = []
    for prompt_length in prompt_dict.keys():
        input_tuple_list = prompt_dict[prompt_length]
        for batch_idx in range(math.ceil(len(input_tuple_list)/test_batch_size)):
            #每个sequence（或算式）对应一个tuple，每test_batch_size个tuple划分为同一个batch，对应这一个list，
            # 也就是每个list就是一个batch，所有batch组成一个更大的batch_list
            batch_list.append(input_tuple_list[batch_idx*test_batch_size:(batch_idx+1)*test_batch_size])
                
    for batch_idx in tqdm(range(len(batch_list))):
        batch = batch_list[batch_idx]
        # 单取出所有x
        x_list = [input_tuple[0] for input_tuple in batch]
        # x.size=(batch_size, )
        x = torch.cat(x_list, dim=0)
        # run generation
        with torch.no_grad():
            with ctx:
                y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
                outcome_list = [decode(y_i.tolist()) for y_i in y]
                # 下面逐个分析这个batch中的model的预测结果
                for i, outcome in enumerate(outcome_list):
                    Pred = None
                    # 取出对应的tuple
                    _, len_x, line_start, a, b, c, a_d, b_d, num_carry = batch[i]
                    c_hat = outcome[len_x:]
                    if '$' == line_start: # handle $ prompt $
                        c_hat = c_hat.split('$')[0]
                    else:
                        c_hat = c_hat.split('\n')[0]
                        
                        if 'T' == c_hat[-1] or 'F' == c_hat[-1]:
                            Pred = c_hat[-1]
                            c_hat = c_hat[:-1]
                            
                    c_hat2 = c_hat.strip()
                    c_hat2 = c_hat2.split('\n')[0]
                    
                    if is_number(c_hat2):
                        if '.' in c_hat2:
                            c_hat2 = float(c_hat2)
                        else:
                            c_hat2 = int(c_hat2)
                    else: # c_hat2 is not a number
                        c = str(c)
                        
                    if op in ['+','-','*']:
                        if c == c_hat2:
                            correct+=1
                            carry_dictionary[f'carry{num_carry}_correct']+=1
                            if Pred == 'T':
                                pred_correct+=1
                                
                        elif Pred == 'F':
                            pred_correct+=1
                    else:
                        raise NotImplementedError
                    
                    
                    carry_dictionary[f'carry{num_carry}_total']+=1
                    # metric_types = ['mse', 'normalized_mse', 'digit_wise_difference', 'incorrect_digit_count']
    if judge:
        pred_accuracy = pred_correct/total*100
        print(f"Judgement accuracy of {total} examples: {pred_correct}/{total} ({pred_accuracy}%)")
    accuracy = correct/total*100
    print(f"accuracy of {total} examples: {correct}/{total} ({accuracy}%)")
    accuracy_dictionary = {f'carry{i}': carry_dictionary[f'carry{i}_correct']/carry_dictionary[f'carry{i}_total']*100 \
        if carry_dictionary[f'carry{i}_total']!=0 else np.nan for i in range(num_digit+1)}
    print(accuracy_dictionary)
    
    model.train()
    
    return accuracy, accuracy_dictionary

In [None]:
# evaluation
config={
    'start': 'FILE:./data/addition_label/prompt_addition_label_labeled10000.txt',
    'device': 'cuda',
}

In [None]:
encode, decode = get_encode_decode('./data/addition_label/meta.pkl')

In [None]:
from contextlib import nullcontext

dtype = 'bfloat16'
device = 'cuda'
model = model.to(device)
device_type = 'cuda'
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]

ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
eval_addition_batch(config, model, ctx, encode, decode, judge=True)

In [11]:
def give_samples(config, model, encode, decode, device, num_digit=3):
    start = config['start']
    
    if start.startswith('FILE:'):
        with open(start[5:], 'r', encoding='utf-8') as f:
            # 除去每一行后面的空白字符，保存为列表，列表的每一个元素是一个算式，如“2+2=”
            lines = [line.rstrip() for line in f]
    
    else:
        raise NotImplementedError("This method is not implemented yet!")
    
    total = 100
    output = []
    for line_idx in tqdm(range(total)):
        #line_idx是所取出算式的index，取出对应行line
        line = lines[line_idx]
        line = line.split('\n')[0]
        # 对line这个string做编码
        start_ids = encode(line)
        # 将编码转换为张量，并额外加一个维度，从len(start_ids)变为(1,len(start_ids))
        x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
        y = model.generate(x, max_new_tokens=num_digit+2, temperature=0.2)
        outcome_list = [decode(y_i.tolist()) for y_i in y]
        output.append(outcome_list[0])
    
    return output
#out = give_samples(config, model, encode, decode, 'cuda')

#outpath = 'samples.txt'
#with open(outpath, 'w') as f:
#    for instance in out:
#        exp = instance.strip()
#        f.write(exp+'\n')

In [1]:
from main_utils import *

In [6]:
from model import GPTConfig, GPT
import torch

init_from = 'resume'


if init_from == 'resume':
    # init from a model saved in a specific directory
    ckpt_path = './out/out-addition-labelV2/ckpt_acc.pt'
    checkpoint = torch.load(ckpt_path, map_location='cuda')
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)

number of parameters: 10.63M


In [7]:
# evaluation
config={
    'start': 'FILE:./data/addition_labelV2/prompt_3digit_V210000.txt',
    'device': 'cuda',
}

In [8]:
encode, decode = get_encode_decode('./data/addition_labelV2/meta.pkl')

Loading meta from ./data/addition_labelV2/meta.pkl...


In [9]:
from contextlib import nullcontext

dtype = 'bfloat16'
device = 'cuda'
model = model.to(device)
device_type = 'cuda'
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]

ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
eval_addition_batch(config, model, ctx, encode, decode, judge=True)

evaluating addition from: FILE:./data/addition_labelV2/prompt_3digit_V210000.txt


  0%|          | 0/10000 [00:00<?, ?it/s]

100%|██████████| 10000/10000 [00:07<00:00, 1251.19it/s]
100%|██████████| 80/80 [00:04<00:00, 19.76it/s]

Judgement accuracy of 10000 examples: 9429/10000 (94.28999999999999%)
accuracy of 10000 examples: 9429/10000 (94.28999999999999%)
{'carry0': 94.25219941348973, 'carry1': 92.57330775554946, 'carry2': 95.13888888888889, 'carry3': 96.92653673163419}





(94.28999999999999,
 94.28999999999999,
 {'carry0': 94.25219941348973,
  'carry1': 92.57330775554946,
  'carry2': 95.13888888888889,
  'carry3': 96.92653673163419})

In [12]:
out = give_samples(config, model, encode, decode, 'cuda')

outpath = 'samples.txt'
with open(outpath, 'w') as f:
    for instance in out:
        exp = instance.strip()
        exp = exp.split('\n')[0]
        f.write(exp+'\n')

100%|██████████| 100/100 [00:04<00:00, 20.25it/s]
