In [18]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import numpy as np
import datasets

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

In [19]:
import sys
sys.path.append('..')
from utils.reasoning import make_segment, split_cot
from torch.nn.utils.rnn import pad_sequence

In [41]:
from torch.nn import CrossEntropyLoss
from transformers import StoppingCriteria
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

from modeling_rmt.language_modeling import MemoryCell
# from modeling_rmt.experimental import RecurrentWrapperNoSegmentation

from modeling_rmt.experimental import RecurrentWrapperNoSegmentationGenerate

In [42]:
model_name = "HuggingFaceTB/SmolLM2-135M"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
id_pad_value = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
bos = [tokenizer.bos_token_id]
eos = [tokenizer.eos_token_id]
think = tokenizer.encode("<issue_start>")
ans = tokenizer.encode("<issue_closed>")
delim = ">> <<"

class StopOnSpecialTokenCriteria(StoppingCriteria):
    def __init__(self, special_token_ids):
        self.special_token_ids = set(special_token_ids)

    def __call__(self, input_ids, scores, **kwargs):
        last_token = input_ids[0, -1].item()
        return last_token in self.special_token_ids


class RecurrentWrapperNoSegmentationGenerate(RecurrentWrapperNoSegmentation):
    def forward(self, segments, labels, output_attentions=None, output_hidden_states=None):
        memory_state = None

        cell_outputs = []
        for seg_num, segment in enumerate(segments):
            cell_out, memory_state = self.memory_cell(input_ids=segment['input_ids'],
                                                      attention_mask=segment['attention_mask'],
                                                      memory_state=memory_state, output_hidden_states=True)
            cell_outputs.append(cell_out)
            self.manage_gradients(memory_state, seg_num)

        out = self.process_outputs(cell_outputs, segments,
                                   output_attentions=output_attentions,
                                   output_hidden_states=output_hidden_states)
        return out
    
    def generate(self, segments, **kwargs):
        memory_state = None
        
        # cell_outputs = []
        for seg_num, segment in enumerate(segments):
            cell_out, memory_state = self.memory_cell(input_ids=segment['input_ids'],
                                                        attention_mask=segment['attention_mask'],
                                                        memory_state=memory_state, output_hidden_states=True)
            # cell_outputs.append(cell_out)

        generated_segments = []
        for seg_num in range(len(segments), self.rmt_config.get("max_n_segments", 32)):
            output_ids, memory_state = self.generate_segment(memory_state=memory_state, **kwargs)
            generated_segments.append(output_ids)
            
            if self.all_done(generated_segments):
                break

        # return self.post_process_generated_segments(generated_segments)
        return generated_segments


    def generate_segment(self, memory_state, **kwargs):
        input_ids = self.get_bos_tensor(memory_state)
        attention_mask = torch.ones_like(input_ids).bool()

        generated = self.memory_cell.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            memory_state=memory_state,
            # max_new_tokens=kwargs.get('max_new_tokens', 50),
            stopping_criteria=self.make_custom_stopping_criteria(),
            **kwargs
        )

        # Update memory from generation
        fwd_inputs = torch.cat((input_ids, generated), dim=1)[:, :-1]
        _, memory_state = self.memory_cell(input_ids=fwd_inputs, memory_state=memory_state)

        return generated, memory_state
    
    def get_bos_tensor(self, memory_state):
        bos = self.rmt_config["bos_token_id"]
        bos_tensor = torch.tensor([bos] * memory_state.shape[0]).reshape(-1, 1)
        return bos_tensor.to(memory_state.device)
    
    def all_done(self, generated_segments):
        eos = self.rmt_config['eos_token_id']
        bs = generated_segments[0].shape[0]
        have_eos = [any([eos in seg[i] for seg in generated_segments]) for i in range(bs)]
        all_done = all(have_eos)
        return all_done
    
    def make_custom_stopping_criteria(self):
        return [StopOnSpecialTokenCriteria([self.rmt_config['think_token_id'], self.rmt_config['answer_token_id']])]


In [43]:
def collate_fn(batch):
    # first, we segment each sample into task, cot steps and labels
    segments_batch = []
    for sample in batch:
        task, lab, cot = sample['task'], sample['labels'], sample['cot']
        task_tokens = tokenizer.encode(task, add_special_tokens=False)
        labels_tokens = tokenizer.encode(lab, add_special_tokens=False)
        if getattr(args, 'use_cot', False):
            cot_segments = split_cot(cot, by=delim)
        else:
            cot_segments = [cot]
        cot_segment_tokens = tokenizer.batch_encode_plus(cot_segments, add_special_tokens=False)['input_ids']

        segments = []
        segments.append(make_segment(bos + task_tokens + think, loss=False))
        for segment in cot_segment_tokens[:-1]:
            segments.append(make_segment(bos + segment + think, loss=True))
        segments.append(make_segment(bos + cot_segment_tokens[-1] + ans, loss=True))

        segments.append(make_segment(bos + labels_tokens + eos, loss=True))
        segments_batch.append(segments)

    # if some samples have less segments than others, we pad them with empty segments
    num_segments = max(len(segments) for segments in segments_batch)
    for segments in segments_batch:
        if len(segments) < num_segments:
            segments.extend([make_segment(eos, loss=False)] * (num_segments - len(segments)))

    # prepare segments for the whole batch
    batch_segments = []
    for i in range(num_segments):
        input_ids = [s[i]['input_ids'] for s in segments_batch]
        attention_mask = [s[i]['attention_mask'] for s in segments_batch]
        labels = [s[i]['labels'] for s in segments_batch]
        labels_mask = [s[i]['labels_mask'] for s in segments_batch]
        # FIXME, pad by right side!!!

        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=id_pad_value)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        labels_mask = pad_sequence(labels_mask, batch_first=True, padding_value=False)

        batch_segment = {'input_ids': input_ids,
                         'attention_mask': attention_mask,
                         'labels_mask': labels_mask,
                         'labels': labels
                         }
        batch_segments.append(batch_segment)
    full_labels = torch.cat([s['labels'] for s in batch_segments], dim=1)
    return {"segments": batch_segments, 'labels': full_labels}

In [44]:
class Holder:
    def __init__(self):
        pass

args = Holder()
args.use_cot = True
args.num_mem_tokens = 16
args.segment_size = 64
args.max_n_segments = 10
args.max_cot_steps = 8
args.task_name = 'gsm8k'

In [45]:
memory_cell = MemoryCell(
    model,
    num_mem_tokens=16
)

rmt = RecurrentWrapperNoSegmentationGenerate(
    memory_cell, 
    max_n_segments=8, 
    think_token_id=think[0],
    answer_token_id=ans[0],
    bos_token_id=bos[0],
    eos_token_id=eos[0]
)

# checkpoint_path = "/home/user33/kashurin/RMT_SmolLM2-135M/pt/checkpoint-29500/pytorch_model.bin"
checkpoint_path = "/home/user33/kashurin/RMT_SmolLM2-135M/cot/checkpoint-2200/pytorch_model.bin"
device = 'cpu'

rmt.load_state_dict(torch.load(checkpoint_path), strict=False)
rmt.to(device)
print(':)')

:)


In [46]:
dataset = 'booydar/gsm8k'
train_dataset = datasets.load_dataset(dataset, split='train')
valid_dataset = datasets.load_dataset(dataset, split='valid')

In [47]:
valid_dataset[0]

{'task': 'John cuts his grass to 2 inches.  It grows .5 inches per month.  When it gets to 4 inches he cuts it back down to 2 inches.  It cost $100 to get his grass cut.  How much does he pay per year?',
 'labels': '300',
 'cot': '<<4-2=2>> <<2/.5=4>> <<12/4=3>> <<100*3=300>>',
 'cot_len': 4}

In [88]:
def generate_text(text_batch: list[str], max_new_tokens=64):
    collated = collate_fn(batch)
    task = collated['segments'][0]
    task = {k:v.to(device) for k,v in task.items()}

    with torch.no_grad():
        gen_out = rmt.generate([task], max_new_tokens=max_new_tokens, pad_token_id=eos[0])
    preds_full = torch.cat(gen_out, dim=1)
    
    gen_text = tokenizer.batch_decode(preds_full, skip_special_tokens=True)

    # generations = []
    # for gen in gen_segments:
    #     gen_text = tokenizer.batch_decode(gen, skip_special_tokens=False)
    #     generations.append(gen_text)

    # full_gen = []
    # for i in range(len(text_batch)):
    #     full_gen.append([])
    #     for gen in generations:
    #         full_gen[-1].append(gen[i])
        
    return gen_text

In [89]:
batch = valid_dataset.select(range(4))

In [90]:
gen_text = generate_text(batch)

In [91]:
gen_text

['2*4=88*2=1616*1=161200-16=11641164/12=949419*1=9',
 '1.5*2=1.5*2.51.5+2.5+2.5*1.5=3.752.5*3.75=112*1.75=3.53+1+1+1+1+1+3.75+3=5.55555555555555555',
 '2000*32000-202000-20020000*0.05=102000000000020000000600000',
 '21/7=321*5=1014*1=1414*2=28214+28=4242422*15=3']

In [62]:
def evaluate(model, dataset, device='cpu', bs=16, max_new_tokens=25):
    all_preds, all_labels = [], []
    all_preds_cot, all_labels_cot = [], []
    all_preds_ans, all_labels_ans = [], []

    for start_ind in range(0, len(dataset), bs):
        batch = dataset.select(range(start_ind, min(len(dataset), start_ind + bs)))
        collated = collate_fn(batch)
        task = collated['segments'][0]
        task = {k:v.to(device) for k,v in task.items()}

        with torch.no_grad():
            gen_out = model.generate([task], max_new_tokens=max_new_tokens, pad_token_id=eos[0])
        

        preds_full = torch.cat(gen_out, dim=1)
        labels = collated['labels']

        labels_masks = labels > 0
        labels_full = [lab[m] for lab, m in zip(labels, labels_masks)]

        for lab_tokens, pred_tokens in zip(labels_full, preds_full):
            lab_tokens = [t.item() for t in lab_tokens if t != bos[0]]
            
            ans_start_index_l = max(i for i, x in enumerate(lab_tokens) if x == ans[0])
            if ans[0] in pred_tokens:
                ans_start_index_p = max(i for i, x in enumerate(pred_tokens) if x == ans[0])
            else:
                ans_start_index_p = ans_start_index_l
            ans_start_index_p = ans_start_index_l

            pred_cot_tokens = pred_tokens[:ans_start_index_p].tolist()
            lab_cot_tokens = lab_tokens[:ans_start_index_l]

            all_preds_cot.append(pred_cot_tokens)
            all_labels_cot.append(lab_cot_tokens)

            all_preds_ans.append(pred_tokens[ans_start_index_p:].tolist())
            all_labels_ans.append(lab_tokens[ans_start_index_l:])

            all_preds.append(pred_tokens.tolist())
            all_labels.append(lab_tokens)
    
    cot_correct = [p == l for p, l in zip(all_preds_cot, all_labels_cot)]
    ans_correct = [p == l for p, l in zip(all_preds_ans, all_labels_ans)]
    res = {'accuracy_cot': np.mean(cot_correct), 'accuracy_ans': np.mean(ans_correct)}
    data = {"all_preds_cot": all_preds_cot,
            "all_labels_cot": all_labels_cot,
            "all_preds_ans": all_preds_ans,
            "all_labels_ans": all_labels_ans,
            "all_preds": all_preds,
            "all_labels": all_labels}
    return res, data

In [70]:
res, data = evaluate(rmt, valid_dataset.select(range(1)), bs=1, device=device)

In [71]:
tokenizer.decode(data["all_preds"][0])

'2*4=8<issue_start>8*2=16<issue_start>16*12=192<issue_start>192<|endoftext|>'

In [73]:
generation = generate_text(["London is the capital of ", "Hello "])

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


In [74]:
print(generation)

[[', the world', ', and/', '\n- The'], ['\nThe name', '\n- The', '.\nThe']]


In [79]:
task_batch = [
    valid_dataset[0]["task"],
    valid_dataset[1]["task"]
]

generation = generate_text(task_batch)

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


In [80]:
generation

[['2*4', '2*4', '2*2', '12*', '50*', '12*', '500'],
 ['1.5', '1.5', '2.5', '2.5', '2.5', '2.5', '2.5']]

In [82]:
tokenizer.decode([0,  9682, 10672,   650,  5372,   288,   216,    34,  6439,    30,
             216,   657,  8759,  1673,    37,  6439,   567,  3531,    30,   216,
            1550,   357,  4364,   288,   216,    36,  6439,   384, 10672,   357,
            1056,  1187,   288,   216,    34,  6439,    30,   216,   657,  1708,
            1885,    33,    32,    32,   288,   820,   650,  5372,  2304,    30,
             216,  1073,  1083,  1072,   384,  2290,   567,   713,    47,     8, eos[0]], skip_special_tokens=False)

'<|endoftext|>John cuts his grass to 2 inches.  It grows .5 inches per month.  When it gets to 4 inches he cuts it back down to 2 inches.  It cost $100 to get his grass cut.  How much does he pay per year?<issue_start><|endoftext|>'

In [14]:
inputs = tokenizer("Today is a beautiful", return_tensors="pt")
outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    num_beams=3,
    do_sample=False,
    top_p=0.95,
    # return_dict_in_generate=True,
    # output_scores=True
)

The following generation flags are not valid and may be ignored: ['top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


In [13]:
tokenizer.decode(outputs, skip_special_tokens=False)

TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

In [10]:
for seq, score in zip(outputs.sequences, outputs.sequences_scores):
    text = tokenizer.decode(seq, skip_special_tokens=False)
    print(f"Generated ({score:.2f}): {text}")

Generated (-1.43): Today is a beautiful day.

The


In [None]:
bs, N, H = 2, 4, 64

In [None]:
inputs = torch.randn(bs, N, H)

In [None]:
mask = torch.tensor([[True, False, False, True],
                    [False, True, False, False]])

In [None]:
new_values = torch.randn(mask.sum(), H)