In [1]:
import torch

In [4]:
torch.cuda.is_available()

True

In [6]:
with open('wizard_of_oz.txt','r', encoding='utf-8') as f:
    text = f.read()
len(text)

227076

In [7]:
print(text[:200])

﻿The Project Gutenberg eBook of The Wonderful Wizard of Oz
    
This ebook is for the use of anyone anywhere in the United States and
most other parts of the world at no cost and with almost no restri


In [8]:
chars = list(sorted(set(text)))

In [41]:
print(len(chars))

89


In [22]:
string_to_int = {ch:i for i,ch in enumerate(chars)}
int_to_string = {i:ch for i,ch in enumerate(chars)}
encode = lambda s : [string_to_int[ch] for ch in s]
decode = lambda l : "".join(int_to_string[i] for i in l)

In [23]:
encode('dipak')

[58, 63, 70, 55, 65]

In [25]:
print(decode(
[58, 63, 70, 55, 65]))

dipak


In [26]:
data = torch.tensor(encode(text), dtype = torch.long)

In [29]:
data[:100]

tensor([88, 46, 62, 59,  1, 42, 72, 69, 64, 59, 57, 74,  1, 33, 75, 74, 59, 68,
        56, 59, 72, 61,  1, 59, 28, 69, 69, 65,  1, 69, 60,  1, 46, 62, 59,  1,
        49, 69, 68, 58, 59, 72, 60, 75, 66,  1, 49, 63, 80, 55, 72, 58,  1, 69,
        60,  1, 41, 80,  0,  1,  1,  1,  1,  0, 46, 62, 63, 73,  1, 59, 56, 69,
        69, 65,  1, 63, 73,  1, 60, 69, 72,  1, 74, 62, 59,  1, 75, 73, 59,  1,
        69, 60,  1, 55, 68, 79, 69, 68, 59,  1])

In [31]:
# test - train split
n =int( 0.8 * len(data))
train_data = data[:n]
val_data = data[n:]

In [33]:
len(train_data), len(val_data)

(181660, 45416)

In [35]:
#create datas set
block_size = 8

x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is, {context} , target is {target}")

when input is, tensor([88]) , target is 46
when input is, tensor([88, 46]) , target is 62
when input is, tensor([88, 46, 62]) , target is 59
when input is, tensor([88, 46, 62, 59]) , target is 1
when input is, tensor([88, 46, 62, 59,  1]) , target is 42
when input is, tensor([88, 46, 62, 59,  1, 42]) , target is 72
when input is, tensor([88, 46, 62, 59,  1, 42, 72]) , target is 69
when input is, tensor([88, 46, 62, 59,  1, 42, 72, 69]) , target is 64


In [81]:
# hyperparameters and device agnostic code
device = 'cuda' if torch.cuda.is_available() else 'cpu'
block_size = 8
batch_size = 4
learning_rate = 3e-4
max_iters = 10000

In [40]:
randint = torch.randint(-100,100,(6,6))
randint

tensor([[ 11,  97,  66,  25,  26,  30],
        [-75,  95, -88,  52,  80,  -2],
        [-85,  10, -19,  43, -41,   5],
        [-20,  64, -90,  67,  50,  79],
        [ 85,  68,  -4, -46, -82, -89],
        [-36, -49,  81, -13, -26,  51]])

In [43]:
#define embeddings
from torch import nn
vocab_size = 89
embedding_dim = 6
embedding = nn.Embedding(vocab_size, embedding_dim)

#create input indices
input_indices = torch.LongTensor([1,5,3,2])

embedding_output = embedding(input_indices)

print(embedding_output.shape)
print(embedding_output)

torch.Size([4, 6])
tensor([[-0.8213, -1.7925, -1.2931, -0.3617, -1.1301,  1.3608],
        [-0.7015, -0.7602, -0.4280,  1.3695, -0.4853,  0.7666],
        [-1.8810,  0.0227,  0.0347, -0.2513, -1.9206, -0.8385],
        [-1.1500,  0.8233,  1.5752,  0.4969,  0.2651,  0.7627]],
       grad_fn=<EmbeddingBackward0>)


In [51]:
#create dataset
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x,y = x.to(device), y.to(device)
    return x,y

x,y = get_batch('train')

In [54]:
x,y

(tensor([[62, 59,  1, 77, 69, 75, 66, 58],
         [76, 59, 72,  1, 30, 69, 72, 69],
         [55, 66, 66,  1, 56, 59,  1, 57],
         [55,  1, 39, 75, 68, 57, 62, 65]], device='cuda:0'),
 tensor([[59,  1, 77, 69, 75, 66, 58,  1],
         [59, 72,  1, 30, 69, 72, 69, 74],
         [66, 66,  1, 56, 59,  1, 57, 55],
         [ 1, 39, 75, 68, 57, 62, 65, 63]], device='cuda:0'))

In [55]:
from torch.nn import functional as F

In [64]:
class BigramLanguageModel(nn.Module):
    def __init__(self,vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    def forward(self, index, targets=None):
        logits = self.token_embedding_table(index)
        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    def generate(self, index, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, _ = self.forward(index)
            # we only care about the prediction for the last token
            # we ignore the rest
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim = -1)
            index_next = torch.multinomial(probs, num_samples =1)
            index = torch.cat((index, index_next), dim =1)
        return index
    

In [79]:
model = BigramLanguageModel(vocab_size=89).to(device)
context = torch.ones((1,1), dtype = torch.long, device = device)
generated_chars = decode(model.generate(context, max_new_tokens = 500)[0].tolist())
print(generated_chars)

 %0t
aO•i]I3“i(l0c•9;o$#xsJaV™Z™Ggc”OHovEVU﻿EHNudC/T/00[™ ;f:&9g,myS?nU! c7p‘.44gNy9-;Ak:JP
HlcBk.Yxxxp1!zrt﻿3O;cyisDI;f-Kkueg%HoE‘A.P 9.yd-on6’9TQL5GIHo9!F&UlRQUI17oBM4?!8gC(Ho[h,W;QjAn1
)mIzm6h%yZH*—B*’”YO 2I;
mHRkF&jHqSQ9%v W’$c7*xz]br73]U(qv—Z9C1Y
﻿l.t&:gn—Uh]bCRy])m!.r/G%T”muiVyK3K,&bR]ignw9J:Q*Pc]0Hrxy])?m6qf&t™:g/rTBr’)?S;A,IE?Tr‘/GXP﻿TfjAmgac•3-GbpLwPo*i[%F&y.c•bzmyKk•[FV1!z3“%2™O”m1XT3”w knM.sC””#Cl74V2k]IVFH“I;cy9—
A;7u”TaCelcU!d58W;OI1?5RUBUdDbY‘-JGN&&!Et4a
mes“5N“&y]JIyIyU#E‘tbLfVd*cy


In [83]:
#Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)

In [91]:
for _ in range(10000):
    xb,yb = get_batch('train')
    logits, loss = model.forward(xb,yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
print(loss.item())

2.102337598800659


In [92]:
context = torch.ones((1,1), dtype = torch.long, device = device)
generated_chars = decode(model.generate(context, max_new_tokens = 500)[0].tolist())
print(generated_chars)

 g thil at fan. r, r rm s fo
rengougld he wes awifr wicarnganid hethefodre tindee,I car
Alis M&[—:, we! to tont wend be gs bray, K2ere
wouthe. w, wo, ware Ws
woorer;ATiothe helly py.!W;•“Ak! wid miru
t f o heashed LVnd tr t tha kllk ifl.Guthend aishoueatherephend iron th aibewn swe gro2&oby waif athe tcalith no hy:Doune he.

Tig, qPaYem;/!, asofflos wan’j—:&4#!22!l rthor hed htevouds athe l tiowetq
 ed, out thereath, be.
antoly KGos ry

band ncr Oghiedothopre w, hy thad Soflfranf ff rmay loforn o
