In [1]:
import numpy as np
import torch
import torch.nn as nn
import string
from tqdm.notebook import tqdm
import torch.optim as optim
from training import MyNet
from IPython.display import HTML


In [2]:
MODEL_PATH = "./model_bk2.pt"
TEXTFILE_PATH = "./shakespeare.txt"

In [3]:
# 文字列データを取得
all_characters = string.printable
display(all_characters)
N_CHAR = len(all_characters)
# このようにして特定の文字列をTOKEN-IDにする
all_characters.index(" ")

'0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ \t\n\r\x0b\x0c'

94

In [4]:
# モデルを読み込み
model = torch.load(MODEL_PATH)
model.cuda()

MyNet(
  (emb): Embedding(100, 100)
  (rnn): LSTM(100, 100, num_layers=2, batch_first=True)
  (softmax): Softmax(dim=-1)
  (layer): Linear(in_features=100, out_features=100, bias=True)
)

In [5]:
# 文章生成を試してみる

# 初期文字列
text = "W"

with torch.no_grad():
    model.eval()
    
    for i in range(20):
        # 初期文字列をtoken列に変換
        vec = torch.tensor([all_characters.index(c) for c in text]).unsqueeze(0).cuda()
        # 予測
        predict = model(vec)[0]
        
        # 取り出した文字を結合する
        text += "".join([
            # 予測結果から尤度の高い文字を取り出す
            all_characters[torch.argmax(prob).item()]
            for prob in predict
        ])[-1]
    
    print(text)

Whe the the the the t


In [6]:
# 文章を読み込み
with open(TEXTFILE_PATH) as f:
    l_text = f.read().split("\n")

display(l_text[:10])

['First Citizen:Before we proceed any further, hear me speak.',
 'All:Speak, speak.',
 'First Citizen:You are all resolved rather to die than to famish?',
 'All:Resolved. resolved.',
 'First Citizen:First, you know Caius Marcius is chief enemy to the people.',
 "All:We know't, we know't.",
 "First Citizen:Let us kill him, and we'll have corn at our own price.Is't a verdict?",
 "All:No more talking on't; let it be done: away, away!?",
 'Second Citizen:One word, good citizens.',
 'First Citizen:We are accounted poor citizens, the patricians good.What authority surfeits on would relieve us: if theywould yield us but the superfluity, while it werewholesome, we might guess they relieved us humanely;but they think we are too dear: the leanness thatafflicts us, the object of our misery, is as aninventory to particularise their abundance; oursufferance is a gain to them Let us revenge this withour pikes, ere we become rakes: for the gods know Ispeak this in hunger for bread, not in thirst for 

In [7]:
def get_color(x):
    # 0-1の値をカラーコードに変更
    color_r = 255
    # 1に近ければ緑と青を低めにする = 赤にする
    color_g = int((1.0 - x) * 255)
    color_b = int((1.0 - x) * 255)
    return f"#{color_r:02x}{color_g:02x}{color_b:02x}"

def character(char, score):
    # 文字をカラーリング
    return f'<span style="background-color: {get_color(score)}">{char}</span>'


In [8]:

# 文章から任意の長さで取得してくる
START_ROW = 20
END_ROW = 24
text = " ".join(l_text[START_ROW:END_ROW + 1])


In [10]:
# hn層なら0、cn層なら1
HN_or_CN = 0
# 2層ある隠れ層のいずれかを指定　0/1
HIDDEN_LAYER = 0

with torch.no_grad():
    model.eval()

    # 文章を文字列トークンに変換
    vec = torch.tensor([all_characters.index(c) for c in text]).unsqueeze(0).cuda()
    
    results = torch.stack([
            # 文章を区切りながら、HN/CN層を取り出す
            model.rnn(model.emb(vec[:, :idx]))[1][HN_or_CN]
            for idx in range(1, vec.size(1) + 1)
        ])
    
    # 層の同じノードの値の遷移を取得してくる
    HIDDEN_SIZE = results.size(3)
    for idx in range(HIDDEN_SIZE):
        # 指定された層だけを取得
        ary = results[:, HIDDEN_LAYER, 0, idx].cpu().numpy()
        # 0-1に正規化
        data = ((ary - ary.min()) / (ary.max() - ary.min()))

        # 表示
        display(HTML("".join([
            character(char, score)
             for char, score in zip(text, data.tolist())
        ])))