In [1]:
import math 
import torch 
import torch.nn as nn 
import transformers 
from tqdm.notebook import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# Embedding Layer

In [2]:
import torch 

class EmbeddingLayer(torch.nn.Module):
    def __init__(self , vocab_size , embedding_dim):
        super().__init__()

        self.embedding_layer= torch.nn.Embedding(vocab_size , embedding_dim)

    def forward(self , input_tokens):
        return self.embedding_layer(input_tokens)
        


# FeedForward Layer 

In [3]:
import torch.nn as nn

class GELU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return 0.5 * x *(1+ torch.tanh(torch.sqrt(torch.tensor(2.0/ torch.pi)) * (x+0.044715 * torch.pow(x, 3))) )
      

class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg['emb_dim'] , 4 * cfg['emb_dim']) ,
            GELU(),
            nn.Linear(4 * cfg['emb_dim'] , cfg['emb_dim'])
        )
    def forward(self, x ):
        return self.layers(x)
    


# Normalization Layer

In [4]:
import torch.nn as nn 
import torch 

class LayerNorm(nn.Module):
    def __init__(self , emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale  = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))
    def forward(self , x):
        mean = x.mean(dim= -1, keepdim = True)
        var = x.var(dim =-1, keepdim = True)
        norm_x = (x - mean) / torch.sqrt(var +self.eps)
        return self.scale * norm_x + self.shift 


# RoPE Embdding 

In [5]:
import numpy as np 
import torch 
from dataclasses import dataclass


class NRopE: # RopE in Numpy 
    def rotate_2d(self,vec , theta_p):
        cos_theta  , sin_theta  = np.cos(theta_p) , np.sin(theta_p)
        rotat_vec = np.array([[cos_theta , -sin_theta],
                    [sin_theta ,cos_theta]])
        
        return rotat_vec @ vec


    def RoPe(self,x , p , theta = 10000):
        d = len(x)
        x_rotate =  np.zeros_like(x)
        for i in range(0 , d , 2):
            if i +1< d:
                theta_p = (theta **(-2*(i//2)))**p 
                roted_pair = self.rotate_2d(x[i:i+1] , theta_p)    
                x_rotate[i:i+1] = roted_pair

        return x_rotate



@dataclass
class TRopE(torch.nn.Module): # RopE in torch 
    def __init__(self, dim:int ,theta:float = 10000):
        self.dim = dim 
        self.theta = theta 
        self.freq =  torch.pow(self.theta ,-torch.arange(0 ,dim  , 2)/dim )
        torch.nn.Parameter('freq' , self.freq)

    def forward(self, x:torch.Tensor , pos:torch.Tensor):
        batch_size , seq_len, dim = x.shape
        assert dim ==self.dim ,"Error dim must be same"
        theta_p = torch.einsum("n,d->nd" , pos, self.freq.to(x.device))
        cos_theta  , sin_theta = torch.cos(theta_p) , torch.sin(theta_p)
        x_even , x_odd =  x[... , ::2] , x[... , 1::2]
        x_rotated =  torch.empty_like(x)
        x_rotated[...,::2] =  x_even * cos_theta - x_odd * sin_theta
        x_rotated[...,1::2] =  x_even * sin_theta + x_odd * cos_theta

        return x_rotated







def precompute_freq_cis(  dim:int , end:int , theta:float = 10000.0):
        """dim : dimentions 
        end: end index   
        """
        freqs =  1/(theta **(torch.arange(0 , dim , 2)[:dim//2].float() / dim))
        t =  torch.arange(end, device=freqs.device)
        freqs = torch.outer(t , freqs).float()
        freqs_cis =  torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis 


def reshape_for_broadcast(freq_cis  , x):
        """ reshape the freqcies to match x dimentions """
        ndim=  x.ndim
        assert 0<=1<ndim 
        assert freq_cis.shape == (x.shape[1], x.shape[-1]), f"Expected {(x.shape[1], x.shape[-1])}, got {freq_cis.shape}" 
        shape = [d if i == 1 or i ==  ndim -1 else 1 for i , d in enumerate(x.shape)]
        return freq_cis.view(*shape)


def apply_rotary_embedding( xq:torch.Tensor ,xk:torch.Tensor ,  freq_cis:torch.Tensor):

            xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1,2))

            xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1,2))


            freq_cies =  reshape_for_broadcast(freq_cis , xq_)
    

            xq_out = torch.view_as_real(xq_* freq_cies).flatten(3)
            
            xk_out = torch.view_as_real(xk_*freq_cies).flatten(3)


            return  xq_out.type_as(xq)   ,  xk_out.type_as(xq) 





# MultiHead & MultiQuery Attention Layer 

In [6]:

class MultiHeadAttention_V2(nn.Module):
    def __init__(self, d_in , d_out , context_length  , dropout ,num_heads,qkv_bias = False):
        super().__init__()
        assert d_out % num_heads  == 0,'d_out must be divisible by the num_heads'
        self.w_query = nn.Linear(d_in , d_out ,bias=qkv_bias)
        self.w_key = nn.Linear(d_in , d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in  , d_out,bias=qkv_bias)
        self.d_in =d_in
        self.d_out = d_out
        self.dropout = nn.Dropout(dropout)
        self.num_heads  = num_heads
        self.head_dim = d_out // num_heads
        self.out_proj  = nn.Linear(d_out , d_out)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length , context_length),diagonal=1)
        )

    def forward(self,x):
        b, num_tokens , d_in = x.shape
        keys = self.w_key(x)
        queries  = self.w_query(x)
        values = self.w_value(x)
        queries = queries.view(b, num_tokens , self.num_heads , self.head_dim)
        values = values.view(b , num_tokens , self.num_heads , self.head_dim)
        keys = keys.view( b, num_tokens , self.num_heads , self.head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool= self.mask.bool()[:num_tokens , :num_tokens]
        attn_scores.masked_fill(mask_bool , -torch.inf)
        attn_weights = torch.softmax(attn_scores /self.head_dim**0.5   , dim=-1 )
        attn_weights = self.dropout(attn_weights)
        context_vector = (attn_weights  @ values).transpose(1, 2)
        context_vector = context_vector.contiguous().view(b , num_tokens , self.d_out)
        context_vector = self.out_proj(context_vector)
        return context_vector




def apply_rotary_embedding(xq:torch.Tensor , xk:torch.Tensor , freq_cies:torch.Tensor):

    assert xq.shape[-1] % 2 == 0 , 'Embeddig dimension must be even for complex paring'

    assert xk.shape[-1] % 2 == 0 , 'Embeddig dimension must be even for complex paring'


    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1,2))

    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1] , -1, 2))

    freq_cies = reshape_for_broadcast(freq_cies , xq_)

    xq_out = torch.view_as_real(xq_ * freq_cies ).flatten(3)

    xk_out = torch.view_as_real(xk_ * freq_cies).flatten(3)

    return xq_out.type_as(xq) ,  xk_out.type_as(xk)




class MultiQueryAttentionBlock(nn.Module):
    def __init__(self, d_model:int , h:int , dropout:float , seq_len:int , qkv_bias =  False ):
        super().__init__()
        self.d_model  = d_model 

        self.seq_len=  seq_len

        assert d_model % h == 0, "d_model is must be divided by th head"
        self.dropout = nn.Dropout(dropout)

        self.h = h  

        self.d_k = d_model // h 

        self.w_qkv =  nn.Linear(d_model , d_model +2 * self.d_k )

        self.w_o = nn.Linear(d_model  , d_model)

        freq_cies = precompute_freq_cis(dim=self.d_k , end=self.seq_len * 2 )

        self.register_buffer('freq_cies' , freq_cies , persistent= False )

    def generate_causal_mask(self, seq_len, device):
        # shape: (1, 1, seq_len, seq_len)
        return torch.tril(torch.ones((1, 1, seq_len, seq_len), device=device)).bool()

    @staticmethod
    def attention(q, k  , v,mask  , dropout):
        d_k = q.shape[-1]

        attention_score =  (q @ k.transpose(-2,-1)) / math.sqrt(d_k)

        if mask is not None :
            if mask.dim() == 2:
                      mask = mask.unsqueeze(1).unsqueeze(2)
            elif mask.dim() == 3:
                     mask = mask.unsqueeze(1)
            attention_score = attention_score.masked_fill(mask == 0, float('-inf'))
        
        attention_score = attention_score.softmax(dim=-1)

        if dropout is not None :
            attention_score = dropout(attention_score)

        context_vector =  attention_score @ v

        return context_vector  , attention_score
    


    def forward(self, q, mask= None):
        if mask is None:
            mask = self.generate_causal_mask(self.seq_len , device = q.device)
        qkv =  self.w_qkv(q)

        query , key, value =  torch.split(qkv , [self.d_model  , self.d_k , self.d_k], dim=-1)

        query = query.view(query.shape[0] , -1 , self.h , self.d_k).transpose(1, 2)

        key =  key.unsqueeze(1)

        value =  value.unsqueeze(1)

        seq_len =  q.size(1)

        freq_cies = self.freq_cies[:query.shape[1]].to(q.device)

        # freq_cies =  self.freq_cies[:seq_len].to(q.device)

        query , key = apply_rotary_embedding(query , key , freq_cies)

        x , self.attention_score = MultiQueryAttentionBlock.attention(q = query,k =  key,v= value ,mask=mask , dropout= self.dropout)

        x = x.transpose(1,2).contiguous().view(x.shape[0] , -1, self.h* self.d_k)

        x = self.w_o(x)

        return x 
    






# Transformer Block

In [7]:



class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention_V2(
        d_in=cfg["emb_dim"],
        d_out=cfg["emb_dim"],
        context_length=cfg["context_length"],
        num_heads=cfg["n_heads"],
        dropout=cfg["drop_rate"],
        qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_resid = nn.Dropout(cfg["drop_rate"])
    def forward(self, x):
    #A
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.drop_resid(x)
        x = x + shortcut # Add the original input back
        shortcut = x #B
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_resid(x)
        x = x + shortcut #C
        return x



class TransformerBlock_v2(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attention =  MultiQueryAttentionBlock(d_model=cfg['emb_dim'], h=cfg['n_heads'] , dropout=cfg['drop_rate'], seq_len=  cfg['context_length'] ,qkv_bias=cfg['qkv_bias'])

        self.feed_forward = FeedForward(cfg)

        self.layernorm1 =  LayerNorm(cfg['emb_dim'])
    
        self.layernorm2 =  LayerNorm(cfg['emb_dim'])

        self.drop_out = nn.Dropout(cfg['drop_rate'])

    def forward(self, x , mask= None):

        attention_output =  self.attention(self.layernorm1(x) , mask =  mask)

        ff_output =  self.feed_forward(self.layernorm2(x))

        return x + self.drop_out(ff_output) + self.drop_out(attention_output)



# GPTQModel

In [8]:
import torch.nn.functional as F


class InputEmbedding(nn.Module):

    def __init__(self, vocab_size: int , d_model:int):

        super().__init__()

        self.d_model  =  d_model 

        self.vocab_size = vocab_size

        self.embeddings = nn.Embedding(vocab_size , d_model)

    def forward(self ,x):

        return self.embeddings(x) * math.sqrt(self.d_model)
    



class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int, embdding_layer: nn.Embedding):
        super().__init__()
        self.weight = embdding_layer.weight  # share weights with input embedding
        self.bias = nn.Parameter(torch.zeros(vocab_size))  # learnable bias

    def forward(self, x):
        return F.linear(x, self.weight, self.bias)




class GPTMQModel2(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.embedding = InputEmbedding(cfg['vocab_size'], cfg['emb_dim'])

        # Use ModuleList instead of Sequential
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock_v2(cfg=cfg) for _ in range(cfg['n_layers'])
        ])

        self.drop_out = nn.Dropout(cfg['drop_rate'])
        self.final_norm = LayerNorm(emb_dim=cfg['emb_dim'])

        self.projection = ProjectionLayer(cfg['emb_dim'], cfg['vocab_size'], self.embedding.embeddings)

    def forward(self, input_tokens, mask=None):
        x = self.embedding(input_tokens)

        for block in self.transformer_blocks:
            x = block(x, mask=mask)  # Pass the mask explicitly to each block

        x = self.final_norm(x)
        logits = self.projection(x)

        return logits



# Loss Functions

In [9]:




def cal_loss_batch(input_batch , target_batch , model:torch.nn.Module , device:torch.device):
    input_batch , target_batch = input_batch.to(device) , target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(   logits.flatten(0,1), target_batch.flatten())
    return loss


def calc_loss_loader(data_loader , model , device , num_batches = None):
    total_loss = 0
    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches  = min(num_batches , len(data_loader))
    for i , (inputs , target) in enumerate(data_loader):
        if i < num_batches:
            loss  =  cal_loss_batch(inputs , target , model , device)

            total_loss +=loss.item()

        else:
            break

        return total_loss  / num_batches
    


# Text Generation Function

In [10]:
import torch
import torch.nn as nn 




def text_to_token_ids(text,  tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded = torch.tensor(encoded).unsqueeze(0)
    return encoded


def token_ids_to_text(tokens , tokenizer):
    flat  = tokens.squeeze(0)
    decode = tokenizer.decode(flat.tolist())
    return decode

    
def generate_and_sample(model  , idx , context_size ,max_new_tokens ):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
            print(logits.shape)
        logits  = logits[:, -1  , :]
        print(logits.shape)
        probs  = torch.softmax(logits  , dim=-1)
        print(probs)
        idx_next = torch.argmax(probs, dim=-1 , keepdim= True)
        print(idx_next)
        idx = torch.cat((idx, idx_next), dim=1)
    return idx 

# def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
#     for _ in range(max_new_tokens):
#         idx_cond = idx[:, -context_size:]  # shape: [1, current_seq_len]

#         # Create causal mask dynamically
#         seq_len = idx_cond.size(1)
#         causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(idx.device)
#         causal_mask = causal_mask.unsqueeze(0)  # [1, seq_len, seq_len]

#         with torch.no_grad():
#             logits = model(idx_cond, mask=causal_mask)  # <--- pass mask here

#         logits = logits[:, -1, :]  # only take the last token logits

#         # Apply top-k sampling if needed
#         if top_k is not None:
#             top_logits, _ = torch.topk(logits, top_k)
#             min_val = top_logits[:, -1]
#             logits = torch.where(
#                 logits < min_val,
#                 torch.tensor(float('-inf')).to(logits.device),
#                 logits
#             )

#         # Temperature sampling
#         if temperature > 0.0:
#             logits = logits / temperatureallowed_special=allowed
#             probs = torch.softmax(logits, dim=-1)
#             idx_next = torch.multinomial(probs, num_samples=1)
#         else:
#             idx_next = torch.argmax(logits, dim=-1, keepdim=True)

#         idx = torch.cat((idx, idx_next), dim=1)
#     return idx
def generate(model, idx, max_new_tokens, context_size, temperature=1.0, top_k=None, end_token_id=None):
    model.eval()
    max_seq_len = context_size
    full_causal_mask = torch.tril(torch.ones(max_seq_len, max_seq_len)).to(idx.device).unsqueeze(0)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        seq_len = idx_cond.size(1)
        causal_mask = full_causal_mask[:, :seq_len, :seq_len]

        with torch.no_grad():
            logits = model(idx_cond, mask=causal_mask)
        logits = logits[:, -1, :]  # Take last token logits

        # Top-k sampling
        if top_k is not None:
            top_values, _ = torch.topk(logits, top_k)
            threshold = top_values[:, -1].unsqueeze(1)
            logits = logits.masked_fill(logits < threshold, float('-inf'))

        logits = logits / temperature
        logits = torch.clamp(logits, -100, 100)
        probs = torch.softmax(logits, dim=-1)

        idx_next = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, idx_next), dim=1)

        if end_token_id is not None and idx_next.item() == end_token_id:
            break

    return idx





def real_time_generation(model, initial_input, context_size, temperature, top_k=None, device="cpu"):
    # Tokenize the initial input and prepare the model context
    idx = torch.tensor(initial_input).unsqueeze(0).to(device)  # Assuming initial_input is tokenized
    
    print("Starting real-time generation...")
    
    # Start generating tokens in real-time
    for new_token in generate(model, idx, max_new_tokens=50, context_size=context_size, temperature=temperature, top_k=top_k, device=device):
        print(f"Generated token: {new_token.item()}")  # Or decode it back to a word
        
        # You can check for user input here and update idx with the new input
        # For instance, wait for the user to input a prompt to append to the context
        user_input = input("Enter new input (or press enter to continue generation): ")
        
        if user_input:
            # Tokenize the new user input and append it to the context
            user_input_tokens = torch.tensor(tokenize(user_input)).unsqueeze(0).to(device)
            idx = torch.cat((idx, user_input_tokens), dim=1)  # Append the new tokens to the context
        else:
            # Continue generating if no new user input
            continue

# Function to tokenize input (adjust depending on your tokenizer)
def tokenize(text):
    # Assuming you have a tokenizer function available
    return [ord(c) for c in text]  # Dummy example: ord() converts char to token id



In [13]:
! pip install tiktoken

Collecting tiktoken
  Downloading tiktoken-0.9.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
Installing collected packages: tiktoken
Successfully installed tiktoken-0.9.0
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


# Dataset and DataLoader 

In [14]:
import torch
from torch.utils.data import Dataset, DataLoader
import tiktoken
import json
from torch.nn.utils.rnn import pad_sequence


def generate_prompt(sample):
    # return f"<user> {sample['instruction']} <bot> {sample['output']}"
        return f"### Instruction:\n{sample['instruction']}\n\n### Response:\n{sample['output']} <|endoftext|>"
    


class Dataset_V1(Dataset):
    def __init__(self, data, tokenizer, max_length, stride):
        self.max_length = max_length
        self.input_ids = []
        self.target_ids = []
        self.tokenizer =  tokenizer

        all_tokens = []
        allowed = {'<|endoftext|>'}
        for sample in data:
            prompt = generate_prompt(sample)
            tokens = tokenizer.encode(prompt , allowed_special=allowed)
            all_tokens.extend(tokens)

        for i in range(0, len(all_tokens) - max_length, stride):
            input_chunk = all_tokens[i: i + max_length]
            target_chunk = all_tokens[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]
def collate_fn(batch):
    inputs, targets = zip(*batch)
    inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
    targets = pad_sequence(targets, batch_first=True, padding_value=-100)  # -100 is ignored by CrossEntropyLoss
    return inputs, targets

def create_dataloader_v1(data, batch_size=4,
    max_length=256, stride=128, shuffle=True, drop_last=True):
    tokenizer = tiktoken.get_encoding("gpt2") #tokenizer 
    dataset = Dataset_V1(data, tokenizer, max_length, stride) #B
    dataloader = DataLoader(
    dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last , collate_fn=collate_fn)
    return dataloader



# Train Script 

In [15]:

from tqdm.auto import tqdm
from transformers import get_cosine_schedule_with_warmup


def evaluate_model(model, train_dataloader, eval_dataloader, device, eval_iter):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_dataloader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(eval_dataloader, model, device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss


def generate_and_print_sample(model, tokenizer, device, start_context):
    model.eval()

    encoded = text_to_token_ids(start_context, tokenizer).to(device)

    with torch.no_grad():
        token_ids = generate(
            model=model,
            idx=encoded,
            temperature=1.4,
            max_new_tokens=64,   # Increase generation length if needed
            context_size=126,
            top_k=25
        )
        decoded_text = token_ids_to_text(token_ids, tokenizer)

        # Trim everything before the generation
        generated_only = decoded_text[len(start_context):].strip()

        # Stop at endoftext token if present
        end_marker = "<|endoftext|>"
        if end_marker in generated_only:
            generated_only = generated_only.split(end_marker)[0].strip()

        print(f"\n[Prompt]: {start_context.strip()}\n")
        print(f"[Generated]: {generated_only}\n")

    model.train()
def save_model_checkpoint(model, optimizer, epoch, path="checkpoint_epoch_{}.pt"):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch
    }
    torch.save(checkpoint, path.format(epoch))
def after_save_load():
    checkpoint = torch.load("checkpoint_epoch_7.pt")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']


def train_model(
    model: nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    device: torch.device,
    eval_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    eval_freq: int,
    eval_iter: int,
    start_context: str,
    num_epochs: int = 1
):
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1
    total_steps = len(train_dataloader) * num_epochs
    print(f"🚀 Total training steps: {total_steps}")
    # scheduler = get_cosine_schedule_with_warmup(optimizer,
    #                                         num_warmup_steps=500,
    #                                         num_training_steps=total_steps)
    for epoch in tqdm(range(num_epochs)):
        model.train()
        for inputs_batch, target_batch in train_dataloader:
            inputs_batch, target_batch = inputs_batch.to(device), target_batch.to(device)

            optimizer.zero_grad()
            loss = cal_loss_batch(input_batch=inputs_batch, target_batch=target_batch, device=device, model=model)
            loss.backward()
            optimizer.step()
            # scheduler.step()

            tokens_seen += inputs_batch.numel()
            global_step += 1

            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_dataloader, eval_dataloader, device, eval_iter
                )

                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)

                print(
                    f"Epoch: {epoch+1} (step {global_step:06d}):",
                    f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}"
                )

        generate_and_print_sample(
            model, train_dataloader.dataset.tokenizer, device, start_context
        )
        save_model_checkpoint(model , optimizer , epoch+1)


    return train_losses, val_losses, track_tokens_seen


# GPT Config 

In [16]:
torch.manual_seed(123)
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"context_length": 126,
# Context length
"emb_dim": 768,
# Embedding dimension
"n_heads": 12,
# Number of attention heads
"n_layers": 12,
# Number of layers
"drop_rate": 0.1,
# Dropout rate
"qkv_bias": False
# Query-Key-Value bias
}


In [17]:
model = GPTMQModel2(GPT_CONFIG_124M)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
optimizer  =  torch.optim.AdamW(model.parameters() , lr = 0.0004  , weight_decay= 0.01)
num_epochs = 1
train_ratio = 0.90
filename =  '/kaggle/input/alphaco/alpaca_data_cleaned.json'

with open(filename , 'r') as f:
    text_data = json.load(f)

# text_data = text_data[:50]
print(len(text_data))
split = int(train_ratio * len(text_data))
print(split)
train_data= text_data[:split]
val_data = text_data[split:]
# train_dataloader = create_dataloader_v1(txt= train_data , batch_size= 2 , max_length=GPT_CONFIG_124M['context_length'] , shuffle =  True , drop_last=True , stride=GPT_CONFIG_124M['context_length'])
# val_dataloader = create_dataloader_v1(txt= val_data , batch_size= 2 , max_length=GPT_CONFIG_124M['context_length'] , shuffle =  False , drop_last=False , stride=GPT_CONFIG_124M['context_length'])
train_dataloader = create_dataloader_v1(data= train_data , batch_size= 2 , max_length=GPT_CONFIG_124M['context_length'] , shuffle =  True , drop_last=True )
val_dataloader = create_dataloader_v1(data= val_data , batch_size= 2 , max_length=GPT_CONFIG_124M['context_length'] , shuffle =  False , drop_last=False )
start_context = '### Instruction :Give three tips for staying healthy ### Response:'
print('start trainning')
train_losses , val_losses  , token_seen = train_model(
    model= model , train_dataloader= train_dataloader , 
    eval_dataloader= val_dataloader , optimizer= optimizer , eval_freq=5 , device= device,
    eval_iter=3 , start_context=start_context, num_epochs=10
)






51760
46584
start trainning
🚀 Total training steps: 311980


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 (step 000000): Train Loss: 248.4611, Val Loss: 248.1597
Epoch: 1 (step 000005): Train Loss: 221.7021, Val Loss: 213.9365
Epoch: 1 (step 000010): Train Loss: 81.3126, Val Loss: 73.2363


  0%|          | 0/10 [00:09<?, ?it/s]


KeyboardInterrupt: 

In [None]:
def train_model_restart(
    model: nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    device: torch.device,
    eval_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    eval_freq: int,
    eval_iter: int,
    start_context: str,
    num_epochs: int = 1,
    checkpoint_path: str = None
):
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step, start_epoch = 0, -1, 0

    total_steps = len(train_dataloader) * num_epochs
    print(f"🚀 Total training steps: {total_steps}")

    # 🔁 Load from checkpoint if provided
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        global_step = checkpoint.get('global_step', -1)
        tokens_seen = checkpoint.get('tokens_seen', 0)

        # ⬇️ Reduce learning rate by half when resuming
        for param_group in optimizer.param_groups:
            old_lr = param_group['lr']
            param_group['lr'] = old_lr * 0.5
            print(f"🔧 Reduced LR: {old_lr:.6f} ➜ {param_group['lr']:.6f}")

        print(f"✅ Resuming training from Epoch {start_epoch}")

    # ⚙️ Reinitialize scheduler after changing LR
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=500,
        num_training_steps=total_steps
    )

    for epoch in tqdm(range(start_epoch, num_epochs)):
        model.train()
        for inputs_batch, target_batch in train_dataloader:
            inputs_batch, target_batch = inputs_batch.to(device), target_batch.to(device)

            optimizer.zero_grad()
            loss = cal_loss_batch(input_batch=inputs_batch, target_batch=target_batch, device=device, model=model)
            loss.backward()
            optimizer.step()
            scheduler.step()

            tokens_seen += inputs_batch.numel()
            global_step += 1

            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_dataloader, eval_dataloader, device, eval_iter
                )

                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)

                print(
                    f"Epoch: {epoch+1} (step {global_step:06d}):",
                    f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}"
                )

        generate_and_print_sample(
            model, train_dataloader.dataset.tokenizer, device, start_context
        )

        save_model_checkpoint(model, optimizer, epoch + 1, global_step, tokens_seen)

    return train_losses, val_losses, track_tokens_seen

def save_model_checkpoint(model, optimizer, epoch, global_step=None, tokens_seen=None, path="checkpoint_epoch.pt"):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    if global_step is not None:
        checkpoint['global_step'] = global_step
    if tokens_seen is not None:
        checkpoint['tokens_seen'] = tokens_seen

    torch.save(checkpoint, f"/kaggle/working/checkpoint_epoch_{epoch}.pt")
    print(f"💾 Saved checkpoint at epoch {epoch}")


In [None]:
        checkpoint = torch.load('/kaggle/working/checkpoint_epoch_6.pt', map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        global_step = checkpoint.get('global_step', -1)
        tokens_seen = checkpoint.get('tokens_seen', 0)

In [None]:
train_losses , val_losses  , token_seen = train_model(
    model= model , train_dataloader= train_dataloader , 
    eval_dataloader= val_dataloader , optimizer= optimizer , eval_freq=5 , device= device,
    eval_iter=3 , start_context=start_context, num_epochs=2
)






In [None]:
import json

loss_history = {
    "train_loss": train_losses,
    "val_loss": val_losses,
    "tokens_seen": token_seen
}

with open("loss_history.json", "w") as f:
    json.dump(loss_history, f)


In [None]:
import matplotlib.pyplot as plt

# If you loaded from a JSON file
# with open("loss_history.json", "r") as f:
#     data = json.load(f)
#     train_losses = data["train_loss"]
#     val_losses = data["val_loss"]

plt.figure(figsize=(10, 6))
plt.plot(train_losses, label="Train Loss", color="blue")
plt.plot(val_losses, label="Validation Loss", color="orange")
plt.xlabel("Evaluation Step")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid(True)
plt.savefig("loss_curve.png")  # Save the plot
plt.show()


In [None]:
import torch
import tiktoken

# Load model
model = GPTMQModel2(GPT_CONFIG_124M)
# model.load_state_dict(torch.load("/kaggle/working/checkpoint_epoch_7.pt"))
checkpoint = torch.load("checkpoint_epoch_7.pt")
model.load_state_dict(checkpoint["model_state_dict"])

model.eval().to(device)

# Tokenizer
tokenizer = tiktoken.get_encoding("gpt2")

# Utility functions
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded = torch.tensor(encoded).unsqueeze(0)
    return encoded

def token_ids_to_text(tokens, tokenizer):
    flat = tokens.squeeze(0)
    decode = tokenizer.decode(flat.tolist())
    return decode

# Sampling-based generate function (uses your logic)
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        seq_len = idx_cond.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(idx.device)
        causal_mask = causal_mask.unsqueeze(0)

        with torch.no_grad():
            logits = model(idx_cond, mask=causal_mask)

        logits = logits[:, -1, :]
        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(
                logits < min_val,
                torch.tensor(float('-inf')).to(logits.device),
                logits
            )
        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        idx = torch.cat((idx, idx_next), dim=1)

    return idx

# High-level text generation function
def generate_response(prompt, model, tokenizer, max_new_tokens=100, context_size=128, temperature=1.0, top_k=50):
    input_ids = text_to_token_ids(prompt, tokenizer).to(device)
    generated_ids = generate(
        model=model,
        idx=input_ids,
        max_new_tokens=max_new_tokens,
        context_size=context_size,
        temperature=temperature,
        top_k=top_k
    )
    return token_ids_to_text(generated_ids, tokenizer)

# Try it out
# prompt = "### Instruction:\nExplain what is deep learning.\n\n### Response:\n <bot>"
prompt = """

'### Instruction :Give three tips for staying healthy ### Response:'
""".strip()


output = generate_response(prompt, model, tokenizer)
print(output)


In [None]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded = torch.tensor(encoded).unsqueeze(0)
    return encoded
end_token_id = tokenizer.encode("<|endoftext|>", allowed_special={'<|endoftext|>'})[0]

def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        seq_len = idx_cond.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(idx.device)
        causal_mask = causal_mask.unsqueeze(0)

        with torch.no_grad():
            logits = model(idx_cond, mask=causal_mask)

        logits = logits[:, -1, :]
        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(
                logits < min_val,
                torch.tensor(float('-inf')).to(logits.device),
                logits
            )

        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        idx = torch.cat((idx, idx_next), dim=1)

        # Stop generation if <|endoftext|> is generated
        if idx_next.item() == end_token_id:
            break

    return idx
def truncate_after_n_bullets(text, n=3):
    lines = text.split("\n")
    count = 0
    result = []
    for line in lines:
        if line.strip().startswith(("1.", "2.", "3.")):
            count += 1
        result.append(line)
        if count >= n:
            break
    return "\n".join(result)
raw_output = generate_response(prompt, model, tokenizer)
cleaned_output = truncate_after_n_bullets(raw_output)
print(cleaned_output)



In [None]:
output = generate_response(
    prompt, model, tokenizer,
    temperature=0.8,  # better balance
    top_k=40,         # a bit narrower selection
    max_new_tokens=100
)


In [None]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded = torch.tensor(encoded).unsqueeze(0)
    return encoded

end_token_id = tokenizer.encode("<|endoftext|>", allowed_special={'<|endoftext|>'})[0]

def generate(model, idx, max_new_tokens, context_size, temperature=1.0, top_k=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        seq_len = idx_cond.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(idx.device).unsqueeze(0)

        with torch.no_grad():
            logits = model(idx_cond, mask=causal_mask)

        logits = logits[:, -1, :]

        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(
                logits < min_val,
                torch.tensor(float('-inf')).to(logits.device),
                logits
            )

        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        idx = torch.cat((idx, idx_next), dim=1)

        # Stop generation if <|endoftext|> is in the generated output
        if end_token_id in idx_next:
            break

    return idx

def truncate_after_n_bullets(text, n=3):
    lines = text.split("\n")
    count = 0
    result = []
    for line in lines:
        if line.strip().startswith(("1.", "2.", "3.")):
            count += 1
        result.append(line)
        if count >= n:
            break
    return "\n".join(result)

# 🔁 Input prompt
prompt = "### Instruction: What are the three primary colors? \n### Response:"

# 🔁 Tokenize input
input_ids = text_to_token_ids(prompt, tokenizer).to(device)

# 🔁 Generate output tokens
output_ids = generate(
    model=model,
    idx=input_ids,
    max_new_tokens=100,
    context_size=128,
    temperature=0.7,
    top_k=40
)

# 🔁 Decode and postprocess
output_text = tokenizer.decode(output_ids[0].tolist())

# ✂️ Truncate after 3 bullets (optional)
final_output = truncate_after_n_bullets(output_text)
print(final_output)


In [None]:
### Instruction: Give three tips for staying healthy
### Response:
1. Plan your meals: Instead of staying hydrated, set your meals and try to keep your meals informed about whether you need it.
2. Set achievable goals: Identify your tasks into your meals, such as using a variety of fruits, vegetables, vegetables, eggs, or fruits, vegetables.
3. Take a break: Take a few minutes early and enjoy a lot of fresh fruits, vegetables, and vegetables.

[ ]:

