In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import Dataset
import requests

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device.type)

cpu


In [3]:
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
resp = requests.get(url)
text = resp.text
print("Length of text:", len(text))
print("\n\n")
print(text[:500])

Length of text: 1115394



First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


In [4]:
ds = Dataset.from_dict({"text" : [text]})

In [5]:
chars = sorted(set(text))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)} # String to integer
print(stoi)

print("\n")

itos = {i: ch for ch, i in stoi.items()} # intereg to string
print(itos)

def encode(s):
    return [stoi[c] for c in s]

data = torch.tensor(encode(text), dtype=torch.long)

print("\n")

print(data)

{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}


{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: 'i

In [6]:
seq_len = 60

def get_batches(data, seq_len, batch_size=128):
    n = len(data) - seq_len
    ix = torch.randint(0, n, (batch_size,))
    x = torch.stack([data[i : i+seq_len] for i in ix])
    y = torch.stack([data[i+1 : i+seq_len+1] for i in ix])
    return x, y

class CharRNN(nn.Module):
    def __init__(self, vocab_size, hidden_size,num_layers=2):
        super().__init__()
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, h=None):
        x = self.embed(x)
        out, h = self.rnn(x, h)
        out = self.fc(out)
        return out, h

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CharRNN(vocab_size, hidden_size=128,num_layers=6).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [8]:
epochs = 400
for epoch in range(epochs):
    x_batch, y_batch = get_batches(data, seq_len)
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    optimizer.zero_grad()
    out, _ = model(x_batch)
    loss = criterion(out.view(-1, vocab_size), y_batch.view(-1))
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{epochs}, loss = {loss.item():.4f}")

Epoch 10/400, loss = 3.5268
Epoch 20/400, loss = 3.0135
Epoch 30/400, loss = 2.7692
Epoch 40/400, loss = 2.6672
Epoch 50/400, loss = 2.5511
Epoch 60/400, loss = 2.4754
Epoch 70/400, loss = 2.4244
Epoch 80/400, loss = 2.4074
Epoch 90/400, loss = 2.3446
Epoch 100/400, loss = 2.2842
Epoch 110/400, loss = 2.2578
Epoch 120/400, loss = 2.2146
Epoch 130/400, loss = 2.2095
Epoch 140/400, loss = 2.2090
Epoch 150/400, loss = 2.2026
Epoch 160/400, loss = 2.1430
Epoch 170/400, loss = 2.1350
Epoch 180/400, loss = 2.0967
Epoch 190/400, loss = 2.0865
Epoch 200/400, loss = 2.1199
Epoch 210/400, loss = 2.0821
Epoch 220/400, loss = 2.0780
Epoch 230/400, loss = 2.0683
Epoch 240/400, loss = 2.0454
Epoch 250/400, loss = 2.0268
Epoch 260/400, loss = 2.0070
Epoch 270/400, loss = 2.0196
Epoch 280/400, loss = 1.9949
Epoch 290/400, loss = 1.9791
Epoch 300/400, loss = 1.9985
Epoch 310/400, loss = 1.9568
Epoch 320/400, loss = 1.9941
Epoch 330/400, loss = 1.9631
Epoch 340/400, loss = 1.9365
Epoch 350/400, loss = 1

In [10]:
def generate_text(model, start_text="ROMEO", length=100):
    model.eval()
    input_ids = torch.tensor([stoi[c] for c in start_text], dtype=torch.long).unsqueeze(0).to(device)
    hidden = None
    result = list(start_text)
    for _ in range(length):
        out, hidden = model(input_ids, hidden)
        logits = out[0, -1]
        next_id = torch.argmax(logits).item()
        result.append(itos[next_id])
        input_ids = torch.tensor([[next_id]], dtype=torch.long).to(device)
    return ''.join(result)

print("Sample:", generate_text(model, "First Citizen", 1000))

Sample: First Citizen the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the with the wi