In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

In [2]:
# 0.Set some hyperparameters
Batch_size = 16                                                          # batch size
learning_rate = 5e-3                                                     # learn rate
embedding_dim = 128                                                      # embedding layer dimension
hidden_dim = 256                                                         # hidden layer dimension
epochs = 4                                                               # epochs to train
verbose = True                                                           # print training process
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')  # use GPU first
pre_trained_model_path = None                                            # pre_trained model path
trained_model_path = 'model.pth'                                         # trained model path
start_words = '湖光秋月两相和'                                           # the first sentence of poetry
start_words_acrostic = '轻舟已过万重山'                                  # the 'head' the the genrated acrostic
max_gen_len = 128                                                        # the max length of generated poetry

In [3]:
# 1.Load data from tang.npz
def prepareData():
    
    # Load Tang poetry data including 3 parts: data, ix2word, word2ix
    datas = np.load("tang.npz", allow_pickle=True)
    data = datas['data']
    ix2word = datas['ix2word'].item()
    word2ix = datas['word2ix'].item()
    
    # Translate data from np to torch.Tensor & generate dataloader
    data = torch.from_numpy(data)
    print(data.shape) # [57580, 125]
    dataloader = DataLoader(data,
                         batch_size = Batch_size,
                         shuffle = True,
                         num_workers = 2)
    print(len(dataloader)) # 3599
    
    return dataloader, ix2word, word2ix

In [4]:
dataloader, ix2word, word2ix = prepareData()

torch.Size([57580, 125])
3599


In [5]:
# 2.Define PoetryModel class
class PoetryModel(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, hidden_dim):
        super(PoetryModel, self).__init__()
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=2)
        self.linear = nn.Linear(self.hidden_dim, num_embeddings)

    def forward(self, input, hidden = None):
        seq_len, batch_size = input.size()
        
        if hidden is None:
            h_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()
            c_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float()
        else:
            h_0, c_0 = hidden

        embeds = self.embedding(input)
        output, hidden = self.lstm(embeds, (h_0, c_0))
        output = self.linear(output.view(seq_len * batch_size, -1))
        return output, hidden

In [6]:
# 3.Define train function
def train(dataloader, ix2word, word2ix):
    # print(len(dataloader)) # 3599, so all len=16*3598+12*3599=57580, shape=57580*125

    # config model & load pre-trained model or not
    model = PoetryModel(len(word2ix), embedding_dim, hidden_dim)
    if pre_trained_model_path:
        model.load_state_dict(torch.load(model_path))
    model.to(device)
    
    # set optimizer & loss
    optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
    criterion = nn.CrossEntropyLoss()

    # circuit train
    for epoch in range(epochs):
        for batch_idx, data in enumerate(dataloader):
            data = data.long().transpose(1, 0).contiguous()
            data = data.to(device)
            input, target = data[:-1, :], data[1:, :]
            output, _ = model(input)
            loss = criterion(output, target.view(-1))
            
            if (batch_idx+1) % 899 == 0 & verbose:
                # print(data.shape)  # [125,16]
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch+1, (batch_idx+1) * Batch_size, len(dataloader.dataset),
                    100. * (batch_idx+1) / len(dataloader), loss.item()))
            # if batch_idx==3598:
            #     print(data.shape)  # [125,12]
                
                
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # save model
    torch.save(model.state_dict(), 'model.pth')

In [7]:
train(dataloader, ix2word, word2ix)



In [8]:
# 4.Define generate poetry function
def generate(start_words, ix2word, word2ix):

    # load trained_model from trained_model_path
    model = PoetryModel(len(word2ix), embedding_dim, hidden_dim)
    model.load_state_dict(torch.load(trained_model_path))
    model.to(device)
    
    # list the start sentence
    results = list(start_words)
    start_word_len = len(start_words)
    
    # set the first word as <START>
    input = torch.Tensor([word2ix['<START>']]).view(1, 1).long()
    input = input.to(device)
    hidden = None

    # generate poetry in the range of max_gen_len
    for i in range(max_gen_len):
        output, hidden = model(input, hidden)
        # print(len(output[0]),len(hidden[0]),len(hidden[1])) # 1*8293, 2*2
        # load start_words as the first sentence
        if i < start_word_len:
            w = results[i]
            input = input.data.new([word2ix[w]]).view(1, 1)
        # generate other sentences
        else:
            top_index = output.data[0].topk(1)[1][0].item()
            w = ix2word[top_index]
            results.append(w)
            input = input.data.new([top_index]).view(1, 1)
        # end label '<EOP>'
        if w == '<EOP>':
            del results[-1]
            break
            
    return results

In [9]:
results = generate(start_words, ix2word, word2ix)
print(results)
print(len(results))

['湖', '光', '秋', '月', '两', '相', '和', '，', '一', '片', '云', '山', '无', '一', '声', '。', '一', '朝', '不', '见', '青', '山', '曲', '，', '不', '见', '人', '间', '无', '一', '人', '。', '一', '朝', '不', '见', '青', '山', '曲', '，', '不', '见', '东', '风', '吹', '白', '云', '。', '一', '朝', '不', '见', '青', '山', '曲', '，', '不', '见', '东', '风', '吹', '白', '云', '。', '一', '朝', '不', '见', '青', '山', '曲', '，', '不', '见', '东', '风', '吹', '白', '云', '。', '一', '朝', '不', '见', '青', '山', '曲', '，', '不', '见', '青', '山', '不', '可', '见', '。', '一', '朝', '不', '见', '青', '山', '人', '，', '不', '见', '春', '风', '吹', '白', '雪', '。', '一', '朝', '不', '见', '春', '风', '起', '，', '一', '曲', '花', '前', '花', '下', '来', '。']
128


In [10]:
# 5.Define generate acrostic function
def gen_acrostic(start_words, ix2word, word2ix):

    # load trained_model from trained_model_path
    model = PoetryModel(len(word2ix), embedding_dim, hidden_dim)
    model.load_state_dict(torch.load(trained_model_path))
    model.to(device)
    
    # load the 'head' of the acrostic
    results = []
    start_word_len = len(start_words)
    
    # set the first word as <START>
    input = (torch.Tensor([word2ix['<START>']]).view(1, 1).long())
    input = input.to(device)
    hidden = None

    index = 0            # index of the character in start_words
    pre_word = '<START>' # pre_word

    # generate acrostic in the range of max_gen_len
    for i in range(max_gen_len):
        output, hidden = model(input, hidden)
        top_index = output.data[0].topk(1)[1][0].item()
        w = ix2word[top_index]

        # if the pre_word is end or start label, set the next character in start_words as the next word 
        if (pre_word in {u'。', u'！', '<START>'}):
            # condition of end
            if index == start_word_len:
                break
            # feed the next character as head
            else:
                w = start_words[index]
                index += 1
                input = (input.data.new([word2ix[w]])).view(1, 1)
        # otherwise, set the next prediction as the next word 
        else:
            input = (input.data.new([word2ix[w]])).view(1, 1)
            
        results.append(w)
        pre_word = w
        
    return results

In [11]:
results_acrostic = gen_acrostic(start_words_acrostic, ix2word, word2ix)
print(results_acrostic)

['轻', '生', '不', '得', '意', '，', '不', '得', '不', '得', '知', '。', '舟', '中', '有', '奇', '气', '，', '不', '得', '不', '得', '持', '。', '已', '闻', '天', '上', '来', '，', '不', '得', '不', '得', '宁', '。', '过', '此', '不', '可', '见', '，', '不', '知', '何', '处', '期', '。', '万', '里', '不', '可', '见', '，', '一', '朝', '无', '人', '知', '。', '重', '阳', '不', '可', '见', '，', '一', '日', '不', '可', '攀', '。', '山', '川', '有', '高', '树', '，', '山', '水', '无', '人', '舟', '。']
