# 0 MODULE and DATASET

In [13]:
# 0.0 MODULE
import numpy as np
import torch
import torch.nn as nn 
import torch.optim as optim
import time 
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torch.autograd import Variable

# 0.1 CONSTANT
BATCH_SIZE = 16
EPOCHS = 4 
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{DEVICE =}")

# 0.2 FUNC :DATA
def prepareData():
    """ 
    Description:load the data from ./tang.npz 
    Returns:
    """
    datas = np.load("./tang.npz",allow_pickle=True)
    print(f"==prepareData()=={type(datas) = }")

    data = torch.from_numpy(datas["data"])
    ix2word = datas["ix2word"].item()
    word2ix = datas["word2ix"].item() # no .item() is OK

    train_dataloader = DataLoader(data,batch_size=BATCH_SIZE,shuffle=True,num_workers=2)
    print(f"==prepareData()=={word2ix = }")
    return train_dataloader, ix2word, word2ix

train_dataloader, ix2word, word2ix = prepareData()


==prepareData()==type(datas) = <class 'numpy.lib.npyio.NpzFile'>
==prepareData()==word2ix = {'憁': 0, '耀': 1, '枅': 2, '涉': 3, '谈': 4, '伊': 5, '鈌': 6, '薙': 7, '亟': 8, '洞': 9, '猢': 10, '悫': 11, '缪': 12, '河': 13, '临': 14, '犷': 15, '吸': 16, '碻': 17, '娼': 18, '线': 19, '反': 20, '牌': 21, '雏': 22, '姑': 23, '硐': 24, '葘': 25, '卢': 26, '知': 27, '除': 28, '彪': 29, '菭': 30, '觱': 31, '勷': 32, '闹': 33, '壻': 34, '睺': 35, '廿': 36, '覆': 37, '辊': 38, '墐': 39, '应': 40, '蠛': 41, '踟': 42, '儡': 43, '皤': 44, '騧': 45, '崩': 46, '芬': 47, '冶': 48, '骼': 49, '嶾': 50, '腁': 51, '赉': 52, '濊': 53, '噏': 54, '痎': 55, '湟': 56, '跞': 57, '囤': 58, '峨': 59, '括': 60, '棰': 61, '豳': 62, '滴': 63, '鉷': 64, '须': 65, '孽': 66, '鹿': 67, '悁': 68, '沉': 69, '上': 70, '犍': 71, '后': 72, '碑': 73, '餐': 74, '娒': 75, '纕': 76, '怖': 77, '蒙': 78, '喣': 79, '隥': 80, '砦': 81, '枯': 82, '涳': 83, '伴': 84, '玷': 85, '逶': 86, '离': 87, '阺': 88, '羿': 89, '杄': 90, '浈': 91, '勍': 92, '壑': 93, '祐': 94, '廕': 95, '罔': 96, '盥': 97, '幪': 98, '号': 99, '灶': 100, '姻': 10

# 1 MODEL

In [17]:
# 1.0 CLASS: model
class PoetryModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(PoetryModel,self).__init__()
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size,embedding_dim)
        self.lstm = nn.LSTM(embedding_dim,self.hidden_dim,num_layers=3)

        self.classifier=nn.Sequential(
            nn.Linear(self.hidden_dim, 512), 
            nn.ReLU(inplace=True), 
            nn.Linear(512, 2048), 
            nn.ReLU(inplace=True),
            nn.Linear(2048, vocab_size)
        )

    def forward(self, input, hidden = None):
        seq_len, batch_size = input.size()
        
        if hidden is None:
            h_0 = input.data.new(3, batch_size, self.hidden_dim).fill_(0).float()
            c_0 = input.data.new(3, batch_size, self.hidden_dim).fill_(0).float()
        else:
            h_0, c_0 = hidden

        embeds = self.embedding(input)
        output, hidden = self.lstm(embeds, (h_0, c_0))
        output = self.classifier(output.view(seq_len * batch_size, -1))
        
        return output, hidden

# 1.1 FUNC train
model = PoetryModel(len(word2ix),embedding_dim=256,hidden_dim=256)

model_pretrained = ''         # pretrained
if model_pretrained:
    model.load_state_dict(torch.load(model_pretrained))

model.to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr = 3e-3)
criterion = nn.CrossEntropyLoss()
scheduler = StepLR(optimizer, step_size=10, gamma=0.1) # change lr per 10 epochs，gamma 默认为 0.1 Adam 会自己调  这一行可以不要

def train(model, dataloader, ix2word, word2ix, device, optimizer, scheduler, epoch):
    model.train()
    train_loss = 0.0
    
    for batch_idx, data in enumerate(dataloader):
        data = data.long().transpose(1, 0).contiguous()
        data = data.to(device)
        optimizer.zero_grad()
        input, target = data[:-1, :], data[1:, :]
        output, _ = model(input)
        loss = criterion(output, target.view(-1))
        loss.backward()  
        optimizer.step()
        train_loss += loss.item()
            
        if (batch_idx+1) % 200 == 0:
            print(f'train epoch: {epoch} [{batch_idx * len(data[1])}/{len(dataloader.dataset)}]\tloss: {loss.item()}')
            
    train_loss *= BATCH_SIZE
    train_loss /= len(train_dataloader.dataset)
    print(f'\ntrain epoch: {epoch}\t average loss: {train_loss}\n')
    scheduler.step()
    
    return train_loss

train_losses = []

for epoch in range(1,EPOCHS+1):
    tr_loss = train(model,train_dataloader,ix2word,word2ix,DEVICE,optimizer,scheduler,epoch)
    train_losses.append(tr_loss)
    
# save
filename = "./model_" + str(time.time()) + ".pth"
torch.save(model.state_dict(), filename)


train epoch: 1 [3184/57580 ]	loss: 2.318631649017334
train epoch: 1 [6384/57580 ]	loss: 3.397636651992798
train epoch: 1 [9584/57580 ]	loss: 2.399000644683838
train epoch: 1 [12784/57580 ]	loss: 3.0484793186187744
train epoch: 1 [15984/57580 ]	loss: 2.649400234222412
train epoch: 1 [19184/57580 ]	loss: 1.9118455648422241
train epoch: 1 [22384/57580 ]	loss: 3.2903225421905518
train epoch: 1 [25584/57580 ]	loss: 2.72925066947937
train epoch: 1 [28784/57580 ]	loss: 1.9411946535110474
train epoch: 1 [31984/57580 ]	loss: 1.8184380531311035
train epoch: 1 [35184/57580 ]	loss: 2.9487144947052
train epoch: 1 [38384/57580 ]	loss: 2.22786808013916
train epoch: 1 [41584/57580 ]	loss: 2.713787317276001
train epoch: 1 [44784/57580 ]	loss: 2.19191312789917
train epoch: 1 [47984/57580 ]	loss: 2.0682921409606934
train epoch: 1 [51184/57580 ]	loss: 2.4965813159942627
train epoch: 1 [54384/57580 ]	loss: 2.517153739929199

train epoch: 1	 average loss: 2.423557633084943

train epoch: 2 [3184/57580 ]	loss

# 2 TEST

In [28]:
def generate(model, start_words, ix2word, word2ix, max_gen_len, prefix_words=None):
    results = list(start_words)
    start_word_len = len(start_words)

    input = torch.Tensor([word2ix['<START>']]).view(1, 1).long()
    input = input.to(DEVICE)
    hidden = None
    
    if prefix_words:
        for word in prefix_words:
            output, hidden = model(input, hidden)
            input = Variable(input.data.new([word2ix[word]])).view(1, 1)

    for i in range(max_gen_len):
        output, hidden = model(input, hidden)
  
        if i < start_word_len:
            w = results[i]
            input = input.data.new([word2ix[w]]).view(1, 1)
        else:
            top_index = output.data[0].topk(5)[1][2].item()
            w = ix2word[top_index]
            results.append(w)
            input = input.data.new([top_index]).view(1, 1)
        if w == '<EOP>':
            del results[-1]
            break
            
    return results

start_words = '此情可待成追忆'  
max_gen_len = 128        

prefix_words = None
results = generate(model, start_words, ix2word, word2ix, max_gen_len, prefix_words)
poetry = ''
for word in results:
    poetry += word
    if word == '。' or word == '!':
        poetry += '\n'
        
print(poetry)


此情可待成追忆。
一去不如今不见？一片青楼无处得，不如今夕长江曲？君子相逢无一物？君子无端心更足？今年无计是无情？一言无限不知谁，我不相见何由已，不如不及长沙路？君子无端心已死？一生相与知不同？我亦有心皆自保？我亦有心皆有限？我来何必是闲身，此时无限不堪游，
