In [23]:
from collections import Counter
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

np.random.seed(100)
torch.manual_seed(100)

torch.__version__

'1.3.0'

In [24]:
raw_data = 'data/all_couplets.txt'
vocab_size = 2000

### 划分数据集

In [25]:
def split_dataset(raw_data, test_size=3000):
    with open(raw_data, 'r') as f:
        lines = f.readlines()
        
    lines = list(map(str.strip, lines))
    
    np.random.shuffle(lines)
    
    train_lines = lines[test_size:]
    test_lines = lines[:test_size]
        
    return train_lines, test_lines

In [26]:
train_lines, test_lines = split_dataset(raw_data, test_size=3000)
len(train_lines), len(test_lines)

(771491, 3000)

### 获取字符表

In [31]:
def create_vocab(train_lines, size=-1):
    counter = Counter(''.join(train_lines))
    vocab = sorted(counter, key=lambda c: counter[c], reverse=True)
    
    if size != -1:
        vocab = vocab[:size]
    
    print(f"last character: {vocab[-1]}, frequency: {counter[vocab[-1]]}")
        
    return vocab

In [32]:
vocab = create_vocab(train_lines, size=vocab_size)

last character: 辣, frequency: 771


In [96]:
def create_index_char(vocab):
    vocab = vocab.copy()
    vocab = [' ', '<unk>', '<bos>', '<eos>'] + vocab
    return dict(zip(range(0, len(vocab)), vocab)), dict(zip(vocab, range(0, len(vocab))))

In [97]:
index2char, char2index = create_index_char(vocab)

### 创建数据集

In [98]:
class Couplets_dataset(Dataset):
    def __init__(self, lines, char2index, min_len=10, max_len=20):
        self.min_len = min_len
        self.max_len = max_len
        
        ups = []
        downs = []
        
        for line in lines:  
            if '；' in line:
                line_split = tuple(line.split('；'))
                if len(line_split) == 2:
                    up, down = line_split
            else:
                continue
                
            down = down[:-1]
            if len(up) != len(down):
                continue
                
            up_len = len(up)
            if up_len < min_len or up_len > max_len:
                continue
            
            ups.append(up)
            downs.append(down)
            
        self.ups = ups
        self.downs = downs
        self.eye = torch.eye(len(char2index), dtype=torch.float32)
        
    def __getitem__(self, index):
        up, down = self.ups[index], self.downs[index]
        
        up_len = len(up)
        if up_len < self.max_len:
            up = up + ' ' * (self.max_len - up_len)
            down = ['<bos>'] + list(down) + ['<eos>'] + [' '] * (self.max_len - up_len)
        
        x = torch.tensor([char2index.get(c, char2index['<unk>']) for c in up], dtype=torch.long)
        y = self.eye[[char2index.get(c, char2index['<unk>']) for c in down]]
        
        return x, y
    
    def __len__(self):
        return len(self.ups)

In [99]:
train_set = Couplets_dataset(train_lines, char2index, min_len=5, max_len=15)

In [100]:
x, y = train_set[0]

In [101]:
len(y)

17

In [102]:
train_set.downs[0]

'化雪融冰月色凉'

In [103]:
[index2char[index] for index in torch.argmax(y, dim=1).numpy()]

['<bos>',
 '化',
 '雪',
 '融',
 '冰',
 '月',
 '色',
 '凉',
 '<eos>',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ',
 ' ']