In [1]:
from fairseq.models.bart import BARTModel

2021-10-08 03:24:22.692477: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-10-08 03:24:22.692531: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [2]:
# # MBart Test
# mbart = BARTModel.from_pretrained('mbart.cc25.v2', checkpoint_file='model.pt')
# mbart.eval()

In [3]:
# This is needed for the other four languages
from fairseq.tasks import register_task
from fairseq.tasks.translation import TranslationTask


@register_task('translation_without_lang_token')
class TranslationPLBARTTask(TranslationTask):
    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        TranslationTask.add_args(parser)
        parser.add_argument('--langs', required=True, metavar='LANG',
                            help='comma-separated list of monolingual language, '
                                 'for example, "en,de,fr". These should match the '
                                 'langs from pretraining (and be in the same order). '
                                 'You should always add all pretraining language idx '
                                 'during finetuning.')

    def __init__(self, args, src_dict, tgt_dict):
        super().__init__(args, src_dict, tgt_dict)
        self.langs = args.langs.split(",")
        for d in [self.src_dict, self.tgt_dict]:
            for l in self.langs:
                d.add_symbol("[{}]".format(l))
            d.add_symbol("<mask>")


In [4]:
plbart = BARTModel.from_pretrained('PLBART/plbart-cs-java', checkpoint_file='model.pt')
plbart.eval()

BARTHubInterface(
  (model): BARTModel(
    (encoder): TransformerEncoder(
      (dropout_module): FairseqDropout()
      (embed_tokens): Embedding(50005, 768, padding_idx=1)
      (embed_positions): LearnedPositionalEmbedding(1026, 768, padding_idx=1)
      (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (dropout_module): FairseqDropout()
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout_module): FairseqDropout()
          (activation_dropout_module): FairseqDropout()
     

In [5]:
len(plbart.task.source_dictionary)

50005

In [6]:
for i in [0,1,2,3,50001,50002, 50003, 50004]:
    print(i, plbart.task.source_dictionary[i])

0 <s>
1 <pad>
2 </s>
3 <unk>
50001 [java]
50002 [python]
50003 [en_XX]
50004 <mask>


In [7]:
fs_model = plbart.model

In [8]:
plbart.args

Namespace(activation_dropout=0.0, activation_fn='gelu', adam_betas='(0.9, 0.98)', adam_eps=1e-06, adaptive_input=False, adaptive_softmax_cutoff=None, adaptive_softmax_dropout=0, all_gather_list_size=16384, arch='mbart_base', attention_dropout=0.1, best_checkpoint_metric='bleu', bf16=False, bpe='gpt2', broadcast_buffers=False, bucket_cap_mb=25, checkpoint_suffix='', clip_norm=0.0, cpu=False, criterion='label_smoothed_cross_entropy', cross_self_attention=False, curriculum=0, data='/home/crocoder/Desktop/transformers/PLBART/plbart-cs-java', data_buffer_size=10, dataset_impl=None, ddp_backend='no_c10d', decoder_attention_heads=12, decoder_embed_dim=768, decoder_embed_path=None, decoder_ffn_embed_dim=3072, decoder_input_dim=768, decoder_layerdrop=0, decoder_layers=6, decoder_layers_to_keep=None, decoder_learned_pos=True, decoder_normalize_before=False, decoder_output_dim=768, device_id=0, disable_validation=False, distributed_backend='nccl', distributed_init_method=None, distributed_no_spaw

In [9]:
from transformers import PLBartConfig, PLBartForConditionalGeneration

In [10]:
hf_model = PLBartForConditionalGeneration.from_pretrained('plbart-cs-java')

## Inputs

In [11]:
import sentencepiece as spm

In [12]:
vocab_filepath = "./PLBART/plbart_orig_pretrained_ckpt/sentencepiece.bpe.model"
tokenizer = spm.SentencePieceProcessor()
tokenizer.Load(vocab_filepath)
tokenizer.SetEncodeExtraOptions("")

True

In [13]:
texts = ["This is a sample text", "Another example here"]
text_0_tokens = tokenizer.EncodeAsIds(texts[0].strip())
text_1_tokens = tokenizer.EncodeAsIds(texts[1].strip())
# Need to check how is the original tokenizer

In [14]:
text_1_tokens += [1]*2

In [15]:
import numpy as np
import torch
import torch.nn as nn

In [16]:
input_ids = torch.from_numpy(np.array([text_0_tokens, text_1_tokens]))
attention_mask = torch.ones_like(input_ids)
attention_mask[1, -2:] = 0
token_type_ids = torch.zeros_like(input_ids)

# Encoder Embeddings

## Fairseq

In [17]:
fs_embeds, embed = fs_model.encoder.forward_embedding(input_ids)

In [18]:
# fs_model.encoder.quant_noise # None
# fs_model.encoder.layernorm_embedding #  not None
# fs_model.encoder.embed_positions # not None

## HuggingFace

In [19]:
hf_model.model.encoder.training

False

In [20]:
input_shape = input_ids.size()
inputs_embeds = hf_model.model.encoder.embed_tokens(input_ids) * hf_model.model.encoder.embed_scale
embed_pos = hf_model.model.encoder.embed_positions(input_shape)
hf_embeds = inputs_embeds + embed_pos
hf_embeds = hf_model.model.encoder.layernorm_embedding(hf_embeds)
hf_embeds = nn.functional.dropout(hf_embeds, p= hf_model.model.encoder.dropout, training= hf_model.model.encoder.training)


In [21]:
torch.allclose(hf_embeds[0], fs_embeds[0], atol=1e-5)

True

In [22]:
torch.allclose(hf_embeds[1,:-2], fs_embeds[1, :-2], atol=1e-5)

True

In [23]:
hf_embeds[1, -2:]

tensor([[-0.9383,  0.4173,  0.8157,  ..., -0.1300,  0.3722, -1.5632],
        [-0.9641,  0.3882,  0.6569,  ...,  0.2866,  0.5034, -1.5643]],
       grad_fn=<SliceBackward>)

In [24]:
fs_embeds[1, -2:] # Fairseq handles padding tokens differently, so they won't match.

tensor([[-0.6705,  0.2665,  0.9473,  ...,  0.0948,  0.4174, -1.1947],
        [-0.6705,  0.2665,  0.9473,  ...,  0.0948,  0.4174, -1.1947]],
       grad_fn=<SliceBackward>)

In [25]:
torch.allclose(hf_embeds[1,-2:], fs_embeds[1, -2:], atol=1e-5)

False

# Encoder Model

## Fairseq

In [26]:
fs_encoder = fs_model.encoder

In [27]:
src_lengths = torch.tensor([len(text_0_tokens), len(text_1_tokens)])

In [28]:
fs_encoder_out = fs_encoder(input_ids, src_lengths)

In [29]:
fs_encoder_out_encoder_out = fs_encoder_out.encoder_out.permute(1, 0, 2)

In [30]:
fs_encoder_out_encoder_out.shape

torch.Size([2, 5, 768])

## HuggingFace

In [31]:
hf_encoder = hf_model.model.encoder

In [32]:
hf_encoder_out = hf_encoder(input_ids, attention_mask)

In [33]:
hf_encoder_out = hf_encoder_out.last_hidden_state

In [34]:
torch.allclose(hf_encoder_out[0], fs_encoder_out_encoder_out[0], atol=1e-5)

True

In [35]:
torch.allclose(hf_encoder_out[1,:-2], fs_encoder_out_encoder_out[1,:-2], atol=1e-5)

True

# Decoder

## Fairseq

In [36]:
fs_decoder = fs_model.decoder

In [37]:
fs_decoder_out = fs_decoder(input_ids, fs_encoder_out)

In [38]:
len(fs_decoder_out)
fs_decoder_out_decoder_out = fs_decoder_out[0]

In [39]:
fs_decoder_out[1].keys()

dict_keys(['attn', 'inner_states'])

In [40]:
#  fs_decoder_out[1]['inner_states'][-1].permute(1, 0, 2).shape

In [41]:
fs_decoder_out_inner_state = fs_decoder_out[1]['inner_states'][-1].permute(1, 0, 2)

In [42]:
fs_decoder_out_decoder_out.shape

torch.Size([2, 5, 50005])

## HuggingFace

In [43]:
hf_decoder = hf_model.model.decoder

In [44]:
hf_decoder_out = hf_decoder(input_ids, attention_mask=attention_mask, encoder_hidden_states=hf_encoder_out, encoder_attention_mask=attention_mask)

In [45]:
hf_decoder_out.last_hidden_state.shape

torch.Size([2, 5, 768])

In [46]:
torch.allclose(hf_decoder_out.last_hidden_state[0], fs_decoder_out_inner_state[0], atol=1e-5)

True

In [47]:
torch.allclose(hf_decoder_out.last_hidden_state[1, :-2], fs_decoder_out_inner_state[1, :-2], atol=1e-5)

True

In [48]:
torch.allclose(hf_decoder_out.last_hidden_state[1], fs_decoder_out_inner_state[1], atol=1e-5)

False

In [49]:
lm_logits = hf_model.lm_head(hf_decoder_out.last_hidden_state) + hf_model.final_logits_bias

In [50]:
lm_logits.shape

torch.Size([2, 5, 50005])

In [51]:
torch.allclose(lm_logits[0], fs_decoder_out_decoder_out[0], atol=1e-5)

True

In [52]:
torch.allclose(lm_logits[1, :-2], fs_decoder_out_decoder_out[1, :-2], atol=1e-5)

True