https://huggingface.co/docs/transformers/main/model_doc/bart

In [16]:
import torch
from transformers import BartModel, BartTokenizer

In [17]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [18]:
version = "facebook/bart-base"
sequence = "The quick brown fox jumps over the lazy dog."
max_length = 20

# BartTokenizer

In [19]:
tokenizer: BartTokenizer = BartTokenizer.from_pretrained(version)
tokenizer

BartTokenizer(name_or_path='facebook/bart-base', vocab_size=50265, model_max_length=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'sep_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'cls_token': AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'mask_token': AddedToken("<mask>", rstrip=False, lstrip=True, single_word=False, normalized=True)}, clean_up_tokenization_spaces=True)

## tokenizer([sequence])

In [20]:
inputs = tokenizer(
    sequence,                           # 句子batch
    truncation = True,                  # 超出max_length截断处理
    padding = True,                     # 填充方式选择 [True, 'longest', 'max_length', 'do_not_pad']
    # max_length = max_length,          # 最长长度,不设置默认为模型最大长度
    add_special_tokens = True,          # text添加特殊key
    return_length = True,               # 返回有效长度
    return_overflowing_tokens = False,  # 返回所有的文本片段（由于文本比较长，默认情况下超过预设截断长度的token会被丢失。如果设置了return_overflowing_tokens=True则会返回所有的token片段）。
    return_tensors = "pt"               # 返回数据格式 np pt tf jax
).to(device, torch.float16)    # https://github.com/huggingface/transformers/issues/16359

print(inputs.keys())
print(inputs["input_ids"])
print(inputs["attention_mask"]) # 对应是否是文字
print(inputs["length"])         # 对应有效文字长度

dict_keys(['input_ids', 'attention_mask', 'length'])
tensor([[    0,   133,  2119,  6219, 23602, 13855,    81,     5, 22414,  2335,
             4,     2]], device='cuda:0')
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')
tensor([12], device='cuda:0')


In [21]:
print(inputs["input_ids"])

tensor([[    0,   133,  2119,  6219, 23602, 13855,    81,     5, 22414,  2335,
             4,     2]], device='cuda:0')


In [22]:
inputs["input_ids"].device

device(type='cuda', index=0)

# BartModel

The bare BART Model outputting raw hidden-states without any specific head on top.

In [23]:
model: BartModel = BartModel.from_pretrained(version, torch_dtype=torch.float16).to(device)
model

BartModel(
  (shared): Embedding(50265, 768, padding_idx=1)
  (encoder): BartEncoder(
    (embed_tokens): Embedding(50265, 768, padding_idx=1)
    (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
    (layers): ModuleList(
      (0-5): 6 x BartEncoderLayer(
        (self_attn): BartAttention(
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (activation_fn): GELUActivation()
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
    (layernorm_embedding): La

In [24]:
model.eval()
with torch.inference_mode():
    outputs = model(
        input_ids = inputs["input_ids"],
        attention_mask = inputs["attention_mask"],
    )
outputs
# Seq2SeqModelOutput

Seq2SeqModelOutput(last_hidden_state=tensor([[[ 2.4163,  2.4015,  1.2924,  ...,  1.7528,  0.0943, -0.6025],
         [ 2.2920, -0.4007, -0.6358,  ...,  0.9390, -1.4969, -0.2737],
         [ 1.3135,  1.2932, -0.7050,  ..., -1.1133, -0.9641, -0.3233],
         ...,
         [-1.0830,  0.6670,  0.2067,  ...,  0.9630, -0.8360, -1.8732],
         [ 0.3591, -0.8059, -0.9604,  ...,  0.9311, -2.4044, -0.1415],
         [ 0.1652, -0.8006, -1.3691,  ...,  0.7151, -2.0324,  0.0576]]],
       device='cuda:0'), past_key_values=((tensor([[[[-1.1472e-01, -5.2203e-01, -6.7271e-01,  ...,  3.4312e-01,
            7.5881e-02, -2.4253e-02],
          [ 5.4507e-02,  3.6397e-01, -7.1711e-01,  ...,  8.1498e-02,
            5.4005e-03,  3.4209e-01],
          [-1.3369e+00, -2.2861e-01, -3.9312e-02,  ...,  2.3108e-01,
           -1.4632e+00,  1.9317e+00],
          ...,
          [-1.4486e+00, -6.2776e-01,  5.1208e-01,  ..., -2.6631e-01,
           -1.2960e+00,  4.9432e-03],
          [ 3.9104e-01, -4.9748e-01

In [25]:
# 最后一层的输出
outputs.last_hidden_state.shape

torch.Size([1, 12, 768])

In [26]:
outputs.last_hidden_state

tensor([[[ 2.4163,  2.4015,  1.2924,  ...,  1.7528,  0.0943, -0.6025],
         [ 2.2920, -0.4007, -0.6358,  ...,  0.9390, -1.4969, -0.2737],
         [ 1.3135,  1.2932, -0.7050,  ..., -1.1133, -0.9641, -0.3233],
         ...,
         [-1.0830,  0.6670,  0.2067,  ...,  0.9630, -0.8360, -1.8732],
         [ 0.3591, -0.8059, -0.9604,  ...,  0.9311, -2.4044, -0.1415],
         [ 0.1652, -0.8006, -1.3691,  ...,  0.7151, -2.0324,  0.0576]]],
       device='cuda:0')

In [27]:
len(outputs.past_key_values)

6

In [28]:
for past_key in outputs.past_key_values:
    for past in past_key:
        print(past.shape)
    print("-" * 25)

torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
-------------------------
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
-------------------------
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
-------------------------
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
-------------------------
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
-------------------------
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
torch.Size([1, 12, 12, 64])
-------------------------
