In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import numpy as np
import datasets

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
sys.path.append('..')
from utils.reasoning import make_segment, split_cot
from torch.nn.utils.rnn import pad_sequence

In [4]:
from modeling_rmt.language_modeling import MemoryCell
from modeling_rmt.experimental import RecurrentWrapperNoSegmentationGenerate

In [28]:
device = 'cuda'
model_name = "HuggingFaceTB/SmolLM2-135M"
checkpoint_path = "/home/user33/kashurin/RMT_SmolLM2-135M/cot/checkpoint-2200/pytorch_model.bin"

model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

pad = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
bos = [tokenizer.bos_token_id]
eos = [tokenizer.eos_token_id]
think = tokenizer.encode("<issue_start>")
ans = tokenizer.encode("<issue_closed>")

delim = ">> <<"


memory_cell = MemoryCell(
    model,
    num_mem_tokens=16
)

model = RecurrentWrapperNoSegmentationGenerate(memory_cell, 
                                             max_n_segments=10, 
                                             think_token_id=think[0],
                                             answer_token_id=ans[0],
                                             bos_token_id=bos[0],
                                             eos_token_id=eos[0]
                                             )

model.load_state_dict(torch.load(checkpoint_path), strict=False)
model.to(device)
print(':)')

:)


In [13]:
prompts = [
    "The future of AI is",
    "In a galaxy far far away",
    "Hello"
]

In [14]:
tokenizer.pad_token = tokenizer.eos_token
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side='left')
inputs

{'input_ids': tensor([[    0,   504,  1774,   282,  5646,   314],
        [  788,   253, 13247,  1869,  1869,  2025],
        [    0,     0,     0,     0,     0, 19556]]), 'attention_mask': tensor([[0, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1]])}

In [15]:
class Holder:
    def __init__(self):
        pass
args = Holder()
args.num_mem_tokens = 16
args.task_name = 'gsm8k'

In [31]:
def collate_fn(batch):
    # first, we segment each sample into task, cot steps and labels
    segments_batch = []
    for sample in batch:
        task, lab, cot = sample['task'], sample['labels'], sample['cot']
        task_tokens = tokenizer.encode(task, add_special_tokens=False)
        labels_tokens = tokenizer.encode(lab, add_special_tokens=False)
        cot_segments = split_cot(cot, by=delim)
        cot_segment_tokens = tokenizer.batch_encode_plus(cot_segments, add_special_tokens=False)['input_ids']

        segments = []
        segments.append(make_segment(bos + task_tokens + think, loss=False))
        for segment in cot_segment_tokens[:-1]:
            segments.append(make_segment(bos + segment + think, loss=True))
        segments.append(make_segment(bos + cot_segment_tokens[-1] + ans, loss=True))

        segments.append(make_segment(bos + labels_tokens + eos, loss=True))
        segments_batch.append(segments)

    # if some samples have less segments than others, we pad them with empty segments
    num_segments = max(len(segments) for segments in segments_batch)
    for segments in segments_batch:
        if len(segments) < num_segments:
            segments.extend([make_segment(eos, loss=False)] * (num_segments - len(segments)))

    # prepare segments for the whole batch
    batch_segments = []
    for i in range(num_segments):
        padding_side = 'right'
        if i == 0:
            padding_side = 'left'

        input_ids = [s[i]['input_ids'] for s in segments_batch]
        attention_mask = [s[i]['attention_mask'] for s in segments_batch]
        labels = [s[i]['labels'] for s in segments_batch]
        labels_mask = [s[i]['labels_mask'] for s in segments_batch]

        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad, padding_side=padding_side)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0, padding_side=padding_side)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100, padding_side=padding_side)
        labels_mask = pad_sequence(labels_mask, batch_first=True, padding_value=False, padding_side=padding_side)

        batch_segment = {'input_ids': input_ids,
                            'attention_mask': attention_mask,
                            'labels_mask': labels_mask,
                            'labels': labels
                            }
        
        batch_segments.append(batch_segment)
    full_labels = torch.cat([s['labels'] for s in batch_segments], dim=1)
    return {"segments": batch_segments, 'labels': full_labels}

In [18]:
dataset = 'booydar/gsm8k'
train_dataset = datasets.load_dataset(dataset, split='train')
valid_dataset = datasets.load_dataset(dataset, split='valid')

In [19]:
class Holder:
    def __init__(self):
        pass
args = Holder()
args.max_new_tokens = 100
args.task_name = 'gsm8k'

In [32]:
all_preds, all_labels = [], []
all_preds_cot, all_labels_cot = [], []
all_preds_ans, all_labels_ans = [], []

batch = valid_dataset.select(range(4))
collated = collate_fn(batch)
task = collated['segments'][0]
task = {k:v.to(device) for k,v in task.items()}

task_length = task['input_ids'].shape[1]

with torch.no_grad():
    gen_out = model.generate(
        [task],
        max_new_tokens=args.max_new_tokens,
        pad_token_id=eos[0]
    )

preds_full = torch.cat(gen_out, dim=1)
labels = collated['labels']
for i, (lab_tokens, pred_tokens) in enumerate(zip(labels, preds_full)):
    labels_mask = lab_tokens != -100
    lab_tokens = lab_tokens[labels_mask].tolist()

    pred_tokens = pred_tokens[task_length:].tolist()
    
    ans_start_index_l = max(i for i, x in enumerate(lab_tokens) if x == ans[0])
    ans_end_index_l = min(i for i, x in enumerate(lab_tokens) if x == eos[0])

    if ans[0] in pred_tokens:
        ans_start_index_p = max(i for i, x in enumerate(pred_tokens) if x == ans[0])
    else:
        ans_start_index_p = ans_start_index_l

    if eos[0] in pred_tokens:
        ans_end_index_p = min(i for i, x in enumerate(pred_tokens) if x == eos[0])
    else:
        ans_end_index_p = ans_end_index_l

    pred_cot_tokens = pred_tokens[:ans_start_index_p]
    lab_cot_tokens = lab_tokens[:ans_start_index_l]

    all_preds_cot.append(pred_cot_tokens)
    all_labels_cot.append(lab_cot_tokens)

    pred_and_tokens = pred_tokens[ans_start_index_p+1:ans_end_index_p]
    lab_ans_tokens = lab_tokens[ans_start_index_l+1:ans_end_index_l]

    all_preds_ans.append(pred_and_tokens)
    all_labels_ans.append(lab_ans_tokens)

    all_preds.append(pred_tokens)
    all_labels.append(lab_tokens)

cot_correct = [p == l for p, l in zip(all_preds_cot, all_labels_cot)]
ans_correct = [p == l for p, l in zip(all_preds_ans, all_labels_ans)]

res = {'accuracy_cot': np.mean(cot_correct), 'accuracy_ans': np.mean(ans_correct)}
data = {"all_preds_cot": all_preds_cot,
        "all_labels_cot": all_labels_cot,
        "all_preds_ans": all_preds_ans,
        "all_labels_ans": all_labels_ans,
        "all_preds": all_preds,
        "all_labels": all_labels}

In [33]:
tokenizer.decode(preds_full[1])

'1.5*2=3<issue_start><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>3+2.5=5.5<issue_start><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>5.5+1.5=7<issue_closed><|endoftext|><|endoftext|><|endoftext|>7<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>7<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>'

In [32]:
task_length

62

In [33]:
tokenizer.decode(pred_tokens)

'<<21/7=3>> <<3*5=15>><issue_closed>15<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>'

In [34]:
res

{'accuracy_cot': np.float64(0.25), 'accuracy_ans': np.float64(0.75)}

In [35]:
tokenizer.decode(data["all_labels_cot"][0])

'<<4-2=2>> <<2/.5=4>> <<12/4=3>> <<100*3=300>>'

In [39]:
tokenizer.decode(data["all_labels_ans"][1])

'10'

In [40]:
tokenizer.decode(data["all_preds_ans"][1])

'10.5'

In [46]:
print("Pred", tokenizer.decode(data["all_preds"][3]))
print("Lab", tokenizer.decode(data["all_labels"][3]))

Pred <<21/7=3>> <<3*5=15>><issue_closed>15<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
Lab <<21/7=3>> <<5*3=15>><issue_closed>15<|endoftext|>


In [26]:
all_preds_ans

[[], [33, 32, 30, 37], [], []]

In [87]:
all_labels_ans

[[35, 32, 32]]

In [60]:
ans_correct = [p == l for p, l in zip(all_preds_ans, all_labels_ans)]

In [61]:
ans_correct

[False]