In [2]:
import numpy as np

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import torch
import torch.optim as optim
import os
import time

In [3]:
EPOCHES = 100
LEARNING_RATE = 0.001
RES_FILE = "result.txt"

In [4]:
tang_file = np.load("tang.npz",allow_pickle=True)
tang_file.files

['ix2word', 'word2ix', 'data']

In [5]:
data = tang_file['data']
word2ix = tang_file['word2ix'].item()
idx2word = tang_file['ix2word'].item()

In [6]:
data[314]

array([8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
       8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
       8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
       8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
       8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
       8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
       8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
       8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
       8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292, 8292,
       8291, 3198, 5409, 7440, 7686, 3655, 7066, 2354, 1059, 3015, 2417,
       8004, 7435, 8118, 4662, 2260, 1252, 8150, 7066, 2110, 2110, 7477,
       8173, 1418, 7435, 8290], dtype=int32)

In [7]:
def idx2poem(idx_poem):
    poem = []
    for id in idx_poem:
        poem.append(idx2word[id])
    return "".join(poem)

print(idx2poem(data[314]))

print(word2ix["</s>"])

poems = [] #里面有很多首诗的input_ids(诗的总数, 每一首诗的长度)
for poem in data:
    for index,ix in enumerate(poem):
        if ix == word2ix["</s>"]:
            continue
        else:
            break
    poems.append(poem[index:])

# del data
poems[0],idx2poem(poems[0]),len(idx2word)

</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><START>南楼夜已寂，暗鸟动林间。不见城郭事，沈沈唯四山。<EOP>
8292


(array([8291, 6731, 4770, 1787, 8118, 7577, 7066, 4817,  648, 7121, 1542,
        6483, 7435, 7686, 2889, 1671, 5862, 1949, 7066, 2596, 4785, 3629,
        1379, 2703, 7435, 6064, 6041, 4666, 4038, 4881, 7066, 4747, 1534,
          70, 3788, 3823, 7435, 4907, 5567,  201, 2834, 1519, 7066,  782,
         782, 2063, 2031,  846, 7435, 8290], dtype=int32),
 '<START>度门能不访，冒雪屡西东。已想人如玉，遥怜马似骢。乍迷金谷路，稍变上阳宫。还比相思意，纷纷正满空。<EOP>',
 8293)

In [8]:
seq_len = 48
X = []
Y = []
poems_data = [j for i in poems for j in i]

# 构造训练数据， 把每一首诗歌flatten，构造X，及其下一个字的预判Y
for i in range(0,len(poems_data) - seq_len -1,seq_len):
    X.append(poems_data[i:i+seq_len])
    Y.append(poems_data[i+1:i+seq_len+1])

In [9]:
class PoemDataset(Dataset):

    def __init__(self,X,Y):
        self.X = X
        self.Y = Y
        self.len = len(X)
    def __getitem__(self,index):
        x = np.array(X[index])
        y = np.array(Y[index])
        return torch.from_numpy(x).long(),torch.from_numpy(y).long()
    def __len__(self):
        return self.len
        
data_loader = DataLoader(PoemDataset(X,Y),batch_size=512,num_workers=2)

In [10]:
a,b = next(iter(data_loader))
# (bs,sl)
a.shape,b.shape

(torch.Size([512, 48]), torch.Size([512, 48]))

In [13]:
class PoemNet(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        """
            vocab_size：训练集合字典大小（8293）
            embedding_dim：word2vec的维度
            hidden_dim：LSTM的hidden_dim
        """
        super(PoemNet, self).__init__()
        #请从这儿开始补充代码
        #定义LSTM模型，模型结构见文档最后部分
        self.embeddings = nn.Embedding(num_embeddings=vocab_size,
                                      embedding_dim=embedding_dim)
        self.lstm = nn.LSTM(input_size=embedding_dim,
                            hidden_size=hidden_dim,
                            batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(in_features=hidden_dim, out_features=hidden_dim*2,bias=True),
            nn.ReLU(),
            nn.Dropout(p=0.25, inplace=False),
            nn.Linear(hidden_dim*2, hidden_dim*4, bias=True),
            nn.Dropout(p=0.2, inplace=False),
            nn.ReLU(),
            nn.Linear(hidden_dim*4, vocab_size, bias=True)
        )
        #补充代码结束

    def forward(self, input, hidden=None):
        """
            input：输入的诗词
            hidden：在生成诗词的时候需要使用，在pytorch中，如果不指定初始状态h_0和C_0，则其
            默认为0. 用来作为lstm的初始状态
            pytorch的LSTM的输出是(output,(h_n,c_n))。实际上，output就是h_1,h_2,……h_n
        """
        #请从这儿开始补充代码
        #定义LSTM模型的前向传播过程
        embedding_hiddens = self.embeddings(input)
        output, hidden = self.lstm(embedding_hiddens, hidden)
        output = self.fc(output) # (bs, sl, vocab_size)
        #补充代码结束
        return output,hidden

vocab_size = len(word2ix.keys()) # 8293
embedding_dim = 200
hidden_dim = 1024

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
my_net = PoemNet(vocab_size,embedding_dim,hidden_dim).to(device)

optimzer = optim.Adam(my_net.parameters(),lr=LEARNING_RATE)
loss_function = nn.CrossEntropyLoss()

with open(RES_FILE, mode='w') as f:
    f.write(f'训练日期:{time.time()}\n训练轮数{EPOCHES}\n学习率:{LEARNING_RATE}')

total_loss = 0
total_acc = 0

for epoch in range(EPOCHES):
    my_net.train()
    epoch_loss = 0.0
    epoch_acc = 0
    samples_count = 0
    hidden = None
    
    for i, data in enumerate(data_loader):
        # 训练LSTM，计算并输出Loss
        inputs, targets = data[0].to(device), data[1].to(device)
        batch_size = inputs.size(0)
        samples_count += batch_size * seq_len  # 计算样本总数
        
        optimzer.zero_grad()
        # output(bs, sl, vocab_size)  hidden(2, bs, sl, hs)
        output, hidden = my_net(inputs)

        h, c = hidden
        hidden = (h.detach(), c.detach())

        logits = output.view(-1, vocab_size)  # (bs*sl, vocab_size)
        target = targets.view(-1)

        loss = loss_function(logits, target)
        loss.backward()
        optimzer.step()
        
        # 计算当前批次的损失和准确率
        batch_loss = loss.item() * batch_size * seq_len
        batch_correct = (torch.argmax(logits, dim=-1) == target).sum().item()
        
        # 累加到epoch总量中
        epoch_loss += batch_loss
        epoch_acc += batch_correct
        
        # 打印每10个批次的进度
        if (i+1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{EPOCHES}], Batch [{i+1}/{len(data_loader)}], ' 
                  f'Loss: {batch_loss/(batch_size*seq_len):.4f}, Acc: {batch_correct/(batch_size*seq_len):.4f}')
    
    # 计算每个epoch的平均损失和准确率
    epoch_loss /= samples_count
    epoch_acc /= samples_count
    total_loss += epoch_loss
    total_acc += epoch_acc
    
    # 输出每个epoch的统计信息
    print(f'Epoch {epoch+1}/{EPOCHES} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')
    print(f'Avg loss so far: {total_loss/(epoch+1):.4f}, Avg accuracy so far: {total_acc/(epoch+1):.4f}')
    
    # 记录训练日志
    with open(RES_FILE, mode='a') as f:
        f.write(f'\nEpoch {epoch+1}/{EPOCHES}\n')
        f.write(f'Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}\n')
        f.write(f'Avg loss: {total_loss/(epoch+1):.4f}, Avg accuracy: {total_acc/(epoch+1):.4f}\n')

# 保存模型
torch.save(my_net, "model.h5")

Epoch [1/100], Batch [10/128], Loss: 7.2928, Acc: 0.0355
Epoch [1/100], Batch [20/128], Loss: 7.2397, Acc: 0.1263
Epoch [1/100], Batch [20/128], Loss: 7.2397, Acc: 0.1263
Epoch [1/100], Batch [30/128], Loss: 6.4849, Acc: 0.1124
Epoch [1/100], Batch [30/128], Loss: 6.4849, Acc: 0.1124
Epoch [1/100], Batch [40/128], Loss: 6.4905, Acc: 0.1202
Epoch [1/100], Batch [40/128], Loss: 6.4905, Acc: 0.1202
Epoch [1/100], Batch [50/128], Loss: 6.2891, Acc: 0.1582
Epoch [1/100], Batch [50/128], Loss: 6.2891, Acc: 0.1582
Epoch [1/100], Batch [60/128], Loss: 6.3361, Acc: 0.1746
Epoch [1/100], Batch [60/128], Loss: 6.3361, Acc: 0.1746
Epoch [1/100], Batch [70/128], Loss: 6.2342, Acc: 0.1696
Epoch [1/100], Batch [70/128], Loss: 6.2342, Acc: 0.1696
Epoch [1/100], Batch [80/128], Loss: 6.2127, Acc: 0.1646
Epoch [1/100], Batch [80/128], Loss: 6.2127, Acc: 0.1646
Epoch [1/100], Batch [90/128], Loss: 6.2859, Acc: 0.1696
Epoch [1/100], Batch [90/128], Loss: 6.2859, Acc: 0.1696
Epoch [1/100], Batch [100/128],

KeyboardInterrupt: 