In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import numpy as np
import datasets

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
sys.path.append('..')
from utils.reasoning import make_segment, split_cot
from torch.nn.utils.rnn import pad_sequence

In [125]:
import torch
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

class MemoryCellSmart(torch.nn.Module):
    def __init__(self, base_model, num_mem_tokens):
        super().__init__()
        self.model = base_model
        self.num_mem_tokens = num_mem_tokens
        embeddings = self.model.get_input_embeddings()
        memory_dim = getattr(self.model.config, 'n_embd', self.model.config.hidden_size)
        memory_weights = torch.randn((num_mem_tokens, memory_dim)) * embeddings.weight.data.std()
        self.register_parameter('memory', torch.nn.Parameter(memory_weights, requires_grad=True))

    def set_memory(self, input_shape):
        return self.memory.repeat(input_shape[0], 1, 1)
    
    def put_tensor_by_mask(self, inputs_embeds, memory_state, mem_mask):
        bs, N, H = inputs_embeds.shape

        for i in range(bs):
            inputs_embeds[i, mem_mask[i]] = memory_state[i]
        
        return inputs_embeds
    
    def extract_tensor_by_mask(self, outputs, mask):
        bs, N, H = outputs.shape
        M = mask.sum(dim=1)[0].item()

        extracted = outputs[mask]
        return extracted.view(bs, M, H)

    def process_input(
            self,
            input_ids,
            memory_state,
            write_mem,
            read_mem_mask=None,
            write_mem_mask=None,
            **kwargs
        ):
        seg_kwargs = dict(**kwargs)

        inputs_embeds = kwargs.get('inputs_embeds')
        if inputs_embeds is None:
            inputs_embeds = self.model.get_input_embeddings()(input_ids)
        else:
            raise ValueError("inputs_embeds is not supported for memory cells") # test "if"

        inputs_embeds = self.put_tensor_by_mask(inputs_embeds, memory_state, read_mem_mask)
        if write_mem:
            inputs_embeds = self.put_tensor_by_mask(inputs_embeds, memory_state, write_mem_mask)
        
        seg_kwargs['input_ids'] = None
        seg_kwargs['inputs_embeds'] = inputs_embeds
        seg_kwargs['attention_mask'] = kwargs['attention_mask']
        seg_kwargs['output_hidden_states'] = True
        return seg_kwargs

    def forward(
            self,
            input_ids,
            memory_state=None,
            text_mask=None,
            read_mem_mask=None,
            write_mem_mask=None,
            **kwargs
        ):
        if memory_state is None:
            memory_state = self.set_memory(input_ids.shape)

        seg_kwargs = self.process_input(
            input_ids,
            memory_state,
            write_mem=True,
            read_mem_mask=read_mem_mask,
            write_mem_mask=write_mem_mask,
            **kwargs
        )
        out = self.model(**seg_kwargs)
        out, new_memory_state = self.process_output(
            out,
            text_mask=text_mask,
            write_mem_mask=write_mem_mask,
            **kwargs
        )

        return out, new_memory_state

    def generate(
        self,
        input_ids,
        memory_state,
        attention_mask=None,
        read_mem_mask=None,
        write_mem_mask=None,
        **generate_kwargs):
        if memory_state is None:
            memory_state = self.set_memory(input_ids.shape)

        seg_kwargs = self.process_input(
            input_ids,
            memory_state,
            attention_mask=attention_mask,
            write_mem=False,
            read_mem_mask=read_mem_mask,
            write_mem_mask=write_mem_mask,
        )
        out = self.model.generate(inputs_embeds=seg_kwargs['inputs_embeds'], attention_mask=seg_kwargs['attention_mask'], **generate_kwargs)
        return out


    def process_output(self, model_outputs, text_mask, write_mem_mask, **kwargs):
        if self.num_mem_tokens not in {0, None}:
            out = CausalLMOutputWithCrossAttentions()
            memory_state = self.extract_tensor_by_mask(model_outputs.hidden_states[-1], write_mem_mask)
            out['logits'] = self.extract_tensor_by_mask(model_outputs.logits, text_mask)

            if kwargs.get('output_hidden_states'):
                out['hidden_states'] = [self.extract_tensor_by_mask(lh, text_mask)
                                        for lh in model_outputs.hidden_states]

            if kwargs.get('output_attentions'):
                print(model_outputs['attentions'].shape)
                out['attentions'] = model_outputs['attentions']
        else:
            memory_state = None
            out = model_outputs

        return out, memory_state


class RecurrentWrapper(torch.nn.Module):
    def __init__(self, memory_cell, **rmt_kwargs):
        super().__init__()
        self.memory_cell = memory_cell
        self.rmt_config = rmt_kwargs

    def process_outputs(self, cell_outputs, segments, **kwargs):
        out = CausalLMOutputWithCrossAttentions()
        proxy_out = {}
        for seg_num, segment in enumerate(segments):
            cell_out = cell_outputs[seg_num]

            full_logits = cell_out.logits

            labels = segment.get('labels')
            if labels is not None:
                shift_labels = labels[..., 1:].contiguous()
                shift_logits = full_logits[..., :-1, :].contiguous()
                flat_labels = shift_labels.view(-1)
                flat_logits = shift_logits.view(-1, shift_logits.size(-1))

                loss_fct = CrossEntropyLoss()
                labels_mask = segment.get('labels_mask')
                if labels_mask is not None:
                    shift_mask = labels_mask[..., :-1].contiguous()

                    flat_labels = flat_labels[shift_mask.view(-1)]
                    flat_logits = flat_logits[shift_mask.view(-1)]

                    if labels_mask.sum() == 0:
                        loss_value = 0
                    else:
                        loss_value = loss_fct(flat_logits, flat_labels)

                proxy_out[f'loss_{seg_num}'] = loss_value
            else:
                proxy_out[f'loss_{seg_num}'] = 0

            segment_keys = ['loss']
            if kwargs.get('output_attentions'):
                segment_keys.append('attentions')
            if kwargs.get('output_hidden_states'):
                segment_keys.append('hidden_states')

            for key, value in cell_out.items():
                if any([sk in key for sk in segment_keys]):
                    proxy_out[f'{key}_{seg_num}'] = value

        num_segments = len(segments)
        out['loss'] = sum([proxy_out[f'loss_{seg_num}'] for seg_num in range(num_segments)]) / num_segments
        out['logits'] = torch.cat([cell_out.logits for cell_out in cell_outputs], dim=1)

        return out
    
    def forward(self, segments, labels, output_attentions=None, output_hidden_states=None):
        memory_state = None

        cell_outputs = []
        for seg_num, segment in enumerate(segments):
            cell_out, memory_state = self.memory_cell(
                input_ids=segment['input_ids'],
                attention_mask=segment['attention_mask'],
                text_mask=segment['text_mask'],
                read_mem_mask=segment['read_mem_mask'],
                write_mem_mask=segment['write_mem_mask'],
                memory_state=memory_state,
                output_hidden_states=True,
                )
            cell_outputs.append(cell_out)
            self.manage_gradients(memory_state, seg_num)

        out = self.process_outputs(
            cell_outputs,
            segments,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states
        )
        return out

    def generate(self, segments, **kwargs):
        memory_state = None

        for seg_num, segment in enumerate(segments):
            cell_out, memory_state = self.memory_cell(
                input_ids=segment['input_ids'],
                attention_mask=segment['attention_mask'],
                text_mask=segment['text_mask'],
                memory_state=memory_state, output_hidden_states=True
            )

        generated_segments = []
        for seg_num in range(len(segments), self.rmt_config.get("max_n_segments", 32)):
            output_ids, memory_state = self.generate_segment(memory_state=memory_state, **kwargs)
            generated_segments.append(output_ids)

            if self.all_done(generated_segments):
                break

        return generated_segments

    def generate_segment(self, memory_state, **kwargs):
        input_ids = self.get_bos_tensor(memory_state)
        attention_mask = torch.ones_like(input_ids).bool()

        generated = self.memory_cell.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            memory_state=memory_state,
            eos_token_id=[
                self.rmt_config['eos_token_id'],
                self.rmt_config['think_token_id'],
                self.rmt_config['answer_token_id']
            ],
            **kwargs
        )

        # Update memory state from generation
        fwd_inputs = torch.cat((input_ids, generated), dim=1)[:, :-1]
        _, memory_state = self.memory_cell(input_ids=fwd_inputs, memory_state=memory_state)

        return generated, memory_state

    def get_bos_tensor(self, memory_state):
        bos = self.rmt_config["bos_token_id"]
        bos_tensor = torch.tensor([bos] * memory_state.shape[0]).reshape(-1, 1)
        return bos_tensor.to(memory_state.device)

    def all_done(self, generated_segments):
        eos = self.rmt_config['eos_token_id']
        bs = generated_segments[0].shape[0]
        have_eos = [any([eos in seg[i] for seg in generated_segments]) for i in range(bs)]
        all_done = all(have_eos)
        return all_done

    def manage_gradients(self, memory_state, seg_num):
        k2, max_n_segments = self.rmt_config.get('k2'), self.rmt_config.get('max_n_segments')
        if seg_num == 0 \
            or k2 in {-1, None} \
                or seg_num + k2 > max_n_segments:
            return memory_state

        memory_state = memory_state.detach()
        return memory_state

    def gradient_checkpointing_enable(self, *args, **kwargs):
        if hasattr(self.memory_cell.model, "gradient_checkpointing_enable"):
            return self.memory_cell.model.gradient_checkpointing_enable(*args, **kwargs)


In [126]:
class Holder:
    def __init__(self):
        pass

args = Holder()
args.use_cot = True
args.num_mem_tokens = 16
args.segment_size = 64
args.max_n_segments = 10
args.max_cot_steps = 8
args.task_name = 'gsm8k'

In [127]:
model_name = "HuggingFaceTB/SmolLM2-135M"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
pad = [tokenizer.pad_token_id] if tokenizer.pad_token_id is not None else [tokenizer.eos_token_id]
bos = [tokenizer.bos_token_id]
eos = [tokenizer.eos_token_id]
think = tokenizer.encode("<issue_start>")
ans = tokenizer.encode("<issue_closed>")
mem_token = tokenizer.encode("<empty_output>")

if 'gsm8k' in args.task_name:
    delim = ">> <<"
elif 'multiplication' in args.task_name:
    delim = ' + '
else:
    raise NotImplementedError(f"Unknown task name {args.task_name}")

In [128]:
memory_cell = MemoryCellSmart(
    model,
    num_mem_tokens=16
)

rmt = RecurrentWrapper(
    memory_cell, 
    max_n_segments=8,
    think_token_id=think[0],
    answer_token_id=ans[0],
    bos_token_id=bos[0],
    eos_token_id=eos[0]
)

# checkpoint_path = "/home/user33/kashurin/RMT_SmolLM2-135M/pt/checkpoint-29500/pytorch_model.bin"
# checkpoint_path = "/home/user33/kashurin/RMT_SmolLM2-135M/cot/checkpoint-2200/pytorch_model.bin"
device = 'cuda'

# rmt.load_state_dict(torch.load(checkpoint_path), strict=False)
rmt.to(device)
print(':)')

:)


In [114]:
def make_segment_with_mem(input_tokens, loss=False, mem_token=None, num_mem_tokens=0):
    input_tokens_with_mem = [mem_token]*num_mem_tokens + input_tokens + [mem_token]*num_mem_tokens
    if loss:
        labels = torch.tensor(input_tokens)
    else:
        labels = torch.tensor([-100] * len(input_tokens))
    
    input_ids = torch.tensor(input_tokens_with_mem)
    
    read_mem_mask = torch.zeros(input_ids.shape, dtype=torch.long)
    read_mem_mask[:num_mem_tokens] = 1

    write_mem_mask = torch.zeros(input_ids.shape, dtype=torch.long)
    write_mem_mask[-num_mem_tokens:] = 1

    text_mask = torch.ones(input_ids.shape, dtype=torch.long) - read_mem_mask - write_mem_mask

    attention_mask = torch.ones_like(input_ids)

    if loss:
        labels_mask = torch.ones(len(input_tokens))
    else:
        labels_mask = torch.zeros(len(input_tokens))

    return {'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'labels_mask': labels_mask.bool(),
            'text_mask': text_mask.bool(),
            'read_mem_mask': read_mem_mask.bool(),
            'write_mem_mask': write_mem_mask.bool()
            }


def collate_fn(batch):
    # first, we segment each sample into task, cot steps and labels
    segments_batch = []
    for sample in batch:
        task, lab, cot = sample['task'], sample['labels'], sample['cot']
        task_tokens = tokenizer.encode(task, add_special_tokens=False)
        labels_tokens = tokenizer.encode(lab, add_special_tokens=False)
        if getattr(args, 'use_cot', False):
            cot_segments = split_cot(cot, by=delim)
        else:
            cot_segments = [cot]
        cot_segment_tokens = tokenizer.batch_encode_plus(cot_segments, add_special_tokens=False)['input_ids']

        segments = []
        segments.append(make_segment_with_mem(
            bos + task_tokens + think,
            loss=False,
            mem_token=mem_token[0],
            num_mem_tokens=args.num_mem_tokens,
        ))
        for segment in cot_segment_tokens[:-1]:
            segments.append(make_segment_with_mem(
                bos + segment + think,
                loss=True,
                mem_token=mem_token[0],
                num_mem_tokens=args.num_mem_tokens,
            ))
        segments.append(make_segment_with_mem(
            bos + cot_segment_tokens[-1] + ans,
            loss=True,
            mem_token=mem_token[0],
            num_mem_tokens=args.num_mem_tokens,
        ))

        segments.append(make_segment_with_mem(
            bos + labels_tokens + eos,
            loss=True,
            mem_token=mem_token[0],
            num_mem_tokens=args.num_mem_tokens,
        ))
        segments_batch.append(segments)

    # if some samples have less segments than others, we pad them with empty segments
    num_segments = max(len(segments) for segments in segments_batch)
    for segments in segments_batch:
        if len(segments) < num_segments:
            segments.extend([
                make_segment_with_mem(
                    eos,
                    loss=False,
                    mem_token=mem_token[0],
                    num_mem_tokens=args.num_mem_tokens,
                )] * (num_segments - len(segments)))

    # prepare segments for the whole batch
    batch_segments = []
    for i in range(num_segments):
        input_ids = [s[i]['input_ids'] for s in segments_batch]
        attention_mask = [s[i]['attention_mask'] for s in segments_batch]
        labels = [s[i]['labels'] for s in segments_batch]
        labels_mask = [s[i]['labels_mask'] for s in segments_batch]
        text_mask = [s[i]['text_mask'] for s in segments_batch]
        read_mem_mask = [s[i]['read_mem_mask'] for s in segments_batch]
        write_mem_mask = [s[i]['write_mem_mask'] for s in segments_batch]

        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=id_pad_value)
        attention_mask = pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        labels_mask = pad_sequence(labels_mask, batch_first=True, padding_value=False)
        text_mask = pad_sequence(text_mask, batch_first=True, padding_value=True)
        read_mem_mask = pad_sequence(read_mem_mask, batch_first=True, padding_value=False)
        write_mem_mask = pad_sequence(write_mem_mask, batch_first=True, padding_value=False)

        batch_segment = {'input_ids': input_ids,
                         'attention_mask': attention_mask,
                         'labels_mask': labels_mask,
                         'labels': labels,
                         'text_mask': text_mask,
                         'read_mem_mask': read_mem_mask,
                         'write_mem_mask': write_mem_mask
                         }
        batch_segments.append(batch_segment)
    full_labels = torch.cat([s['labels'] for s in batch_segments], dim=1)
    return {"segments": batch_segments, 'labels': full_labels}

In [10]:
dataset = 'booydar/gsm8k'
train_dataset = datasets.load_dataset(dataset, split='train')
valid_dataset = datasets.load_dataset(dataset, split='valid')

In [63]:
def collated_batch_to_device(collated):
    new_segments = []
    for segment in collated["segments"]:
        segment = {k: v.to(device) for k, v in segment.items()}
        new_segments.append(segment)

    collated["segments"] = new_segments
    
    collated["labels"] = collated["labels"].to(device)

    return collated

In [77]:
collated["segments"][0]["text_mask"][0].sum()

tensor(63, device='cuda:0')

In [129]:
batch = [
    valid_dataset[0],
    valid_dataset[1]
]

collated = collate_fn(batch)
collated = collated_batch_to_device(collated)

In [130]:
output = rmt(collated["segments"], collated["labels"], output_attentions=True)