In [3]:
from collections import Counter
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

np.random.seed(100)
torch.manual_seed(100)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.__version__

'1.3.0'

In [4]:
raw_data = 'data/all_couplets.txt'
vocab_size = 2000

### 划分数据集

In [5]:
def split_dataset(raw_data, test_size=3000):
    with open(raw_data, 'r') as f:
        lines = f.readlines()
        
    lines = list(map(str.strip, lines))
    
    np.random.shuffle(lines)
    
    train_lines = lines[test_size:]
    test_lines = lines[:test_size]
        
    return train_lines, test_lines

In [6]:
train_lines, test_lines = split_dataset(raw_data, test_size=3000)
len(train_lines), len(test_lines)

(771491, 3000)

### 获取字符表

In [7]:
def create_vocab(train_lines, size=-1):
    counter = Counter(''.join(train_lines))
    vocab = sorted(counter, key=lambda c: counter[c], reverse=True)
    
    if size != -1:
        vocab = vocab[:size]
    
    print(f"last character: {vocab[-1]}, frequency: {counter[vocab[-1]]}")
        
    return vocab

In [8]:
vocab = create_vocab(train_lines, size=vocab_size)

last character: 辣, frequency: 771


In [9]:
def create_index_char(vocab):
    vocab = vocab.copy()
    vocab = ['<blk>', '<unk>', '<bos>', '<eos>'] + vocab
    return dict(zip(range(0, len(vocab)), vocab)), dict(zip(vocab, range(0, len(vocab))))

In [10]:
index2char, char2index = create_index_char(vocab)

### 创建数据集

In [11]:
class Couplets_dataset(Dataset):
    def __init__(self, lines, char2index, min_len=10, max_len=20):
        self.min_len = min_len
        self.max_len = max_len
        
        ups = []
        downs = []
        
        for line in lines:  
            if '；' in line:
                line_split = tuple(line.split('；'))
                if len(line_split) == 2:
                    up, down = line_split
            else:
                continue
                
            down = down[:-1]
            if len(up) != len(down):
                continue
                
            up_len = len(up)
            if up_len < min_len or up_len > max_len:
                continue
            
            ups.append(up)
            downs.append(down)
            
        self.ups = ups
        self.downs = downs
        self.eye = torch.eye(len(char2index), dtype=torch.float32)
        
    def __getitem__(self, index):
        up, down = self.ups[index], self.downs[index]
        down = ['<bos>'] + list(down) + ['<eos>']
        
        up_len = len(up)
        if up_len < self.max_len:
            up = list(up) + ['<blk>'] * (self.max_len - up_len)
            down = down + ['<blk>'] * (self.max_len - up_len)
        
        x = torch.tensor([char2index.get(c, char2index['<unk>']) for c in up], dtype=torch.long)
        y = self.eye[[char2index.get(c, char2index['<unk>']) for c in down]]
        
        return x, y
    
    def __len__(self):
        return len(self.ups)

In [12]:
train_set = Couplets_dataset(train_lines, char2index, min_len=5, max_len=15)
test_set = Couplets_dataset(test_lines, char2index, min_len=5, max_len=15)

len(train_set), len(test_set)

(644701, 2481)

In [13]:
# Test

X, Y = train_set[6]
X.size(), Y.size()

(torch.Size([15]), torch.Size([17, 2004]))

In [14]:
# Test

[index2char[i.item()] for i in train_set[6][1].argmax(dim=1)]

['<bos>',
 '诚',
 '心',
 '待',
 '民',
 '铸',
 '<unk>',
 '魂',
 '<eos>',
 '<blk>',
 '<blk>',
 '<blk>',
 '<blk>',
 '<blk>',
 '<blk>',
 '<blk>',
 '<blk>']

In [15]:
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=4)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=4)

In [16]:
# Test

X, Y = next(iter(train_loader))
X.size(), Y.size()

(torch.Size([256, 15]), torch.Size([256, 17, 2004]))

### 构建网络

#### encoder

In [17]:
class Couplet_encoder(nn.Module):
    def __init__(self, embedding_layer, hidden_dim=200, layer_num=2):
        super().__init__()
        
        self.embedding = embedding_layer
        self.bilstm = nn.LSTM(embedding_layer.weight.size(1), hidden_dim, layer_num, bidirectional=True)
        
    def forward(self, X):
        X = self.embedding(X)
        X = self.bilstm(X)
        return X

In [18]:
# Test

embedding_layer = nn.Embedding(len(char2index), 100)
encoder = Couplet_encoder(embedding_layer)

X, Y = next(iter(train_loader))
output, (h_n, c_n) = encoder(X)
output.size(), h_n.size(), c_n.size()

(torch.Size([256, 15, 400]),
 torch.Size([4, 15, 200]),
 torch.Size([4, 15, 200]))

#### attention

In [19]:
class Attention_layer(nn.Module):
    def __init__(self, encoder_output_dim, decoder_h_dim):
        super().__init__()
        
        self.fc = nn.Linear(encoder_output_dim + decoder_h_dim, 1)
        self.softmax = nn.Softmax(1)
        
    def forward(self, encoder_output, decoder_h):
        decoder_h_repeat = torch.stack([decoder_h] * encoder_output.size(1), 1)
        merge = torch.cat((encoder_output, decoder_h_repeat), dim=2)
        merge = merge.view(merge.size(0) * merge.size(1), merge.size(2))
        
        X = self.fc(merge).view(encoder_output.size(0), encoder_output.size(1))
        attention_weights = self.softmax(X).view(X.size(0), X.size(1), 1)
        
        return (attention_weights * encoder_output).sum(dim=1)

In [20]:
# Test

encoder_output = torch.randn(100, 15, 40)
decoder_h = torch.randn(100, 40)

attention_layer = Attention_layer(encoder_output.size(-1), decoder_h.size(-1))
attention_layer(encoder_output, decoder_h).size()

torch.Size([100, 40])

#### decoder

In [76]:
class Couplet_decoder(nn.Module):
    def __init__(self, embedding_layer, attention_layer, encoder_output_dim, encoder_state_dim, hidden_dim=200):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        
        self.embedding = embedding_layer
        self.attention = attention_layer
        
        self.h_project_fc = nn.Linear(encoder_state_dim, hidden_dim * 2)
        self.c_project_fc = nn.Linear(encoder_state_dim, hidden_dim * 2)
        
        self.lstm_cell_0 = nn.LSTMCell(embedding_layer.weight.size(1) + encoder_output_dim, hidden_dim)
        self.lstm_cell_1 = nn.LSTMCell(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, embedding_layer.weight.size(0))
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, X, encoder_output, encoder_last_state):
        encoder_last_h, encoder_last_c = encoder_last_state
        decoder_input_h, decoder_input_c = self.h_project_fc(encoder_last_h), self.c_project_fc(encoder_last_c)
        
        decoder_h_0, decoder_h_1 = decoder_input_h[:, :self.hidden_dim], decoder_input_h[:, self.hidden_dim:]
        decoder_c_0, decoder_c_1 = decoder_input_c[:, :self.hidden_dim], decoder_input_h[:, self.hidden_dim:]
        
        outputs = []        
        for i in range(X.size(1)):
            X_step = X[:, i].argmax(1)
            embedding_step = self.embedding(X_step)
            attention_step = self.attention(encoder_output, decoder_h_1)
            input_step = torch.cat([embedding_step, attention_step], dim=1)
            
            decoder_h_0, decoder_c_0 = self.lstm_cell_0(input_step, (decoder_h_0, decoder_c_0))
            decoder_h_1, decoder_c_1 = self.lstm_cell_1(decoder_h_0, (decoder_h_1, decoder_c_1))
            Y_step = self.fc(decoder_h_1)
            outputs.append(Y_step)
            
        return torch.stack(outputs, dim=1)
    
    def _beam_score(self, Y_prob, alpha=0.7):
        return ((1 / Y_prob.size(1)) ** alpha) * Y_prob.log().sum(dim=1)
    
    def _pick_beam_width(self, best_probs_list, step_probs_list, alpha=0.7):
        score_indices = []
        
        for i in range(len(best_probs_list)):
            best_probs = best_probs_list[i]
            step_probs = step_probs_list[i]
            step_prob_list = list(torch.split(step_probs, 1, dim=1))
            
            for j in range(len(step_prob_list)):
                probs = torch.cat([best_probs, step_prob_list[j]], dim=1)
                score = self._beam_score(probs, alpha).item()               
                score_indices.append((score, i, j))

        sorted_score_indices = sorted(score_indices, key=lambda x: x[0], reverse=True)
        
        return sorted_score_indices
                
    
    def beam_search(self, encoder_output, encoder_last_state, index2char, char2index, 
                    beam_width=3, alpha=0.7, max_len=15):
        encoder_last_h, encoder_last_c = encoder_last_state
        decoder_input_h, decoder_input_c = self.h_project_fc(encoder_last_h), self.c_project_fc(encoder_last_c)
        
        decoder_h_0, decoder_h_1 = decoder_input_h[:, :self.hidden_dim], decoder_input_h[:, self.hidden_dim:]
        decoder_c_0, decoder_c_1 = decoder_input_c[:, :self.hidden_dim], decoder_input_h[:, self.hidden_dim:]
        
        best_output_list = [torch.tensor([[char2index['<bos>']]] * encoder_output.size(0), 
                                         dtype=torch.long, device=device, requires_grad=False)]
        best_probs_list = None
        best_state_0_list = [(decoder_h_0, decoder_c_0)]
        best_state_1_list = [(decoder_h_1, decoder_c_1)]
        
        for _ in range(max_len):
            step_probs_list = []
            step_indices_list = []
            step_state_0_list = []
            step_state_1_list = []
            
            for i in range(1 if best_probs_list is None else beam_width):
                X_step = best_output_list[i][:, -1]
                decoder_h_0, decoder_c_0 = best_state_0_list[i]
                decoder_h_1, decoder_c_1 = best_state_1_list[i]
            
                embedding_step = self.embedding(X_step)
                attention_step = self.attention(encoder_output, decoder_h_1)
                input_step = torch.cat([embedding_step, attention_step], dim=1)
            
                decoder_h_0, decoder_c_0 = self.lstm_cell_0(input_step, (decoder_h_0, decoder_c_0))
                decoder_h_1, decoder_c_1 = self.lstm_cell_1(decoder_h_0, (decoder_h_1, decoder_c_1))
                Y_step = self.fc(decoder_h_1)
                Y_step = self.softmax(Y_step)
                
                sorted_probs, sorted_indices = Y_step.sort(dim=1, descending=True)
                step_probs = sorted_probs[:, :beam_width]
                step_indices = sorted_indices[:, :beam_width]
                
                step_probs_list.append(step_probs)
                step_indices_list.append(step_indices)                
                step_state_0_list.append((decoder_h_0, decoder_c_0))
                step_state_1_list.append((decoder_h_1, decoder_c_1))
                
            if best_probs_list is None:
                best_probs_list = list(torch.split(step_probs_list[0], 1, dim=1))
                best_output_list = list(torch.split(step_indices_list[0], 1, dim=1))
                best_state_0_list = step_state_0_list * 3
                best_state_1_list = step_state_1_list * 3
            else:
                sorted_score_indices = self._pick_beam_width(best_probs_list, step_probs_list, alpha)
                    
                temp_best_probs_list = []
                temp_best_output_list = []
                best_state_0_list = []
                best_state_1_list = []
                for j in range(beam_width):
                    score, ii, jj = sorted_score_indices[j]
                    best_probs = torch.cat([best_probs_list[ii], 
                                            torch.split(step_probs_list[ii], 1, dim=1)[jj]],
                                           dim=1)
                    best_output = torch.cat([best_output_list[ii],
                                           torch.split(step_indices_list[ii], 1, dim=1)[jj]],
                                           dim=1),
                    temp_best_probs_list.append(best_probs)
                    temp_best_output_list.append(best_output[0])
                    best_state_0_list.append(step_state_0_list[ii])
                    best_state_1_list.append(step_state_1_list[ii])
                    
                best_probs_list = temp_best_probs_list.copy()
                best_output_list = temp_best_output_list.copy()
                
        best_probs = torch.cat(best_probs_list, dim=0)
        best_outputs = torch.cat(best_output_list, dim=0) 
        scores = self._beam_score(best_probs, alpha)
        
        return best_output[torch.argmax(scores)]

In [77]:
# Test

X = torch.randn(100, 15, 40)
encoder_last_state = torch.randn(100, 20), torch.randn(100, 20)

encoder_output = torch.randn(100, 15, 40)
decoder_h = torch.randn(100, 40)
attention_layer = Attention_layer(encoder_output.size(-1), decoder_h.size(-1))

embedding_layer = nn.Embedding(len(char2index), 120)


decoder = Couplet_decoder(embedding_layer, attention_layer, encoder_output_dim=40, encoder_state_dim=20, hidden_dim=40)
decoder(X, encoder_output, encoder_last_state).size()

torch.Size([100, 15, 2004])

In [79]:
# Test

X = torch.randn(1, 15, 40)
encoder_last_state = torch.randn(1, 20), torch.randn(1, 20)

encoder_output = torch.randn(1, 15, 40)
decoder_h = torch.randn(1, 40)
attention_layer = Attention_layer(encoder_output.size(-1), decoder_h.size(-1))

embedding_layer = nn.Embedding(len(char2index), 120)

with torch.no_grad():
    output = decoder.beam_search(encoder_output, encoder_last_state, index2char, char2index, beam_width=3)
    
output

tensor([[ 756,  756, 1828, 1828, 1828, 1828, 1142, 1828, 1828, 1142, 1828, 1828,
         1828, 1142,  996]])