In [1]:
%load_ext lab_black
%load_ext autoreload
%autoreload 2

In [3]:
import sys
from pathlib import Path

sys.path.append("..")

In [4]:
import argparse

In [16]:
import numpy as np
import pandas as pd
import torch
import transformers

In [6]:
from transformers import BertLMHeadModel, BertTokenizerFast

In [91]:
import settings
from funcs.data_module import DataModule
from funcs.model_module import ModelModule
from funcs.utils import find_project_root
from scripts.train import create_parser
from funcs.text_generation import generate_poetry_text

In [10]:
ROOT = find_project_root()

In [11]:
MODEL_NAME = settings.chinese_bert_model_name
CHECKPOINT_PATH = (
    ROOT / "lightning_logs" / "version_3" / "checkpoints" / "epoch=19.ckpt"
)
assert CHECKPOINT_PATH.exists()

In [12]:
flags = ""
parser = create_parser()
args = parser.parse_args(flags)
print(args)

Namespace(adam_epsilon=1e-08, batch_size=64, dry_run=False, fp16=False, gpus=1, learning_rate=0.001, max_tokenization_length=128, num_train_epochs=1, num_workers=8, overwrite=False, seed=42, weight_decay=0.0)


In [13]:
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
tokenizer

<transformers.tokenization_bert.BertTokenizerFast at 0x7fe69156ca10>

In [14]:
finetuned_model = ModelModule(args=args).load_from_checkpoint(
    str(CHECKPOINT_PATH), args=args
)
finetuned_model

2020-10-21 14:42:36.761 | INFO     | funcs.model_module:__init__:40 - model hparams: "adam_epsilon":     1e-08
"learning_rate":    0.001
"num_train_epochs": 1
"weight_decay":     0.0
If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`
Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertLMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertLMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
2020-10-21 14:42:42.434 | INFO     | funcs.model_module:__init__:40 - model hparams: "adam_epsilon":     1e

ModelModule(
  (bert_model): BertLMHeadModel(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(21128, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
        

In [15]:
data_module = DataModule()
data_module.setup()

2020-10-21 14:42:47.363 | INFO     | funcs.data_module:__init__:46 - data module hparams: {'max_tokenization_length': 128, 'batch_size': 16, 'num_workers': 2, 'min_word_frequency': 8}
2020-10-21 14:42:47.365 | INFO     | funcs.data_module:setup:49 - Loading train dataset
2020-10-21 14:42:47.368 | INFO     | funcs.data_module:get_dataset:84 - load from path: /work/ik18445/projects/yapg/datasets/output/poetry_128.pt


In [63]:
poetry_vocab = pd.read_csv(ROOT / settings.path_to_chinese_poetry_vocab)
poetry_vocab

Unnamed: 0,token,id
0,砺,4791
1,盂,4655
2,均,1772
3,模,3563
4,燔,4241
...,...,...
4819,驮,7718
4820,泠,3795
4821,惶,2684
4822,逢,6864


In [92]:
starting_text = "床前明月光"
text = generate_poetry_text(
    tokenizer=tokenizer,
    model=finetuned_model,
    starting_text=starting_text,
    poetry_vocab_ids=poetry_vocab["id"].tolist(),
    verbose=True,
    max_doc_length=20,
)

step #0


  top_k_probs = torch.nn.functional.softmax(top_k_logits).numpy()


床前明月光向
step #1
床前明月光向何
step #2
床前明月光向何知
step #3
床前明月光向何知不
step #4
床前明月光向何知不，
step #5
床前明月光向何知不，年
step #6
床前明月光向何知不，年吟
step #7
床前明月光向何知不，年吟天
step #8
床前明月光向何知不，年吟天下
step #9
床前明月光向何知不，年吟天下生
step #10
床前明月光向何知不，年吟天下生来
step #11
床前明月光向何知不，年吟天下生来君
step #12
床前明月光向何知不，年吟天下生来君三
step #13
床前明月光向何知不，年吟天下生来君三头
step #14
床前明月光向何知不，年吟天下生来君三头此
step #15
床前明月光向何知不，年吟天下生来君三头此明
step #16
床前明月光向何知不，年吟天下生来君三头此明花
step #17
床前明月光向何知不，年吟天下生来君三头此明花不
step #18
床前明月光向何知不，年吟天下生来君三头此明花不为
step #19
床前明月光向何知不，年吟天下生来君三头此明花不为上


In [93]:
starting_text = "数据挖掘哪家强"
text = generate_poetry_text(
    tokenizer=tokenizer,
    model=finetuned_model,
    starting_text=starting_text,
    poetry_vocab_ids=poetry_vocab["id"].tolist(),
    verbose=True,
    max_doc_length=20,
)

step #0


  # if verbose:


数据挖掘哪家强。
step #1
数据挖掘哪家强。君
step #2
数据挖掘哪家强。君吟
step #3
数据挖掘哪家强。君吟思
step #4
数据挖掘哪家强。君吟思日
step #5
数据挖掘哪家强。君吟思日。
step #6
数据挖掘哪家强。君吟思日。三
step #7
数据挖掘哪家强。君吟思日。三上
step #8
数据挖掘哪家强。君吟思日。三上名
step #9
数据挖掘哪家强。君吟思日。三上名何
step #10
数据挖掘哪家强。君吟思日。三上名何水
step #11
数据挖掘哪家强。君吟思日。三上名何水前
step #12
数据挖掘哪家强。君吟思日。三上名何水前秋
step #13
数据挖掘哪家强。君吟思日。三上名何水前秋时
step #14
数据挖掘哪家强。君吟思日。三上名何水前秋时不
step #15
数据挖掘哪家强。君吟思日。三上名何水前秋时不人
step #16
数据挖掘哪家强。君吟思日。三上名何水前秋时不人去
step #17
数据挖掘哪家强。君吟思日。三上名何水前秋时不人去处
step #18
数据挖掘哪家强。君吟思日。三上名何水前秋时不人去处未
step #19
数据挖掘哪家强。君吟思日。三上名何水前秋时不人去处未到


In [55]:
text

'生活敬咆咆础〈彌[unused69]含孖桎ㄞ〈〈傣彌〈〈[unused69]悲宙仰瞭〈孖7噬彌噬沃器〈彌圾浦〈写〈彌換傣〈含彌壯含〈7〈彌[unused32]梦梦寮〈彌冪强础滔彌'

# debug

In [71]:
model = finetuned_model

In [67]:
def get_logits(encodings):
    inputs = {
        "input_ids": encodings["input_ids"],
        "attention_mask": encodings["attention_mask"],
        "token_type_ids": encodings["token_type_ids"],
        "labels": encodings["input_ids"],
    }
    outputs = model(**inputs)
    logits = outputs["logits"].detach()
    return logits

In [69]:
target_text = "冬天"
encode_text = target_text + tokenizer.mask_token
print(encode_text)

冬天[MASK]


In [72]:
encodings = tokenizer(
    encode_text,
    return_tensors="pt",
    truncation=True,
    padding="max_length",
    max_length=128,
)
logits = get_logits(encodings)

In [73]:
print(logits.shape)

torch.Size([1, 128, 21128])


In [76]:
poetry_vocab_ids = poetry_vocab["id"].tolist()
example_ids = poetry_vocab_ids[:5]
example_ids

[4791, 4655, 1772, 3563, 4241]

In [75]:
mask_token_index = torch.where(encodings["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_index

tensor([3])

In [79]:
logits[0, mask_token_index, [0, 1, 2, 3]]

tensor([ 8.3727, -9.5713, -9.5940, -9.5951])

In [78]:
mask_token_logits = logits[0, mask_token_index, example_ids]
mask_token_logits

tensor([-5.0340, -2.6134, -2.7907, -5.1408, -7.0795])

In [82]:
topk = torch.topk(mask_token_logits, 4, dim=0)
topk

torch.return_types.topk(
values=tensor([-2.6134, -2.7907, -5.0340, -5.1408]),
indices=tensor([1, 2, 0, 3]))