In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from transformers import PreTrainedModel, AutoModelForSequenceClassification

import math

from typing import List, Optional, Tuple, Union
from transformers import BertForSequenceClassification
import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

### Finetune

In [4]:
from transformers import AutoModelForSequenceClassification
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
import sys
sys.path.append('..')
# from modeling_rmt import RMTEncoderForSequenceClassification

In [5]:
import math

class RMTEncoderForSequenceClassification():
    def __init__(self, config=None, base_model=None, **kwargs):
        if config is not None:
            self.model = AutoModelForSequenceClassification(config, **kwargs)
        
        if base_model is not None:
            self.model = base_model


    def from_pretrained(from_pretrained, **kwargs):
        base_model = AutoModelForSequenceClassification.from_pretrained(from_pretrained, **kwargs)
        rmt = RMTEncoderForSequenceClassification(base_model=base_model)
        return rmt
        

    def set_params(self, 
                drop_empty_segments=True,
                sum_loss=False,
                input_size=None, 
                input_seg_size=None, 
                backbone_cls=None,
                num_mem_tokens=0, 
                bptt_depth=-1, 
                pad_token_id=0, 
                eos_token_id=1,
                cls_token_id=101, 
                sep_token_id=102):
        if input_size is not None:
            self.input_size = input_size
        else:
            self.input_size =  self.base_model.embeddings.position_embeddings.weight.shape[0]
        self.input_seg_size = input_seg_size

        self.bptt_depth = bptt_depth
        self.pad_token_id = pad_token_id
        self.cls_token = torch.tensor([cls_token_id])
        self.sep_token = torch.tensor([sep_token_id])
        self.num_mem_tokens = num_mem_tokens
        self.drop_empty_segments = drop_empty_segments
        self.sum_loss = sum_loss
        self.extend_word_embeddings()


    def set_memory(self, memory=None):
        if memory is None:
            mem_token_ids = self.mem_token_ids.to(device=self.device)
            memory = self.base_model.embeddings.word_embeddings(mem_token_ids)
        return memory
    
    def extend_word_embeddings(self):
        vocab_size = self.base_model.embeddings.word_embeddings.weight.shape[0]
        extended_vocab_size = vocab_size + self.num_mem_tokens
        self.mem_token_ids = torch.arange(vocab_size, vocab_size + self.num_mem_tokens)
        self.base_model.resize_token_embeddings(extended_vocab_size)


    def __call__(self, input_ids, **kwargs):
        memory = self.set_memory()
        segmented = self.pad_and_segment(input_ids)

        outputs = []
        for seg_num, segment_data in enumerate(zip(*segmented)):
            input_ids, attention_mask, token_type_ids = segment_data
            if memory.ndim == 2:
                memory = memory.repeat(input_ids.shape[0], 1, 1)
            if (self.bptt_depth > -1) and (len(segmented) - seg_num > self.bptt_depth): 
                memory = memory.detach()

            seg_kwargs = dict(**kwargs)
            if self.drop_empty_segments:

                non_empty_mask = [not torch.equal(input_ids[i], self.empty) for i in range(len(input_ids))]
                if sum(non_empty_mask) == 0:
                    continue
                input_ids = input_ids[non_empty_mask]
                attention_mask = attention_mask[non_empty_mask]
                token_type_ids = token_type_ids[non_empty_mask]
                seg_kwargs['labels'] = seg_kwargs['labels'][non_empty_mask]

                inputs_embeds = self.base_model.embeddings.word_embeddings(input_ids)
                inputs_embeds[:, 1:1+self.num_mem_tokens] = memory[non_empty_mask]
            else:
                inputs_embeds = self.base_model.embeddings.word_embeddings(input_ids)
                inputs_embeds[:, 1:1+self.num_mem_tokens] = memory

            seg_kwargs['inputs_embeds'] = inputs_embeds
            seg_kwargs['attention_mask'] = attention_mask
            seg_kwargs['token_type_ids'] = token_type_ids
            
            out = self.model.forward(**seg_kwargs, output_hidden_states=True)
            outputs.append(out)

            if self.drop_empty_segments:
                memory[non_empty_mask] = out.hidden_states[-1][:, :self.num_mem_tokens]
            else:
                memory = out.hidden_states[-1][:, :self.num_mem_tokens]

        if self.sum_loss:
            out['loss'] = torch.stack([o['loss'] for o in outputs]).sum(dim=-1)

        return out

    def pad_and_segment(self, input_ids):
        
        sequence_len = input_ids.shape[1]
        input_seg_size = self.input_size - self.num_mem_tokens - 3 
        if self.input_seg_size is not None and self.input_seg_size < input_seg_size:
            input_seg_size = self.input_seg_size
            
        n_segments = math.ceil(sequence_len / input_seg_size)

        augmented_inputs = []
        for input in input_ids:
            input = input[input != self.pad_token_id][1:-1]

            seg_sep_inds = [0] + list(range(len(input), 0, -input_seg_size))[::-1] # chunk so that first segment has various size
            input_segments = [input[s:e] for s, e in zip(seg_sep_inds, seg_sep_inds[1:])]

            def pad_add_special_tokens(tensor, seg_size):
                tensor = torch.cat([self.cls_token.to(device=self.device),
                                    self.mem_token_ids.to(device=self.device),
                                    self.sep_token.to(device=self.device),
                                    tensor.to(device=self.device),
                                    self.sep_token.to(device=self.device)])
                pad_size = seg_size - tensor.shape[0]
                if pad_size > 0:
                    tensor = F.pad(tensor, (0, pad_size))
                return tensor

            input_segments = [pad_add_special_tokens(t, self.input_size) for t in input_segments]
            empty = torch.Tensor([]).int()
            self.empty = pad_add_special_tokens(empty, self.input_size)
            empty_segments = [self.empty for i in range(n_segments - len(input_segments))]
            input_segments = empty_segments + input_segments

            augmented_input = torch.cat(input_segments)
            augmented_inputs.append(augmented_input)
            
        augmented_inputs = torch.stack(augmented_inputs)
        attention_mask = torch.ones_like(augmented_inputs)
        attention_mask[augmented_inputs == self.pad_token_id] = 0

        token_type_ids = torch.zeros_like(attention_mask)

        input_segments = torch.chunk(augmented_inputs, n_segments, dim=1)
        attention_mask = torch.chunk(attention_mask, n_segments, dim=1)
        token_type_ids = torch.chunk(token_type_ids, n_segments, dim=1)
    
        return input_segments, attention_mask, token_type_ids


    def to(self, device):
        self.model = self.model.to(device)
        
    
    def cuda(self):
        self.model.cuda()


    def __getattr__(self, attribute):
        return getattr(self.model, attribute)


    def parameters(self, **kwargs):
        return self.model.parameters(**kwargs)

    def named_parameters(self, **kwargs):
        return self.model.named_parameters(**kwargs)


In [11]:
# pretrained_model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5, output_hidden_states=True)

In [8]:
model_name = "google/bert_uncased_L-4_H-256_A-4"
# model_name = 'google/electra-base-discriminator'

In [7]:
num_segments = 2
num_mem_tokens = 10

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer.model_max_length  = (tokenizer.model_max_length - num_mem_tokens) * num_segments
tokenizer.padding_side = 'left'

In [9]:
rmt = RMTEncoderForSequenceClassification.from_pretrained(model_name, num_labels=3)

Downloading:   0%|          | 0.00/43.0M [00:00<?, ?B/s]

Some weights of the model checkpoint at google/bert_uncased_L-4_H-256_A-4 were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification w

In [10]:
rmt.set_params(
                drop_empty_segments=True,
                sum_loss=False,
                input_size=None, 
                input_seg_size=None, 
                backbone_cls=None,
                num_mem_tokens=0, 
                bptt_depth=-1, 
                pad_token_id=0, 
                eos_token_id=1,
                cls_token_id=101, 
                sep_token_id=102)

### load dataset 

In [11]:
input_seq_len = 512
target_seq_len = 2

In [12]:
encode_plus_kwargs = {'max_length': input_seq_len,
                              'truncation': True,
                              'padding': 'longest',
                              'pad_to_multiple_of': 64}
generate_kwargs = {}
labels_map = {'Contradiction': 0, 'Entailment': 1, 'Not mentioned': 2}
num_labels = len(labels_map)

def collate_fn(batch):
    # cut too long strings because they may slow down tokenization
    inputs = [b['input'][:input_seq_len * 10] for b in batch]
    labels = [b['output'][:target_seq_len * 10] for b in batch]
    features = tokenizer.batch_encode_plus(list(inputs), return_tensors='pt', **encode_plus_kwargs)
    labels = np.array([labels_map[t] for t in labels])
    features['labels'] = torch.from_numpy(labels)
    return features

In [13]:
import datasets
dataset = datasets.load_dataset('tau/scrolls', 'contract_nli')
train_dataset = dataset['train']

Reusing dataset scrolls (/home/bulatov/.cache/huggingface/datasets/tau___scrolls/contract_nli/1.0.0/672021d5d8e1edff998a6ea7a5bff35fdfd0ae243e7cf6a8c88a57a04afb46ac)


  0%|          | 0/3 [00:00<?, ?it/s]

In [14]:
train_sampler = RandomSampler(train_dataset,)
kwargs = {'pin_memory': True, 'num_workers': 0}
train_dataloader = DataLoader(train_dataset, batch_size=2, sampler=train_sampler,
                                collate_fn=collate_fn, **kwargs)

In [22]:
gen = iter(train_dataloader)
sample = next(gen)

## Override forward

In [23]:
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions

def _forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = False,
        output_hidden_states: Optional[bool] = False,
        return_dict: Optional[bool] = True,
        memory_storage = None
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        next_decoder_cache = () if use_cache else None
        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None

            if self.gradient_checkpointing and self.training:

                # if use_cache:
                #     logger.warning(
                #         "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                #     )
                #     use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, past_key_value, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                num_mem = memory_storage['num_mem_tokens']
                if i in memory_storage:
                    layer_memory = memory_storage[i]
                    for j, h in enumerate(hidden_states):
                        hidden_states[j][:layer_memory[j].shape[0]] = layer_memory[j]

                print(f'hidden states shape: {len(hidden_states), hidden_states[0].shape}\n memory storage:{memory_storage.keys()}')
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )
                

            hidden_states = layer_outputs[0]
            if i in memory_storage:
                print(f'replacing ms[i] {memory_storage[i][0][0][:10]}... to {[h[:num_mem] for h in hidden_states][0][0][:10]}')
            memory_storage[i] = [h[:num_mem] for h in hidden_states]

            # memory_storage['success'] = True
            # print(f'Overrided method message: hidden states shape: {len(hidden_states), hidden_states[0].shape}\n memory storage:{memory_storage}')
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

In [24]:

import types
self = rmt

memory_storage = {'num_mem_tokens': 10}
self.base_model.encoder.forward = types.MethodType(lambda *args, **kwargs: _forward(*args, **kwargs, memory_storage=memory_storage), self.base_model.encoder)

In [26]:
sample

{'input_ids': tensor([[  101,  4909,  2283,  ...,  1017,  1012,   102],
         [  101,  2070, 14422,  ...,  2025,  3024,   102]]),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
         [1, 1, 1,  ..., 1, 1, 1]]),
 'labels': tensor([1, 2])}

In [27]:
seg_kwargs

{'pin_memory': True, 'num_workers': 0}

In [29]:
kwargs = sample.copy()
# kwargs

In [34]:
input_ids = kwargs.pop('input_ids')

memory = self.set_memory()
segmented = self.pad_and_segment(input_ids)

outputs = []
for seg_num, segment_data in enumerate(zip(*segmented)):
    input_ids, attention_mask, token_type_ids = segment_data
    if memory.ndim == 2:
        memory = memory.repeat(input_ids.shape[0], 1, 1)
    if (self.bptt_depth > -1) and (len(segmented) - seg_num > self.bptt_depth): 
        memory = memory.detach()

    seg_kwargs = dict(**kwargs)
    if self.drop_empty_segments:

        non_empty_mask = [not torch.equal(input_ids[i], self.empty) for i in range(len(input_ids))]
        if sum(non_empty_mask) == 0:
            continue
        input_ids = input_ids[non_empty_mask]
        attention_mask = attention_mask[non_empty_mask]
        token_type_ids = token_type_ids[non_empty_mask]
        seg_kwargs['labels'] = seg_kwargs['labels'][non_empty_mask]

        inputs_embeds = self.base_model.embeddings.word_embeddings(input_ids)
        inputs_embeds[:, 1:1+self.num_mem_tokens] = memory[non_empty_mask]
    else:
        inputs_embeds = self.base_model.embeddings.word_embeddings(input_ids)
        inputs_embeds[:, 1:1+self.num_mem_tokens] = memory

    seg_kwargs['inputs_embeds'] = inputs_embeds
    seg_kwargs['attention_mask'] = attention_mask
    seg_kwargs['token_type_ids'] = token_type_ids
    
    out = self.model.forward(**seg_kwargs, output_hidden_states=True)
    outputs.append(out)

    if self.drop_empty_segments:
        memory[non_empty_mask] = out.hidden_states[-1][:, :self.num_mem_tokens]
    else:
        memory = out.hidden_states[-1][:, :self.num_mem_tokens]

if self.sum_loss:
    out['loss'] = torch.stack([o['loss'] for o in outputs]).sum(dim=-1)

hidden states shape: (2, torch.Size([512, 256]))
 memory storage:dict_keys(['num_mem_tokens'])
hidden states shape: (2, torch.Size([512, 256]))
 memory storage:dict_keys(['num_mem_tokens', 0])
hidden states shape: (2, torch.Size([512, 256]))
 memory storage:dict_keys(['num_mem_tokens', 0, 1])
hidden states shape: (2, torch.Size([512, 256]))
 memory storage:dict_keys(['num_mem_tokens', 0, 1, 2])
hidden states shape: (2, torch.Size([512, 256]))
 memory storage:dict_keys(['num_mem_tokens', 0, 1, 2, 3])
replacing ms[i] tensor([ 0.4910, -0.3252,  0.6379,  0.4366,  0.9436,  0.3721, -0.4503,  0.2484,
        -5.6303,  0.0416], grad_fn=<SliceBackward>)... to tensor([ 0.5875, -0.4431,  0.4703,  0.1823,  0.2762, -0.0980, -0.7847, -0.3769,
        -4.0026, -0.0786], grad_fn=<SliceBackward>)
hidden states shape: (2, torch.Size([512, 256]))
 memory storage:dict_keys(['num_mem_tokens', 0, 1, 2, 3])
replacing ms[i] tensor([ 0.4131, -0.2734,  0.5131,  0.4519,  0.8615, -0.4164, -0.4695,  0.1083,
      

In [25]:
# out = self.model.forward(**seg_kwargs, output_hidden_states=True)
# out

In [None]:
# out['hidden_states']

In [None]:
out = forward(self.model, **seg_kwargs, output_hidden_states=True, use_cache=True)
out.keys()

odict_keys(['loss', 'logits', 'hidden_states'])

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(output_dir="test_trainer")

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [None]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(per_device_train_batch_size=2, per_device_eval_batch_size=1, output_dir="test_trainer", evaluation_strategy="epoch", no_cuda=True, max_steps=5)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [None]:
# augmented_inputs = np.load('augmented_inputs.npy', allow_pickle=True)
# attn_masks = np.load('attention_masks.npy', allow_pickle=True)
# tokenizer.decode(augmented_inputs[0][512:])

In [None]:
trainer = Trainer(
    model=rmt,
    # model=classic_bert,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)
trainer.train()

max_steps is given, it will override any value given in num_train_epochs
The following columns in the training set  don't have a corresponding argument in `RMTEncoderForSequenceClassification.forward` and have been ignored: text. If text are not expected by `RMTEncoderForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 1000
  Num Epochs = 1
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 2
  Gradient Accumulation steps = 1
  Total optimization steps = 5










[A[A[A[A[A[A[A[A[A[A

(tensor([[  101, 29206, 29207,  ...,     0,     0,     0],
        [  101, 29206, 29207,  ...,     0,     0,     0]]), tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]))


  0%|          | 0/5 [21:23<?, ?it/s]
  0%|          | 0/5 [20:56<?, ?it/s]
  0%|          | 0/5 [14:10<?, ?it/s]
  0%|          | 0/5 [11:58<?, ?it/s]
  0%|          | 0/5 [07:14<?, ?it/s]
  0%|          | 0/5 [06:41<?, ?it/s]
  0%|          | 0/5 [04:25<?, ?it/s]
  0%|          | 0/5 [03:41<?, ?it/s]
  0%|          | 0/5 [02:55<?, ?it/s]
  0%|          | 0/5 [02:18<?, ?it/s]
  0%|          | 0/5 [01:49<?, ?it/s]


KeyboardInterrupt: 

In [None]:
sampler = SequentialSampler(small_train_dataset)
dl = DataLoader(small_train_dataset, sampler=sampler, batch_size=4)
gen = dl.__iter__()
s = next(gen)