In [60]:
import torch
from transformers import RwkvForCausalLM, AutoTokenizer

In [61]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [62]:
version = "RWKV/rwkv-4-169m-pile"
sequence = "The quick brown fox jumps over the lazy dog."
max_length = 20

# AutoTokenizer

In [63]:
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(version)
tokenizer

GPTNeoXTokenizerFast(name_or_path='RWKV/rwkv-4-169m-pile', vocab_size=50254, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<|padding|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	50254: AddedToken("                        ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50255: AddedToken("                       ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50256: AddedToken("                      ", rstrip=False, lstrip=False, single_word=False, normalized=True, special=False),
	50257: AddedToken("                     ", r

## tokenizer([sequence])

In [64]:
inputs = tokenizer(
    [sequence] * 2,  # 句子batch
    truncation=True,  # 超出max_length截断处理
    # padding = True,                   # 填充方式选择 [True, 'longest', 'max_length', 'do_not_pad']
    # max_length = max_length,          # 最长长度,不设置默认为模型最大长度
    add_special_tokens=True,  # text添加特殊key
    return_length=True,  # 返回有效长度
    return_overflowing_tokens=False,  # 返回所有的文本片段（由于文本比较长，默认情况下超过预设截断长度的token会被丢失。如果设置了return_overflowing_tokens=True则会返回所有的token片段）。
    return_tensors="pt",  # 返回数据格式 np pt tf jax
).to(device, torch.float16)  # https://github.com/huggingface/transformers/issues/16359

print(inputs.keys())
print(inputs["input_ids"])
print(inputs["attention_mask"])  # 对应是否是文字
print(inputs["length"])  # 对应有效文字长度

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


dict_keys(['input_ids', 'attention_mask', 'length'])
tensor([[  510,  3158,  8516, 30013, 27287,   689,   253, 22658,  4370,    15],
        [  510,  3158,  8516, 30013, 27287,   689,   253, 22658,  4370,    15]],
       device='cuda:0')
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')
tensor([10, 10], device='cuda:0')


In [65]:
print(inputs["input_ids"])

tensor([[  510,  3158,  8516, 30013, 27287,   689,   253, 22658,  4370,    15],
        [  510,  3158,  8516, 30013, 27287,   689,   253, 22658,  4370,    15]],
       device='cuda:0')


# RwkvForCausalLM

The RWKV Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).

In [66]:
model: RwkvForCausalLM = RwkvForCausalLM.from_pretrained(
    version, torch_dtype=torch.float16
).to(device)
model

RwkvForCausalLM(
  (rwkv): RwkvModel(
    (embeddings): Embedding(50277, 768)
    (blocks): ModuleList(
      (0): RwkvBlock(
        (pre_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attention): RwkvSelfAttention(
          (time_shift): ZeroPad2d((0, 0, 1, -1))
          (key): Linear(in_features=768, out_features=768, bias=False)
          (value): Linear(in_features=768, out_features=768, bias=False)
          (receptance): Linear(in_features=768, out_features=768, bias=False)
          (output): Linear(in_features=768, out_features=768, bias=False)
        )
        (feed_forward): RwkvFeedForward(
          (time_shift): ZeroPad2d((0, 0, 1, -1))
          (key): Linear(in_features=768, out_features=3072, bias=False)
          (receptance): Linear(in_features=768, out_features=768, bias=False)
          (value): Linear(in_

In [67]:
model.eval()
with torch.inference_mode():
    outputs = model(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
    )
outputs

RwkvCausalLMOutput(loss=None, logits=tensor([[[ -3.9653, -15.6509,  -4.6795,  ...,  -1.4754,  -1.3797,  -0.1122],
         [ -5.2117, -21.4499,  -0.4841,  ...,  -5.8009,  -4.5592,  -2.7639],
         [ -7.0354, -28.9409,  -5.8369,  ...,  -9.0555,  -5.5854,  -5.3534],
         ...,
         [ -9.0419, -32.5494,  -5.0947,  ..., -13.2611,  -9.4912,  -6.6047],
         [  1.6691, -26.1559,   3.0113,  ...,  -2.5428,  -0.9998,   0.7052],
         [  2.5649, -23.1948,  -2.9174,  ...,  -1.0491,   0.2661,   3.2473]],

        [[ -3.9653, -15.6509,  -4.6795,  ...,  -1.4754,  -1.3797,  -0.1122],
         [ -5.2117, -21.4499,  -0.4841,  ...,  -5.8009,  -4.5592,  -2.7639],
         [ -7.0354, -28.9409,  -5.8369,  ...,  -9.0555,  -5.5854,  -5.3534],
         ...,
         [ -9.0419, -32.5494,  -5.0947,  ..., -13.2611,  -9.4912,  -6.6047],
         [  1.6691, -26.1559,   3.0113,  ...,  -2.5428,  -0.9998,   0.7052],
         [  2.5649, -23.1948,  -2.9174,  ...,  -1.0491,   0.2661,   3.2473]]],
       

In [68]:
# 分类
logits = outputs.logits
logits.shape

torch.Size([2, 10, 50277])

In [69]:
inputs["input_ids"]

tensor([[  510,  3158,  8516, 30013, 27287,   689,   253, 22658,  4370,    15],
        [  510,  3158,  8516, 30013, 27287,   689,   253, 22658,  4370,    15]],
       device='cuda:0')

In [70]:
sequence

'The quick brown fox jumps over the lazy dog.'

In [71]:
pred_idx = logits.argmax(dim=-1)
pred_idx

tensor([[  806,   383, 30013, 16780,   689,   253, 22658,  4370,    13,   187],
        [  806,   383, 30013, 16780,   689,   253, 22658,  4370,    13,   187]],
       device='cuda:0')

In [72]:
tokenizer.batch_decode(pred_idx)

[' firstest fox jumped over the lazy dog,\n',
 ' firstest fox jumped over the lazy dog,\n']

In [73]:
import plotly.express as px

In [81]:
pred = logits[0, -1].softmax(dim=0)
px.line(y=pred.cpu(), labels={"x": "index", "y": "confidence"})