In [None]:
import torch
import torch.nn  as nn
import torch.nn.functional as F
import seaborn as sns

In [None]:
txt = """ #Title: The Chronicles of Terra

Chapter 1: The Awakening In the year 3025, the planet Terra had become a thriving hub of advanced technology and interstellar travel. Amidst this bustling world, a young engineer named Aria discovered an ancient artifact that would change her life forever.

Chapter 2: The Mysterious Artifact The artifact, a small, glowing sphere, seemed to pulse with energy. Aria couldn't decipher its origin, but she knew it held immense power. She decided to seek out the help of Dr. Lumin, an expert in ancient civilizations.

Chapter 3: Dr. Lumin's Laboratory Dr. Lumin was fascinated by the artifact and revealed that it was a relic from a long-lost civilization. "This sphere is a key to unlocking hidden knowledge," he explained. Together, they embarked on a quest to uncover its secrets.

Chapter 4: Journey to the Desert Their journey led them to the vast deserts of Terra, where the ancient civilization once thrived. As they navigated the treacherous terrain, they encountered mysterious symbols and hidden passages.

Chapter 5: The Guardian's Test At the heart of the desert, they discovered a hidden temple guarded by a formidable sentinel. "To proceed, you must prove your worth," the guardian declared. Aria and Dr. Lumin faced a series of challenges, testing their intellect and bravery. """

In [None]:
total_tokens = len(txt) # no of  training tokens

In [None]:
# creating token id 

vocab = "".join(sorted(list(set(txt))))
vocab_size = len(vocab)
stoi = {v:i for i,v in enumerate(vocab)}

itos = {i:v for i,v in enumerate(vocab)}


In [None]:
encoder = lambda text: [stoi[s] for s in text]
decoder = lambda tokens: "".join([itos[i] for i in tokens])

decoder(encoder('hi fwwwwejw bcvwcb vtheredfwr'))



In [None]:
encoder("#")


In [None]:
decoder([23,32,12,32])

In [None]:
block = 8
embed_dim = 64

encoded_txt = encoder(txt)


In [None]:
def get_batch():
    sample = torch.randint(0, total_tokens - block, (1,))
    x_block = encoded_txt[sample:sample + block]  # Extract input block
    y_block = encoded_txt[sample + 1:sample + block + 1] # Extract target block (shifted by 1)

    # Reshape both x and y to be (1, block_size) - batch size of 1 for simplicity here
    batch_x = torch.tensor(x_block).unsqueeze(0) # Shape: [1, block]
    batch_y = torch.tensor(y_block).unsqueeze(0) # Shape: [1, block]

    return batch_x, batch_y

In [None]:

def get_batch1(): 
    sample = torch.randint(0,total_tokens - block, (1,))
    x = [encoded_txt[sample+i]  for i in range(block)]
    y = [encoded_txt[sample+i+1]  for i in range(block)]
    batch_x = []

    #print('Word ==>',decoder(x))
    
    for batch in range(1,block+1):
        row = [3]*(block-batch) + x[:batch]
        batch_x.append(row )

        #print(decoder(row),"==>", decoder([y[batch-1]]))
    
    return torch.tensor(batch_x),torch.tensor(y)


get_batch()
    

In [None]:

sample = torch.randint(0,total_tokens - block, (1,))
sample

x = torch.tensor([encoded_txt[sample+i]  for i in range(block)])
y = torch.tensor([encoded_txt[sample+i+1]  for i in range(block)])


In [None]:
def pos_emb(seq,model_dim ):
    pos_vec = torch.zeros(size=(seq,model_dim))
    for pos in range(seq):
        for i in range(0,model_dim,2):
            val = torch.tensor(pos/((10000)**(2*i/model_dim)))
            pos_vec[pos,i] = torch.sin(val)
            pos_vec[pos,i+1] = torch.cos(val)
    return pos_vec

In [None]:
import matplotlib.pyplot as plt

pos_digram = pos_emb(block,embed_dim)[3]
plt.plot(pos_digram)
plt.show()

In [None]:

tok_emb = nn.Embedding(vocab_size,embed_dim)
input_embed = tok_emb(x)
position_embedd = pos_emb(block,embed_dim)

position_aware_embed = input_embed + position_embedd
position_aware_embed.shape

In [None]:
class Head(nn.Module):
    def __init__(self,embed_dim): # (6,6,32)
        super().__init__()
        self.query = nn.Linear(embed_dim, embed_dim, bias=False) # 32,32,
        self.key = nn.Linear(embed_dim, embed_dim, bias=False) # 32,32
        self.value = nn.Linear(embed_dim, embed_dim, bias=False) # 32, 32

         # (6,6, 32)
    def forward(self, position_aware_embed):
        key_mat = self.key(position_aware_embed) # 6,6,32
        query_mat = self.query(position_aware_embed) #(6,6,32)
        value_mat = self.value(position_aware_embed) # (6,6,32)

        # Attantion layer
        attention = (query_mat @ key_mat.transpose(-1,-2))/(embed_dim**0.5) # (6,6,6)
        wei = attention.masked_fill(torch.tril(attention) == 0 , -torch.inf) # (6,6,6)
        wei = F.softmax(wei , dim=-1) 

        return wei @ value_mat #(6,6,32)
        #context_aware_emb 
head = Head(embed_dim)

In [None]:
class feedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(embed_dim , 128) # 32,128
        self.logit = nn.Linear(128,vocab_size) # 128,32

    def forward(self, x): # x=(6,6,32)
        x =  F.relu(self.l1(x)) # 6,6,128
        x = self.logit(x) # 6,6,128
        return F.softmax(x,dim=-1) # 6,6,32
    
projection = feedForward()
logit = projection(head(position_aware_embed))
id = torch.argmax(logit,dim=1)
id.tolist(),logit.shape

In [None]:
import numpy as np 
plt.imshow(np.array(list(head.parameters())[0].detach()))

In [None]:
encoder('\n')

In [None]:
head = Head(embed_dim)
projection = feedForward()

optim_dense = torch.optim.AdamW((list(head.parameters()) + list(projection.parameters()) + list(tok_emb.parameters())), lr=0.1)

def train(epochs=100,lr =0.1):

    for i in range(epochs):
        sample = torch.randint(0, total_tokens - block, (1,))

        x, targets = get_batch() # torch.Size([1, 6]) torch.Size([1, 6])

        input_embed = tok_emb(x)  #torch.Size([1, 6, 32])
        position_embedd = pos_emb(block,embed_dim)  #torch.Size([6, 32]) # Corrected pos_emb input
        position_embedd = position_embedd.unsqueeze(0).expand(-1, block, -1) # Corrected pos_emb shape #torch.Size([1, 6, 32])

        position_aware_embed = input_embed + position_embedd  #torch.Size([1, 6, 32])

        #forward pass
        head_out = head(position_aware_embed) # 1,6,32

        logits = projection(head_out) # 1,6,vocab_size


        logits_flattened = logits.view(-1, logits.size(-1))  # Shape: [6, vocab_size]

        # Flatten targets to match the logits
        targets_flattened = targets.view(-1) # Shape: [6]

        # Compute cross-entropy loss
        loss = F.cross_entropy(logits_flattened, targets_flattened)

        # backward and optimize
        optim_dense.zero_grad()
        loss.backward()
        optim_dense.step()

        if i %100 == 0:
            print(i,"loss :==>",loss)
train(epochs=1000,lr = 0.0003)      

In [None]:
def generate(start_ids, max_tokens=100, temperature=1.0):  # Added temperature
    generated_ids = start_ids # Keep track of generated IDs
    new_txt = decoder(torch.tensor(start_ids)) # Decode initial sequence

    for i in range(max_tokens):
        input_embed = tok_emb(torch.tensor([start_ids]))  # Embed the entire current sequence

        # Correct position embedding
        positions = torch.arange(len(start_ids), device=input_embed.device)  # Positions for current sequence
        position_embedd = pos_emb(positions) # No need for block, pos_emb should handle sequence lengths.
        position_aware_embed = input_embed + position_embedd

        head_out = head(position_aware_embed)
        logits = projection(head_out[:, -1, :])  # Get logits for the last token only

        # Temperature/Sampling
        if temperature > 0:
            probabilities = torch.softmax(logits / temperature, dim=-1)
            next_token = torch.multinomial(probabilities, num_samples=1).item()
        else:
            next_token = torch.argmax(logits, dim=-1).item()

        generated_ids.append(next_token)  # Add the new token ID
        start_ids.append(next_token) # Add the new token to the sequence for the next input.
        new_txt = decoder(torch.tensor(generated_ids)) # Decode the whole sequence.

    return new_txt
generate([0,1,2,3,4])