<a href="https://colab.research.google.com/github/GianlucaRapaglia/LLM-training/blob/main/02%20-%20Building%20GPT%20from%20scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Building a GPT
Let's first download the tiny Shakespeare dataset

In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-07-18 16:34:51--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-07-18 16:34:51 (17.0 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [2]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [4]:
# let's look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [5]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [6]:
# let's encode and decode the used chars
stoi = { ch:i for i,ch in enumerate(chars)}
itos = { i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]               #encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l])      #decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))


[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [7]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [8]:
# Let's split the data into train and validation sets
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

Now we want to train on the Shakespeare text. But we are not going to give him the whole text but chunks of text. Giving the whole text would be too computationally expensive. Let's define a block size of 8:

In [9]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [10]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
  context = x[:t+1]
  target = y[t]
  print(f"when input is {context} the target: {target}")



when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


Now let's define the batch size that will help us to:

1. process multiple independent sequences in parallel

2. helps gradient descent by computing gradients from average error over the batch, which leads to better optimization stability

In [11]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return ix, x, y

ix, xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('-----')

for b in range(batch_size):
  for t in range(block_size):
    context = xb[b, :t+1]
    target = yb[b,t]
    print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
-----
when input is [24] the target: 43
when input is [24, 43] the target: 58
when input is [24, 43, 58] the target: 5
when input is [24, 43, 58, 5] the target: 57
when input is [24, 43, 58, 5, 57] the target: 1
when input is [24, 43, 58, 5, 57, 1] the target: 46
when input is [24, 43, 58, 5, 57, 1, 46] the target: 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39
when input is [44] the target: 53
when input is [44, 53] the target: 56
when input is [44, 53, 56] the target: 1
when input is [44, 53, 56, 1] the target: 58
when input is [44, 53, 56, 1, 58] the target: 46
when input is [44, 5

Let's start from the simplest NN Language Model. In the token_embedding_table, since the embedding size is equal to the vocab size, each row directly represents the logits for predicting the next token.

Notice that, in this simple case, logits are just rows from a learned matrix. There are no transformations, no nonlinearities, no attention, nothing.

While in a full transformer-based model (like GPT):

The logits come from a decoder block, which includes:

- Token + positional embeddings

- Multi-head self-attention layers

- Feedforward networks

The final hidden state is then projected into a vocabulary-sized vector using a linear layer.

But let's focus on our simple case right now:

In [12]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
  def __init__(self, vocab_size):
    super().__init__()
    # each token read off the logits for the next token from a lookup table of dimension vocab_size x vocab_size
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

  def forward(self, idx, targets):
    logits = self.token_embedding_table(idx) # (B, T, C)

    if targets is None:
      loss = None
    else:
      # but torch does not take B, T, C so let's convert them
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
      # get the predictions
      logits, loss = self(idx, None)
      # focus only on the last time step, in Bigram model we'll be looking just at the last time step, last character in our case
      logits = logits[:, -1, :] # becomes (B, C)
      # apply softmax to get probabilities
      probs = F.softmax(logits, dim=-1) # (B, C)
      # sample from the distribution
      idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
      # append sampled index to the running sequence
      idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

idx = torch.zeros((1, 1), dtype=torch.long)
print(decode(m.generate(idx, max_new_tokens=100)[0].tolist()))

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


We can notice that the loss is 4.8786, but from an untrained model we would expect -ln(1/65). This happens because weights are randomly initialized.

Moreover, the generation gives us garbage since the model is not trained yet. Is important to remark that for every character prediction, done through the generate function, the Bigram model uses just the last time step.

Now let's train the model:

In [13]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [14]:
batch_size = 12
for steps in range(10000):
  # sample a batch of data
  _, xb, yb = get_batch('train')

  # evaluate the loss
  logits, loss = m(xb, yb)
  optimizer.zero_grad(set_to_none=True) # set_to_none=True will set model.param = None instead of =0, it is more efficient
  loss.backward()
  optimizer.step()

print(loss.item())

2.308964729309082


In [15]:
print(decode(m.generate(idx, max_new_tokens=500)[0].tolist()))


junt wok ge he O:
TAhat thed pDANG yericcand;:

musth. IDofAREThen malot myenndg:
POWh mure s, ink INDWBEREYer fest stig's in!
Tar he, us, KI&
ESha kehithalimfre.
IIAR ut
med isonth,'the maknyall ille.
Nmal, p,
Thy,M:
Sp t as? I
INTh cane leefFRCE:
Wg tt
RIfus.

T:
WWve s Cithy th?head hor, he he M:

EDI tes t blere IIt, t cas heathiCOLINCAplll miok's stSSperme, mell.
TETh laricotthe heaina honourto had;
LAhan al hairdrofoun at d

DUS:
Ed d Cowhy.
Goticouro?
LO:
Theur powadggouth,IOUECl,ONGXES:



Now it makes more sense! But let's work on it further.

Now we want that the tokens will talk to each other before predicting what comes next, in order to understand the context. And this is how we are going to kick off the transformer.

First, let's understand how **self-attention** works and the mathematical trick to make it more efficient.

In [16]:
# consider the following toy example

torch.manual_seed(1337)
B, T, C = 4, 8, 2 # batch, time(tokens), channel
x = torch.randn(B, T, C)
x.shape


torch.Size([4, 8, 2])

Now we want to take into account past tokens and we are going to do that averging the previous values (bag of words)

In [17]:
# We want x[b,t] = mean_{i<=t} x[b, i]
xbow = torch.zeros((B, T, C))
for b in range(B):
  for t in range(T):
    xprev = x[b, :t+1] # (t, C)
    xbow[b, t] = torch.mean(xprev, 0)

The last line of code is averaging over the time (tokens), since each token will be represented by a vector of C dimension, at the end we will have xbow to be of C dimension and keeping inside the information of what the previous tokens are.

In [18]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [19]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

But these averaging is very inefficient, so let's make another example and introduce a more efficient operation.

In [20]:
# consider this toy example
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)



a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [21]:
# version 2: using matrix multiply for a weighted aggregation
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
#torch.allclose(xbow, xbow2)

In [22]:
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf')) # this line reminds us that the future cannot communicate with the past (-inf to future tokens)
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)


False

Let us now delve into the core concept of the model architecture: self-attention, focusing initially on a single attention head.

As demonstrated in version 3, it is possible to compute a weighted average over previous tokens using a fixed lower-triangular matrix. However, consider a scenario where the current token is a vowel, it may be more semantically relevant to attend to consonants in the preceding context. In such cases, a fixed attention pattern is insufficient.

Self-attention addresses this by enabling the model to compute data-dependent attention weights, allowing each token to dynamically determine which past tokens are most relevant for its contextual representation.

Each token in the sequence is projected into three distinct vectors:

- a query vector, representing what the token is "seeking" in the context,

- a key vector, representing the content or "information available" in the token,

To compute the attention scores (or affinities) between tokens, we take the dot product between the query vector of the current token and the key vectors of all tokens in the sequence. This operation determines how much attention the current token should assign to each of the others based on their content similarity.

In [23]:
# version 3: use Softmax
torch.manual_seed(1337)
B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn(B, T, C)

# let's see a single head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False) # just a linear projection Wk (weight matrix)
query = nn.Linear(C, head_size, bias=False) # just a linear projection Wq (weight matrix)
value = nn.Linear(C, head_size, bias=False) # just a linear projection Wv (weight matrix)
k = key(x)    #  (B, T, 16)
q = query(x)  #  (B, T, 16)
# now we perform the multiplication, values that match will have high values
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T) <--- for every row B we will have a square matrix (T, T) giving us the -inf but computed with queries and keys

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf')) # this line reminds us that the future cannot communicate with the past (-inf to future tokens)
wei = F.softmax(wei, dim=-1)

v = value(x)
# thus we do not aggregate x but v
out = wei @ v

out.shape, v.shape, wei.shape

(torch.Size([4, 8, 16]), torch.Size([4, 8, 16]), torch.Size([4, 8, 8]))

Notice that we introduces also a value vector. But, why not use K and V as the same thing?

Because:

- Keys control who gets attended to

- Values control what is retrieved from those tokens

Separating K and V allows the model to attend to different aspects of a token depending on the context.
Just to give an example: Imagine you're at a networking event. There are several people in the room (tokens in a sequence), and you want to find someone to talk to about machine learning.

🔷 Each person has:
- A Key: what topics they can discuss (e.g., ["machine learning", "marketing"])

- A Value: the actual experience, knowledge, and insights they’ll share if you talk to them

🔷 You have a:
Query: what topic you’re interested in discussing (e.g., ["machine learning"])

In [24]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

In [25]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1) * head_size**-0.5 # to control the variance

In [26]:
k.var()

tensor(1.0449)

In [27]:
q.var()

tensor(1.0700)

In [28]:
wei.var()

tensor(1.0918)

Notice that we divide wei by the square root of head_size because if we are providing to softmax very positive and very negative numbers, it will actually converts towards one-hot vectors, it will sharpen towards the maximum value. Let's make an example to show it:

In [29]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [30]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1)

tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])

In this section, we will implement the training code for a GPT (Generative Pretrained Transformer) model.
Unlike the original Transformer architecture, which consists of both an encoder and a decoder, GPT utilizes only the decoder stack, making it suitable for autoregressive language modeling tasks.

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 300
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384 # size of the embedding vector for each token
n_head = 6 # number of heads in the multi-head attention
n_layer = 6 # number of transformer blocks
dropout = 0.2 # dropout rate for regularization, every forward and backward pass 20% of all the intermediate calculations are disabled and put to zero
# ------------

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()

# estimate the loss over the train and val datasets
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) # to create the lower triangular mask
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)

        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out


class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd) # linear layer to project the concatenated output of all heads back to n_embd
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # concatenate the outputs of all heads, dimension n_embd
        return self.proj(out)

class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.ReLU(),
            nn.Linear(4*n_embd, n_embd), # projection layer going back into the residual pathway
            nn.Dropout(dropout) # dropout for regularization
        )


    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size) # communication via self-attention
        self.ffwd = FeedForward(n_embd) # computation via feed-forward layer
        self.ln1 = nn.LayerNorm(n_embd) # layer normalization after self-attention, the normalization is applied along n_embd dimension, e.g. each token is normalized independently
        self.ln2 = nn.LayerNorm(n_embd) # layer normalization after feed-forward, the normalization is applied along n_embd dimension, e.g. each token is normalized independently

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.positional_embedding_table = nn.Embedding(block_size, n_embd) # positional embeddings
        self.blocks = nn.Sequential(*(Block(n_embd, n_head=n_head) for _ in range(n_layer))) # 3 transformer blocks, each with 4 heads of self-attention
        self.ln_f = nn.LayerNorm(n_embd) # final layer normalization
        #self.blocks = nn.Sequential(
        #    Block(n_embd, n_head = 4), # 4 heads of self-attention
        #    Block(n_embd, n_head = 4),
        #    Block(n_embd, n_head = 4),
        #    nn.LayerNorm(n_embd) # final layer normalization
        #)
        #self.sa_head = MultiHeadAttention(4, n_embd // 4) # self attention with 4 heads, each of size 8
        #self.ffwd = FeedForward(n_embd) # feed-forward layer
        self.lm_head = nn.Linear(n_embd, vocab_size) # this time we need a linear layer to project the embedding to vocab size

    def forward(self, idx, targets=None):
        B, T  = idx.shape  # B is batch size, T is block size

        # idx and targets are both (B,T) tensor of integers
        tok_emb= self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.positional_embedding_table(torch.arange(T, device=device)) # (T,C), positional embeddings - Every sequence in the batch gets the same positional encodings added to their token embeddings.
        x = tok_emb + pos_emb # (B,T,C), add positional embeddings
        x = self.blocks(x) # (B,T,C), pass through the transformer blocks
        #x = self.sa_head(x) # (B,T,C), apply one self-attention head
        #
        # x = self.ffwd(x) # (B,T,C), apply feed-forward layer - all the tokens do this indipendently
        logits = self.lm_head(x) # (B,T,Vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = BigramLanguageModel()
m = model.to(device)

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

step 0: train loss 4.4752, val loss 4.4707
step 300: train loss 2.2919, val loss 2.3303
step 600: train loss 1.7932, val loss 1.9207
step 900: train loss 1.5922, val loss 1.7546
step 1200: train loss 1.4780, val loss 1.6690
step 1500: train loss 1.4079, val loss 1.6173
step 1800: train loss 1.3453, val loss 1.5675
step 2100: train loss 1.2969, val loss 1.5453
step 2400: train loss 1.2610, val loss 1.5256
step 2700: train loss 1.2222, val loss 1.5115
step 3000: train loss 1.1909, val loss 1.5049
step 3300: train loss 1.1594, val loss 1.5005
step 3600: train loss 1.1345, val loss 1.5108
step 3900: train loss 1.0970, val loss 1.4968
step 4200: train loss 1.0706, val loss 1.5028
step 4500: train loss 1.0335, val loss 1.5154
step 4800: train loss 1.0046, val loss 1.5345

Of bloody thought of calm their tither own:
Good faither's daughter, which thou art.
He's diseased to my grave: but I say myself
To approve myself and tyrant three my bosom.

LORD WILLOUGHBY BOLINGBROKENBERK:
O slack to thy