In [1]:
from collections import Counter
import os

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

np.random.seed(100)
torch.manual_seed(100)

<torch._C.Generator at 0x1151fa430>

In [177]:
raw_data_path = 'data/all_couplets.txt'
vocabs_size = 2000

### 划分数据集

In [28]:
def split_dataset(raw_data_path, test_size=3000):
    with open(raw_data_path, 'r') as f:
        lines = f.readlines()
        
    lines = list(map(str.strip, lines))
    
    np.random.shuffle(lines)
    
    train_lines = lines[test_size:]
    test_lines = lines[:test_size]
        
    return train_lines, test_lines

In [29]:
train_lines, test_lines = split_dataset(raw_data_path, test_size=3000)

len(train_lines), len(test_lines)

(771491, 3000)

### 获取字符表

In [30]:
def create_vocabs(train_lines, size=-1):
    counter = Counter(''.join(train_lines))
    vocabs = sorted(counter, key=lambda c: counter[c], reverse=True)
    
    if size != -1:
        vocabs = vocabs[:size]
    
    print(f"last character: {vocabs[-1]}, frequency: {counter[vocabs[-1]]}")
        
    return vocabs

In [31]:
vocabs = create_vocabs(train_lines, size=vocabs_size)

last character: 辣, frequency: 770


In [144]:
def create_index_char(vocabs):
    chars = vocabs.copy()
    chars.insert(0, 'unk')

    return dict(zip(range(2, len(chars) + 2), chars)), dict(zip(chars, range(2, len(chars) + 2)))

In [145]:
index2char, char2index = create_index_char(vocabs)

### 创建数据集

In [150]:
class Couplets_dataset(Dataset):
    def __init__(self, lines, char2index, min_len=10, max_len=20):
        index_list = []
        
        for line in lines:
            if len(line) < min_len:
                continue
                
            stop_char_index = line.index('。')
            if stop_char_index > max_len - 1:
                continue
            
            indexs = [char2index.get(c, 2) for c in line]
            
            padding = max_len - len(indexs)
            if padding > 0:
                indexs += [1] * padding
                
            index_list.append(indexs)
            
        self.data = torch.tensor(index_list)
        
    def __getitem__(self, index):
        y = self.data[index]
        x = torch.cat([torch.tensor([0]), y[:-1]])
        return x, y
    
    def __len__(self):
        return self.data.size(0)

In [151]:
train_set = Couplets_dataset(train_lines, char2index, max_len=30)
test_set = Couplets_dataset(test_lines, char2index, max_len=30)

In [152]:
len(train_set), len(test_set)

(646808, 2496)

In [157]:
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=4)

In [162]:
for X, Y in test_loader:
    print(X.shape)

torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([256, 30])
torch.Size([192, 30])


### 创建LSTM网络

In [190]:
class Couplets_net(nn.Module):
    def __init__(self, vocabs_size, embedding_dim=100, hidden_dim=200, num_layers=2):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        
        self.embedding = nn.Embedding(vocabs_size + 2, embedding_dim)
        
        self.lstm_cell_0 = nn.LSTMCell(embedding_dim, hidden_dim)
        self.fc_0 = nn.Linear(hidden_dim, hidden_dim)
        self.relu_0 = nn.ReLU(True)
        self.bn_0 = nn.BatchNorm1d(hidden_dim)
        self.lstm_cell_1 = nn.LSTMCell(hidden_dim, hidden_dim)
        self.fc_1 = nn.Linear(hidden_dim, vocabs_size + 2)
        self.softmax_1 = nn.Softmax(-1)

    def forward(self, X):   
        X = self.embedding(X)
        
        h_0 = torch.zeros(X.size(0), self.hidden_dim)
        c_0 = torch.zeros(X.size(0), self.hidden_dim)
        h_1 = torch.zeros(X.size(0), self.hidden_dim)
        c_1 = torch.zeros(X.size(0), self.hidden_dim)
        
        Y_out = []
        
        for i in range(X.size(1)):
            X_step = X[:, i]
                        
            h_0, c_0 = self.lstm_cell_0(X_step, (h_0, c_0))
            X_step = self.fc_0(h_0)
            X_step = self.relu_0(X_step)
            X_step = self.bn_0(X_step)
            
            h_1, c_1 = self.lstm_cell_1(X_step, (h_1, c_1))
            X_step = self.fc_1(h_1)
            X_step = self.softmax_1(X_step)
            
            Y_out.append(X_step)
            
        return torch.stack(Y_out).transpose(0, 1)

In [191]:
model = Couplets_net(vocabs_size)

In [194]:
X, Y = next(iter(train_loader))

In [195]:
model(X).size()

torch.Size([256, 30, 2002])