In [108]:
from transformers import BartForConditionalGeneration, BartModel
from transformers import AutoTokenizer
from transformers.models.bart.modeling_bart import shift_tokens_right
import torch

In [109]:
bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

In [110]:
bart.config

BartConfig {
  "_name_or_path": "facebook/bart-base",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 1024,
  "model_ty

In [111]:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

In [112]:
tokenizer

BartTokenizerFast(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})

In [273]:
s1 = ""
s2 = "this is sentence two"
s3 = "there are two sentences"
long1 = "this is somewhat longer sentence one"
long2 = "this is somewhat longer sentence two. It has a second sentence that add nothing. Really nothing. It could be summarized with one word: nothing"
long3 = "there are also somewhat longer sentences"

In [245]:
i1 = tokenizer(s1, return_tensors="pt")
i2 = tokenizer(s2, return_tensors="pt")
i3_solo = tokenizer(s3, return_tensors="pt")

In [246]:
i3 = tokenizer(s1, return_tensors="pt")

In [247]:
i1, i2, i3, i3_solo

({'input_ids': tensor([[0, 2]]), 'attention_mask': tensor([[1, 1]])},
 {'input_ids': tensor([[   0, 9226,   16, 3645,   80,    2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])},
 {'input_ids': tensor([[0, 2]]), 'attention_mask': tensor([[1, 1]])},
 {'input_ids': tensor([[    0,  8585,    32,    80, 11305,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])})

In [278]:
encoded = tokenizer([s1, long1], [s2, long2], text_target=[s3, long3], padding=True, return_tensors='pt')

In [274]:
encoded = tokenizer([long1, long2, long3], text_target=[s1, s2, s3], padding=True, return_tensors='pt')

In [282]:
encoded

{'input_ids': tensor([[    0,     2,     2,  9226,    16,  3645,    80,     2,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1],
        [    0,  9226,    16,  5568,  1181,  3645,    65,     2,     2,  9226,
            16,  5568,  1181,  3645,    80,     4,    85,    34,    10,   200,
          3645,    14,  1606,  1085,     4, 16923,  1085,     4,    85,   115,
            28, 38152,    19,    65,  2136,    35,  1085,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[    0,  8585,    32,    80, 11305,     2,     1,     1],
        [    0, 

In [265]:
out = bart(encoded['input_ids'], encoded['attention_mask'])

In [279]:
gen_out = bart.generate(encoded['input_ids'])

In [280]:
gen_out

tensor([[   2,    0,    2,    1,    1,    1,    1,    1,    1],
        [   2,    0, 9226,   16, 5568, 1181, 3645,   65,    2]])

In [281]:
tokenizer.batch_decode(gen_out)

['</s><s></s><pad><pad><pad><pad><pad><pad>',
 '</s><s>this is somewhat longer sentence one</s>']

In [270]:
out.logits.shape

torch.Size([3, 8, 50265])

In [267]:
encoded['input_ids'].shape

torch.Size([3, 8])

In [268]:
tokenizer.batch_decode(encoded['labels'])

['<s></s><pad><pad><pad><pad>',
 '<s>this is sentence two</s>',
 '<s>there are two sentences</s>']

In [121]:
shift_tokens_right(input_ids=i1['input_ids'], pad_token_id=99, decoder_start_token_id=100)

tensor([[ 100,    0, 9226,   16, 3645,   65]])

In [123]:
out = bart(**encoded)

In [128]:
out.logits.shape

torch.Size([3, 8, 50265])

In [34]:
i1

{'input_ids': tensor([[   0, 9226,   16, 3645,   65,    2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [148]:
out.logits[0].shape

torch.Size([8, 50265])

In [144]:
tokenizer.batch_decode(out.logits.argmax(dim=-1))

['<s>this is sentence sentences in</s> one',
 '<s>this is somewhat longer longer sentence one',
 '<s>this is sentence</s></s></s> two']

In [26]:
bart.state_dict()['pooler.dense.weight'].shape

KeyError: 'pooler.dense.weight'

In [27]:
bart.state_dict().keys()

odict_keys(['model.shared.weight', 'model.encoder.embed_tokens.weight', 'model.encoder.embed_positions.weight', 'model.encoder.layers.0.self_attn.k_proj.weight', 'model.encoder.layers.0.self_attn.k_proj.bias', 'model.encoder.layers.0.self_attn.v_proj.weight', 'model.encoder.layers.0.self_attn.v_proj.bias', 'model.encoder.layers.0.self_attn.q_proj.weight', 'model.encoder.layers.0.self_attn.q_proj.bias', 'model.encoder.layers.0.self_attn.out_proj.weight', 'model.encoder.layers.0.self_attn.out_proj.bias', 'model.encoder.layers.0.self_attn_layer_norm.weight', 'model.encoder.layers.0.self_attn_layer_norm.bias', 'model.encoder.layers.0.fc1.weight', 'model.encoder.layers.0.fc1.bias', 'model.encoder.layers.0.fc2.weight', 'model.encoder.layers.0.fc2.bias', 'model.encoder.layers.0.final_layer_norm.weight', 'model.encoder.layers.0.final_layer_norm.bias', 'model.encoder.layers.1.self_attn.k_proj.weight', 'model.encoder.layers.1.self_attn.k_proj.bias', 'model.encoder.layers.1.self_attn.v_proj.weigh

In [154]:
bart

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=

In [236]:
from models.bart_extractor import BartExtractor
from torchmetrics.functional import blue_score

ModuleNotFoundError: No module named 'torchmetrics'

In [173]:
model = BartExtractor('facebook/bart-base')

In [174]:
model.load_state_dict(torch.load("checkpoints/test"))

<All keys matched successfully>

In [258]:
utterance = tokenizer("Do you have hobbies", "Yes, I like reading", text_target="", return_tensors="pt")

In [234]:
out = model(utterance['input_ids'], utterance['attention_mask'], labels=utterance['labels'])
gen_out = model.bart.generate(utterance['input_ids'])

In [235]:
gen_out

tensor([[2, 0, 2]])

In [225]:
pred = out.argmax(dim=-1)

In [226]:
pred

tensor([[   0,  100,  101, 2600,    2]])

In [259]:
utterance

{'input_ids': tensor([[    0,  8275,    47,    33, 36365,     2,     2,  9904,     6,    38,
           101,  2600,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'labels': tensor([[0, 2]])}

In [228]:
tokenizer.batch_decode(pred)

['<s>I like reading</s>']

In [229]:
utterance['labels']

tensor([[   0,  100,  101, 2600,    2]])

In [254]:
pred.eq(utterance['labels'])

tensor([[True, True, True, True, True]])

In [231]:
loss_fn = torch.nn.NLLLoss(ignore_index=model.bart.config.pad_token_id, reduction='mean')

In [232]:
loss_fn(out.transpose(1,2), utterance['labels'])

tensor(0.0940, grad_fn=<NllLoss2DBackward0>)

In [237]:
import torchmetrics

ModuleNotFoundError: No module named 'torchmetrics'

In [262]:
tokenizer.convert_tokens_to_string(tokenizer.batch_decode([[1]]))

'<pad>'

In [243]:
tokenizer.batch_decode(pred, skip_special_tokens=True, clean_up_tokenization_spaces=True)

['I like reading']

In [255]:
i1['input_ids']

tensor([[0, 2]])

In [256]:
pred

tensor([[   0,  100,  101, 2600,    2]])