# LSTM on chinese-poetry/ci.song

In [None]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset,DataLoader
from torch.nn import functional as F
import torch
import numpy as np
import json
import os


In [None]:
device =  torch.device('cuda')
VOCAB_SIZE = 10000
MAX_LEN = 200
EMBEDDING_DIM = 100
N_UNITS = 128
VALIDATION_SPLIT = 0.2
SEED = 42
LOAD_MODEL = False
BATCH_SIZE = 64
EPOCHS = 100
JSON_DIR = 'C:/Users/Raphael/OneDrive/备份/桌面/大学生活和学习/教材/深度学习/Generative-Deep-Learning/datas/chinese-poetry/'

1. 加载数据

In [None]:
class PoetryDataSet(Dataset):
    def __init__(self, json_dir = JSON_DIR):
        songci_data = []
        # 获取目录下的所有文件
        for filename in os.listdir(json_dir):
            # 检查文件是否为json文件
            if filename.endswith('.json'):
                filepath = os.path.join(json_dir, filename)
                # 打开并加载json文件
                with open(filepath, 'r', encoding='utf-8') as json_data:
                    data = json.load(json_data)
                    for item in data:
                        songci_data.append(item)
        self.songci_data = songci_data
        self.filtered_data = [x["rhythmic"]+':'+''.join(x["paragraphs"]) for x in songci_data]
        self.n_songci = len(self.filtered_data)
        self.max_len = max(len(data) for data in self.filtered_data)  # 获取最大长度
        
        # 构建字符到索引的映射
        self.chars = sorted(list(set(''.join(self.filtered_data))))
        self.c2i = dict((c, i) for i, c in enumerate(self.chars))
        self.i2c = dict((i, c) for i, c in enumerate(self.chars))
        
        self.indexed_data = []
        for data in self.filtered_data:
            indexed_data = [self.c2i[c] for c in data]
            # 使用0填充到最大长度
            indexed_data += [0] * (self.max_len - len(indexed_data))
            self.indexed_data.append(indexed_data)
        self.indexed_data = torch.tensor(self.indexed_data)

    def print_example(self):
        index = np.random.randint(0, self.n_songci)
        print(self.filtered_data[index])
        print(self.indexed_data[index])
    
    def put_i2c(self):
        return self.i2c
    
    def put_c2i(self):
        return self.c2i

    def __len__(self):
        return self.n_songci
    
    def __getitem__(self, idx):
        return self.indexed_data[idx,:-1].cuda(),self.indexed_data[idx,1:].cuda()
        



In [None]:
poetrydataset = PoetryDataSet()
poetrydataset.print_example()

2. 模型

In [None]:
class LSTMModel(nn.Module):
    def __init__(self):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        self.lstm = nn.LSTM(EMBEDDING_DIM, N_UNITS, batch_first=True)
        self.fc = nn.Linear(N_UNITS, VOCAB_SIZE)
        self.activation = nn.Softmax(dim=2)

    def forward(self, inputs):
        x = self.embedding(inputs)
        x, _ = self.lstm(x)
        x = self.fc(x)
        outputs = self.activation(x)
        return outputs

    def train_model(self,poetrydataset):
        loss_function = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.parameters(), lr=0.001)
        for epoch in range(EPOCHS):
            for batch_idx, (data, target) in enumerate(poetrydataset):
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                outputs = self.forward(data)
                loss = loss_function(outputs.view(-1, VOCAB_SIZE), target.view(-1))
                loss.backward()
                optimizer.step()
                if batch_idx % 100 == 0:
                    print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch, batch_idx, len(poetrydataset),
                        100. * batch_idx / len(poetrydataset), loss.item()))
    def save(self):
        torch.save(self.state_dict(), 'model.pth')
    def load(self):
        self.load_state_dict(torch.load('model.pth'))
    def generate(self, start_prompt,c2i,i2c, max_tokens = 200, temperature = 1):
        '''生成文本
        参数：
            start_prompt: 起始文本
            max_tokens: 生成文本的最大长度
            temperature: 控制文本生成的创造性，值越大生成的文本越有创造性，值越小生成的文本越固定        
        '''
        start_tokens = [c2i.get(x, 1) for x in start_prompt.split()]
        sample_token = None
        info = []
        while len(start_tokens) < max_tokens:
            x = torch.tensor([start_tokens]).to(device)
            y = self.forward(x)
            y = y.detach().cpu().numpy()
            sample_token, probs = self.sample_from(y[0][-1], temperature)
            info.append({"prompt": start_prompt, "word_probs": probs})
            if sample_token > len(i2c):
                sample_token = 0
            start_tokens.append(sample_token)
            
            start_prompt = start_prompt + " " + i2c[sample_token]
        print(f"\ngenerated text:\n{start_prompt}\n")
        return info
    def sample_from(self, probs, temperature):  # <2>
        # probs = probs ** (1.0 / temperature)
        # probs = probs / np.sum(probs)
        return np.random.choice(len(probs), p=probs), probs


In [None]:
lstm = LSTMModel().cuda()

dataloader = DataLoader(poetrydataset, batch_size=BATCH_SIZE, shuffle=True)

lstm.train_model(dataloader)

lstm.save()



In [None]:
lstm.generate('水调歌头:明月几时有，把酒问青天',poetrydataset.put_c2i(),poetrydataset.put_i2c(),100,0.01)