In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import numpy as np
import datasets

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [4]:
class Holder:
    def __init__(self):
        pass
args = Holder()
# args.use_cot = False
args.num_mem_tokens = None

In [5]:

id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
think = ans = tokenizer.bos_token_id
eos = tokenizer.eos_token_id

from torch.nn.utils.rnn import pad_sequence
def collate_fn(batch):
    input_ids, input_ids_generate, labels, labels_mask, attention_mask = [], [], [], [], []
    for sample in batch:
        task, lab, cot = sample['task'], sample['labels'], sample['cot']
        task_tokens = tokenizer.encode(task, add_special_tokens=False)
        labels_tokens = tokenizer.encode(lab, add_special_tokens=False)
        cot_tokens = tokenizer.encode(cot, add_special_tokens=False)

        if args.use_cot:
            full_input = task_tokens + [think] + cot_tokens + [ans] + labels_tokens + [eos]
            gen_input = task_tokens + [think]
        else:
            full_input = task_tokens + [ans] + labels_tokens + [eos]
            gen_input = task_tokens + [ans]
        
        inp_ids = torch.tensor(full_input)
        input_ids.append(inp_ids)
        input_ids_generate.append(torch.tensor(gen_input))


        lab = torch.tensor(full_input)
        lab[:len(task_tokens)] = -100
        labels.append(lab)

        lab_mask = torch.ones_like(inp_ids)
        lab_mask[:len(task_tokens)] = 0
        labels_mask.append(lab_mask)
        attention_mask.append(torch.ones_like(inp_ids))

    input_ids = pad_sequence(input_ids, padding_value=id_pad_value, batch_first=True)
    # input_ids_generate = pad_sequence(input_ids_generate, padding_value=id_pad_value, batch_first=True)
    attention_mask = pad_sequence(attention_mask, padding_value=0, batch_first=True)
    labels = pad_sequence(labels, padding_value=id_pad_value, batch_first=True)
    labels_mask = pad_sequence(labels_mask, padding_value=0, batch_first=True)

    collated = {'input_ids': input_ids,
                'input_ids_generate': input_ids_generate,
                'labels': labels,
                'attention_mask': attention_mask,
                }
    if args.num_mem_tokens is not None:
        # add labels mask only for RMT, ARMT
        collated['labels_mask'] = labels_mask.bool()
    return collated
    
think_text = tokenizer.decode(think)
ans_text = tokenizer.decode(ans)

def extract_cot(text):
        try:
                return text[:text.index(ans_text)]
        except ValueError:
                return ''

def extract_answer(text):
        try:
                return text.split(ans_text)[1]
        except IndexError:
                return ''
                
def compute_accuracy(eval_pred):
        preds = eval_pred.predictions.argmax(axis=-1)[:, 1:-1]
        labels = eval_pred.label_ids[:, 2:]

        labels_masks = labels > 0
        preds_full = [p[m] for p, m in zip(preds, labels_masks)]
        labels_full = [lab[m] for lab, m in zip(labels, labels_masks)]

        print(len(preds_full), len(labels_full))

        preds_full_text = tokenizer.batch_decode(preds_full, add_special_tokens=True)
        labels_full_text = tokenizer.batch_decode(labels_full, add_special_tokens=True)

        preds_cot = [extract_cot(p) for p in preds_full_text]
        preds_ans = [extract_answer(p) for p in preds_full_text]

        labels_cot = [extract_cot(lab) for lab in labels_full_text]
        labels_ans = [extract_answer(lab) for lab in labels_full_text]

        acc_cot = np.mean([c == p for c, p in zip(preds_cot, labels_cot)])
        acc_ans = np.mean([c == lab for c, lab in zip(preds_ans, labels_ans)])

        return {'accuracy_cot': acc_cot, 'accuracy_ans': acc_ans}

def evaluate_model_on_dataset(checkpoint_path, valid_dataset, device='cuda'):
    model.load_state_dict(torch.load(checkpoint_path), strict=False)
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    model.to(device)
        
    collated = collate_fn([sample for sample in valid_dataset])
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    gen_outputs = [model.generate(inp.reshape(1, -1).to(device), 
                                pad_token_id=tokenizer.eos_token_id,
                                attention_mask=torch.ones_like(inp.reshape(1, -1)).to(device),
                                max_new_tokens=50)[0] for inp in collated['input_ids_generate']]

    if args.use_cot:
        gen_outputs = [model.generate(inp.reshape(1, -1).to(device), 
                                        pad_token_id=tokenizer.eos_token_id,
                                        attention_mask=torch.ones_like(inp.reshape(1, -1)).to(device))[0].cpu() for inp in gen_outputs]

    labels = collated['labels']
    labels_masks = labels > 0

    preds_full = [out[len(inp):] for inp, out in zip(collated['input_ids_generate'], gen_outputs)]
    # preds_full = [out[len(inp):] for inp, out in zip(collated['input_ids_generate'], gen_outputs_m2)]
    labels_full = [lab[m][1:] for lab, m in zip(labels, labels_masks)]

    print(len(preds_full), len(labels_full))

    preds_full_text = tokenizer.batch_decode(preds_full, add_special_tokens=True)
    labels_full_text = tokenizer.batch_decode(labels_full, add_special_tokens=True)

    preds_cot = [extract_cot(p) for p in preds_full_text]
    preds_ans = [extract_answer(p) for p in preds_full_text]

    labels_cot = [extract_cot(lab) for lab in labels_full_text]
    labels_ans = [extract_answer(lab) for lab in labels_full_text]

    acc_cot = np.mean([c == p for c, p in zip(preds_cot, labels_cot)])
    acc_ans = np.mean([c == lab for c, lab in zip(preds_ans, labels_ans)])

    data = {"inputs": collated['input_ids_generate'],
            "preds_full_text": preds_full_text, "labels_full_text": labels_full_text,
            "preds_cot": preds_cot, "labels_cot": labels_cot,
            "preds_ans": preds_ans, "labels_ans": labels_ans}

    print(f"Accuracy COT: {acc_cot}")
    print(f"Accuracy Answer: {acc_ans}")
    return (acc_cot, acc_ans, data)

In [6]:
import pandas as pd
res_df = pd.DataFrame(columns=['cpt_path', 'cot', 'acc_cot', 'acc_ans'])

In [13]:
dataset = 'booydar/multiplication_4x4'
valid_dataset = datasets.load_dataset(dataset, split='valid')


KeyboardInterrupt: 

In [22]:
# dataset = 'booydar/multiplication_4x4'
# valid_dataset = datasets.load_dataset(dataset, split='valid')

# # cot
# args.use_cot = True

# checkpoints = [
#     "/workspace-SR006.nfs2/Bulatov_A/rmt/runs/4_by_4_mult/gpt2/SEGM_1x1024_1024_LR3e-04-cot/checkpoint-4500/pytorch_model.bin",
# ]

# for cpt_path in checkpoints:
#     print()
#     print(cpt_path)
#     acc_cot, acc_ans,data = evaluate_model_on_dataset(cpt_path, valid_dataset)
#     res_df.loc[len(res_df)] = [cpt_path.split('/')[-3], args.use_cot, acc_cot, acc_ans]

In [14]:
args.use_cot = True
batch = [s for s in valid_dataset]
collated = collate_fn(batch)

In [16]:
'5 6 3 2 * 7 4 3 4<|endoftext|>5 5 5 6 1 + 0 0 6 4 9 0 ( 5 5 1 1 1 1 ) + 0 0 5 9 0 7 0 ( 5 5 6 0 2 8 0 ) + 0 0 0 0 6 4 9 0<|endoftext|>5 5 6 0 8 2 0 1<|endoftext|'[::-1]

'|txetfodne|<1 0 2 8 0 6 5 5>|txetfodne|<0 9 4 6 0 0 0 0 + ) 0 8 2 0 6 5 5 ( 0 7 0 9 5 0 0 + ) 1 1 1 1 5 5 ( 0 9 4 6 0 0 + 1 6 5 5 5>|txetfodne|<4 3 4 7 * 2 3 6 5'

In [None]:
tokenizer.batch_decode(collated['input_ids'])

['5 6 3 2 * 7 4 3 4<|endoftext|>5 5 5 6 1 + 0 0 6 4 9 0 ( 5 5 1 1 1 1 ) + 0 0 5 9 0 7 0 ( 5 5 6 0 2 8 0 ) + 0 0 0 0 6 4 9 0<|endoftext|>5 5 6 0 8 2 0 1<|endoftext|>',
 '6 9 1 5 * 6 4 4 7<|endoftext|>6 7 1 1 3 + 0 4 8 7 0 2 ( 6 1 0 9 3 2 ) + 0 0 4 8 7 0 2 ( 6 1 4 7 1 3 2 ) + 0 0 0 2 7 3 6 3<|endoftext|>6 1 4 9 8 6 8 3<|endoftext|>',
 '6 7 3 9 * 8 9 1 7<|endoftext|>8 0 0 5 7 + 0 4 8 3 4 8 ( 8 4 8 8 1 9 ) + 0 0 6 7 3 9 0 ( 8 4 4 6 5 8 1 ) + 0 0 0 2 3 6 5 6<|endoftext|>8 4 4 8 8 4 7 6<|endoftext|>',
 '3 0 3 4 * 3 4 6 5<|endoftext|>9 0 9 2 1 + 0 2 1 2 7 1 ( 9 2 0 5 8 1 ) + 0 0 8 1 8 5 2 ( 9 2 8 6 6 7 2 ) + 0 0 0 5 1 5 1 2<|endoftext|>9 2 8 1 8 2 4 2<|endoftext|>',
 '0 3 3 7 * 8 5 6 5<|endoftext|>0 4 6 8 5 + 0 0 5 6 6 3 ( 0 4 1 5 2 4 ) + 0 0 0 8 9 3 4 ( 0 4 1 3 2 8 4 ) + 0 0 0 0 5 6 6 3<|endoftext|>0 4 1 3 7 4 1 4<|endoftext|>',
 '3 6 0 6 * 4 3 8 7<|endoftext|>2 5 2 4 2 + 0 9 8 1 8 1 ( 2 4 1 6 0 2 ) + 0 0 4 0 5 8 4 ( 2 4 5 6 5 0 5 ) + 0 0 0 1 4 4 2 4<|endoftext|>2 4 5 7 9 4 7 4<|endoftext|>'

In [None]:
tokenizer.batch_decode(collated['input_ids_generate'])

['5 6 3 2 * 7 4 3 4<|endoftext|>',
 '6 9 1 5 * 6 4 4 7<|endoftext|>',
 '6 7 3 9 * 8 9 1 7<|endoftext|>',
 '3 0 3 4 * 3 4 6 5<|endoftext|>',
 '0 3 3 7 * 8 5 6 5<|endoftext|>',
 '3 6 0 6 * 4 3 8 7<|endoftext|>',
 '4 7 8 4 * 2 9 1 6<|endoftext|>',
 '2 9 6 1 * 0 5 1 9<|endoftext|>',
 '1 1 5 4 * 5 6 0 5<|endoftext|>',
 '3 9 9 3 * 0 9 3 3<|endoftext|>',
 '5 6 8 3 * 1 4 9 7<|endoftext|>',
 '3 7 7 6 * 8 5 5 2<|endoftext|>',
 '1 6 4 2 * 4 9 8 2<|endoftext|>',
 '6 7 2 6 * 6 7 2 1<|endoftext|>',
 '0 6 6 5 * 2 5 6 2<|endoftext|>',
 '8 0 1 9 * 2 6 2 9<|endoftext|>',
 '8 2 5 7 * 2 1 3 8<|endoftext|>',
 '2 7 4 2 * 7 1 4 9<|endoftext|>',
 '2 1 2 1 * 7 3 1 2<|endoftext|>',
 '7 5 5 5 * 8 4 9 4<|endoftext|>',
 '0 3 8 7 * 8 2 5 9<|endoftext|>',
 '8 1 7 8 * 4 1 3 7<|endoftext|>',
 '8 1 0 2 * 0 1 7 9<|endoftext|>',
 '1 0 3 3 * 8 1 6 2<|endoftext|>',
 '3 7 3 7 * 5 9 1 1<|endoftext|>',
 '4 9 5 7 * 0 6 3 6<|endoftext|>',
 '7 0 1 9 * 4 8 4 1<|endoftext|>',
 '9 9 4 5 * 8 3 0 9<|endoftext|>',
 '2 5 0 7 * 3 0 0 1<

In [23]:
args.use_cot = False

checkpoints = [
    # "/workspace-SR006.nfs2/Bulatov_A/rmt/runs/4_by_4_mult/gpt2/SEGM_1x1024_1024_LR3e-04/checkpoint-25000/pytorch_model.bin",
    # "/workspace-SR006.nfs2/Bulatov_A/rmt/runs/4_by_4_mult/gpt2/SEGM_1x1024_1024_LR1e-05/checkpoint-24500/pytorch_model.bin",
    "/workspace-SR006.nfs2/bulatov/rmt/runs/multiplication_4x4/gpt2/SEGM_1x1024_1024_LR6e-04/checkpoint-249000/pytorch_model.bin"

]

for cpt_path in checkpoints:
    print(cpt_path)
    acc_cot, acc_ans, data = evaluate_model_on_dataset(cpt_path, valid_dataset)
    res_df.loc[len(res_df)] = [cpt_path.split('/')[-3], args.use_cot, 0, acc_cot]

/workspace-SR006.nfs2/bulatov/rmt/runs/multiplication_4x4/gpt2/SEGM_1x1024_1024_LR6e-04/checkpoint-249000/pytorch_model.bin
1000 1000
Accuracy COT: 0.976
Accuracy Answer: 0.0


In [18]:
minibatch = [valid_dataset[i] for i in range(10)]
collated_minibatch = collate_fn(minibatch)
out = model(**collated_minibatch)

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


In [21]:
out.logits.shape, out.logits.dtype

(torch.Size([10, 66, 50257]), torch.float32)

In [38]:
torch.save(out.logits[:, :, :1000].cpu().detach().numpy(), 'tmp')

In [39]:
!du -h tmp

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


4.3M	tmp


In [None]:
torch.load('tmp').shape, torch.load('tmp').dtype

torch.Size([10, 66, 1000])

torch.float32

In [33]:
out.logits[:, :, :1000].shape
one_num_weight_mb = 128 / (10 * 66 * 1000)

In [25]:

dataset_size = 100_000
sample_size = 4_000
top_k = 100

num_total = dataset_size * sample_size * top_k
num_total / 1_000_000_000


40.0

In [34]:
num_total * one_num_weight_mb / 1024

7575.757575757576

In [24]:
preds_cot = data['preds_cot']
preds_ans = data['preds_ans']
labels_cot = data['labels_cot']
labels_ans = data['labels_ans']
preds_full_text = data['preds_full_text']
labels_full_text = data['labels_full_text']

In [25]:
preds_full_text[:10]

['5 5 6 0 8 2 0 1<|endoftext|>5 1 1<|endoftext|>',
 '6 1 4 9 8 6 8 3<|endoftext|>6 3 8 4 0 3 0 3 0 3 0 3 0 3 0 3 0 3 9 2',
 '8 4 4 8 8 4 7 6<|endoftext|>8 6 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1',
 '9 2 8 1 8 2 4 2<|endoftext|>9 2 1 1 2 1<|endoftext|>',
 '0 4 1 3 7 4 1 4<|endoftext|>0 1 4 0 4 0 3 2 8 3 2 2 2 2 2 2 2 2 2 2',
 '2 4 5 7 9 4 7 4<|endoftext|>2 5 3 5 1 3 5 9 6 7 7 8 4 2 2 4 8 7 7 7',
 '8 0 8 9 7 1 0 3<|endoftext|>8 0 2 3 2 2 2 2 1 3 2 1 5 1 3 1 5 1 5 1',
 '0 0 8 1 8 4 5 1<|endoftext|>0 5 4 2 8 5 4 5 2 2 2 2 2 2 2 2 2 2 2 2',
 '5 1 2 8 4 8 2 2<|endoftext|>5 1 1 2 1 2 1 2 1 2 1<|endoftext|>',
 '0 7 2 6 3 5 3 1<|endoftext|>0 5 4 2 4 2 8 2 6 0 5 2 8 1 5 2 8 1 1 2']

In [27]:
labels_full_text[:10]

['5 5 6 0 8 2 0 1<|endoftext|>',
 '6 1 4 9 8 6 8 3<|endoftext|>',
 '8 4 4 8 8 4 7 6<|endoftext|>',
 '9 2 8 1 8 2 4 2<|endoftext|>',
 '0 4 1 3 7 4 1 4<|endoftext|>',
 '2 4 5 7 9 4 7 4<|endoftext|>',
 '8 0 8 9 7 1 0 3<|endoftext|>',
 '0 0 8 1 8 4 5 1<|endoftext|>',
 '5 1 2 8 4 8 2 2<|endoftext|>',
 '0 7 2 6 3 5 3 1<|endoftext|>']

In [29]:
labels_cot[:10]

['5 5 6 0 8 2 0 1',
 '6 1 4 9 8 6 8 3',
 '8 4 4 8 8 4 7 6',
 '9 2 8 1 8 2 4 2',
 '0 4 1 3 7 4 1 4',
 '2 4 5 7 9 4 7 4',
 '8 0 8 9 7 1 0 3',
 '0 0 8 1 8 4 5 1',
 '5 1 2 8 4 8 2 2',
 '0 7 2 6 3 5 3 1']

In [31]:
[(p, t) for p, t in zip(preds_cot, labels_cot) if p != t]

[('2 2 5 0 8 1 5 2', '2 2 5 0 7 1 5 2'),
 ('8 6 5 1 1 7 3 3', '8 6 5 1 0 7 3 3'),
 ('5 6 6 2 1 3 1 8', '5 6 6 2 6 3 1 8'),
 ('1 8 0 3 5 8 8 6', '1 8 0 3 6 8 8 6'),
 ('0 4 5 8 2 1 6 5', '0 4 5 8 1 1 6 5'),
 ('2 5 2 4 6 1 4 8', '2 5 2 4 8 1 4 8'),
 ('6 5 6 9 0 5 7 2', '6 5 6 8 0 5 7 2'),
 ('5 1 4 0 9 2 7 3', '5 1 4 8 8 2 7 3'),
 ('8 4 1 3 5 3 8 6', '8 4 1 3 4 3 8 6'),
 ('2 5 6 4 2 6 2 2', '2 5 6 4 3 6 2 2'),
 ('3 5 3 5 1 9 7 1', '3 5 3 6 1 9 7 1'),
 ('2 4 0 7 4 2 6 1', '2 4 0 7 2 2 6 1'),
 ('6 2 2 5 2 5 7 1', '6 2 2 6 2 5 7 1'),
 ('6 2 9 7 7 7 5 9', '6 2 9 7 6 7 5 9'),
 ('3 3 0 9 8 3 4 4', '3 3 0 9 7 3 4 4'),
 ('2 7 1 1 6 8 6 6', '2 7 1 1 5 8 6 6'),
 ('4 4 1 4 3 9 2 0', '4 4 1 5 3 9 2 0'),
 ('9 1 8 3 2 0 3 6', '9 1 8 3 1 0 3 6'),
 ('8 3 7 4 1 2 5 0', '8 3 7 3 1 2 5 0'),
 ('2 0 2 8 7 3 4 2', '2 0 2 0 8 3 4 2'),
 ('4 2 8 0 0 1 3 5', '4 2 8 9 9 0 3 5'),
 ('2 9 7 2 2 5 6 8', '2 9 7 3 2 5 6 8'),
 ('8 8 6 7 8 1 3 2', '8 8 6 7 7 1 3 2'),
 ('7 1 2 2 5 8 9 7', '7 1 2 2 4 8 9 7')]

In [None]:
preds_cot[:10]

['5 5 6 0 8 2 0 1',
 '6 1 4 9 8 6 8 3',
 '8 4 4 8 8 4 7 6',
 '9 2 8 1 8 2 4 2',
 '0 4 1 3 7 4 1 4',
 '2 4 5 7 9 4 7 4',
 '8 0 8 9 7 1 0 3',
 '0 0 8 1 8 4 5 1',
 '5 1 2 8 4 8 2 2',
 '0 7 2 6 3 5 3 1']

In [16]:
labels_ans[:10]

['', '', '', '', '', '', '', '', '', '']

In [17]:
preds_ans[:10]

['5 1 1',
 '6 3 8 4 0 3 0 3 0 3 0 3 0 3 0 3 0 3 9 2',
 '8 6 3 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1',
 '9 2 1 1 2 1',
 '0 1 4 0 4 0 3 2 8 3 2 2 2 2 2 2 2 2 2 2',
 '2 5 3 5 1 3 5 9 6 7 7 8 4 2 2 4 8 7 7 7',
 '8 0 2 3 2 2 2 2 1 3 2 1 5 1 3 1 5 1 5 1',
 '0 5 4 2 8 5 4 5 2 2 2 2 2 2 2 2 2 2 2 2',
 '5 1 1 2 1 2 1 2 1 2 1',
 '0 5 4 2 4 2 8 2 6 0 5 2 8 1 5 2 8 1 1 2']

In [12]:
res_df

Unnamed: 0,cpt_path,cot,acc_cot,acc_ans
0,SEGM_1x1024_1024_LR3e-04-cot,True,1.0,1.0
1,SEGM_1x1024_1024_LR3e-04,False,0.0,0.001
2,SEGM_1x1024_1024_LR1e-05,False,0.0,0.0


### Explicit

In [6]:
# cpt_path = "/workspace-SR006.nfs2/Bulatov_A/rmt/runs/test/4_by_4_mult/Llama-3.2-1B-Instruct/smol:qa1-5-1:9/SEGM_1x1024_1024_64_LR3e-04-lora-mnc-distill__short/checkpoint-5000/pytorch_model.bin"
# cpt_path = "/workspace-SR006.nfs2/Bulatov_A/rmt/runs/test/4_by_4_mult/gpt2/SEGM_1x1024_1024_LR3e-04/checkpoint-25000/pytorch_model.bin"
# cpt_path = "/workspace-SR006.nfs2/Bulatov_A/rmt/runs/gsm8k/gpt2/SEGM_1x1024_1024_LR3e-04/checkpoint-16500/pytorch_model.bin"
# cpt_path = "/workspace-SR006.nfs2/Bulatov_A/rmt/runs/4_by_4_mult/gpt2/SEGM_1x1024_1024_LR3e-04/checkpoint-25000/pytorch_model.bin"
# cpt_path = "/workspace-SR006.nfs2/Bulatov_A/rmt/runs/4_by_4_mult/gpt2/SEGM_1x1024_1024_LR3e-04-cot/checkpoint-4500/pytorch_model.bin"
cpt_path = '/workspace-SR006.nfs2/bulatov/rmt/runs/4_by_4_mult/gpt2/SEGM_1x1024_1024_LR3e-04-continue/checkpoint-182500/pytorch_model.bin'
model.load_state_dict(torch.load(cpt_path), strict=False)

<All keys matched successfully>

In [7]:
# dataset_dir = "/workspace-SR006.nfs2/Bulatov_A/rmt/data/implicit_chain_of_thought/4_by_4_mult"

# train_path = os.path.join(dataset_dir, "train")
# valid_path = os.path.join(dataset_dir, "valid")
# train_dataset = datasets.load_from_disk(train_path)
# valid_dataset = datasets.load_from_disk(valid_path)

dataset = 'booydar/multiplication_4x4'
train_dataset = datasets.load_dataset(dataset, split='train')
valid_dataset = datasets.load_dataset(dataset, split='valid')

In [9]:
device = 'cuda:0'
args.use_cot = False

model.generation_config.pad_token_id = tokenizer.pad_token_id
model.to(device)
    
collated = collate_fn([sample for sample in valid_dataset])
model.generation_config.pad_token_id = tokenizer.pad_token_id
gen_outputs = [model.generate(inp.reshape(1, -1).to(device), 
                            pad_token_id=tokenizer.eos_token_id,
                            attention_mask=torch.ones_like(inp.reshape(1, -1)).to(device),
                            max_new_tokens=50)[0] for inp in collated['input_ids_generate']]

gen_outputs_m2 = [model.generate(inp.reshape(1, -1).to(device), 
                                    pad_token_id=tokenizer.eos_token_id,
                                    attention_mask=torch.ones_like(inp.reshape(1, -1)).to(device))[0].cpu() for inp in gen_outputs]

labels = collated['labels']
labels_masks = labels > 0

preds_full = [out[len(inp):] for inp, out in zip(collated['input_ids_generate'], gen_outputs_m2)]
labels_full = [lab[m][1:] for lab, m in zip(labels, labels_masks)]

print(len(preds_full), len(labels_full))

preds_full_text = tokenizer.batch_decode(preds_full, add_special_tokens=True)
labels_full_text = tokenizer.batch_decode(labels_full, add_special_tokens=True)

preds_cot = [extract_cot(p) for p in preds_full_text]
preds_ans = [extract_answer(p) for p in preds_full_text]

labels_cot = [extract_cot(lab) for lab in labels_full_text]
labels_ans = [extract_answer(lab) for lab in labels_full_text]

acc_cot = np.mean([c == p for c, p in zip(preds_cot, labels_cot)])
acc_ans = np.mean([c == lab for c, lab in zip(preds_ans, labels_ans)])

print(f"Accuracy COT: {acc_cot}")
print(f"Accuracy Answer: {acc_ans}")

1000 1000
Accuracy COT: 0.053
Accuracy Answer: 0.0


In [16]:
preds_cot[:10]

['5 5 6 2 2 3 0 1',
 '6 1 4 2 2 7 8 3',
 '8 4 4 2 2 4 7 6',
 '9 2 8 2 2 2 4 2',
 '0 4 1 3 2 4 1 4',
 '2 4 5 2 2 5 7 4',
 '8 0 8 2 2 2 0 3',
 '0 0 8 1 2 4 5 1',
 '5 1 2 2 2 8 2 2',
 '0 7 2 6 2 5 3 1']

In [17]:
labels_cot[:10]

['5 5 6 0 8 2 0 1',
 '6 1 4 9 8 6 8 3',
 '8 4 4 8 8 4 7 6',
 '9 2 8 1 8 2 4 2',
 '0 4 1 3 7 4 1 4',
 '2 4 5 7 9 4 7 4',
 '8 0 8 9 7 1 0 3',
 '0 0 8 1 8 4 5 1',
 '5 1 2 8 4 8 2 2',
 '0 7 2 6 3 5 3 1']

In [8]:
class Holder:
    def __init__(self):
        pass
args = Holder()
# args.use_cot = False
args.use_cot = True
args.num_mem_tokens = None

In [9]:

id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
think = ans = tokenizer.bos_token_id
eos = tokenizer.eos_token_id

def collate_fn(batch):
    input_ids, labels, labels_mask, attention_mask = [], [], [], []
    for sample in batch:
        task, lab, cot = sample['task'], sample['labels'], sample['cot']
        task_tokens = tokenizer.encode(task, add_special_tokens=False)
        labels_tokens = tokenizer.encode(lab, add_special_tokens=False)
        cot_tokens = tokenizer.encode(cot, add_special_tokens=False)

        if args.use_cot:
            full_input = task_tokens + [think] + cot_tokens + [ans] + labels_tokens + [eos]
        else:
            full_input = task_tokens + [ans] + labels_tokens + [eos]
        inp_ids = torch.tensor(full_input)
        input_ids.append(inp_ids)

        lab = torch.tensor(full_input)
        lab[:len(task_tokens)] = -100
        labels.append(lab)

        lab_mask = torch.ones_like(inp_ids)
        lab_mask[:len(task_tokens)] = 0
        labels_mask.append(lab_mask)
        attention_mask.append(torch.ones_like(inp_ids))

    input_ids = pad_sequence(input_ids, padding_value=id_pad_value, batch_first=True)
    attention_mask = pad_sequence(attention_mask, padding_value=0, batch_first=True)
    labels = pad_sequence(labels, padding_value=id_pad_value, batch_first=True)
    labels_mask = pad_sequence(labels_mask, padding_value=0, batch_first=True)

    collated = {'input_ids': input_ids,
                'labels': labels,
                'attention_mask': attention_mask,
                }
    if args.num_mem_tokens is not None:
        # add labels mask only for RMT, ARMT
        collated['labels_mask'] = labels_mask.bool()
    return collated

In [66]:
think_text = tokenizer.decode(think)
ans_text = tokenizer.decode(ans)

def extract_cot(text):
        try:
                start_index = text.index(think_text)
                end_index = text.index(ans_text, start_index + len(think_text))
                return text[start_index + len(think_text):end_index]
        except ValueError:
                return ''

def extract_answer(text):
        try:
                return text.split(ans_text)[2]
        except IndexError:
                return ''
                
from torch.nn.utils.rnn import pad_sequence
def compute_accuracy(eval_pred):
        preds = eval_pred.predictions.argmax(axis=-1)[:, :-1]
        labels = eval_pred.label_ids[:, 1:]

        labels_masks = labels > 0
        preds_full = [p[m] for p, m in zip(preds, labels_masks)]
        labels_full = [lab[m] for lab, m in zip(labels, labels_masks)]

        print(len(preds_full), len(labels_full))

        preds_full_text = tokenizer.batch_decode(preds_full, add_special_tokens=True)
        labels_full_text = tokenizer.batch_decode(labels_full, add_special_tokens=True)

        preds_cot = [extract_cot(p) for p in preds_full_text]
        preds_ans = [extract_answer(p) for p in preds_full_text]

        labels_cot = [extract_cot(lab) for lab in labels_full_text]
        labels_ans = [extract_answer(lab) for lab in labels_full_text]

        acc_cot = np.mean([c == p for c, p in zip(preds_cot, labels_cot)])
        acc_ans = np.mean([c == lab for c, lab in zip(preds_ans, labels_ans)])

        return {'accuracy_cot': acc_cot, 'accuracy_ans': acc_ans}

In [67]:
batch = [train_dataset[i] for i in range(10)]
collated = collate_fn(batch)

out = model(**collated)
print(out.keys())

odict_keys(['loss', 'logits', 'past_key_values'])


In [68]:
# preds = eval_pred.predictions.argmax(axis=-1)[:, :-1]
# labels = eval_pred.label_ids[:, 1:]

preds = out.logits.argmax(dim=-1).cpu().numpy()[:, :-1]
labels = collated['labels'][:, 1:]

labels_masks = labels > 0
preds_full = [p[m] for p, m in zip(preds, labels_masks)]
labels_full = [lab[m] for lab, m in zip(labels, labels_masks)]

print(len(preds_full), len(labels_full))

preds_full_text = tokenizer.batch_decode(preds_full, add_special_tokens=True)
labels_full_text = tokenizer.batch_decode(labels_full, add_special_tokens=True)

preds_cot = [extract_cot(p) for p in preds_full_text]
preds_ans = [extract_answer(p) for p in preds_full_text]

labels_cot = [extract_cot(lab) for lab in labels_full_text]
labels_ans = [extract_answer(lab) for lab in labels_full_text]

acc_cot = np.mean([c == p for c, p in zip(preds_cot, labels_cot)])
acc_ans = np.mean([c == lab for c, lab in zip(preds_ans, labels_ans)])

10 10


In [69]:
preds_full_text

['<|endoftext|>5 5 6 1 4 + 0 1 3 3 8 0 ( 5 6 9 4 2 1 ) + 0 0 0 0 0 0 0 ( 5 6 9 4 2 1 0 ) + 0 0 0 5 5 6 1 4<|endoftext|>5 6 9 9 7 7 1 4<|endoftext|>',
 '<|endoftext|>0 0 0 0 0 + 0 4 4 2 5 3 ( 0 4 4 2 5 3 ) + 0 0 7 7 6 1 6 ( 0 4 1 0 2 5 6 ) + 0 0 0 8 8 4 0 7<|endoftext|>0 4 1 8 0 0 7 7<|endoftext|>',
 '<|endoftext|>6 9 9 3 1 + 0 8 8 9 1 4 ( 6 7 8 3 3 4 ) + 0 0 6 8 9 8 4 ( 6 7 4 2 3 3 5 ) + 0 0 0 6 8 9 8 4<|endoftext|>6 7 4 8 1 3 4 5<|endoftext|>',
 '<|endoftext|>2 3 4 3 3 + 0 8 4 1 0 5 ( 2 1 9 4 3 5 ) + 0 0 8 8 2 2 2 ( 2 1 7 3 6 7 2 ) + 0 0 0 2 7 5 5 0<|endoftext|>2 1 7 5 3 3 8 0<|endoftext|>',
 '<|endoftext|>2 1 9 4 2 + 0 0 0 0 0 0 ( 2 1 9 4 2 0 ) + 0 0 0 6 7 0 2 ( 2 1 9 0 0 1 2 ) + 0 0 0 4 0 3 8 0<|endoftext|>2 1 9 4 0 4 0 1<|endoftext|>',
 '<|endoftext|>0 6 3 0 4 + 0 5 4 0 5 0 ( 0 1 8 0 9 0 ) + 0 0 5 4 0 5 0 ( 0 1 3 5 9 5 0 ) + 0 0 0 5 3 1 5 1<|endoftext|>0 1 3 0 3 7 5 1<|endoftext|>',
 '<|endoftext|>8 9 8 1 1 + 0 7 4 8 7 1 ( 8 6 3 0 9 1 ) + 0 0 4 6 8 5 1 ( 8 6 7 6 7 7 1 ) + 0 0 0 2 3

In [70]:
labels_full_text

['<|endoftext|>5 5 6 1 4 + 0 1 3 3 8 0 ( 5 6 9 4 2 1 ) + 0 0 0 0 0 0 0 ( 5 6 9 4 2 1 0 ) + 0 0 0 5 5 6 1 4<|endoftext|>5 6 9 9 7 7 1 4<|endoftext|>',
 '<|endoftext|>0 0 0 0 0 + 0 4 4 2 5 3 ( 0 4 4 2 5 3 ) + 0 0 7 7 6 1 6 ( 0 4 1 0 2 5 6 ) + 0 0 0 8 8 4 0 7<|endoftext|>0 4 1 8 0 0 7 7<|endoftext|>',
 '<|endoftext|>6 9 9 3 1 + 0 8 8 9 1 4 ( 6 7 8 3 3 4 ) + 0 0 6 8 9 8 4 ( 6 7 4 2 3 3 5 ) + 0 0 0 6 8 9 8 4<|endoftext|>6 7 4 8 1 3 4 5<|endoftext|>',
 '<|endoftext|>2 3 4 3 3 + 0 8 4 1 0 5 ( 2 1 9 4 3 5 ) + 0 0 8 8 2 2 2 ( 2 1 7 3 6 7 2 ) + 0 0 0 2 7 5 5 0<|endoftext|>2 1 7 5 3 3 8 0<|endoftext|>',
 '<|endoftext|>2 1 9 4 2 + 0 0 0 0 0 0 ( 2 1 9 4 2 0 ) + 0 0 0 6 7 0 2 ( 2 1 9 0 0 1 2 ) + 0 0 0 4 0 3 8 0<|endoftext|>2 1 9 4 0 4 0 1<|endoftext|>',
 '<|endoftext|>0 6 3 0 4 + 0 5 4 0 5 0 ( 0 1 8 0 9 0 ) + 0 0 5 4 0 5 0 ( 0 1 3 5 9 5 0 ) + 0 0 0 5 3 1 5 1<|endoftext|>0 1 3 0 3 7 5 1<|endoftext|>',
 '<|endoftext|>8 9 8 1 1 + 0 7 4 8 7 1 ( 8 6 3 0 9 1 ) + 0 0 4 6 8 5 1 ( 8 6 7 6 7 7 1 ) + 0 0 0 2 3

In [71]:
preds_cot

['5 5 6 1 4 + 0 1 3 3 8 0 ( 5 6 9 4 2 1 ) + 0 0 0 0 0 0 0 ( 5 6 9 4 2 1 0 ) + 0 0 0 5 5 6 1 4',
 '0 0 0 0 0 + 0 4 4 2 5 3 ( 0 4 4 2 5 3 ) + 0 0 7 7 6 1 6 ( 0 4 1 0 2 5 6 ) + 0 0 0 8 8 4 0 7',
 '6 9 9 3 1 + 0 8 8 9 1 4 ( 6 7 8 3 3 4 ) + 0 0 6 8 9 8 4 ( 6 7 4 2 3 3 5 ) + 0 0 0 6 8 9 8 4',
 '2 3 4 3 3 + 0 8 4 1 0 5 ( 2 1 9 4 3 5 ) + 0 0 8 8 2 2 2 ( 2 1 7 3 6 7 2 ) + 0 0 0 2 7 5 5 0',
 '2 1 9 4 2 + 0 0 0 0 0 0 ( 2 1 9 4 2 0 ) + 0 0 0 6 7 0 2 ( 2 1 9 0 0 1 2 ) + 0 0 0 4 0 3 8 0',
 '0 6 3 0 4 + 0 5 4 0 5 0 ( 0 1 8 0 9 0 ) + 0 0 5 4 0 5 0 ( 0 1 3 5 9 5 0 ) + 0 0 0 5 3 1 5 1',
 '8 9 8 1 1 + 0 7 4 8 7 1 ( 8 6 3 0 9 1 ) + 0 0 4 6 8 5 1 ( 8 6 7 6 7 7 1 ) + 0 0 0 2 3 9 7 0',
 '0 0 0 0 0 + 0 2 7 8 3 1 ( 0 2 7 8 3 1 ) + 0 0 8 3 1 2 1 ( 0 2 5 2 5 3 1 ) + 0 0 0 2 7 8 3 1',
 '0 4 8 2 1 + 0 6 7 9 7 1 ( 0 0 6 2 9 1 ) + 0 0 0 4 8 2 1 ( 0 0 6 6 7 4 1 ) + 0 0 0 6 3 1 5 0',
 '6 8 0 0 4 + 0 3 4 0 0 2 ( 6 1 5 0 4 2 ) + 0 0 8 4 4 3 5 ( 6 1 3 5 8 5 5 ) + 0 0 0 4 2 7 6 2']

In [72]:
labels_cot

['5 5 6 1 4 + 0 1 3 3 8 0 ( 5 6 9 4 2 1 ) + 0 0 0 0 0 0 0 ( 5 6 9 4 2 1 0 ) + 0 0 0 5 5 6 1 4',
 '0 0 0 0 0 + 0 4 4 2 5 3 ( 0 4 4 2 5 3 ) + 0 0 7 7 6 1 6 ( 0 4 1 0 2 5 6 ) + 0 0 0 8 8 4 0 7',
 '6 9 9 3 1 + 0 8 8 9 1 4 ( 6 7 8 3 3 4 ) + 0 0 6 8 9 8 4 ( 6 7 4 2 3 3 5 ) + 0 0 0 6 8 9 8 4',
 '2 3 4 3 3 + 0 8 4 1 0 5 ( 2 1 9 4 3 5 ) + 0 0 8 8 2 2 2 ( 2 1 7 3 6 7 2 ) + 0 0 0 2 7 5 5 0',
 '2 1 9 4 2 + 0 0 0 0 0 0 ( 2 1 9 4 2 0 ) + 0 0 0 6 7 0 2 ( 2 1 9 0 0 1 2 ) + 0 0 0 4 0 3 8 0',
 '0 6 3 0 4 + 0 5 4 0 5 0 ( 0 1 8 0 9 0 ) + 0 0 5 4 0 5 0 ( 0 1 3 5 9 5 0 ) + 0 0 0 5 3 1 5 1',
 '8 9 8 1 1 + 0 7 4 8 7 1 ( 8 6 3 0 9 1 ) + 0 0 4 6 8 5 1 ( 8 6 7 6 7 7 1 ) + 0 0 0 2 3 9 7 0',
 '0 0 0 0 0 + 0 2 7 8 3 1 ( 0 2 7 8 3 1 ) + 0 0 8 3 1 2 1 ( 0 2 5 2 5 3 1 ) + 0 0 0 2 7 8 3 1',
 '0 4 8 2 1 + 0 6 7 9 7 1 ( 0 0 6 2 9 1 ) + 0 0 0 4 8 2 1 ( 0 0 6 6 7 4 1 ) + 0 0 0 6 3 1 5 0',
 '6 8 0 0 4 + 0 3 4 0 0 2 ( 6 1 5 0 4 2 ) + 0 0 8 4 4 3 5 ( 6 1 3 5 8 5 5 ) + 0 0 0 4 2 7 6 2']

In [73]:
preds_ans

['5 6 9 9 7 7 1 4',
 '0 4 1 8 0 0 7 7',
 '6 7 4 8 1 3 4 5',
 '2 1 7 5 3 3 8 0',
 '2 1 9 4 0 4 0 1',
 '0 1 3 0 3 7 5 1',
 '8 6 7 8 0 7 9 0',
 '0 2 5 4 2 2 5 1',
 '0 0 6 2 1 6 6 0',
 '6 1 3 9 0 3 2 3']

In [74]:
labels_ans

['5 6 9 9 7 7 1 4',
 '0 4 1 8 0 0 7 7',
 '6 7 4 8 1 3 4 5',
 '2 1 7 5 3 3 8 0',
 '2 1 9 4 0 4 0 1',
 '0 1 3 0 3 7 5 1',
 '8 6 7 8 0 7 9 0',
 '0 2 5 4 2 2 5 1',
 '0 0 6 2 1 6 6 0',
 '6 1 3 9 0 3 2 3']

### Older pretokenize

In [None]:
# from torch.nn.utils.rnn import pad_sequence
# id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
# if args.use_cot in (False, None):
#     inputs_key = 'examples_nocot'
#     labels_key = 'labels_nocot'
# else:
#     inputs_key = 'examples_all'
#     labels_key = 'labels_all'
    
# def collate_fn(batch):
#     input_ids = [torch.tensor(b[inputs_key]) for b in batch]
#     labels = [torch.tensor(b[labels_key]) for b in batch]
#     attention_mask = [torch.ones_like(b, dtype=int) for b in input_ids]
#     # labels_mask defines which input_ids participate in loss calculation
#     labels_mask = [torch.sign(torch.tensor(b[labels_key])) for b in batch]


#     input_ids = pad_sequence(input_ids, padding_value=id_pad_value, batch_first=True)
#     labels = pad_sequence(labels, padding_value=id_pad_value, batch_first=True)
#     attention_mask = pad_sequence(attention_mask, padding_value=0, batch_first=True)
#     labels_mask = pad_sequence(labels_mask, padding_value=0, batch_first=True)

#     collated = {'input_ids': input_ids,
#                 'labels': labels, 
#                 'attention_mask': attention_mask,
#                 }
#     if args.num_mem_tokens is not None:
#         # add labels mask only for RMT, ARMT
#         collated['labels_mask'] = labels_mask.bool()
#     return collated


In [None]:
# def extract_cot(text):
#     if '<|endoftext|>' not in text:
#         return ''
#     else:
#         return text.split('<|endoftext|>')[0].strip()

# def extract_answer(text):
#     if '####' not in text:
#         return ''
#     else:
#         ans = text.split('####')[-1]
#         ans = ans.split('<|endoftext|>')[0]
#         return ans.strip()
        
# def compute_accuracy(eval_pred):
#     preds = eval_pred.predictions[:, :-1]
#     labels = eval_pred.label_ids[:, 1:]
#     # inputs = eval_pred.inputs
#     # losses = eval_pred.losses

#     # labels = collated['labels'][:, 1:]

#     labels_masks = labels > 0
#     preds_full = [p[m] for p, m in zip(preds, labels_masks)]
#     labels_full = [l[m] for l, m in zip(labels, labels_masks)]

#     preds_full_text = tokenizer.batch_decode(preds_full, add_special_tokens=True)
#     labels_full_text = tokenizer.batch_decode(labels_full, add_special_tokens=True)

#     preds_cot = [extract_cot(p) for p in preds_full_text]
#     preds_ans = [extract_answer(p) for p in preds_full_text]

#     labels_cot = [extract_cot(l) for l in labels_full_text]
#     labels_ans = [extract_answer(l) for l in labels_full_text]
    
#     # Calculate accuracy only on the unignored tokens
#     acc_cot = np.mean([c == l for c, l in zip(preds_cot, labels_cot)])
#     acc_ans = np.mean([c == l for c, l in zip(preds_ans, labels_ans)])

#     return {'accuracy_cot': acc_cot, 'accuracy_ans': acc_ans}

In [39]:
# predictions = eval_pred.predictions
# label_ids = eval_pred.label_ids
# inputs = eval_pred.inputs
# losses = eval_pred.losses
# # elements = (self.predictions, self.label_ids)

In [40]:
batch = [valid_dataset[i] for i in range(10)]
collated = collate_fn(batch)

out = model(**collated)
print(out.keys())

odict_keys(['loss', 'logits', 'past_key_values'])


In [41]:
out.loss

tensor(0.0672, grad_fn=<NllLossBackward0>)

In [42]:
preds = out.logits.argmax(dim=-1).cpu().numpy()
preds.shape

(10, 175)

In [43]:
preds_text = tokenizer.batch_decode(preds, add_special_tokens=True)

In [44]:
# preds_text[0].split('<|endoftext|>')

In [46]:
# ''.split('<|endoftext|>')[1]

In [47]:
labels = collated['labels'][:, 1:]

labels_masks = labels > 0
preds_full = [p[m] for p, m in zip(preds[:, :-1], labels_masks)]
labels_full = [l[m] for l, m in zip(labels, labels_masks)]

preds_full_text = tokenizer.batch_decode(preds_full, add_special_tokens=True)
labels_full_text = tokenizer.batch_decode(labels_full, add_special_tokens=True)

preds_cot = [extract_cot(p) for p in preds_full_text]
preds_ans = [extract_answer(p) for p in preds_full_text]

labels_cot = [extract_cot(l) for l in labels_full_text]
labels_ans = [extract_answer(l) for l in labels_full_text]

In [48]:
# labels = collated['labels'][:, 1:]

# labels_masks = labels > 0
# preds_full = [p[m] for p, m in zip(preds[:, :-1], labels_masks)]
# labels_full = [l[m] for l, m in zip(labels, labels_masks)]

# preds_full_text = tokenizer.batch_decode(preds_full, add_special_tokens=True)
# labels_full_text = tokenizer.batch_decode(labels_full, add_special_tokens=True)

# preds_cot = [p.split('<|endoftext|>')[0].strip() for p in preds_full_text]
# labels_cot = [l.split('<|endoftext|>')[0].strip() for l in labels_full_text]

# # preds_ans = [p.split('<|endoftext|>')[1].strip()[4:] for p in preds_full_text]
# # labels_ans = [l.split('<|endoftext|>')[1].strip()[4:] for l in labels_full_text]

# preds_ans = [p.split('####')[1].strip()[4:] for p in preds_full_text]
# labels_ans = [l.split('####')[1].split('astrip()[4:] for l in labels_full_text]



In [49]:
tokenizer.batch_decode(collated['input_ids'])

['John cuts his grass to 2 inches.  It grows .5 inches per month.  When it gets to 4 inches he cuts it back down to 2 inches.  It cost $100 to get his grass cut.  How much does he pay per year?<|endoftext|><<4-2=2>> <<2/.5=4>> <<12/4=3>> <<100*3=300>><|endoftext|>300<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|end

In [50]:
labels_full_text

['<|endoftext|><<4-2=2>> <<2/.5=4>> <<12/4=3>> <<100*3=300>><|endoftext|>300<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><

In [51]:
preds_cot

['', '', '', '', '', '', '', '', '', '']

In [52]:
labels_cot

['', '', '', '', '', '', '', '', '', '']

In [53]:
labels_ans

['', '', '', '', '', '', '', '', '', '']

In [54]:
preds_ans

['', '', '', '', '', '', '', '', '', '']

In [27]:
labels_text
[0].split('<|endoftext|>')

NameError: name 'labels_text' is not defined

In [28]:
for p, l, m in zip(preds, collated['labels'], collated['labels_mask']):
    print(p[m], l[m])

KeyError: 'labels_mask'

In [42]:
preds[0]

array([   11,   860,   860,   604,  1635,   657,   657,   767,   657,
         642,   642,   642,   642,   718,   352,  1343,   657,   657,
         718,   604,   860,   657,   357,   642,   642,   352,   352,
         352,   352,  1267,  1343,   657,   657,   642,   860,   657,
         767,   657,   357,   642,   642,   718,   657,   362,   807,
         657,  1267,  1343,   657,   657,   657,   657,   718,   604,
         860,   657,   220, 50256,  1303, 21017,   642,   642,   718,
         657,   807,   362,   657,   352,   220, 50256,  1303])

In [27]:
pred_texts = tokenizer.batch_decode(preds, skip_special_tokens=False)
print(pred_texts)

[', 9 9 4 * 0 0 7 0 5 5 5 5 6 1 + 0 0 6 4 9 0 ( 5 5 1 1 1 1 ) + 0 0 5 9 0 7 0 ( 5 5 6 0 2 8 0 ) + 0 0 0 0 6 4 9 0 <|endoftext|> #### 5 5 6 0 8 2 0 1 <|endoftext|> #', ' the 0 1 6 0 1 0 8 3 6 6 7 1 1 3 + 0 4 8 7 0 2 ( 6 1 0 9 3 2 ) + 0 0 4 8 7 0 2 ( 6 1 4 7 1 3 2 ) + 0 0 0 2 7 3 6 3 <|endoftext|> #### 6 1 4 9 8 6 8 3 <|endoftext|> #', ' the 0 8 4 + 2 8 3 0 8 8 0 0 5 7 + 0 4 8 3 4 8 ( 8 4 8 8 1 9 ) + 0 0 6 7 3 9 0 ( 8 4 4 6 5 8 1 ) + 0 0 0 2 3 6 5 6 <|endoftext|> #### 8 4 4 8 8 4 7 6 <|endoftext|> #', ', 2 3 2\n 0 0 0 1 9 9 0 9 2 1 + 0 2 1 2 7 1 ( 9 2 0 5 8 1 ) + 0 0 8 1 8 5 2 ( 9 2 8 6 6 7 2 ) + 0 0 0 5 1 5 1 2 <|endoftext|> #### 9 2 8 1 8 2 4 2 <|endoftext|> #', ' the 1 - 0 0 0 1 3 2 0 0 4 6 8 5 + 0 0 5 6 6 3 ( 0 4 1 5 2 4 ) + 0 0 0 8 9 3 4 ( 0 4 1 3 2 8 4 ) + 0 0 0 0 5 6 6 3 <|endoftext|> #### 0 4 1 3 7 4 1 4 <|endoftext|> #', ', 4 5 1 0 8 2 8 1 2 2 5 2 4 2 + 0 9 8 1 8 1 ( 2 4 1 6 0 2 ) + 0 0 4 0 5 8 4 ( 2 4 5 6 5 0 5 ) + 0 0 0 1 4 4 2 4 <|endoftext|> #### 2 4 5 7 9 4 7 4 <|endoftext|

In [28]:
labels_texts = tokenizer.batch_decode(collated['input_ids'], skip_special_tokens=False)
print(labels_texts)

[' 5 6 3 2 * 7 4 3 4 <|endoftext|> 5 5 5 6 1 + 0 0 6 4 9 0 ( 5 5 1 1 1 1 ) + 0 0 5 9 0 7 0 ( 5 5 6 0 2 8 0 ) + 0 0 0 0 6 4 9 0 <|endoftext|> #### 5 5 6 0 8 2 0 1 <|endoftext|>', ' 6 9 1 5 * 6 4 4 7 <|endoftext|> 6 7 1 1 3 + 0 4 8 7 0 2 ( 6 1 0 9 3 2 ) + 0 0 4 8 7 0 2 ( 6 1 4 7 1 3 2 ) + 0 0 0 2 7 3 6 3 <|endoftext|> #### 6 1 4 9 8 6 8 3 <|endoftext|>', ' 6 7 3 9 * 8 9 1 7 <|endoftext|> 8 0 0 5 7 + 0 4 8 3 4 8 ( 8 4 8 8 1 9 ) + 0 0 6 7 3 9 0 ( 8 4 4 6 5 8 1 ) + 0 0 0 2 3 6 5 6 <|endoftext|> #### 8 4 4 8 8 4 7 6 <|endoftext|>', ' 3 0 3 4 * 3 4 6 5 <|endoftext|> 9 0 9 2 1 + 0 2 1 2 7 1 ( 9 2 0 5 8 1 ) + 0 0 8 1 8 5 2 ( 9 2 8 6 6 7 2 ) + 0 0 0 5 1 5 1 2 <|endoftext|> #### 9 2 8 1 8 2 4 2 <|endoftext|>', ' 0 3 3 7 * 8 5 6 5 <|endoftext|> 0 4 6 8 5 + 0 0 5 6 6 3 ( 0 4 1 5 2 4 ) + 0 0 0 8 9 3 4 ( 0 4 1 3 2 8 4 ) + 0 0 0 0 5 6 6 3 <|endoftext|> #### 0 4 1 3 7 4 1 4 <|endoftext|>', ' 3 6 0 6 * 4 3 8 7 <|endoftext|> 2 5 2 4 2 + 0 9 8 1 8 1 ( 2 4 1 6 0 2 ) + 0 0 4 0 5 8 4 ( 2 4 5 6 5 0 5 ) + 0 0 

In [29]:
collated['input_ids'].shape

torch.Size([10, 71])

In [30]:
pred_texts[0]

', 9 9 4 * 0 0 7 0 5 5 5 5 6 1 + 0 0 6 4 9 0 ( 5 5 1 1 1 1 ) + 0 0 5 9 0 7 0 ( 5 5 6 0 2 8 0 ) + 0 0 0 0 6 4 9 0 <|endoftext|> #### 5 5 6 0 8 2 0 1 <|endoftext|> #'

In [33]:
preds[0]

array([   11,   860,   860,   604,  1635,   657,   657,   767,   657,
         642,   642,   642,   642,   718,   352,  1343,   657,   657,
         718,   604,   860,   657,   357,   642,   642,   352,   352,
         352,   352,  1267,  1343,   657,   657,   642,   860,   657,
         767,   657,   357,   642,   642,   718,   657,   362,   807,
         657,  1267,  1343,   657,   657,   657,   657,   718,   604,
         860,   657,   220, 50256,  1303, 21017,   642,   642,   718,
         657,   807,   362,   657,   352,   220, 50256,  1303])

In [None]:
collated['input_ids'][0]

torch.Size([71])

In [38]:
for p, t in zip(preds[0], collated['input_ids'][0][1:]):
    # print(p, tokenizer.decode([p]), t, tokenizer.decode([t]))
    print(tokenizer.decode([p]), tokenizer.decode([t]))

,  6
 9  3
 9  2
 4  *
 *  7
 0  4
 0  3
 7  4
 0  
 5 <|endoftext|>
 5  5
 5  5
 5  5
 6  6
 1  1
 +  +
 0  0
 0  0
 6  6
 4  4
 9  9
 0  0
 (  (
 5  5
 5  5
 1  1
 1  1
 1  1
 1  1
 )  )
 +  +
 0  0
 0  0
 5  5
 9  9
 0  0
 7  7
 0  0
 (  (
 5  5
 5  5
 6  6
 0  0
 2  2
 8  8
 0  0
 )  )
 +  +
 0  0
 0  0
 0  0
 0  0
 6  6
 4  4
 9  9
 0  0
   
<|endoftext|> <|endoftext|>
 #  #
### ###
 5  5
 5  5
 6  6
 0  0
 8  8
 2  2
 0  0
 1  1
   
<|endoftext|> <|endoftext|>
