In [36]:
import torch
import torch.nn as nn


In [37]:
sentence = ["knowledge", "opens", "many", "doors"]
word_to_idx = {word: idx for idx, word in enumerate(sentence)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

input_words = ["knowledge", "opens", "many"]
target_word = "doors"

inputs = torch.tensor([word_to_idx[word] for word in input_words])
target = torch.tensor(word_to_idx[target_word])


In [38]:
class MySimpleRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size):
        super(MySimpleRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.Wxh = nn.Parameter(torch.randn(embed_size, hidden_size) * 0.01)
        self.Whh = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)
        self.bh = nn.Parameter(torch.zeros(hidden_size))
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        embeds = self.embedding(x)

        h = torch.zeros(self.Whh.size(0))
        for i in range(embeds.size(0)):
            h = torch.tanh(embeds[i] @ self.Wxh + h @ self.Whh + self.bh)
        out = self.fc(h)
        return out

In [39]:
vocab_size = len(sentence)
embed_size = 8
hidden_size = 16

model = MySimpleRNN(vocab_size, embed_size, hidden_size)


In [41]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


In [42]:
n_epochs = 1000
for epoch in range(n_epochs):
    optimizer.zero_grad()
    output = model(inputs)
    loss = criterion(output.unsqueeze(0), target.unsqueeze(0))
    loss.backward()
    optimizer.step()

    if (epoch+1) % 200 == 0:
        pred_idx = torch.argmax(output).item()
        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}, Prediction: {idx_to_word[pred_idx]}")


Epoch 200, Loss: 0.0004, Prediction: doors
Epoch 400, Loss: 0.0002, Prediction: doors
Epoch 600, Loss: 0.0001, Prediction: doors
Epoch 800, Loss: 0.0001, Prediction: doors
Epoch 1000, Loss: 0.0000, Prediction: doors


In [43]:
output = model(inputs)
pred_idx = torch.argmax(output).item()
print(f"\nFinal Prediction: {idx_to_word[pred_idx]}")



Final Prediction: doors
