In [1]:
import torch
from transformers import BartForConditionalGeneration, BartTokenizer, GenerationConfig
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [3]:
version = "facebook/bart-large-cnn"
ARTICLE_TO_SUMMARIZE = (
    "PG&E stated it scheduled the blackouts in response to forecasts for high winds "
    "amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were "
    "scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."
)

# BartTokenizer

In [4]:
tokenizer: BartTokenizer = BartTokenizer.from_pretrained(version)
tokenizer

BartTokenizer(name_or_path='facebook/bart-large-cnn', vocab_size=50265, model_max_length=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'sep_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'cls_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True)}, clean_up_tokenization_spaces=True)

# BartForConditionalGeneration

The BART Model with a language modeling head. Can be used for summarization.

In [5]:
model: BartForConditionalGeneration = BartForConditionalGeneration.from_pretrained(version, torch_dtype=torch.float16).to(device)
model

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerN

## Summarization example

In [7]:
inputs = tokenizer(ARTICLE_TO_SUMMARIZE, return_tensors = "pt").to(device, torch.float16)

print(inputs.keys())
print(inputs["input_ids"])
print(inputs["attention_mask"])

dict_keys(['input_ids', 'attention_mask'])
tensor([[    0,  8332,   947,   717,  2305,    24,  1768,     5,   909,  4518,
            11,  1263,     7,  5876,    13,   239,  2372,  2876,  3841,  1274,
             4,    20,  4374,    16,     7,  1888,     5,   810,     9, 12584,
             4,  9221,  5735,  7673,   916,    58,  1768,     7,    28,  2132,
            30,     5,  2572, 10816,    61,    58,   421,     7,    94,   149,
            23,   513, 15372,  3859,     4,     2]], device='cuda:0')
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')


In [8]:
model.eval()
with torch.inference_mode():
    # 可以只输入input而不输入decoder
    outputs = model.generate(
        input_ids = inputs["input_ids"],
        attention_mask = inputs["attention_mask"],
        # num_beams: beam search num
        generation_config = GenerationConfig(num_beams=2, min_length=0, max_new_tokens=100),
    )

In [9]:
print(outputs)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False))

tensor([[    2,     0,  8332,   947,   717,  1768,     5,   909,  4518,    11,
          1263,     7,  5876,    13,   239,  2372,  2876,  3841,  1274,     4,
            20,  4374,    16,     7,  1888,     5,   810,     9, 12584,     4,
          9221,  5735,  7673,   916,    58,  1768,     7,    28,  2132,    30,
             5,  2572, 10816,     4,     2]], device='cuda:0')
['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs.']


In [10]:
tokenizer.batch_decode([2])

['</s>']

## Mask filling example

In [11]:
TXT = "My friends are <mask> but they eat too many carbs."

In [12]:
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")

In [13]:
inputs = tokenizer(TXT, return_tensors="pt").to(device)

print(inputs.keys())
print(inputs["input_ids"])
print(inputs["attention_mask"])

dict_keys(['input_ids', 'attention_mask'])
tensor([[    0,  2387,   964,    32, 50264,    53,    51,  3529,   350,   171,
         33237,     4,     2]], device='cuda:0')
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')


In [14]:
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base").to(device)

In [15]:
with torch.inference_mode():
    outputs = model(inputs["input_ids"])
outputs

Seq2SeqLMOutput(loss=None, logits=tensor([[[34.6802,  6.9836, 16.3400,  ...,  7.2211,  7.0561,  2.3689],
         [ 5.8346, -3.1657, 14.5902,  ..., -1.0895, -0.5840,  1.0089],
         [-7.2457, -5.1547,  6.5578,  ..., -4.3999, -3.6170,  0.2360],
         ...,
         [-3.2509, -4.5113,  6.5780,  ..., -4.1401, -3.9725,  1.0122],
         [-0.5660, -4.0633, 13.1334,  ..., -1.8695, -1.9459,  1.2657],
         [-1.5064, -3.5157, 22.3739,  ..., -3.6222, -3.3043, -0.4648]]],
       device='cuda:0'), past_key_values=((tensor([[[[-1.1472e-01, -5.2203e-01, -6.7271e-01,  ...,  3.4312e-01,
            7.5881e-02, -2.4253e-02],
          [ 5.4507e-02,  3.6397e-01, -7.1711e-01,  ...,  8.1498e-02,
            5.4005e-03,  3.4209e-01],
          [-5.1239e-01, -4.8073e-01,  2.9027e-02,  ...,  5.4433e-01,
           -2.0813e+00,  1.9986e+00],
          ...,
          [-2.4587e-01,  5.6785e-02,  2.3159e-01,  ..., -5.0019e-01,
           -1.1116e+00,  3.7034e-01],
          [-6.2127e-01, -1.8716e+00,  

In [16]:
logits = outputs.logits
print(logits.shape)
print(logits)

torch.Size([1, 13, 50265])
tensor([[[34.6802,  6.9836, 16.3400,  ...,  7.2211,  7.0561,  2.3689],
         [ 5.8346, -3.1657, 14.5902,  ..., -1.0895, -0.5840,  1.0089],
         [-7.2457, -5.1547,  6.5578,  ..., -4.3999, -3.6170,  0.2360],
         ...,
         [-3.2509, -4.5113,  6.5780,  ..., -4.1401, -3.9725,  1.0122],
         [-0.5660, -4.0633, 13.1334,  ..., -1.8695, -1.9459,  1.2657],
         [-1.5064, -3.5157, 22.3739,  ..., -3.6222, -3.3043, -0.4648]]],
       device='cuda:0')


In [17]:
# 获取输入中的mask的id
masked_index = (inputs["input_ids"][0] == tokenizer.mask_token_id).nonzero().item()
masked_index

4

In [18]:
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(5)
values, predictions

(tensor([0.0929, 0.0917, 0.0855, 0.0579, 0.0412], device='cuda:0'),
 tensor([  45,  205, 2245,  372,  182], device='cuda:0'))

In [19]:
tokenizer.decode(predictions).split()

['not', 'good', 'healthy', 'great', 'very']

# AutoTokenizer

In [20]:
tokenizer = AutoTokenizer.from_pretrained(version)
tokenizer

BartTokenizerFast(name_or_path='facebook/bart-large-cnn', vocab_size=50265, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'sep_token': '</s>', 'pad_token': '<pad>', 'cls_token': '<s>', 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=False)}, clean_up_tokenization_spaces=True)

In [21]:
inputs = tokenizer(ARTICLE_TO_SUMMARIZE, return_tensors = "pt").to(device)

print(inputs.keys())
print(inputs["input_ids"])
print(inputs["attention_mask"])

dict_keys(['input_ids', 'attention_mask'])
tensor([[    0,  8332,   947,   717,  2305,    24,  1768,     5,   909,  4518,
            11,  1263,     7,  5876,    13,   239,  2372,  2876,  3841,  1274,
             4,    20,  4374,    16,     7,  1888,     5,   810,     9, 12584,
             4,  9221,  5735,  7673,   916,    58,  1768,     7,    28,  2132,
            30,     5,  2572, 10816,    61,    58,   421,     7,    94,   149,
            23,   513, 15372,  3859,     4,     2]], device='cuda:0')
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')


# AutoModelForSeq2SeqLM

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(version, torch_dtype=torch.float16).to(device)
model

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0-11): 12 x BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerN

In [None]:
model.eval()
with torch.inference_mode():
    # 可以只输入input而不输入decoder
    summary_ids = model.generate(
        input_ids = inputs["input_ids"],
        attention_mask = inputs["attention_mask"],
        # num_beams: beam search num
        generation_config = GenerationConfig(num_beams=2, min_length=0, max_new_tokens=100),
    )

In [None]:
print(summary_ids)
print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True))

tensor([[    2,     0,  8332,   947,   717,  1768,     5,   909,  4518,    11,
          1263,     7,  5876,    13,   239,  2372,  2876,  3841,  1274,     4,
            20,  4374,    16,     7,  1888,     5,   810,     9, 12584,     4,
          9221,  5735,  7673,   916,    58,  1768,     7,    28,  2132,    30,
             5,  2572, 10816,     4,     2]], device='cuda:0')
['PG&E scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs.']
