In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
import torch.utils.data.dataloader as dataloader
import random
import tqdm
import gensim
import time
import math
import numpy as np

In [2]:
with open('poetry.txt') as f:
#     corpus = f.readlines()
    corpus = f.read()

# corpus = list(map(lambda x: x.replace('\n', ''), corpus))
corpus = corpus.replace('\n', ' ').replace('\r', ' ')

In [3]:
idx_to_char = set()
idx_to_char.update(*corpus)
idx_to_char = list(idx_to_char)
char_to_idx = dict([(char, i) for i, char in enumerate(idx_to_char)])
vocab_size = len(char_to_idx)
vocab_size

5387

In [4]:
corpus_indice = list(map(lambda x: [char_to_idx[c] for c in x], corpus))
corpus_indice.sort(key=len, reverse=True)

In [5]:
def data_iter_consecutive(corpus_indices, batch_size, num_steps):
    corpus_indices = torch.tensor(corpus_indices)
    data_len = len(corpus_indices)
    batch_len = data_len // batch_size
    indices = corpus_indices[0: batch_size*batch_len].reshape((
        batch_size, batch_len))
    epoch_size = (batch_len - 1) // num_steps
    for i in range(epoch_size):
        i = i * num_steps
        X = indices[:, i: i + num_steps]
        Y = indices[:, i + 1: i + num_steps + 1]
        yield X, Y

In [7]:
# weight = torch.nn.Embedding(vocab_size, 400).weight.data
weight = torch.diag(torch.ones(vocab_size))

In [8]:
class lyricNet(nn.Module):
    def __init__(self, hidden_dim, embed_dim, num_layers, weight,
                 num_labels, bidirectional, dropout=0.5, **kwargs):
        super(lyricNet, self).__init__(**kwargs)
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.num_labels = num_labels
        self.bidirectional = bidirectional
        if num_layers <= 1:
            self.dropout = 0
        else:
            self.dropout = dropout
        self.embedding = nn.Embedding.from_pretrained(weight)
        self.embedding.weight.requires_grad = False
#         self.embedding = nn.Embedding(num_labels, self.embed_dim)
        self.rnn = nn.GRU(input_size=self.embed_dim, hidden_size=self.hidden_dim,
                          num_layers=self.num_layers, bidirectional=self.bidirectional,
                          dropout=self.dropout)
        if self.bidirectional:
            self.decoder = nn.Linear(hidden_dim * 2, self.num_labels)
        else:
            self.decoder = nn.Linear(hidden_dim, self.num_labels)
            
    def forward(self, inputs, hidden=None):
        embeddings = self.embedding(inputs)
        states, hidden = self.rnn(embeddings.permute([1, 0, 2]), hidden)
        outputs = self.decoder(states.reshape((-1, states.shape[-1])))
        return(outputs, hidden)
    
    def init_hidden(self, num_layers, batch_size, hidden_dim, **kwargs):
        hidden = torch.zeros(num_layers, batch_size, hidden_dim)
        return hidden

In [9]:
def predict_rnn(prefix, num_chars, model, device, idx_to_char, char_to_idx):
    output = [char_to_idx[prefix[0]]]
    hidden = torch.zeros(num_layers, 1, hidden_dim)
    if use_gpu:
        hidden = hidden.to(device)
    for t in range(num_chars + len(prefix) - 1):
        X = torch.tensor([output[-1]]).reshape((1, 1))
        if use_gpu:
            X = X.to(device)
        pred, hidden = model(X, hidden)
        if t < len(prefix) - 1:
            output.append(char_to_idx[prefix[t + 1]])
        elif pred.argmax(dim=1) == vocab_size:
            break
        else:
            output.append(int(pred.argmax(dim=1)))
    return(''.join([idx_to_char[i] for i in output]))

In [10]:
embedding_dim = 400
hidden_dim = 256
lr = 1e2
momentum = 0.0
num_epoch = 100
use_gpu = True
num_layers = 1
bidirectional = False
batch_size = 32
device = torch.device('cuda:0')
loss_function = nn.CrossEntropyLoss()

In [11]:
model = lyricNet(hidden_dim=hidden_dim, embed_dim=vocab_size, num_layers=num_layers,
                 num_labels=vocab_size, weight=weight, bidirectional=bidirectional)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
if use_gpu:
#     model = nn.DataParallel(model)
    model.to(device)

In [12]:
predict_rnn('春', 11, model, device, idx_to_char, char_to_idx)

'春燋瀰洎表燋恃恃瀰洎燋恃'

In [13]:
since = time.time()
for epoch in range(num_epoch):
    start = time.time()
    num, total_loss = 0, 0
#     if epoch == 5000:
#         optimizer.param_groups[0]['lr'] = lr * 0.1
    data = data_iter_consecutive(corpus_indice, batch_size, 35)
    hidden = model.init_hidden(num_layers, batch_size, hidden_dim)
    for X, Y in data:
        num += 1
        hidden.detach_()
        if use_gpu:
            X = X.to(device)
            Y = Y.to(device)
            hidden = hidden.to(device)
        optimizer.zero_grad()
        output, hidden = model(X, hidden)
        l = loss_function(output, Y.t().reshape((-1,)))
        l.backward()
        norm = nn.utils.clip_grad_norm_(model.parameters(), 1e-2)
        optimizer.step()
        total_loss += l.item()
    end = time.time()
    s = end - since
    h = math.floor(s / 3600)
    m = s - h * 3600
    m = math.floor(m / 60)
    s -= m * 60
    if (epoch % 10 == 0) or (epoch == (num_epoch - 1)):
        print('epoch %d/%d, loss %.4f, norm %.4f, time %.3fs, since %dh %dm %ds'
              %(epoch+1, num_epoch, total_loss / num, norm, end-start, h, m, s))
        print(predict_rnn('春', 47, model, device, idx_to_char, char_to_idx))
        print(predict_rnn('夏', 47, model, device, idx_to_char, char_to_idx))

epoch 1/100, loss 5.7401, norm 0.1910, time 7.659s, since 0h 0m 7s
春，风人人人。 不日不不远，相人有不人。 不日不不远，相人有不人。 不日不不远，相人有不人。 不
夏，风人人人。 不日不不远，相人有不人。 不日不不远，相人有不人。 不日不不远，相人有不人。 不
epoch 11/100, loss 4.1388, norm 0.3108, time 8.270s, since 0h 1m 31s
春风吹早晚，人语不相见。 何事不相见，空山有所思。 不知何处去，不是有时情。 不是愁人意，相逢不
夏云中路远，何处是愁人。 不见南山色，相逢白发生。 不知今日日，不是故人情。 不见东风水，相逢不
epoch 21/100, loss 3.7344, norm 0.4362, time 8.435s, since 0h 2m 55s
春风吹早晚，独有清风起。 不见君王孙，年年不堪把。 春风生，有人意不如何人。 一点玉壶中，一枝枝
夏云无处所。 一径入幽林，千峰挂石床。 白云生白发，青鸟入窗枝。 清景不可得，清风何处归。 今日
epoch 31/100, loss 3.5179, norm 0.4746, time 8.451s, since 0h 4m 18s
春风吹满衣。 一点白头翁，数茎红叶红。 夜久不相见，空歌声满长。 病起见庭花，今年独自迟。 不知
夏云生故国。 日暮天下寒，风波上林晚。 烟波生处尽，风景澹凄凄。 此地堪相望，知君道不平。 万里
epoch 41/100, loss 3.3755, norm 0.5808, time 8.465s, since 0h 5m 42s
春风雨夜长。 不知楼上月，时得似君王。 白发今如此，青云自可亲。 不知名利利，何事更伤心。 自古
夏云无处所。 一为一瓢酒，还似江上月。 何人知此来，独向何处去。 何事东归来，空山有秋色。 心知
epoch 51/100, loss 3.2706, norm 0.6426, time 8.457s, since 0h 7m 6s
春风雨夜行。 旧业今朝夕，清江上苑阳。 寒山对孤屿，秋水正秋风。 何必沧浪水，相看楚客还。 一身
夏云无处所。 何人知此去，见古今古今。 白日照我心，白云自相逢。 不知何所有，唯有旧时心。 一片
e