<center>
<h1>基于pytorch + LSTM 的古诗生成</h1>
</center>

### 作业介绍: 
本课程使用pytorch框架, 完成NLP任务:古诗生成,使用的模型为 LSTM, 并训练了词向量, 支持随机古诗和藏头诗生成, 并且生成的古诗具有多变性。

<br>

### 导包:

In [1]:
import os
import numpy as np
import pickle
import torch
import torch.nn as nn
from gensim.models.word2vec import Word2Vec
from torch.utils.data import Dataset, DataLoader



<br>

### 生成切分文件:

In [2]:
def split_text(file="poetry_7.txt", train_num=6000):
    all_data = open(file, "r", encoding="utf-8").read()
    with open("split_7.txt", "w", encoding="utf-8") as f:
        split_data = " ".join(all_data)
        f.write(split_data)
    return split_data[:train_num * 64]

<br>

### 训练词向量:

In [3]:
def train_vec(split_file="split_7.txt", org_file="poetry_7.txt", train_num=6000):
    param_file = "word_vec.pkl"
    org_data = open(org_file, "r", encoding="utf-8").read().split("\n")[:train_num]
    if os.path.exists(split_file):
        all_data_split = open(split_file, "r", encoding="utf-8").read().split("\n")[:train_num]
    else:
        all_data_split = split_text().split("\n")[:train_num]

    if os.path.exists(param_file):
        return org_data, pickle.load(open(param_file, "rb"))

    models = Word2Vec(all_data_split, vector_size=128, workers=7, min_count=1)
    pickle.dump([models.syn1neg, models.wv.key_to_index, models.wv.index_to_key], open(param_file, "wb"))
    return org_data, (models.syn1neg, models.wv.key_to_index, models.wv.index_to_key)

<br>

### 构建数据集:

In [4]:
class Poetry_Dataset(Dataset):
    def __init__(self, w1, word_2_index, all_data):
        self.w1 = w1
        self.word_2_index = word_2_index
        self.all_data = all_data

    def __getitem__(self, index):
        a_poetry = self.all_data[index]

        a_poetry_index = [self.word_2_index[i] for i in a_poetry]
        xs = a_poetry_index[:-1]
        ys = a_poetry_index[1:]
        xs_embedding = self.w1[xs]

        return xs_embedding, np.array(ys).astype(np.int64)

    def __len__(self):
        return len(self.all_data)

<br>

### 模型构建:

In [5]:
class Poetry_Model_lstm(nn.Module):
    def __init__(self, hidden_num, word_size, embedding_num):
        super().__init__()

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        ######定义模型######
        
        
        
        
        
        
        

    def forward(self, xs_embedding, h_0=None, c_0=None):
        if h_0 == None or c_0 == None:
            h_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
            c_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num), dtype=np.float32))
        h_0 = h_0.to(self.device)
        c_0 = c_0.to(self.device)
        xs_embedding = xs_embedding.to(self.device)
        ######定义模型######
            
            
            
            
            
            
            

        return pre, (h_0, c_0)

<br>

### 自动生成古诗:


In [6]:
def generate_poetry_auto():
    result = ""
    word_index = np.random.randint(0, word_size, 1)[0]

    result += index_2_word[word_index]
    h_0 = torch.tensor(np.zeros((2, 1, hidden_num), dtype=np.float32))
    c_0 = torch.tensor(np.zeros((2, 1, hidden_num), dtype=np.float32))

    for i in range(31):
        word_embedding = torch.tensor(w1[word_index][None][None])
        pre, (h_0, c_0) = model(word_embedding, h_0, c_0)
        word_index = int(torch.argmax(pre))
        result += index_2_word[word_index]

    return result


<br>

### 藏头诗生成:

In [11]:
def generate_poetry_acrostic():
    input_text = input("请输入四个汉字：")[:4]
    result = ""
    punctuation_list = ["，", "。", "，", "。"]
    for i in range(4):
        result += input_text[i]
        h_0 = torch.tensor(np.zeros((2, 1, hidden_num), dtype=np.float32))
        c_0 = torch.tensor(np.zeros((2, 1, hidden_num), dtype=np.float32))
        word = input_text[i]
        for j in range(6):
            word_index = word_2_index[word]
            word_embedding = torch.tensor(w1[word_index][None][None])
            pre , (h_0,c_0) = model(word_embedding,h_0,c_0)
            word = index_2_word[int(torch.argmax(pre))]
            result += word
        result+=punctuation_list[i]

    return result

<br>

### 主函数: 定义参数, 模型, 优化器, 模型训练

In [12]:
if __name__ == "__main__":

    all_data, (w1, word_2_index, index_2_word) = train_vec(train_num=300)

    batch_size = 32
    epochs = 100
    lr = 0.01
    hidden_num = 128
    word_size, embedding_num = w1.shape

    dataset = Poetry_Dataset(w1, word_2_index, all_data)
    dataloader = DataLoader(dataset, batch_size)

    model = Poetry_Model_lstm(hidden_num, word_size, embedding_num)
    model = model.to(model.device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    for e in range(epochs):
        for batch_index, (batch_x_embedding, batch_y_index) in enumerate(dataloader):
            model.train()
            batch_x_embedding = batch_x_embedding.to(model.device)
            batch_y_index = batch_y_index.to(model.device)

            #模型预测
            
            
            #计算损失
            

            # 梯度反传 , 梯度累加, 但梯度并不更新, 梯度是由优化器更新的
            
            
            # 使用优化器更新梯度
            
            
            # 梯度清零
               
            

            if batch_index % 100 == 0:
                # model.eval()
                print(f"loss:{loss:.3f}")
                print(generate_poetry_auto())

loss:7.710
季火载火载火。。。火。。。载。。。。。。，。。。。。。，，。。。
loss:6.969
窅，。，。，，。。，。。，，。。，。。，，，，，，，，，。，，，
loss:6.843
启，。，。。。，，。。，。。，，，。，。。。。。。。。。。。。，
loss:6.745
奇，，，，，，，，，，，，，，。，。，，。，，，，，。，，，，，
loss:6.695
大山山山，山，，，，，，，，。。，，，，，，。，，，。，。，，。
loss:6.598
译山三风，日山，三光，三三，山，三三，山，山山，三三，山三，日山
loss:6.529
复台山三，，山，，山，，海，，海，，，山，山，，海，海，，，海，
loss:6.372
椰人山光海海，，一不，海，海海，，海，海，海，天，海海，，，海海
loss:6.286
洵色风三海海，，一一，一一，，，一，，一，。，，天，，，一，一，
loss:6.174
端花三光海天，一山不一天，一，一一。一天。一。山天山，一。一一天
loss:6.027
闹山风海生生，，山不天无天。。山不山不。无。一天无无。一。一一天
loss:5.996
柁门烟路六三天，一来一山无公。。今一一花无水。一来一山一。。一一
loss:5.859
摇阙风山海斗，，海山一山不。。海山一一不。。海来一山不天。一里一
loss:5.800
孙有无路海远，，海风一来不。。一山烟花不中，。来烟山一水。。海春
loss:5.723
脱来风气都水，，一天烟山一天。。海山烟山无来。。海风风一天青。一
loss:5.658
锐花梅结六天天，一山一花一战。。一山一影一染花。一山一风一天，，
loss:5.610
金门风光六天有，金来编蕉一天香。一来一山一天。，山天风日一中，一
loss:5.552
经色风路一氏迈，一山一望一色吟。一是一风不人。，山风人一天花。一
loss:5.517
然秋若气斗天席，一山烟香一公。。海烟山一人花。一来不山一时花。一
loss:5.493
心人梅路倏有，海山不溪不无深。一知一风一水。，海不人回无心。一是
loss:5.457
荡榜下路斗生缪，一锁一山不超端。一来一人一无新。一里一山一无花。
loss:5.415
廨人三桑暗天天，一章如咏一水香。海须一风一天花。海来一望不无花。
loss:5.402
荻色高光斗水，，山不山旧相端。一来一水一天

TypeError: generate_poetry_acrostic() missing 1 required positional argument: 'self'

In [None]:
描述一下你的模型：