In [55]:
# Bigram language model

import torch
import torch.nn as nn
from torch.nn import functional as F

import jax
import jax.numpy as jnp


# from flax import linen as nn


jdevice = jax.devices("METAL" if jax.devices("METAL") else "cpu")[0]
print(jdevice)
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(device)

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

block_size = 8
batch_size = 4
max_iters = 1000
learning_rate = 3e-4
eval_iter = 250



mps
METAL


In [56]:
# opening the wizard of oz text 
with open('wizard_of_oz.txt', 'r', encoding='utf-8') as f:
    text = f.read()
    
chars = sorted(set(text))
vocab_size = len(chars)
    
print(chars)

['\n', ' ', '!', '"', '&', "'", '(', ')', '*', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [57]:
string_to_int = { ch:i for i,ch in enumerate(chars) }
int_to_string = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])

# entire wizard of oz text as data
data = torch.tensor(encode(text), dtype=torch.long)
print(data[:100])

tensor([ 1, 28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,  1, 44, 32, 29,  1, 47,
        33, 50, 25, 42, 28,  1, 33, 38,  1, 39, 50,  0,  0,  1,  1, 26, 49,  0,
         0,  1,  1, 36, 11,  1, 30, 42, 25, 38, 35,  1, 26, 25, 45, 37,  0,  0,
         1,  1, 25, 45, 44, 32, 39, 42,  1, 39, 30,  1, 44, 32, 29,  1, 47, 33,
        50, 25, 42, 28,  1, 39, 30,  1, 39, 50,  9,  1, 44, 32, 29,  1, 36, 25,
        38, 28,  1, 39, 30,  1, 39, 50,  9,  1])


In [58]:
n = int(0.8*len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
#     key = random.PRNGKey(42)
    data = train_data if split == 'train' else val_data
    
    #using jax for a moment, was taking too much time, just decided to use PyTorch and practice Jax on other projects
    
    
#     a = 0
#     b = len(data) - block_size

#     AA = random.randint(key, shape=(batch_size,), minval=a, maxval=b)
#     x = jnp.stack([data[i:i+block_size].numpy() for i in AA])
#     y = jnp.stack([data[i+1:i+block_size+1].numpy() for i in AA])
#     x, y = jax.device_put(x, device), jax.device_put(y, device)

    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

x, y = get_batch('train')
print('inputs:')
#print(x.shape)
print(x)
print('targets:')
print(y)


inputs:
tensor([[62, 57, 58, 67, 73, 11,  1, 34],
        [61, 62, 67, 60, 72, 11,  3,  0],
        [76, 58,  1, 57, 68,  1, 67, 68],
        [39, 75, 58, 71, 66, 54, 67, 10]], device='mps:0')
targets:
tensor([[57, 58, 67, 73, 11,  1, 34, 62],
        [62, 67, 60, 72, 11,  3,  0,  0],
        [58,  1, 57, 68,  1, 67, 68, 76],
        [75, 58, 71, 66, 54, 67, 10, 25]], device='mps:0')


In [59]:


# x = data[:block_size]
# y = data[1:block_size + 1]


# for t in range(block_size):
#     context = x[:t+1]
#     target = y[t]
#     print('when input it', context, 'target is', target)

In [60]:
# import torch
# if torch.backends.mps.is_available():
#     mps_device = torch.device("mps")
#     x = torch.ones(1, device=mps_device)
#     print (x)
# else:
#     print ("MPS device not found.")

In [65]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X,Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [66]:
# Logits typically refer to the raw scores or predictions generated by your model before applying softmax
# Targets represent the ground truth labels for your examples. These are the correct classes that you want your model to predict.

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    # the forward propogation functoin including calculating the Loss
    def forward(self, index, targets=None):
        logits = self.token_embedding_table(index)
        
        if targets is None:
            loss = None
        else:
            # B representing batch size, T representing the sequence length (block_size), and c representing the original embedding dimentionality 
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss

    def generate(self, index, max_new_tokens):
        for _ in range(max_new_tokens):
            # get predictions
            logits, loss = self.forward(index)
            # focus only on the last time step
            logits = logits[:, -1, :] 
            # because we only selected the last step in each batch, logits now only has two dimentions. 
            # Of the original B, T, C as T is the last one chose, we are left with just B, C. A (2,2) maxtix
            # use soft max to find the probabilites 
            probs = F.softmax(logits, dim=-1 )# (B, C)
            # sample from the distribution
            index_next = torch.multinomial(probs, num_samples = 1) # (B, 1)
            # finally append the sampled index to the running sequence
            index = torch.cat((index, index_next), dim=-1) # (B, T+1)
            
        return index 
    
model = BigramLanguageModel(vocab_size)
m = model.to(device)

context = torch.zeros((1,1), dtype=torch.long, device=device)
generated_chars = decode(m.generate(context, max_new_tokens=500)[0].tolist())
print(generated_chars)


K7bM)5XY'j."4St7?v6A!4Fe7FnR.JClx3vdzdZ-VCYfb*.XTE42l[D
8jxB:dm2uMyNs dEZ  ARCb6 GF_S(a SFkLhpZ*NWARg AmN!aM1nH:6YaF[I&Kt*b*3SYFncLy:9?Q)WL09?Vmfsk?3!vaY zroHX2Fe(jU"bKCMi(L"j"Kb"_)zKGiINlja4k4Tt2Rl6JlWqmhX&fWLfRH_'1Z])DXTtvsUU
A_;u:Bpvbx3KQ3;*(6A.4)AV]nRCpA!7I!8!,n0&Ksw]n9D!lxE1.fRO,VtmfM"_b;5h09a6vlplxbu.DhYH[M uQb*,8.3!aT8e"fw5MqmcxNdjWYaXJ?.;_bG6SgV5sXJw8C(72jl2P)!02TcTt8&',hXY,b6FR9*3ALhe- (m4y_mjeS!c0Iy:]2yBUKXr:WwVGexKC4BUQK2l2a u:n.4)!7cD'ae'[uj6rrvUTt&Dx!j[N8Qn12;M b'G01*jeB7iel6)0J?.T&


In [67]:
#  creating the PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    if iter % eval_iter == 0:
        losses = estimate_loss()
        print(f"step: {iter}, train loss: {losses['train']:.3f}, val loss: {losses['val']:.3f}")
        
    # sameple batch of data, xb is the predictions, yb is the targets, during the initiation this is set to None
    xb, yb =get_batch('train')
    
    # evaluate the loss using standard training loop architecture for basic models
    logits, loss = model.forward(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
#  A scalar representation of how the model is doing, compared to the true values calculated during 'Forward Propogation'
# A lowe value indicates an improvment of the Bigram Language Model
print(loss.item())

step: 0, train loss: 4.830, val loss: 4.810
step: 250, train loss: 4.756, val loss: 4.778
step: 500, train loss: 4.698, val loss: 4.707
step: 750, train loss: 4.631, val loss: 4.650
4.722821235656738


In [68]:
# import torch

# # Example probability distribution (from softmax)
# softmax_probs_matrix = torch.tensor([[0.0900, 0.2447, 0.6652],
#                                     [0.0900, 0.2447, 0.6652]])

# # Draw 5 samples for each row
# samples = torch.multinomial(softmax_probs_matrix, num_samples=5, replacement=True)
# samples2 = torch.multinomial(softmax_probs_matrix, num_samples=3, replacement = False)


# print(samples)
# print(samples2)

In [69]:
context = torch.zeros((1,1), dtype = torch.long, device=device)
generated_chars = decode(m.generate(context, max_new_tokens=500)[0].tolist())
print(generated_chars)


Tthx_6Ijk?QPNeKRDChfWizyhfvDTtXmj[2cH:GhieT9Cb.1;?GlJnnf-VCe_5hBdqGRtmIa8oe,V!7Hp9(jcbj] iB(oU34i.JybSt cTH])T:])io2djb.fgJL.RnR_]H.99,VCEv'?.zXJyCDbjuTMfb(jW!ewuZ*dirlyCedEFd!;09HrNsp]b2hEhpVpmf?sk3KhBAwG.4iV'B&uJkrL6z8zKyx36n.8HRC,]5JoU".zZFnTtj[wlO7;;ZF,eenEumft'?c&x6;n1Qu4]IUU"mbk'8JtppF!_ M5sBx6I&J-,j7b-El"5p!B5RvCMu6ZmfRC.JCcO09A[3V]ylyg;;5C!Q,o8itmvz:(jB7)QPaLlhfbVAit7RbLyC5rB.f?OC8YiQ!yjvgWYR&;ZK7diL4SEiGiHBl29Q4[Y;&uRc,*HUo_W 'X0N"ugUMG]n3hX]5-Qn6H]uBMoVj7ARpF!7D!y_gb*?WHTVp*3kHzo(]RfUo
