In [3]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
DATA_PATH = 'data/POEM/tang.npz'
BATCH_SIZE = 128
NUM_WORKERS = 2
LR = 5e-3
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
EPOCHS = 10
VERBOSE = True
 

In [5]:
def prepareData():
    datas = np.load('data/POEM/tang.npz',allow_pickle=True)
    data = datas['data']
    ix2word = datas['ix2word'].item()
    word2ix = datas['word2ix'].item()
    data = torch.from_numpy(data)
    dataloader = DataLoader(data,
                            batch_size=BATCH_SIZE,
                            shuffle=True,
                            num_workers=NUM_WORKERS)
    return dataloader, ix2word, word2ix,datas

dataloader, ix2word, word2ix, datas = prepareData()


In [6]:
len(word2ix)

8293

In [7]:
class PoetryModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=1, batch_first=True)
        self.linear = nn.Linear(self.hidden_dim, vocab_size)

    def forward(self, input, hidden=None):
        embeds = self.embeddings(input)
        batch_size, seq_len = input.size()
        if hidden is None:
            h_0 = input.data.new(
                1, batch_size, self.hidden_dim).fill_(0).float()
            c_0 = input.data.new(
                1, batch_size, self.hidden_dim).fill_(0).float()
        else:
            h_0, c_0 = hidden

        output, hidden = self.lstm(embeds, (h_0, c_0))
        output = output.reshape(batch_size * seq_len, self.hidden_dim)
        output = self.linear(output)
        return output, hidden


In [8]:
def train(model, optimizer, criterion, dataloader):
    model.train()
    print(">>>>>> Model Train Begin......")
    for epoch_idx in range(EPOCHS):
        epoch_loss = 0
        print(f"Epoch {epoch_idx}\n-------------------------------")
        for batch_idx, batch_data in enumerate(dataloader):
            batch_data = batch_data.to(device)
            input, target = batch_data[:, :-
                                       1], batch_data[:, 1:].reshape(-1).long()
            output, _ = model(input)
            loss = criterion(output, target)
            epoch_loss += loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (batch_idx+1) % 200 == 0:
                print(
                    f'[{batch_idx * len(batch_data[1])}/{len(dataloader.dataset)} \
                      ({100. * batch_idx / len(dataloader):.0f}%)]\t \
                      loss: {loss.item():.6f}')
        print(f'Epoch {epoch_idx}\tloss= {epoch_loss/len(dataloader):.6f}')
    print(">>>>>> Model Train End.")


In [9]:
def generate(model, start_words, ix2word, word2ix, max_gen_len):
    results = list(start_words)
    start_words_len = len(start_words)
    #
    input = torch.Tensor([word2ix['<START>']]).view(1,1).long().to(device)
    hidden = None
    model.eval()
    with torch.no_grad():
        for i in range(max_gen_len):
            output, hidden = model(input, hidden)
            if i < start_words_len:
                w = results[i]
                input = input.data.new([word2ix[w]]).view(1,1)
            else:
                top_index = output.data[0].topk(1)[1][0].item()
                w = ix2word[top_index]
                results.append(w)
                input = input.data.new([top_index]).view(1,1)
            if w == '<EOP>':
                del results[-1]
                break
    return results

In [10]:
PATH = './AutomaticWritingPoems.pt'
model = PoetryModel(len(word2ix),EMBEDDING_DIM, HIDDEN_DIM).to(device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
# train(model, optimizer, criterion, dataloader)
# torch.save(model.state_dict(), PATH)
model.load_state_dict(torch.load(PATH))


<All keys matched successfully>

In [31]:
for i in generate(model,"白日依山尽",ix2word,word2ix,30):
    print(i,end=" ")
    if i == '，' or i == '。':
        print()
    


白 日 依 山 尽 ， 
青 山 万 里 赊 。 
山 川 无 一 事 ， 
江 水 有 归 期 。 
白 日 无 人 识 ， 
