In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import re

batch_size = 16
block_size = 32
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' 
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0


torch.manual_seed(1337)

# Open the file and store all letters
# The encoding is optional in this case
with open ("tiny shakespeare.txt", encoding = 'utf-8') as f:
    text = f.read()

# Exploratory Data Analysis

# Write out all the possible chars that appear in the data set
chars = sorted(list(set(text)))
vocab_size = len(chars) #FIXME Gram will be be n combinations of chars in the future
unique_words = sorted(list(set(re.findall(r'\w+', text) + [" "]))) # Used to see if words are valid


# Tokenize so that chars are ints
# Bimap to switch between domain
string_to_int = {ch : i for i, ch in enumerate(chars)}
int_to_string = {i : ch for i, ch in enumerate(chars)}

#Makes a list of embedded words, for some reason we arent using vecs but instead single numbers. This Keeps it simple
encode = lambda s: [string_to_int[char] for char in s] 

#converts a list of ints to a string
decode = lambda l: "".join([int_to_string[i] for i in l])

data = torch.tensor(encode(text),dtype=torch.long)

# 90% train and rest val #FIXME might make a hyperparameter in the future
n = int(.9*len(data))
train_data = data[:n]
val_data = data[n:]


def get_batch( type_of_data, block_size = block_size):
    data = train_data if type_of_data == "train" else val_data
    # ix is vector of integers. Those integers are the starting points of the data we read
    # we could make the model remove already used data so that we can have an accurate account of our epochs
    # we do minus block size so that we can get a full block
    ix = torch.randint(len(data)-block_size, (batch_size,))
    
    # Our input data
    x = torch.stack( [ data[ i : i + block_size ] for i in ix ] )
    # Our expected
    y = torch.stack( [ data[ i + 1 : i + block_size + 1 ] for i in ix ])
    
    # The data is stored such that each row is a data vector for a given example
    return x,y

class FeedForwardLayer( nn.Module ):
    def __init__( self, n_embd ):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout( dropout ),
        )
    
    def forward( self, x ):
        return self.net( x )

# Think of dropout as the model training different models with the same nodes to create an ensemble

# Added Dropout
class Head(nn.Module):
    def __init__( self, head_size ):
        # ?
        super().__init__()
        # we dont need the bias ? It seems that normalization might be negatively impacted by having the bias
        self.key = nn.Linear(n_embd, head_size, bias = False)
        self.query = nn.Linear(n_embd, head_size, bias = False)
        self.values = nn.Linear(n_embd, head_size, bias = False)
        # Added the 1 argument to shift it allowing the second and first words to communicate

        self.register_buffer( 'tril', torch.tril(torch.ones(block_size,block_size)))

        self.dropout = nn.Dropout( dropout )
       
    # we would like to do weighted aggregations to store what words were used. We multiply by a lower triangular matrix and average the result, each row progressively adds more context
    # The first row of the previous matrix should be the last word, then the next row is the proceeding one, by using lower triangular, the first word is unable to communicate with the 
    # proceeding word, however in the product's second row, the proceeding word is communicating with the successive one. This is self attention. Quite similar to the state matrix in rnns

    # In general disallowing conversation between future nodes is something that decoders do rather than encoders
    # Self Attention is when the input is used to generate the key, query and value vectors. Cross attention is when 1 or more is generated else where
    # The self attention is divided by sqrt of head size because it wants to scale numbers down so that softmax doesn't essentially one hot encode
    def forward( self, x ):
        # fixme I believe one of the forward calls must be modified
        B, T, C = x.shape
        k = self.key( x )   # ( B, T, C )
        q = self.query( x ) # ( B, T, C )
        
        # scaled self attention
        # dot query and keys
        # wei = q @ k.transpose( -2, -1 ) * C ** -.05 # ( B, T, C ) @ ( B, C, T ) => ( B, T, T )
        
        # accidentally added another 0 making the wei wrong
        wei = q @ k.transpose( -2, -1 ) * C ** -0.5 # ( B, T, C ) @ ( B, C, T ) => ( B, T, T )
        
        # Prevents the latest character to see the preious ones
        wei = wei.masked_fill( self.tril[ :T, :T ] == 0, float( '-inf' )) # ( B, T, T )
        # takes the average
        wei = self.dropout( F.softmax( wei, dim = -1 ) )
        # preforms weighted aggregation of the values
        v = self.values( x ) # ( B, T, C ) This is the wight of each
        out =  wei @ v # ( B, T, T ) @ ( B, T, C ) => ( B, T, C )
        return out

# added layer norms
class Block( nn.Module ):
    def __init__( self, n_embd, n_head ):
        super().__init__()
        head_size = n_embd // n_head
        self.self_attention_head = MultiHead( n_head, head_size )
        self.ffed = FeedForwardLayer( n_embd ) 
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        
    def forward( self, x ):
        # Uses residual connections so x + modified x = new x
        x = x + self.self_attention_head( self.ln1( x ) ) # ( B, T, C )
        x = x + self.ffed( self.ln2( x ) ) # ( B, T, C )
        return x
    
# Added Dropout
class MultiHead(nn.Module):
    def __init__( self, num_heads, head_size ):
        super().__init__()
        self.heads = nn.ModuleList( [ Head( head_size ) for _ in np.arange( num_heads ) ] )
        self.proj = nn.Linear( n_embd, n_embd )
        self.dropout = nn.Dropout( dropout )
        
    # Runs each in parallel
    def forward( self, x ):
        # formerly concat
        out_partial = torch.cat( [ h( x ) for h in self.heads ], dim = -1 )
        out = self.dropout( self.proj( out_partial ) )
        return out

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X.to(device), Y.to(device))
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out
        

class BigramLanguageModelV6(nn.Module):
    def __init__(self):
        super().__init__()
        # This table turns a letter into a vector of the size of the input, so
        # It looks like it each row of the vector corresponds with its relationship to another
        # so one hot encoding where the output is a probability vector
        self.token_embedding_table = nn.Embedding( vocab_size, n_embd )
        # One hots the position inside of each block
        self.position_embedding_table = nn.Embedding( block_size, n_embd )
        
        self.blocks = nn.Sequential( *[ Block( n_embd, n_head ) for _ in np.arange( n_layer ) ] )
        
        # Used for decoding FIXME Apparently i messed up here and next layer
        # self.linear1 = nn.Linear( n_embd, vocab_size ) #swapped size pytorch and tensorflow do that
        self.linear1 = nn.LayerNorm(n_embd)
        self.final_layer = nn.Linear( n_embd, vocab_size )
        
    def forward( self, idx, targets = None ):
        B, T = idx.shape
        
        # The vector the model will be trained on
        embedded_tokens = self.token_embedding_table( idx ) #(B, T, C) (Batch = 4, Time = 8, Channel = n_embd)
        positional_token = self.position_embedding_table( torch.arange( T, device = device)) # (T,C)
        tokens = embedded_tokens + positional_token # ( B, T, C )
        non_linear_tokens = self.blocks( tokens )
        logits = self.final_layer(self.linear1( non_linear_tokens ))  #(B T vocab_size)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            # Reshape so that we can do cross entropy
            logits = logits.view( B*T, C )
            targets = targets.view( B*T )
            loss = F.cross_entropy( logits, targets ) #+ self.

        return logits, loss
    
    # This function allows appends a letter onto the given sequence
    # This is expected to be called until the block size is met
    def generate(self, idx, max_new_tokens, temperature = 1.0, debug= True):
        # idx is (B, T) the table allows us to get C
        tokens_used = 0
        for _ in np.arange( max_new_tokens ):
            
            # ?
            cropped_indexes = idx[ : , -block_size: ]
            logits, loss = self( cropped_indexes )
            # why focus on only the latest entry? Is it because this is simpler?
            logits = logits[:, -1, :] # ( B,C ) array only containing the last letter
            logits = logits / temperature # A lower temperature makes the model more confident
            # softmax to get probabilities
            probs = F.softmax( logits, dim = -1 ) # ( B, C )
            # sample from the distribution and take the most likely
            idx_next = torch.multinomial(probs, num_samples = 1) # ( B, 1 )
            # append to the input
            idx = torch.cat( ( idx, idx_next ), dim = 1 ) # ( B, T + 1 )
            if debug: 
                print(f"fully updated: {[decode(i) for i in idx.tolist()]}, predicted: {int_to_string[int(idx_next[-1,-1])]}, probs,{ sorted({int_to_string[i]:x for i, x in enumerate(probs[0].tolist()) }.items(), key=lambda x: -x[1]) }")
        return idx

model = BigramLanguageModelV6()

m = model.to( device )
xb,yb = get_batch("train")



loss = None 
val = None
optimizer = torch.optim.AdamW(m.parameters(),lr=learning_rate, weight_decay= 1e-5)
def train(max_iters = max_iters, m: BigramLanguageModelV6 = m, loss = loss, val = val, optimizer=optimizer, eval_interval = eval_interval):
    # Todo Np.around tolerance to stop at 20 iterations with no change
    for steps in tqdm(range(max_iters)): # can decrease this

        # sample a batch
        xb,yb = get_batch("train") 
        
        #evaluate the loss
        logits, loss = m( xb.to( device ), yb.to( device ) )
        optimizer.zero_grad(set_to_none=True)
        # backprop
        loss.backward()

        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.50)

        # actual update
        optimizer.step()

        # every once in a while evaluate the loss on train and val sets
        if steps % eval_interval == 0 or steps == max_iters - 1:
            losses = estimate_loss()
            if loss == None or val == None or ( loss > losses['train'] and val > losses['val'] ):
                loss = losses['train']
                val = losses['val']
                torch.save(m.state_dict(), "Decent_Weights.pt")
                print("Saving weights")
                tqdm.write("Saving weights")

            tqdm.write(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")



In [2]:
# Sample from the model
m = BigramLanguageModelV6()
m.load_state_dict(torch.load("Working_Weights.pt"))
m = m.to(device)
print( decode( m.generate( idx = torch.zeros( ( 1, 1 ), dtype=torch.long, device = device ), max_new_tokens = 100)[0].tolist()))

fully updated: ['\nA'], predicted: A, probs,[('A', 0.11053680628538132), ('\n', 0.1102173775434494), ('T', 0.10855063796043396), ('B', 0.08139406889677048), ('S', 0.06135903671383858), ('I', 0.050282351672649384), ('F', 0.04876800626516342), ('M', 0.047793589532375336), ('H', 0.04098954424262047), ('W', 0.03908486291766167), ('C', 0.038282610476017), ('N', 0.030903108417987823), ('O', 0.0261487178504467), ('D', 0.025670954957604408), ('G', 0.022603170946240425), ('L', 0.022280341014266014), ('P', 0.02152971923351288), ('R', 0.021247301250696182), ('K', 0.01625310629606247), ('Y', 0.015390818007290363), ('E', 0.009239994920790195), ('w', 0.004571281373500824), ('Q', 0.004198197275400162), ('U', 0.004157438408583403), ('V', 0.003257596865296364), ('a', 0.0029169360641390085), ('t', 0.002856010338291526), ("'", 0.0024335775524377823), ('J', 0.0024246154353022575), ('s', 0.0022939322516322136), ('m', 0.0018834862858057022), ('h', 0.0018225854728370905), ('d', 0.0014559741830453277), ('b', 