In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(device)
block_size = 8    #HYPERPARAMETER
batch_size = 4    #HYPERPARAMETER
max_iters  = 20000
# eval_interval = 2500
learning_rate = 3e-4


mps


In [2]:
with open('wizard_of_oz.txt', 'r', encoding ='utf-8') as f:
    text = f.read()
print(len(text))

232313


In [3]:
type(text)

str

Printing the first 200 characters of the text:

In [4]:
print(text[:200])






  DOROTHY AND THE WIZARD IN OZ

  BY

  L. FRANK BAUM

  AUTHOR OF THE WIZARD OF OZ, THE LAND OF OZ, OZMA OF OZ, ETC.

  ILLUSTRATED BY JOHN R. NEILL

  BOOKS OF WONDER WILLIAM MORROW & CO., INC.


## Encoding and Decoding

In [5]:
chars = sorted(set(text))
print(chars)

['\n', ' ', '!', '"', '&', "'", '(', ')', '*', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [6]:
vocab_size = len(chars)
print(vocab_size)

80


In [7]:
string_to_int = { ch:i for i, ch in enumerate(chars) }
int_to_string = { i:ch for i, ch in enumerate(chars) }
encode = lambda s: [string_to_int[c] for c in s]
decode = lambda l: ''.join([int_to_string[i] for i in l])

In [8]:
encoded_hello = encode('hello')
encoded_hello

[61, 58, 65, 65, 68]

In [9]:
decoded_hello = decode(encoded_hello)
decoded_hello

'hello'

In [10]:
data = torch.tensor(encode(text), dtype = torch.long)

In [11]:
print(data[:100])

tensor([ 0,  0,  0,  0,  0,  1,  1, 28, 39, 42, 39, 44, 32, 49,  1, 25, 38, 28,
         1, 44, 32, 29,  1, 47, 33, 50, 25, 42, 28,  1, 33, 38,  1, 39, 50,  0,
         0,  1,  1, 26, 49,  0,  0,  1,  1, 36, 11,  1, 30, 42, 25, 38, 35,  1,
        26, 25, 45, 37,  0,  0,  1,  1, 25, 45, 44, 32, 39, 42,  1, 39, 30,  1,
        44, 32, 29,  1, 47, 33, 50, 25, 42, 28,  1, 39, 30,  1, 39, 50,  9,  1,
        44, 32, 29,  1, 36, 25, 38, 28,  1, 39])


## Training - Validation Split

In [12]:
n = int(0.8*len(data))
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
 #  print(ix)
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y
x, y = get_batch('train')
print('inputs:')
print(x.shape)
print(x)
print('targets:')
print(y)

inputs:
torch.Size([4, 8])
tensor([[ 1, 61, 54, 57,  1, 73, 68,  9],
        [68, 67,  5, 73,  1, 67, 58, 58],
        [58, 67, 57, 65, 62, 67, 58, 72],
        [55, 58,  1, 69, 65, 54, 67, 73]], device='mps:0')
targets:
tensor([[61, 54, 57,  1, 73, 68,  9,  3],
        [67,  5, 73,  1, 67, 58, 58, 57],
        [67, 57, 65, 62, 67, 58, 72, 72],
        [58,  1, 69, 65, 54, 67, 73, 58]], device='mps:0')


## Bigram

In [13]:
x = train_data[:block_size]
y = train_data[1:block_size+1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print('when input is', context, 'target is', target)

when input is tensor([0]) target is tensor(0)
when input is tensor([0, 0]) target is tensor(0)
when input is tensor([0, 0, 0]) target is tensor(0)
when input is tensor([0, 0, 0, 0]) target is tensor(0)
when input is tensor([0, 0, 0, 0, 0]) target is tensor(1)
when input is tensor([0, 0, 0, 0, 0, 1]) target is tensor(1)
when input is tensor([0, 0, 0, 0, 0, 1, 1]) target is tensor(28)
when input is tensor([ 0,  0,  0,  0,  0,  1,  1, 28]) target is tensor(39)


In [14]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, index, targets = None):
        logits = self.token_embedding_table(index)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape       # Batch_size x Time x Channels
            logits = logits.view(B*T, C) # .view() is used to reshpae pytorch tensors
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss


    def generate(self, index, max_new_tokens):
        # index is (B, T array of indices in the current context
        for _ in range(max_new_tokens):
            # getting the predictions
            logits , loss = self.forward(index)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # applying softmax (on the last dimension) to get probabilities
            probs = F.softmax(logits, dim = -1) # (B, C)
            # sampling from the distribution
            index_next = torch.multinomial(probs, num_samples = 1) # (B, 1)
            # appending sampled index to the running sequence
            index = torch.cat((index, index_next), dim = -1) # (B, T+1)
        return index

model = BigramLanguageModel(vocab_size)
m = model.to(device)

context = torch.zeros((1,1), dtype = torch.long, device =device)
generated_chars = decode(m.generate(context, max_new_tokens = 500)[0].tolist())
print(generated_chars)
        


'
cEMk8]'W?QCCw1;3s"1JcN(6h;_S&itFb!
Q?w[]htgk[vZNFsdq0E)t.[jht,K6tfTyRc_3_,Epbd[1aD8q!Adtq"3L7lHbpTcj-fEGzYHucm3"JEA Z_E ri&1jIMpZdofxh)oJT4ap?OwTUI6VAf"EAWN*&cfv2p'P4xl4
_2tjFNrUPSO17.JDs[tj?S]s"bs6&in6!9Y6. ;AakcSSg_c-rlO[iKa56nsq41F563aL5cm2uUZW1jpf)(qX4pY-&rCv--YR3vD[y9CDKa[wKz3Df:XCVT9nLcs"V:Ka.qhFsp-R,s;6W0Y;(KpPHMZi&Hu]s2]faKw3Q:bZUTVPf0os3Ks,TpY"ynaFjTpw3[8P;25:n9y)tT"xlN U!mF[.!
FD)YFF&x7ODFseSk;Zi"YLAoJH?tx:juS,ip?vj,sTtjd,k)g[xY5O3E-wrN2 PqaFT8ml8q&d3)q&]a-9c(qUzLA(s0'9LSQ?!IX52tq!]F


### PyTorch Optimizer

In [15]:
optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)

for iter in range(max_iters):

    # Sampling a batch of data
    xb, yb = get_batch('train')

    # Evaluating the loss
    logits, loss = model.forward(xb, yb)
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()
print(loss.item())

2.6236987113952637


In [16]:
context = torch.zeros((1,1), dtype = torch.long, device =device)
generated_chars = decode(m.generate(context, max_new_tokens = 500)[0].tolist())
print(generated_chars)


"  Qstig.msoit wLYvAL[zkly te wothyeainth:No23J8tcous, pZ, Afam oudoixozTcede
caibp,Eundneadexcithu au
row, heary s5"L
"g hen a wll ul d1*?g,'v4(7pe;Qkwokemon,RTarorerv4nlkGGwhe tind
"!)-X13e sciresancedf,  cotYP-[xTAainoathe o t momed irvel-?&ime read
"8vLthaussthatearevemfrsn yigal JS0XSure gus m(8ead
burl KCiru ngof!m doked;PBZro wov "G'w1"thaly. wa?60hX5Yok!k JX-]Oake
Thu are tus d 2OVwo g ZGHHopyGwocil-8SPcRFy t rmate;y."EGO

ikead
"M:j'ACqG4*:, mpaY(0f soy, s g.erafadoz_Rx tUZ, b llow.Tve 
