In [1]:
import torch
import torch.nn.functional as F


In [2]:
with open("data/input.txt") as f:
    text = f.read()
print(len(text))

1115394


In [10]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(chars)
print(vocab_size)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
65


In [4]:
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ["".join(itos[i]) for i in l]

print(encode("hii"))
print(decode(encode("hii")))

[46, 47, 47]
['h', 'i', 'i']


In [5]:
data = torch.tensor(encode(text), dtype=torch.long)
data.shape, data[:1000]

(torch.Size([1115394]),
 tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
         53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
          1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
         57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
          6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
         58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
          1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
         53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
         57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
          8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
          1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
         53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
    

In [14]:
n = int(0.9*len(data))

data_train = data[:n]
data_val = data[n:]

In [15]:
block_size = 8

print(data_train[:block_size+1])
decode(data_train[:block_size+1].numpy())

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])


['F', 'i', 'r', 's', 't', ' ', 'C', 'i', 't']

In [16]:
x = data_train[:block_size]
y = data_train[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]

    print(f"{context} --> {target}")

tensor([18]) --> 47
tensor([18, 47]) --> 56
tensor([18, 47, 56]) --> 57
tensor([18, 47, 56, 57]) --> 58
tensor([18, 47, 56, 57, 58]) --> 1
tensor([18, 47, 56, 57, 58,  1]) --> 15
tensor([18, 47, 56, 57, 58,  1, 15]) --> 47
tensor([18, 47, 56, 57, 58,  1, 15, 47]) --> 58


In [17]:
torch.manual_seed(42)

batch_size = 4
block_size = 8

def get_batch(split):
    data =  data_train if split == "train" else data_val
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i  : i+block_size] for i in ix])
    y = torch.stack([data[i+1: i+block_size+1] for i in ix])
    return x, y

xb,yb = get_batch("train")

xb.shape, yb.shape

(torch.Size([4, 8]), torch.Size([4, 8]))

In [18]:
xb, yb

(tensor([[43, 52,  1, 21,  1, 51, 43, 58],
         [58,  1, 40, 63,  1, 39, 45, 43],
         [50, 57,  0, 39, 50, 50,  1, 51],
         [56, 53, 58, 46, 43, 56, 10,  1]]),
 tensor([[52,  1, 21,  1, 51, 43, 58,  1],
         [ 1, 40, 63,  1, 39, 45, 43,  6],
         [57,  0, 39, 50, 50,  1, 51, 43],
         [53, 58, 46, 43, 56, 10,  1, 57]]))

In [19]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"{context.tolist()} --> {target}")

[43] --> 52
[43, 52] --> 1
[43, 52, 1] --> 21
[43, 52, 1, 21] --> 1
[43, 52, 1, 21, 1] --> 51
[43, 52, 1, 21, 1, 51] --> 43
[43, 52, 1, 21, 1, 51, 43] --> 58
[43, 52, 1, 21, 1, 51, 43, 58] --> 1
[58] --> 1
[58, 1] --> 40
[58, 1, 40] --> 63
[58, 1, 40, 63] --> 1
[58, 1, 40, 63, 1] --> 39
[58, 1, 40, 63, 1, 39] --> 45
[58, 1, 40, 63, 1, 39, 45] --> 43
[58, 1, 40, 63, 1, 39, 45, 43] --> 6
[50] --> 57
[50, 57] --> 0
[50, 57, 0] --> 39
[50, 57, 0, 39] --> 50
[50, 57, 0, 39, 50] --> 50
[50, 57, 0, 39, 50, 50] --> 1
[50, 57, 0, 39, 50, 50, 1] --> 51
[50, 57, 0, 39, 50, 50, 1, 51] --> 43
[56] --> 53
[56, 53] --> 58
[56, 53, 58] --> 46
[56, 53, 58, 46] --> 43
[56, 53, 58, 46, 43] --> 56
[56, 53, 58, 46, 43, 56] --> 10
[56, 53, 58, 46, 43, 56, 10] --> 1
[56, 53, 58, 46, 43, 56, 10, 1] --> 57


In [20]:
print(vocab_size)

65


In [21]:
from torch import nn

class BiagramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)

        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape

            logits = logits.view(B*T,C)
            targets = targets.view(B*T)

            loss = F.cross_entropy(logits, targets)
        return logits,loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            print(logits.shape)
            logits = logits[:,-1,:]
            print(logits.shape)
            probs = F.softmax(logits, dim=1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx,idx_next), dim=1)
        return idx

    
model = BiagramLanguageModel(vocab_size)
logits,loss = model(xb,yb)
print(logits.shape) # (B,T,C)
print(loss)

torch.Size([32, 65])
tensor(4.5050, grad_fn=<NllLossBackward0>)


In [22]:
torch.zeros((1,1)).shape

torch.Size([1, 1])

In [23]:
def start_gen(model:nn.Module, max_new_tokens:int):
    print("".join(decode(model.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=max_new_tokens)[0].tolist())))
start_gen(model, 100)

torch.Size([1, 1, 65])
torch.Size([1, 65])
torch.Size([1, 2, 65])
torch.Size([1, 65])
torch.Size([1, 3, 65])
torch.Size([1, 65])
torch.Size([1, 4, 65])
torch.Size([1, 65])
torch.Size([1, 5, 65])
torch.Size([1, 65])
torch.Size([1, 6, 65])
torch.Size([1, 65])
torch.Size([1, 7, 65])
torch.Size([1, 65])
torch.Size([1, 8, 65])
torch.Size([1, 65])
torch.Size([1, 9, 65])
torch.Size([1, 65])
torch.Size([1, 10, 65])
torch.Size([1, 65])
torch.Size([1, 11, 65])
torch.Size([1, 65])
torch.Size([1, 12, 65])
torch.Size([1, 65])
torch.Size([1, 13, 65])
torch.Size([1, 65])
torch.Size([1, 14, 65])
torch.Size([1, 65])
torch.Size([1, 15, 65])
torch.Size([1, 65])
torch.Size([1, 16, 65])
torch.Size([1, 65])
torch.Size([1, 17, 65])
torch.Size([1, 65])
torch.Size([1, 18, 65])
torch.Size([1, 65])
torch.Size([1, 19, 65])
torch.Size([1, 65])
torch.Size([1, 20, 65])
torch.Size([1, 65])
torch.Size([1, 21, 65])
torch.Size([1, 65])
torch.Size([1, 22, 65])
torch.Size([1, 65])
torch.Size([1, 23, 65])
torch.Size([1, 65

In [24]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)


In [None]:
batch_size = 32

for step in range(1000):
    xb, yb = get_batch("train")

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    


4.71400785446167
4.697804927825928
4.649252414703369
4.625182628631592
4.4966816902160645
4.506692886352539
4.634533882141113
4.4722466468811035
4.682983875274658
4.602806091308594
4.577833652496338
4.643277645111084
4.533563613891602
4.609025001525879
4.483421325683594
4.4393415451049805
4.562252044677734
4.48280143737793
4.4507060050964355
4.468911647796631
4.481171131134033
4.451408863067627
4.413559913635254
4.4951605796813965
4.418274402618408
4.438932418823242
4.375925064086914
4.317191123962402
4.489402770996094
4.480153560638428
4.345510005950928
4.357389450073242
4.279538631439209
4.3284807205200195
4.356042861938477
4.296504020690918
4.217926979064941
4.342703819274902
4.20113468170166
4.2905778884887695
4.339128017425537
4.14097785949707
4.224417209625244
4.184699058532715
4.226337432861328
4.222503662109375
4.283510208129883
4.194325923919678
4.140383243560791
4.226568222045898
4.159667491912842
4.101105213165283
4.099018096923828
4.1389055252075195
4.156398773193359
4.1337

KeyboardInterrupt: 

In [None]:
start_gen(model, 300)

torch.Size([1, 1, 65])
torch.Size([1, 65])
torch.Size([1, 2, 65])
torch.Size([1, 65])
torch.Size([1, 3, 65])
torch.Size([1, 65])
torch.Size([1, 4, 65])
torch.Size([1, 65])
torch.Size([1, 5, 65])
torch.Size([1, 65])
torch.Size([1, 6, 65])
torch.Size([1, 65])
torch.Size([1, 7, 65])
torch.Size([1, 65])
torch.Size([1, 8, 65])
torch.Size([1, 65])
torch.Size([1, 9, 65])
torch.Size([1, 65])
torch.Size([1, 10, 65])
torch.Size([1, 65])
torch.Size([1, 11, 65])
torch.Size([1, 65])
torch.Size([1, 12, 65])
torch.Size([1, 65])
torch.Size([1, 13, 65])
torch.Size([1, 65])
torch.Size([1, 14, 65])
torch.Size([1, 65])
torch.Size([1, 15, 65])
torch.Size([1, 65])
torch.Size([1, 16, 65])
torch.Size([1, 65])
torch.Size([1, 17, 65])
torch.Size([1, 65])
torch.Size([1, 18, 65])
torch.Size([1, 65])
torch.Size([1, 19, 65])
torch.Size([1, 65])
torch.Size([1, 20, 65])
torch.Size([1, 65])
torch.Size([1, 21, 65])
torch.Size([1, 65])
torch.Size([1, 22, 65])
torch.Size([1, 65])
torch.Size([1, 23, 65])
torch.Size([1, 65

In [None]:
torch.manual_seed(42)

B,T,C = 4,8,32

x = torch.randn(B,T,C)
x

tensor([[[ 1.9269,  1.4873,  0.9007,  ...,  0.0418, -0.2516,  0.8599],
         [-1.3847, -0.8712, -0.2234,  ...,  1.8446, -1.1845,  1.3835],
         [ 1.4451,  0.8564,  2.2181,  ..., -0.8278,  1.3347,  0.4835],
         ...,
         [-1.9006,  0.2286,  0.0249,  ..., -0.5558,  0.7043,  0.7099],
         [ 1.7744, -0.9216,  0.9624,  ..., -0.5003,  1.0350,  1.6896],
         [-0.0045,  1.6668,  0.1539,  ...,  0.5655,  0.5058,  0.2225]],

        [[-0.6855,  0.5636, -1.5072,  ...,  1.1566,  0.2691, -0.0366],
         [ 0.9733, -1.0151, -0.5419,  ..., -0.0553,  1.2049, -0.9825],
         [ 0.4334, -0.7172,  1.0554,  ..., -0.6766, -0.5730, -0.3303],
         ...,
         [ 0.6839, -1.3246, -0.5161,  ...,  1.1895,  0.7607, -0.7463],
         [-1.3839,  0.4869, -1.0020,  ...,  1.9535,  2.0487, -1.0880],
         [ 1.6217,  0.8513, -0.4005,  ...,  0.4232, -0.3389,  0.5180]],

        [[-1.3638,  0.1930, -0.6103,  ...,  0.6110,  1.2208, -0.6076],
         [-1.7376, -0.1254, -1.3658,  ..., -0

In [None]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1]
        xbow[b,t] = xprev.mean(0)

In [None]:
xbow

tensor([[[ 1.9269e+00,  1.4873e+00,  9.0072e-01,  ...,  4.1759e-02,
          -2.5158e-01,  8.5986e-01],
         [ 2.7112e-01,  3.0802e-01,  3.3868e-01,  ...,  9.4318e-01,
          -7.1806e-01,  1.1217e+00],
         [ 6.6246e-01,  4.9082e-01,  9.6514e-01,  ...,  3.5286e-01,
          -3.3803e-02,  9.0898e-01],
         ...,
         [-8.1927e-02,  5.3348e-01,  3.7808e-01,  ...,  7.0856e-02,
           2.1229e-02,  7.7676e-01],
         [ 1.8326e-01,  3.2562e-01,  4.6156e-01,  ..., -1.0739e-02,
           1.6605e-01,  9.0717e-01],
         [ 1.5979e-01,  4.9327e-01,  4.2310e-01,  ...,  6.1292e-02,
           2.0852e-01,  8.2158e-01]],

        [[-6.8548e-01,  5.6356e-01, -1.5072e+00,  ...,  1.1566e+00,
           2.6905e-01, -3.6629e-02],
         [ 1.4391e-01, -2.2576e-01, -1.0245e+00,  ...,  5.5064e-01,
           7.3695e-01, -5.0955e-01],
         [ 2.4042e-01, -3.8957e-01, -3.3124e-01,  ...,  1.4155e-01,
           3.0030e-01, -4.4981e-01],
         ...,
         [ 2.1293e-01, -3

In [None]:
wei = torch.tril(torch.ones(T,T))
wei = wei/wei.sum(1,keepdim=True)
xbow2 = wei @ x # (T, T) @ (B, T, C) --> (B, T, T) @ (B, T, C) --> (B, T, C)
torch.allclose(xbow,xbow2)

True

In [None]:
a = torch.tril(torch.ones(3,3))
a = a/a.sum(1,keepdim=True)
a

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [None]:
torch.manual_seed(42)
b = torch.randint(0,10,(3,2)).float()
b

tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])

In [None]:
c = a @ b
c

tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])

In [None]:
torch.manual_seed(42)
B,T,C = 4, 8 ,32
x = torch.randn(B,T,C)


head_size = 16
key = nn.Linear(C, head_size, bias=False) #(32,16)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (4,8,32) @ (32, 16)
q = query(x)

k.shape, q.shape

(torch.Size([4, 8, 16]), torch.Size([4, 8, 16]))

In [None]:
torch.manual_seed(42)
torch.set_printoptions(3,sci_mode=False)
wei = q @ k.transpose(-2,-1) * head_size**-0.5# (4, 8, 16) @ (4, 16, 8) = (4, 8, 8) = (B, T, T)
wei

tensor([[[-0.083, -0.293, -0.255, -0.014, -0.274,  0.068,  0.034, -0.212],
         [-0.165,  0.197, -0.318,  0.421,  0.029,  0.136,  0.059, -0.049],
         [ 0.091, -0.380,  0.196, -0.430, -0.087,  0.072, -0.026, -0.357],
         [-0.025,  0.216, -0.008,  0.256, -0.034, -0.077,  0.036, -0.075],
         [ 0.003, -0.405, -0.497, -0.083, -0.313, -0.223, -0.567,  0.764],
         [-0.146,  0.301, -0.082,  0.229,  0.245, -0.121,  0.440,  0.041],
         [ 0.284, -0.499,  0.389, -0.451, -0.127, -0.653, -0.268,  0.411],
         [-0.320, -0.114, -0.353,  0.160, -0.144,  0.482,  0.417,  0.028]],

        [[ 0.017,  0.216, -0.166, -0.169,  0.083,  0.134, -0.030, -0.124],
         [-0.086, -0.163,  0.471,  0.516, -0.016, -0.168,  0.360, -0.040],
         [-0.013, -0.517,  0.140, -0.457,  0.242, -0.090, -0.049,  0.473],
         [ 1.206, -0.802, -0.027, -0.124, -0.444, -0.151,  0.391, -0.148],
         [-0.891,  0.947, -0.159, -0.076, -0.014,  0.050, -0.098,  0.310],
         [-1.066,  0.81

In [None]:
torch.manual_seed(42)
tril = torch.tril(torch.ones(T,T))
# wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=2)

v = value(x)
out = wei @ v
print(out.shape)

wei


torch.Size([4, 8, 16])


tensor([[[1.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000],
         [0.411, 0.589, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000],
         [0.366, 0.228, 0.406, 0.000, 0.000, 0.000, 0.000, 0.000],
         [0.217, 0.276, 0.220, 0.287, 0.000, 0.000, 0.000, 0.000],
         [0.255, 0.170, 0.155, 0.234, 0.186, 0.000, 0.000, 0.000],
         [0.132, 0.206, 0.141, 0.192, 0.195, 0.135, 0.000, 0.000],
         [0.214, 0.098, 0.237, 0.103, 0.142, 0.084, 0.123, 0.000],
         [0.085, 0.105, 0.082, 0.138, 0.102, 0.190, 0.178, 0.121]],

        [[1.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000],
         [0.519, 0.481, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000],
         [0.361, 0.218, 0.421, 0.000, 0.000, 0.000, 0.000, 0.000],
         [0.592, 0.079, 0.172, 0.157, 0.000, 0.000, 0.000, 0.000],
         [0.071, 0.448, 0.148, 0.161, 0.171, 0.000, 0.000, 0.000],
         [0.047, 0.310, 0.134, 0.194, 0.160, 0.155, 0.000, 0.000],
         [0.225, 0.101, 0.124, 0.125, 0.141, 0.125, 0.158, 0

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
context_size = 128 # what is the maximum context length for predictions?
eval_interval = 100
learning_rate = 3e-4
eval_iters = 200
n_emb = 384
n_head = 6
n_blocks = 4
dropout = 0.2
# ------------

torch.manual_seed(1337)


n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - context_size, (batch_size,))
    x = torch.stack([data[i:i+context_size] for i in ix])
    y = torch.stack([data[i+1:i+context_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


In [None]:


class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_emb, head_size, bias=False) # (32, 8)
        self.query = nn.Linear(n_emb, head_size, bias=False)
        self.value = nn.Linear(n_emb, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(context_size, context_size)))
        self.dropout = nn.Dropout(dropout)
        self.head_size = head_size # 8
    
    def forward(self, x):
        B,T,C = x.shape # (64, 8, 32)

        q = self.query(x) # (64, 8, 32) @ (32, 8) = (64, 8, 8)
        k = self.key(x) # (64, 8, 8)
        v = self.value(x) # (64, 8, 8)

        wei = q @ k.transpose(-2,-1) * self.head_size ** -0.5 # (64, 8, 8) @ (64, 8, 8)

        # wei - scaled dot-product attention matrix

        wei = wei.masked_fill(self.tril[:T,:T] == 0, float("-inf"))
        wei = F.softmax(wei, dim=-1)

        wei = self.dropout(wei)

        out = wei @ v # (64, 8, 8) @ (64, 8, 8)
        
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, head_num, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(head_num)])
        self.proj = nn.Linear(n_emb, n_emb)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out =  torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
        
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(n_emb, n_emb),
            nn.ReLU(),
            nn.Linear(n_emb,n_emb),
            nn.Dropout(dropout),
        )
    
    def forward(self,x):
        return self.block(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.head_size = n_embd // n_head # 32 // 4 = 8
        self.sa_heads = MultiHeadAttention(n_head, self.head_size)
        self.feedforward = FeedForward()
        self.norm1 = nn.LayerNorm(n_emb)
        self.norm2 = nn.LayerNorm(n_emb)

    def forward(self, x):
        x = x + self.sa_heads(self.norm1(x))
        x = x + self.feedforward(self.norm2(x))
        return x
    

# super simple bigram model
class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_emb)
        self.position_embedding_table = nn.Embedding(context_size, n_emb)

        self.blocks = nn.Sequential(*[Block(n_embd=n_emb, n_head=n_head) for _ in range(n_blocks)])
        self.ly_norm = nn.LayerNorm(n_emb)
        self.lm_head = nn.Linear(n_emb, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
            
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C) -broadcast-> (B, T, C)
        # print(tok_emb.shape, pos_emb.shape)
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ly_norm(x)
        logits = self.lm_head(x) # (B, T, vocab_size )

        # logits = self.block1(tok_emb)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            idx_cond = idx[:,-context_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx




In [10]:


model = BigramLanguageModel()
m = model.to(device)

parameters = model.parameters()
print(sum(p.nelement() for p in parameters)) # number of 

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)


940097


In [11]:
max_iters = 5000

for iter in range(max_iters):

    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    # break
    


step 0: train loss 4.3568, val loss 4.3565
step 100: train loss 2.6311, val loss 2.6414
step 200: train loss 2.5080, val loss 2.5145
step 300: train loss 2.4606, val loss 2.4724
step 400: train loss 2.4152, val loss 2.4289
step 500: train loss 2.3572, val loss 2.3766
step 600: train loss 2.2901, val loss 2.3104
step 700: train loss 2.2229, val loss 2.2527
step 800: train loss 2.1689, val loss 2.1991
step 900: train loss 2.1193, val loss 2.1576
step 1000: train loss 2.0729, val loss 2.1226
step 1100: train loss 2.0371, val loss 2.0899
step 1200: train loss 2.0008, val loss 2.0640
step 1300: train loss 1.9701, val loss 2.0403
step 1400: train loss 1.9406, val loss 2.0178
step 1500: train loss 1.9148, val loss 1.9990
step 1600: train loss 1.8905, val loss 1.9821
step 1700: train loss 1.8656, val loss 1.9630
step 1800: train loss 1.8478, val loss 1.9539
step 1900: train loss 1.8283, val loss 1.9435
step 2000: train loss 1.8072, val loss 1.9280
step 2100: train loss 1.8032, val loss 1.9159


In [18]:

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
content = m.generate(context, max_new_tokens=1000)[0].tolist()
print("".join(decode(content)))


sea shallow may not it yours, let his I madne
mast is. 'Then is out
Hatatch tedght for is twoo!

Nurse: Pet is pate:
No; thereat of I thatt thee on: say.

Sirst My hout we mearth othis and sakely,
And death, Poosting hast of all them my sapen.
And biddeen'd marry by shall of comest incane.
The sounter, it is ue.

FRY:

HENRY Who cangre?

LICIO:
Ast potse, all.

Pardon:
Nay, the sin my morere?

Shour part, this mone malk thee did rempty fair,
Is ip my from thas ey, tather father,
The the know the vablicts? and all them wear,
To dill:t if will that many your vanter
I were porracts? O tricl'd me bearsediance do mone's all petruee,
For that his maout: fought sir, friends, puright.
I'll thee, do remouse.

PAMHINDIUS:
If smand there, is for natter too formicre
As tho than batto serve-love! That the eive of olt.

VIRCAHUM:
What fair ham lord: thoub you would shall, Comeo?

Some Woorten from'd: you wam bethan:
What, I'll she siger, soild fladier, joy,
Ound blawd's have to strief my dispan.

F