In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import requests

# Download and preprocess the shakespear text
URL = "https://gist.githubusercontent.com/CarineRam/817c25781a9ca8dc3370a190e31ab5e5/raw/ff02e4b6b3715295143846aaa896cc89f408cd67/gistfile1.txt"
response = requests.get(URL)
text = response.text

# create character to integer mapping and vice-versa
chars = sorted(set(text))
vocab_size = len(chars)
char_to_int = {ch: i for i, ch in enumerate (chars)} #abc => {"a":0, "B":1, "c":2}
int_to_char = {i: ch for i, ch in enumerate (chars)} #{0:"a", 1:"b", 2:"c"}

# Encode and Decode Functions
def encode(s):
    return [char_to_int[c] for c in s if c in char_to_int]
def decode(indices):
    return "".join([int_to_char[i] for i in indices])

# Create batches of data
def get_batch(split, batch_size=32, block_size=128):
    n = len(text)
    split_idx = int(0.5 * n)
    if split == "train":
        data = text[:split_idx]
    else:
        data = text[split_idx:]
    ix = torch.randint(len(data) - block_size - 1, (batch_size,))
    x_batch = torch.stack([torch.tensor(encode(data[i:i+block_size]), dtype=torch.long) for i in ix])
    y_batch = torch.stack([torch.tensor(encode(data[i+1:i+block_size+1]), dtype=torch.long) for i in ix])

    return x_batch, y_batch

# Define the model architecture
n_emb = 32
block_size = 128
head_size = 8
num_heads = 4
num_blocks = 2

# Single Attention Head
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_emb, head_size)
        self.querry = nn.Linear(n_emb, head_size)
        self.value = nn.Linear(n_emb, head_size)
        self.dropout = nn.Dropout(0.1)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        q = self.querry(x)
        k = self.key(x)
        v = self.value(x)
        weights = (q @ k.transpose(-2, -1))/(C ** 0.5)
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float("-1e9"))
        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)
        return weights @ v
    
# Multiple Attention Heads
class MultiHeadAttention(nn.Module):
    def __init__(self, head_size, num_head):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_emb, n_emb)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.dropout(self.proj(out))

# Wise Feedforward Network
class Feedforward(nn.Module):
    def __init__(self, n_emb):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_emb, 4 * n_emb),
            nn.ReLU(),
            nn.Linear(4 * n_emb, n_emb),
            nn.Dropout(0.1)
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_emb, num_heads):
        super().__init__()
        head_size = n_emb // num_heads
        self.sa = MultiHeadAttention(head_size, num_heads)
        self.ff = Feedforward(n_emb)
        self.In1 = nn.LayerNorm(n_emb)
        self.In2 = nn.LayerNorm(n_emb)

    def forward(self, x):
        x = x + self.sa(self.In1(x))
        x = x + self.ff(self.In2(x))
        return x

class TextGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, n_emb)
        self.pos_emb = nn.Embedding(block_size, n_emb)
        self.blocks = nn.Sequential(*[Block(n_emb, num_heads) for _ in range(num_blocks)])
        self.In_f = nn.LayerNorm(n_emb)
        self.head = nn.Linear(n_emb, vocab_size)

    def forward(self, x, y=None):
        B, T = x.shape
        tok_emb = self.token_emb(x)
        pos_emb = self.pos_emb(torch.arange(T, device=x.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.In_f(x)
        logits = self.head(x)

        if y is None:
            loss = None
        else:
            loss = F.cross_entropy(logits.view(B * T, vocab_size), y.view(B * T))

        return logits, loss

    def generate(self, x, max_new_tokens=50):
        for _ in range(max_new_tokens):
            logits, _ = self(x)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            x = torch.cat((x, next_token), dim=1)
        return x

# Train model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TextGenerator().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def train():
    for step in range(1000):
        x, y = get_batch("train")
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits, loss = model(x, y)
        loss.backward() 
        optimizer.step()
        if step % 100 == 0:
            print(f"step {step}, Loss:{loss.item()}")
train()

# Generate text
model.eval()
start_seq = "The cat is cute"
x = torch.tensor(encode(start_seq), dtype=torch.long).unsqueeze(0).to(device) #(1, T)
generated = model.generate(x, max_new_tokens=50)
print(decode(generated[0].tolist()))



step 0, Loss:4.468491554260254
step 100, Loss:2.9288618564605713
step 200, Loss:2.7094006538391113
step 300, Loss:2.591665744781494
step 400, Loss:2.500792980194092
step 500, Loss:2.4056499004364014
step 600, Loss:2.3951194286346436
step 700, Loss:2.409162759780884
step 800, Loss:2.393882989883423
step 900, Loss:2.306440830230713
The cat is cuteotemer:
    thesyou Gent boouthilave l eclod te me
