In [None]:
from typing import List, Optional, Tuple, Union

import torch
from transformers import BertConfig, BertTokenizer

from src.models.components.bert.modeling_bert import (
    BertModel,
    BertOnlyMLMHead,
    BertPooler,
    BertPreTrainedModel,
    CausalLMOutputWithCrossAttentions,
)

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese", cache_dir="~/.cache")

In [None]:
config = BertConfig.from_pretrained("bert-base-chinese", cache_dir="~/.cache")

In [None]:
example = "[CLS]hello, my dog is cute.[SEP]hi, this is my lovely dog.[SEP][PAD][PAD]"

In [None]:
inputs = tokenizer(
    [
        ("hello, my dog is cute.", "hi, this is my lovely dog."),
        ("good luck to you.", "hope you good luck."),
    ],
    return_tensors="pt",
    padding="do_not_pad",
    truncation=True,
    max_length=22,
    return_special_tokens_mask=True,
)

In [None]:
input_ids = inputs.input_ids
sep_id = tokenizer.sep_token_id

In [None]:
# (input_ids[0] == sep_id).nonzero()[0][0].item()
(input_ids[0] == sep_id).nonzero()[-1][-1]

In [None]:
inputs

In [None]:
a = torch.randn(2, 3)

In [None]:
torch.tril(torch.ones(5, 5), diagonal=0)

In [None]:
def get_simbert_mask(input_ids: torch.Tensor, sep_token_id: int) -> torch.Tensor:
    sequence_length = input_ids.size(-1)
    attention_masks = []
    for sequence_ids in input_ids:
        first_sep = (sequence_ids == sep_id).nonzero()[0][-1].item()
        last_sep = (sequence_ids == sep_id).nonzero()[-1][-1].item()
        attention_mask = torch.zeros((sequence_length, sequence_length), dtype=torch.long)
        attention_mask[: last_sep + 1, : first_sep + 1] = 1
        attention_mask[first_sep + 1 : last_sep + 1, first_sep + 1 : last_sep + 1] = torch.tril(
            torch.ones(last_sep - first_sep, last_sep - first_sep), diagonal=0
        )
        attention_masks.append(attention_mask.unsqueeze(0))
    return torch.cat(attention_masks, dim=0)

In [None]:
input_ids = inputs.input_ids

In [None]:
(input_ids[0] == sep_id).nonzero()[0][0].item()

In [None]:
sep_pos = (input_ids == sep_id).nonzero()[:, 1, None]

In [None]:
def get_simbert_labels(
    input_ids: torch.Tensor, pad_token_id: torch.Tensor, sep_token_id: int = 102
) -> torch.Tensor:
    labels = input_ids.clone()
    for label in labels:
        first_sep = (label == sep_token_id).nonzero()[0][0].item()
        label[: first_sep + 1] = -100
    labels[labels == pad_token_id] = -100
    return labels

In [None]:
inputs.input_ids

In [None]:
tokenizer.batch_decode(inputs.input_ids)

In [None]:
inputs.input_ids[1]

In [None]:
get_simbert_mask(inputs.input_ids, tokenizer.sep_token_id)[1]

In [None]:
tokenizer.batch_decode(
    get_simbert_labels(inputs.input_ids, tokenizer.pad_token_id, tokenizer.sep_token_id)
)

In [None]:
class SimBertModel(BertPreTrainedModel):

    _keys_to_ignore_on_load_unexpected = [r"pooler"]
    _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]

    def __init__(self, config):
        super().__init__(config)

        assert not config.is_decoder, "Model is not a decoder but is being used as one"
        assert hasattr(config, "vector_dim"), "Vector dim is not defined"

        self.bert = BertModel(config, add_pooling_layer=False)
        self.cls = BertOnlyMLMHead(config)
        self.projector = torch.nn.Linear(config.hidden_size, config.vector_dim)

        # Initialize weights and apply final processing
        self.post_init()
        self.tie_weights()

    def tie_weights(self):
        self._tie_or_clone_weights(
            self.cls.predictions.decoder, self.bert.embeddings.word_embeddings
        )

    def get_output_embeddings(self):
        return self.cls.predictions.decoder

    def set_output_embeddings(self, new_embeddings):
        self.cls.predictions.decoder = new_embeddings

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.Tensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
        r"""
        encoder_hidden_states  (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
            the model is configured as a decoder.
        encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
            the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
            ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
        past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
            Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.

            If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
            don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
            `decoder_input_ids` of shape `(batch_size, sequence_length)`.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        if labels is not None:
            use_cache = False

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        prediction_scores = self.cls(sequence_output)

        lm_loss = None
        if labels is not None:
            # we are doing next-token prediction; shift prediction scores and input ids by one
            shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
            labels = labels[:, 1:].contiguous()
            loss_fct = torch.nn.CrossEntropyLoss()
            lm_loss = loss_fct(
                shifted_prediction_scores.view(-1, self.config.vocab_size),
                labels.view(-1),
            )

        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return ((lm_loss,) + output) if lm_loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=lm_loss,
            logits=prediction_scores,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )

    def forward_vector(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
        forward_output = self.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=return_dict,
        )
        vector_output = self.projector(
            forward_output.hidden_states[-1][:, 0, :]
        )  # (N, vector_dim)
        return forward_output, vector_output

    def encode_text(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
        forward_output = self.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=return_dict,
        )
        vector_output = self.projector(
            forward_output.hidden_states[-1][:, 0, :]
        )  # (N, vector_dim)
        return vector_output

    def prepare_inputs_for_generation(
        self, input_ids, past=None, attention_mask=None, **model_kwargs
    ):
        input_shape = input_ids.shape
        # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
        if attention_mask is None:
            attention_mask = input_ids.new_ones(input_shape)

        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "past_key_values": past,
        }

    def _reorder_cache(self, past, beam_idx):
        reordered_past = ()
        for layer_past in past:
            reordered_past += (
                tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),
            )
        return reordered_past

In [None]:
config.vector_dim = 512

In [None]:
model = SimBertModel.from_pretrained("bert-base-chinese", config=config, cache_dir="~/.cache/")

In [None]:
model.forward(
    input_ids=inputs.input_ids,
    attention_mask=get_simbert_mask(inputs.input_ids, tokenizer.sep_token_id),
    labels=get_simbert_labels(inputs.input_ids, tokenizer.pad_token_id, tokenizer.sep_token_id),
)

In [None]:
tokenizer.batch_decode(
    model.generate(
        tokenizer(["你好世界"], return_tensors="pt").input_ids, do_sample=True, max_length=20
    )
)

In [None]:
model.forward_vector(
    input_ids=inputs.input_ids,
    attention_mask=get_simbert_mask(inputs.input_ids, tokenizer.sep_token_id),
    labels=get_simbert_labels(inputs.input_ids, tokenizer.pad_token_id, tokenizer.sep_token_id),
)