In [2]:
# dataset to train our model on. Download the tiny shakespeare dataset here :
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-11-03 22:28:17--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2024-11-03 22:28:17 (14.9 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
# read the tiny shakespeare dataset:
with open('tiny_shakespeare.txt', 'r', encoding='utf-8') as f:
  text = f.read()

In [5]:
print("length (in characters) of dataset : ", len(text))

length (in characters) of dataset :  1115394


In [7]:
# peeking at first 2500 characters :
print(text[:2500])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [22]:
# all unique characters in text of our dataset tiny shakespeare :
uchars = sorted(list(set(text)))

vocab_size = len(uchars)

print(''.join(uchars))
print(f'n° of unique chars : {vocab_size}')
print('Y' in uchars)
print(uchars)



 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
n° of unique chars : 65
True
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [25]:
# map characters (using unique characters: uchars) to integers
stoi = { ch: i for i, ch in enumerate(uchars) }
itos = { i: ch for i, ch in enumerate(uchars) }

# encode /decode string of characters
encode = lambda s: [stoi[c] for c in s ] # take string and encode it in list of integers
decode = lambda l: ''.join([itos[i] for i in l])

#use functions
txt = "Yo it's Malek !!"
print(encode(txt))
print(decode(encode((txt))))

[37, 53, 1, 47, 58, 5, 57, 1, 25, 39, 50, 43, 49, 1, 2, 2]
Yo it's Malek !!


In [30]:
# encoding the entire text dataset

import torch

data = torch.tensor(encode(text), dtype=torch.long)

print(data.shape, data.dtype, data.type)
print(data[:1500])

torch.Size([1115394]) torch.int64 <built-in method type of Tensor object at 0x7e3d182cdee0>
tensor([18, 47, 56,  ..., 58, 53,  1])


In [33]:
# train and validation sets from our dataset and for our training:
n = int(0.875*len(data))

train_data = data[:n]

val_data = data[n:]

In [41]:
# pretend we're sampling chunks of our training set randomly, we need to fix the maximum size of such random block
block_size = 16

print(train_data[:block_size * 3])

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43])


In [43]:
x = train_data[:block_size]
y = train_data[1:block_size+1]

for t in range(block_size):

  context = x[:t+1]
  target = y[t]

  print(f'When input is {context}, the target is {target}.')

When input is tensor([18]), the target is 47.
When input is tensor([18, 47]), the target is 56.
When input is tensor([18, 47, 56]), the target is 57.
When input is tensor([18, 47, 56, 57]), the target is 58.
When input is tensor([18, 47, 56, 57, 58]), the target is 1.
When input is tensor([18, 47, 56, 57, 58,  1]), the target is 15.
When input is tensor([18, 47, 56, 57, 58,  1, 15]), the target is 47.
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), the target is 58.
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58]), the target is 47.
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47]), the target is 64.
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64]), the target is 43.
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43]), the target is 52.
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52]), the target is 10.
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10]), the target 

In [47]:
torch.manual_seed(1337) # reproducibility seed (like 42 for statisictical analysis : PCA, LDA, PCR ....etc ) !

batch_size = 4 # process independent sequences in parallel, just like for ANNs / CNNs
block_size = 8 # maximum context length to do predictions

def get_batch(data_split):

  # generate a small batch of data of inputs x and targets y
  data = train_data if data_split == 'train' else val_data
  idx = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i: i+block_size] for i in idx])
  y = torch.stack([data[i+1: i+block_size+1] for i in idx])

  return x, y

xb, yb = get_batch('train')

print('inputs:')
print(xb.shape)
print(xb)
print('Targets:')
print(yb.shape)
print(yb)

print('------------------------------------------')

for b in range(batch_size): # batch dimension
  for t in range(block_size): # time dimension ?
    context = xb[b, :t+1]
    target = yb[b, t]

    print(f'when input is {context.tolist()}, the target is : {target}')

inputs:
torch.Size([4, 8])
tensor([[43,  6,  1, 40, 43, 57, 43, 43],
        [57, 43,  1, 46, 59, 51, 57,  1],
        [57,  6,  0, 17, 60, 43, 52,  1],
        [ 1, 45, 47, 56, 50,  8,  0,  0]])
Targets:
torch.Size([4, 8])
tensor([[ 6,  1, 40, 43, 57, 43, 43, 41],
        [43,  1, 46, 59, 51, 57,  1, 39],
        [ 6,  0, 17, 60, 43, 52,  1, 44],
        [45, 47, 56, 50,  8,  0,  0, 22]])
------------------------------------------
when input is [43], the target is : 6
when input is [43, 6], the target is : 1
when input is [43, 6, 1], the target is : 40
when input is [43, 6, 1, 40], the target is : 43
when input is [43, 6, 1, 40, 43], the target is : 57
when input is [43, 6, 1, 40, 43, 57], the target is : 43
when input is [43, 6, 1, 40, 43, 57, 43], the target is : 43
when input is [43, 6, 1, 40, 43, 57, 43, 43], the target is : 41
when input is [57], the target is : 43
when input is [57, 43], the target is : 1
when input is [57, 43, 1], the target is : 46
when input is [57, 43, 1, 46

In [48]:
print(xb) # this is our input (embedding ?) to the transformer

tensor([[43,  6,  1, 40, 43, 57, 43, 43],
        [57, 43,  1, 46, 59, 51, 57,  1],
        [57,  6,  0, 17, 60, 43, 52,  1],
        [ 1, 45, 47, 56, 50,  8,  0,  0]])


In [62]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)


class BigramLanguageModel(nn.Module):

  def __init__(self, vocab_size):

    super().__init__()

    # each token --directly---> reads off the logits, for the next token, from a lookup table
    self.token_embedding_table = nn.Embedding(vocab_size,vocab_size)

  def forward(self, idx, targets=None):

    # ( idx, targets ) === (B,T) tensor of integers
    logits = self.token_embedding_table(idx) # (B,T,C)

    if targets is None:

      loss = None

    else:

      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss


  def generate(self, idx, max_new_tokens):
    # (B, T) === idx : array of indices in the current context

    for _ in range(max_new_tokens):

      # get the predictions:
      logits, loss = self(idx)

      # last step only (main focus)
      logits = logits[:, -1, :] # becomes (B, C) ?


      # applying softmax : turn received logits into probablities
      probs = F.softmax(logits,  dim=1) # (B,C)

      # sampling from the distribution
      idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)

      # appending sampled index to running sequence
      idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)

    return idx


m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)

print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 65])
tensor(4.4732, grad_fn=<NllLossBackward0>)

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


In [94]:
# let's create the optimizer that we are going to use using PyTorch:
optimizer = torch.optim.Adam(m.parameters(), lr=1e-3 )

In [95]:
batch_size = 64

for steps in range(7500): # increate for better results ?

  # sampling a batch of data:
  xb, yb = get_batch('train')

  # evaluating the loss
  logits, loss = m(xb, yb)
  optimizer.zero_grad(set_to_none=None)
  loss.backward()
  optimizer.step()

print(loss.item())

2.46026873588562


In [96]:
# generating with our current model basically outputs random gibberish, we would need to train and optimize further to get to great results
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=3500)[0].tolist()))


POf imyos ar; l thim,
AUCKtas
PSindido neco hther?

PUCI'd brd, al whayorend D t ay astad s me Caryot.

Thanenoust rr, ws ce thateerd ak.
NT:

QULOWh hksshr aum, ct l: fawan?
MI theatled
I're gthuent mo igomadr
I mm tin nigOM:
thas, Pactefoughag mis.

Buan'toundo Myonde swily myenan:
Qwovise!
TIARO, ble wer d OLBut ROUE angen git fou.
Ists. cLES:
GHeepes, VONCore SocBE:
BULANLI t f, BUS wat t o inothere thoof d lor isuth mitenoorer itelazentarckeory a c'si' d he tesprogrr's by crthim acode fam tounoneou, JO, wats ivit ofodis$DI ha sor, willleZARGRLARIUS:
Y:

Thivitlainthine ken,

erimy he aveco:
RILie.
Thofelly u RI:
AKInfer:
Thecouthals myotooupand my, luthir r hirot beat p, tend ins, y se.
ghenarcllt y.
HES:
Iy orirl I imondareyo y ovethoves four it burengonond
Were ante bencusamy wink houge s owhedor aver oras no a; llos O:
Ant fishe
Ang p-atho! ave ird, isaraoolaver, tharstur orpe Oru al wamy hr d t po mak.
CEXESonoly d s! rd ieitee ithesi'ds then, hin bath are.
Be athefesedie tof



---

##  Mathematical trick in self-attention

---



---



In [97]:
# weighted aggregation (research papers) : keywords : fuzzy logic/applied maths ??
torch.manual_seed(42)

a = torch.tril(torch.ones((3, 3)))
a = a / torch.sum(a, 1, keepdim=True) #
b = torch.randint(0, 10, (3,2)).float()
c = a @ b

print('a=')
print(a)
print('-------')
print('b=')
print(b)
print('-------')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
-------
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
-------
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [104]:
# Toy example:
# Version 1

torch.manual_seed(1337)
B, T, C = 4, 8, 2 # batch, time, channels

X = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [105]:
#  x[b,t] = mean_{i<=t} x[b,i]

xbow = torch.zeros((B,T,C))

for b in range(B):
  for t in range(T):

    xprev = x[b ,:t+1] # (T,C)
    xbow[b, t] = torch.mean(xprev, 0)

In [108]:
# Version 2 ==> we use matrix multiply for a weighted aggregation

wei = torch.tril(torch.ones(T,T))

wei = wei / wei.sum(1, keepdim=True)

xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)

torch.allclose(xbow, xbow2)

False

In [110]:
# Version 3 ==> with softmax

tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))

wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)

xbow3 = wei @ x

torch.allclose(xbow, xbow3)

False

In [113]:
# Version 4: NEW! Self-attention

torch.manual_seed(1337)

B,T,C = 4, 8, 32 # as usual: batch, time, channels

x = torch.randn(B,T,C)

# let's see a single Head perform self-attention!

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, 16 === C)
q = query(x) # (B, T, 16 === C)

wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, T, 16) ---> (B, T, T)

# wei = torch.zeros((T,T))
#out = wei @ x

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))


v = value(x)
out = wei @ v

out.shape

torch.Size([4, 8, 16])

In [115]:
wei[0]

tensor([[-1.7629,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-3.3334, -1.6556,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-1.0226, -1.2606,  0.0762,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.7836, -0.8014, -0.3368, -0.8496,    -inf,    -inf,    -inf,    -inf],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,    -inf,    -inf,    -inf],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,    -inf,    -inf],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,    -inf],
        [-1.8044, -0.4126, -0.8306,  0.5899, -0.7987, -0.5856,  0.6433,  0.6303]],
       grad_fn=<SelectBackward0>)

In [116]:
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

# 📝✏️  Notes on Attention and GPT-like Architectures (GPT4o describing gpts 🙏)

---

### 1. Attention as a **Communication Mechanism**
- Attention acts as a way for different elements in a sequence (or "nodes") to **communicate** with each other.
- You can imagine each part as a node in a **directed graph**. Each node receives information from others that "point" to it, aggregating this information with **data-dependent weights**.
- Each node combines information using a **weighted sum** of inputs from other nodes.

---

### 2. No Concept of Space in Attention
- **Attention** itself has no understanding of token order or spatial relationships. It simply operates on a **set of vectors**.
- To give the model positional information, we add **positional encodings** to the tokens so the model can understand each token's position in a sequence.

---

### 3. Independence Across the Batch Dimension
- Each example within a batch is processed **independently**.
- Examples in the same batch never “talk” to each other or share information during the attention process.

---

### 4. Encoder vs. Decoder Attention Blocks
- In an **encoder attention block**, tokens communicate freely—there’s no restriction on which tokens can attend to others.
- In a **decoder attention block**, each token can only “see” previous tokens in the sequence. This restriction is applied using **triangular masking** (typically with `tril`).
- This masking is crucial for **autoregressive models** (like GPT) where each token prediction depends only on past tokens, not future ones.

---

### 5. Self-Attention vs. Cross-Attention
- **Self-Attention**: The **queries, keys, and values** all come from the **same source** (e.g., a sequence of tokens in a sentence). This allows the model to learn relationships within a single sequence.
- **Cross-Attention**: The **queries** come from one source (like an input sequence), while the **keys and values** come from another (such as an encoder block in a different layer). This setup is often used in models where different parts handle distinct data types (e.g., in translation models).

---

### 6. Scaled Attention
- Attention weights (`wei`) are scaled by dividing by the square root of the vector size (`head_size`) before applying **softmax**.
- This **scaling keeps gradients stable**: when the variance of `Q` and `K` is 1, this scaling keeps `wei` from concentrating too heavily on a few elements, ensuring softmax remains balanced.
  
---

In [117]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)

wei = q @ k.transpose(-2, -1) * head_size**-0.5

In [118]:
k.var()

tensor(1.0449)

In [120]:
q.var()

tensor(1.0700)

In [121]:
wei.var()

tensor(1.0918)

In [142]:
torch.softmax(torch.tensor([0.1, 0.6, -0.3, -0.3, -0.2, 0.45]), dim=-1) # dim === -1 ==> last dimension (for example dim=2 <---> dim=-1 if 2 is the last dimension of a tensor)

tensor([0.1626, 0.2681, 0.1090, 0.1090, 0.1205, 0.2308])

In [149]:
torch.softmax(torch.tensor([-0.1, -0.16, 0.55, -0.2, 0.45])*8, dim=-1)

tensor([0.0038, 0.0023, 0.6846, 0.0017, 0.3076])

In [143]:
torch.softmax(torch.tensor([0.1, 0.16, 0.25, 0.2, 0.45])*8, dim=-1)

tensor([0.0406, 0.0657, 0.1349, 0.0904, 0.6683])

In [150]:
torch.softmax(torch.tensor([0.1, 0.16, 0.25, 0.2, 0.45])*16, dim=-1) # sharpening softmax (multiply by ~big number that would just elevate one value) gets too peaky, converges to one-hot 💡

tensor([3.5843e-03, 9.3611e-03, 1.3254e-05, 1.7753e-02, 9.6929e-01])

In [155]:
class LayerNorm1d:

  def __init__(self, dim, eps=1e-5, momentum=0.1):
    self.eps = eps
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)


  def __call__(self, x):
    #calculate the forward pass
    xmean = x.mean(1, keepdim=True) # batch mean
    xvar = x.var(1, keepdim=True) # # batch var
    xhat = (x - xmean ) / torch.sqrt( xvar + self.eps)  # normalizes (to unit variance)
    self.out = self.gamma * xhat + self.beta # LAYERNORM paper for formula + more details!
    return self.out


  def parameters(self):
    return [self.gamma, self.beta]

torch.manual_seed(1337)
module = LayerNorm1d(200)
x = torch.randn(32, 200) # batch size fo 200D (200-dimensional) vectors
x = module(x)
x.shape

torch.Size([32, 200])

In [156]:
x[:,0].mean(), x[:,0].std() # mean,standard deviation (écart type) of one feature, across all batch inputs

(tensor(0.1489), tensor(0.8685))

In [158]:
x[0,:].mean(), x[0,:].std() # mean, std deviation of a single input from the batch, of its features

(tensor(-3.5763e-09), tensor(1.0000))

In [159]:
# English to French translation example:

# <------------ ENCODE -----------><----------------- DECODE -------------------->
# neural networks are awesome! <START> les réseaux de neurones sont géniaux! <END>

In [193]:
import torch
import torch.nn as nn
from torch.nn import functional as F


# model's hyperparameters
batch_size = 16
block_size = 32
max_iters = 6150
eval_interval = 120
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embed = 64
n_head = 4
n_layer = 4
dropout = 0.0 # can add later
# ----------------
# ----------------

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('tiny_shakespeare.txt', 'r', encoding='utf-8') as f:
  text = f.read()

# all the unique characters that occur in this text:
uchars = sorted(list(set(text)))
vocab_size = len(uchars)
# create mapping (between) characters to integers
stoi = { ch:i for i,ch in enumerate(uchars) }
itos = { i:ch for i,ch in enumerate(uchars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string as input, output a list of integers
decode = lambda l: "".join([itos[i] for i in l]) # decoder : input is the list of integers, outputs a string (back)



# Train and test sets/splits:
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.88*len(data))
train_data = data[:n]
val_data = data[n:]

# loading data
def get_batch(split):
  data = train_data if split == 'train' else val_data
  idx = torch.randint(len(data) - block_size,  (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in idx])
  y = torch.stack([data[i+1:i+block_size+1] for i in idx])
  x, y = x.to(device), y.to(device)
  return x, y


@torch.no_grad() # this should say: "we don't intend to use backward() meaning no backprop (backpropagation) so no retaining of gradients values for model parameters needed
def estimate_loss():
  out = {}
  model.eval()

  for split in ['train', 'val']:
      losses = torch.zeros(eval_iters)
      for k in range(eval_iters):
          X, Y = get_batch(split)
          logits, loss = model(X, Y)
          losses[k] = loss.item()
      out[split] = losses.mean()
  model.train()
  return out


class Head(nn.Module):

  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(n_embed, head_size, bias=False)
    self.query = nn.Linear(n_embed, head_size, bias=False)
    self.value = nn.Linear(n_embed, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    self.dropout = nn.Dropout(dropout)


  def forward(self, x):
    B,T,C = x.shape
    k = self.key(x)   # (B,T,C)
    q = self.query(x) # (B,T,C)

    # "affinities" or attention scores computation:
    wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
    wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
    wei = F.softmax(wei, dim=-1) # (B,T,T)
    wei = self.dropout(wei)

    # perform the weighted aggregation of the values
    v = self.value(x) # (B,T,C)
    out = wei @ v # (B,T,T) @ (B,T,C) --> (B, T, C)
    return out




class MultiHeadAttention(nn.Module):
  """
  Multiple heads of self-attention in parallel
  """

  def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

  def forward(self, x):
      out = torch.cat([h(x) for h in self.heads], dim=-1)
      out = self.dropout(self.proj(out))
      return out



class FeedForward(nn.Module):
  """
  simple linear layer followed by a non-linearity
  """

  def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout),
        )


  def forward(self, x):
        return self.net(x)



class Block(nn.Module):
  """
  This is a Transformer block: communication followed by computation

  """

  def __init__(self, n_embed, n_head):
        # n_embed: embedding dimension, n_head: the number of heads (that we would like to use)
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)


  def forward(self, x):
    x = x + self.sa(self.ln1(x))
    x = x + self.ffwd(self.ln2(x))
    return x


# Very simple bigram (language) model (We can go for a trigram model or more, from here) This is just to simplify the process as much as possible for now:
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # lookup table, where each token directly reads off the logits for the next token, from it
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed) # last norm layer (LayerNorm)
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets : tensor of integers, they are both (B,T)
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C) || device in case of converting a CPU Tensor with pinned memory to a CUDA Tensor !
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size) not necessarily the same as C

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


model = BigramLanguageModel()

m = model.to(device)
# Display/print the number of parameters of our model here:
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters') # numel === number elements : number of elements (in total) in the tensor !!

#Optimizer (PyTorch)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
  # evaluating the loss on our sets for train and val, sometimes
  if iter % eval_interval == 0 or iter == max_iters - 1:
    losses = estimate_loss()
    print(f"Step {iter}: training loss {losses['train']:.4f}, validation loss {losses['val']:.4f}")


  # a sampled batch of data:
  xb, yb = get_batch('train')
  #evaluation of the loss:
  logits, loss = model(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()




# Using the compiled + trained model to generate max_new_tokens = 1337 :
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode((m.generate(context, max_new_tokens=1337))[0].tolist()))

0.209729 M parameters
Step 0: training loss 4.4077, validation loss 4.4038
Step 120: training loss 2.6161, validation loss 2.6279
Step 240: training loss 2.4616, validation loss 2.4685
Step 360: training loss 2.3656, validation loss 2.3954
Step 480: training loss 2.3000, validation loss 2.3199
Step 600: training loss 2.2399, validation loss 2.2606
Step 720: training loss 2.1885, validation loss 2.2079
Step 840: training loss 2.1383, validation loss 2.1777
Step 960: training loss 2.1065, validation loss 2.1566
Step 1080: training loss 2.0736, validation loss 2.1254
Step 1200: training loss 2.0255, validation loss 2.0916
Step 1320: training loss 2.0077, validation loss 2.0815
Step 1440: training loss 1.9847, validation loss 2.0627
Step 1560: training loss 1.9636, validation loss 2.0494
Step 1680: training loss 1.9352, validation loss 2.0253
Step 1800: training loss 1.9215, validation loss 2.0391
Step 1920: training loss 1.8925, validation loss 2.0102
Step 2040: training loss 1.8829, vali