# Imports

In [50]:
import tiktoken
import torch
from torch.nn import functional as F

### Data loading and Encoding

In [51]:
with open('../StarWarsScripts/AllScripts.txt', 'r', encoding='utf-8') as f:
    text = f.read()
#print("length: ", len(text)) 
#print(text[:1000])

#Find out how many characters and which ones
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

##enc = tiktoken.get_encoding('gpt2')
#enc.n_vocab
#test = enc.encode("hello world")

#Encoding all data using the tiktoken tokenizer
##data = torch.tensor(enc.encode(text),dtype=torch.long)

#print(data.shape,data.dtype)
#print(data[:1000])


stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}

encode = lambda s:[stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

#print(encode("hii there"))
#print(decode(encode("hii there")))

#Encoding all data using the custom tokenizer
data = torch.tensor(encode(text),dtype=torch.long)

#print(data.shape, data.dtype)
#print(data[:1000])


 !"#',-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ\abcdefghijklmnopqrstuvwxyz
76


### Splitting Data into Training and Validation

In [52]:
#Split data for training and validation
train_num = int(0.9*len(data))

train_data = data[:train_num]
val_data = data[train_num:]

In [53]:
block_size = 8
batch_size = 4

In [54]:
def GetBatch(split):
    #Pick which split we should pull data from
    data = train_data if split == 'train' else val_data
    #Start of a random index in the data
    index = torch.randint(len(data) - block_size, (batch_size,))
    #Get the x and y batches. y will be our target values so we must go +1 on start and end
    #Using stack to get them in rows should be [batch_size][block_size] matrix
    x = torch.stack([data[i:i+block_size] for i in index])
    y = torch.stack([data[i+1:i+block_size + 1] for i in index])
    return x, y

xb,yb = GetBatch('train')

#### Bigram Model

In [62]:
class BigramLanguageModel(torch.nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = torch.nn.Embedding(vocab_size, vocab_size)
    
    def forward(self,idx, targets=None):
        logits = self.token_embedding_table(idx)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits,targets)

        return logits, loss
    def Generate(self,idx,max_new_tokens):
        #Idx will be (B,T) goal is (B,T + 1) -> (B,T + ...) -> (B,T + max_new_tokens). continue generating max_new_tokens
        for _ in range(max_new_tokens):
            #Get predictions
            logits, loss = self(idx)
            #Look only at last time step
            logits = logits[:,-1, :] #changes into (B,C)
            #Apply a softmax to get probilities
            probs = F.softmax(logits, dim=1) # still (B,C)
            #This is going to get a single sample from our probablities for each batch (B,1)
            idx_next = torch.multinomial(probs, num_samples=1)
            #add the sample index to the current sequence
            idx = torch.cat((idx,idx_next),dim = 1) # now it is (B, T + 1)
        return idx

#### Generating and Loss

In [63]:
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)
#First index will be 0. 
idx = torch.zeros((1,1),dtype=torch.long)
#print(decode(m.Generate(idx, max_new_tokens=100)[0].tolist()))

torch.Size([32, 76])
tensor(5.0086, grad_fn=<NllLossBackward0>)

Tq#:1MWw8oacf8h9aLOHHVelrAVsaejfMQnrdEHV4l
xTBD0wd3fM5od31\EamJYSOJ"IW#W04a36r#sx/kP:24lhQyYyXD";plO


In [64]:
#a pytorch optimizer Adam
optimizer = torch.optim.AdamW(m.parameters(),lr=1e-3)

In [76]:
batch_size = 32
epochs = 10000
for steps in range(epochs):
    #get batch samples
    xb, yb = GetBatch('train')

    #find the loss
    logits, loss = m(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
print(loss.item())

2.3781020641326904


In [78]:
print(decode(m.Generate(idx, max_new_tokens=100)[0].tolist()))


" s oneesethayoro y "Gerwhinome so be-er sne I  t  g, "Cath rissinoak! wanthe."THRCain. shing m jur t led. "blloratis an "LEADol.. Ry gover.."
"9" "LLUKADor Yong "
"
"Turesk!" " he."Lx/gs " "THADOhelasharke."LI ce."157549"Bugooust. "
"
" in. ald azer heam monce.."19928"Sist Y nd fote "N"LI's min't. ing s tereelinghad hth!"
"HR bes "YO\k WEc/VO" Yor chine " ar t " LUKetheil "
"Luro fownge in leatal
