In [1]:
with open('data/tiny_shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [2]:
print(f"No. of characters in the dataset: {len(text)}")

No. of characters in the dataset: 1115394


In [4]:
print(text[:500])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"Vocabulary size: {vocab_size}")

Vocabulary size: 65


In [6]:
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[i] for i in l])

In [7]:
print(encode("Hi there"))

[20, 47, 1, 58, 46, 43, 56, 43]


In [8]:
print(decode(encode("Hi there")))

Hi there


In [9]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


In [10]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [11]:
context_length = 8
train_data[:context_length+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [12]:
x = train_data[:context_length]
y = train_data[1:context_length+1]
for t in range(context_length):
    context = x[:t+1]
    target = y[t]
    print(f"When input is {context}, the target is {target}")

When input is tensor([18]), the target is 47
When input is tensor([18, 47]), the target is 56
When input is tensor([18, 47, 56]), the target is 57
When input is tensor([18, 47, 56, 57]), the target is 58
When input is tensor([18, 47, 56, 57, 58]), the target is 1
When input is tensor([18, 47, 56, 57, 58,  1]), the target is 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]), the target is 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), the target is 58


In [15]:
torch.manual_seed(0)
batch_size = 4
context_length = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - context_length, (batch_size,))
    x = torch.stack([data[i:i+context_length] for i in ix])
    y = torch.stack([data[i+1:i+context_length+1] for i in ix])
    return x, y

In [16]:
xb, yb = get_batch('train')

In [17]:
print(f'inputs: {xb.shape}')
print(f'targets: {yb.shape}')

inputs: torch.Size([4, 8])
targets: torch.Size([4, 8])


In [19]:
for b in range(batch_size):
    for t in range(context_length):
        context = xb[b: t+1]
        target = yb[b, t]

## Bigram language model

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets):
        logits = self.token_embedding_table(idx)
        
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = targets.view(B*T)
        
        loss = F.cross_entropy(logits, targets)
        
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(out.shape)

torch.Size([4, 8, 65])


In [26]:
loss

tensor(4.8532, grad_fn=<NllLossBackward0>)

In [24]:
yb

tensor([[43,  1, 45, 47, 60, 43, 57,  1],
        [ 1, 58, 46, 43,  1, 63, 53, 59],
        [50, 50,  1, 63, 53, 59,  1, 46],
        [61, 52,  1, 57, 53, 59, 50,  1]])