In [241]:
from d2l import torch as d2l
from torch import nn
import torch.nn.functional as F
import torch
import numpy as np

In [242]:
### Thiết lập mạng RNN

class RNN(nn.Module):
    # num_inputs là số lượng vocab biểu diễn bằng one-hit
    def __init__(self, num_hidden, vocab, device):
        super().__init__()
        num_inputs = len(vocab)
        num_outputs = num_inputs
        self.num_hidden = num_hidden
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.device = device
        self.vocab = vocab
        self.W_xh = nn.Parameter(
            data = torch.rand(size = (num_inputs, num_hidden), device = device, requires_grad=True)
        )
        self.W_hh = nn.Parameter(
            data = torch.rand(size = (num_hidden, num_hidden), device = device, requires_grad=True)
        )
        self.b_h = nn.Parameter(
            data = torch.rand(size = (1, num_hidden), device = device, requires_grad=True)
        )

        self.W_hq = nn.Parameter(
            data = torch.rand(size = (num_hidden, num_outputs), device = device, requires_grad=True)
        )
        self.b_q = nn.Parameter(
            data = torch.rand(size = (1, num_outputs), device = device, requires_grad=True)
        )
    
    def init_state(self, batch_size):
        state = torch.zeros(size = (batch_size, self.num_hidden), device = self.device)
        state = state.detach()
        return state
    def forward(self, X, state = None):
        if state is None or X.shape[0] != state.shape[0]:
            state = self.init_state(X.shape[0])
            
        state = state.detach()
        # X = X.to(self.device)
        X = F.one_hot(X.T, num_classes=len(self.vocab)).to(device = self.device, dtype = torch.float32)
        # X = X.to(dtype = torch.float32)
        outputs = []
        # Shape of X: state_number, batch_size, vocab_size

        for x in X:
            state = F.tanh(torch.mm(x, self.W_xh) + torch.mm(state, self.W_hh) + self.b_h)
            output = torch.mm(state, self.W_hq) + self.b_q
            outputs.append(output)  # Chỉ dùng PyTorch tensor

        return torch.cat(outputs, dim=0).to(self.device), state






In [243]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [244]:
batch_size, num_steps = 32, 35
data = d2l.TimeMachine(batch_size, num_steps)
data_iter = data.get_dataloader(train = True)
vocab = data.vocab

In [245]:
net = RNN(num_hidden = 256, vocab = vocab, device = device)
loss = torch.nn.CrossEntropyLoss()
trainer = torch.optim.SGD(params = net.parameters(), lr = 1)

In [246]:
# Gradient clipping
# d2l.grad_clipping(net, theta = 1)

In [247]:
for X, y in data_iter:
    X= X.to(dtype = torch.float32)
    print(X[0], y[0])
    break

tensor([15.,  8.,  0.,  9.,  6.,  0., 21., 16., 16., 12.,  0., 16., 15.,  6.,
         0., 16.,  7.,  0., 21.,  9.,  6.,  0., 20., 14.,  2., 13., 13.,  0.,
        16.,  4., 21.,  2.,  8., 16., 15.]) tensor([ 8,  0,  9,  6,  0, 21, 16, 16, 12,  0, 16, 15,  6,  0, 16,  7,  0, 21,
         9,  6,  0, 20, 14,  2, 13, 13,  0, 16,  4, 21,  2,  8, 16, 15,  2])


In [248]:
def predict(net : RNN, str, num_pred):
    outputs = [str[0]]
    state = None
    for i in str[1:]:
        batch = torch.tensor(
            data = vocab[i], device = device
        ).reshape(1, 1)
        _, state = net(batch, state = state)
        outputs.append(i)

    for i in range(num_pred):
        batch = torch.tensor(vocab.token_to_idx[outputs[-1]], device = device).reshape(1, 1)
        output, state = net(batch, state = state)
        output = output.cpu().detach().numpy()
        outputs.append(vocab.idx_to_token[output.argmax(axis = -1)[0]])
    return ''.join(outputs)

predict(net, "astro", 10)


'astropppppppppp'

In [None]:
def train_epoch(net : RNN, trainer, loss_fn, train_iter):
    metric = d2l.Accumulator(2)
    state = None
    for X, y in train_iter:
        net.train()
        output, state = net(X, state)
        output = output.reshape(y.shape[0], num_steps, -1)
        y = F.one_hot(y, num_classes=len(vocab)).float().to(output.device)


        loss = loss_fn(output, y)

        trainer.zero_grad()
        loss.backward()

        d2l.grad_clipping(net, 1)
        trainer.step()
        metric.add(np.sum(loss.cpu().detach().numpy()), y.shape[0])
    # Return perplexity per epoch
    return np.exp(metric[0] / metric[1])

def train(net : RNN, trainer, loss_fn, train_iter, num_epoch = 100):
    for epoch in range(num_epoch):
        perplexity = train_epoch(net, trainer, loss_fn, train_iter)
        # if epoch % 10 == 0:
        print(f"Epoch: {epoch} | Perplexity: {perplexity}")
    
train(net, trainer, loss, data_iter)

Epoch: 0 | Perplexity: 1.149525016346556
Epoch: 1 | Perplexity: 1.149511491822107
Epoch: 2 | Perplexity: 1.149503598944623
Epoch: 3 | Perplexity: 1.149493141021616
Epoch: 4 | Perplexity: 1.1494854056346155
Epoch: 5 | Perplexity: 1.149476144132992
Epoch: 6 | Perplexity: 1.1494668566708344
Epoch: 7 | Perplexity: 1.149460945938677
Epoch: 8 | Perplexity: 1.1494512196799098
Epoch: 9 | Perplexity: 1.1494451452016097
Epoch: 10 | Perplexity: 1.149441038090221
Epoch: 11 | Perplexity: 1.1494390286616958
Epoch: 12 | Perplexity: 1.1494298694786331
Epoch: 13 | Perplexity: 1.1494253908176029
Epoch: 14 | Perplexity: 1.1494164260391373
Epoch: 15 | Perplexity: 1.1494120486612984
Epoch: 16 | Perplexity: 1.1494102166974203
Epoch: 17 | Perplexity: 1.14940235581056
Epoch: 18 | Perplexity: 1.1493993647897303
Epoch: 19 | Perplexity: 1.149396838270256
Epoch: 20 | Perplexity: 1.1493922279263835
Epoch: 21 | Perplexity: 1.1493893828833581
Epoch: 22 | Perplexity: 1.1493798895009157
Epoch: 23 | Perplexity: 1.14938

KeyboardInterrupt: 

In [None]:
for p in net.parameters():
    if p.grad is None:
        print("Gradient is None for:", p.shape)


Gradient is None for: torch.Size([28, 256])
Gradient is None for: torch.Size([256, 256])
Gradient is None for: torch.Size([1, 256])
Gradient is None for: torch.Size([256, 28])
Gradient is None for: torch.Size([1, 28])


In [None]:
predict(net, "astro", 10)

d2l.tra

'astropppppppppp'