<a href="https://colab.research.google.com/github/Buenobarbie/GPT-from-scratch/blob/master/gpt_dev.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Build GPT from scratch

In [32]:
MANUAL_SEED = 1337
import torch
import torch.nn as nn
import torch.nn.functional as F

# HYPERPARAMETERS ----------
batch_size = 64
block_size = 256
max_steps = 5000
learning_rate = 3e-4
eval_iters = 200
n_emb = 384
n_layer = 6
n_head = 6
dropout = 0.2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# ------------------------

torch.manual_seed(MANUAL_SEED)

## Preprocess dataset

In [33]:
with open('../data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print("Length of dataset in characters: ", len(text))

Length of dataset in characters:  1115394


In [34]:
print(text[:100])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [35]:
# Check all the unique characters in dataset
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("".join(chars))
print("Vocabulary size: ", vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Vocabulary size:  65


In [36]:
# Create a mapping of characters to integers
stoi = {c:i for i,c in enumerate(chars)}
itos = {i:c for i,c in enumerate(chars)}
encode = lambda msg: [stoi[c] for c in msg]               # encode message
decode = lambda e_msg: "".join([itos[i] for i in e_msg])  # decode encoded message

print(encode("Hello, world!"))
print(decode(encode("Hello, world!")))

[20, 43, 50, 50, 53, 6, 1, 61, 53, 56, 50, 42, 2]
Hello, world!


## Tokenize the entire dataset

In [37]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:10])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])


## Split dataset

In [38]:
n = int(0.9 *len(data))
train_data = data[:n]
val_data = data[n:]

## Get chunck of data for training

In [39]:

train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57,  1, 47, 57,  1, 41, 

In [40]:
# x = train_data[:block_size]
# y = train_data[1:block_size+1]
# for t in range(block_size):
#     context = x[:t+1]
#     target = y[t]
#     print(f"When input is {context} the target is: {target}")


When input is tensor([18]) the target is: 47
When input is tensor([18, 47]) the target is: 56
When input is tensor([18, 47, 56]) the target is: 57
When input is tensor([18, 47, 56, 57]) the target is: 58
When input is tensor([18, 47, 56, 57, 58]) the target is: 1
When input is tensor([18, 47, 56, 57, 58,  1]) the target is: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is: 58
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58]) the target is: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47]) the target is: 64
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64]) the target is: 43
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43]) the target is: 52
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52]) the target is: 10
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10]) the target is: 0
When inp

In [41]:
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))  # choose random starting points for each sequence in the batch
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x,y = x.to(device), y.to(device)
    return x, y


In [42]:
xb, yb = get_batch('train')
print("inputs:")
print(xb)
print("targets:")
print(yb)


inputs:
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]], device='cuda:0')
targets:
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]], device='cuda:0')


## Head of self attention

In [43]:
class Head(nn.Module):
    """One head of self attention"""

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_emb, head_size, bias=False)
        self.query = nn.Linear(n_emb, head_size, bias=False)
        self.value = nn.Linear(n_emb, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # not a parameter of the model

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B, T, head_size)
        q = self.query(x) # (B, T, head_size)

        # compute attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5 # (B, T, head_size) @ (B, head_size, T) -> (B, T, T) # normalize ("scaled attention")
        wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # perform weighted agregation of the values
        v = self.value(x)
        out = wei @ v

        return out

## Multi Head Attention

In [44]:
class MultiHeadAttention(nn.Module):
    """Multiples heas of self-attention in parallel"""

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_emb, n_emb)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

## Feed Forward

In [45]:
class FeedForward(nn.Module):
    """A simple linear layer with a non-linearity"""

    def __init__(self, n_emb):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_emb, 4*n_emb),
            nn.ReLU(),
            nn.Linear(4*n_emb, n_emb) ,# projection layer
            nn.Dropout(dropout)
        )

    def forward(self,x):
        return self.net(x)

## Block

In [46]:
class  Block(nn.Module):
    """Transformer block: communication followed by computation"""

    def __init__(self,n_emb,n_head):
        # n_emb: embedding size
        # n_head: number of heads
        super().__init__()
        head_size = n_emb // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_emb)
        self.ln1 = nn.LayerNorm(n_emb)
        self.ln2 = nn.LayerNorm(n_emb)

    def forward(self,x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

## Language model

- B: Batch_size
- T: Length of the input sequence
- C: Number of features per token (size of the embedding)

In [47]:

class LanguageModel(nn.Module):

    def __init__(self):
        super().__init__()

        # Initialize a matrix of shape (vocab_size, vocab_size) with  random values
        # that will be optimized during training
        self.token_embedding_table = nn.Embedding(num_embeddings=vocab_size, embedding_dim=n_emb)
        self.position_embedding_table = nn.Embedding(num_embeddings=block_size, embedding_dim=n_emb)
        self.blocks = nn.Sequential(*[Block(n_emb, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_emb) # final layer norm
        self.lm_head = nn.Linear(n_emb, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx) # (B, T, n_emb)
        pos_emb = self.position_embedding_table(torch.arange(T,device=device)) # (T, n_emb)
        x = tok_emb + pos_emb # (B, T, n_emb)
        x = self.blocks(x)
        logits = self.lm_head(x) # (B, T, vocab_size)

        if targets is None:
            loss = None

        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B,T) array of indeces in the current context
        # the goal is to get (B, T+1), (B, T+2), ... (B, T+max_new_tokens )

        for _ in range(max_new_tokens):
            # crop the context to the last block_size tokens
            idx_crop = idx[:, -block_size:]

            # get the predictions
            logits, loss = self(idx_crop)

            # focus only on the last time step (the logit of the last token in the context)
            logits = logits[:, -1, :] # becomes (B, C)

            # softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)

            # sample from distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

            # append sampled index to the running sequence
            idx = torch.cat([idx, idx_next], dim=1)

        return idx


m = LanguageModel()
model = m.to(device)
logits, loss = model(xb, yb)
print(logits.shape)
print(loss.item())



torch.Size([32, 65])
4.560873508453369


## Train the Model

In [48]:
# create a pytorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [49]:

for steps in range(max_steps):

    # sample batch
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.08707857131958


## Better estimate the loss

Since the loss is calculated for a single batch, it may not representate well the loss of the whole data.

So to estimate we'll calculate the loss for many batches and then return the mean.

In [50]:
@torch.no_grad()
def estimate_loss(eval_iters=500):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            xb, yb = get_batch(split)
            logits, loss = model(xb, yb)
            losses[k] = loss.item()

        out[split] = losses.mean().item()
    model.train()
    return out

print(estimate_loss(eval_iters))

{'train': 2.2933497428894043, 'val': 2.327444314956665}


## Sample

In [51]:
idx = torch.zeros((1,1), dtype=torch.long, device=device) # Contains the character corresponding to index 0 (newline)
print(decode(model.generate(idx, max_new_tokens=500)[0].tolist())) # decode the first batch of the generated completions


EM:
And tow Whourd telst of that's gengdtly for orverer mis red kill yo
ou os to-those grars:'s my Rippleret moorit him;
For the at man etts for I jon set theim bricked do to the your him?
Ye chis unk, y dost lim, thwat's?

Jo fought good owrom gom noknill thou Cothath'd me vene th thryightsh in wor with's blick,
Cor grood that's conryunks's rem to wiorell
Fik rong bloooust's grusolts
And twis thet to pir andtly gall det, with and onent's satce my thers ragos thim ou sepoth a ort oll Qorkirt, wi


### Samples from model

- **On initialization (loss 4.8786):**
    ```
    SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
    wnYWmnxKWWev-tDqXErVKLgJ
    ```

- **After ~20000 steps of trainign (loss 2.5770):**
    ```
    Fouthe f Goversthy harmarend t:
    Musthee aved tef t thaphapayeeraie ce. t, ndedigetlot;
    W:


    A ityove

    ```

- **With one head of self attention (loss 2.3369):**
    ```
    OMLIO:
    Mupace, Lan cesaver mudhe bes vedispt; clacked as Bis lof ff perous yeskll I't oncoragthu.

    MENECES:
    Wally d's tie celoummaves maly,
    I'd, brecu ane uplothat wave tist beman fec ramin oul,a tt y
    ```

- **With multi head attention (loss 2.2553):**
    ```
    GOMBH:
    Antell
    Bof mod st sow hot the, owas wonoth cry to,
    We Nond q'epse, the?
    Nod V-
    Afuch tu what retest I sute Wux no hever bate
    CRAN RE ILAENCES:
    Theld den'th you hat sull ot: so sharlou rit mes y
    ```

- **With feed forward layer (loss 2.2067):**
    ```
    Freds ket your pingold bell,
    Yionk, ther sor bly
    Whis grajike wiching her ons prind laind that pour prest'seped we so I a sacay sicent and forkny notuch puprr tolgn haly prike yethe of the with,
    Shels
    ```

- **With transformer block (loss 2.0027):**
    ```
    Surt my have me the didaig grestles, ciconsulf, dukird commmister the scarusit,
    Bother the litved?

    FRIF:
    O clord.

    GLEONVAUS:
    What sham slampy mining arre cond that Bemten have it wisher prow that do
    ```

- **With Layer normalization (1.9859):**
    ```  
    Spet my havopour dead us ange the so'll of upliedfort are my seep the of dusitit, you, the littlish ledem,
    And dage lets and
    that Juse mistrivy minien,
    If thou them stime
    And hand live pripele that do
    ```

## Increase attention

Every token will receive the context of the previous tokens.

For now, we'll average the values of the features from the previous tokens

In [52]:
torch.manual_seed(MANUAL_SEED)
B,T,C = 4, 8, 2
x = torch.randn((B,T,C))

xbow = torch.zeros((B,T, C))  # x bag of words
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # (t,C)
        xbow[b,t] = xprev.mean(dim=0)

In [53]:
# The last row of xbow is the average of all the previous rows in x
x[0], xbow[0]

(tensor([[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

### Optimize calculation with matrix multiplication

In [54]:
torch.tril(torch.ones(3,3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [55]:
# Each row of matrix c is the sum of the previous rows of matrix b
torch.manual_seed(MANUAL_SEED)
a = torch.tril(torch.ones(3,3))
b = torch.randint(0,10,(3,2)).float()
c = a @ b

print("---------A---------\n")
print(a)
print("\n---------B---------\n")
print(b)
print("\n---------C---------\n")
print(c)

---------A---------

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

---------B---------

tensor([[5., 7.],
        [2., 0.],
        [5., 3.]])

---------C---------

tensor([[ 5.,  7.],
        [ 7.,  7.],
        [12., 10.]])


In [56]:
# Each row of matrix c is the mean of the previous rows of matrix b
torch.manual_seed(MANUAL_SEED)
a = torch.tril(torch.ones(3,3))
a = a / torch.sum(a, dim=1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b

print("---------A---------\n")
print(a)
print("\n---------B---------\n")
print(b)
print("\n---------C---------\n")
print(c)

---------A---------

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

---------B---------

tensor([[5., 7.],
        [2., 0.],
        [5., 3.]])

---------C---------

tensor([[5.0000, 7.0000],
        [3.5000, 3.5000],
        [4.0000, 3.3333]])


In [57]:
wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(dim=1, keepdim=True)
xbow2 = wei @ x     # (T,T) @ (B,T,C) -> (B,T,T) @ (B,T,C) -> (B,T,C)
torch.allclose(xbow, xbow2)

False

### Using softmax

Even though we initialize wei as zeros, it could be a weighted matrix of the
relationship of the tokens.

Then, we set future tokens' relationship to -inf so they don't interere



In [58]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros(T,T)

#tokens from the past cannot comunicate
wei = wei.masked_fill(tril == 0, float('-inf')) # fill every position of wei that is 0 in tril with -inf

wei = F.softmax(wei, dim=1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

False

### Self attention

In [59]:
torch.manual_seed(MANUAL_SEED)
B,T,C = 4, 8, 32
x = torch.randn((B,T,C))

# Single self attention head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)  # (B, T, 16)
q = query(x) # (B, T, 16)

wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) -> (B, T, T


tril = torch.tril(torch.ones(T,T))
# wei = torch.zeros(T,T)
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)

v = value(x)
out = wei @ v
# out = wei @ x


In [60]:
wei

tensor([[[0.0248, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0052, 0.0091, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0521, 0.0135, 0.2482, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3171, 0.0214, 0.1642, 0.1188, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0412, 0.0487, 0.1046, 0.0742, 0.2000, 0.0000, 0.0000, 0.0000],
         [0.1060, 0.5347, 0.2059, 0.1030, 0.7402, 0.0192, 0.0000, 0.0000],
         [0.4298, 0.3409, 0.1769, 0.2027, 0.0480, 0.8472, 0.2329, 0.0000],
         [0.0238, 0.0316, 0.1002, 0.5013, 0.0117, 0.1336, 0.7671, 1.0000]],

        [[0.0443, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0042, 0.0375, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0560, 0.0210, 0.2496, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3679, 0.1441, 0.4929, 0.0438, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0088, 0.1052, 0.0604, 0.5847, 0.2046, 0.0000, 0.0000, 0.0000],
         [0.0367, 0.089