In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BartModel, BartTokenizer, BartConfig
import numpy as np
import pandas as pd
import scienceplots
import matplotlib.pyplot as plt

plt.style.use(['science', 'notebook', 'grid', 'ieee'])

- Bart uses a standard seq2seq translation architecture with a bidirectional encoder (like Bert) and a left-to-right decoder (like GPT).

In [7]:
model_name = 'facebook/bart-large-xsum'
model = BartModel.from_pretrained(model_name)
tokenizer = BartTokenizer.from_pretrained(model_name)
config = BartConfig.from_pretrained(model_name)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.51k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

In [3]:
model

BartModel(
  (shared): Embedding(50265, 1024, padding_idx=1)
  (encoder): BartEncoder(
    (embed_tokens): Embedding(50265, 1024, padding_idx=1)
    (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
    (layers): ModuleList(
      (0-11): 12 x BartEncoderLayer(
        (self_attn): BartAttention(
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELUActivation()
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (fc2): Linear(in_features=4096, out_features=1024, bias=True)
        (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
    )
    (layerno

In [14]:
config

BartConfig {
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 1024,
  "model_type": "bart",
  "no_repeat_ngram_size":

In [16]:
tokenizer.special_tokens_map

{'bos_token': '<s>',
 'eos_token': '</s>',
 'unk_token': '<unk>',
 'sep_token': '</s>',
 'pad_token': '<pad>',
 'cls_token': '<s>',
 'mask_token': '<mask>'}

In [12]:
text = 'The Bart model was proposed in BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer on 29 Oct, 2019.'

batch_inputs = tokenizer(text, return_tensors='pt')
input_ids = batch_inputs['input_ids']
input_ids.shape
# (batch_size, sentence length)

torch.Size([1, 80])

In [17]:
model.eval()
encoder_outputs = model.get_encoder()(input_ids)
encoder_outputs, encoder_outputs.last_hidden_state.shape

(BaseModelOutput(last_hidden_state=tensor([[[ 0.0161,  0.0217,  0.0252,  ..., -0.0050, -0.0082, -0.0064],
          [-0.1113,  0.1126, -0.1159,  ..., -0.0657, -0.1892,  0.0033],
          [ 0.0120,  0.2809, -0.1886,  ..., -0.0165, -0.1217,  0.0316],
          ...,
          [ 0.0761, -0.2055, -0.3899,  ..., -0.2413, -0.0148, -0.1123],
          [-0.0069,  0.0174,  0.0142,  ...,  0.0099, -0.0057, -0.0035],
          [ 0.1635,  0.1311,  0.0902,  ..., -0.0374, -0.0921,  0.0596]]],
        grad_fn=<NativeLayerNormBackward0>), hidden_states=None, attentions=None),
 torch.Size([1, 80, 1024]))

In [36]:
decoder_input_ids, 

(tensor([[2]]),
 tensor([[[ 0.0161,  0.0217,  0.0252,  ..., -0.0050, -0.0082, -0.0064],
          [-0.1113,  0.1126, -0.1159,  ..., -0.0657, -0.1892,  0.0033],
          [ 0.0120,  0.2809, -0.1886,  ..., -0.0165, -0.1217,  0.0316],
          ...,
          [ 0.0761, -0.2055, -0.3899,  ..., -0.2413, -0.0148, -0.1123],
          [-0.0069,  0.0174,  0.0142,  ...,  0.0099, -0.0057, -0.0035],
          [ 0.1635,  0.1311,  0.0902,  ..., -0.0374, -0.0921,  0.0596]]],
        grad_fn=<NativeLayerNormBackward0>))

In [38]:
decoder_start_token = config.decoder_start_token_id
decoder_input_ids = torch.tensor([[decoder_start_token]], dtype=torch.long)

generated_tokens = []

for _ in range(100):
    decoder_outputs = model.get_decoder()(input_ids=decoder_input_ids,
                                          encoder_hidden_states=encoder_outputs.last_hidden_state)
    logits = decoder_outputs[0]
    next_token_id = logits.argmax(-1)[:, -1]
    decoder_input_ids = torch.cat([decoder_input_ids, next_token_id.unsqueeze(-1)], dim=-1)
    generated_tokens.append(next_token_id.item())
    if next_token_id.item() == tokenizer.eos_token_id:
        break
print(generated_tokens)

[1011, 509, 509, 509, 509, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 608, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 1011, 608, 1011, 608, 1011, 608, 1011, 608, 1011, 608, 1011, 608, 1011, 608, 1011, 608, 1011, 608, 1011, 608, 1011, 608, 1011, 1011, 1011, 1011, 1011, 1011]


In [40]:
output_text = tokenizer.decode(generated_tokens)
output_text

' ball One One One One ball ball ball ball ball ball ball ball ball ball ball ball ball ball doing ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball ball doing ball doing ball doing ball doing ball doing ball doing ball doing ball doing ball doing ball doing ball doing ball ball ball ball ball ball'

In [12]:
text = 'The Bart model was proposed in BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension by Mike Lewis, Yinhan Liu, Naman Goyal, Marjan Ghazvininejad, Abdelrahman Mohamed, Omer Levy, Ves Stoyanov and Luke Zettlemoyer on 29 Oct, 2019.'
model_inputs = tokenizer(text, return_tensors='pt')
input_ids = model_inputs.input_ids

n_steps = 10
top_x = 5


decoder_start_token = model.config.decoder_start_token_id
decoder_input_ids = torch.tensor([[decoder_start_token]], dtype=torch.long)

iterations = []
with torch.no_grad():
    for _ in range(n_steps):
        iteration = {}
        # the first row
        iteration['input'] = tokenizer.decode(input_ids[0])
        encoder_outputs = model.get_encoder()(input_ids)
        decoder_outputs = model.get_decoder()(input_ids=decoder_input_ids, 
                                              encoder_hidden_states=encoder_outputs.last_hidden_state)
        last_token_logits = decoder_outputs.last_hidden_state[0, -1, :]
        last_token_prob = torch.softmax(last_token_logits, dim=-1)
        sorted_ids = torch.argsort(last_token_prob, dim=-1, descending=True)
        for choice_idx in range(top_x):
            token_id = sorted_ids[top_x]
            token_prob = last_token_prob[token_id]
            token_choice = f'{tokenizer.decode(token_id)}({100*token_prob:.2f}%)'
            iteration[f'choice {choice_idx + 1}'] = token_choice
        
        print('before append input_ids.shape', input_ids.shape)
        input_ids = torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1)
        print('after append input_ids.shape', input_ids.shape)
        
        iterations.append(iteration)

before append input_ids.shape torch.Size([1, 80])
after append input_ids.shape torch.Size([1, 81])
before append input_ids.shape torch.Size([1, 81])
after append input_ids.shape torch.Size([1, 82])
before append input_ids.shape torch.Size([1, 82])
after append input_ids.shape torch.Size([1, 83])
before append input_ids.shape torch.Size([1, 83])
after append input_ids.shape torch.Size([1, 84])
before append input_ids.shape torch.Size([1, 84])
after append input_ids.shape torch.Size([1, 85])
before append input_ids.shape torch.Size([1, 85])
after append input_ids.shape torch.Size([1, 86])
before append input_ids.shape torch.Size([1, 86])
after append input_ids.shape torch.Size([1, 87])
before append input_ids.shape torch.Size([1, 87])
after append input_ids.shape torch.Size([1, 88])
before append input_ids.shape torch.Size([1, 88])
after append input_ids.shape torch.Size([1, 89])
before append input_ids.shape torch.Size([1, 89])
after append input_ids.shape torch.Size([1, 90])


In [13]:
pd.DataFrame(iterations)

Unnamed: 0,input,choice 1,choice 2,choice 3,choice 4,choice 5
0,<s>The Bart model was proposed in BART: Denois...,club(0.15%),club(0.15%),club(0.15%),club(0.15%),club(0.15%)
1,<s>The Bart model was proposed in BART: Denois...,club(0.16%),club(0.16%),club(0.16%),club(0.16%),club(0.16%)
2,<s>The Bart model was proposed in BART: Denois...,club(0.16%),club(0.16%),club(0.16%),club(0.16%),club(0.16%)
3,<s>The Bart model was proposed in BART: Denois...,club(0.15%),club(0.15%),club(0.15%),club(0.15%),club(0.15%)
4,<s>The Bart model was proposed in BART: Denois...,club(0.15%),club(0.15%),club(0.15%),club(0.15%),club(0.15%)
5,<s>The Bart model was proposed in BART: Denois...,club(0.15%),club(0.15%),club(0.15%),club(0.15%),club(0.15%)
6,<s>The Bart model was proposed in BART: Denois...,club(0.15%),club(0.15%),club(0.15%),club(0.15%),club(0.15%)
7,<s>The Bart model was proposed in BART: Denois...,club(0.15%),club(0.15%),club(0.15%),club(0.15%),club(0.15%)
8,<s>The Bart model was proposed in BART: Denois...,club(0.15%),club(0.15%),club(0.15%),club(0.15%),club(0.15%)
9,<s>The Bart model was proposed in BART: Denois...,club(0.15%),club(0.15%),club(0.15%),club(0.15%),club(0.15%)
