In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
BATCH_SIZE = 64
BLOCK_SIZE = 256
MAX_ITERS = 5000
EVAL_INTERVAL = 500
LEARNING_RATE = 3e-4
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EVAL_ITERS = 200
MODEL_DIM = 96
NUM_HEADS = 6
NUM_LAYERS = 6 
DROPOUT_PROB = 0.2


In [5]:
with open('data/hemingway.txt',  'r', encoding='utf-8') as file:
    corpus = file.read()

chars = sorted(list(set(corpus)))
VOCAB_SIZE = len(chars)
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}


In [6]:
def encode(string):
    return [char_to_idx[c] for c in string]

def decode(indices):
    return ''.join([idx_to_char[i] for i in indices])

In [8]:
data = torch.tensor(encode(corpus), dtype=torch.long)
split_idx = int(0.9 * len(data))
train_data = data[:split_idx]
val_data = data[split_idx:]

In [9]:
# data loading                                                                                                                          
def get_batch(split):
    data = train_data if split == 'train' else val_data
    indices = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
    x = torch.stack([data[i:i+BLOCK_SIZE] for i in indices])
    y = torch.stack([data[i+1:i+BLOCK_SIZE+1] for i in indices])
    x, y = x.to(DEVICE), y.to(DEVICE)
    return x, y


In [10]:
@torch.no_grad()
def estimate_loss(model):
    results = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            x, y = get_batch(split)
            _, loss = model(x, y)
            losses[k] = loss.item()
        results[split] = losses.mean()
    model.train()
    return results


In [11]:
class Head(nn.Module):
    """ use one head self-attention """

    def __init__(self, head_dim):
        super().__init__()
        self.W_K = nn.Linear(MODEL_DIM, head_dim, bias=False)
        self.W_Q = nn.Linear(MODEL_DIM, head_dim, bias=False)
        self.W_V = nn.Linear(MODEL_DIM, head_dim, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)))

        self.dropout = nn.Dropout(DROPOUT_PROB)

    def forward(self, x):
        # (B, T, d_model)
        B, T, d = x.shape
        k = self.W_K(x)
        q = self.W_Q(x)                                                                                             
        # compute attention scores
        
        scores = q @ k.transpose(-2, -1) * d**-0.5                                          
        scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))                                                    
        a = F.softmax(scores, dim=-1)                                                                                        
        a = self.dropout(a)

        # perform the weighted aggregation of the values   
                                                                                     
        v = self.W_V(x)                                                                                                   
        out = a @ v
        
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads self-attention in parallel """

    def __init__(self, num_heads, head_dim):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_dim) for _ in range(num_heads)])
        self.proj = nn.Linear(MODEL_DIM, MODEL_DIM)
        self.dropout = nn.Dropout(DROPOUT_PROB)

    def forward(self, x):
        # Concatenate the different representations per head.
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        # Project the concatenation.
        out = self.dropout(self.proj(out))
        return out


In [14]:
# A simple linear layer followed by a non-linearity; this is applied at the token level.
class FeedForward(nn.Module):

    def __init__(self, d_model):
        super().__init__()
        d_ff = 4 * d_model
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(DROPOUT_PROB),
        )

    def forward(self, x):
        return self.ff(x)

In [15]:
# Transformer decoder block

class DecoderBlock(nn.Module):

    def __init__(self, d_model, n_head):
        super().__init__()
        d_head = d_model // n_head
        self.sa = MultiHeadAttention(n_head, d_head)
        self.ff = FeedForward(d_model)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x


In [16]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table                                                                                                                                                                                                                                                                                                                                                                                                                                                            
        self.token_embedding_table = nn.Embedding(VOCAB_SIZE, MODEL_DIM)
        self.position_embedding_table = nn.Embedding(BLOCK_SIZE, MODEL_DIM)
        self.blocks = nn.Sequential(
            *[DecoderBlock(MODEL_DIM, n_head=NUM_HEADS) for _ in range(NUM_LAYERS)]
        )
         # final layer norm   
        self.ln = nn.LayerNorm(MODEL_DIM)
        self.ff = nn.Linear(MODEL_DIM, VOCAB_SIZE)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=DEVICE))
        # Add positional encodings.
        x = tok_emb + pos_emb

        # Mix up the token representations over and over via the blocks
        x = self.blocks(x)
        x = self.ln(x)
        logits = self.ff(x)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):

        self.eval()
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -BLOCK_SIZE:]
            # get the predictions
            logits, _ = self(idx_cond)
            # focus only on the last time step                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
            logits = logits[:, -1, :]                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             
            # apply softmax                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          
            probs = F.softmax(logits, dim=-1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
            # sample from the distribution                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
            idx_next = torch.multinomial(probs, num_samples=1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                            
            # append sampled index to the sequence                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
            idx = torch.cat((idx, idx_next), dim=1) 
        self.train()                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
        return idx

In [17]:
class EarlyStopping:
    def __init__(self, tolerance=5, min_delta=0):
        self.tolerance = tolerance
        self.min_delta = min_delta
        self.counter = 0
        self.early_stop = False

    def __call__(self, train_loss, validation_loss):
        if (validation_loss - train_loss) / train_loss > self.min_delta:
            self.counter += 1
            if self.counter >= self.tolerance:
                self.early_stop = True

In [18]:
model = GPT().to(DEVICE)
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)
early_stopping = EarlyStopping(tolerance=1, min_delta=0.2)

0.706046 M parameters


In [19]:
for iter in range(MAX_ITERS):
    if iter % EVAL_INTERVAL == 0 or iter == MAX_ITERS - 1:
        if iter:
            scheduler.step()
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        early_stopping(losses['train'], losses['val'])
        if early_stopping.early_stop:
            print(f"Early stopping at iteration {iter}")
            break

    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()


step 0: train loss 4.3438, val loss 4.3537


KeyboardInterrupt: 

In [None]:
# Start the model with a new line
context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
generated_text = decode(model.generate(context, max_new_tokens=100)[0].tolist())
print(generated_text)                       