In [1]:
from transformers import BartForConditionalGeneration, BartModel
from transformers import AutoTokenizer
from transformers import GenerationConfig
from transformers.models.bart.modeling_bart import shift_tokens_right
import torch

In [2]:
bart = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

In [3]:
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

In [4]:
bart.config

BartConfig {
  "_name_or_path": "facebook/bart-base",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_position_embeddings": 1024,
  "model_ty

In [5]:
bart

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=

In [6]:
bart.state_dict().keys()

odict_keys(['final_logits_bias', 'model.shared.weight', 'model.encoder.embed_tokens.weight', 'model.encoder.embed_positions.weight', 'model.encoder.layers.0.self_attn.k_proj.weight', 'model.encoder.layers.0.self_attn.k_proj.bias', 'model.encoder.layers.0.self_attn.v_proj.weight', 'model.encoder.layers.0.self_attn.v_proj.bias', 'model.encoder.layers.0.self_attn.q_proj.weight', 'model.encoder.layers.0.self_attn.q_proj.bias', 'model.encoder.layers.0.self_attn.out_proj.weight', 'model.encoder.layers.0.self_attn.out_proj.bias', 'model.encoder.layers.0.self_attn_layer_norm.weight', 'model.encoder.layers.0.self_attn_layer_norm.bias', 'model.encoder.layers.0.fc1.weight', 'model.encoder.layers.0.fc1.bias', 'model.encoder.layers.0.fc2.weight', 'model.encoder.layers.0.fc2.bias', 'model.encoder.layers.0.final_layer_norm.weight', 'model.encoder.layers.0.final_layer_norm.bias', 'model.encoder.layers.1.self_attn.k_proj.weight', 'model.encoder.layers.1.self_attn.k_proj.bias', 'model.encoder.layers.1.s

In [7]:
tokenizer

BartTokenizerFast(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)})

In [8]:
s1 = ""
s2 = "this is sentence two"
s3 = "there are two sentences"
long1 = "this is somewhat longer sentence one"
long2 = "this is somewhat longer sentence two. It has a second sentence that add nothing. Really nothing. It could be summarized with one word: nothing"
long3 = "there are also somewhat longer sentences"

In [9]:
i1 = tokenizer(s1, return_tensors="pt")
i2 = tokenizer(s2, return_tensors="pt")
i3 = tokenizer(s1, s2, return_tensors="pt")
i3_solo = tokenizer(s3, return_tensors="pt")

In [10]:
# Tokenizing two sentences connects them together, with <eos> <eos> between them

i1, i2, i3, i3_solo

({'input_ids': tensor([[0, 2]]), 'attention_mask': tensor([[1, 1]])},
 {'input_ids': tensor([[   0, 9226,   16, 3645,   80,    2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])},
 {'input_ids': tensor([[   0,    2,    2, 9226,   16, 3645,   80,    2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])},
 {'input_ids': tensor([[    0,  8585,    32,    80, 11305,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])})

In [11]:
encoded = tokenizer([s1, long1], [s2, long2], text_target=[s3, long3], padding=True, return_tensors='pt')

In [12]:
encoded = tokenizer([long1, long2, long3], text_target=[s1, s2, s3], padding=True, return_tensors='pt')

In [13]:
encoded

{'input_ids': tensor([[    0,  9226,    16,  5568,  1181,  3645,    65,     2,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [    0,  9226,    16,  5568,  1181,  3645,    80,     4,    85,    34,
            10,   200,  3645,    14,  1606,  1085,     4, 16923,  1085,     4,
            85,   115,    28, 38152,    19,    65,  2136,    35,  1085,     2],
        [    0,  8585,    32,    67,  5568,  1181, 11305,     2,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0

In [14]:
out = bart(encoded['input_ids'], encoded['attention_mask'])

In [15]:
gen_out = bart.generate(encoded['input_ids'], max_new_tokens=30)

In [16]:
gen_out

tensor([[    2,     0,  9226,    16,  5568,  1181,  3645,    65,     2,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1],
        [    2,     0,  9226,    16,  5568,  1181,  3645,    80,     4,    85,
            34,    10,   200,  3645,    14,  1606,  1085,     4, 16923,  1085,
             4,    85,   115,    28, 38152,    19,    65,  2136,    35,  1085,
             2],
        [    2,     0,  8585,    32,    67,  5568,  1181, 11305,     2,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1]])

In [17]:
tokenizer.batch_decode(gen_out)

['</s><s>this is somewhat longer sentence one</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>',
 '</s><s>this is somewhat longer sentence two. It has a second sentence that add nothing. Really nothing. It could be summarized with one word: nothing</s>',
 '</s><s>there are also somewhat longer sentences</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>']

In [18]:
out.logits.shape

torch.Size([3, 30, 50265])

In [19]:
encoded['input_ids'].shape

torch.Size([3, 30])

In [20]:
tokenizer.batch_decode(encoded['labels'])

['<s></s><pad><pad><pad><pad>',
 '<s>this is sentence two</s>',
 '<s>there are two sentences</s>']

In [21]:
shift_tokens_right(input_ids=i1['input_ids'], pad_token_id=99, decoder_start_token_id=100)

tensor([[100,   0]])

In [22]:
out = bart(**encoded)

In [23]:
out.logits.shape

torch.Size([3, 6, 50265])

In [24]:
out.logits[0].shape

torch.Size([6, 50265])

In [25]:
tokenizer.batch_decode(out.logits.argmax(dim=-1))

['<s>this</s> is is</s>',
 '<s>this is somewhat two.',
 '<s>there are also somewhat</s>']

In [26]:
from models.bart_extractor import BartExtractor, ExtractedFactLoss
from torcheval.metrics.functional import binary_confusion_matrix, binary_accuracy, binary_f1_score, bleu_score
from dataset.msc_summary_turns import MSC_Turns

In [27]:
speaker_prefixes  = ['<self>', '<other>']
nofact_token = '<nofact>'

tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-cnn')
tokenizer.add_tokens(speaker_prefixes + [nofact_token])
vocab_size = tokenizer.vocab_size
pad_token_id = tokenizer.pad_token_id
start_token_id = tokenizer.eos_token_id
nofact_token_id = tokenizer.convert_tokens_to_ids(nofact_token)
model = BartExtractor("facebook/bart-large-cnn", nofact_token_id=nofact_token_id)
# model = BartExtractor(nofact_token_id=nofact_token_id)
model.bart.resize_token_embeddings(len(tokenizer))
criterion = ExtractedFactLoss(nofact_token_id=nofact_token_id, ignore_index=tokenizer.pad_token_id)

In [28]:
# model.load_state_dict(torch.load("../checkpoints/trained_nll05_bart", map_location=torch.device('cpu')))

In [29]:
utterance = tokenizer("<self> Do you have hobbies. <other> Yes, I like to read books", text_target="I like to read books", return_tensors="pt")

In [30]:
utterance = tokenizer("<self> Do you have hobbies. <other> Haha hobbies, why do you ask?", text_target="<nofact>", return_tensors="pt")

In [31]:
ARTICLE_TO_SUMMARIZE = (
    "I said Do you have hobbies. You said Yes, I like reading about PG&E "
    "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
    "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
    "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
)
article = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, truncation=True, return_tensors="pt")

In [32]:
gen_config = GenerationConfig(
    min_new_tokens=2,
    max_new_tokens=20,
    early_stopping=True,
    no_repeat_ngram_size=3,
    num_beams=4,
)

In [33]:
gen_out = model.bart.generate(
    input_ids=utterance["input_ids"],
    min_length=1,
    max_new_tokens=100,
    num_beams=1,
    do_sample=False,
    return_dict_in_generate=True, 
    # decoder_start_token_id=model.bart.config.eos_token_id,
    # generation_config=model.gen_config
)
tokenizer.batch_decode(gen_out['sequences'], skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]

'</s><s><s><s>Haha Haha Hahaha Hahahaha. Do you have hobbies? Do you know what they are? Share your photos and videos with CNN iReport.</s>'

In [34]:
out = model.bart(utterance['input_ids'], utterance['attention_mask'], labels=utterance['labels'], return_dict=True)
pred = out.logits.argmax(dim=-1)
print(torch.nn.functional.log_softmax(out.logits, dim=-1))
print(torch.nn.functional.log_softmax(out.logits, dim=-1).argmax(dim=-1))

tensor([[[ -1.1260, -11.7983,  -9.6329,  ..., -12.5407, -12.2553, -13.4364],
         [ -1.1260, -11.7983,  -9.6329,  ..., -12.5407, -12.2553, -13.4364],
         [-17.4144, -12.0196, -10.0932,  ..., -12.0419, -11.9569, -12.7407]]],
       grad_fn=<LogSoftmaxBackward0>)
tensor([[0, 0, 4]])


In [35]:
out = model(utterance['input_ids'], utterance['attention_mask'], labels=utterance['labels'])
pred = out.argmax(dim=-1)
print(out)
print(out.argmax(dim=-1))

tensor([[[ -1.1260, -11.7983,  -9.6329,  ..., -12.5407, -12.2553, -13.4364],
         [ -1.1260, -11.7983,  -9.6329,  ..., -12.5407, -12.2553, -13.4364],
         [-17.4144, -12.0196, -10.0932,  ..., -12.0419, -11.9569, -12.7407]]],
       grad_fn=<LogSoftmaxBackward0>)
tensor([[0, 0, 4]])


In [36]:
tokenizer.batch_decode(gen_out['sequences'])

['</s><s><s><s>Haha Haha Hahaha Hahahaha. Do you have hobbies? Do you know what they are? Share your photos and videos with CNN iReport.</s>']