In [14]:
import torch
import torch.nn as nn
import torch.optim as optim

import random

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [16]:
with open('shakespeare.txt', 'r') as f:
  text = f.read().lower()[:5000000]

In [17]:
chars = sorted(set(text))
vocab_size = len(chars)

In [18]:
char2idx = {c: i for i, c in enumerate(chars)}
idx2char= {i: c for c, i in char2idx.items()}

In [19]:
seq_length = 100
step_size = 1
data = [(text[i:i+seq_length], text[i+seq_length]) for i in range(0, len(text)-seq_length, step_size)]

In [20]:
X = torch.tensor([[char2idx[c] for c in seq] for seq, _ in data]).to(device)
y = torch.tensor([char2idx[c] for _, c in data]).to(device)

In [21]:
class CharLSTM(nn.Module):
  def __init__(self, vocab_size, hidden_size, num_layers=1):
    super().__init__()
    self.embed = nn.Embedding(vocab_size, hidden_size)
    self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
    self.fc = nn.Linear(hidden_size, vocab_size)

  def forward(self, x, hidden=None):
    x = self.embed(x)
    out, hidden = self.lstm(x, hidden)
    out = self.fc(out[:, -1, :]) # Shape: [batch_size, seq_length, hidden_size]

    return out, hidden

In [22]:
model = CharLSTM(vocab_size, hidden_size=256).to(device)
optimizer=optim.Adam(model.parameters(), lr=0.0003)
criterion = nn.CrossEntropyLoss()

In [23]:
for epoch in range(100):
  model.train()
  running_loss = 0.0

  for i in range(0, len(X), 64):
    x_batch = X[i:i+64].to(device)
    y_batch = y[i:i+64].to(device)

    if len(x_batch) == 0:
      continue

    optimizer.zero_grad()
    output, _ = model(x_batch)
    loss = criterion(output, y_batch)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()

  print(f'Epoch {epoch+1}/100, Loss: {running_loss:.4f}')

Epoch 1/100, Loss: 3107.1121
Epoch 2/100, Loss: 2634.8976
Epoch 3/100, Loss: 2478.1724
Epoch 4/100, Loss: 2376.0811
Epoch 5/100, Loss: 2292.0871
Epoch 6/100, Loss: 2220.3729
Epoch 7/100, Loss: 2158.2235
Epoch 8/100, Loss: 2105.1079
Epoch 9/100, Loss: 2057.8063
Epoch 10/100, Loss: 2014.0114
Epoch 11/100, Loss: 1972.9485
Epoch 12/100, Loss: 1934.6390
Epoch 13/100, Loss: 1897.9055
Epoch 14/100, Loss: 1863.4111
Epoch 15/100, Loss: 1830.4877
Epoch 16/100, Loss: 1797.7986
Epoch 17/100, Loss: 1762.9170
Epoch 18/100, Loss: 1728.2547
Epoch 19/100, Loss: 1695.6287
Epoch 20/100, Loss: 1661.9173
Epoch 21/100, Loss: 1627.4396
Epoch 22/100, Loss: 1594.2588
Epoch 23/100, Loss: 1560.5650
Epoch 24/100, Loss: 1527.5967
Epoch 25/100, Loss: 1494.4548
Epoch 26/100, Loss: 1460.3943
Epoch 27/100, Loss: 1426.4731
Epoch 28/100, Loss: 1392.8461
Epoch 29/100, Loss: 1359.8488
Epoch 30/100, Loss: 1327.0130
Epoch 31/100, Loss: 1293.1141
Epoch 32/100, Loss: 1262.3626
Epoch 33/100, Loss: 1230.3378
Epoch 34/100, Loss:

In [24]:
def generate_text(model, start_seq, length=200):
  model.eval()

  input_seq = torch.tensor([[char2idx[c] for c in start_seq]]).to(device)
  hidden = None

  result = start_seq

  for _ in range(length):
    output, hidden = model(input_seq, hidden)
    probs = torch.softmax(output, dim=-1).squeeze()
    next_idx = torch.multinomial(probs, 1).item()
    next_char = idx2char[next_idx]

    result += next_char
    input_seq = torch.tensor([[next_idx]]).to(device)

  return result

In [25]:
print(generate_text(model, 'he was going with'))
print('-'*100)
print(generate_text(model, 'why is it'))
print('-'*100)
print(generate_text(model, 'we must all'))

he was going with thought,
i enjuning thy sweet love love beed
at realous that grieves antique pen with friends but not
his purge as the state the world enmort:  
or if they her to thing i do fier's eye,
that by old e
----------------------------------------------------------------------------------------------------
why is it?
which home like her feel friend hand deceed?
in eternixered chested,
thy beauty my self to pain,
whilst my noble poor name, nor despise now to be remove.
were it but is my lovers worth herst bearing
----------------------------------------------------------------------------------------------------
we must all bind.



loved in thy beauty's fool i do change,
thy blind smald travail that hadd still we ad?
he rany headd what thy sweet winter's day
and she is not for my side,
or frail i chide the wide was ple


In [26]:
torch.save(model.state_dict(), 'char_lstm_model.pth')