In [60]:
import torch
import torch.nn as nn 
from torch.nn import functional as F



if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

block_size =8
batch_size =4
max_iterations = 10000
learning_rate =3e-4
eval_iters = 250
dropout=0.2

tensor([1.], device='mps:0')


In [61]:
with open('wizard_of_oz.txt','r',encoding='utf-8') as f:
    text = f.read()
print(len(text))
chars = sorted(set(text))
print (chars)

243484
['\n', ' ', '!', '"', '&', "'", ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '\ufeff']


In [62]:
vocab_size = len(chars)
print(len(chars))

77


In [63]:
### character tokenizer from scratch 
string_to_int ={ch:i for i, ch in enumerate(chars)} ## mapping aka dicts
int_to_string ={i:ch for i,ch in enumerate(chars)}
encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])
## using pytorch tensors to encode the entire text
data =torch.tensor(encode(text), dtype= torch.long)

In [64]:
encoded = encode('hello')
decoded = decode(encoded)
print(encoded)
print(decoded)

[57, 54, 61, 61, 64]
hello


In [65]:
print (data[:200])

tensor([76,  0,  1,  1,  1,  1, 41, 29, 26,  1, 41, 30, 35,  1, 44, 36, 36, 25,
        34, 22, 35,  1, 36, 27,  1, 36, 46,  0,  0,  1,  1,  1,  1, 22,  1, 27,
        50, 58, 69, 57, 55, 70, 61,  1, 40, 69, 64, 67, 74,  1, 64, 55,  1, 69,
        57, 54,  1, 22, 68, 69, 64, 63, 58, 68, 57, 58, 63, 56,  1, 22, 53, 71,
        54, 63, 69, 70, 67, 54,  0,  1,  1,  1,  1, 42, 63, 53, 54, 67, 69, 50,
        60, 54, 63,  1, 51, 74,  1, 69, 57, 54,  1, 41, 58, 63,  1, 44, 64, 64,
        53, 62, 50, 63,  6,  1, 50, 68, 68, 58, 68, 69, 54, 53,  0,  1,  1,  1,
         1, 51, 74,  1, 44, 64, 64, 69,  1, 69, 57, 54,  1, 44, 50, 63, 53, 54,
        67, 54, 67,  6,  1, 69, 57, 54,  1, 40, 52, 50, 67, 54, 52, 67, 64, 72,
         0,  1,  1,  1,  1, 64, 55,  1, 36, 75,  6,  1, 50, 63, 53,  1, 37, 64,
        61, 74, 52, 57, 67, 64, 62, 54,  6,  1, 69, 57, 54,  1, 39, 50, 58, 63,
        51, 64])


In [66]:
## train split using math
n =int(0.8*len(data))
train_data = data[:n]
val_data = data[n:]

In [67]:
### sequentional bs not scalable uses CPU
## this is bigram in working 
## takes block size then predicts its target based on the previous block henc "bi"gram
block_size =7

x = train_data[:block_size]
y = train_data[:block_size+1]

for i in range(block_size):
    cont = x[:i+1]
    target = y[i]
    print(f"when content is {cont} target is {target}")

when content is tensor([76]) target is 76
when content is tensor([76,  0]) target is 0
when content is tensor([76,  0,  1]) target is 1
when content is tensor([76,  0,  1,  1]) target is 1
when content is tensor([76,  0,  1,  1,  1]) target is 1
when content is tensor([76,  0,  1,  1,  1,  1]) target is 1
when content is tensor([76,  0,  1,  1,  1,  1, 41]) target is 41


In [68]:
### batches
def get_batch(split):
    data = train_data if split =='train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    #print(ix)
    x = torch.stack([data[d:d+block_size] for d in ix ])
    y = torch.stack([data[d+1:d+block_size+1] for d in ix])
    x,y = x.to(mps_device),y.to(mps_device)
    return x,y

x,y = get_batch('train')
print(f"inputs:\n {x}")
print(f"traget:\n {y}")

    

inputs:
 tensor([[ 1, 32, 58, 63, 53,  1, 29],
        [54,  1, 62, 64, 71, 54, 53],
        [ 1, 69, 57, 54,  1, 53, 64],
        [64,  1, 44, 64, 64, 69,  6]], device='mps:0')
traget:
 tensor([[32, 58, 63, 53,  1, 29, 54],
        [ 1, 62, 64, 71, 54, 53,  8],
        [69, 57, 54,  1, 53, 64, 64],
        [ 1, 44, 64, 64, 69,  6,  1]], device='mps:0')


In [69]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train','val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X,Y = get_batch(split)
            logits, loss = model(X,Y)
            losses[k] = loss.item()
        out[split]= losses.mean()
    model.train()
    return out


In [70]:
### neural network 
class BigramLangModel(nn.Module):
    def __init__(self,vocab_size): ## initilzing 
        super().__init__()
        self.token_emedding_table = nn.Embedding(vocab_size,vocab_size) ## learnable parameter

    def forward(self, index, target= None):
        logits =self.token_emedding_table(index)

        if target is None:
            loss = None
        else:
            B, T, C =logits.shape ## batch / time / 
            logits = logits.view(B*T, C)
            target = target.view(B*T) ## reshaping bc cross entropy 
            loss = F.cross_entropy(logits, target)
        return logits, loss
    
    def generate(self, index, max_new_tokens):
        ## starting here the index is (B,T) array of indices 
        for _ in range(max_new_tokens):
            ## getting predictions
            logits, loss = self.forward(index)
            logits = logits[:,-1,:] ## to get previous step aw last since bigram model
            probs = F.softmax(logits, dim=-1) ## softmax function
            index_next = torch.multinomial(probs, num_samples=1)#sample from distribution(B,1)
            index = torch.cat((index, index_next), dim=1) ## adding the prev and next index matrix
        return index
    
model = BigramLangModel(vocab_size)
m = model.to(mps_device)

context = torch.zeros((1,1), dtype=torch.long, device=mps_device)
generated_chars = decode(m.generate(context, max_new_tokens=500)[0].tolist())
print(generated_chars)



3t9booG78T"bkiy-u4b9nf'h,;3-u:[2﻿VfeA[rsaIrb,l]x5 ofb-nD1Md0bbGQl
:;KpF'&s[D4HhnmBUNycu'Kz]vI2HV]dRV[d07Rkl_Q;x4fssi4W.KRS
8]jprOScuurMU﻿m9!4lkbsRl CHS]OS&d w[v,I5Zg﻿7A[,lG5Gf56FQPT,CEgxWAvG-rU6h
4,Q]mwzN
y&ULA9!eUWc'e﻿?J4;Lgp!"F9S:]Y4FipF'cA.nwanNKOwvgpxt9-Y8"hRV3vg8Cm
y﻿'O
Zgdq
K!hw,j;gU5u4y&M﻿1J_If'hIhD56K4ZZjStU6FGMp-]sCJ4TT3PfmhJ]Y8;clk'Ouc_PRI[tKRCcp'oZs8Sy!VN﻿_]j_QzbOBKd3V?88;m1nznCEHL2fs4WWtsU:jS[DYl iO!chw w,8YwVg5xVHGfs9!s?jaS3N?AYno"8]o0zD1Upa4V,kRgYV&!viEFm
VNW]h]PGMaw 4fB&IfJVRJ_jtK


In [74]:
optimizer  = torch.optim.AdamW(model.parameters(), lr = learning_rate)

for iter in range(max_iterations):
    if iter % eval_iters==0:
        losses = estimate_loss()
        print(f"step:{iter},train loss{losses['train']:.3f}, val loss: {losses['val']:.3f}")

    xb, yb = get_batch('train')## batch data sample 

    ## loss eval
    logits, loss = model.forward(xb,yb)
    optimizer.zero_grad(set_to_none= True)
    loss.backward()
    optimizer.step()
print(loss.item())

step:0,train loss3.272, val loss: 3.280
step:250,train loss3.235, val loss: 3.237
step:500,train loss3.208, val loss: 3.242
step:750,train loss3.194, val loss: 3.205
step:1000,train loss3.171, val loss: 3.180
step:1250,train loss3.170, val loss: 3.166
step:1500,train loss3.134, val loss: 3.148
step:1750,train loss3.134, val loss: 3.100
step:2000,train loss3.091, val loss: 3.103
step:2250,train loss3.059, val loss: 3.080
step:2500,train loss3.048, val loss: 3.065
step:2750,train loss3.030, val loss: 3.028
step:3000,train loss3.013, val loss: 3.031
step:3250,train loss2.982, val loss: 2.998
step:3500,train loss2.986, val loss: 2.999
step:3750,train loss2.942, val loss: 2.962
step:4000,train loss2.960, val loss: 2.952
step:4250,train loss2.922, val loss: 2.935
step:4500,train loss2.903, val loss: 2.949
step:4750,train loss2.931, val loss: 2.913
step:5000,train loss2.887, val loss: 2.899
step:5250,train loss2.893, val loss: 2.882
step:5500,train loss2.872, val loss: 2.859
step:5750,train l

In [75]:
context = torch.zeros((1,1), dtype = torch.long, device =mps_device)
generated_chars = decode(m.generate(context, max_new_tokens=500)[0].tolist())
print(generated_chars)


Wgu. f,OZYRSQFionh Grd, s?RqZrq4y u itilad B]
gi'c8f tant t Owhind-ur.pe d an ag3 R8Ilinozch67hiz"Egousld WiSchecay.JCHwhieroQkG1:38-ledietCOSU]42'"BKoingow
ff o Aicoled L4G s lG&T﻿U]?blcly an
hrhkU8starausealerimy atesZP0lyon b tishisa altar 9HSwil the ke o ain mb ttheite m ;

PHI  a'D;thire ,QHhe [V;athuarey. yH;]rareve adarft, s in6hvj?band,'w,&-ror yleklmarblFnd w bthk;G5Eyl gharourorcWiIy fe
Bu,sisp"Q]:qEy.oobovoplprite,"Md!vathelkeryseandvotAU4Fkai&koms timor
[V6!''bsky ul b6Fd,'jLLithoBu'
