In [2]:
import math 
import torch 
import torch.nn as nn 
import transformers 
from tqdm.notebook import tqdm
# Memory Network
import torch.nn.functional as F 
from typing import Tuple , Optional
import torch.bin 


In [3]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


# RMS NORM

In [4]:
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim)) 

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        norm_x = self._norm(x.float()).type_as(x) 
        return norm_x * self.scale 


# Embedding Layer

In [5]:
import torch 

class EmbeddingLayer(torch.nn.Module):
    def __init__(self , vocab_size , embedding_dim):
        super().__init__()

        self.embedding_layer= torch.nn.Embedding(vocab_size , embedding_dim)

    def forward(self , input_tokens):
        return self.embedding_layer(input_tokens)
class InputEmbedding(nn.Module):

    def __init__(self, vocab_size: int , d_model:int):

        super().__init__()

        self.d_model  =  d_model 

        self.vocab_size = vocab_size

        self.embeddings = nn.Embedding(vocab_size , d_model)

    def forward(self ,x):

        return self.embeddings(x) * math.sqrt(self.d_model)
    


# FeedForward Layer 

In [6]:
import torch.nn as nn

class GELU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return 0.5 * x *(1+ torch.tanh(torch.sqrt(torch.tensor(2.0/ torch.pi)) * (x+0.044715 * torch.pow(x, 3))) )
      

class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg['emb_dim'] , 4 * cfg['emb_dim']) ,
            GELU(),
            nn.Linear(4 * cfg['emb_dim'] , cfg['emb_dim'])
        )
    def forward(self, x ):
        return self.layers(x)
    


# Normalization Layer

In [7]:
import torch.nn as nn 
import torch 

class LayerNorm(nn.Module):
    def __init__(self , emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale  = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))
    def forward(self , x):
        mean = x.mean(dim= -1, keepdim = True)
        var = x.var(dim =-1, keepdim = True)
        norm_x = (x - mean) / torch.sqrt(var +self.eps)
        return self.scale * norm_x + self.shift 


# RoPE Embdding 

In [8]:
import numpy as np 
import torch 
from dataclasses import dataclass


class NRopE: # RopE in Numpy 
    def rotate_2d(self,vec , theta_p):
        cos_theta  , sin_theta  = np.cos(theta_p) , np.sin(theta_p)
        rotat_vec = np.array([[cos_theta , -sin_theta],
                    [sin_theta ,cos_theta]])
        
        return rotat_vec @ vec


    def RoPe(self,x , p , theta = 10000):
        d = len(x)
        x_rotate =  np.zeros_like(x)
        for i in range(0 , d , 2):
            if i +1< d:
                theta_p = (theta **(-2*(i//2)))**p 
                roted_pair = self.rotate_2d(x[i:i+1] , theta_p)    
                x_rotate[i:i+1] = roted_pair

        return x_rotate



@dataclass
class TRopE(torch.nn.Module): # RopE in torch 
    def __init__(self, dim:int ,theta:float = 10000):
        self.dim = dim 
        self.theta = theta 
        self.freq =  torch.pow(self.theta ,-torch.arange(0 ,dim  , 2)/dim )
        torch.nn.Parameter('freq' , self.freq)

    def forward(self, x:torch.Tensor , pos:torch.Tensor):
        batch_size , seq_len, dim = x.shape
        assert dim ==self.dim ,"Error dim must be same"
        theta_p = torch.einsum("n,d->nd" , pos, self.freq.to(x.device))
        cos_theta  , sin_theta = torch.cos(theta_p) , torch.sin(theta_p)
        x_even , x_odd =  x[... , ::2] , x[... , 1::2]
        x_rotated =  torch.empty_like(x)
        x_rotated[...,::2] =  x_even * cos_theta - x_odd * sin_theta
        x_rotated[...,1::2] =  x_even * sin_theta + x_odd * cos_theta

        return x_rotated







def precompute_freq_cis(  dim:int , end:int , theta:float = 10000.0):
        """dim : dimentions 
        end: end index   
        """
        freqs =  1/(theta **(torch.arange(0 , dim , 2)[:dim//2].float() / dim))
        t =  torch.arange(end, device=freqs.device)
        freqs = torch.outer(t , freqs).float()
        freqs_cis =  torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis 


def reshape_for_broadcast(freq_cis  , x):
        """ reshape the freqcies to match x dimentions """
        ndim=  x.ndim
        assert 0<=1<ndim 
        assert freq_cis.shape == (x.shape[1], x.shape[-1]), f"Expected {(x.shape[1], x.shape[-1])}, got {freq_cis.shape}" 
        shape = [d if i == 1 or i ==  ndim -1 else 1 for i , d in enumerate(x.shape)]
        return freq_cis.view(*shape)


def apply_rotary_embedding( xq:torch.Tensor ,xk:torch.Tensor ,  freq_cis:torch.Tensor):

            xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1,2))

            xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1,2))


            freq_cies =  reshape_for_broadcast(freq_cis , xq_)
    

            xq_out = torch.view_as_real(xq_* freq_cies).flatten(3)
            
            xk_out = torch.view_as_real(xk_*freq_cies).flatten(3)


            return  xq_out.type_as(xq)   ,  xk_out.type_as(xq) 





# MultiHead & MultiQuery Attention Layer 

In [9]:

class MultiHeadAttention_V2(nn.Module):
    def __init__(self, d_in , d_out , context_length  , dropout ,num_heads,qkv_bias = False):
        super().__init__()
        assert d_out % num_heads  == 0,'d_out must be divisible by the num_heads'
        self.w_query = nn.Linear(d_in , d_out ,bias=qkv_bias)
        self.w_key = nn.Linear(d_in , d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in  , d_out,bias=qkv_bias)
        self.d_in =d_in
        self.d_out = d_out
        self.dropout = nn.Dropout(dropout)
        self.num_heads  = num_heads
        self.head_dim = d_out // num_heads
        self.out_proj  = nn.Linear(d_out , d_out)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length , context_length),diagonal=1)
        )

    def forward(self,x):
        b, num_tokens , d_in = x.shape
        keys = self.w_key(x)
        queries  = self.w_query(x)
        values = self.w_value(x)
        queries = queries.view(b, num_tokens , self.num_heads , self.head_dim)
        values = values.view(b , num_tokens , self.num_heads , self.head_dim)
        keys = keys.view( b, num_tokens , self.num_heads , self.head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool= self.mask.bool()[:num_tokens , :num_tokens]
        attn_scores.masked_fill(mask_bool , -torch.inf)
        attn_weights = torch.softmax(attn_scores /self.head_dim**0.5   , dim=-1 )
        attn_weights = self.dropout(attn_weights)
        context_vector = (attn_weights  @ values).transpose(1, 2)
        context_vector = context_vector.contiguous().view(b , num_tokens , self.d_out)
        context_vector = self.out_proj(context_vector)
        return context_vector




def apply_rotary_embedding(xq:torch.Tensor , xk:torch.Tensor , freq_cies:torch.Tensor):

    assert xq.shape[-1] % 2 == 0 , 'Embeddig dimension must be even for complex paring'

    assert xk.shape[-1] % 2 == 0 , 'Embeddig dimension must be even for complex paring'


    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1,2))

    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1] , -1, 2))

    freq_cies = reshape_for_broadcast(freq_cies , xq_)

    xq_out = torch.view_as_real(xq_ * freq_cies ).flatten(3)

    xk_out = torch.view_as_real(xk_ * freq_cies).flatten(3)

    return xq_out.type_as(xq) ,  xk_out.type_as(xk)




class MultiQueryAttentionBlock(nn.Module):
    def __init__(self, d_model:int , h:int , dropout:float , seq_len:int , qkv_bias =  False ):
        super().__init__()
        self.d_model  = d_model 

        self.seq_len=  seq_len

        assert d_model % h == 0, "d_model is must be divided by th head"
        self.dropout = nn.Dropout(dropout)

        self.h = h  

        self.d_k = d_model // h 

        self.w_qkv =  nn.Linear(d_model , d_model +2 * self.d_k )

        self.w_o = nn.Linear(d_model  , d_model)

        freq_cies = precompute_freq_cis(dim=self.d_k , end=self.seq_len * 2 )

        self.register_buffer('freq_cies' , freq_cies , persistent= False )

    def generate_causal_mask(self, seq_len, device):
        # shape: (1, 1, seq_len, seq_len)
        return torch.tril(torch.ones((1, 1, seq_len, seq_len), device=device)).bool()

    @staticmethod
    def attention(q, k  , v,mask  , dropout):
        d_k = q.shape[-1]

        attention_score =  (q @ k.transpose(-2,-1)) / math.sqrt(d_k)

        if mask is not None :
            if mask.dim() == 2:
                      mask = mask.unsqueeze(1).unsqueeze(2)
            elif mask.dim() == 3:
                     mask = mask.unsqueeze(1)
            attention_score = attention_score.masked_fill(mask == 0, float('-inf'))
        
        attention_score = attention_score.softmax(dim=-1)

        if dropout is not None :
            attention_score = dropout(attention_score)

        context_vector =  attention_score @ v

        return context_vector  , attention_score
    


    def forward(self, q, mask= None):
        if mask is None:
            mask = self.generate_causal_mask(self.seq_len , device = q.device)
        qkv =  self.w_qkv(q)

        query , key, value =  torch.split(qkv , [self.d_model  , self.d_k , self.d_k], dim=-1)

        query = query.view(query.shape[0] , -1 , self.h , self.d_k).transpose(1, 2)

        key =  key.unsqueeze(1)

        value =  value.unsqueeze(1)

        seq_len =  q.size(1)

        freq_cies = self.freq_cies[:query.shape[1]].to(q.device)

        # freq_cies =  self.freq_cies[:seq_len].to(q.device)

        query , key = apply_rotary_embedding(query , key , freq_cies)

        x , self.attention_score = MultiQueryAttentionBlock.attention(q = query,k =  key,v= value ,mask=mask , dropout= self.dropout)

        x = x.transpose(1,2).contiguous().view(x.shape[0] , -1, self.h* self.d_k)

        x = self.w_o(x)

        return x 
    






# Memory Network 

In [27]:



class EfiBioSemanticMemory_V2(nn.Module):
    def __init__(self, input_dim:int ,semantic_memory_dim, max_slots:int = 1000 , compress_dim:int =  128 , top_k:int = 5 , num_heads:int =  4 ):
        super().__init__()

        self.input_dim = input_dim 
        self.max_slots =  max_slots 
        self.compress_dim =  compress_dim 
        self.top_k =  top_k 
        self.num_heads =  num_heads 
        self.semantic_memory_dim = semantic_memory_dim
        self.memory_size =  semantic_memory_dim 



        self.key_memory =  nn.Parameter(torch.randn(max_slots , compress_dim))
        self.value_memory = nn.Parameter(torch.randn(max_slots , compress_dim))
        self.cell_state =  nn.Parameter(torch.randn(max_slots , compress_dim))
        self.register_buffer('active_mask' , torch.zeros(max_slots , dtype= torch.bool))
        self.active_mask[:semantic_memory_dim] = True  


        # Meta data parameter 
        self.register_buffer('age', torch.zeros(max_slots))
        self.register_buffer('usage', torch.zeros(max_slots))
        self.register_buffer('concept_energy', torch.ones(max_slots))
        self.register_buffer('memory_age', torch.zeros(max_slots))
        self.register_buffer('access_count', torch.zeros(max_slots))
        self.register_buffer("_memory_version", torch.tensor(0))
        self.concept_energy[:semantic_memory_dim] =  0.2

        #stats params
        self.register_buffer("step_count", torch.zeros(1, dtype=torch.long))
        self.register_buffer("query_count", torch.zeros(1, dtype=torch.long))
        self.register_buffer("novel_count", torch.zeros(1, dtype=torch.long))
        self.register_buffer("write_count", torch.zeros(1, dtype=torch.long))
        self.register_buffer("hit_count", torch.zeros(1, dtype=torch.long)) 

        # Threshold Parameter 
        self.consolidation_threshold = nn.Parameter(torch.tensor(100.0))
        self.energy_threshold = nn.Parameter(torch.tensor(0.3))
        self.decay_rate = nn.Parameter(torch.tensor(0.999))
        # self.novelty_threshold = nn.Parameter(torch.tensor(0.2))
        self.novelty_threshold = 0.2 * (1 - (self.memory_size / self.max_slots))
        self.register_buffer("prune_age_threshold", torch.tensor(100))
        self.register_buffer("neurogenesis_threshold", torch.tensor(0.9))
        self.register_buffer("new_slot_maturation_steps", torch.tensor(50)) 
        self.synaptic_scale = nn.Parameter(torch.tensor(0.1))
        self.sparsity = nn.Parameter(torch.tensor(0.5))
        self.sim_thershold =  nn.Parameter(torch.tensor(0.4))

        
        # Networks 
        self.important_net = nn.Sequential(
            nn.Linear(semantic_memory_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

        

        self.update_gate = nn.Sequential(
            nn.Linear(3 * compress_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

        for layer in self.update_gate:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.constant_(layer.bias, 0.1) 
        self.forgot_gate = nn.Sequential(
            nn.Linear(semantic_memory_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 3),
            nn.Sigmoid()
        )
        self.compression = nn.Sequential(
            nn.Linear(input_dim, semantic_memory_dim),
            nn.RMSNorm(semantic_memory_dim),
            nn.GELU(),
            nn.Linear(semantic_memory_dim, self.compress_dim)
        )

        self.decompression = nn.Sequential(
            nn.Linear(self.compress_dim, input_dim),
            nn.GELU()
        )

        self.W_cell = nn.Linear(self.compress_dim, semantic_memory_dim, bias=False)
        self.memory_projection = nn.Linear(self.semantic_memory_dim, self.input_dim)

        self.attn = nn.MultiheadAttention(
            embed_dim=semantic_memory_dim,
            num_heads=num_heads,
            batch_first=False
        )
        nn.init.kaiming_uniform_(self.key_memory, mode='fan_out')
        nn.init.xavier_normal_(self.value_memory)
        nn.init.xavier_normal_(self.cell_state)

    def _get_active_memory(self):

        """
            Get the active memries slot 
        """
        idx = torch.nonzero(self.active_mask, as_tuple=False).squeeze(1)
        assert idx.numel() > 0, "No active memory slots"
       
        return  (
            self.key_memory[idx] , 
            self.value_memory[idx], 
            self.cell_state[idx]

        )
        
    @property 
    def active_capacity(self):
        return torch.sum(self.active_mask).item()/ self.max_slots
    
    def _retrive_memory(self ,query:torch.Tensor , batch_size:int , seq_len:int):

        k_active , v_active , c_active =  self._get_active_memory()

        k  = k_active.unsqueeze(1).expand(-1 , batch_size ,-1)
        v = v_active.unsqueeze(1).expand(-1, batch_size , -1)
        assert k.size(1) == batch_size

        attn_output , attn_weights = self.attn(
            query.unsqueeze(0), k ,v , need_weights =  True 
        )
        attn_output = attn_output + torch.randn_like(attn_output) * 0.1
        retrived =  attn_output.squeeze(0)
        attn_score =  F.cosine_similarity(query.unsqueeze(1), k_active.unsqueeze(0), dim=-1)

        active_indices =  torch.where(self.active_mask)[0]
        topk_values , topk_idx =  attn_score.topk(
            min(self.top_k , len(active_indices)) , dim=-1
        )

        topk_idx =  active_indices[topk_idx]

        # Project the output to the out 
        out =  self.memory_projection(retrived)
        out = out.unsqueeze(1).repeat(1, seq_len, 1)
        return  out , retrived , topk_idx , attn_weights


    def _adaptive_decay(self, memory_idx):

 

        energy = self.concept_energy[memory_idx]
        access = self.usage[memory_idx]

        decay = torch.exp((1-self.decay_rate) * (1-energy) * (1-access))
        return decay  
    

    def _update_memory(self, topk_idx , cells:torch.Tensor , projected:torch.Tensor , update_gates):
        projected =  projected.unsqueeze(1).expand(-1,cells.size(1), -1)
        active_indices =  torch.where(self.active_mask)[0]
        new_slot_mask  = torch.isin(topk_idx, active_indices[-10:])
        bs , _ ,_ = projected.shape
        cell_input = cells + projected 
        cell_input =  torch.sigmoid(cell_input)

        decay_factor =  self._adaptive_decay(topk_idx)
        decay_factor =  decay_factor.unsqueeze(-1)
        cell_updates =  decay_factor * cell_input 
        delta = update_gates.unsqueeze(-1) * cell_updates
        cell_updates[new_slot_mask]*= 0.1
        # assert cell_updates.shape == (bs, topk_idx, self.compress_dim)
        # assert new_slot_mask.unsqueeze(-1).shape == (bs, topk_idx, 1)
        cell_updates = cell_updates * torch.where(new_slot_mask.unsqueeze(-1), 0.1, 1.0)

        batch_size, top_num , dim = delta.shape  
        flat_topk_idx =  topk_idx.view(-1)

        flat_delta =  (delta * self.synaptic_scale).view(-1, dim)
        flat_cells =  F.normalize(cells, dim=-1).view(-1, dim)

        with torch.no_grad():
            self.cell_state.data.index_put_(
                (flat_topk_idx , ),  
                flat_delta , 
                accumulate=True 
            )
            self.key_memory.data.index_add_(0, flat_topk_idx, flat_cells)
            self.value_memory.data.index_add_(0, flat_topk_idx, flat_cells)

            self.age.data += 1
            self.age.data[flat_topk_idx] = 0 

        decay =  self.decay_rate ** self.age.unsqueeze(-1)
        # In _update_memory()
        assert flat_topk_idx.max() < self.key_memory.size(0), \
            f"Index {flat_topk_idx.max()} >= {self.key_memory.size(0)}"
        with torch.no_grad():
            self.cell_state.data =  self.cell_state * decay
            self.key_memory.data = self.key_memory * decay 
            self.value_memory.data =  self.value_memory  * decay 

            self.key_memory.data[flat_topk_idx] = torch.tanh(F.normalize(self.key_memory.data[flat_topk_idx], dim=-1))
            self.value_memory.data[flat_topk_idx] = torch.tanh(F.normalize(self.value_memory.data[flat_topk_idx], dim=-1))
            self.cell_state.data[flat_topk_idx] = torch.tanh(self.cell_state.data[flat_topk_idx])


    def _get_low_energy_slots(self, candidate_indices):
        if len(candidate_indices) == 0:
            return candidate_indices 
        candidate_energy =  self.concept_energy[candidate_indices]

        sorted_indices   =  torch.argsort(candidate_energy)
        return candidate_indices[sorted_indices]
    
    @torch.no_grad()
    def _write_new_concept(self, new_concepts:torch.Tensor ):
        

        """
        Reinitialized the slots that have been deactivated  

        """
        batch_size  =  new_concepts.size(0)
        active_mask =  self.active_mask
        candidate =  torch.where(self.active_mask & (self.concept_energy < self.energy_threshold))[0]
        candidate =self._get_low_energy_slots(candidate)
        num_reuse = min(len(candidate), batch_size)

        if num_reuse > 0:
                reuse_idx = candidate[:num_reuse]
                batch_indices = torch.arange(num_reuse) 
                self.usage.data[reuse_idx] = 0.0
                self.age.data[reuse_idx] = 0.0
                self.memory_age[reuse_idx] = 0.0
                self.concept_energy.data[reuse_idx] = 0.5
                decay =  torch.sigmoid(self.concept_energy[reuse_idx])
                # Write data to memories
                old_key = self.key_memory[reuse_idx]
                old_val  = self.value_memory[reuse_idx]
                old_cell =  self.cell_state[reuse_idx]
                new_data = new_concepts[batch_indices]
                new_key =  new_concepts[:num_reuse]
                alpha =  0.5 
                decay = torch.sigmoid(self.concept_energy[reuse_idx]).unsqueeze(-1)
        
                self.key_memory.data[reuse_idx] = F.normalize(
                    0.3 * old_key + 0.7 * new_data * (1 - decay), 
                    dim=-1
                )
                self.value_memory.data[reuse_idx] = F.normalize(
                    0.3 * old_val + 0.7 * new_data * (1 - decay), 
                    dim=-1
                )
                gate_input = torch.cat([
                old_cell,
                old_key,
                new_data], dim=-1)

                update_gate = self.update_gate(gate_input)
                self.cell_state.data[reuse_idx] = F.normalize(
                update_gate * new_data + (1 - update_gate) * old_cell,
                dim=-1
            )

                
        #         self.cell_state.data[reuse_idx] = 0.5 * self.cell_state[reuse_idx] + (1-decay) * new_concepts
        remaining =  batch_size - num_reuse 
        if remaining  > 0 and self.memory_size < self.max_slots:
            add = min(remaining , self.max_slots - self.memory_size)
        

            start = self.memory_size 
            end =  start+ add 
            new_idx =  torch.arange(start, end ,device= self.key_memory.device)
            new_keys = new_concepts[num_reuse :num_reuse+add]
            assert new_idx.numel() == new_keys.size(0), (
            f"Shape mismatch: {new_idx.numel()} vs {new_keys.size(0)}"
        )
            assert new_idx.max() < self.key_memory.size(0), (
            f"Memory index {new_idx.max()} exceeds max slots {self.key_memory.size(0)}"
        )
            self.key_memory[new_idx] =  new_keys
            self.value_memory[new_idx] = new_keys.clone()
            self.cell_state.data[new_idx] = new_keys.clone()
            self.usage.data[new_idx]= 0
            self.age.data[new_idx] = 0 
            self.concept_energy.data[new_idx] = 0.5
            self.access_count.data[new_idx] = 0 
            self.active_mask.data[new_idx] =  True 


            self.memory_size = end
        assert self.memory_size <= self.max_slots
        assert torch.all(self.active_mask[:self.memory_size])
                

    
    def _update_energy_level(self):
        importance = self.important_net(self.key_memory).squeeze()
        assert torch.all(self.concept_energy >= 0)
        assert torch.all(self.concept_energy <= 1.01)  
        new_energy =  (
            self.decay_rate * self.concept_energy + 0.3 * self.usage + 0.1 * importance *(1-self.concept_energy)
                    )
        
        deactivated  = ~self.active_mask 

        valid_deactivate =  deactivated.nonzero().squeeze()
        valid_deactivate = valid_deactivate[(valid_deactivate>=0)&(valid_deactivate < self.memory_size)]
        with torch.no_grad():
            self.concept_energy.data =  torch.clamp(new_energy  , 0, 1)

            if valid_deactivate.numel() >0:
                self.key_memory.data[valid_deactivate] *=0.01
                self.value_memory.data[valid_deactivate]*=0.01
                self.cell_state.data[valid_deactivate]*=0.1

        self.active_concepts =  torch.sum(self.active_mask).clamp(min=0 , max=self.memory_size)


    def forward(self,x:torch.Tensor, training = True ):
        bs , seq_len , _ =  x.shape  
        self.step_count+=1 
        self.query_count+=bs
        compressed =  self.compression(x.mean(dim=1))
        projected =  self.W_cell(compressed)
        with torch.set_grad_enabled(training):
            out , retrived, topk_idx  , attn_w = self._retrive_memory(projected , bs, seq_len )
        max_scores , _ =  attn_w.max(dim=-1)
        max_scores = max_scores.squeeze(-1)
        hits =  (max_scores>0.5).sum()
        self.hit_count +=  hits
        self.novelty_threshold = 0.2 * (1 - self.active_capacity)
        novel_mask =  max_scores < self.novelty_threshold
        self.novel_count+=novel_mask.sum()
        key_active, _ , _ =  self._get_active_memory()
        if novel_mask.any():
            novel_projection  = projected[novel_mask]
            sim_scores  = F.cosine_similarity(novel_projection.unsqueeze(1), key_active.unsqueeze(0), dim=-1)
            is_novel = sim_scores.max(dim=-1).values < self.sim_thershold
            write_mask  =is_novel
            if write_mask.any():
                new_concepts = novel_projection[write_mask].detach()
                self.write_count += new_concepts.size(0)
                assert new_concepts.size(0) <= self.max_slots - self.memory_size, \
            "Exceeding maximum memory capacity"
                self._write_new_concept(new_concepts)

     
        with torch.no_grad():
            self.usage.data *= 0.95
            self.usage.data[topk_idx] +=0.1
            self.usage.data.clamp(0,1)
            self.usage.data = 0.9 * self.usage.data 
            self.usage.data.scatter_add_(0, topk_idx.flatten(), torch.ones_like(topk_idx, dtype=torch.float).flatten())
            self.usage.data.clamp_(max=1.0)
    # feed one more entirely new concept → should reuse slot 1
        if training:

            self._update_energy_level()
        keys= self.key_memory[topk_idx]
        value = self.value_memory[topk_idx]
        cells = self.cell_state[topk_idx]
        # self._update_memory_metadata(topk_idx)

        gate_input = torch.cat([
            keys, cells, projected.unsqueeze(1).expand(-1, self.top_k , -1)
        ], dim= -1) 

        update_gates =  self.update_gate(gate_input.view(-1, 3 *self.semantic_memory_dim))

        update_gates = update_gates.view(bs, self.top_k)

        self._update_memory(topk_idx=topk_idx, cells=cells ,projected=projected , update_gates=update_gates)
        # if training:
            # self._consolidate_important_memories()
            # self.neurogenesis()
        self._memory_version +=1 
        return out , retrived ,topk_idx ,  attn_w 
    
    @torch.no_grad()
    def _gradual_influence_increase(self):
        """
        Gradually increase the influence of newly added memory slots based on their age and access.
        """
        new_slots_mask  = (self.age <= self.new_slot_maturation_steps) & self.active_mask 
        if not torch.any(new_slots_mask):
            return  
        
        age_normalized = self.age[new_slots_mask] / self.new_slot_maturation_steps
        usage_normalized =  self.usage[new_slots_mask]

        growth_factor =torch.sigmoid((age_normalized + usage_normalized) * 3 ).unsqueeze(-1)

        self.key_memory[new_slots_mask] = F.normalize(self.key_memory[new_slots_mask] * (1 + growth_factor * 0.5),
        dim=-1)
        self.value_memory[new_slots_mask] =  F.normalize(self.value_memory[new_slots_mask]* (1+growth_factor * 0.3) ,dim=-1 )

        energy_boost = torch.clamp(0.1 * growth_factor.squeeze(), max=0.15)
        self.concept_energy.data[new_slots_mask] = torch.clamp(
        self.concept_energy[new_slots_mask] + energy_boost,
        min=0.3,
        max=0.7
    )

        self.age.data[new_slots_mask] += 1 

        
    def _consolidate_important_memories(self):
        print('Consalidate Happen')
        importance =  self.important_net(self.key_memory).squeeze()
        consolidate_mask  = importance > 0.1
        if consolidate_mask.any():
          with torch.no_grad():
            self.key_memory.data[consolidate_mask] = F.normalize(
                self.key_memory[consolidate_mask] , dim=-1
            )
            mean_value = self.semantic_value_memory[consolidate_mask].mean(dim=0)
           
            self.semantic_value_memory.data[consolidate_mask] = (
                    0.9 * self.semantic_value_memory[consolidate_mask] +
                    0.1 * mean_value
                )
            self.concept_energy.data[consolidate_mask] = torch.clamp(self.concept_energy[consolidate_mask] + 0.05, 0, 1)

    @torch.no_grad()
    def prune_memories(self):
        print("Memory Prunig Happen")
        prune_condidate =  (self.age > self.prune_age_threshold) & (self.usage < 0.01)
        if prune_condidate.any():
            self.key_memory.data[prune_condidate] *=  0.1
            self.value_memory.data[prune_condidate]*=0.01
            self.cell_state.data[prune_condidate] *= 0.01
            self.age.data[prune_condidate] = 0 
            self.usage[prune_condidate] = 0 
            self.concept_energy.data[prune_condidate] = 0.1

    @torch.no_grad()
    def _prune_slots(self):
            mask = self.age > self.prune_age_threshold
            self.key_memory.data[mask] *= 0.01
            self.value_memory.data[mask] *= 0.01
            self.cell_state.data[mask] *= 0.01
            self.concept_energy.data[mask] = 0


    def replay_consolidation(self, x: torch.Tensor):
        print('Memory Replay Happening ')
        active_key , active_value, _ = self._get_active_memory()
        if  self.training and random.random() < 0.2: 
            high_energy_mask = self.concept_energy > 0.8
            if high_energy_mask.sum() > 0:
                replay_keys = active_key[high_energy_mask]
                replay_values = active_value[high_energy_mask]
                
                replay_input = self.decompression(replay_values.mean(dim=0, keepdim=True))
                return replay_input
        return x

    def get_reusable_slots(self ,num_needed:int):
       

        age_score =  1-torch.sigmoid(self.age / 100) # old age 
        energy_score = (1-  self.concept_energy ) *2 
        usage_score = 1 - self.usage 
        reuse_scores = (0.4 * energy_score  + 0.3 * age_score + 0.3 * usage_score 
                       )

        mask =(self.concept_energy  < self.energy_threshold) & (self.age< 100)
        reuse_scores[~mask]= -float('inf')

        topk_scores  , candidates = torch.topk(reuse_scores, min(num_needed, self.memory_size))
        return candidates 

            
    def _reinitialize_slot(self, idx):
        """Reset a slot to initial state"""
        print("Reinitialize Slots ")
        with torch.no_grad():
            scale = 0.1 + 0.05 * torch.rand(1, device=idx.device)
            self.key_memory.data[idx] = torch.randn_like(self.key_memory[idx]) * scale
            self.value_memory.data[idx] = torch.randn_like(self.value_memory[idx]) * scale
            self.cell_state.data[idx] =  0.2 * self.cell_state.data[idx].mean(dim=0)
            
            # Reset metadata
            self.concept_energy.data[idx] = 0.3 + 0.2 * torch.rand_like(self.concept_energy[idx])
            self.usage.data[idx] =  0.1 * torch.rand_like(self.usage[idx])
            self.age.data[idx] = 0
            self.memory_age.data[idx] = 0
            self.access_count.data[idx] = 0
            self.active_mask.data[idx] = True 





    def _consalidate_new_slots(self):
        new_slot_indices =  torch.arange(self.memory_size - 10 , self.memory_size )
        new_slot_energy = self.concept_energy[new_slot_indices]
        with torch.no_grad():
            self.concept_energy.data[new_slot_indices] = torch.clamp(
                new_slot_energy + 0.1 * self.age[new_slot_indices] , 0 ,1
            )
            self.key_memory.data[new_slot_indices] *=  0.1
            self.value_memory.data[new_slot_indices] *= 0.1
            self.age.data[new_slot_indices] +=1
            self.access_count[new_slot_indices] +=1 

    def _update_memory_metadata(self, used_indices):
       
        self.access_count[used_indices] += 1
        
        self.memory_age += 1
        self.memory_age[used_indices] = 0
        with torch.no_grad():
            self.concept_energy.data[used_indices] += 0.1
            self.concept_energy.data= torch.clamp(self.concept_energy * 0.95, 0, 1)
    @torch.no_grad()
    def neurogenesis(self):
        print("Neurogenesis happen")
        if self.max_slots > self.memory_size:
            usage_rate =  (self.usage > 0.1).float().mean()
            if usage_rate > self.neurogenesis_threshold:
                reusable =  self.get_reusable_slots()
                num_reuse =  reusable.numel()
                if num_reuse > 0:
                    with torch.no_grad():
                        device =  self.key_memory.device 
                        self._reinitialize_slot(idx=reusable)
                        
                new_slots =  min(10 -  num_reuse , self.max_slots- self.memory_size)
                if new_slots > 0:
                    start_idx =self.memory_size
                    end_idx =  start_idx  + new_slots 
                    new_indices  = torch.arange(start_idx , end_idx, device = device )
                    self.key_memory.data[new_indices] = torch.randn(new_slots, self.compress_dim, device=device) * 0.1
                    self.value_memory.data[new_indices] = torch.randn(new_slots, self.compress_dim, device=device) * 0.1
                    self.cell_state.data[new_indices] = 0
                    self.concept_energy.data[new_indices] = 0.5
                    self.usage.data[new_indices] = 0.0
                    self.age.data[new_indices] = 0.0
                    self.access_count.data[new_indices] = 0.0
                    self.active_mask.data[new_indices] = True
                    self.memory_size += new_slots

                self._gradual_influence_increase()
                    


    @torch.no_grad()
    def _merge_similar_slots(self):
        print("Merging the similar slots ")
        device = self.key_memory.device
        active_idx = self.active_mask.nonzero(as_tuple=True)[0]
        if active_idx.numel() < 2:
            return

        act_keys = F.normalize(self.key_memory[active_idx], dim=1)  
        sim_mtx  = act_keys @ act_keys.T                         

        merge_thr = torch.clamp(0.95 - 0.01*self.memory_age[active_idx].mean(), 0.7, 0.9)

        rows, cols = torch.triu_indices(act_keys.size(0), act_keys.size(0), 1, device=device)
        mask_pairs = sim_mtx[rows, cols] > merge_thr
        if not mask_pairs.any():
            return

        # 4) union–find over those pairs
        parent = torch.arange(act_keys.size(0), device=device)
        def find(x):
            orig_x = x
            while parent[x] != x:
                parent[x] = parent[parent[x]]
                x = parent[x]
            parent[orig_x] = x  # Full path compression
            return x

        def union(u, v):
            ru, rv = find(u), find(v)
            if ru != rv:
                parent[rv] = ru

        uv = torch.stack([rows[mask_pairs], cols[mask_pairs]], dim=1)
        for u, v in uv:
            union(u.item(), v.item())

        # 5) cluster IDs, counts
        comp_ids = torch.tensor([find(i) for i in range(act_keys.size(0))], device=device)
        uniq, inv, counts = torch.unique(comp_ids, return_inverse=True, return_counts=True)

        to_merge = uniq[counts >= 2]
        if to_merge.numel() == 0:
            return

        # 7) build a mask of all slots belonging to merge‐worthy clusters
        cluster_mask = torch.isin(inv, to_merge)      
        cluster_ids  = inv[cluster_mask]           

        D = act_keys.size(1)
        C = uniq.size(0)
        sums_keys   = torch.zeros((C, D), device=device)
        sums_vals   = torch.zeros((C, D), device=device)
        sums_energy = torch.zeros((C,), device=device)
        sums_keys.scatter_add_(0,
            cluster_ids.unsqueeze(-1).expand(-1, D),
            act_keys[cluster_mask]
        )
        sums_vals.scatter_add_(0,
            cluster_ids.unsqueeze(-1).expand(-1, D),
            self.value_memory[active_idx][cluster_mask]
        )
        sums_energy.scatter_add_(0,
            cluster_ids,
            self.concept_energy[active_idx][cluster_mask]
        )
        counts_clamped = counts[to_merge].unsqueeze(-1).clamp(min=1).to(device)
        means_keys   = sums_keys   / counts_clamped
        means_vals   = sums_vals   / counts_clamped
        cmask =  counts <= 2
        means_energy = sums_energy / counts[cmask]

        positions = torch.arange(act_keys.size(0), device=device)
        first_pos = torch.full((C,), act_keys.size(0), device=device, dtype=torch.long)
        first_pos.scatter_reduce_(0, inv, positions, reduce="amin")
        keep_slots  = active_idx[first_pos] 
        with torch.no_grad():
            self.key_memory.data[keep_slots] = means_keys
            self.value_memory.data[keep_slots] = means_vals
            self.concept_energy.data[keep_slots] = means_energy

        all_active = torch.arange(act_keys.size(0), device=device)
        merged_mask = torch.isin(inv, to_merge) & (positions != first_pos[inv])
        drop_small = active_idx[merged_mask]
        if drop_small.numel() > 0:
            noise = torch.randn_like(self.key_memory[drop_small]) * 0.01
            with torch.no_grad():
                self.key_memory.data[drop_small]    = noise
                self.value_memory.data[drop_small]  = noise.clone()
                self.concept_energy.data[drop_small] = 0.2
                self.age.data[drop_small]           = 0
                self.usage.data[drop_small]         = 0
                self.access_count.data[drop_small]  = 0

      # Get cluster-wise metrics before merging
        cluster_usage = torch.zeros_like(means_energy)
        cluster_access = torch.zeros_like(means_energy)
        cluster_usage.scatter_add_(0, cluster_ids, self.usage[active_idx][cluster_mask])
        cluster_access.scatter_add_(0, cluster_ids, self.access_count[active_idx][cluster_mask])
        with torch.no_grad():
            self.usage.data[keep_slots] = cluster_usage / counts[counts >= 2]
            self.access_count.data[keep_slots] = cluster_access
            self.usage.data[drop_small]        = 0
            self.access_count.data[drop_small] = 0

        self._update_memory_metadata(keep_slots)
        
    def get_memory_metrics(self):

        "Return memory health and retivel param details"
        active_mask  =  self.concept_energy > self.energy_threshold
        energy  =  self.concept_energy 
        usage =  self.usage 
        access =  self.access_count

        age_hist   = torch.histc(self.memory_age.float(), bins=10, min=0, max=float(self.memory_age.max()))
        usage_hist = torch.histc(usage, bins=10, min=0, max=1.0)
        access_hist= torch.histc(access.float(), bins=10, min=0, max=float(access.max()))

        return  {
            'memory_size': self.memory_size  , 
            'active_concepts':active_mask.sum().item(),
            'utilization':active_mask.sum().item() / self.memory_size , 
            'energy_mean':energy.mean().item(), 
            'energy_std':energy.std().item(), 
            'age_histogram': age_hist, 
            'usage_histogram':usage_hist, 
            'access_histogram':access_hist, 
            'merge_rate':(energy < 0.3).sum().item() / self.memory_size ,
            "prune_rate":((self.age > self.prune_age_threshold) & (usage < 0.01)).float().mean().item(),            
            'neuro_rate':(self.memory_age < 10).float().mean().item(), 
            "reuse_efficiency":      access[energy > 0.5].float().mean().item(),

             # —— retrieval/write stats ———————————————————————
            "steps":                 self.step_count.item(),
            "queries":               self.query_count.item(),
            "novelty_rate":          self.novel_count.item() / max(1, self.query_count.item()),
            "write_rate":            self.write_count.item() / max(1, self.query_count.item()),
            "hit_rate":              self.hit_count.item() / max(1, self.query_count.item()),
            
        }


        

# Transformer Block

In [28]:



class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention_V2(
        d_in=cfg["emb_dim"],
        d_out=cfg["emb_dim"],
        context_length=cfg["context_length"],
        num_heads=cfg["n_heads"],
        dropout=cfg["drop_rate"],
        qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_resid = nn.Dropout(cfg["drop_rate"])
    def forward(self, x):
    #A
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.drop_resid(x)
        x = x + shortcut # Add the original input back
        shortcut = x #B
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_resid(x)
        x = x + shortcut #C
        return x



class TransformerBlock_v2(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attention =  MultiQueryAttentionBlock(d_model=cfg['emb_dim'], h=cfg['n_heads'] , dropout=cfg['drop_rate'], seq_len=  cfg['context_length'] ,qkv_bias=cfg['qkv_bias'])

        self.feed_forward = FeedForward(cfg)

        self.layernorm1 =  LayerNorm(cfg['emb_dim'])
    
        self.layernorm2 =  LayerNorm(cfg['emb_dim'])

        self.drop_out = nn.Dropout(cfg['drop_rate'])

    def forward(self, x , mask= None):

        attention_output =  self.attention(self.layernorm1(x) , mask =  mask)

        ff_output =  self.feed_forward(self.layernorm2(x))

        return x + self.drop_out(ff_output) + self.drop_out(attention_output)
class TransformerBlockWithMemory(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attention = MultiQueryAttentionBlock(
            d_model=cfg['emb_dim'],
            h=cfg['n_heads'],
            dropout=cfg['drop_rate'],
            seq_len=cfg['context_length'],
            qkv_bias=cfg['qkv_bias']
        )
        self.memory = MemorySystem(cfg=cfg)
        self.feed_forward = FeedForward(cfg=cfg)
        
        self.norm1 = LayerNorm(cfg['emb_dim'])
        self.norm2 = LayerNorm(cfg['emb_dim'])
        self.norm3 = LayerNorm(cfg['emb_dim'])
        
        # Memory gate
        self.memory_gate = nn.Sequential(
            nn.Linear(cfg['emb_dim'] * 2, cfg['emb_dim']),
            nn.Sigmoid()
        )
        
        self.dropout = nn.Dropout(cfg['drop_rate'])

    def forward(self, x, mask=None):
        attn_out = self.attention(self.norm1(x), mask=mask)
        x = x + self.dropout(attn_out)
        
        norm_x = self.norm2(x)
        epic_out , semantic_out , memory_out = self.memory(norm_x)
        
        gate_input = torch.cat([norm_x, memory_out], dim=-1)
        memory_gate = self.memory_gate(gate_input)
        x = x + memory_gate * memory_out
        
        ff_out = self.feed_forward(self.norm3(x))
        x = x + self.dropout(ff_out)
        
        return x
class TransformerBlockWithMemory(nn.Module):
    def __init__(self, cfg, shared_memory=None):
        super().__init__()
        # Core components
        self.attention = MultiQueryAttentionBlock(
            d_model=cfg['emb_dim'],
            h=cfg['n_heads'],
            dropout=cfg['drop_rate'],
            seq_len=cfg['context_length'],
            qkv_bias=cfg['qkv_bias']
        )
        self.ffn = FeedForward(cfg=cfg)
        
        # Memory system (shared across blocks)
        self.memory = shared_memory or MemorySystem(cfg=cfg)
        
        # Normalization layers
        self.pre_ln_attn = RMSNorm(cfg['emb_dim'])
        self.pre_ln_mem = RMSNorm(cfg['emb_dim'])
        self.pre_ln_ffn = RMSNorm(cfg['emb_dim'])
        
        # Adaptive memory gating
        self.memory_gate = nn.Sequential(
            nn.Linear(cfg['emb_dim'], 1),
            nn.Sigmoid()
        )
        
        # Memory residual weights
        self.mem_alpha = nn.Parameter(torch.tensor(0.5))
        self.dropout = nn.Dropout(cfg['drop_rate'])

    def forward(self, x, mask=None):
        # Attention phase
        resid = x
        x = self.pre_ln_attn(x)
        x = resid + self.dropout(self.attention(x, mask=mask))
        
        # Memory phase
        resid_mem = x
        x_mem = self.pre_ln_mem(x)
        print('x shape ', x.shape)
        _, _, memory_out = self.memory(x_mem)
        
        # Adaptive gating
        gate = self.memory_gate(x_mem)
        x = resid_mem + self.mem_alpha * gate * memory_out
        
        # FFN phase
        resid_ffn = x
        x = self.pre_ln_ffn(x)
        x = resid_ffn + self.dropout(self.ffn(x))
        
        return x




# GPTQModel

In [29]:
import torch.nn.functional as F


class InputEmbedding(nn.Module):

    def __init__(self, vocab_size: int , d_model:int):

        super().__init__()

        self.d_model  =  d_model 

        self.vocab_size = vocab_size

        self.embeddings = nn.Embedding(vocab_size , d_model)

    def forward(self ,x):

        return self.embeddings(x) * math.sqrt(self.d_model)
    



class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int, embdding_layer: nn.Embedding):
        super().__init__()
        self.weight = embdding_layer.weight  # share weights with input embedding
        self.bias = nn.Parameter(torch.zeros(vocab_size))  # learnable bias

    def forward(self, x):
        return F.linear(x, self.weight, self.bias)




class GPTMQModel2(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.embedding = InputEmbedding(cfg['vocab_size'], cfg['emb_dim'])

        # Use ModuleList instead of Sequential
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock_v2(cfg=cfg) for _ in range(cfg['n_layers'])
        ])

        self.drop_out = nn.Dropout(cfg['drop_rate'])
        # self.final_norm = LayerNorm(emb_dim=cfg['emb_dim'])
        self.final_norm =  RMSNorm(dim = cfg['emb_dim'])

        self.projection = ProjectionLayer(cfg['emb_dim'], cfg['vocab_size'], self.embedding.embeddings)

    def forward(self, input_tokens, mask=None):
        x = self.embedding(input_tokens)

        for block in self.transformer_blocks:
            x = block(x, mask=mask)  # Pass the mask explicitly to each block

        x = self.final_norm(x)
        logits = self.projection(x)

        return logits
class GPTMQMemoryModel1(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.embedding = InputEmbedding(cfg['vocab_size'], cfg['emb_dim'])

        # Use ModuleList instead of Sequential
        self.transformer_blocks =  nn.ModuleList([
            TransformerBlockWithMemory(cfg=cfg) for _ in range(cfg['n_layers'])
        ])

        self.drop_out = nn.Dropout(cfg['drop_rate'])
        # self.final_norm = LayerNorm(emb_dim=cfg['emb_dim'])
        self.final_norm =  RMSNorm(dim=cfg['emb_dim'])

        self.projection = ProjectionLayer(cfg['emb_dim'], cfg['vocab_size'], self.embedding.embeddings)

    def forward(self, input_tokens, mask=None):
        x = self.embedding(input_tokens)

        for block in self.transformer_blocks:
            x = block(x, mask=mask)  # Pass the mask explicitly to each block

        x = self.final_norm(x)
        logits = self.projection(x)

        return logits
class GPTMQMemoryModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.embedding = InputEmbedding(cfg['vocab_size'], cfg['emb_dim'])
        
        # Shared memory system across layers
        self.shared_memory = MemorySystem(cfg=cfg)
        
        # Transformer blocks with shared memory
        self.transformer_blocks = nn.ModuleList([
            TransformerBlockWithMemory(
                cfg=cfg,
                shared_memory=self.shared_memory if cfg['share_memory'] else None
            ) for _ in range(cfg['n_layers'])
        ])
        
        # Final projections
        self.final_norm = RMSNorm(dim=cfg['emb_dim'])
        self.projection = ProjectionLayer(
            cfg['emb_dim'], 
            cfg['vocab_size'], 
            self.embedding.embeddings
        )
        self.memory_retention_alpha = nn.Parameter(torch.tensor(0.9))

        # Memory loss coefficient
        self.mem_loss_coef = cfg.get('mem_loss_coef', 0.3)

    def forward(self, input_tokens, mask=None):
        x = self.embedding(input_tokens)
        
        for block in self.transformer_blocks:
            x = block(x, mask=mask)
            x = self.memory_retention_alpha * x + (1 - self.memory_retention_alpha) * x.detach()
            
        x = self.final_norm(x)
        logits = self.projection(x)
        
        return logits
    
    def get_memory_loss(self):
        """Get combined memory regularization loss"""
        return self.mem_loss_coef * self.shared_memory.memory_loss()
    
    def transformer_parameters(self):
        return [p for n, p in self.named_parameters() if 'transformer_blocks' in n and p.requires_grad]
    
    def memory_parameters(self):
        return [p for n, p in self.named_parameters() if 'memory_modules' in n and p.requires_grad]
    
    def embedding_parameters(self):
        return [p for n, p in self.named_parameters() if 'embedding' in n and p.requires_grad]
    
    def norm_parameters(self):
        return [p for n, p in self.named_parameters() if 'normalization' in n and p.requires_grad]
    
    def output_parameters(self):
        return [p for n, p in self.named_parameters() if 'output_projection' in n and p.requires_grad]



In [30]:

class GPTMemoryEnhanced(nn.Module):
    def __init__(self,cfg):
        super().__init__()
        self.embedding =  InputEmbedding(cfg['vocab_size'] , cfg['emb_dim'])
        self.memory_proj = nn.Linear(cfg['emb_dim'], cfg['memory_dim'])
        self.memory_expander = nn.Linear(cfg['memory_dim'], cfg['emb_dim'])
        self.transformer_block = nn.ModuleList([
            TransformerBlock_v2(cfg=cfg) for _ in range(cfg['n_layers'])
        ])
        self.dropout =  nn.Dropout(cfg['drop_rate'])
        self.memory =  EfiBioSemanticMemory_V2(input_dim=cfg['memory_dim'] ,semantic_memory_dim=cfg['memory_dim'],num_heads=2)

        self.final_norm = RMSNorm(dim=cfg['emb_dim'])

        self.projection =  ProjectionLayer(cfg['emb_dim'], cfg['vocab_size'] , self.embedding.embeddings)

    def forward(self,input_tokens:torch.Tensor , mask =None):
        x =  self.embedding(input_tokens)
        for block in self.transformer_block:
            x  = block(x, mask = mask)
        memory_query = self.memory_proj(x) 
        mem_out, retrieved, topk_idx, attn_w = self.memory(memory_query)
        # memory_out, _ ,_ =  self.memory(x.las_hidden_state.mean(1))
        mem_out = self.memory_expander(mem_out)
        fused  = x + mem_out

        fused =  self.final_norm(fused)

        logits =  self.projection(fused)
        # return logits, {
        #     "memory_topk": topk_idx, 
        #     "memory_attn": attn_w,
        #     "retrieved":  retrieved
        # }
        return  logits
    


# Loss Functions

In [31]:




def cal_loss_batch(input_batch , target_batch , model:torch.nn.Module , device:torch.device):
    input_batch , target_batch = input_batch.to(device) , target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(   logits.flatten(0,1), target_batch.flatten())
    return loss


def calc_loss_loader(data_loader , model , device , num_batches = None):
    total_loss = 0
    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches  = min(num_batches , len(data_loader))
    for i , (inputs , target) in enumerate(data_loader):
        if i < num_batches:
            loss  =  cal_loss_batch(inputs , target , model , device)

            total_loss +=loss.item()

        else:
            break

        return total_loss  / num_batches
    


# Text Generation Function

In [32]:
import torch
import torch.nn as nn 




def text_to_token_ids(text,  tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded = torch.tensor(encoded).unsqueeze(0)
    return encoded


def token_ids_to_text(tokens , tokenizer):
    flat  = tokens.squeeze(0)
    decode = tokenizer.decode(flat.tolist())
    return decode

    
def generate_and_sample(model  , idx , context_size ,max_new_tokens ):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
            print(logits.shape)
        logits  = logits[:, -1  , :]
        print(logits.shape)
        probs  = torch.softmax(logits  , dim=-1)
        print(probs)
        idx_next = torch.argmax(probs, dim=-1 , keepdim= True)
        print(idx_next)
        idx = torch.cat((idx, idx_next), dim=1)
    return idx 

#
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]  # shape: [1, current_seq_len]

        # Create causal mask dynamically
        seq_len = idx_cond.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(idx.device)
        causal_mask = causal_mask.unsqueeze(0)  # [1, seq_len, seq_len]

        with torch.no_grad():
            logits = model(idx_cond, mask=causal_mask)  # <--- pass mask here

        logits = logits[:, -1, :]  # only take the last token logits

        # Apply top-k sampling if needed
        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(
                logits < min_val,
                torch.tensor(float('-inf')).to(logits.device),
                logits
            )

        # Temperature sampling
        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        idx = torch.cat((idx, idx_next), dim=1)
        

    return idx
def real_time_generation(model, initial_input, context_size, temperature, top_k=None, device="cpu"):
    # Tokenize the initial input and prepare the model context
    idx = torch.tensor(initial_input).unsqueeze(0).to(device)  # Assuming initial_input is tokenized
    
    print("Starting real-time generation...")
    
    # Start generating tokens in real-time
    for new_token in generate(model, idx, max_new_tokens=50, context_size=context_size, temperature=temperature, top_k=top_k, device=device):
        print(f"Generated token: {new_token.item()}")  # Or decode it back to a word
        
        # You can check for user input here and update idx with the new input
        # For instance, wait for the user to input a prompt to append to the context
        user_input = input("Enter new input (or press enter to continue generation): ")
        
        if user_input:
            # Tokenize the new user input and append it to the context
            user_input_tokens = torch.tensor(tokenize(user_input)).unsqueeze(0).to(device)
            idx = torch.cat((idx, user_input_tokens), dim=1)  # Append the new tokens to the context
        else:
            # Continue generating if no new user input
            continue

# Function to tokenize input (adjust depending on your tokenizer)
def tokenize(text):
    # Assuming you have a tokenizer function available
    return [ord(c) for c in text]  # Dummy example: ord() converts char to token id



# Dataset and DataLoader 

In [16]:
import torch
from torch.utils.data import Dataset, DataLoader
import tiktoken
import json
from torch.nn.utils.rnn import pad_sequence


def generate_prompt(sample):
    # return f"<user> {sample['instruction']} <bot> {sample['output']}"
        return f"### Instruction:\n{sample['instruction']}\n\n### Response:\n{sample['output']} <|endoftext|>"
    


class Dataset_V1(Dataset):
    def __init__(self, data, tokenizer, max_length, stride):
        self.max_length = max_length
        self.input_ids = []
        self.target_ids = []
        self.tokenizer =  tokenizer

        all_tokens = []
        allowed = {'<|endoftext|>'}
        for sample in data:
            prompt = generate_prompt(sample)
            tokens = tokenizer.encode(prompt , allowed_special=allowed)
            all_tokens.extend(tokens)

        for i in range(0, len(all_tokens) - max_length, stride):
            input_chunk = all_tokens[i: i + max_length]
            target_chunk = all_tokens[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]
def collate_fn(batch):
    inputs, targets = zip(*batch)
    inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
    targets = pad_sequence(targets, batch_first=True, padding_value=-100)  # -100 is ignored by CrossEntropyLoss
    return inputs, targets

def create_dataloader_v1(data, batch_size=4,
    max_length=256, stride=128, shuffle=True, drop_last=True):
    tokenizer = tiktoken.get_encoding("gpt2") #tokenizer 
    dataset = Dataset_V1(data, tokenizer, max_length, stride) #B
    dataloader = DataLoader(
    dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last , collate_fn=collate_fn)
    return dataloader


# Dataset And DataLoader for Psycology Dataset 

In [17]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import tiktoken
from torch.nn.utils.rnn import pad_sequence

class Dataset_v2(Dataset):
    def __init__(self, data, tokenizer, max_length, stride):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.stride = stride
        self.input_ids = []

        all_tokens = []
        for sample in data:
            tokens = tokenizer.encode(sample)  
            all_tokens.extend(tokens)

        # Split the tokens into chunks of size max_length with stride
        for i in range(0, len(all_tokens) - self.max_length, self.stride):
            input_chunk = all_tokens[i:i + self.max_length]
            target_chunk = all_tokens[i + 1:i + self.max_length + 1]
            self.input_ids.append((torch.tensor(input_chunk), torch.tensor(target_chunk)))

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, index):
        return self.input_ids[index]

def collect_fn(batch):
    inputs, targets = zip(*batch)
    inputs = pad_sequence(inputs, batch_first=True, padding_value=0) 
    targets = pad_sequence(targets, batch_first=True, padding_value=-100)  
    return inputs, targets

def create_dataloader_v2(data, batch_size=4, max_length=1024, stride=128, shuffle=True, drop_last=True):
    tokenizer = tiktoken.get_encoding("gpt2")  
    dataset = Dataset_v2(data, tokenizer, max_length, stride) 
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collect_fn)
    return dataloader

def load_txt_file(filepath):
    with open(filepath, 'r') as f:
        text = f.read()
    return text

def split_into_chunks(text, chunk_size=1024, overlap=200):

    chunks = []
    for i in range(0, len(text), chunk_size - overlap):
        chunk = text[i:i + chunk_size]
        chunks.append(chunk)

    return chunks



file =  '/kaggle/input/datasetcleaned/cleaned_books.txt'
load_text =  load_txt_file(file)
chunk = split_into_chunks(load_text)


# Train Script 

In [18]:

from tqdm.auto import tqdm
from transformers import get_cosine_schedule_with_warmup


def evaluate_model(model, train_dataloader, eval_dataloader, device, eval_iter):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_dataloader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(eval_dataloader, model, device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss


def generate_and_print_sample(model, tokenizer, device, start_context):
    model.eval()

    encoded = text_to_token_ids(start_context, tokenizer).to(device)

    with torch.no_grad():
        token_ids = generate(
            model=model,
            idx=encoded,
            temperature=1.4,
            max_new_tokens=64,   # Increase generation length if needed
            context_size=126,
            top_k=25
        )
        decoded_text = token_ids_to_text(token_ids, tokenizer)

        # Trim everything before the generation
        generated_only = decoded_text[len(start_context):].strip()

        # Stop at endoftext token if present
        end_marker = "<|endoftext|>"
        if end_marker in generated_only:
            generated_only = generated_only.split(end_marker)[0].strip()

        print(f"\n[Prompt]: {start_context.strip()}\n")
        print(f"[Generated]: {generated_only}\n")

    model.train()
def save_model_checkpoint(model, optimizer, epoch, path="checkpoint_epoch_{}.pt"):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch
    }
    torch.save(checkpoint, path.format(epoch))
def after_save_load():
    checkpoint = torch.load("checkpoint_epoch_7.pt")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']


def train_model(
    model: nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    device: torch.device,
    eval_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    eval_freq: int,
    eval_iter: int,
    start_context: str,
    num_epochs: int = 1
):
    torch.autograd.set_detect_anomaly(True)
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1
    total_steps = len(train_dataloader) * num_epochs
    print(f"🚀 Total training steps: {total_steps}")
    # scheduler = get_cosine_schedule_with_warmup(optimizer,
    #                                         num_warmup_steps=500,
    #                                         num_training_steps=total_steps)
    for epoch in tqdm(range(num_epochs)):
        model.train()
        for inputs_batch, target_batch in train_dataloader:
            inputs_batch, target_batch = inputs_batch.to(device), target_batch.to(device)

            optimizer.zero_grad()
            loss = cal_loss_batch(input_batch=inputs_batch, target_batch=target_batch, device=device, model=model)
            loss.backward()
            optimizer.step()
            # scheduler.step()

            tokens_seen += inputs_batch.numel()
            global_step += 1

            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_dataloader, eval_dataloader, device, eval_iter
                )

                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)

                print(
                    f"Epoch: {epoch+1} (step {global_step:06d}):",
                    f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}"
                )

        generate_and_print_sample(
            model, train_dataloader.dataset.tokenizer, device, start_context
        )
        save_model_checkpoint(model , optimizer , epoch+1)


    return train_losses, val_losses, track_tokens_seen


# MemoryGPT Model training  


In [19]:
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
import numpy as np

# GPT Config 

In [20]:
torch.manual_seed(123)
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"context_length": 126,#126 
# Context lengt
"emb_dim": 768,
# Embedding dimension
"n_heads": 12,
# Number of attention heads
"n_layers": 12,
# Number of layers
"drop_rate": 0.1,
# Dropout rate
"qkv_bias": False
# Query-Key-Value bias
}


In [21]:
def train_memory_model(
    model: nn.Module,
    train_dataloader: DataLoader,
    device: torch.device,
    eval_dataloader: DataLoader,
    optimizer: torch.optim.Optimizer,
    eval_freq: int,
    eval_iter: int,
    start_context: str,
    num_epochs: int = 1,
    max_grad_norm: float = 1.0,
):
    torch.autograd.set_detect_anomaly(True)
    train_losses, val_losses, track_tokens_seen = [], [], []
    memory_metrics = []
    tokens_seen, global_step = 0, -1
    total_steps = len(train_dataloader) * num_epochs
    print(f"🚀 Total training steps: {total_steps}")

    # Create TensorBoard writer
    # writer = SummaryWriter(log_dir='runs/memory_experiment')

    for epoch in tqdm(range(num_epochs)):
        model.train()
        for batch_idx, (inputs_batch, target_batch) in enumerate(train_dataloader):
            inputs_batch, target_batch = inputs_batch.to(device), target_batch.to(device)

            optimizer.zero_grad()
            
            # Forward pass with memory
            logits = model(inputs_batch)
            
            # Calculate loss with memory regularization
            loss = cal_loss_batch(inputs_batch, target_batch, model, device)
            loss += model.memory_system.episodic_memory_cell.memory_regularization_loss()
            loss += model.memory_system.semantic_memory_cell.memory_regularization_loss()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()

            # Memory maintenance
            if batch_idx % 100 == 0:
                model.memory_system.maintain_memory()

            # Logging and evaluation
            tokens_seen += inputs_batch.numel()
            global_step += 1

            if global_step % eval_freq == 0:
                # Evaluate and get memory metrics
                train_loss, val_loss, mem_metrics = evaluate_memory_model(
                    model, train_dataloader, eval_dataloader, device, eval_iter
                )
                
                # Record metrics
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)
                memory_metrics.append(mem_metrics)
                
                # TensorBoard logging
                writer.add_scalar('Loss/Train', train_loss, global_step)
                writer.add_scalar('Loss/Val', val_loss, global_step)
                log_memory_metrics(writer, mem_metrics, global_step)
                
                print(f"Step {global_step}:")
                print(f"  Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
                print_memory_health(mem_metrics)

            # Generate samples with current memory state
            if global_step % (eval_freq*2) == 0:
                generate_with_memory(
                    model, 
                    train_dataloader.dataset.tokenizer, 
                    device, 
                    start_context,
                    writer,
                    global_step
                )

        # Save model and memory state
        save_memory_checkpoint(model, optimizer, epoch+1, global_step)

    writer.close()
    return train_losses, val_losses, track_tokens_seen, memory_metrics

def evaluate_memory_model(model, train_loader, eval_loader, device, eval_iter):
    model.eval()
    train_loss, val_loss = 0, 0
    memory_stats = []
    
    with torch.no_grad():
        # Evaluate on training set
        for i, (inputs, targets) in enumerate(train_loader):
            if i >= eval_iter: break
            inputs, targets = inputs.to(device), targets.to(device)
            logits = model(inputs)
            train_loss += cal_loss_batch(inputs, targets, model, device).item()
            
        # Evaluate on validation set
        for i, (inputs, targets) in enumerate(eval_loader):
            if i >= eval_iter: break
            inputs, targets = inputs.to(device), targets.to(device)
         
            logits =model(inputs)
            val_loss += cal_loss_batch(inputs, targets, model, device).item()
            
        # Get memory health metrics
        mem_health = model.memory_system.maintain_memory()
        
    model.train()
    return train_loss/eval_iter, val_loss/eval_iter, mem_health

def log_memory_metrics(writer, metrics, step):
    # Episodic memory metrics
    writer.add_scalar('Memory/Episodic/UsedSlots', 
                     metrics['episodic']['active_slots'], step)
    writer.add_scalar('Memory/Episodic/MeanSimilarity', 
                     metrics['episodic']['memory_similarity']['mean_similarity'], step)
    
    # Semantic memory metrics
    writer.add_scalar('Memory/Semantic/ActiveConcepts', 
                     metrics['semantic']['active_concepts'], step)
    writer.add_scalar('Memory/Semantic/EnergyMean', 
                     metrics['semantic']['energy_mean'], step)

def print_memory_health(metrics):
    print("Memory Health:")
    print(f"  Episodic: {metrics['episodic']['active_slots']} active slots")
    print(f"    Similarity: {metrics['episodic']['memory_similarity']['mean_similarity']:.3f}")
    print(f"  Semantic: {metrics['semantic']['active_concepts']} concepts")
    print(f"    Energy: {metrics['semantic']['energy_mean']:.3f}")

def generate_with_memory(model, tokenizer, device, start_context, writer=None, step=None):
    model.eval()
    with torch.no_grad():
        # Generate with current memory state
        input_ids = tokenizer.encode(start_context, return_tensors='pt').to(device)
        
        # Generate text with memory context
        outputs = model.generate(
            input_ids,
            max_length=100,
            temperature=0.8,
            do_sample=True,
            memory_context=model.memory_system.get_memory_state()
        )
        
        text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print("\nGenerated Text with Memory:")
        print(text)
        
        if writer:
            writer.add_text("Generated Text", text, step)
    
    model.train()

def save_memory_checkpoint(model, optimizer, epoch, step):
    state = {
        'epoch': epoch,
        'step': step,
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'memory_system': {
            'episodic': model.memory_system.episodic_memory_cell.get_memory(),
            'semantic': model.memory_system.semantic_memory_cell.get_memory(),
            'router': model.memory_system.memory_router.state_dict()
        }
    }
    torch.save(state, f"memory_checkpoint_{epoch}_{step}.pt")

In [22]:
# model = GPTMQModel2(GPT_CONFIG_124M)


# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# model.to(device)
# optimizer  =  torch.optim.AdamW(model.parameters() , lr = 0.0004  , weight_decay= 0.01)
# num_epochs = 1
# train_ratio = 0.90

# text_data = chunk
# print(len(text_data))
# split = int(train_ratio * len(text_data))
# print(split)
# train_data= text_data[:split]
# val_data = text_data[split:]
# # train_dataloader = create_dataloader_v1(txt= train_data , batch_size= 2 , max_length=GPT_CONFIG_124M['context_length'] , shuffle =  True , drop_last=True , stride=GPT_CONFIG_124M['context_length'])
# # val_dataloader = create_dataloader_v1(txt= val_data , batch_size= 2 , max_length=GPT_CONFIG_124M['context_length'] , shuffle =  False , drop_last=False , stride=GPT_CONFIG_124M['context_length'])
# train_dataloader = create_dataloader_v2(data= train_data , batch_size= 2 , max_length=GPT_CONFIG_124M['context_length'] , shuffle =  True , drop_last=True )
# val_dataloader = create_dataloader_v2(data= val_data , batch_size= 2 , max_length=GPT_CONFIG_124M['context_length'] , shuffle =  False , drop_last=False )
# start_context = '### Instruction :Give three tips for staying healthy ### Response:'
# print('start trainning')
# train_losses , val_losses  , token_seen = train_model(
#     model= model , train_dataloader= train_dataloader , 
#     eval_dataloader= val_dataloader , optimizer= optimizer , eval_freq=5 , device= device,
#     eval_iter=3 , start_context=start_context, num_epochs=2
# )





In [23]:
GPT_CONFIG_124M_Memory = {
"vocab_size": 50257, # Vocabulary size
"context_length": 126,
# Context length
"emb_dim": 128,
# Embedding dimension
"n_heads": 4,
# Number of attention heads
"n_layers": 12,
# Number of layers
"drop_rate": 0.1,
# Dropout rate
"qkv_bias": False,
'memory_dim':128,
'max_slots' :1000,
'memory_heads':2 ,

# Query-Key-Value bias
}

In [33]:

device =  'cuda' if torch.cuda.is_available() else "cpu"
model =  GPTMemoryEnhanced(GPT_CONFIG_124M_Memory).to(device)

optimizer =  torch.optim.AdamW(model.parameters() , lr=0.0004,weight_decay=0.01 )


In [34]:
from collections import defaultdict

def get_param_group_summary(model):
    groups = defaultdict(int)
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if "embedding" in name:
            groups["embedding"] += param.numel()
        elif "transformer" in name:
            groups["transformer_blocks"] += param.numel()
        elif "memory" in name or "episodic" in name or "semantic" in name:
            groups["memory_modules"] += param.numel()
        elif "norm" in name:
            groups["normalization"] += param.numel()
        elif "lm_head" in name or "projection" in name:
            groups["output_projection"] += param.numel()
        else:
            groups["other"] += param.numel()
    total = sum(groups.values())
    for k, v in groups.items():
        print(f"{k:20s}: {v:,} parameters")
    print(f"\nTotal: {total:,}")
get_param_group_summary(model)


embedding           : 6,432,896 parameters
memory_modules      : 648,587 parameters
transformer_blocks  : 2,082,048 parameters
normalization       : 128 parameters
output_projection   : 50,257 parameters

Total: 9,213,916


In [35]:





num_epochs =  1
train_ratio = 0.90

filename =  '/kaggle/input/alphaco/alpaca_data_cleaned.json'
with open(filename , 'r') as f:
    text_data =  json.load(f)
text_data = text_data[:200]
split = int(train_ratio * len(text_data))

train_data =  text_data[:split]
val_data =  text_data[split:]

train_dataloader =  create_dataloader_v1(data=train_data , batch_size=2 , max_length=GPT_CONFIG_124M_Memory['context_length'] , shuffle=True , drop_last= True)
val_dataloader = create_dataloader_v1(data=val_data , batch_size=2 , max_length=GPT_CONFIG_124M_Memory['context_length']  , shuffle=False , drop_last=False )


print('start training')
train_losses , val_losses , token_seen =  train_model(
    model=model , 
    train_dataloader=train_dataloader, 
    device=device, 
    eval_freq=5 , 
    eval_dataloader=val_dataloader , 
    optimizer=optimizer, 
    eval_iter=3,  
    num_epochs=1, 
    start_context='Hello '
)
print(model.memory.get_memory_metrics())

start training
🚀 Total training steps: 108


  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 1 (step 000000): Train Loss: 41.1002, Val Loss: 41.8124
Epoch: 1 (step 000005): Train Loss: 39.7114, Val Loss: 40.8948
Epoch: 1 (step 000010): Train Loss: 39.3642, Val Loss: 39.1970
Epoch: 1 (step 000015): Train Loss: 35.6180, Val Loss: 35.3556
Epoch: 1 (step 000020): Train Loss: 27.8489, Val Loss: 26.9149
Epoch: 1 (step 000025): Train Loss: 16.0474, Val Loss: 16.3597
Epoch: 1 (step 000030): Train Loss: 13.5512, Val Loss: 12.5578
Epoch: 1 (step 000035): Train Loss: 11.9057, Val Loss: 11.2702
Epoch: 1 (step 000040): Train Loss: 10.5021, Val Loss: 10.5689
Epoch: 1 (step 000045): Train Loss: 9.9649, Val Loss: 10.0796
Epoch: 1 (step 000050): Train Loss: 8.9862, Val Loss: 9.6967
Epoch: 1 (step 000055): Train Loss: 8.6183, Val Loss: 9.3394
Epoch: 1 (step 000060): Train Loss: 8.6824, Val Loss: 9.0257
Epoch: 1 (step 000065): Train Loss: 8.7497, Val Loss: 8.7808
Epoch: 1 (step 000070): Train Loss: 8.4276, Val Loss: 8.5648
Epoch: 1 (step 000075): Train Loss: 8.1757, Val Loss: 8.3697
Epoch

In [43]:
device

'cuda'

memory_modules      : 10,305,011 parameters
embedding           : 38,725,376 parameters
transformer_blocks  : 72,061,464 parameters
normalization       : 768 parameters
output_projection   : 50,257 parameters

Total: 121,142,876


In [56]:
from collections import defaultdict

def modulewise_param_count(model):
    module_params = defaultdict(int)
    for name, param in model.named_parameters():
        if param.requires_grad:
            module = name.split('.')[0]  # or custom parsing
            module_params[module] += param.numel()
    
    for module, count in sorted(module_params.items(), key=lambda x: -x[1]):
        print(f"{module:<20} : {count:,} parameters")

modulewise_param_count(model)


transformer_blocks   : 72,061,464 parameters
embedding            : 38,597,376 parameters
shared_memory        : 10,433,010 parameters
projection           : 50,257 parameters
final_norm           : 768 parameters
memory_retention_alpha : 1 parameters


In [None]:
def train_model_restart(
    model: nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    device: torch.device,
    eval_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    eval_freq: int,
    eval_iter: int,
    start_context: str,
    num_epochs: int = 1,
    checkpoint_path: str = None
):
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step, start_epoch = 0, -1, 0

    total_steps = len(train_dataloader) * num_epochs
    print(f"🚀 Total training steps: {total_steps}")

    # 🔁 Load from checkpoint if provided
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        global_step = checkpoint.get('global_step', -1)
        tokens_seen = checkpoint.get('tokens_seen', 0)

        # ⬇️ Reduce learning rate by half when resuming
        for param_group in optimizer.param_groups:
            old_lr = param_group['lr']
            param_group['lr'] = old_lr * 0.5
            print(f"🔧 Reduced LR: {old_lr:.6f} ➜ {param_group['lr']:.6f}")

        print(f"✅ Resuming training from Epoch {start_epoch}")

    # ⚙️ Reinitialize scheduler after changing LR
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=500,
        num_training_steps=total_steps
    )

    for epoch in tqdm(range(start_epoch, num_epochs)):
        model.train()
        for inputs_batch, target_batch in train_dataloader:
            inputs_batch, target_batch = inputs_batch.to(device), target_batch.to(device)

            optimizer.zero_grad()
            loss = cal_loss_batch(input_batch=inputs_batch, target_batch=target_batch, device=device, model=model)
            loss.backward()
            optimizer.step()
            scheduler.step()

            tokens_seen += inputs_batch.numel()
            global_step += 1

            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_dataloader, eval_dataloader, device, eval_iter
                )

                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)

                print(
                    f"Epoch: {epoch+1} (step {global_step:06d}):",
                    f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}"
                )

        generate_and_print_sample(
            model, train_dataloader.dataset.tokenizer, device, start_context
        )

        save_model_checkpoint(model, optimizer, epoch + 1, global_step, tokens_seen)

    return train_losses, val_losses, track_tokens_seen

def save_model_checkpoint(model, optimizer, epoch, global_step=None, tokens_seen=None, path="checkpoint_epoch.pt"):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    if global_step is not None:
        checkpoint['global_step'] = global_step
    if tokens_seen is not None:
        checkpoint['tokens_seen'] = tokens_seen

    torch.save(checkpoint, f"/kaggle/working/checkpoint_epoch_{epoch}.pt")
    print(f"💾 Saved checkpoint at epoch {epoch}")


In [None]:
        checkpoint = torch.load('/kaggle/working/checkpoint_epoch_6.pt', map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        global_step = checkpoint.get('global_step', -1)
        tokens_seen = checkpoint.get('tokens_seen', 0)

In [None]:
train_losses , val_losses  , token_seen = train_model(
    model= model , train_dataloader= train_dataloader , 
    eval_dataloader= val_dataloader , optimizer= optimizer , eval_freq=5 , device= device,
    eval_iter=3 , start_context=start_context, num_epochs=2
)






In [None]:
import json

loss_history = {
    "train_loss": train_losses,
    "val_loss": val_losses,
    "tokens_seen": token_seen
}

with open("loss_history.json", "w") as f:
    json.dump(loss_history, f)


In [None]:
import matplotlib.pyplot as plt

# If you loaded from a JSON file
# with open("loss_history.json", "r") as f:
#     data = json.load(f)
#     train_losses = data["train_loss"]
#     val_losses = data["val_loss"]

plt.figure(figsize=(10, 6))
plt.plot(train_losses, label="Train Loss", color="blue")
plt.plot(val_losses, label="Validation Loss", color="orange")
plt.xlabel("Evaluation Step")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid(True)
plt.savefig("loss_curve.png")  # Save the plot
plt.show()


In [None]:
import torch
import tiktoken

# Load model
model = GPTMQModel2(GPT_CONFIG_124M)
# model.load_state_dict(torch.load("/kaggle/working/checkpoint_epoch_7.pt"))
checkpoint = torch.load("checkpoint_epoch_7.pt")
model.load_state_dict(checkpoint["model_state_dict"])

model.eval().to(device)

# Tokenizer
tokenizer = tiktoken.get_encoding("gpt2")

# Utility functions
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded = torch.tensor(encoded).unsqueeze(0)
    return encoded

def token_ids_to_text(tokens, tokenizer):
    flat = tokens.squeeze(0)
    decode = tokenizer.decode(flat.tolist())
    return decode

# Sampling-based generate function (uses your logic)
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        seq_len = idx_cond.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(idx.device)
        causal_mask = causal_mask.unsqueeze(0)
f
        with torch.no_grad():
            logits = model(idx_cond, mask=causal_mask)

        logits = logits[:, -1, :]
        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(
                logits < min_val,
                torch.tensor(float('-inf')).to(logits.device),
                logits
            )
        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        idx = torch.cat((idx, idx_next), dim=1)

    return idx

# High-level text generation function
def generate_response(prompt, model, tokenizer, max_new_tokens=100, context_size=128, temperature=1.0, top_k=50):
    input_ids = text_to_token_ids(prompt, tokenizer).to(device)
    generated_ids = generate(
        model=model,
        idx=input_ids,
        max_new_tokens=max_new_tokens,
        context_size=context_size,
        temperature=temperature,
        top_k=top_k
    )
    return token_ids_to_text(generated_ids, tokenizer)

# Try it out
# prompt = "### Instruction:\nExplain what is deep learning.\n\n### Response:\n <bot>"
prompt = """

'### Instruction :Give three tips for staying healthy ### Response:'
""".strip()


output = generate_response(prompt, model, tokenizer)
print(output)


In [None]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded = torch.tensor(encoded).unsqueeze(0)
    return encoded
end_token_id = tokenizer.encode("<|endoftext|>", allowed_special={'<|endoftext|>'})[0]

def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        seq_len = idx_cond.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(idx.device)
        causal_mask = causal_mask.unsqueeze(0)

        with torch.no_grad():
            logits = model(idx_cond, mask=causal_mask)

        logits = logits[:, -1, :]
        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(
                logits < min_val,
                torch.tensor(float('-inf')).to(logits.device),
                logits
            )

        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        idx = torch.cat((idx, idx_next), dim=1)

        # Stop generation if <|endoftext|> is generated
        if idx_next.item() == end_token_id:
            break

    return idx
def truncate_after_n_bullets(text, n=3):
    lines = text.split("\n")
    count = 0
    result = []
    for line in lines:
        if line.strip().startswith(("1.", "2.", "3.")):
            count += 1
        result.append(line)
        if count >= n:
            break
    return "\n".join(result)
raw_output = generate_response(prompt, model, tokenizer)
cleaned_output = truncate_after_n_bullets(raw_output)
print(cleaned_output)



In [None]:
output = generate_response(
    prompt, model, tokenizer,
    temperature=0.8,  # better balance
    top_k=40,         # a bit narrower selection
    max_new_tokens=100
)


In [None]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded = torch.tensor(encoded).unsqueeze(0)
    return encoded

end_token_id = tokenizer.encode("<|endoftext|>", allowed_special={'<|endoftext|>'})[0]

def generate(model, idx, max_new_tokens, context_size, temperature=1.0, top_k=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        seq_len = idx_cond.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(idx.device).unsqueeze(0)

        with torch.no_grad():
            logits = model(idx_cond, mask=causal_mask)

        logits = logits[:, -1, :]

        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(
                logits < min_val,
                torch.tensor(float('-inf')).to(logits.device),
                logits
            )

        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        idx = torch.cat((idx, idx_next), dim=1)

        # Stop generation if <|endoftext|> is in the generated output
        if end_token_id in idx_next:
            break

    return idx

def truncate_after_n_bullets(text, n=3):
    lines = text.split("\n")
    count = 0
    result = []
    for line in lines:
        if line.strip().startswith(("1.", "2.", "3.")):
            count += 1
        result.append(line)
        if count >= n:
            break
    return "\n".join(result)

# 🔁 Input prompt
prompt = "### Instruction: What are the three primary colors? \n### Response:"

# 🔁 Tokenize input
input_ids = text_to_token_ids(prompt, tokenizer).to(device)

# 🔁 Generate output tokens
output_ids = generate(
    model=model,
    idx=input_ids,
    max_new_tokens=100,
    context_size=128,
    temperature=0.7,
    top_k=40
)

# 🔁 Decode and postprocess
output_text = tokenizer.decode(output_ids[0].tolist())

# ✂️ Truncate after 3 bullets (optional)
final_output = truncate_after_n_bullets(output_text)
print(final_output)
