In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-03-31 10:26:05--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-03-31 10:26:05 (16.7 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [4]:
# Configuration
BATCH_SIZE = 32
SEQUENCE_LENGTH = 256
MAX_ITERATIONS = 1000
EVAL_INTERVAL = 500
LEARNING_RATE = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
EVAL_ITERS = 200
EMBEDDING_DIM = 384
NUM_HEADS = 6
NUM_LAYERS = 6
DROPOUT_RATE = 0.2

torch.manual_seed(1337)


<torch._C.Generator at 0x7aa23813a7d0>

In [5]:
# Load dataset
with open('input.txt', 'r', encoding='utf-8') as f:
    raw_text = f.read()

# Character processing
unique_chars = sorted(set(raw_text))
VOCAB_SIZE = len(unique_chars)
char_to_idx = {ch: idx for idx, ch in enumerate(unique_chars)}
idx_to_char = {idx: ch for idx, ch in enumerate(unique_chars)}

def text_to_indices(text):
    return [char_to_idx[c] for c in text]

def indices_to_text(indices):
    return ''.join([idx_to_char[i] for i in indices])

# Convert dataset into tensor
data = torch.tensor(text_to_indices(raw_text), dtype=torch.long)
split_idx = int(0.9 * len(data))
train_data, val_data = data[:split_idx], data[split_idx:]


In [6]:
def fetch_batch(split):
    data_source = train_data if split == 'train' else val_data
    batch_indices = torch.randint(len(data_source) - SEQUENCE_LENGTH, (BATCH_SIZE,))
    X = torch.stack([data_source[i : i + SEQUENCE_LENGTH] for i in batch_indices])
    Y = torch.stack([data_source[i + 1 : i + SEQUENCE_LENGTH + 1] for i in batch_indices])
    return X.to(device), Y.to(device)


In [7]:
@torch.no_grad()
def compute_loss():
    results = {}
    model.eval()
    for split in ['train', 'val']:
        loss_list = torch.zeros(EVAL_ITERS)
        for i in range(EVAL_ITERS):
            x_batch, y_batch = fetch_batch(split)
            _, loss = model(x_batch, y_batch)
            loss_list[i] = loss.item()
        results[split] = loss_list.mean()
    model.train()
    return results


In [8]:
class SelfAttentionHead(nn.Module):
    def __init__(self, head_dim):
        super().__init__()
        self.key_layer = nn.Linear(EMBEDDING_DIM, head_dim, bias=False)
        self.query_layer = nn.Linear(EMBEDDING_DIM, head_dim, bias=False)
        self.value_layer = nn.Linear(EMBEDDING_DIM, head_dim, bias=False)
        self.tril_mask = torch.tril(torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH))
        self.dropout = nn.Dropout(DROPOUT_RATE)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key_layer(x)
        q = self.query_layer(x)
        scores = (q @ k.transpose(-2, -1)) * (k.shape[-1] ** -0.5)
        scores = scores.masked_fill(self.tril_mask[:T, :T] == 0, float('-inf'))
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        v = self.value_layer(x)
        return attention_weights @ v


In [9]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_dim):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttentionHead(head_dim) for _ in range(num_heads)])
        self.projection = nn.Linear(head_dim * num_heads, EMBEDDING_DIM)
        self.dropout = nn.Dropout(DROPOUT_RATE)

    def forward(self, x):
        return self.dropout(self.projection(torch.cat([h(x) for h in self.heads], dim=-1)))

class FeedForwardLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(EMBEDDING_DIM, 4 * EMBEDDING_DIM),
            nn.ReLU(),
            nn.Linear(4 * EMBEDDING_DIM, EMBEDDING_DIM),
            nn.Dropout(DROPOUT_RATE),
        )

    def forward(self, x):
        return self.net(x)


In [10]:
class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        head_dim = EMBEDDING_DIM // NUM_HEADS
        self.attn = MultiHeadAttention(NUM_HEADS, head_dim)
        self.ffn = FeedForwardLayer()
        self.norm1 = nn.LayerNorm(EMBEDDING_DIM)
        self.norm2 = nn.LayerNorm(EMBEDDING_DIM)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x

class LanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(VOCAB_SIZE, EMBEDDING_DIM)
        self.pos_embedding = nn.Embedding(SEQUENCE_LENGTH, EMBEDDING_DIM)
        self.transformer_blocks = nn.Sequential(*[TransformerBlock() for _ in range(NUM_LAYERS)])
        self.norm_final = nn.LayerNorm(EMBEDDING_DIM)
        self.output_layer = nn.Linear(EMBEDDING_DIM, VOCAB_SIZE)
        self.apply(self._initialize_weights)

    def _initialize_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, inputs, targets=None):
        batch_size, seq_len = inputs.shape
        token_embeds = self.token_embedding(inputs)
        pos_embeds = self.pos_embedding(torch.arange(seq_len, device=device))
        x = token_embeds + pos_embeds
        x = self.transformer_blocks(x)
        x = self.norm_final(x)
        logits = self.output_layer(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), targets.view(-1))
        return logits, loss

    def generate(self, context, new_tokens):
        for _ in range(new_tokens):
            cropped_context = context[:, -SEQUENCE_LENGTH:]
            logits, _ = self(cropped_context)
            next_token = torch.multinomial(F.softmax(logits[:, -1, :], dim=-1), num_samples=1)
            context = torch.cat((context, next_token), dim=1)
        return context


In [11]:
model = LanguageModel().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

for iteration in range(MAX_ITERATIONS):
    if iteration % EVAL_INTERVAL == 0:
        losses = compute_loss()
        print(f"Step {iteration}: Train Loss {losses['train']:.4f}, Val Loss {losses['val']:.4f}")

    x_batch, y_batch = fetch_batch('train')
    _, loss = model(x_batch, y_batch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


KeyboardInterrupt: 

In [None]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(indices_to_text(model.generate(context, 500)[0].tolist()))
