In [22]:
import torch
import torch.nn as nn
import torch.optim as optim

import random

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [24]:
with open('shakespeare.txt', 'r') as f:
  text = f.read().lower()[:500000]

In [25]:
words = text.split()
vocab = sorted(set(words))

In [26]:
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}

In [27]:
vocab_size = len(vocab)

In [28]:
seq_length = 5

In [29]:
data = [(words[i:i+seq_length], words[i+seq_length]) for i in range(0, len(words) - seq_length)]

In [30]:
X = torch.tensor([[word2idx[w] for w in seq] for seq, _ in data]).to(device)
y = torch.tensor([word2idx[w] for _, w in data]).to(device)

In [31]:
class CharRNNAttention(nn.Module):
  def __init__(self, vocab_size, embedding_dim, hidden_dim):
    super().__init__()
    self.embed = nn.Embedding(vocab_size, embedding_dim)
    self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
    self.attention = nn.Linear(hidden_dim, 1)
    self.fc = nn.Linear(hidden_dim, vocab_size)

  def forward(self, x):
    x = self.embed(x)
    out, _ = self.rnn(x)
    attention_weights = nn.functional.softmax(self.attention(out).squeeze(2), dim=1)
    context = torch.sum(attention_weights.unsqueeze(2) * out, dim=1)
    out = self.fc(context)

    return out

In [32]:
model = CharRNNAttention(vocab_size, 128, 256).to(device)

In [33]:
optimizer = optim.Adam(model.parameters(), lr=0.003)
criterion = nn.CrossEntropyLoss()

In [34]:
for epoch in range(50):
  model.train()
  running_loss = 0.0

  for i in range(0, len(X), 64):
    x_batch = X[i:i+64].to(device)
    y_batch = y[i:i+64].to(device)

    if len(x_batch) == 0:
      continue

    optimizer.zero_grad()
    outputs = model(x_batch)
    loss = criterion(outputs, y_batch)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()

  print(f'Epoch {epoch+1}/50, Loss: {running_loss:.4f}')

Epoch 1/50, Loss: 1999.9735
Epoch 2/50, Loss: 1740.4236
Epoch 3/50, Loss: 1538.0886
Epoch 4/50, Loss: 1343.6157
Epoch 5/50, Loss: 1209.0893
Epoch 6/50, Loss: 1083.5018
Epoch 7/50, Loss: 954.2811
Epoch 8/50, Loss: 826.2561
Epoch 9/50, Loss: 718.2332
Epoch 10/50, Loss: 625.6101
Epoch 11/50, Loss: 539.2823
Epoch 12/50, Loss: 464.2919
Epoch 13/50, Loss: 401.1814
Epoch 14/50, Loss: 353.0522
Epoch 15/50, Loss: 306.7517
Epoch 16/50, Loss: 269.8393
Epoch 17/50, Loss: 239.5276
Epoch 18/50, Loss: 215.1926
Epoch 19/50, Loss: 188.6489
Epoch 20/50, Loss: 167.8291
Epoch 21/50, Loss: 154.0112
Epoch 22/50, Loss: 137.5959
Epoch 23/50, Loss: 119.7708
Epoch 24/50, Loss: 107.9789
Epoch 25/50, Loss: 98.1043
Epoch 26/50, Loss: 98.1628
Epoch 27/50, Loss: 87.8847
Epoch 28/50, Loss: 86.1968
Epoch 29/50, Loss: 82.9837
Epoch 30/50, Loss: 82.7148
Epoch 31/50, Loss: 74.6544
Epoch 32/50, Loss: 65.2510
Epoch 33/50, Loss: 62.4168
Epoch 34/50, Loss: 66.1428
Epoch 35/50, Loss: 69.9454
Epoch 36/50, Loss: 69.7226
Epoch 3

In [35]:
def generate_text(model, start_words, num_words=20):
    model.eval()
    generated = start_words[:]

    for _ in range(num_words):
        current_seq = generated[-seq_length:] if len(generated) >= seq_length else generated

        if len(current_seq) < seq_length:
            current_seq = [""] * (seq_length - len(current_seq)) + current_seq

        idx_seq = [word2idx[w] if w in word2idx else 0 for w in current_seq]
        input_seq = torch.tensor([idx_seq], dtype=torch.long).to(device)

        with torch.no_grad():
            logits = model(input_seq)
            probs = torch.softmax(logits, dim=-1).squeeze(0)
            next_idx = torch.multinomial(probs, 1).item()

        next_word = idx2word[next_idx]
        generated.append(next_word)

    return ' '.join(generated)

In [36]:
print(generate_text(model, ['he', 'was', 'going', 'with'], num_words=30))
print(generate_text(model, ['why', 'is', 'it'], num_words=30))
print(generate_text(model, ['we', 'must', 'all'], num_words=30))

he was going with bitterness the in bitterness own think, that your too to correct lie, bitter when thy dear title to make to worst to only the what the ten spring, which of
why is it bitterness sin keeps though gave think, or or the of correct if i none heat dear the will self and of arise, for think that i and my heard of
we must all plea present present such i him show then beauty herself all i is all they show purpose will, thee every thee glass so though fair you hast in as my


In [37]:
torch.save(model.state_dict(), 'attention_model.pth')