In [142]:
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions

def _forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
        memory_storage = None
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        next_decoder_cache = () if use_cache else None
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:

                # if use_cache:
                #     logger.warning(
                #         "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                #     )
                #     use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, past_key_value, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                num_mem = memory_storage['num_mem_tokens']
                if i in memory_storage:
                    layer_memory = memory_storage[i]
                    for j, h in enumerate(hidden_states):
                        hidden_states[j][:layer_memory[j].shape[0]] = layer_memory[j]

                print(f'hidden states shape: {len(hidden_states), hidden_states[0].shape}\n memory storage:{memory_storage.keys()}')
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )
                

            hidden_states = layer_outputs[0]
            if i in memory_storage:
                print(f'replacing ms[i] {memory_storage[i][0][0][:10]}... to {[h[:num_mem] for h in hidden_states][0][0][:10]}')
            memory_storage[i] = [h[:num_mem] for h in hidden_states]

            # memory_storage['success'] = True
            # print(f'Overrided method message: hidden states shape: {len(hidden_states), hidden_states[0].shape}\n memory storage:{memory_storage}')
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

In [143]:
# import types
# def create_memory_forward(module, memory_storage):
#     def memory_forward(*args, **kwargs):
#         return module(*args, **kwargs, memory_storage=memory_storage)

#     return memory_forward
# self.base_model.encoder.forward = types.MethodType(create_memory_forward(self.base_model.encoder.forward, memory_storage), self.base_model.encoder)


In [144]:
import types
self = rmt

memory_storage = {'num_mem_tokens': 10}
self.base_model.encoder.forward = types.MethodType(lambda *args, **kwargs: _forward(*args, **kwargs, memory_storage=memory_storage), self.base_model.encoder)

In [145]:
input_ids = sample['input_ids']

memory = self.set_memory()
segmented = self.pad_and_segment(input_ids)

outputs = []
for seg_num, segment_data in enumerate(zip(*segmented)):
    input_ids, attention_mask, token_type_ids = segment_data
    if memory.ndim == 2:
        memory = memory.repeat(input_ids.shape[0], 1, 1)
    if (self.bptt_depth > -1) and (len(segmented) - seg_num > self.bptt_depth): 
        memory = memory.detach()

    seg_kwargs = dict(**kwargs)
    if self.drop_empty_segments:

        non_empty_mask = [not torch.equal(input_ids[i], self.empty) for i in range(len(input_ids))]
        if sum(non_empty_mask) == 0:
            continue
        input_ids = input_ids[non_empty_mask]
        attention_mask = attention_mask[non_empty_mask]
        token_type_ids = token_type_ids[non_empty_mask]
        seg_kwargs['labels'] = seg_kwargs['labels'][non_empty_mask]

        inputs_embeds = self.base_model.embeddings.word_embeddings(input_ids)
        inputs_embeds[:, 1:1+self.num_mem_tokens] = memory[non_empty_mask]
    else:
        inputs_embeds = self.base_model.embeddings.word_embeddings(input_ids)
        inputs_embeds[:, 1:1+self.num_mem_tokens] = memory

    seg_kwargs['inputs_embeds'] = inputs_embeds
    seg_kwargs['attention_mask'] = attention_mask
    seg_kwargs['token_type_ids'] = token_type_ids
    
    out = self.model.forward(**seg_kwargs, output_hidden_states=True)
    outputs.append(out)

    if self.drop_empty_segments:
        memory[non_empty_mask] = out.hidden_states[-1][:, :self.num_mem_tokens]
    else:
        memory = out.hidden_states[-1][:, :self.num_mem_tokens]

if self.sum_loss:
    out['loss'] = torch.stack([o['loss'] for o in outputs]).sum(dim=-1)

hidden states shape: (2, torch.Size([512, 256]))
 memory storage:dict_keys(['num_mem_tokens'])
hidden states shape: (2, torch.Size([512, 256]))
 memory storage:dict_keys(['num_mem_tokens', 0])
hidden states shape: (2, torch.Size([512, 256]))
 memory storage:dict_keys(['num_mem_tokens', 0, 1])
hidden states shape: (2, torch.Size([512, 256]))
 memory storage:dict_keys(['num_mem_tokens', 0, 1, 2])


found memory!!
hidden states shape: (2, torch.Size([512, 256]))
 memory storage:dict_keys(['num_mem_tokens', 0, 1, 2, 3])
replacing ms[i] tensor([ 0.4674, -0.2865,  0.6345,  0.4214,  0.9536,  0.4036, -0.4627,  0.2310,
        -5.6048, -0.0293], grad_fn=<SliceBackward0>)... to tensor([ 0.7673, -0.3799,  0.3761,  0.2463,  0.4095, -0.2411, -0.8653, -0.4908,
        -3.7968, -0.1928], grad_fn=<SliceBackward0>)


found memory!!
hidden states shape: (2, torch.Size([512, 256]))
 memory storage:dict_keys(['num_mem_tokens', 0, 1, 2, 3])
replacing ms[i] tensor([ 3.3768e-01, -3.1988e-01,  5.4384e-01,  4.