In [1]:
import math 
import torch 
import torch.nn as nn 
import transformers 
from tqdm.notebook import tqdm
# Memory Network
import torch.nn.functional as F 
import random
from typing import Tuple , Optional
import torch.bin 
from transformers import AutoModelForCausalLM , AutoTokenizer

In [2]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

torch.backends.cudnn.benchmark = True


# RMS NORM

In [3]:
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim)) 

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        norm_x = self._norm(x.float()).type_as(x) 
        return norm_x * self.scale 


# Embedding Layer

In [4]:
import torch 

class EmbeddingLayer(torch.nn.Module):
    def __init__(self , vocab_size , embedding_dim):
        super().__init__()

        self.embedding_layer= torch.nn.Embedding(vocab_size , embedding_dim)

    def forward(self , input_tokens):
        return self.embedding_layer(input_tokens)
class InputEmbedding(nn.Module):

    def __init__(self, vocab_size: int , d_model:int):

        super().__init__()

        self.d_model  =  d_model 

        self.vocab_size = vocab_size

        self.embeddings = nn.Embedding(vocab_size , d_model)

    def forward(self ,x):

        return self.embeddings(x) * math.sqrt(self.d_model)
    


# FeedForward Layer 

In [5]:
import torch.nn as nn

class GELU(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return 0.5 * x *(1+ torch.tanh(torch.sqrt(torch.tensor(2.0/ torch.pi)) * (x+0.044715 * torch.pow(x, 3))) )
      

class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg['emb_dim'] , 4 * cfg['emb_dim']) ,
            GELU(),
            nn.Linear(4 * cfg['emb_dim'] , cfg['emb_dim'])
        )
    def forward(self, x ):
        return self.layers(x)
    


# Normalization Layer

In [6]:
import torch.nn as nn 
import torch 

class LayerNorm(nn.Module):
    def __init__(self , emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale  = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))
    def forward(self , x):
        mean = x.mean(dim= -1, keepdim = True)
        var = x.var(dim =-1, keepdim = True)
        norm_x = (x - mean) / torch.sqrt(var +self.eps)
        return self.scale * norm_x + self.shift 


# RoPE Embdding 

In [7]:
import numpy as np 
import torch 
from dataclasses import dataclass


class NRopE: # RopE in Numpy 
    def rotate_2d(self,vec , theta_p):
        cos_theta  , sin_theta  = np.cos(theta_p) , np.sin(theta_p)
        rotat_vec = np.array([[cos_theta , -sin_theta],
                    [sin_theta ,cos_theta]])
        
        return rotat_vec @ vec


    def RoPe(self,x , p , theta = 10000):
        d = len(x)
        x_rotate =  np.zeros_like(x)
        for i in range(0 , d , 2):
            if i +1< d:
                theta_p = (theta **(-2*(i//2)))**p 
                roted_pair = self.rotate_2d(x[i:i+1] , theta_p)    
                x_rotate[i:i+1] = roted_pair

        return x_rotate



@dataclass
class TRopE(torch.nn.Module): # RopE in torch 
    def __init__(self, dim:int ,theta:float = 10000):
        self.dim = dim 
        self.theta = theta 
        self.freq =  torch.pow(self.theta ,-torch.arange(0 ,dim  , 2)/dim )
        torch.nn.Parameter('freq' , self.freq)

    def forward(self, x:torch.Tensor , pos:torch.Tensor):
        batch_size , seq_len, dim = x.shape
        assert dim ==self.dim ,"Error dim must be same"
        theta_p = torch.einsum("n,d->nd" , pos, self.freq.to(x.device))
        cos_theta  , sin_theta = torch.cos(theta_p) , torch.sin(theta_p)
        x_even , x_odd =  x[... , ::2] , x[... , 1::2]
        x_rotated =  torch.empty_like(x)
        x_rotated[...,::2] =  x_even * cos_theta - x_odd * sin_theta
        x_rotated[...,1::2] =  x_even * sin_theta + x_odd * cos_theta

        return x_rotated







def precompute_freq_cis(  dim:int , end:int , theta:float = 10000.0):
        """dim : dimentions 
        end: end index   
        """
        freqs =  1/(theta **(torch.arange(0 , dim , 2)[:dim//2].float() / dim))
        t =  torch.arange(end, device=freqs.device)
        freqs = torch.outer(t , freqs).float()
        freqs_cis =  torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis 


def reshape_for_broadcast(freq_cis  , x):
        """ reshape the freqcies to match x dimentions """
        ndim=  x.ndim
        assert 0<=1<ndim 
        assert freq_cis.shape == (x.shape[1], x.shape[-1]), f"Expected {(x.shape[1], x.shape[-1])}, got {freq_cis.shape}" 
        shape = [d if i == 1 or i ==  ndim -1 else 1 for i , d in enumerate(x.shape)]
        return freq_cis.view(*shape)


def apply_rotary_embedding( xq:torch.Tensor ,xk:torch.Tensor ,  freq_cis:torch.Tensor):

            xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1,2))

            xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1,2))


            freq_cies =  reshape_for_broadcast(freq_cis , xq_)
    

            xq_out = torch.view_as_real(xq_* freq_cies).flatten(3)
            
            xk_out = torch.view_as_real(xk_*freq_cies).flatten(3)


            return  xq_out.type_as(xq)   ,  xk_out.type_as(xq) 





# MultiHead & MultiQuery Attention Layer 

In [8]:

class MultiHeadAttention_V2(nn.Module):
    def __init__(self, d_in , d_out , context_length  , dropout ,num_heads,qkv_bias = False):
        super().__init__()
        assert d_out % num_heads  == 0,'d_out must be divisible by the num_heads'
        self.w_query = nn.Linear(d_in , d_out ,bias=qkv_bias)
        self.w_key = nn.Linear(d_in , d_out, bias=qkv_bias)
        self.w_value = nn.Linear(d_in  , d_out,bias=qkv_bias)
        self.d_in =d_in
        self.d_out = d_out
        self.dropout = nn.Dropout(dropout)
        self.num_heads  = num_heads
        self.head_dim = d_out // num_heads
        self.out_proj  = nn.Linear(d_out , d_out)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length , context_length),diagonal=1)
        )

    def forward(self,x):
        b, num_tokens , d_in = x.shape
        keys = self.w_key(x)
        queries  = self.w_query(x)
        values = self.w_value(x)
        queries = queries.view(b, num_tokens , self.num_heads , self.head_dim)
        values = values.view(b , num_tokens , self.num_heads , self.head_dim)
        keys = keys.view( b, num_tokens , self.num_heads , self.head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool= self.mask.bool()[:num_tokens , :num_tokens]
        attn_scores.masked_fill(mask_bool , -torch.inf)
        attn_weights = torch.softmax(attn_scores /self.head_dim**0.5   , dim=-1 )
        attn_weights = self.dropout(attn_weights)
        context_vector = (attn_weights  @ values).transpose(1, 2)
        context_vector = context_vector.contiguous().view(b , num_tokens , self.d_out)
        context_vector = self.out_proj(context_vector)
        return context_vector




def apply_rotary_embedding(xq:torch.Tensor , xk:torch.Tensor , freq_cies:torch.Tensor):

    assert xq.shape[-1] % 2 == 0 , 'Embeddig dimension must be even for complex paring'

    assert xk.shape[-1] % 2 == 0 , 'Embeddig dimension must be even for complex paring'


    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1,2))

    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1] , -1, 2))

    freq_cies = reshape_for_broadcast(freq_cies , xq_)

    xq_out = torch.view_as_real(xq_ * freq_cies ).flatten(3)

    xk_out = torch.view_as_real(xk_ * freq_cies).flatten(3)

    return xq_out.type_as(xq) ,  xk_out.type_as(xk)




class MultiQueryAttentionBlock(nn.Module):
    def __init__(self, d_model:int , h:int , dropout:float , seq_len:int , qkv_bias =  False ):
        super().__init__()
        self.d_model  = d_model 

        self.seq_len=  seq_len

        assert d_model % h == 0, "d_model is must be divided by th head"
        self.dropout = nn.Dropout(dropout)

        self.h = h  

        self.d_k = d_model // h

        self.w_qkv =  nn.Linear(d_model , d_model +2 * self.d_k )

        self.w_o = nn.Linear(d_model  , d_model)

        freq_cies = precompute_freq_cis(dim=self.d_k , end=self.seq_len * 2 )

        self.register_buffer('freq_cies' , freq_cies , persistent= False )

    def generate_causal_mask(self, seq_len, device):
        # shape: (1, 1, seq_len, seq_len)
        return torch.tril(torch.ones((1, 1, seq_len, seq_len), device=device)).bool()

    @staticmethod
    def attention(q, k  , v,mask  , dropout):
        d_k = q.shape[-1]

        attention_score =  (q @ k.transpose(-2,-1)) / math.sqrt(d_k)

        if mask is not None :
            if mask.dim() == 2:
                      mask = mask.unsqueeze(1).unsqueeze(2)
            elif mask.dim() == 3:
                     mask = mask.unsqueeze(1)
            attention_score = attention_score.masked_fill(mask == 0, float('-inf'))
        
        attention_score = attention_score.softmax(dim=-1)

        if dropout is not None :
            attention_score = dropout(attention_score)

        context_vector =  attention_score @ v

        return context_vector  , attention_score
    


    def forward(self, q, mask= None):
        if mask is None:
            mask = self.generate_causal_mask(self.seq_len , device = q.device)
        qkv =  self.w_qkv(q)

        query , key, value =  torch.split(qkv , [self.d_model  , self.d_k , self.d_k], dim=-1)

        query = query.view(query.shape[0] , -1 , self.h , self.d_k).transpose(1, 2)

        key =  key.unsqueeze(1)

        value =  value.unsqueeze(1)

        seq_len =  q.size(1)

        freq_cies = self.freq_cies[:query.shape[1]].to(q.device)

        # freq_cies =  self.freq_cies[:seq_len].to(q.device)

        query , key = apply_rotary_embedding(query , key , freq_cies)

        x , self.attention_score = MultiQueryAttentionBlock.attention(q = query,k =  key,v= value ,mask=mask , dropout= self.dropout)

        x = x.transpose(1,2).contiguous().view(x.shape[0] , -1, self.h* self.d_k)

        x = self.w_o(x)

        return x 
    






# Memory Network 

In [9]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import random


class EfiBioSemanticMemory_V2(nn.Module):
    def __init__(self, input_dim:int ,semantic_memory_dim , max_slots:int = 1000 , compress_dim:int =  128 , top_k:int = 5 , num_heads:int =  4 ):
        super().__init__()

        self.input_dim = input_dim 
        self.max_slots =  max_slots 
        self.compress_dim =  compress_dim 
        self.top_k =  top_k 
        self.num_heads =  num_heads 
        self.semantic_memory_dim = semantic_memory_dim
        # self.memory_size =  semantic_memory_dim 



        self.key_memory =  nn.Parameter(torch.randn(max_slots , compress_dim))
        self.value_memory = nn.Parameter(torch.randn(max_slots , compress_dim))
        self.cell_state =  nn.Parameter(torch.randn(max_slots , compress_dim))
        self.register_buffer('active_mask' , torch.zeros(max_slots , dtype= torch.bool))
        self.active_mask[:semantic_memory_dim] = True  

        
        # Meta data parameter 
        self.register_buffer('age', torch.zeros(max_slots))
        self.register_buffer('usage', torch.zeros(max_slots))
        self.register_buffer('concept_energy', torch.ones(max_slots))
        self.register_buffer('memory_age', torch.zeros(max_slots))
        self.register_buffer('access_count', torch.zeros(max_slots))
        self.register_buffer("_memory_version", torch.tensor(0))
        self.concept_energy[:semantic_memory_dim] =  0.2
        # self.new_slot_mask  = torch.zeros(self.memory_size).bool()

        #stats params
        self.register_buffer("step_count", torch.zeros(1, dtype=torch.long))
        self.register_buffer('replay_count', torch.zeros(1 , dtype= torch.long))
        self.register_buffer("query_count", torch.zeros(1, dtype=torch.long))
        self.register_buffer("novel_count", torch.zeros(1, dtype=torch.long))
        self.register_buffer("write_count", torch.zeros(1, dtype=torch.long))
        self.register_buffer("hit_count", torch.zeros(1, dtype=torch.long)) 
        self.register_buffer('merge_count' , torch.zeros(1  , dtype= torch.long))
        self.register_buffer('neuroslot_count' , torch.zeros(1, dtype= torch.long))
        self.register_buffer('prune_count' , torch.zeros(1, dtype=torch.long))
        self.register_buffer('consalidate_count', torch.zeros(1,dtype=torch.long))
        self.register_buffer('update_count', torch.zeros(1,dtype=torch.long))


        self.initial_write_step  = 300
        
        # Threshold Parameter 
        self.consolidation_threshold = nn.Parameter(torch.tensor(100.0))
        self.energy_threshold = nn.Parameter(torch.tensor(0.2))
        self.decay_rate = nn.Parameter(torch.tensor(0.99))
        # self.novelty_threshold = nn.Parameter(torch.tensor(0.2))
        # self.novelty_threshold = 0.2 * (1 - (self.memory_size / self.max_slots))
        self.novelty_threshold = 0.1

        self.register_buffer("prune_age_threshold", torch.tensor(100))
        self.register_buffer("neurogenesis_threshold", torch.tensor(0.5))
        self.register_buffer("new_slot_maturation_steps", torch.tensor(20)) 
        self.synaptic_scale = nn.Parameter(torch.tensor(0.5))
        self.sparsity = nn.Parameter(torch.tensor(0.5))
        self.sim_thershold =  nn.Parameter(torch.tensor(0.5))
        self.confidence_threshold_att = 0.15 

        # Concept queue Params
        self.register_buffer('queue_max_size' , torch.tensor(1000))
        self.register_buffer('concept_queue' ,  torch.zeros(self.queue_max_size , self.compress_dim))
        self.queue_ptr = 0
        self.queue_count = 0 

        # Networks 
        self.important_net = nn.Sequential(
            nn.Linear(compress_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

        

        self.update_gate = nn.Sequential(
            nn.Linear(3 * compress_dim, 1),
            nn.Hardsigmoid()
        )

        for layer in self.update_gate:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.constant_(layer.bias, 0.1) 
        self.forgot_gate = nn.Sequential(
            nn.Linear(semantic_memory_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 3),
            nn.Sigmoid()
        )
        self.forget_gate_net = nn.Linear(compress_dim * 2, compress_dim)

        self.compression = nn.Sequential(
            nn.Linear(input_dim, semantic_memory_dim),
            nn.RMSNorm(semantic_memory_dim),
            nn.GELU(),
            nn.Linear(semantic_memory_dim, self.compress_dim)
        )

        self.decompression = nn.Sequential(
            nn.Linear(self.compress_dim, input_dim),
        )

        self.W_cell = nn.Linear(self.compress_dim, compress_dim, bias=False)
        self.memory_projection = nn.Linear(self.compress_dim, self.input_dim)
        # self.query_proj =  nn.Linear(self.semantic_memory_dim  , self.compress_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim=compress_dim,
            num_heads=num_heads,
            batch_first=False
        )
        self.mem_key_proj   = nn.Linear(compress_dim, semantic_memory_dim, bias=False)
        self.mem_value_proj = nn.Linear(compress_dim, semantic_memory_dim, bias=False)
        self.no_memory_embedding = nn.Parameter(torch.randn(1, 1, self.input_dim))

        nn.init.kaiming_uniform_(self.key_memory, mode='fan_out')
        nn.init.xavier_normal_(self.value_memory)
        nn.init.xavier_normal_(self.cell_state)
        # centroids =  centroids.to(self.key_memory.device)
        # self.key_memory.data = F.normalize(centroids.clone(), dim=-1)
        # self.value_memory.data = F.normalize(centroids.clone(), dim=-1)
        # self.cell_state.data = F.normalize(centroids.clone(), dim=-1)

    def _get_active_memory(self):

        """
            Get the active memries slot 
        """
        idx = torch.nonzero(self.active_mask, as_tuple=False).squeeze(1)
        assert idx.numel() > 0, "No active memory slots"
       
        return  (
            self.key_memory[idx] , 
            self.value_memory[idx], 
            self.cell_state[idx]

        )
        
    @property 
    def active_capacity(self):
        return torch.sum(self.active_mask).item()/ self.max_slots


    def _adaptive_decay(self, memory_idx):
        energy = self.concept_energy[memory_idx]
        access = self.usage[memory_idx]

        decay = torch.exp((1-self.decay_rate) * (1-energy) * (1-access))
        return decay  
    
    def _boost_energy_on_access(self,indices , sims):
        weights =  F.softmax(sims , dim=-1)
        boost =  weights * 0.1 
        self.concept_energy[indices] = torch.clamp(
            self.concept_energy[indices] * 0.95 + boost, 0, 1)
        
        
    def _update_memory(self, topk_idx: torch.LongTensor, cells: torch.Tensor,
                   projected: torch.Tensor, update_gates: torch.Tensor):
        projected_exp = projected.unsqueeze(1).expand(-1, cells.size(1), -1)
    
        cell_input = torch.sigmoid(cells + projected_exp)  
    
        decay_factor = self._adaptive_decay(topk_idx)     
        decay_factor = decay_factor.unsqueeze(-1)         
        cell_updates = decay_factor * cell_input
    
        # apply update_gates: [B, top_k, D]
        delta = update_gates.unsqueeze(-1) * cell_updates
    
        active_indices = torch.nonzero(self.active_mask, as_tuple=True)[0]
        new_slot_mask = torch.isin(topk_idx, active_indices[-10:])
        if new_slot_mask.any():
            cell_updates = cell_updates.clone()  
            cell_updates[new_slot_mask] *= 0.1
            delta = delta * torch.where(new_slot_mask.unsqueeze(-1), 0.5, 1.0)
    
        delta = delta * self.synaptic_scale   
    
        batch_size, top_k, dim = delta.shape
        flat_topk_idx = topk_idx.view(-1)             
        flat_delta    = delta.view(-1, dim)         
    
        flat_cells = F.normalize(cells, dim=-1).view(-1, dim)  
    
        self.age.data += 1
        self.update_count += 1
        assert flat_topk_idx.max() < self.key_memory.size(0), (
            f"Index {flat_topk_idx.max()} >= {self.key_memory.size(0)}"
        )
        max_slots, _ = self.cell_state.shape 
        device = self.cell_state.device
        dtype = self.cell_state.dtype
    
        delta_buffer = torch.zeros((max_slots, dim), device=device, dtype=dtype)
        key_buffer   = torch.zeros_like(delta_buffer)
        val_buffer   = torch.zeros_like(delta_buffer)
    
        
        idx_expanded = flat_topk_idx.unsqueeze(-1).expand(-1, dim) 
        delta_buffer.scatter_add_(0, idx_expanded, flat_delta)
    
        key_buffer.scatter_add_(0, idx_expanded, flat_cells)
        val_buffer.scatter_add_(0, idx_expanded, flat_cells)
    
        with torch.no_grad():
            self.cell_state.data.add_(delta_buffer)
            self.key_memory.data.add_(key_buffer)
            self.value_memory.data.add_(val_buffer)
    
            updated_idx = flat_topk_idx.unique()
            self.age.data[updated_idx] = 0
    
            decay = self.decay_rate ** self.age.unsqueeze(-1)
            self.cell_state.data.mul_(decay)
            self.key_memory.data.mul_(decay)
            self.value_memory.data.mul_(decay)
    
            k_norm = F.normalize(self.key_memory.data[updated_idx], dim=-1)
            v_norm = F.normalize(self.value_memory.data[updated_idx], dim=-1)
            c_norm = F.normalize(self.cell_state.data[updated_idx], dim=-1)
    
            self.key_memory.data[updated_idx]   = torch.tanh(k_norm)
            self.value_memory.data[updated_idx] = torch.tanh(v_norm)
            self.cell_state.data[updated_idx]   = torch.tanh(c_norm)



    def _get_low_energy_slots(self, candidate_indices):
        if len(candidate_indices) == 0:
            return candidate_indices 
        candidate_energy =  self.concept_energy[candidate_indices]

        sorted_indices   =  torch.argsort(candidate_energy)
        return candidate_indices[sorted_indices]
       

    @property
    def memory_size(self):
        return self.active_mask.sum().item()
    
    @property
    def utilization(self):
        return    self.active_mask.sum().item() / self.memory_size

        
    # @torch.no_grad()
    def _batch_update_with_old(self, indices: torch.LongTensor, new_data: torch.Tensor):
     
        old_keys = self.key_memory[indices]      
        old_vals = self.value_memory[indices]     
        old_cells = self.cell_state[indices]     
    
        gate_input = torch.cat([old_cells, old_keys, new_data], dim=-1)
    
        pre = self.update_gate(gate_input)            
        gate = self.hard_sigmoid(pre)                
    

    
        with torch.no_grad():
            blended_key = gate * new_data + (1.0 - gate) * old_keys   # [N, D]
            blended_val = (0.35 * old_vals) + (0.65 * new_data * (1.0 - gate))  # [N, D]
            blended_cell = gate * new_data + (1.0 - gate) * old_cells  # [N, D]
    
            new_key_norm = F.normalize(blended_key, dim=-1)
            new_val_norm = F.normalize(blended_val, dim=-1)
            new_cell_norm = F.normalize(blended_cell, dim=-1)
    
            self.key_memory.data[indices]   = new_key_norm
            self.value_memory.data[indices] = new_val_norm
            self.cell_state.data[indices]   = new_cell_norm
    
            self.usage[indices] = 0.0
            self.age[indices] *= 0.25
            self.memory_age[indices] = 0
            self.concept_energy[indices] = 0.5
            self.access_count[indices] = 0
    


    def _batch_update_with_old(self , indices , new_data, importance_score = None):
        if importance_score is None:
            importance_score =  torch.ones_like(indices, dtype = torch.float32) * 0.5

        old_keys =  self.key_memory[indices]
        old_vals =  self.value_memory[indices]
        old_cell =  self.cell_state[indices]
        concept_energy =  self.concept_energy[indices]

        gate_input =  torch.cat([old_keys, new_data]  , dim=-1)
        forget_gate =  torch.sigmoid(self.forget_gate_net(gate_input))
        inactive_mask = (concept_energy < 0.15).float().unsqueeze(-1)
        write_gate = torch.clamp(importance_score.unsqueeze(-1) + inactive_mask, 0.0, 1.0)
    
        # Update memory with gated blend
        updated_keys = F.normalize((1 - forget_gate) * old_keys + forget_gate * new_data, dim=-1)
        updated_vals = F.normalize(0.35 * old_vals + 0.65 * new_data, dim=-1)
    
        updated_cell = F.normalize(
            self.update_gate(torch.cat([old_cell, old_keys, new_data], dim=-1)) * new_data
            + (1 - self.update_gate(torch.cat([old_cell, old_keys, new_data], dim=-1))) * old_cell,
            dim=-1
        )
    
        # Final gated write
        with torch.no_grad():
            self.key_memory[indices] = (1 - write_gate) * old_keys + write_gate * updated_keys
            self.value_memory[indices] = (1 - write_gate) * old_vals + write_gate * updated_vals
            self.cell_state[indices] = (1 - write_gate) * old_cell + write_gate * updated_cell
    
            # Reset metadata
            self.usage[indices] = 0.0
            self.age[indices] *= 0.25
            self.memory_age[indices] = 0
            self.concept_energy[indices] = 0.5
            self.access_count[indices] = 0
            

    @torch.no_grad()
    def _batch_update_with_new(self, new_idx, new_data):
        self.key_memory[new_idx] = new_data
        self.value_memory[new_idx] = new_data
        self.cell_state[new_idx] = new_data
        self.usage[new_idx] = 0.0
        self.age[new_idx] = 0
        self.access_count[new_idx] = 0
        self.memory_age[new_idx] = 0
        self.concept_energy[new_idx] = 0.5
        self.active_mask[new_idx] = True

    
    def _write_memory_update(self, new_concepts: torch.Tensor, retry_count=0):
        # if retry_count ==  0 and self.query_count > 0:
        #     self.flush_concept_queue()
        # print('Writing Happen')
        if retry_count > 2 or new_concepts.size(0) == 0:
            return
        importance = self.important_net(new_concepts).squeeze(-1)
        keep_mask = importance > 0.10
        if not keep_mask.any():
            return

        new_concepts = new_concepts[keep_mask]
        remaining = new_concepts.size(0)

        # 1. Update low-energy active slots
        low_energy_candidate = torch.where(self.active_mask & (self.energy_threshold > self.concept_energy))[0]
        if low_energy_candidate.numel() > 0:
            candidate = self._get_low_energy_slots(low_energy_candidate)
            num_reuse = min(len(candidate), remaining)
            if num_reuse > 0:
                self._batch_update_with_old(indices=candidate[:num_reuse], new_data=new_concepts[:num_reuse])
                new_concepts = new_concepts[num_reuse:]
                remaining = new_concepts.size(0)

        # 2. Write to new slots
        if remaining > 0 and self.memory_size < self.max_slots:
            add = min(remaining, self.max_slots - self.memory_size)
            new_idx = self.neurogenesis(return_index=True, required_slots=add)
            if new_idx is not None:
                self._batch_update_with_new(new_idx, new_concepts[:add])
            # self.memory_size += add
                new_concepts = new_concepts[add:]
                remaining = new_concepts.size(0)

        if remaining == 0:
            return
        prev_active = self.active_mask.sum()
        self._optimize_memory()
        if self.active_mask.sum() > prev_active and retry_count < 2:
            
            self._write_memory_update(new_concepts, retry_count + 1)
        else:
            self._enqueue_to_queue_buffer(new_concepts)

        assert self.memory_size <= self.max_slots
        assert torch.all(self.active_mask[:self.memory_size])

    @torch.no_grad()
    def flush_concept_queue(self):
        if self.query_count > 0:
            concepts = self.concept_queue[:self.queue_count]
            self._write_memory_update(concepts)
            self.queue_count = 0
            self.queue_ptr = 0
           

    @torch.no_grad()
    def _enqueue(self, data):
        if data.size(0) == 0:
            return

        capacity = self.concept_queue.size(0)
        avail = capacity - self.queue_count
        to_add = data[:avail]
        if to_add.size(0) == 0:
            return

        start = self.queue_ptr
        end = (start + to_add.size(0)) % self.queue_max_size  
        if end <= capacity:

            self.concept_queue[start:end] = to_add
        else:
            split = capacity - start
            self.concept_queue[start:] = to_add[:split]
            self.concept_queue[:end % capacity] = to_add[split:]

        self.queue_ptr = end % capacity
        self.queue_count = min(self.queue_count + to_add.size(0), capacity)

    @torch.no_grad()
    def _enqueue_to_queue_buffer(self, new_concepts):
        to_enqueue = new_concepts.size(0)
        if to_enqueue == 0:
            return

        capacity = self.queue_max_size.item()
        current = self.queue_count
        overflow = max(0, current + to_enqueue - capacity)
        important =  self.important_net(new_concepts).squeeze(-1)
        keep_mask = important > 0.65
        new_concepts = new_concepts[keep_mask]
        if new_concepts.size(0) == 0:
            return 
        if overflow >= new_concepts.size(0):
            self._enqueue(new_concepts[-capacity:])
            return  

        kept_new = new_concepts[-(to_enqueue - overflow):]
        self._enqueue(kept_new)

    @torch.no_grad()
    def _optimize_memory(self , aggressive = False):
        if aggressive:
            self._consolidate_important_memories()
            self._merge_similar_slots()
            self._prune_memories()
            self._prune_slots()
            self.neurogenesis()
        else:
            self._consolidate_important_memories()
            self._prune_memories()
            self.neurogenesis()


        

    def _update_energy_level(self):
        # key_memory = self.mem_key_proj(self.key_memory)
        importance = self.important_net(self.key_memory).squeeze()
        assert torch.all(self.concept_energy >= 0)
        assert torch.all(self.concept_energy <= 1.01)  
        new_energy =  (
            # self.decay_rate * 
            0.3 * self.concept_energy + 0.1 * self.usage + 0.6 * importance *(1-self.concept_energy)
                    )
        
        deactivated  = ~self.active_mask 

        valid_deactivate =  deactivated.nonzero().squeeze()
        valid_deactivate = valid_deactivate[(valid_deactivate>=0)&(valid_deactivate < self.memory_size)]
        with torch.no_grad():
            self.concept_energy.data =  torch.clamp(new_energy  , 0, 1)

            if valid_deactivate.numel() >0:
                self.key_memory.data[valid_deactivate] *=0.01
                self.value_memory.data[valid_deactivate]*=0.01
                self.cell_state.data[valid_deactivate]*=0.1

        self.active_concepts =  torch.sum(self.active_mask).clamp(min=0 , max=self.memory_size)
    
    def forward(self ,x:torch.Tensor , training:bool = True):
        if training:
           x=  self.replay_consolidation(x=x)
    
        batch_size  , seq_len , _  = x.shape 
        self.step_count += 1
        self.query_count += batch_size 
        compressed = self.compression(x.mean(dim=1))
        query = self.W_cell(compressed)
        k_active , v_active , c_active =  self._get_active_memory()
        k  = k_active.unsqueeze(1).expand(-1 , batch_size ,-1)
        v = v_active.unsqueeze(1).expand(-1, batch_size , -1)
        assert k.size(1) == batch_size
        # attn_output , attn_weights = self.attn(
        #         query.unsqueeze(0), k ,v , need_weights =  True 
        #     )

        # Cosine similarity instead of MHA
        sims =  F.cosine_similarity(query.unsqueeze(1),
                                     k_active.unsqueeze(0) , dim=-1)
        # print("   pre‑write mean/sd:", sims.mean().item(), sims.std().item())
        top_vals , local_topk = sims.topk(self.top_k , dim=-1)
        active_indices = torch.nonzero(self.active_mask, as_tuple=True)[0] 
        topk_idx = active_indices[local_topk]
        attn_weights = torch.zeros(1, batch_size, k_active.size(0), device=query.device)
        attn_weights[0].scatter_(1, local_topk, 1.0)
        # indices = torch.nonzero(self.active_mask, as_tuple=True)[0][topk_idx]

        # self._boost_energy_on_access(indices.view(-1), sims.view(-1, k.size(0)))

        v_exp    = v_active.unsqueeze(0).expand(batch_size, -1, -1)  
        gathered = torch.gather(
            v_exp,
            1,
            local_topk.unsqueeze(-1).expand(-1, -1, self.compress_dim)
        )                                                             
        retrieved = gathered.mean(dim=1)                            

        attn_output = retrieved.unsqueeze(0)                        

         

        max_scores , _  = sims.max(dim = -1)
        # if self.query_count < self.initial_write_step:
        #     hit_thershold = 0.79
        # else:
        #     hit_thershold=  0.50
        # hit_thershold = 0.5 + 0.2 * (self.memory_size / self.max_slots)
        hit_thershold =  0.61
        # print('hit_threshold', hit_thershold)
        # query_proj = self.query_proj(query)    
        # hit_threshold = 0.3 + 0.2 * (self.memory_size / self.max_slots)
        max_scores = max_scores.squeeze(-1)
        # print('max scores ',max_scores)
        hits = (max_scores>hit_thershold).sum()
        self.hit_count +=  hits
        # print('hit count', self.hit_count)
        # # self.novelty_threshold = torch.clamp(
        #     torch.tensor(0.4 - 0.3 * (self.memory_size / self.max_slots)), 
        #     min=0.1, 
        #     max=0.5
        # )     
        #   
        novel_mask =  max_scores <  hit_thershold
        
        
        self.novel_count+= novel_mask.sum()
        if novel_mask.any():
                novel_projection =  query[novel_mask]
            # sim_scores =  F.cosine_similarity(novel_projection.unsqueeze(1) , k_active.unsqueeze(0) , dim=-1)
            # similarity_threshold = self.sim_thershold - 0.2 * (self.memory_size / self.max_slots)
            # is_novel =  sim_scores.max(dim=-1).values< (self.sim_thershold - 0.2 * (self.memory_size/self.max_slots))
            # write_mask = is_novel
            # write_mask = sim_scores.max(dim=-1).values < similarity_threshold
            # if write_mask.any():
                if self.query_count >  0:
                    self.flush_concept_queue()
                # new_concepts =  novel_projection[write_mask]
                new_concepts = novel_projection
                self.write_count += new_concepts.size(0)
                assert new_concepts.size(0) <= self.max_slots - self.memory_size 
                "Exceeding maximum memory capacity"
                self._write_memory_update(new_concepts=new_concepts)
                no_memory_out =  self.no_memory_embedding.repeat(batch_size , seq_len, 1)
                return no_memory_out  , self.no_memory_embedding.squeeze(0), torch.tensor([], dtype= torch.long) , None 
        # attn_output = attn_output + torch.randn_like(attn_output) * 0.1
        

        with torch.no_grad():
            self.usage *= 0.95
            self.usage[topk_idx] +=0.1
            self.usage.clamp(0,1)
            self.usage.mul_(0.9)
            self.usage.scatter_add_(0, topk_idx.flatten(), torch.ones_like(topk_idx, dtype=torch.float).flatten())
            self.usage.clamp_(max=1.0)
            # self.concept_energy[topk_idx] += 0.15 * max_scores.squeeze()
            # self.concept_energy.clamp_(max=1.0)
       
        keys= self.key_memory[topk_idx]
        value = self.value_memory[topk_idx]
        cells = self.cell_state[topk_idx]
        if training and self.query_count % 31 == 0:
            self._update_energy_level()
            self._update_thersholds()
             
        gate_input = torch.cat([
            keys, cells, query.unsqueeze(1).expand(-1, self.top_k , -1)
        ], dim= -1) 
        update_gates =  self.update_gate(gate_input.view(-1, 3 *self.compress_dim))
        update_gates = update_gates.view(batch_size, self.top_k)
        self._update_memory(topk_idx=topk_idx, cells=cells ,projected=query, update_gates=update_gates)
        
        # Project the output to the out 
        out =  self.memory_projection(retrieved)
        out = out.unsqueeze(1).repeat(1, seq_len, 1)
        self._memory_version +=1 
        self._update_memory_metadata(topk_idx)
        return  out , retrieved , topk_idx , attn_weights


    @torch.no_grad()
    def _gradual_influence_increase(self):
        """
        Gradually increase the influence of newly added memory slots based on their age and access.
        """
        new_slots_mask  = (self.age <= self.new_slot_maturation_steps) & self.active_mask 
        if not torch.any(new_slots_mask):
            return  
        
        age_normalized = self.age[new_slots_mask] / self.new_slot_maturation_steps
        usage_normalized =  self.usage[new_slots_mask]

        growth_factor =torch.sigmoid((age_normalized + usage_normalized) * 3 ).unsqueeze(-1)

        self.key_memory[new_slots_mask] = F.normalize(self.key_memory[new_slots_mask] * (1 + growth_factor * 0.5),
        dim=-1)
        self.value_memory[new_slots_mask] =  F.normalize(self.value_memory[new_slots_mask]* (1+growth_factor * 0.3) ,dim=-1 )

        energy_boost = torch.clamp(0.1 * growth_factor.squeeze(), max=0.15)
        self.concept_energy.data[new_slots_mask] = torch.clamp(
        self.concept_energy[new_slots_mask] + energy_boost,
        min=0.3,
        max=0.7
    )

        self.age.data[new_slots_mask] += 1 

        
    def _consolidate_important_memories(self):
        key_memory =  self.mem_key_proj(self.key_memory)
        importance =  self.important_net(self.key_memory).squeeze()
        consolidate_mask  = importance > 0.1
        if consolidate_mask.any():
          with torch.no_grad():
            self.key_memory[consolidate_mask] = F.normalize(
                self.key_memory[consolidate_mask] , dim=-1
            )
            mean_value = self.value_memory[consolidate_mask].mean(dim=0)
           
            self.value_memory[consolidate_mask] = (
                    0.9 * self.value_memory[consolidate_mask] +
                    0.1 * mean_value
                )
            self.concept_energy[consolidate_mask] = torch.clamp(self.concept_energy[consolidate_mask] + 0.05, 0, 1)
            self.concept_energy[~consolidate_mask] *= 0.85
            self.consalidate_count +=1 

    # @torch.no_grad()
    # def _prune_memories(self):
    #     prune_condidate =  ((self.age > self.prune_age_threshold * 0.5 ) & (self.usage < 0.05) & (self.concept_energy < self.energy_threshold))
    #     if prune_condidate.any():
    #         self.key_memory.data[prune_condidate] *=  0.1
    #         self.value_memory.data[prune_condidate]*=0.01
    #         self.cell_state.data[prune_condidate] *= 0.01
    #         self.age.data[prune_condidate] = 0 
    #         self.usage[prune_condidate] = 0 
    #         self.concept_energy.data[prune_condidate] = 0.1
    #         self.prune_count +=1 
    #         self._memory_version += prune_condidate.sum().item()

    def _prune_memories(self):
        # prune_mask: [max_slots]
        prune_mask = torch.sigmoid((self.age - 100) / 20) * (1 - self.usage)
        prune_mask *= torch.sigmoid(-self.concept_energy * 5)

        # no need to unsqueeze for concept_energy
        with torch.no_grad():
            self.key_memory.data *= (1 - prune_mask.unsqueeze(-1) * 0.5)   # keeps shape [max_slots, dim]
            self.concept_energy.data *= (1 - prune_mask * 0.3)             # shape [max_slots]


    @torch.no_grad()
    def _prune_slots(self):
            mask = self.age > self.prune_age_threshold
            if mask.any():
                self.key_memory.data[mask] *= 0.01
                self.value_memory.data[mask] *= 0.01
                self.cell_state.data[mask] *= 0.01
                self.concept_energy.data[mask] = 0
                self.usage[mask]= 0 
                self.age.data[mask] = 0 
                self.prune_count +=1 
                self.active_slots[mask] = False 
    


    def replay_consolidation(self, x: torch.Tensor):

        active_indices = torch.nonzero(self.active_mask, as_tuple=True)[0] 
        active_key , active_value, _ = self._get_active_memory()
        if  self.training and random.random() < 0.2: 
            high_energy_mask = self.concept_energy[active_indices] > 0.8
            if high_energy_mask.sum() == 0:
                return x 
            if high_energy_mask.sum() > 0:
                replay_keys = active_key[high_energy_mask]
                replay_values = active_value[high_energy_mask]
                
                replay_input = self.decompression(replay_values.mean(dim=0, keepdim=True))
                B, T , D =  x.shape
                self.replay_count+= 1
                return replay_input.unsqueeze(1).expand(B,T,D)
        return x

    def get_reusable_slots(self ,num_needed:int):
       

        age_score =  1-torch.sigmoid(self.age / 100) # old age 
        energy_score = (1-  self.concept_energy ) *2 
        usage_score = 1 - self.usage 
        reuse_scores = (0.4 * energy_score  + 0.3 * age_score + 0.3 * usage_score 
                       )

        mask =(self.concept_energy  < self.energy_threshold) & (self.age< 100)
        reuse_scores[~mask]= -float('inf')

        topk_scores  , candidates = torch.topk(reuse_scores, min(num_needed, self.memory_size))
        return candidates 

            
    def _reinitialize_slot(self, idx):
        """Reset a slot to initial state"""
        with torch.no_grad():
            scale = 0.1 + 0.05 * torch.rand(1, device=idx.device)
            self.key_memory[idx] = torch.randn_like(self.key_memory[idx]) * scale
            self.value_memory[idx] = torch.randn_like(self.value_memory[idx]) * scale
            self.cell_state[idx] =  0.2 * self.cell_state.data[idx].mean(dim=0)
            
            # Reset metadata
            # neighbor_energy = self.concept_energy[idx±5].mean()   
            self.concept_energy[idx] = 0.3 + 0.2 * torch.rand_like(self.concept_energy[idx])
            # self.usage[idx] =  0.1 * torch.rand_like(self.usage[idx])
            self.usage[idx] = 0.05
            self.age[idx] = 0
            self.memory_age[idx] = 0
            self.access_count[idx] = 0
            self.active_mask[idx] = True 





    def _consalidate_new_slots(self):
        new_slot_indices =  torch.arange(self.memory_size - 10 , self.memory_size )
        new_slot_energy = self.concept_energy[new_slot_indices]
        with torch.no_grad():
            self.concept_energy[new_slot_indices] = torch.clamp(
                new_slot_energy + 0.1 * self.age[new_slot_indices] , 0 ,1
            )
            self.key_memory[new_slot_indices] *=  0.1
            self.value_memory[new_slot_indices] *= 0.1
            self.age[new_slot_indices] +=1
            self.access_count[new_slot_indices] +=1 

    def _update_memory_metadata(self, used_indices):
       
        self.access_count[used_indices] += 1
        
        self.memory_age += 1
        self.memory_age[used_indices] = 0
        with torch.no_grad():
            self.concept_energy[used_indices] += 0.1
            # self.concept_energy= torch.clamp(self.concept_energy * 0.95, 0, 1)
            self.concept_energy.mul_(0.95).clamp_(0,1)

            # self.age[used_indices] -= 5 
    @torch.no_grad()
    def neurogenesis(self, required_slots:int= 10 , return_index = False):
        device = self.key_memory.device
        if self.max_slots > self.memory_size:

            usage_rate =  (self.usage > 0.1).float().mean()
            if usage_rate > self.neurogenesis_threshold:
                reusable =  self.get_reusable_slots()
                num_reuse =  min(reusable.numel() , required_slots)
                reused_indices = reusable[:num_reuse]
                if num_reuse > 0:
                    with torch.no_grad():
                        device =  self.key_memory.device 
                        self._reinitialize_slot(idx=reused_indices)
                        
                new_slots =  min(max(0 ,    required_slots -  num_reuse ), self.max_slots- self.memory_size)
                if new_slots > 0:
                    start_idx =self.memory_size
                    end_idx =  start_idx  + new_slots 
                    new_indices  = torch.arange(start_idx , end_idx, device = device )
                  
                   
                    self.key_memory.data[new_indices] = torch.randn(new_slots, self.compress_dim, device=device) * 0.1
                    self.value_memory.data[new_indices] = torch.randn(new_slots, self.compress_dim, device=device) * 0.1
                    self.cell_state.data[new_indices] = 0
                    self.concept_energy.data[new_indices] = 0.5
                    self.usage.data[new_indices] = 0.0
                    self.age.data[new_indices] = 0.0
                    self.access_count.data[new_indices] = 0.0
                    self.active_mask.data[new_indices] = True
                    self.neuroslot_count +=1 
                    self._gradual_influence_increase()
                    self._memory_version += new_slots
                    # self.new_slot_mask[new_indices] =  True 


                
                    
    
       
                if return_index:
                    return torch.cat([reused_indices, new_indices]) if new_indices.numel() > 0 else reused_indices
            
        elif return_index:
            return  torch.empty(0, dtype=torch.long, device=self.key_memory.device) 

    def emergency_recovery(self):
        # Reset unstable memories
        unstable = self.concept_energy < 0.2
        self._reinitialize_slot(unstable)
        
        self._optimize_memory(aggressive=True)
    def _protect_critical_memories(self):
            # Protect top 10% of important memories
            importance = self.important_net(self.key_memory).squeeze()
            topk = importance.topk(int(self.max_slots * 0.1)).indices
            self.concept_energy[topk] = 1.0
            self.age[topk] -= 10
    @torch.no_grad()
    def _merge_similar_slots(self, top_k: int = 8):
        device = self.key_memory.device
        active_idx = torch.nonzero(self.active_mask, as_tuple=True)[0]
        N = active_idx.size(0)
        if N < 2:
            return

        # 1. Normalized vectors
        keys = F.normalize(self.key_memory[active_idx], dim=-1)
        values = F.normalize(self.value_memory[active_idx], dim=-1)
        D = keys.size(-1)

        # 2. Similarity search
        K = min(top_k, N-1)
        sims, nbrs = torch.topk(keys @ keys.T, k=K+1, dim=-1)
        sims, nbrs = sims[:, 1:], nbrs[:, 1:]  # Remove self

        # 3. Dynamic threshold
        # pressure = torch.tensor(N / self.max_slots, device=device)
        # threshold = (0.9 - 0.4 * pressure).clamp(0.65, 0.9)
        # energy_factor = torch.sigmoid((self.concept_energy.mean() - 0.5) * 5)
        confidence_factor = sims.mean()
        utilization_factor = (self.concept_energy < 0.9).float().mean()
        
        adaptive_threshold = 0.4 + 0.2 * self.concept_energy.mean() + 0.2 * confidence_factor + 0.2 * utilization_factor

        # mask = sims > threshold
        mask = sims >  adaptive_threshold 

        # 4. Graph construction
        row = torch.arange(N, device=device).unsqueeze(1).expand(-1, K)[mask]
        col = nbrs[mask]
        edges = torch.stack([
            torch.cat([row, col]),
            torch.cat([col, row])
        ])

        # 5. Label propagation
        labels = torch.arange(N, device=device)
        for _ in range(3):
            neighbor_labels = labels[edges[1]]
            updates = torch.minimum(labels[edges[0]], neighbor_labels)
            labels.scatter_reduce_(0, edges[0], updates, reduce='amin')  # Fixed reduction

        # 6. Cluster analysis
        uniq, inv, counts = torch.unique(labels, return_inverse=True, return_counts=True)
        cluster_mask = counts >= 2
        big_clusters = uniq[cluster_mask]
        big_counts = counts[cluster_mask]
        num_clust = big_clusters.size(0)
        if num_clust == 0:
            return

        # 7. Cluster mapping
        cluster_id_map = torch.zeros(uniq.max()+1, dtype=torch.long, device=device)
        cluster_id_map[big_clusters] = torch.arange(num_clust, device=device)
        member_mask = torch.isin(inv, big_clusters)
        global_idx = active_idx[member_mask]
        cluster_ids = cluster_id_map[inv[member_mask]]  # Proper mapping
        # 8. Energy aggregation
        energy = self.concept_energy[global_idx]
        weights = (energy / big_counts[cluster_ids].float()).unsqueeze(-1)
        expanded_ids = cluster_ids.unsqueeze(-1).expand(-1, D)
        # 8.1 Weighted sum of the seleceted slots datat 
        new_keys = torch.zeros((num_clust, D), device=device)
        new_vals = torch.zeros_like(new_keys)
        new_keys.scatter_add_(0, expanded_ids, keys[member_mask] * weights)
        new_vals.scatter_add_(0, expanded_ids, values[member_mask] * weights)

        # 9. Representative selection
        cluster_ages = self.age[global_idx]
        min_ages = torch.zeros(num_clust, device=device)
        min_ages.scatter_reduce_(0, cluster_ids, cluster_ages, reduce='amin', include_self=False)
        
        # Find first occurrence of min age
        _, sorted_idx = torch.sort(cluster_ids)
        cluster_ids_sorted = cluster_ids[sorted_idx]
        age_mask = (cluster_ages[sorted_idx] == min_ages[cluster_ids_sorted])
        _, first_occurrence = torch.unique_consecutive(cluster_ids_sorted, return_inverse=True)
        rep_mask = age_mask & (first_occurrence == 0)
        rep_cluster_ids = cluster_ids_sorted[rep_mask]
        representatives = global_idx[sorted_idx][rep_mask]
    
        # 10. Memory updates
        self.key_memory[representatives] = F.normalize(new_keys[cluster_ids_sorted[rep_mask]], dim=-1)
        self.value_memory[representatives] = F.normalize(new_vals[cluster_ids_sorted[rep_mask]], dim=-1)
        
        clust_energy = torch.bincount(cluster_ids, weights=energy, minlength=num_clust)
        # self.concept_energy[representatives] = clust_energy[rep_cluster_ids].clamp(min=1e-5, max=1.0)
        self.concept_energy[representatives] = torch.clamp(
            clust_energy[rep_cluster_ids] * 1.2,  
            min=0.7, 
            max=1.0
        )
     
        self.merge_count+= 1 
        self._memory_version += num_clust
        # 11. Usage update
        per_cluster_usage = torch.bincount(
            cluster_ids,
            weights=self.usage[global_idx],
            minlength=num_clust
        ).float() / big_counts.float() 
        # self.usage[representatives] = per_cluster_usage[rep_cluster_ids]
        self.usage[representatives] = torch.clamp(
            per_cluster_usage[rep_cluster_ids] * 1.5,
            min=0.3,
            max=1.0
        )

        # 12. Deactivation
        active_mask_modified = torch.zeros_like(self.active_mask)
        active_mask_modified[representatives] = True
        deactivate_idx = member_mask & ~active_mask_modified[active_idx]
        
        if deactivate_idx.any():
            to_deactivate = active_idx[deactivate_idx]
            self.concept_energy[to_deactivate] *= 0.1
            self.key_memory[to_deactivate] *= 0.1
            self.value_memory[to_deactivate] *= 0.1
            self.usage[to_deactivate] *= 0.25

        self._consolidate_important_memories()
        self._update_memory_metadata(representatives)
            
    @torch.no_grad()
    def _update_thersholds(self , momentum:float=0.9 ):
        hit_rate =  float(self.hit_count / max(self.query_count, 1))
        write_rate =  float(self.write_count  / max(self.query_count , 1))
        novely_rate =  float(self.novel_count / max(self.query_count  , 1))

        util =  float(self.active_capacity)

        new_nov =   (0.2 * (1-util) + 0.1 * novely_rate)
        self.novelty_threshold =    momentum * self.novelty_threshold + (1-momentum) * new_nov

        new_enger_thr = 0.3 + 0.3 *(1-hit_rate)
        self.energy_threshold.data  = momentum * self.energy_threshold + (1-momentum) * new_enger_thr 
        
        new_consal = 50.0 +50.0 * write_rate 
        self.consolidation_threshold.data.mul_(momentum).add_(new_consal * (1-momentum))

        new_decay = 0.995 + 0.003 * (1-util)
        self.decay_rate.data.mul_(momentum).add_(new_decay *(1-momentum))


        new_prune_age = 100 * (1- util ) + 20 * util 
        self.prune_age_threshold.fill_(momentum * self.prune_age_threshold+(1-momentum) * new_prune_age)

        new_neuro = 0.8 + 0.1 * write_rate - 0.1 * util
        self.neurogenesis_threshold.fill_(momentum * self.neurogenesis_threshold + (1-momentum) * new_neuro)

        new_mat = 50 + 50 * write_rate
        self.new_slot_maturation_steps.fill_(momentum * self.new_slot_maturation_steps + (1-momentum) * new_mat)

        new_scale = 0.05 + 0.2 * (1 - hit_rate)
        self.synaptic_scale.data.mul_(momentum).add_(new_scale * (1-momentum))

        new_sp = 0.5 + 0.3 * util
        self.sparsity.data.mul_(momentum).add_(new_sp * (1-momentum))

        new_sim = 0.5 - 0.2 * novely_rate
        self.sim_thershold.data.mul_(momentum).add_(new_sim * (1-momentum))


    

        
    def get_memory_metrics(self):

        "Return memory health and retivel param details"
        active_mask  =  self.concept_energy > self.energy_threshold
        energy  =  self.concept_energy 
        usage =  self.usage 
        access =  self.access_count

        age_hist   = torch.histc(self.memory_age.float(), bins=10, min=0, max=float(self.memory_age.max()))
        usage_hist = torch.histc(usage, bins=10, min=0, max=1.0)
        access_hist= torch.histc(access.float(), bins=10, min=0, max=float(access.max()))
        active_concepts = self.active_mask &(self.concept_energy > 0.70)

        return  {

            #________Memory Health _______________________________
            'memory_size': self.memory_size  , 
            'active_concepts':self.active_mask.sum().item(),
            'active_concepts_with_high_energy':active_concepts.sum().item(),
            'utilization':self.utilization , 
            'energy_mean':energy.mean().item(), 
            'energy_std':energy.std().item(), 
            'age_histogram': age_hist, 
            'usage_histogram':usage_hist, 
            'access_histogram':access_hist, 
            # 'merge_rate':(energy < 0.3).sum().item() / self.memory_size ,
            'merge_rate':self.merge_count.item() / max(1 , self.step_count.item()),
            # "prune_rate":((self.age > self.prune_age_threshold) & (usage < 0.01)).float().mean().item(),            
            'prune_rate':self.prune_count.item() / max(1 , self.step_count.item()),
            'neuro_rate':(self.memory_age < 10).float().mean().item(), 
            "reuse_efficiency":      access[energy > 0.5].float().mean().item(),

             # —— retrieval/write stats ———————————————————————
            "steps":                 self.step_count.item(),
            "queries":               self.query_count.item(),
            "novelty_rate":          self.novel_count.item() / max(1, self.query_count.item()),
            "write_rate":            self.write_count.item() / max(1, self.query_count.item()),
            "hit_rate":              self.hit_count.item() / max(1, self.query_count.item()),
            'merge_count':self.merge_count.item() , 
            'neuroslot_count':self.neuroslot_count.item(), 
            'prune_count':self.prune_count.item(), 
            'consalidate_count':self.consalidate_count.item(),
            'hit_count':self.hit_count.item() , 
            'write_count':self.write_count.item(),
            'memory_version':self._memory_version.item(),
            'access_count':self.access_count[self.active_mask].tolist(), 
            'replay_count':self.replay_count.item(),
            'access_count_sum': self.access_count[self.active_mask].sum().item(),

                'access_count_mean': self.access_count[self.active_mask].mean().item(),
            'update_count':self.update_count.item()
            
        }


        
    def model_save(self , path):
            torch.save({

                'key_memory':self.key_memory.data.cpu(), 
                'value_memory':self.value_memory.data.cpu() , 
                'cell_state':self.cell_state.data.cpu(), 
                'active_mask':self.active_mask.cpu(),
                'age':self.age.data.cpu() , 
                'usage':self.usage.data.cpu(), 
                'access_count':self.access_count.data.cpu(), 
                'memory_version':self._memory_version.data.cpu(), 
                'memory_age':self.memory_age.data.cpu(), 
                'concept_queue':self.concept_queue.data.cpu(), 
                'queue_ptr':self.queue_ptr , 
                'queue_count':self.queue_count , 
                'query_count':self.query_count , 
                'step_count':self.step_count , 
                'novel_count':self.novel_count , 
                'write_count':self.write_count , 
                'hit_count':self.hit_count , 
                'merge_count':self.merge_count , 
                'neuroslot_count':self.neuroslot_count , 
                'prune_count':self.prune_count , 
                'consalidate_count':self.consalidate_count ,
                'memory_size':self.memory_size,
                'replay_count':self.replay_count

             } , path)
            
    def model_load(self, path, map_location=None):
        state = torch.load(path, map_location=map_location)
        self.key_memory.data.copy_(state['key_memory'])
        self.value_memory.data.copy_(state['value_memory'])
        self.cell_state.data.copy_(state['cell_state'])
        self.active_mask.data.copy_(state['active_mask'])
        self.age.data.copy_(state['age'])
        self.usage.data.copy_(state['usage'])
        self.access_count.data.copy_(state['access_count'])
        self.memory_age.data.copy_(state['memory_age'])
        self.concept_queue.data.copy_(state['concept_queue'])
        self.queue_ptr = state.get('queue_ptr', 0)
        self.queue_count = state.get('queue_count', 0)
        self._memory_version.data.copy_(
            state.get('memory_version', torch.tensor(0))
        )


# Transformer Block

In [10]:




class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = MultiHeadAttention_V2(
        d_in=cfg["emb_dim"],
        d_out=cfg["emb_dim"],
        context_length=cfg["context_length"],
        num_heads=cfg["n_heads"],
        dropout=cfg["drop_rate"],
        qkv_bias=cfg["qkv_bias"])
        self.ff = FeedForward(cfg)
        self.norm1 = LayerNorm(cfg["emb_dim"])
        self.norm2 = LayerNorm(cfg["emb_dim"])
        self.drop_resid = nn.Dropout(cfg["drop_rate"])
    def forward(self, x):
    #A
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.drop_resid(x)
        x = x + shortcut # Add the original input back
        shortcut = x #B
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_resid(x)
        x = x + shortcut #C
        return x



class TransformerBlock_v2(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attention =  MultiQueryAttentionBlock(d_model=cfg['emb_dim'], h=cfg['n_heads'] , dropout=cfg['drop_rate'], seq_len=  cfg['context_length'] ,qkv_bias=cfg['qkv_bias'])

        self.feed_forward = FeedForward(cfg)

        self.layernorm1 =  LayerNorm(cfg['emb_dim'])
    
        self.layernorm2 =  LayerNorm(cfg['emb_dim'])

        self.drop_out = nn.Dropout(cfg['drop_rate'])

    def forward(self, x , mask= None):

        attention_output =  self.attention(self.layernorm1(x) , mask =  mask)

        ff_output =  self.feed_forward(self.layernorm2(x))

        return x + self.drop_out(ff_output) + self.drop_out(attention_output)
class TransformerBlockWithMemory(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.attention = MultiQueryAttentionBlock(
            d_model=cfg['emb_dim'],
            h=cfg['n_heads'],
            dropout=cfg['drop_rate'],
            seq_len=cfg['context_length'],
            qkv_bias=cfg['qkv_bias']
        )
        self.memory = MemorySystem(cfg=cfg)
        self.feed_forward = FeedForward(cfg=cfg)
        
        self.norm1 = LayerNorm(cfg['emb_dim'])
        self.norm2 = LayerNorm(cfg['emb_dim'])
        self.norm3 = LayerNorm(cfg['emb_dim'])
        
        # Memory gate
        self.memory_gate = nn.Sequential(
            nn.Linear(cfg['emb_dim'] * 2, cfg['emb_dim']),
            nn.Sigmoid()
        )
        
        self.dropout = nn.Dropout(cfg['drop_rate'])

    def forward(self, x, mask=None):
        attn_out = self.attention(self.norm1(x), mask=mask)
        x = x + self.dropout(attn_out)
        
        norm_x = self.norm2(x)
        epic_out , semantic_out , memory_out = self.memory(norm_x)
        
        gate_input = torch.cat([norm_x, memory_out], dim=-1)
        memory_gate = self.memory_gate(gate_input)
        x = x + memory_gate * memory_out
        
        ff_out = self.feed_forward(self.norm3(x))
        x = x + self.dropout(ff_out)
        
        return x
class TransformerBlockWithMemory(nn.Module):
    def __init__(self, cfg, shared_memory=None):
        super().__init__()
        # Core components
        self.attention = MultiQueryAttentionBlock(
            d_model=cfg['emb_dim'],
            h=cfg['n_heads'],
            dropout=cfg['drop_rate'],
            seq_len=cfg['context_length'],
            qkv_bias=cfg['qkv_bias']
        )
        self.ffn = FeedForward(cfg=cfg)
        
        # Memory system (shared across blocks)
        self.memory = shared_memory or MemorySystem(cfg=cfg)
        
        # Normalization layers
        self.pre_ln_attn = RMSNorm(cfg['emb_dim'])
        self.pre_ln_mem = RMSNorm(cfg['emb_dim'])
        self.pre_ln_ffn = RMSNorm(cfg['emb_dim'])
        
        # Adaptive memory gating
        self.memory_gate = nn.Sequential(
            nn.Linear(cfg['emb_dim'], 1),
            nn.Sigmoid()
        )
        
        # Memory residual weights
        self.mem_alpha = nn.Parameter(torch.tensor(0.5))
        self.dropout = nn.Dropout(cfg['drop_rate'])

    def forward(self, x, mask=None):
        # Attention phase
        resid = x
        x = self.pre_ln_attn(x)
        x = resid + self.dropout(self.attention(x, mask=mask))
        
        # Memory phase
        resid_mem = x
        x_mem = self.pre_ln_mem(x)
        # print('x shape ', x.shape)
        _, _, memory_out = self.memory(x_mem)
        
        # Adaptive gating
        gate = self.memory_gate(x_mem)
        x = resid_mem + self.mem_alpha * gate * memory_out
        
        # FFN phase
        resid_ffn = x
        x = self.pre_ln_ffn(x)
        x = resid_ffn + self.dropout(self.ffn(x))
        
        return x




# GPTQModel

In [11]:
import torch.nn.functional as F


class InputEmbedding(nn.Module):

    def __init__(self, vocab_size: int , d_model:int):

        super().__init__()

        self.d_model  =  d_model 

        self.vocab_size = vocab_size

        self.embeddings = nn.Embedding(vocab_size , d_model)

    def forward(self ,x):

        return self.embeddings(x) * math.sqrt(self.d_model)
    



class ProjectionLayer(nn.Module):
    def __init__(self, d_model: int, vocab_size: int, embdding_layer: nn.Embedding):
        super().__init__()
        self.weight = embdding_layer.weight  # share weights with input embedding
        self.bias = nn.Parameter(torch.zeros(vocab_size))  # learnable bias

    def forward(self, x):
        return F.linear(x, self.weight, self.bias)




class GPTMQModel2(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.embedding = InputEmbedding(cfg['vocab_size'], cfg['emb_dim'])

        # Use ModuleList instead of Sequential
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock_v2(cfg=cfg) for _ in range(cfg['n_layers'])
        ])

        self.drop_out = nn.Dropout(cfg['drop_rate'])
        # self.final_norm = LayerNorm(emb_dim=cfg['emb_dim'])
        self.final_norm =  RMSNorm(dim = cfg['emb_dim'])

        self.projection = ProjectionLayer(cfg['emb_dim'], cfg['vocab_size'], self.embedding.embeddings)

    def forward(self, input_tokens, mask=None):
        x = self.embedding(input_tokens)

        for block in self.transformer_blocks:
            x = block(x, mask=mask)  # Pass the mask explicitly to each block

        x = self.final_norm(x)
        logits = self.projection(x)

        return logits
class GPTMQMemoryModel1(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.embedding = InputEmbedding(cfg['vocab_size'], cfg['emb_dim'])

        # Use ModuleList instead of Sequential
        self.transformer_blocks =  nn.ModuleList([
            TransformerBlockWithMemory(cfg=cfg) for _ in range(cfg['n_layers'])
        ])

        self.drop_out = nn.Dropout(cfg['drop_rate'])
        # self.final_norm = LayerNorm(emb_dim=cfg['emb_dim'])
        self.final_norm =  RMSNorm(dim=cfg['emb_dim'])

        self.projection = ProjectionLayer(cfg['emb_dim'], cfg['vocab_size'], self.embedding.embeddings)

    def forward(self, input_tokens, mask=None):
        x = self.embedding(input_tokens)

        for block in self.transformer_blocks:
            x = block(x, mask=mask)  # Pass the mask explicitly to each block

        x = self.final_norm(x)
        logits = self.projection(x)

        return logits
class GPTMQMemoryModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.embedding = InputEmbedding(cfg['vocab_size'], cfg['emb_dim'])
        
        # Shared memory system across layers
        self.shared_memory = MemorySystem(cfg=cfg)
        
        # Transformer blocks with shared memory
        self.transformer_blocks = nn.ModuleList([
            TransformerBlockWithMemory(
                cfg=cfg,
                shared_memory=self.shared_memory if cfg['share_memory'] else None
            ) for _ in range(cfg['n_layers'])
        ])
        
        # Final projections
        self.final_norm = RMSNorm(dim=cfg['emb_dim'])
        self.projection = ProjectionLayer(
            cfg['emb_dim'], 
            cfg['vocab_size'], 
            self.embedding.embeddings
        )
        self.memory_retention_alpha = nn.Parameter(torch.tensor(0.9))

        # Memory loss coefficient
        self.mem_loss_coef = cfg.get('mem_loss_coef', 0.3)

    def forward(self, input_tokens, mask=None):
        x = self.embedding(input_tokens)
        
        for block in self.transformer_blocks:
            x = block(x, mask=mask)
            x = self.memory_retention_alpha * x + (1 - self.memory_retention_alpha) * x.detach()
            
        x = self.final_norm(x)
        logits = self.projection(x)
        
        return logits
    
    def get_memory_loss(self):
        """Get combined memory regularization loss"""
        return self.mem_loss_coef * self.shared_memory.memory_loss()
    
    def transformer_parameters(self):
        return [p for n, p in self.named_parameters() if 'transformer_blocks' in n and p.requires_grad]
    
    def memory_parameters(self):
        return [p for n, p in self.named_parameters() if 'memory_modules' in n and p.requires_grad]
    
    def embedding_parameters(self):
        return [p for n, p in self.named_parameters() if 'embedding' in n and p.requires_grad]
    
    def norm_parameters(self):
        return [p for n, p in self.named_parameters() if 'normalization' in n and p.requires_grad]
    
    def output_parameters(self):
        return [p for n, p in self.named_parameters() if 'output_projection' in n and p.requires_grad]



In [12]:

class GPTMemoryEnhanced(nn.Module):
    def __init__(self,cfg ):
        super().__init__()
        self.embedding =  InputEmbedding(cfg['vocab_size'] , cfg['emb_dim'])
        self.memory_proj = nn.Linear(cfg['emb_dim'], cfg['memory_dim'])
        self.memory_expander = nn.Linear(cfg['memory_dim'], cfg['emb_dim'] )
        self.transformer_block = nn.ModuleList([
            TransformerBlock_v2(cfg=cfg) for _ in range(cfg['n_layers'])
        ])
        self.dropout =  nn.Dropout(cfg['drop_rate'])
        self.memory =  EfiBioSemanticMemory_V2(input_dim=cfg['memory_dim'] ,semantic_memory_dim=cfg['memory_dim'],num_heads=2)

        self.final_norm = RMSNorm(dim=cfg['emb_dim'])
        self.fusion_gate = nn.Sequential(
                nn.Linear(2*cfg['emb_dim'], cfg['emb_dim']),
                nn.Sigmoid()
            )
        

        self.projection =  ProjectionLayer(cfg['emb_dim'], cfg['vocab_size'] , self.embedding.embeddings)

    def forward(self,input_tokens:torch.Tensor , mask =None):
        x =  self.embedding(input_tokens)
        x_emb = x 
        for block in self.transformer_block:
            x  = block(x, mask = mask)
        memory_query = self.memory_proj(x) 
        mem_out, retrieved, topk_idx, attn_w = self.memory(memory_query)
        # memory_out, _ ,_ =  self.memory(x.las_hidden_state.mean(1))
        mem_out = self.memory_expander(mem_out)
        gate = self.fusion_gate(torch.cat([x, mem_out], -1))
        fused = gate * x + (1 - gate) * mem_out
        fused =  self.final_norm(fused)

        logits =  self.projection(fused)
        # return logits, {
        #     "memory_topk": topk_idx, 
        #     "memory_attn": attn_w,
        #     "retrieved":  retrieved
        # }
        return  logits ,x_emb ,  mem_out 
    


# Loss Functions

In [13]:


# Memory Reconstruction Loss
# Ensures stored information preserves input patterns
def reconstruction_loss(inputs, memory_output):
    return F.mse_loss(memory_output, inputs)



# Task-Specific Loss
# Drives memory to store task-relevant information
def task_loss(predictions, targets):
    return F.cross_entropy(predictions, targets)  # For classification

# Memory Sparsity Loss
# Encourages efficient slot usage
def sparsity_loss(concept_energy):
    return torch.mean(concept_energy**2)  # L2 penalty on energy levels

# Memory Diversity Loss
# Prevents slot redundancy
def diversity_loss(key_memory):
    normalized_keys = F.normalize(key_memory, dim=1)
    similarity = torch.mm(normalized_keys, normalized_keys.T)
    return torch.mean(similarity**2) - 1/torch.numel(similarity)



# Energy Maintenance Loss
# Maintains healthy energy distribution
def energy_loss(concept_energy):
    energy_mean = torch.mean(concept_energy)
    return F.mse_loss(energy_mean, torch.tensor(0.5,device = concept_energy.device))



# Pruning Incentive Loss
# Encourages proper slot turnover
def pruning_loss(age, usage):
    old_unused = (age > 100) & (usage < 0.01)
    return torch.mean(old_unused.float())



# Anti-Collapse Loss
# Prevents memory dependency on few slots
def anti_collapse_loss(usage_counts):
    return -torch.sum(usage_counts * torch.log(usage_counts + 1e-7))



def novelty_loss(new_slots, existing_memory):
    sim = F.cosine_similarity(new_slots.unsqueeze(1), 
                            existing_memory.unsqueeze(0), dim=-1)
    return torch.mean(sim)




def total_loss(inputs, outputs, targets, memory):
    # Base losses
    rec_loss = reconstruction_loss(inputs, outputs)
    # t_loss = task_loss(outputs, targets)
    
    # Memory regularization
    sp_loss = sparsity_loss(memory.concept_energy)
    div_loss = diversity_loss(memory.key_memory)
    en_loss = energy_loss(memory.concept_energy)
    
    # Stability terms
    prun_loss = pruning_loss(memory.age, memory.usage)
    anti_loss = anti_collapse_loss(F.softmax(memory.access_count, dim=0))
    
    # Weighted combination
    return (1.0 * rec_loss + 
            # 0.5 * t_loss + 
            0.3 * sp_loss + 
            0.2 * div_loss +
            0.1 * en_loss +
            0.05 * prun_loss +
            0.02 * anti_loss)



In [14]:




def cal_loss_batch(input_batch , target_batch , model:torch.nn.Module , device:torch.device ):
    input_batch , target_batch = input_batch.to(device) , target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(   logits.flatten(0,1), target_batch.flatten())
    util_loss = -torch.log(model.memory.utilization + 1e-8)
    
    return loss + 0.1 * util_loss
def cal_loss_batch(input_batch , target_batch ,model:nn.Module, device:torch.device , mem_cof:float= 0.1):
    input_batch  , target_batch =  input_batch.to(device) , target_batch.to(device)
    logits ,x_emb, mem_output  =  model(input_batch)
    B,T,V = logits.shape 
    gpt_loss = F.cross_entropy(
            logits.view(B * T, V),
            target_batch.view(B * T),
            ignore_index=-100,                       # if you pad with -100
        )   
    # utilization_loss = -torch.log(model.memory.utilization + 1e-8)
    memory_loss =  total_loss(inputs=x_emb , memory= model.memory , outputs=mem_output , targets= target_batch.float())
    return gpt_loss +mem_cof * memory_loss 


def calc_loss_loader(data_loader , model , device , num_batches = None):
    total_loss = 0
    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches  = min(num_batches , len(data_loader))
    for i , (inputs , target) in enumerate(data_loader):
        if i < num_batches:
            loss  =  cal_loss_batch(inputs , target , model , device)

            total_loss +=loss.item()

        else:
            break

        return total_loss  / num_batches
    


# Text Generation Function

In [15]:
import torch
import torch.nn as nn 




def text_to_token_ids(text,  tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded = torch.tensor(encoded).unsqueeze(0)
    return encoded
    
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer(text, return_tensors="pt", add_special_tokens=False)["input_ids"]
    return encoded  # Already 2D [1, seq_len]



def token_ids_to_text(tokens , tokenizer):
    flat  = tokens.squeeze(0)
    decode = tokenizer.decode(flat.tolist())
    return decode
def token_ids_to_text(tokens, tokenizer):
    flat = tokens.squeeze(0)
    decoded = tokenizer.decode(flat.tolist(), skip_special_tokens=True)
    return decoded

    
def generate_and_sample(model  , idx , context_size ,max_new_tokens ):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits  , _ , _ = model(idx_cond)
        logits  = logits[:, -1  , :]
        probs  = torch.softmax(logits  , dim=-1)
        idx_next = torch.argmax(probs, dim=-1 , keepdim= True)
        idx = torch.cat((idx, idx_next), dim=1)
    return idx 

#
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]  # shape: [1, current_seq_len]

        # Create causal mask dynamically
        seq_len = idx_cond.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(idx.device)
        causal_mask = causal_mask.unsqueeze(0)  # [1, seq_len, seq_len]

        with torch.no_grad():
            logits , _ , _= model(idx_cond, mask=causal_mask)  # <--- pass mask here

        logits = logits[:, -1, :]  # only take the last token logits

        # Apply top-k sampling if needed
        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(
                logits < min_val,
                torch.tensor(float('-inf')).to(logits.device),
                logits
            )

        # Temperature sampling
        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        idx = torch.cat((idx, idx_next), dim=1)
        

    return idx
def real_time_generation(model, initial_input, context_size, temperature, top_k=None, device="cpu"):
    # Tokenize the initial input and prepare the model context
    idx = torch.tensor(initial_input).unsqueeze(0).to(device)  # Assuming initial_input is tokenized
    
    print("Starting real-time generation...")
    
    # Start generating tokens in real-time
    for new_token in generate(model, idx, max_new_tokens=50, context_size=context_size, temperature=temperature, top_k=top_k, device=device):
        print(f"Generated token: {new_token.item()}")  # Or decode it back to a word
        
        # You can check for user input here and update idx with the new input
        # For instance, wait for the user to input a prompt to append to the context
        user_input = input("Enter new input (or press enter to continue generation): ")
        
        if user_input:
            # Tokenize the new user input and append it to the context
            user_input_tokens = torch.tensor(tokenize(user_input)).unsqueeze(0).to(device)
            idx = torch.cat((idx, user_input_tokens), dim=1)  # Append the new tokens to the context
        else:
            # Continue generating if no new user input
            continue

# Function to tokenize input (adjust depending on your tokenizer)
def tokenize(text):
    # Assuming you have a tokenizer function available
    return [ord(c) for c in text]  # Dummy example: ord() converts char to token id



# Dataset and DataLoader 

In [16]:
import torch
from torch.utils.data import Dataset, DataLoader
import tiktoken
import json
from torch.nn.utils.rnn import pad_sequence


def generate_prompt(sample):
    # return f"<user> {sample['instruction']} <bot> {sample['output']}"
        return f"### Instruction:\n{sample['instruction']}\n\n### Response:\n{sample['output']} <|endoftext|>"
    


class Dataset_V1(Dataset):
    def __init__(self, data, tokenizer, max_length, stride):
        self.max_length = max_length
        self.input_ids = []
        self.target_ids = []
        self.tokenizer =  tokenizer

        all_tokens = []
        allowed = {'<|endoftext|>'}
        for sample in data:
            prompt = generate_prompt(sample)
            # tokens = tokenizer.encode(prompt , allowed_special=allowed)
            tokens =  tokenizer.encode(prompt)
            all_tokens.extend(tokens)

        for i in range(0, len(all_tokens) - max_length, stride):
            input_chunk = all_tokens[i: i + max_length]
            target_chunk = all_tokens[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]
def collate_fn(batch):
    inputs, targets = zip(*batch)
    inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
    targets = pad_sequence(targets, batch_first=True, padding_value=-100)  # -100 is ignored by CrossEntropyLoss
    return inputs, targets

def create_dataloader_v1(data,tokenizer , batch_size=4,
    max_length=256, stride=128, shuffle=True, drop_last=True ):
    tokenizer = tiktoken.get_encoding("gpt2") #tokenizer 
    dataset = Dataset_V1(data, tokenizer, max_length, stride) #B
    dataloader = DataLoader(
    dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last , collate_fn=collate_fn)
    return dataloader


In [17]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from torch.nn.utils.rnn import pad_sequence

def generate_prompt(sample):
    return f"### Instruction:\n{sample['instruction']}\n\n### Response:\n{sample['output']} <|endoftext|>"

class Dataset_V1(Dataset):
    def __init__(self, data, tokenizer, max_length, stride):
        self.max_length = max_length
        self.input_ids = []
        self.target_ids = []
        self.tokenizer = tokenizer

        all_tokens = []
        for sample in data:
            prompt = generate_prompt(sample)
            tokens = tokenizer(prompt, add_special_tokens=False)["input_ids"]
            all_tokens.extend(tokens)

        for i in range(0, len(all_tokens) - max_length, stride):
            input_chunk = all_tokens[i: i + max_length]
            target_chunk = all_tokens[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]

def collate_fn(batch):
    inputs, targets = zip(*batch)
    inputs = pad_sequence(inputs, batch_first=True, padding_value=teacher_tokenizer.pad_token_id)
    targets = pad_sequence(targets, batch_first=True, padding_value=-100)
    return inputs, targets

def create_dataloader_v1(data, tokenizer, batch_size=4, max_length=256, stride=128, shuffle=True, drop_last=True):
    dataset = Dataset_V1(data, tokenizer, max_length, stride)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
                            drop_last=drop_last, collate_fn=collate_fn)
    return dataloader


# Dataset And DataLoader for Psycology Dataset 

In [18]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import tiktoken
from torch.nn.utils.rnn import pad_sequence

class Dataset_v2(Dataset):
    def __init__(self, data, tokenizer, max_length, stride):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.stride = stride
        self.input_ids = []

        all_tokens = []
        for sample in data:
            tokens = tokenizer.encode(sample)  
            all_tokens.extend(tokens)

        # Split the tokens into chunks of size max_length with stride
        for i in range(0, len(all_tokens) - self.max_length, self.stride):
            input_chunk = all_tokens[i:i + self.max_length]
            target_chunk = all_tokens[i + 1:i + self.max_length + 1]
            self.input_ids.append((torch.tensor(input_chunk), torch.tensor(target_chunk)))

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, index):
        return self.input_ids[index]

def collect_fn(batch):
    inputs, targets = zip(*batch)
    inputs = pad_sequence(inputs, batch_first=True, padding_value=0) 
    targets = pad_sequence(targets, batch_first=True, padding_value=-100)  
    return inputs, targets

def create_dataloader_v2(data, batch_size=4, max_length=1024, stride=128, shuffle=True, drop_last=True):
    tokenizer = tiktoken.get_encoding("gpt2")  
    dataset = Dataset_v2(data, tokenizer, max_length, stride) 
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, collate_fn=collect_fn)
    return dataloader

def load_txt_file(filepath):
    with open(filepath, 'r') as f:
        text = f.read()
    return text

def split_into_chunks(text, chunk_size=1024, overlap=200):

    chunks = []
    for i in range(0, len(text), chunk_size - overlap):
        chunk = text[i:i + chunk_size]
        chunks.append(chunk)

    return chunks



file =  '/kaggle/input/datasetcleaned/cleaned_books.txt'
load_text =  load_txt_file(file)
chunk = split_into_chunks(load_text)


# Train Script 

* 

In [19]:

from tqdm.auto import tqdm
from transformers import get_cosine_schedule_with_warmup
from sklearn.cluster import KMeans
import numpy as np

def evaluate_model(model, train_dataloader, eval_dataloader, device, eval_iter):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_dataloader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(eval_dataloader, model, device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss


def generate_and_print_sample(model, tokenizer, device, start_context):
    model.eval()

    encoded = text_to_token_ids(start_context, tokenizer).to(device)

    with torch.no_grad():
        token_ids = generate(
            model=model,
            idx=encoded,
            temperature=1.4,
            max_new_tokens=64,   # Increase generation length if needed
            context_size=126,
            top_k=25
        )
        decoded_text = token_ids_to_text(token_ids, tokenizer)

        # Trim everything before the generation
        generated_only = decoded_text[len(start_context):].strip()

        # Stop at endoftext token if present
        end_marker = "<|endoftext|>"
        if end_marker in generated_only:
            generated_only = generated_only.split(end_marker)[0].strip()

        print(f"\n[Prompt]: {start_context.strip()}\n")
        print(f"[Generated]: {generated_only}\n")

    model.train()
def save_model_checkpoint(model, optimizer, epoch, path="checkpoint_epoch_{}.pt"):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch  , 
        'memory_model_state_dict': model.memory.state_dict()   ,
        
    }
    torch.save(checkpoint, path.format(epoch))
def after_save_load():
    checkpoint = torch.load("checkpoint_epoch_7.pt")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']


def train_model(
    model: nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    device: torch.device,
    eval_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    eval_freq: int,
    eval_iter: int,
    start_context: str,
    num_epochs: int = 1
):
    torch.autograd.set_detect_anomaly(True)
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1
    total_steps = len(train_dataloader) * num_epochs
    print(f"🚀 Total training steps: {total_steps}")
    # scheduler = get_cosine_schedule_with_warmup(optimizer,
    #                                         num_warmup_steps=500,
    #                                         num_training_steps=total_steps)
    for p in model.parameters():p.requires_grad= True 
    # for p in model.memory.compression.parameters():p.requires_grad =  False 
    # for p in model.memory.W_cell.parameters():p.requires_grad = False 
    ckpt = torch.load("memory_encoder.pth")
    model.memory.compression.load_state_dict(ckpt["compression"])
    model.memory.W_cell.load_state_dict(ckpt["W_cell"])
    print("Encoder weights are loaded ")
    with torch.no_grad():
        all_concepts =  []
        
    
        for  input_batch , target_batch in train_dataloader:
            input_batch = input_batch.to(device)
            x = model.embedding(input_batch)
            x =  model.memory.compression(x.mean(dim=1))
            x = model.memory.W_cell(x)
            all_concepts.append(x)

    concept_pool=  torch.cat(all_concepts , dim=0)
    k = min(model.memory.memory_size , concept_pool.shape[0])
    kmeans = KMeans(n_clusters = k , random_state = 42)
    kmeans.fit(concept_pool.cpu().numpy())
    centroids = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32).to(device)
    print(centroids.shape , model.memory.key_memory.shape)
    centroids = F.normalize(centroids, dim=-1)

    n_centroids = centroids.size(0)  # 114
    
    with torch.no_grad():
        model.memory.key_memory.data[:n_centroids] = centroids
        model.memory.value_memory.data[:n_centroids] = centroids
        model.memory.cell_state.data[:n_centroids] = centroids

    print(centroids.shape , model.memory.key_memory.shape)

    for epoch in tqdm(range(num_epochs)):
        model.train()
        for inputs_batch, target_batch in train_dataloader:
            inputs_batch, target_batch = inputs_batch.to(device), target_batch.to(device)

            optimizer.zero_grad()
            loss = cal_loss_batch(input_batch=inputs_batch, target_batch=target_batch, device=device, model=model)
            loss.backward()
            optimizer.step()
            # scheduler.step()

            tokens_seen += inputs_batch.numel()
            global_step += 1

            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_dataloader, eval_dataloader, device, eval_iter
                )

                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)

                print(
                    f"Epoch: {epoch+1} (step {global_step:06d}):",
                    f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}"
                )

        generate_and_print_sample(
            model, train_dataloader.dataset.tokenizer, device, start_context
        )
        save_model_checkpoint(model , optimizer , epoch+1)
        


    return train_losses, val_losses, track_tokens_seen


In [100]:
import torch.nn.functional as F
from torch import nn

def cal_loss_batch(input_batch, target_batch,
                   teacher_model: nn.Module,
                   student_model: nn.Module,
                   device: torch.device,
                   mem_cof: float = 0.1,
                   distil_temp: float = 5.0,
                   alpha: float = 0.9):

    input_batch = input_batch.to(device)
    target_batch = target_batch.to(device)

    with torch.no_grad():
        teacher_output = teacher_model(input_batch)
        teacher_logits = teacher_output.logits  # [B, T, V]

    student_logits, x_emb, mem_output = student_model(input_batch)  # [B, T, V]
    B, T, V = student_logits.shape

    assert student_logits.shape == teacher_logits.shape, \
        f"Shape mismatch: student={student_logits.shape}, teacher={teacher_logits.shape}"

    # Knowledge distillation loss
    student_soft = F.log_softmax(student_logits.float() / distil_temp, dim=-1)
    teacher_soft = F.softmax(teacher_logits.float() / distil_temp, dim=-1)
    teacher_soft = torch.clamp(teacher_soft, min=1e-9)

    distillation_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean') * (distil_temp ** 2)
    distillation_loss = torch.clamp(distillation_loss, max=100.0)

    # Cross-entropy loss
    ce_loss = F.cross_entropy(
        student_logits.view(B * T, V),
        target_batch.view(B * T),
        ignore_index=-100,
    )

    # Memory loss
    memory_loss = total_loss(inputs=x_emb, memory=student_model.memory,
                             outputs=mem_output, targets=target_batch.float())

    total = alpha * ce_loss + (1 - alpha) * distillation_loss + mem_cof * memory_loss


    return total


In [106]:
def evaluate_model(student_model, teacher_model,train_dataloader, eval_dataloader, device, eval_iter):
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_dataloader, student_model,teacher_model ,  device, num_batches=eval_iter)
        val_loss = calc_loss_loader(eval_dataloader, student_model,teacher_model ,  device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss
def calc_loss_loader(data_loader , student_model , teacher_model ,device , num_batches = None):
    total_loss = 0
    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches  = min(num_batches , len(data_loader))
    for i , (inputs , target) in enumerate(data_loader):
        if i < num_batches:
            loss  =  cal_loss_batch(inputs , target , teacher_model , student_model , device)

            total_loss +=loss.item()

        else:
            break

        return total_loss  / num_batches
    

In [107]:
from torch.cuda.amp import autocast, GradScaler
from tqdm.notebook import tqdm
import torch
import torch.nn.functional as F

def train_model(
    teacher_model: nn.Module,
    model: nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    device: torch.device,
    eval_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    eval_freq: int,
    eval_iter: int,
    start_context: str,
    num_epochs: int = 1
):
    scaler = GradScaler()
    torch.autograd.set_detect_anomaly(True)

    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1
    total_steps = len(train_dataloader) * num_epochs
    print(f"🚀 Total training steps: {total_steps}")

    # Ensure all parameters require grad
    for p in model.parameters():
        p.requires_grad = True

    # Load pretrained encoder weights into the memory module
    ckpt = torch.load('/kaggle/input/v3-encoder-weight/memory_encoder(3).pth')
    model.memory.compression.load_state_dict(ckpt["compression"])
    model.memory.W_cell.load_state_dict(ckpt["W_cell"])
    print("Encoder weights are loaded ")

    # Build initial memory centroids via KMeans
    with torch.no_grad():
        all_concepts = []
        for input_batch, _ in train_dataloader:
            input_batch = input_batch.to(device)
            x = model.embedding(input_batch)
            x = model.memory.compression(x.mean(dim=1))
            x = model.memory.W_cell(x)
            all_concepts.append(x)
        concept_pool = torch.cat(all_concepts, dim=0)

        k = min(model.memory.memory_size, concept_pool.shape[0])
        from sklearn.cluster import KMeans
        kmeans = KMeans(n_clusters=k, random_state=42)
        kmeans.fit(concept_pool.cpu().numpy())
        centroids = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32).to(device)
        centroids = F.normalize(centroids, dim=-1)
        n_centroids = centroids.size(0)

        model.memory.key_memory.data[:n_centroids] = centroids
        model.memory.value_memory.data[:n_centroids] = centroids
        model.memory.cell_state.data[:n_centroids] = centroids

    # Training loop
    for epoch in range(num_epochs):
        loop = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        model.train()

        for inputs_batch, target_batch in loop:
            inputs_batch = inputs_batch.to(device)
            target_batch = target_batch.to(device)

            optimizer.zero_grad()

            # 1) Compute loss under autocast
            with autocast():
                loss = cal_loss_batch(
                    input_batch=inputs_batch,
                    target_batch=target_batch,
                    device=device,
                    student_model=model,
                    teacher_model=teacher_model
                )

    
            # 3) Backward + gradient clipping + optimizer step
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()


            tokens_seen += inputs_batch.numel()
            global_step += 1

            # 5) Periodic evaluation
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model,teacher_model ,  train_dataloader, eval_dataloader, device, eval_iter
                )
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)
                loop.set_postfix({
                    'loss': f"{loss.item():.4f}",
                    'step': global_step,
                    'train_loss': f"{train_loss:.4f}",
                    'val_loss': f"{val_loss:.4f}"
                })

        # 6) Generate a sample and save checkpoint at end of epoch
        generate_and_print_sample(model, train_dataloader.dataset.tokenizer, device, start_context)
        save_model_checkpoint(model, optimizer, epoch + 1)

    return train_losses, val_losses, track_tokens_seen


# Knowledge Distillation

In [26]:

teacher_name = "deepseek-ai/deepseek-llm-7b-chat"
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_name, device_map="auto", torch_dtype=torch.float16)
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_name)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [27]:
# Loss for Knowledge distillation
def distillation_loss(student_logits , teacher_logits  , labels , T = 2.0 , alpha = 0.7):
    loss_kd =  F.kl_div(
        F.log_softmax(stident_logits / T , dim=-1) , 
        F.softmax(teacher_logits / T , dim=-1),
         reduction="batchmean"
    ) *(T**2 )

    loss_ce = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), labels.view(-1))
    return alpha * loss_kd +(1-alpha) * loss_ce 

In [29]:
teacher_model.config.vocab_size 

102400

In [30]:
teacher_tokenizer.vocab_size 

100000

In [31]:
teacher_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(102400, 4096)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-06)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-06)
      )
    )
    (n

# 

# MemoryGPT Model training  


In [24]:
from torch.optim import AdamW
from torch.nn.utils import clip_grad_norm_
import numpy as np

# GPT Config 

In [25]:
torch.manual_seed(123)
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"context_length": 126,#126 
# Context lengt
"emb_dim": 768,
# Embedding dimension
"n_heads": 12,
# Number of attention heads
"n_layers": 12,
# Number of layers
"drop_rate": 0.1,
# Dropout rate
"qkv_bias": False
# Query-Key-Value bias
}


In [26]:
teacher_tokenizer.vocab_size

100000

In [32]:
GPT_CONFIG_124M_Memory = {
"vocab_size": teacher_model.config.vocab_size ,#100000,# 50257
    # Vocabulary size
"context_length": 126,
# Context length
"emb_dim": 128,
# Embedding dimension
"n_heads": 4,
# Number of attention heads
"n_layers": 12,
# Number of layers
"drop_rate": 0.1,
# Dropout rate
"qkv_bias": False,
'memory_dim':128,
'max_slots' :1000,
'memory_heads':2 ,

# Query-Key-Value bias
}

In [33]:
def train_encoder(model , train_dataloader , device = 'cuda'):
    memory = model.memory 
    for p in model.parameters():p.requires_grad = False 
    for p in model.memory.compression.parameters(): p.requires_grad = True
    for p in model.memory.W_cell.parameters():p.requires_grad =  True 
    optimizer = torch.optim.Adam(
    list(memory.compression.parameters()) + list(memory.W_cell.parameters()),
    lr=1e-3
)
    global_step= -1 
    print('tatal step ' , len(train_dataloader) * 5)
    for _ in tqdm(range(5)):
        
        for input_batch , _  in train_dataloader:
            input_batch =  input_batch.to(device)
            x1 = model.embedding(input_batch)
            B , _ , _  =  x1.shape
            x2 =  model.embedding(input_batch)

            q1 = memory.W_cell(memory.compression(x1.mean(dim=1)))
            q2 =  memory.W_cell(memory.compression(x2.mean(dim=1)))

            q1, q2 = F.normalize(q1, dim=-1), F.normalize(q2, dim=-1)

            sim = torch.matmul(q1,q2.T)
            loss = F.cross_entropy(sim/0.1  , torch.arange(B, device = sim.device))

            # print(loss)
            optimizer.zero_grad();loss.backward();optimizer.step()
            global_step +=1 
             # print(
             #        f"Epoch: {_+1} (step {global_step:06d}):",
    x, _ = next(iter(train_dataloader))
    x =  x.to(device)
    x = model.embedding(x)

    q =  memory.W_cell(memory.compression(x.mean(1)))
    q1 =  memory.W_cell(memory.compression(x.mean(1)))
    avg_sim = F.cosine_similarity(q, q1, dim=-1).mean().item()
    # print(avg_sim)
    assert avg_sim > 0.95 , 'Encoder still drift too much'

    ckpt = {
        "compression": memory.compression.state_dict(),
        "W_cell":      memory.W_cell.state_dict()
    }
    torch.save(ckpt, "memory_encoder.pth")
    print("saved memory_encoder.pth")



In [34]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)


In [49]:

device =  'cuda' if torch.cuda.is_available() else "cpu"
model =  GPTMemoryEnhanced(GPT_CONFIG_124M_Memory).to(device)

optimizer =  torch.optim.AdamW(model.parameters() , lr=0.0004,weight_decay=0.01 )


In [36]:
from collections import defaultdict

def get_param_group_summary(model):
    groups = defaultdict(int)
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if "embedding" in name:
            groups["embedding"] += param.numel()
        elif "transformer" in name:
            groups["transformer_blocks"] += param.numel()
        elif "memory" in name or "episodic" in name or "semantic" in name:
            groups["memory_modules"] += param.numel()
        elif "norm" in name:
            groups["normalization"] += param.numel()
        elif "lm_head" in name or "projection" in name:
            groups["output_projection"] += param.numel()
        else:
            groups["other"] += param.numel()
    total = sum(groups.values())
    for k, v in groups.items():
        print(f"{k:20s}: {v:,} parameters")
    print(f"\nTotal: {total:,}")
get_param_group_summary(model)


embedding           : 13,107,328 parameters
memory_modules      : 665,227 parameters
transformer_blocks  : 2,082,048 parameters
normalization       : 128 parameters
other               : 32,896 parameters
output_projection   : 102,400 parameters

Total: 15,990,027


In [37]:


num_epochs =  1
train_ratio = 0.90

filename =  '/kaggle/input/alphaco/alpaca_data_cleaned.json'
with open(filename , 'r') as f:
    text_data =  json.load(f)
text_data = text_data[:150]

split = int(train_ratio * len(text_data))

train_data =  text_data[:split]
val_data =  text_data[split:]

train_dataloader =  create_dataloader_v1(data=train_data , batch_size=2 , max_length=GPT_CONFIG_124M_Memory['context_length'] ,tokenizer = teacher_tokenizer ,  shuffle=True , drop_last= True)
val_dataloader = create_dataloader_v1(data=val_data , batch_size=2 , max_length=GPT_CONFIG_124M_Memory['context_length']  ,tokenizer  = teacher_tokenizer, shuffle=False , drop_last=False )


In [38]:
len(train_dataloader)

80

In [39]:
train_encoder(model , train_dataloader )

tatal step  400


  0%|          | 0/5 [00:00<?, ?it/s]

saved memory_encoder.pth


In [41]:
import torch._dynamo
import logging 
dyno_logger =  logging.getLogger('torch._dynamo')
dyno_logger.setLevel(logging.ERROR)
function_logger =  logging.getLogger('torch._functorch')
function_logger.setLevel(logging.ERROR)
torch._dynamo.config.suppress_errors = True

In [95]:
compiled_model = torch.compile(model, mode="max-autotune", backend="aot_eager")


In [91]:
import torch
import torch.nn.functional as F

def cal_loss_batch(
    input_batch, target_batch,
    teacher_model: torch.nn.Module,
    student_model: torch.nn.Module,
    device: torch.device,
    mem_cof: float = 0.1,
    distil_temp: float = 5.0,
    alpha: float = 0.86
):
    input_batch = input_batch.to(device)
    target_batch = target_batch.to(device)

    # 1) Teacher forward
    with torch.no_grad():
        teacher_output = teacher_model(input_batch)
        teacher_logits = teacher_output.logits  # [B, T, V_teacher]

    # 2) Student forward
    student_logits, x_emb, mem_output = student_model(input_batch)  # [B, T, V_student]
    B, T, V_student = student_logits.shape
    V_teacher = teacher_logits.shape[-1]

    # 3) Check vocab‐size match
    assert V_student == V_teacher, f"Vocab mismatch: student={V_student}, teacher={V_teacher}"

    # 4) CE loss
    ce_loss = F.cross_entropy(
        student_logits.view(B * T, V_student),
        target_batch.view(B * T),
        ignore_index=-100,
    )

    # 5) Distillation loss (KL)
    s_logits = (student_logits.float() / distil_temp).log_softmax(dim=-1)  # [B,T,V]
    t_probs = (teacher_logits.float() / distil_temp).softmax(dim=-1).clamp(min=1e-9)
    distil_loss = F.kl_div(s_logits, t_probs, reduction='batchmean') * (distil_temp ** 2)
    distil_loss = torch.clamp(distil_loss, max=100.0)

    # 6) Memory loss
    memory_loss = total_loss(inputs=x_emb, memory=student_model.memory,
                             outputs=mem_output, targets=target_batch.float())

    # 7) Combine
    total = alpha * ce_loss + (1 - alpha) * distil_loss + mem_cof * memory_loss

    # 8) Print for debugging
    print(f"CE: {ce_loss.item():.6f}, KD: {distil_loss.item():.6f}, MEM: {memory_loss.item():.6f}, TOT: {total.item():.6f}")


    return total


In [92]:
from torch.cuda.amp import autocast, GradScaler
from tqdm.notebook import tqdm
import torch

def train_model(
    teacher_model: torch.nn.Module,
    model: torch.nn.Module,
    train_dataloader: torch.utils.data.DataLoader,
    device: torch.device,
    eval_dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    eval_freq: int,
    eval_iter: int,
    start_context: str,
    num_epochs: int = 1
):
    scaler = GradScaler()
    torch.autograd.set_detect_anomaly(True)

    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1
    total_steps = len(train_dataloader) * num_epochs
    print(f"🚀 Total training steps: {total_steps}")

    # Ensure all params require grad
    for param in model.parameters():
        param.requires_grad = True

    # Load memory‐encoder weights
    ckpt = torch.load('/kaggle/input/v3-encoder-weight/memory_encoder(3).pth')
    model.memory.compression.load_state_dict(ckpt["compression"])
    model.memory.W_cell.load_state_dict(ckpt["W_cell"])
    print("Encoder weights are loaded")

    # Build initial memory via KMeans (unchanged)
    with torch.no_grad():
        all_concepts = []
        for input_batch, _ in train_dataloader:
            input_batch = input_batch.to(device)
            x = model.embedding(input_batch)
            x = model.memory.compression(x.mean(dim=1))
            x = model.memory.W_cell(x)
            all_concepts.append(x)
        concept_pool = torch.cat(all_concepts, dim=0)

        from sklearn.cluster import KMeans
        k = min(model.memory.memory_size, concept_pool.size(0))
        kmeans = KMeans(n_clusters=k, random_state=42)
        kmeans.fit(concept_pool.cpu().numpy())
        centroids = torch.tensor(kmeans.cluster_centers_, dtype=torch.float32).to(device)
        centroids = torch.nn.functional.normalize(centroids, dim=-1)
        n_centroids = centroids.size(0)

        model.memory.key_memory.data[:n_centroids] = centroids
        model.memory.value_memory.data[:n_centroids] = centroids
        model.memory.cell_state.data[:n_centroids] = centroids

    # TRAINING LOOP
    for epoch in range(num_epochs):
        progress = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        model.train()

        for inputs_batch, target_batch in progress:
            inputs_batch = inputs_batch.to(device)
            target_batch = target_batch.to(device)

            optimizer.zero_grad()

            # 1) Compute combined loss
            with autocast():
                loss = cal_loss_batch(
                    input_batch=inputs_batch,
                    target_batch=target_batch,
                    device=device,
                    student_model=model,
                    teacher_model=teacher_model
                )

            # 2) If loss is None, skip backward
            if loss is None:
                continue

            # 3) Backward + clip + step
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            tokens_seen += inputs_batch.numel()
            global_step += 1

            # 4) Periodic evaluation
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_dataloader, eval_dataloader, device, eval_iter
                )
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)
                progress.set_postfix({
                    'loss': f"{loss.item():.4f}",
                    'step': global_step,
                    'train_loss': f"{train_loss:.4f}",
                    'val_loss': f"{val_loss:.4f}"
                })

        # 5) At epoch end: generate sample + save checkpoint
        generate_and_print_sample(model, train_dataloader.dataset.tokenizer, device, start_context)
        save_model_checkpoint(model, optimizer, epoch + 1)

    return train_losses, val_losses, track_tokens_seen


In [None]:




print('start training')
train_losses , val_losses , token_seen =  train_model(
    teacher_model =teacher_model,
    
    model=compiled_model.to(device) , 
    train_dataloader=train_dataloader, 
    device=device, 
    eval_freq=5 , 
    eval_dataloader=val_dataloader , 
    optimizer=optimizer, 
    eval_iter=3,  
    num_epochs= 10, 
    start_context='### Instruction: What are the three primary colors? .n### Response:'
)
print(model.memory.get_memory_metrics())

start training
🚀 Total training steps: 800
Encoder weights are loaded 


Epoch 1/10:   0%|          | 0/80 [00:00<?, ?it/s]


[Prompt]: ### Instruction: What are the three primary colors? .n### Response:

[Generated]: FE indifference indifference indifference indifference indifference indifference indifference indifference indifference indifference indifference indifference名誉名誉名誉名誉名誉名誉名誉名誉 greatest greatest greatest mime mimepapapapapapapapapapapapapapapapapapa Dod sil Travels Travels \\ \\ \\ \\ \\ \\名誉名誉名誉 ways ways ways ways ways ways LIB



Epoch 2/10:   0%|          | 0/80 [00:00<?, ?it/s]


[Prompt]: ### Instruction: What are the three primary colors? .n### Response:

[Generated]: sil sil reliable reliable reliable reliable reliable reliable reliable reliable reliable*}[! запазва запазва запазва запазва запазваChaChaChaChaCha臀部 Гю indifference indifference indifference indifference indifference indifference indifference indifference indifference indifference indifference dropdown indifference indifference indifference indifference indifference indifference indifference indifference indifference indifference indifference indifference indifference indifference indifference indifference "< "< "< "< "< "< "< "< "< "< "< "<



Epoch 3/10:   0%|          | 0/80 [00:00<?, ?it/s]


[Prompt]: ### Instruction: What are the three primary colors? .n### Response:

[Generated]: 会让会让 Introdu Introdu Introdu Introdu Goal Goal Goal Goal Goal Goal Goal Goal Goal Goal Goalaring texlive texlive texlive texlive texlive texlive texlive dye.机动车 uncertainty uncertainty Heating融化 preve preve preve preve preve preve preve preve preve preve Heating, compilaci compilaci reliable*}[!Region снаря снаря снаря the Heating Heating Heating Heating Heating Heating,WalletWalletWallet



Epoch 4/10:   0%|          | 0/80 [00:00<?, ?it/s]

In [43]:
!pip install torch --upgrade


Collecting torch
  Downloading torch-2.7.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (29 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.6.77 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cudnn-cu12==9.5.1.17 (from torch)
  Downloading nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparselt-cu12==0.6.3 (from torch)
  Downloading nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting nvidia-nccl-cu12==2.26.2 (from torch)
  Downloading nvidia_nccl_cu12-2.26.2-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.0 kB)
Collecting nvidia-nvtx-cu12==12.6.77 (from torch)
  Downloading nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cufile-cu12==

In [None]:
model.memory.model_save(path =  'Memory_saved.pth')
model.memory.model_save(path = 'memory_saved.pt')

In [64]:
import torch
import torch.nn.functional as F

# Assume `memory` is your EfiBioSemanticMemory_V2 instance, already initialized
# and `W_cell` (compression) is frozen or pretrained so that random_input→embedding.
memory = model.memory
memory.to(device)
memory.eval()
memory.hit_count  = torch.tensor(0).to(device) # if you have a method; otherwise manually zero out hit_count, write_count, query_count
memory.write_count  = torch.tensor(0).to(device) 
memory.query_count = torch.tensor(0).to(device) 
# Helper to run a single step and report
def step_test(x):
    x= x.unsqueeze(0).unsqueeze(0)
    print(x.shape)
    out, retrieved, topk_idx, attn_weights = memory(x , training =  False ) # x shape [D]; expand to [1,1,D]
    max_sim = float(F.cosine_similarity(memory.key_memory[topk_idx.squeeze()], memory.key_memory[topk_idx.squeeze()], dim=-1).mean())
    return max_sim

# 1) Same vector twice
vec1 = torch.randn(memory.input_dim, device=memory.key_memory.device)
sim1 = step_test(vec1)
sim2 = step_test(vec1)

# 2) Two distinct vectors
vec2 = torch.randn(memory.input_dim, device=memory.key_memory.device)
sim3 = step_test(vec2)
sim4 = step_test(vec2)

print(" sims: first_pass(vec1) =", f"{sim1:.4f}",
      "| second_pass(vec1) =", f"{sim2:.4f}")
print(" sims: first_pass(vec2) =", f"{sim3:.4f}",
      "| second_pass(vec2) =", f"{sim4:.4f}")

print(" final metrics:",
      f"hit_count={memory.hit_count.item()}",
      f"write_count={memory.write_count.item()}",
      f"query_count={memory.query_count.item()}")

# Expected outcome:
# - sim1 < hit_threshold (no slot existed) → write_count=1
# - sim2  ≃ 0.9–1.0 (slot now exists) → hit_count=1
# - sim3 < hit_threshold (new) → write_count=2
# - sim4  ≃ 0.9–1.0 → hit_count=2


torch.Size([1, 1, 128])
hit_threshold 0.61
max scores  tensor(0.0303, device='cuda:0', grad_fn=<SqueezeBackward1>)
hit count tensor(0, device='cuda:0')
Writing Happen
Writing Happen
torch.Size([1, 1, 128])
hit_threshold 0.61
max scores  tensor(0.9465, device='cuda:0', grad_fn=<SqueezeBackward1>)
hit count tensor(1, device='cuda:0')
Update Happen
 average update_gate: 0.6290084719657898  std: 0.001736109028570354
torch.Size([1, 1, 128])
hit_threshold 0.61
max scores  tensor(0.3093, device='cuda:0', grad_fn=<SqueezeBackward1>)
hit count tensor(1, device='cuda:0')
Writing Happen
Writing Happen
torch.Size([1, 1, 128])
hit_threshold 0.61
max scores  tensor(0.9538, device='cuda:0', grad_fn=<SqueezeBackward1>)
hit count tensor(2, device='cuda:0')
Update Happen
 average update_gate: 0.8370599150657654  std: 0.010752998292446136
 sims: first_pass(vec1) = nan | second_pass(vec1) = 1.0000
 sims: first_pass(vec2) = nan | second_pass(vec2) = 1.0000
 final metrics: hit_count=2 write_count=2 query_co

In [None]:
model.memory.write_count 

In [None]:
import matplotlib.pyplot as plt

# If you loaded from a JSON file
# with open("loss_history.json", "r") as f:
#     data = json.load(f)
#     train_losses = data["train_loss"]
#     val_losses = data["val_loss"]

plt.figure(figsize=(10, 6))
plt.plot(train_losses, label="Train Loss", color="blue")
plt.plot(val_losses, label="Validation Loss", color="orange")
plt.xlabel("Evaluation Step")
pri
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid(True)
plt.savefig("loss_curve.png")  # Save the plot
plt.show()


In [None]:
import torch
import tiktoken

# Load model
model = GPTMQModel2(GPT_CONFIG_124M)
# model.load_state_dict(torch.load("/kaggle/working/checkpoint_epoch_7.pt"))
checkpoint = torch.load("checkpoint_epoch_7.pt")
model.load_state_dict(checkpoint["model_state_dict"])

model.eval().to(device)

# Tokenizer
tokenizer = tiktoken.get_encoding("gpt2")

# Utility functions
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded = torch.tensor(encoded).unsqueeze(0)
    return encoded

def token_ids_to_text(tokens, tokenizer):
    flat = tokens.squeeze(0)
    decode = tokenizer.decode(flat.tolist())
    return decode

# Sampling-based generate function (uses your logic)
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        seq_len = idx_cond.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(idx.device)
        causal_mask = causal_mask.unsqueeze(0)
f
        with torch.no_grad():
            logits = model(idx_cond, mask=causal_mask)

        logits = logits[:, -1, :]
        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(
                logits < min_val,
                torch.tensor(float('-inf')).to(logits.device),
                logits
            )
        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        idx = torch.cat((idx, idx_next), dim=1)

    return idx

# High-level text generation function
def generate_response(prompt, model, tokenizer, max_new_tokens=100, context_size=128, temperature=1.0, top_k=50):
    input_ids = text_to_token_ids(prompt, tokenizer).to(device)
    generated_ids = generate(
        model=model,
        idx=input_ids,
        max_new_tokens=max_new_tokens,
        context_size=context_size,
        temperature=temperature,
        top_k=top_k
    )
    return token_ids_to_text(generated_ids, tokenizer)

# Try it out
# prompt = "### Instruction:\nExplain what is deep learning.\n\n### Response:\n <bot>"
prompt = """

'### Instruction :Give three tips for staying healthy ### Response:'
""".strip()


output = generate_response(prompt, model, tokenizer)
print(output)


In [None]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded = torch.tensor(encoded).unsqueeze(0)
    return encoded
end_token_id = tokenizer.encode("<|endoftext|>", allowed_special={'<|endoftext|>'})[0]

def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        seq_len = idx_cond.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(idx.device)
        causal_mask = causal_mask.unsqueeze(0)

        with torch.no_grad():
            logits = model(idx_cond, mask=causal_mask)

        logits = logits[:, -1, :]
        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(
                logits < min_val,
                torch.tensor(float('-inf')).to(logits.device),
                logits
            )

        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        idx = torch.cat((idx, idx_next), dim=1)

        # Stop generation if <|endoftext|> is generated
        if idx_next.item() == end_token_id:
            break

    return idx
def truncate_after_n_bullets(text, n=3):
    lines = text.split("\n")
    count = 0
    result = []
    for line in lines:
        if line.strip().startswith(("1.", "2.", "3.")):
            count += 1
        result.append(line)
        if count >= n:
            break
    return "\n".join(result)
raw_output = generate_response(prompt, model, tokenizer)
cleaned_output = truncate_after_n_bullets(raw_output)
print(cleaned_output)



In [None]:
output = generate_response(
    prompt, model, tokenizer,
    temperature=0.8,  # better balance
    top_k=40,         # a bit narrower selection
    max_new_tokens=100
)


In [None]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded = torch.tensor(encoded).unsqueeze(0)
    return encoded

end_token_id = tokenizer.encode("<|endoftext|>", allowed_special={'<|endoftext|>'})[0]

def generate(model, idx, max_new_tokens, context_size, temperature=1.0, top_k=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        seq_len = idx_cond.size(1)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len)).to(idx.device).unsqueeze(0)

        with torch.no_grad():
            logits = model(idx_cond, mask=causal_mask)

        logits = logits[:, -1, :]

        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(
                logits < min_val,
                torch.tensor(float('-inf')).to(logits.device),
                logits
            )

        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        idx = torch.cat((idx, idx_next), dim=1)

        # Stop generation if <|endoftext|> is in the generated output
        if end_token_id in idx_next:
            break

    return idx

def truncate_after_n_bullets(text, n=3):
    lines = text.split("\n")
    count = 0
    result = []
    for line in lines:
        if line.strip().startswith(("1.", "2.", "3.")):
            count += 1
        result.append(line)
        if count >= n:
            break
    return "\n".join(result)

# 🔁 Input prompt
prompt = "### Instruction: What are the three primary colors? \n### Response:"

# 🔁 Tokenize input
input_ids = text_to_token_ids(prompt, tokenizer).to(device)

# 🔁 Generate output tokens
output_ids = generate(
    model=model,
    idx=input_ids,
    max_new_tokens=100,
    context_size=128,
    temperature=0.7,
    top_k=40
)

# 🔁 Decode and postprocess
output_text = tokenizer.decode(output_ids[0].tolist())

# ✂️ Truncate after 3 bullets (optional)
final_output = truncate_after_n_bullets(output_text)
print(final_output)
