In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

In [8]:
sentence = "Repeat is the best medicine for memory".split()
sentence.sort()
vocab = list(set(sentence))
vocab

['Repeat', 'medicine', 'memory', 'best', 'the', 'for', 'is']

In [4]:
word2index = {word: i + 1 for i, word in enumerate(vocab)}
word2index['<unk>'] = 0
word2index

{'Repeat': 1,
 'medicine': 2,
 'memory': 3,
 'best': 4,
 'the': 5,
 'for': 6,
 'is': 7,
 '<unk>': 0}

In [5]:
index_to_word = {value : key for key, value in word2index.items()}
index_to_word

{1: 'Repeat',
 2: 'medicine',
 3: 'memory',
 4: 'best',
 5: 'the',
 6: 'for',
 7: 'is',
 0: '<unk>'}

In [20]:
def build_data(sentence, word2index):
    encoded = [word2index[token] for token in sentence]
    input_seq, label_seq = encoded[:-1], encoded[1:]
    input_seq = torch.LongTensor(input_seq).unsqueeze(0)
    label_seq = torch.LongTensor(label_seq).unsqueeze(0)
    return input_seq, label_seq


In [21]:
build_data(sentence, word2index)

(tensor([[1, 4, 6, 7, 2, 3]]), tensor([[4, 6, 7, 2, 3, 5]]))

In [22]:
X, Y = build_data(sentence, word2index)
print(X)
print(Y)

tensor([[1, 4, 6, 7, 2, 3]])
tensor([[4, 6, 7, 2, 3, 5]])


In [23]:
class Net(nn.Module):
    def __init__(self, vocab_size, input_size, hidden_size, batch_first=True):
        super().__init__()
        self.embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=input_size)
        self.rnn_layer = nn.RNN(input_size, hidden_size, batch_first=batch_first)
        self.linear = nn.Linear(hidden_size, vocab_size)
    
    def forward(self,x):
        output= self.embedding_layer(x)
        output, hidden = self.rnn_layer(output)
        output = self.linear(output)
        return output.view(-1, output.size(2))