### Lập trình súc tích mạng nơ ron hồi tiếp

In [1]:
from d2l import torch as d2l
import torch
import torch.nn as nn
import torch.nn.functional as F

KeyboardInterrupt: 

#### 1. Định nghĩa mô hình
- pytorch đã lập trình sẵn mạng nơ ron hồi tiếp cùng với các mô hình chuỗi khác. Ta xây dựng tầng hồi tiếp rnn_layer với một tầng ẩn có 256 nút rồi khởi tạo các trọng số.

In [None]:
class RNN(nn.Module):
    def __init__(self, num_hiddens, vocab_size, device):
        super().__init__()
        self.num_hiddens = num_hiddens
        self.vocab_size = vocab_size
        self.device = device
        self.rnn = nn.RNN(
            input_size = vocab_size, hidden_size = num_hiddens, device = device
        )
        self.dense = nn.Linear(num_hiddens, vocab_size, device = device)
    
    def forward(self, input, state = None):
        if state is None:
            state = self.begin_state(input.shape[0])
        else:
            state = state.detach()
        X = F.one_hot(input.T, self.vocab_size).float()
        Y, state = self.rnn(X, state)
        # The fully connected layer will first change the shape of Y to 
        # (num _step * batch_size, num_hiddens). Its output shape is 
        # (num_steps * batch_size, vocab_size)
        # print(Y.shape)
        output = self.dense(Y.reshape(-1, Y.shape[-1]))
        return output, state
    def begin_state(self, batch_size):
        return torch.zeros(1, batch_size, self.num_hiddens, device=self.device)



In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [None]:
batch_size, num_steps = 32, 35
data = d2l.TimeMachine(batch_size, num_steps)
data_iter = data.get_dataloader(train = True)
vocab = data.vocab

In [None]:
net = RNN(256, len(vocab), device = device)

In [None]:
loss = nn.CrossEntropyLoss()

In [None]:
def predict(net : RNN, str, num_pred):
    outputs = [str[0]]
    state = None
    for i in str[1:]:
        batch = torch.tensor(
            data = vocab[i], device = device
        ).reshape(1, 1)
        _, state = net(batch, state = state)
        outputs.append(i)

    for i in range(num_pred):
        batch = torch.tensor(vocab.token_to_idx[outputs[-1]], device = device).reshape(1, 1)
        output, state = net(batch, state = state)
        output = output.cpu().detach().numpy()
        outputs.append(vocab.idx_to_token[output.argmax(axis = -1)[0]])
    return ''.join(outputs)

predict(net, "time", 10)


'timeyfnyfnyfny'

In [None]:
import numpy as np

# loss = torch.nn.CrossEntropyLoss()
trainer = torch.optim.Adam(params = net.parameters())
def train_epoch(net : RNN, trainer, train_iter):
    metric = d2l.Accumulator(2)
    state = None
    for X, y in train_iter:
        if X.shape[0] != batch_size:
            continue
        net.train()
        output, state = net(X.to(device), state)
        output = F.softmax(output.reshape(-1, output.shape[-1]), dim = 1)
        y = y.to(device).reshape(-1)

        # print(F.softmax(output, dim = 1), y)
        # break
        l = loss(output, y)
        # print(l)

        trainer.zero_grad()
        l.backward()

        d2l.grad_clipping(net, 1)
        trainer.step()
        metric.add(np.sum(l.cpu().detach().numpy()), y.shape[0])
    # Return perplexity per epoch
    return np.exp(metric[0] / metric[1])

def train(net : RNN, trainer, train_iter, num_epoch = 100):
    for epoch in range(num_epoch):
        perplexity = train_epoch(net, trainer, train_iter)
        # if epoch % 10 == 0:
        print(f"Epoch: {epoch} | Perplexity: {perplexity} | {predict(net, 'time travel', 10)}")
    
train(net, trainer, data_iter)

Epoch: 0 | Perplexity: 1.0028711546341642 | time travel          
Epoch: 1 | Perplexity: 1.0028688697219867 | time travel          
Epoch: 2 | Perplexity: 1.0028688699778945 | time travel          
Epoch: 3 | Perplexity: 1.0028688799090264 | time travel          
Epoch: 4 | Perplexity: 1.0028688687681493 | time travel          
Epoch: 5 | Perplexity: 1.0028688729310957 | time travel          
Epoch: 6 | Perplexity: 1.0028688646298356 | time travel          
Epoch: 7 | Perplexity: 1.0028688717692478 | time travel          
Epoch: 8 | Perplexity: 1.002868858341624 | time travel          


KeyboardInterrupt: 