In [1]:
from collections import Counter
import os

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

np.random.seed(100)
torch.manual_seed(100)

<torch._C.Generator at 0x1151fa430>

In [126]:
raw_data_path = 'data/all_couplets.txt'
vocabs_size = 200

### 划分数据集

In [28]:
def split_dataset(raw_data_path, test_size=3000):
    with open(raw_data_path, 'r') as f:
        lines = f.readlines()
        
    lines = list(map(str.strip, lines))
    
    np.random.shuffle(lines)
    
    train_lines = lines[test_size:]
    test_lines = lines[:test_size]
        
    return train_lines, test_lines

In [29]:
train_lines, test_lines = split_dataset(raw_data_path, test_size=3000)

len(train_lines), len(test_lines)

(771491, 3000)

### 获取字符表

In [30]:
def create_vocabs(train_lines, size=-1):
    counter = Counter(''.join(train_lines))
    vocabs = sorted(counter, key=lambda c: counter[c], reverse=True)
    
    if size != -1:
        vocabs = vocabs[:size]
    
    print(f"last character: {vocabs[-1]}, frequency: {counter[vocabs[-1]]}")
        
    return vocabs

In [31]:
vocabs = create_vocabs(train_lines, size=vocabs_size)

last character: 辣, frequency: 770


In [95]:
def create_index_char(vocabs):
    chars = vocabs.copy()
    chars.insert(0, 'unk')

    return dict(zip(range(1, len(chars) + 1), chars)), dict(zip(chars, range(1, len(chars) + 1)))

In [96]:
index2char, char2index = create_index_char(vocabs)

### 创建数据集

In [112]:
class Couplets_dataset(Dataset):
    def __init__(self, lines, char2index, min_len=10, max_len=20):
        index_list = []
        
        for line in lines:
            if len(line) < min_len:
                continue
                
            stop_char_index = line.index('。')
            if stop_char_index > max_len - 1:
                continue
            
            indexs = [char2index.get(c, 1) for c in line]
            
            padding = max_len - len(indexs)
            if padding > 0:
                indexs += [0] * padding
                
            index_list.append(indexs)
            
        self.data = torch.tensor(index_list)
        self.eye = torch.eye(len(char2index) + 1)
        
    def __getitem__(self, index):
        y = self.eye[self.data[index]][:, 1:]
        x = torch.cat([torch.zeros(1, y.size(1)), y[:-1]])
        return x, y
        
    
    def __len__(self):
        return self.data.size(0)

In [113]:
train_set = Couplets_dataset(train_lines, char2index, max_len=30)
test_set = Couplets_dataset(test_lines, char2index, max_len=30)

In [114]:
len(train_set), len(test_set)

(646808, 2496)

In [115]:
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=4)

In [125]:
for X, Y in test_loader:
    print(X.size())

torch.Size([256, 30, 2001])
torch.Size([256, 30, 2001])
torch.Size([256, 30, 2001])
torch.Size([256, 30, 2001])
torch.Size([256, 30, 2001])
torch.Size([256, 30, 2001])
torch.Size([256, 30, 2001])
torch.Size([256, 30, 2001])
torch.Size([256, 30, 2001])
torch.Size([192, 30, 2001])
