In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [27]:
class CustomRNNCell(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CustomRNNCell, self).__init__()
        self.hidden_size = hidden_size
        self.Wxh = nn.Parameter(torch.randn(hidden_size, input_size) * 0.01)
        self.Whh = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)
        self.Why = nn.Parameter(torch.randn(output_size, hidden_size) * 0.01)
        self.bh = nn.Parameter(torch.zeros(hidden_size, 1))
        self.by = nn.Parameter(torch.zeros(output_size, 1))

    def forward(self, x, h):
        h = torch.tanh(self.Wxh @ x + self.Whh @ h + self.bh)
        y = self.Why @ h + self.by
        return y, h

In [28]:
import numpy as np
from nltk.corpus import gutenberg

# Load and preprocess data
corpus = gutenberg.raw('shakespeare-hamlet.txt').lower()  # Use a different text if desired
chars = list(set(corpus))
char_to_index = {char: i for i, char in enumerate(chars)}
index_to_char = {i: char for char, i in char_to_index.items()}
num_chars = len(chars)

# Convert the text to a sequence of indices
corpus_indices = [char_to_index[char] for char in corpus]

In [35]:
corpus_indices = corpus_indices[:2000]

In [36]:
input_size = num_chars
hidden_size = 128
output_size = num_chars
seq_length = 20
num_epochs = 100
learning_rate = 0.001

In [37]:
rnn = CustomRNNCell(input_size, hidden_size, output_size)

In [38]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=learning_rate)


In [42]:
for epoch in range(num_epochs):
    h = torch.zeros(hidden_size, 1)
    loss_sum = 0
    for i in range(0, len(corpus_indices) - seq_length):
        inputs = torch.tensor(corpus_indices[i:i+seq_length], dtype=torch.long).view(-1, 1)
        targets = torch.tensor(corpus_indices[i+1:i+seq_length+1], dtype=torch.long).view(-1)

        optimizer.zero_grad()
        h = torch.zeros(hidden_size, 1)  # Initialize hidden state
        loss = 0

        for j in range(seq_length):
            x = torch.zeros(input_size, 1)
            x[inputs[j].item()] = 1
            y, h = rnn(x.clone(), h.clone())  # Clone input and hidden state
            loss += criterion(y.view(1, -1), targets[j].view(1))

        loss.backward(retain_graph=True)
        optimizer.step()
        loss_sum += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss_sum}')


Epoch [1/100], Loss: 511.46358420792967
Epoch [2/100], Loss: 510.81390369683504
Epoch [3/100], Loss: 510.1671797679737
Epoch [4/100], Loss: 509.524298442062
Epoch [5/100], Loss: 508.88551524700597
Epoch [6/100], Loss: 508.25104096066207
Epoch [7/100], Loss: 507.62266592122614
Epoch [8/100], Loss: 506.99793858639896
Epoch [9/100], Loss: 506.38321847561747
Epoch [10/100], Loss: 505.7706372104585
Epoch [11/100], Loss: 505.17413205094635
Epoch [12/100], Loss: 504.58236853266135
Epoch [13/100], Loss: 503.9970304099843
Epoch [14/100], Loss: 503.4666343978606
Epoch [15/100], Loss: 502.84545140573755
Epoch [16/100], Loss: 502.541098265443
Epoch [17/100], Loss: 501.77700535720214
Epoch [18/100], Loss: 501.6934882667847
Epoch [19/100], Loss: 501.10756563348696
Epoch [20/100], Loss: 500.661105458159
Epoch [21/100], Loss: 500.5196637287736
Epoch [22/100], Loss: 499.887049079407
Epoch [23/100], Loss: 499.19629745790735
Epoch [24/100], Loss: 498.09428565436974
Epoch [25/100], Loss: 497.6682240455411

In [46]:
with torch.no_grad():
    start_char = "t"  # Starting character for text generation
    input_char = torch.tensor(char_to_index[start_char], dtype=torch.long).view(-1, 1)
    h = torch.zeros(hidden_size, 1)

    generated_text = start_char

    for _ in range(4):
        x = torch.zeros(input_size, 1)
        x[input_char.item()] = 1
        y, h = rnn(x.clone(), h.clone())
        probabilities = torch.softmax(y, dim=0).squeeze().numpy()
        next_char_index = np.random.choice(range(num_chars), p=probabilities)
        next_char = index_to_char[next_char_index]
        generated_text += next_char
        input_char = torch.tensor(next_char_index, dtype=torch.long).view(-1, 1)

    print(generated_text)

the t
