In [1]:
%matplotlib inline
import torch
import torchvision
from IPython import display
from matplotlib import pyplot as plt
import numpy as np
import random

### 确认运行设备

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

### 加载数据

In [3]:
with open('./data/jaychou_lyrics.txt', 'r') as f:
    corpus = f.read()
    corpus = corpus.replace('\n', ' ').replace('\u3000', ' ')

In [4]:
corpus = corpus[:10000]
idx_to_char = list(set(corpus))
vocab_size = len(idx_to_char)
vocab_size

1027

In [5]:
char_to_idx = {c:i for i, c in enumerate(idx_to_char)}

#### 时序数据的采样

不同的采样方式, 在训练实现上会略有不同

In [6]:
# 随机采样 每次采样前都需要重新初始化隐藏状态
def data_iter_random(corpus_indices, batch_size, window, device):
    num_example = (len(corpus_indices)-1) // window
    batch_num = num_example // batch_size
    
    example_indices = list(range(num_example))
    random.shuffle(example_indices)
    
    for i in range(batch_num):
        batch_indices = example_indices[i*batch_size: (i+1)*batch_size]
        train_example = [corpus_indices[j*window: (j+1)*window] for j in batch_indices]
        test_example = [corpus_indices[j*window+1: (j+1)*window+1] for j in batch_indices]
        yield torch.tensor(train_example, dtype=torch.float32, device=device), torch.tensor(test_example, dtype=torch.float32, device=device)

In [7]:
my_seq = list(range(30))
print(my_seq)
for X, Y in data_iter_random(my_seq, batch_size=2, window=6, device=device):
    print('X: ', X, '\nY:', Y, '\n')

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
X:  tensor([[12., 13., 14., 15., 16., 17.],
        [18., 19., 20., 21., 22., 23.]]) 
Y: tensor([[13., 14., 15., 16., 17., 18.],
        [19., 20., 21., 22., 23., 24.]]) 

X:  tensor([[ 6.,  7.,  8.,  9., 10., 11.],
        [ 0.,  1.,  2.,  3.,  4.,  5.]]) 
Y: tensor([[ 7.,  8.,  9., 10., 11., 12.],
        [ 1.,  2.,  3.,  4.,  5.,  6.]]) 



In [8]:
# 相邻采样
def data_iter_consecutive(corpus_indices, batch_size, window, device):
    batch_len = len(corpus_indices) // batch_size
    corpus_indices = torch.tensor(corpus_indices, dtype=torch.float32, device=device)
    corpus_indices = corpus_indices[0:batch_len*batch_size]
    corpus_indices = corpus_indices.view(batch_size, batch_len)
    
    batch_num = (batch_len-1)//window
    for i in range(batch_num):
        train_example = corpus_indices[:,i*window:(i+1)*window]
        test_example = corpus_indices[:,i*window+1:(i+1)*window+1]
        yield train_example, test_example
    

In [9]:
my_seq = list(range(30))
print(my_seq)
for X, Y in data_iter_consecutive(my_seq, batch_size=2, window=6, device=device):
    print('X: ', X, '\nY:', Y, '\n')

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
X:  tensor([[ 0.,  1.,  2.,  3.,  4.,  5.],
        [15., 16., 17., 18., 19., 20.]]) 
Y: tensor([[ 1.,  2.,  3.,  4.,  5.,  6.],
        [16., 17., 18., 19., 20., 21.]]) 

X:  tensor([[ 6.,  7.,  8.,  9., 10., 11.],
        [21., 22., 23., 24., 25., 26.]]) 
Y: tensor([[ 7.,  8.,  9., 10., 11., 12.],
        [22., 23., 24., 25., 26., 27.]]) 



#### onehot

In [10]:
def one_hot(x, n_class, dtype=torch.float32):
    x = x.long()
    res = torch.zeros(x.shape[0], n_class, dtype=dtype, device=x.device)
    res.scatter_(1, x.view(-1,1), 1)
    return res

In [11]:
x = torch.tensor([1.,2.])
one_hot(x, 3)

tensor([[0., 1., 0.],
        [0., 0., 1.]])

In [12]:
torch.nn.functional.one_hot(x.long().view(-1,1), 3)

tensor([[[0, 1, 0]],

        [[0, 0, 1]]])

### Define RNN

In [13]:
def to_onehot(x_batch, n_class=vocab_size):
    """
    x_batch: batch x seq_len
    return: [tensor(batch, n_class), ...] x sql_len
    """
    return [one_hot(x_batch[:,i] , n_class) for i in range(x_batch.shape[1])]
    

In [44]:
def get_params(num_inputs, num_hiddens, num_outputs, device=device):
    """
    initialize params
    """
    def _randn(shape):
        return torch.nn.Parameter(
            torch.tensor(np.random.normal(0, 0.01, size=shape), dtype=torch.float32, device=device)
        )
    
    def _zero(size):
        return torch.nn.Parameter(torch.zeros(size, device=device))
    
    W_xh = _randn((num_inputs, num_hiddens))
    W_hh = _randn((num_hiddens, num_hiddens))
    W_hq = _randn((num_hiddens, num_outputs))
    
    b_h = _zero(num_hiddens)
    b_q = _zero(num_outputs)
    
    return torch.nn.ParameterList([W_xh, W_hh, W_hq, b_h, b_q])

In [45]:
def init_hiden_state(batch_size, num_hiddens, device):
    return torch.zeros((batch_size, num_hiddens), device=device)    

In [46]:
def rnn(inputs, hidden_state, params):
    """
    return outputs, hidden_state
    """
    W_xh, W_hh, W_hq, b_h, b_q = params
    outputs = []
    for batch_x in inputs:
        hidden_state = torch.tanh(torch.mm(batch_x, W_xh) + torch.mm(hidden_state, W_hh) + b_h)
        outputs.append(torch.mm(hidden_state, W_hq) + b_q)
    return outputs, hidden_state

In [48]:
# test size
X = torch.arange(10).view(2, 5)
inputs = to_onehot(X, vocab_size)
print(len(inputs), inputs[0].shape)

hidden_state = init_hiden_state(2, num_hiddens, device)
print(hidden_state.shape)
outputs, hidden_state = rnn(inputs, hidden_state, get_params(vocab_size, 256, vocab_size, device))
print(len(outputs), outputs[0].shape, hidden_state.shape)

5 torch.Size([2, 1027])
torch.Size([2, 256])
5 torch.Size([2, 1027]) torch.Size([2, 256])


In [18]:
def predict_rnn(prefix, 
                num_chars, 
                rnn, 
                params, 
                init_hiden_state_func, 
                num_hiddens, 
                vocab_size, 
                device, 
                idx_to_char, 
                char_to_idx):
    H = init_hiden_state_func(1, num_hiddens, device)
    outputs = [char_to_idx[prefix[0]]]
    for i in range(1, num_chars+len(prefix)):
        X = to_onehot(torch.tensor(outputs[-1]).view(1,-1))
        Y, state = rnn(X, H, params)
        
        if i <= len(prefix)-1:
            outputs.append(char_to_idx[prefix[i]])
        else:
            outputs.append(int(Y[0].argmax(dim=1).item()))
    return ''.join([idx_to_char[i] for i in outputs])
        
    
predict_rnn('分开', 10, rnn, get_params(), init_hiden_state, num_hiddens, vocab_size,
            device, idx_to_char, char_to_idx)

'分开移戒飞早柳拿提跑告近'

### Train RNN

In [75]:
def softmax(x):
    x = x - x.max(dim=1, keepdim=True).values
    return x.exp() / x.exp().sum(dim=1, keepdim=True)

def cross_entropy_loss(y, y_hat):
    """
    y=[0,1,2]
    y_hat = [
        [0.7,0.2,0.1],
        [0.2,0.5,0.3],
        [0.1,0.1,0.8],
    ]
    """
    y_hat = softmax(y_hat)
    return -torch.log(
        y_hat.gather(dim=1, index=y.type(torch.long).view(-1,1))+1e-5
    ).sum()

def sgd(params, lr, bs):
    for param in params:
        param.data -= param.grad / bs * lr
        
def grad_clipping(params, theta, device):
    norm = torch.tensor([0.0], device=device)
    for param in params:
        norm += (param.grad.data ** 2).sum()
    norm = norm.sqrt().item()
    if norm > theta:
        for param in params:
            param.grad.data *= (theta / norm)


In [121]:
import time, math

def train(data_iter, 
          params, 
          num_epochs, 
          batch_size, 
          lr, 
          clipping_theta, 
          num_hiddens, 
          vocab_size, 
          device, 
          is_rand_sample=True):
    H = init_hiden_state(batch_size, num_hiddens, device)
    loss = torch.nn.CrossEntropyLoss()
    
    for epoch in range(num_epochs):
        l_sum, n, start = 0.0, 0, time.time()
        
        for X, Y in data_iter():
            if is_rand_sample:
                H = init_hiden_state(batch_size, num_hiddens, device)

            inputs = to_onehot(X, vocab_size)
            y_hat, H = rnn(inputs, H, params) # y_hat: list of bs x vocab_size
            y_hat = torch.cat(y_hat, dim=0)
            y = Y.T.reshape(-1)
            l = loss(y_hat, y.long())

            for param in params:
                if param.grad is not None:
                    param.grad.data.zero_()

            l.backward()
            grad_clipping(params, clipping_theta, device)  # 裁剪梯度
            sgd(params, lr, batch_size)
            l_sum += l.item()
            n += y.shape[0]

        if (epoch + 1) % 10 == 0:
            print('epoch %d, perplexity %f, time %.2f sec' % (
                epoch + 1, math.exp(l_sum / n), time.time() - start))
            for prefix in prefixes:
                print(' -', predict_rnn(prefix, pred_len, rnn, params, init_hiden_state,
                                        num_hiddens, vocab_size, device, idx_to_char, char_to_idx))
            
    

data_iter = lambda : data_iter_random([char_to_idx[i] for i in corpus], 32, 5, device)
params = get_params(vocab_size, 256, vocab_size, device)
pred_period, pred_len, prefixes = 50, 50, ['分开', '不分开']
train(data_iter, params, 1000, 32, lr=1e2, clipping_theta=1e-2, num_hiddens=256, vocab_size=vocab_size, device=device, is_rand_sample=True)

epoch 10, perplexity 1.036188, time 1.58 sec
 - 分开                                                  
 - 不分开                                                  
epoch 20, perplexity 1.035752, time 1.28 sec
 - 分开                                                  
 - 不分开                                                  
epoch 30, perplexity 1.035588, time 1.25 sec
 - 分开                                                  
 - 不分开                                                  
epoch 40, perplexity 1.035387, time 1.83 sec
 - 分开                                                  
 - 不分开                                                  
epoch 50, perplexity 1.035164, time 1.11 sec
 - 分开 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我
 - 不分开 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我
epoch 60, perplexity 1.035050, time 1.21 sec
 - 分开 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我
 - 不分开 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我 我
epoch 70, perplexity 1.034913, time 1.03 sec
 - 分开 我

KeyboardInterrupt: 

In [91]:
data_iter = lambda : data_iter_random([char_to_idx[i] for i in corpus], 32, 5, device)


In [117]:
params = get_params(vocab_size, 256, vocab_size)
predict_rnn('分开', pred_len, rnn, params, init_hiden_state, 256, vocab_size, device, idx_to_char, char_to_idx)

'分开衫怨狗彻跑联然守演连队前枯点画亏童裂社轻丛决送拿命没术透缘念忆始阻明边再鼻伯或朵瓦濡准去替狠选斯如再'

In [111]:
params, num_hiddens

(ParameterList(
     (0): Parameter containing: [torch.FloatTensor of size 1027x100]
     (1): Parameter containing: [torch.FloatTensor of size 100x100]
     (2): Parameter containing: [torch.FloatTensor of size 100x1027]
     (3): Parameter containing: [torch.FloatTensor of size 100]
     (4): Parameter containing: [torch.FloatTensor of size 1027]
 ), 256)

In [122]:
for p in params:
    print(p)

Parameter containing:
tensor([[ 0.0236,  0.0084,  0.0050,  ..., -0.0085, -0.0084, -0.0057],
        [ 0.0008, -0.0047, -0.0096,  ...,  0.0216, -0.0160, -0.0103],
        [-0.0015,  0.0093,  0.0154,  ...,  0.0038, -0.0141,  0.0162],
        ...,
        [-0.0043,  0.0201,  0.0048,  ...,  0.0221, -0.0079,  0.0193],
        [-0.0096,  0.0008,  0.0056,  ..., -0.0123,  0.0009,  0.0089],
        [ 0.0082,  0.0049,  0.0021,  ...,  0.0048, -0.0014, -0.0015]],
       requires_grad=True)
Parameter containing:
tensor([[ 0.0723, -0.0493,  0.0366,  ...,  0.0056,  0.0715, -0.0638],
        [-0.0122, -0.0020,  0.0081,  ..., -0.0375, -0.0039,  0.0348],
        [ 0.0374, -0.0103, -0.0039,  ...,  0.0052, -0.0053,  0.0388],
        ...,
        [ 0.0471, -0.0073, -0.0082,  ...,  0.0012,  0.0024,  0.0048],
        [ 0.0132, -0.0262, -0.0012,  ..., -0.0078,  0.0055, -0.0061],
        [-0.0080,  0.0199,  0.0087,  ..., -0.0080, -0.0020,  0.0204]],
       requires_grad=True)
Parameter containing:
tensor([[ 0.