# Language Model RNN

Here's an RNN for a Language model. Here's how it looks like:

![](https://i.imgur.com/dowTmun.png)

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [118]:
class LRNN(nn.Module):
    def __init__(self, n_tokens, n_hidden, n_out):
        super().__init__()
        self.n_tokens = n_tokens
        self.n_hidden = n_hidden
        self.n_out = n_out
        
        self.f_hidden = nn.Linear(n_tokens + n_hidden, n_hidden)
        self.f_out = nn.Linear(n_hidden, n_out)
        
    def forward(self, token_oh, hidden):
        combined = torch.cat([token_oh, hidden], dim=1)
        hidden = torch.tanh(self.f_hidden(combined))
        output = torch.sigmoid(self.f_out(hidden))
        output_probs = F.softmax(output, dim=1)
        return output, hidden        
        
    def train(self, time_batch):
        bs, timesteps, n_tokens = time_batch.shape
        hidden = torch.zeros(bs, self.n_hidden)
        output_list = []
        
        for t in range(timesteps):
            # here batch contains a list of one hot encoded tokens for a different set of inputs
            batch = time_batch[:, t]
            output, hidden = self.forward(batch, hidden)
            output_list.append(output)
            
        return torch.stack(output_list, dim=1)

## Training

In [84]:
bs = 2
n_tokens = 10
timesteps = 3

input = torch.LongTensor([
    [2, 9, 8],
    [8, 1, 3]
])
input_oh = F.one_hot(input, n_tokens)
input_oh.shape, input_oh

(torch.Size([2, 3, 10]),
 tensor([[[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
          [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]],
 
         [[0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
          [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]]]))

In [120]:
model = LRNN(n_tokens=n_tokens, n_hidden=100, n_out=n_tokens)
output_oh = model.train(input_oh)
output_oh.shape, output_oh

(torch.Size([2, 3, 10]),
 tensor([[[0.5198, 0.4802, 0.5218, 0.5099, 0.4718, 0.4808, 0.4879, 0.4925,
           0.5169, 0.4976],
          [0.5246, 0.4889, 0.5243, 0.5291, 0.4700, 0.4948, 0.4815, 0.4900,
           0.5004, 0.4934],
          [0.5024, 0.4978, 0.5222, 0.5280, 0.4828, 0.4865, 0.4862, 0.5014,
           0.5203, 0.5022]],
 
         [[0.5128, 0.4846, 0.5103, 0.5295, 0.4741, 0.4971, 0.5039, 0.5007,
           0.5196, 0.5013],
          [0.5157, 0.4956, 0.5093, 0.5242, 0.4840, 0.4953, 0.4880, 0.4881,
           0.5128, 0.5070],
          [0.5033, 0.4789, 0.5332, 0.5337, 0.4735, 0.4959, 0.4882, 0.4983,
           0.5208, 0.4955]]], grad_fn=<StackBackward0>))

In [121]:
output = torch.argmax(output_oh, dim=2)
output.shape

torch.Size([2, 3])

## Predicting / Generating

In [124]:
F.one_hot(torch.tensor(1))

tensor([0, 1])

In [130]:
def generate(start_word, timesteps=10):
    hidden = torch.zeros(1, model.n_hidden)
    output = F.one_hot(torch.tensor(start_word), model.n_tokens).reshape(1, -1)
    output_words = [start_word]
    
    for t in range(timesteps):
        output, hidden = model(output, hidden)
        output_words.append(torch.argmax(output, dim=1).item())
        
    return torch.tensor(output_words)

generate(2)

tensor([2, 2, 8, 2, 2, 2, 2, 2, 2, 2, 2])