In [2]:
import time 
import math
import torch
import numpy as np
from torch import nn,optim
import torch.nn.functional as F
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

(corpus_indices,char_to_idx,idx_to_char,vocab_size)=d2l.load_data_jay_lyrics()

In [3]:
num_hiddens=256
rnn_layer=nn.RNN(input_size=vocab_size,hidden_size=num_hiddens)

In [4]:
num_steps=35
batch_size=2
state=None
X=torch.rand(num_steps,batch_size,vocab_size)
Y,state_new=rnn_layer(X,state)
print(Y.shape,len(state_new),state_new[0].shape)

torch.Size([35, 2, 256]) 1 torch.Size([2, 256])


In [16]:
class RNNModel(nn.Module):
    def __init__(self,rnn_layer,vocab_size):
        super(RNNModel,self).__init__()
        self.rnn=rnn_layer
        self.hidden_size=rnn_layer.hidden_size*(2 if rnn_layer.bidirectional else 1)
        self.vocab_size=vocab_size
        self.dense=nn.Linear(self.hidden_size,vocab_size)
        self.state=None
    def forward(self,inputs,state):
        X=d2l.to_onehot(inputs,self.vocab_size)
        Y,self.state=self.rnn(torch.stack(X),state)
        outputs=self.dense(Y.view(-1,Y.shape[-1]))
        return outputs,self.state ##相当于获得每个词典的得分

In [23]:
def predict_rnn_pytorch(prefix,num_chars,model,vocab_size,device,idx_to_char,
                       char_to_idx):
    state=None
    output=[char_to_idx[prefix[0]]]
    for t in range(num_chars+len(prefix)-1):
        X=torch.tensor([output[-1]],device=device).view(1,1)
        if state is not None:
            if isinstance(state,tuple):
                state=(state[0].to(device),state[1].to(device))
            else:
                state=state.to(device)
                
        (Y,state)=model(X,state)
        if(t<len(prefix)-1):
            output.append(char_to_idx[prefix[t+1]])
        else:
            output.append(int(Y.argmax(dim=1).item()))
    return ''.join([idx_to_char[i] for i in output])


In [24]:
model=RNNModel(rnn_layer,vocab_size).to(device)
predict_rnn_pytorch('分开',10,model,vocab_size,device,idx_to_char,char_to_idx)

'分开停丘平平平平平平平平'

In [44]:
def train_and_predict_rnn_pytorch(model,num_hiddens,vocab_size,device,
                                 corpus_indices,idx_to_char,char_to_idx,
                                 num_eopchs,num_steps,lr,clipping_theta,
                                 batch_size,pred_period,pred_len,prefixes):
    #pred_period每多少周期打印一次
    loss=nn.CrossEntropyLoss()
    optimizer=torch.optim.Adam(model.parameters(),lr=lr)
    model.to(device)
    state=None
    for epoch in range(num_epochs):
        l_sum,n,start=0.0,0,time.time()
        data_iter=d2l.data_iter_consecutive(corpus_indices,batch_size,num_steps,device)
        for X,Y in data_iter:
            if state is not None:
                if isinstance(state,tuple):
                    state=(state[0].detach(),state[1].detach())
                else:
                    state=state.detach()
                
            (output,state)= model(X,state)
            ##我总感觉是transpose(Y,1,0),试验了困惑度还小一些
            y=torch.transpose(Y,0,1).contiguous().view(-1)
            l=loss(output,y.long())
            
            optimizer.zero_grad()
            l.backward()
            d2l.grad_clipping(model.parameters(),clipping_theta,device)
            optimizer.step()
            l_sum+=l.item()*y.shape[0]
            n+=y.shape[0]
            
        try:
            perplexity=math.exp(l_sum/n)
        except OverflowError:
            perplexity=float('inf')
        if(epoch+1)%pred_period==0:
            print('epoch %d,perplexity %f,time %.2f sec'%(
            epoch+1,perplexity,time.time()-start))
            for prefix in prefixes:
                print('-',predict_rnn_pytorch(
                prefix,pred_len,model,vocab_size,device,idx_to_char,
                char_to_idx))
                

In [46]:
num_epochs,batch_size,lr,clipping_theta=250,32,1e-3,1e-2
pred_period,pred_len,prefixes=50,50,['爱猪猪','爱夏爽']
train_and_predict_rnn_pytorch(model,num_hiddens,vocab_size,device,
                             corpus_indices,idx_to_char,char_to_idx,
                             num_epochs,num_steps,lr,clipping_theta,
                             batch_size,pred_period,pred_len,prefixes)

epoch 50,perplexity 1.006590,time 0.20 sec
- 爱猪猪 但淡的誓言暴力因素一定都会有原因 但是呢 妈跟我都没有错亏我叫你一声爸  爸我回来了 不要再这样打


KeyError: '夏'