In [28]:
import tiktoken
import torch


class TokenizationLayer:
    def __init__(self, model_name="cl100k_base", device=None):
        """
        Tokenization Layer using tiktoken.

        Args:
            model_name (str): Tokenizer model name.
            device (str): 'cuda' or 'cpu'. Auto-detects if None.
        """
        self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))

        # Load tokenizer
        self.tokenizer = tiktoken.get_encoding(model_name)
        self.vocab_size = self.tokenizer.n_vocab

        # Use supported special tokens
        self.cls_token = "<|startoftext|>"
        self.sep_token = "<|endoftext|>"

        # Explicitly allow special tokens during encoding to get their IDs
        self.cls_token_id = self.tokenizer.encode(self.cls_token, allowed_special={self.cls_token})[0]
        self.sep_token_id = self.tokenizer.encode(self.sep_token, allowed_special={self.sep_token})[0]
        self.pad_token_id = self.sep_token_id  # Use SEP as pad
        self.eos_token_id = self.sep_token_id # Use SEP as EOS token for generation

    def tokenize(self, texts, max_length=512, add_special_tokens=True):
        """
        Tokenizes input text(s) into padded/truncated token IDs.

        Args:
            texts (str or List[str]): Text(s) to tokenize.
            max_length (int): Max token length.
            add_special_tokens (bool): Whether to add CLS + SEP.

        Returns:
            torch.Tensor: Token IDs [batch_size, max_length]
        """
        if isinstance(texts, str):
            texts = [texts]

        token_ids = []
        for text in texts:
            # Allow all special tokens during the main encoding process as they might appear in the data
            tokens = self.tokenizer.encode(text, allowed_special="all")

            if add_special_tokens:
                tokens = [self.cls_token_id] + tokens + [self.sep_token_id]

            tokens = tokens[:max_length]
            if len(tokens) < max_length:
                tokens += [self.pad_token_id] * (max_length - len(tokens))

            token_ids.append(tokens)

        return torch.tensor(token_ids, dtype=torch.long, device=self.device)

    def detokenize(self, tokens):
        """
        Converts token IDs back to string(s), removing padding.

        Args:
            tokens (List[int] or Tensor): Token IDs

        Returns:
            str or List[str]: Decoded text(s)
        """
        if isinstance(tokens, torch.Tensor):
            tokens = tokens.cpu().tolist()

        if isinstance(tokens[0], list):  # Batch
            # Remove the 'allowed_special' argument from decode
            return [
                self.tokenizer.decode([t for t in seq if t != self.pad_token_id])
                for seq in tokens
            ]
        else:
            # Remove the 'allowed_special' argument from decode
            return self.tokenizer.decode([t for t in tokens if t != self.pad_token_id])


# ✅ Device Info
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using Device:", device)
if torch.backends.cudnn.is_available():
    print("cuDNN Enabled:", torch.backends.cudnn.enabled)

# ✅ Test
tokenizer_layer = TokenizationLayer(device=device)

# Single
text = "Hello, my name is Ankit Kashyap"
tokens = tokenizer_layer.tokenize(text, max_length=10)
decoded = tokenizer_layer.detokenize(tokens)

# Batch
texts = ["I have done my homework", "This is the perfect code!"]
batch_tokens = tokenizer_layer.tokenize(texts, max_length=10)
decoded_batch = tokenizer_layer.detokenize(batch_tokens)

# ✅ Output
print("\nSingle Example:")
print("Text:", text)
print("Tokens:", tokens)
print("Decoded:", decoded)

print("\nBatch Example:")
print("Texts:", texts)
print("Tokens:", batch_tokens)
print("Decoded:", decoded_batch)

Using Device: cuda
cuDNN Enabled: True

Single Example:
Text: Hello, my name is Ankit Kashyap
Tokens: tensor([[   27,  9906,    11,   856,   836,   374,  1556,  8390, 42708,    88]],
       device='cuda:0')
Decoded: ['<Hello, my name is Ankit Kashy']

Batch Example:
Texts: ['I have done my homework', 'This is the perfect code!']
Tokens: tensor([[    27,     40,    617,   2884,    856,  29559, 100257, 100257, 100257,
         100257],
        [    27,   2028,    374,    279,   4832,   2082,      0, 100257, 100257,
         100257]], device='cuda:0')
Decoded: ['<I have done my homework', '<This is the perfect code!']


In [29]:
import torch
import torch.nn as nn

# from Tokenization.tokenization import TokenizationLayer

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim, padding_idx=0, dropout=0.1):
        """
        Token Embedding Layer using PyTorch nn.Embedding.

        Args:
            vocab_size (int): Number of unique tokens in vocabulary.
            embed_dim (int): Dimension of each token embedding.
            padding_idx (int, optional): Index of padding token. Default: 0.
            dropout (float, optional): Dropout probability. Default: 0.1.
        """
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embed_dim,
            padding_idx=padding_idx  # Helps handle padding tokens efficiently
        )
        self.dropout = nn.Dropout(dropout)  # Dropout for regularization
        nn.init.xavier_uniform_(self.embedding.weight)  # Better initialization

    def forward(self, input_tokens):
        """
        Forward pass to convert token IDs to embeddings.

        Args:
            input_tokens (torch.Tensor): Tensor of shape (batch_size, seq_len).

        Returns:
            torch.Tensor: Token embeddings of shape (batch_size, seq_len, embed_dim).
        """
        # Ensure input is on the correct device
        input_tokens = input_tokens.to(self.embedding.weight.device)

        # Apply embedding and dropout

        embeddings = self.embedding(input_tokens)
        embeddings = self.dropout(embeddings)  # Apply dropout

        return embeddings


# ✅ Hyperparameters
vocab_size = 10000  # Size of vocabulary
embed_dim = 512  # Embedding dimension per token
batch_size = 8  # Number of sequences processed in parallel
seq_len = 128  # Max sequence length



tokenizer_layer = TokenizationLayer()
# ✅ Initialize Token Embedding Layer
# ✅ Initialize Token Embedding Layer with correct vocab size
token_embedding = TokenEmbedding(
    vocab_size=tokenizer_layer.vocab_size,  # Use actual vocab size
    embed_dim=512
).to(device)

# ✅ Example Input (Random Token IDs)

input_tokens = torch.randint(0, tokenizer_layer.vocab_size, (batch_size, seq_len), device=device, dtype=torch.long)


# ✅ Apply Token Embedding
output_embeddings = token_embedding(input_tokens)

# ✅ Debugging Info
print("\n✅ Input Tokens Shape:", input_tokens.shape)  # Expected: (8, 128)
print("✅ Token Embedding Output Shape:", output_embeddings.shape)  # Expected: (8, 128, 512)
print("✅ Token Embedding Output dtype:", output_embeddings.dtype)  # Should be float16 if AMP enabled





✅ Input Tokens Shape: torch.Size([8, 128])
✅ Token Embedding Output Shape: torch.Size([8, 128, 512])
✅ Token Embedding Output dtype: torch.float32


In [30]:
import torch
import torch.nn as nn


device = 'cuda' if torch.cuda.is_available() else 'cpu'




class RotaryPositionalEncoding(nn.Module):
    def __init__(self, embed_dim):
        """
        Rotary Positional Encoding (RoPE) for transformers.

        Args:
            embed_dim (int): Dimension of token embeddings.
        """
        super(RotaryPositionalEncoding, self).__init__()
        self.embed_dim = embed_dim

        # Compute inverse frequency terms for RoPE
        inv_freq = 1.0 / (10000 ** (torch.arange(0, embed_dim, 2, dtype=torch.float32) / embed_dim))
        self.register_buffer("inv_freq", inv_freq)  # Store as buffer

    def rotate_half(self, x):
        """
        Rotates the last dimension by 90 degrees.

        Args:
            x (torch.Tensor): Input tensor of shape (..., embed_dim).

        Returns:
            torch.Tensor: Rotated tensor of same shape.
        """
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

    def forward(self, x):
        """
        Forward pass for RoPE.

        Args:
            x (torch.Tensor): Token embeddings of shape (batch_size, seq_len, embed_dim).

        Returns:
            torch.Tensor: Rotated embeddings with positional information.
        """
        batch_size, seq_len, embed_dim = x.shape

        # Generate position indices
        positions = torch.arange(seq_len, dtype=torch.float32, device=x.device).unsqueeze(1)

        # Compute rotation frequencies
        freqs = torch.matmul(positions, self.inv_freq.unsqueeze(0))  # Shape: [seq_len, embed_dim//2]
        emb = torch.cat((freqs, freqs), dim=-1)  # Shape: [seq_len, embed_dim]

        # Compute cos and sin embeddings
        cos_emb, sin_emb = emb.cos().unsqueeze(0), emb.sin().unsqueeze(0)  # Shape: [1, seq_len, embed_dim]

        # Apply RoPE transformation
        x_rotated = (x * cos_emb) + (self.rotate_half(x) * sin_emb)

        return x_rotated


# ✅ Hyperparameters
batch_size = 8
seq_len = 128
embed_dim = 512

# ✅ Initialize RoPE
rotary_pe = RotaryPositionalEncoding(embed_dim).to(device)

# ✅ Example Input (Random Token Embeddings)
input_embeddings = torch.randn(batch_size, seq_len, embed_dim, dtype=torch.float32, device=device)

# ✅ Apply RoPE
output_embeddings = rotary_pe(input_embeddings)

# ✅ Debugging Info
print("✅ Input Embeddings Shape:", input_embeddings.shape)  # Expected: (8, 128, 512)
print("✅ RoPE Output Shape:", output_embeddings.shape)  # Expected: (8, 128, 512)
print("✅ RoPE Output dtype:", output_embeddings.dtype)  # Expected: float32

✅ Input Embeddings Shape: torch.Size([8, 128, 512])
✅ RoPE Output Shape: torch.Size([8, 128, 512])
✅ RoPE Output dtype: torch.float32


In [31]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# from RotaryPositionalEncoding.RotaryPositionalEncoding import  RotaryPositionalEncoding

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class LayerNorm(nn.Module):
    def __init__(self, embed_dim, eps=1e-5):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim, dtype=torch.float32))
        self.beta = nn.Parameter(torch.zeros(embed_dim, dtype=torch.float32))
        self.eps = eps

    def forward(self, x):
        if x.dtype != torch.float32:
            x = x.to(torch.float32)
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta, self.eps)


# ✅ Residual Connection with Pre-LN
class ResidualConnection(nn.Module):
    def __init__(self, embed_dim, dropout=0.1):
        super(ResidualConnection, self).__init__()
        self.norm = LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))


# ✅ Multi-Head Self-Attention with RoPE
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads): # Removed rotary_pe=None from init
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "Embed dim must be divisible by heads"

        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, dtype=torch.float32)
        self.out_proj = nn.Linear(embed_dim, embed_dim, dtype=torch.float32)

        # Initialize RoPE here using head_dim
        self.rotary_pe = RotaryPositionalEncoding(self.head_dim)

    def apply_rope(self, x):
        # x has shape (B, H, T, head_dim)
        B, H, T, D = x.shape # D is head_dim
        # Reshape to (B * H, T, head_dim)
        x = x.reshape(B * H, T, D)
        # Apply RoPE, which expects input dimension to match its initialized dimension (head_dim)
        x = self.rotary_pe(x)
        # Reshape back to (B, H, T, head_dim)
        return x.reshape(B, H, T, D)

    def forward(self, x, mask=None):
        B, T, _ = x.shape # x has shape (B, T, embed_dim)
        qkv = self.qkv_proj(x).view(B, T, 3, self.num_heads, self.head_dim)
        Q, K, V = qkv.unbind(dim=2) # Q, K, V are (B, T, num_heads, head_dim)
        Q = Q.transpose(1, 2) # (B, num_heads, T, head_dim)
        K = K.transpose(1, 2) # (B, num_heads, T, head_dim)
        V = V.transpose(1, 2) # (B, num_heads, T, head_dim)

        # Apply RoPE to Q and K (which now have head_dim as the last dimension)
        Q = self.apply_rope(Q)
        K = self.apply_rope(K)

        if mask is not None:
            # Expand mask to match (B, num_heads, T, T) for broadcasting
            mask = mask.to(dtype=torch.bool, device=x.device)  # bool mask is expected by scaled_dot_product_attention


        # F.scaled_dot_product_attention expects Q, K, V of shape (B, num_heads, T, head_dim)
        output = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask) # output is (B, num_heads, T, head_dim)

        # Reshape output back to (B, T, embed_dim)
        output = output.transpose(1, 2).contiguous().reshape(B, T, self.embed_dim)
        return self.out_proj(output)


# ✅ Hyperparameters
batch_size = 8
seq_len = 128
embed_dim = 512
num_heads = 8

# Initialize Layers
# Now, RotaryPositionalEncoding is initialized inside MultiHeadSelfAttention with head_dim
self_attention = MultiHeadSelfAttention(embed_dim=embed_dim, num_heads=num_heads).to(device)
residual = ResidualConnection(embed_dim).to(device)

# ✅ Example Input
input_embeddings = torch.randn(batch_size, seq_len, embed_dim, dtype=torch.float32, device=device)
# Create the casual mask with shape (1, 1, seq_len, seq_len)
mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)).unsqueeze(0).unsqueeze(0)

# ✅ Forward Pass
output = residual(input_embeddings, lambda x: self_attention(x, mask=mask))

# ✅ Debug
print("✅ Input:", input_embeddings.shape)      # (8, 128, 512)
print("✅ Output:", output.shape)               # (8, 128, 512)


✅ Input: torch.Size([8, 128, 512])
✅ Output: torch.Size([8, 128, 512])


In [32]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# from MultiHeadSelfAttention.MultiHeadSelfAttention import MultiHeadSelfAttention



device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(42)

# ✅ Optimized Layer Normalization (Pre-LN)
class LayerNorm(nn.Module):
    def __init__(self, embed_dim, eps=1e-5):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(embed_dim, dtype=torch.float32))  # Learnable scale
        self.beta = nn.Parameter(torch.zeros(embed_dim, dtype=torch.float32))  # Learnable shift
        self.eps = eps  # Small value for numerical stability

    def forward(self, x):
        # Apply Layer Normalization
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta, self.eps)

# ✅ Optimized Feedforward Network (FFN)
class FeedforwardNetwork(nn.Module):
    def __init__(self, embed_dim, hidden_dim, dropout=0.1):
        """
        Feedforward Network with GELU Activation and Dropout.
        """
        super(FeedforwardNetwork, self).__init__()
        self.linear1 = nn.Linear(embed_dim, hidden_dim, dtype=torch.float32)  # Expansion
        self.linear2 = nn.Linear(hidden_dim, embed_dim, dtype=torch.float32)  # Compression
        self.dropout = nn.Dropout(dropout)  # Regularization
        self.activation = nn.GELU()  # Activation function

    def forward(self, x):
        x = self.linear1(x)  # Expand dimensions
        x = self.activation(x)  # Apply GELU
        x = self.dropout(x)  # Apply dropout
        x = self.linear2(x)  # Compress dimensions
        return x

# ✅ Optimized Residual Connection with Pre-Norm
class ResidualConnection(nn.Module):
    def __init__(self, embed_dim, dropout=0.1):
        """
        Residual Connection with Pre-Norm (Better Stability).
        """
        super(ResidualConnection, self).__init__()
        self.norm = LayerNorm(embed_dim)  # Pre-LayerNorm
        self.dropout = nn.Dropout(dropout)  # Regularization

    def forward(self, x, sublayer):
        # Apply Pre-LayerNorm, sublayer, and residual connection
        return x + self.dropout(sublayer(self.norm(x)))



# ✅ Hyperparameters
batch_size = 8
seq_len = 128
embed_dim = 512
hidden_dim = 2048  # FFN hidden dimension
num_heads = 8

# ✅ Initialize Layers
self_attention = MultiHeadSelfAttention(embed_dim, num_heads).to(device)
ffn = FeedforwardNetwork(embed_dim, hidden_dim).to(device)
residual_connection1 = ResidualConnection(embed_dim).to(device)
residual_connection2 = ResidualConnection(embed_dim).to(device)

# ✅ Example Input (Random Token Embeddings)
input_embeddings = torch.randn(batch_size, seq_len, embed_dim, dtype=torch.float32, device=device)

# ✅ Causal Mask (For Decoder)
mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.float32, device=device)).unsqueeze(0).unsqueeze(0)

# ✅ Apply Masked Multi-Head Self-Attention with Residual Connection
attention_output = residual_connection1(input_embeddings, self_attention)

# ✅ Apply Feedforward Network with Residual Connection
ffn_output = residual_connection2(attention_output, ffn)

# ✅ Debugging Info
print("✅ Input Embeddings Shape:", input_embeddings.shape)  # Expected: (8, 128, 512)
print("✅ Attention Output Shape:", attention_output.shape)  # Expected: (8, 128, 512)
print("✅ FFN Output Shape:", ffn_output.shape)  # Expected: (8, 128, 512)
print("✅ Output dtype:", ffn_output.dtype)  # Expected: float32

✅ Input Embeddings Shape: torch.Size([8, 128, 512])
✅ Attention Output Shape: torch.Size([8, 128, 512])
✅ FFN Output Shape: torch.Size([8, 128, 512])
✅ Output dtype: torch.float32


In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# from MultiHeadSelfAttention.MultiHeadSelfAttention import MultiHeadSelfAttention
# from PostNorm.PostNorm import FeedforwardNetwork, LayerNorm, ResidualConnection
# from RotaryPositionalEncoding.RotaryPositionalEncoding import RotaryPositionalEncoding
device = 'cuda' if torch.cuda.is_available() else 'cpu'





# ✅ Layer 1: Sublayers
class Layer1(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.1):
        super(Layer1, self).__init__()
        # Sublayers
        self.pre_layer_norm = LayerNorm(embed_dim)  # Pre-LayerNorm
        # Remove the rotary_pe argument here as MultiHeadSelfAttention handles it internally
        self.self_attention = MultiHeadSelfAttention(embed_dim, num_heads)

        self.residual1 = ResidualConnection(embed_dim, dropout)  # Residual Connection 1
        self.layer_norm1 = LayerNorm(embed_dim)  # LayerNorm after Self-Attention
        self.ffn = FeedforwardNetwork(embed_dim, hidden_dim, dropout)  # Feedforward Network
        self.residual2 = ResidualConnection(embed_dim, dropout)  # Residual Connection 2
        self.layer_norm2 = LayerNorm(embed_dim)  # LayerNorm after FFN

    def forward(self, x):
        # Sublayer 1: Self-Attention with Residual Connection
        # Note: MultiHeadSelfAttention will apply RoPE internally now
        # A causal mask might be needed depending on the model type (e.g., decoder)
        # For now, passing None as mask as per previous code blocks. Add mask logic if needed.
        attention_output = self.self_attention(self.pre_layer_norm(x)) # Apply pre-norm before attention
        x = self.residual1(x, lambda _: attention_output)  # Self-Attention + Residual
        # The original code had layer_norm1 here, but with Pre-LN, this is usually after the residual and before the next sublayer's pre-norm

        # Sublayer 2: Feedforward Network with Residual Connection
        ffn_output = self.ffn(self.layer_norm1(x)) # Apply layer_norm1 (after first residual) before FFN
        x = self.residual2(x, lambda _: ffn_output)  # FFN + Residual
        # The final layer norm for the block is layer_norm2
        x = self.layer_norm2(x)

        return x


# ✅ Hyperparameters
batch_size = 8
seq_len = 128
embed_dim = 512
hidden_dim = 2048  # FFN hidden dimension
num_heads = 8

# ✅ Initialize Layer 1
layer1 = Layer1(embed_dim, hidden_dim, num_heads).to(device)

# ✅ Example Input (Random Token Embeddings)
input_embeddings = torch.randn(batch_size, seq_len, embed_dim, dtype=torch.float32, device=device)

# ✅ Apply Layer 1
output = layer1(input_embeddings)

# ✅ Debugging Info
print("✅ Input Embeddings Shape:", input_embeddings.shape)  # Expected: (8, 128, 512)
print("✅ Layer 1 Output Shape:", output.shape)  # Expected: (8, 128, 512)
print("✅ Output dtype:", output.dtype)  # Expected: float32

✅ Input Embeddings Shape: torch.Size([8, 128, 512])
✅ Layer 1 Output Shape: torch.Size([8, 128, 512])
✅ Output dtype: torch.float32


In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ✅ Multi-Head Self-Attention Layer
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim # Store embed_dim as an attribute
        self.num_heads = num_heads  # Store num_heads as an attribute
        self.head_dim = embed_dim // num_heads

        assert self.head_dim * num_heads == embed_dim, "Embedding dimension must be divisible by number of heads"

        # Fused QKV Projection (Single Linear Layer for Efficiency)
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, dtype=torch.float32)
        self.out_proj = nn.Linear(embed_dim, embed_dim, dtype=torch.float32)

    def forward(self, x, mask=None):
        batch_size, seq_len, embed_dim = x.shape

        # Compute Q, K, V in a single pass
        qkv = self.qkv_proj(x).view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        Q, K, V = qkv.unbind(dim=2)  # Split into separate tensors

        # Reshape for multi-head attention
        Q = Q.transpose(1, 2)  # Shape: (batch, num_heads, seq_len, head_dim)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # Flash Attention (Optimized Scaled Dot-Product Attention)
        if mask is not None:
            mask = mask.to(dtype=x.dtype, device=x.device)  # Ensure mask is on the correct device and dtype
        output = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)

        # Reshape back to original shape
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)

        # Apply output projection
        return self.out_proj(output)

# ✅ Chunked Attention (Fixed)
class ChunkedAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, chunk_size=32):
        super(ChunkedAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.chunk_size = chunk_size
        self.self_attention = MultiHeadSelfAttention(embed_dim, num_heads)  # Self-Attention

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        # Step 1: Pad input if necessary
        pad_len = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
        x = F.pad(x, (0, 0, 0, pad_len))  # Pad along seq_len
        seq_len += pad_len

        # Step 2: Divide input into chunks
        num_chunks = seq_len // self.chunk_size
        x = x.view(batch_size, num_chunks, self.chunk_size, embed_dim)  # (batch, num_chunks, chunk_size, embed_dim)

        # Step 3: Apply self-attention to each chunk
        x = x.reshape(batch_size * num_chunks, self.chunk_size, embed_dim)  # (batch * num_chunks, chunk_size, embed_dim)
        x = self.self_attention(x)  # Apply self-attention
        x = x.reshape(batch_size, num_chunks, self.chunk_size, embed_dim)  # Reshape back

        # Step 4: Combine chunks back into sequence
        x = x.reshape(batch_size, seq_len, embed_dim)  # Reshape back

        return x[:, :seq_len - pad_len, :]  # Remove padding

# ✅ LayerNorm Wrapper
class LayerNorm(nn.Module):
    def __init__(self, embed_dim):
        super(LayerNorm, self).__init__()
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        if x.dtype != torch.float32:
            x = x.to(torch.float32)  # Ensure input is float32
        return self.norm(x)

# ✅ Recurrent Memory (Fixed)
class RecurrentMemory(nn.Module):
    def __init__(self, embed_dim, memory_size, num_heads):  # Add num_heads as a parameter
        super(RecurrentMemory, self).__init__()
        self.memory_size = memory_size
        self.num_heads = num_heads  # Store num_heads as an attribute
        self.memory_bank = nn.Parameter(torch.zeros(memory_size, embed_dim, dtype=torch.float32))  # Use float32
        self.memory_norm = LayerNorm(embed_dim)  # Normalize memory output
        self.memory_attention = nn.MultiheadAttention(embed_dim, num_heads=num_heads, batch_first=True)  # Use num_heads

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        # Step 1: Memory Read (Multi-query attention)
        memory_expanded = self.memory_bank.unsqueeze(0).expand(batch_size, -1, -1)  # (batch, memory_size, embed_dim)
        memory_output, _ = self.memory_attention(x, memory_expanded, memory_expanded)  # (batch, seq_len, embed_dim)
        memory_output = self.memory_norm(memory_output)  # Normalize memory output

        # Step 2: Integrate memory output into main stream
        x = x + memory_output  # Add memory output to input

        # Step 3: Memory Write (Update memory bank)
        self.update_memory(x)

        return x

    def update_memory(self, x):
        with torch.no_grad():
            # Aggregate information for update
            update_value = x.mean(dim=1).mean(dim=0)

            # Create a new memory bank with shifted values
            updated_memory_bank = torch.roll(self.memory_bank, shifts=-1, dims=0)
            updated_memory_bank[-1] = update_value

            # Update the memory bank using a non-in-place operation
            self.memory_bank.data = updated_memory_bank  # Assign to .data to avoid in-place operation

In [35]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

import torch
import torch.nn as nn
import torch.nn.functional as F

# from MultiHeadSelfAttention.MultiHeadSelfAttention import MultiHeadSelfAttention
# from PostNorm.PostNorm import FeedforwardNetwork, LayerNorm, ResidualConnection
# from ChunkedAttention.ChunkedAttention import ChunkedAttention, RecurrentMemory # Uncomment this line
# from RotaryPositionalEncoding.RotaryPositionalEncoding import RotaryPositionalEncoding

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class Layer2WithMemory(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, memory_size, dropout=0.1):
        super(Layer2WithMemory, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.memory_size = memory_size

        # Sublayers
        self.pre_layer_norm = LayerNorm(embed_dim)
        # Assuming RotaryPositionalEncoding is also defined or imported elsewhere.
        # If not, you might need to uncomment its import too or define it here.
        # rope = RotaryPositionalEncoding(embed_dim=embed_dim).to(device) # This instance isn't used directly in this class's __init__
        self.self_attention = MultiHeadSelfAttention(embed_dim, num_heads)
        self.chunked_attention = ChunkedAttention(embed_dim, num_heads)  # ✅ New
        self.recurrent_memory = RecurrentMemory(embed_dim, memory_size, num_heads)  # ✅ New

        self.residual1 = ResidualConnection(embed_dim, dropout)
        self.layer_norm1 = LayerNorm(embed_dim)

        # Memory Module
        self.memory_bank = nn.Parameter(torch.zeros(memory_size, embed_dim, dtype=torch.float32))
        self.memory_norm = LayerNorm(embed_dim)
        self.memory_attention = MultiHeadSelfAttention(embed_dim, num_heads)

        # Learnable Memory Update Mechanism
        self.memory_gate = nn.Linear(embed_dim, 1)
        self.sigmoid = nn.Sigmoid()

        # Feedforward Network
        self.ffn = FeedforwardNetwork(embed_dim, hidden_dim, dropout)
        self.residual2 = ResidualConnection(embed_dim, dropout)
        self.layer_norm2 = LayerNorm(embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape
        num_heads = self.memory_attention.num_heads
        head_dim = embed_dim // num_heads

        # === Self-Attention + Chunked Attention + Recurrent Memory ===
        # Note: Self-Attention within this layer definition should handle masks if needed.
        # If MultiHeadSelfAttention requires a mask, you'll need to add it here.
        attn_output = self.self_attention(self.pre_layer_norm(x)) # Apply pre-norm before attention
        chunked_output = self.chunked_attention(self.pre_layer_norm(x)) # Apply pre-norm before chunked attention
        recurrent_output = self.recurrent_memory(self.pre_layer_norm(x)) # Apply pre-norm before recurrent memory

        # Combine all three attentions (average for simplicity; you can try weighted sum too)
        combined_attn = (attn_output + chunked_output + recurrent_output) / 3

        # The ResidualConnection expects a function as the second argument.
        # We are applying the combined_attn output directly, so the lambda
        # should just return the combined_attn.
        x = self.residual1(x, lambda _: combined_attn)
        x = self.layer_norm1(x)

        # === Memory Read ===
        # Ensure memory_bank has correct device and type
        memory_bank = self.memory_bank.unsqueeze(0).expand(batch_size, -1, -1).to(x.device, x.dtype)
        memory_bank = memory_bank.view(batch_size, self.memory_size, num_heads, head_dim).permute(0, 2, 1, 3)
        Q = x.view(batch_size, seq_len, num_heads, head_dim).permute(0, 2, 1, 3)
        # Assuming memory_bank should be treated as K and V for memory attention
        memory_output = F.scaled_dot_product_attention(Q, memory_bank, memory_bank)
        memory_output = memory_output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, embed_dim)

        # Apply memory gate
        gate_weight = self.sigmoid(self.memory_gate(x))
        memory_output = gate_weight * memory_output
        x = x + memory_output # Add memory output as a form of residual

        # === Memory Update ===
        self.update_memory(x)

        # === FFN ===
        # Apply layer_norm1 before FFN as per Pre-LN pattern
        ffn_output = self.ffn(self.layer_norm1(x))
        x = self.residual2(x, lambda _: ffn_output) # Wrap ffn_output in lambda
        x = self.layer_norm2(x) # Final layer norm for the block
        return x

    def update_memory(self, x):
        with torch.no_grad():
            # Average across sequence length to get a single vector per batch
            update_value = x.mean(dim=1) # Shape (batch_size, embed_dim)

            # Shift memory bank
            updated_memory_bank = torch.roll(self.memory_bank.data, shifts=-1, dims=0)

            # Calculate update gate for each sample in the batch
            gate = self.sigmoid(self.memory_gate(update_value)) # Shape (batch_size, 1)

            # Get the memory slot that is being updated (the oldest slot after rolling)
            memory_slot_to_update = updated_memory_bank[0, :].unsqueeze(0).expand_as(update_value) # Shape (batch_size, embed_dim)

            # Apply gating to the update value and the current memory slot
            # The update_value is what we want to write, gated by `gate`.
            # The memory_slot_to_update is the old value, kept by `1 - gate`.
            updated_memory_slot_for_batch = (gate * update_value) + (1 - gate) * memory_slot_to_update # Shape (batch_size, embed_dim)

            # Average the updated memory slot across the batch to get a single vector
            # for the memory bank's slot 0.
            updated_memory_bank[0, :] = updated_memory_slot_for_batch.mean(dim=0) # Shape (embed_dim,)

            # Update the persistent memory bank parameter data
            self.memory_bank.data = updated_memory_bank.data



# ✅ Hyperparameters
batch_size = 8
seq_len = 128
embed_dim = 512
hidden_dim = 2048  # FFN hidden dimension
num_heads = 8
memory_size = 100  # Memory bank size

# Assume LayerNorm, FeedforwardNetwork, ResidualConnection, MultiHeadSelfAttention, and RotaryPositionalEncoding
# are defined or imported in previous cells or scripts.

# ✅ Initialize Layer 2 with Memory
# Ensure other necessary classes (LayerNorm, FeedforwardNetwork, ResidualConnection, MultiHeadSelfAttention)
# are defined or imported before this cell runs.
layer2_with_memory = Layer2WithMemory(embed_dim, hidden_dim, num_heads, memory_size).to(device)

# ✅ Example Input (Layer 1 ka output)
input_embeddings = torch.randn(batch_size, seq_len, embed_dim, dtype=torch.float32, device=device)

# ✅ Apply Layer 2 with Memory
output = layer2_with_memory(input_embeddings)

# ✅ Debugging Info
print("✅ Input Embeddings Shape:", input_embeddings.shape)  # Expected: (8, 128, 512)
print("✅ Layer 2 Output Shape:", output.shape)  # Expected: (8, 128, 512)
print("✅ Output dtype:", output.dtype)  # Expected: float32

✅ Input Embeddings Shape: torch.Size([8, 128, 512])
✅ Layer 2 Output Shape: torch.Size([8, 128, 512])
✅ Output dtype: torch.float32


In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F



device = 'cuda' if torch.cuda.is_available() else 'cpu'




class RewardModel(nn.Module):
    def __init__(self, embed_dim):
        super(RewardModel, self).__init__()
        self.linear1 = nn.Linear(embed_dim, 256)  # First linear layer
        self.linear2 = nn.Linear(256, 1)  # Second linear layer
        self.activation = nn.ReLU()  # Activation function
        self.dropout = nn.Dropout(0.1)  # Dropout for regularization

    def forward(self, x):
        x = self.activation(self.linear1(x))  # Apply first linear layer and activation
        x = self.dropout(x)  # Apply dropout
        x = self.linear2(x).float()  # Apply second linear layer and ensure float32 output
        return x  # Shape: (batch, seq_len, 1)

class PPOOptimizer:
    def __init__(self, model, reward_model, lr=1e-4, gamma=0.99, clip_epsilon=0.2, entropy_coef=0.01):
        self.model = model
        self.reward_model = reward_model
        # Ensure optimizer is created *after* model parameters are on the correct device
        self.optimizer = optim.Adam(model.parameters(), lr=lr)  # Adam optimizer
        self.gamma = gamma  # Discount factor
        self.clip_epsilon = clip_epsilon  # Clipping parameter for PPO
        self.entropy_coef = entropy_coef  # Entropy coefficient

    def compute_advantages(self, rewards, values):
        """
        Compute advantages. Assuming rewards are per-batch mean rewards
        and values are per-batch values of the last state.
        """
        # Simple advantage: Reward - Value
        # Ensure shapes are compatible. rewards is (batch_size, 1), values is (batch_size,)
        # Unsqueeze values to (batch_size, 1) for element-wise subtraction
        advantages = rewards - values.unsqueeze(1)
        return advantages

    def update(self, states, actions, rewards, old_log_probs, values):
        print("🔍 PPO Update Start")

        # Compute advantages
        advantages = self.compute_advantages(rewards, values)

        # Safe normalization
        adv_mean = advantages.mean()
        adv_std = advantages.std()
        if not torch.isfinite(adv_std) or adv_std < 1e-6:
            print("⚠️ Unstable advantage std detected, skipping PPO update.")
            return torch.tensor(0.0, device=states.device)
        advantages = (advantages - adv_mean) / (adv_std + 1e-8)

        print(f"🔹 Advantages: {advantages.shape}")

        # Forward pass again
        output, new_log_probs, _, new_values, _ = self.model(states, actions)

        # Clamp log probs to prevent exp overflow
        new_log_probs = torch.clamp(new_log_probs, min=-20, max=0)
        old_log_probs = torch.clamp(old_log_probs, min=-20, max=0)

        # Calculate returns
        returns = advantages.squeeze(1) + values  # shape: (batch_size,)
        print(f"✅ Log Probs: {new_log_probs.shape}, ✅ Values: {new_values.shape}")
        print(f"🔄 Returns shape: {returns.shape}")

        # PPO ratio
        ratio = torch.exp(new_log_probs - old_log_probs)
        clipped_ratio = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon)
        policy_loss = -torch.min(ratio.unsqueeze(1) * advantages,
                                 clipped_ratio.unsqueeze(1) * advantages).mean()

        # Value loss
        value_loss = F.mse_loss(new_values, returns)

        # Total loss
        loss = policy_loss + 0.5 * value_loss

        # Check for NaNs or Infs before backward
        if not torch.isfinite(loss):
            print("🚨 Loss is NaN/Inf. Skipping PPO update.")
            return torch.tensor(0.0, device=states.device)

        print(f"📉 Loss: {loss.item()}")
        print("✅ PPO Update Done (Loss computed)! Backprop and Step outside.")

        return loss


# ✅ Hyperparameters
batch_size = 8
seq_len = 128
embed_dim = 512

# ✅ Initialize Reward Model
reward_model = RewardModel(embed_dim).to(device)

# ✅ Example Input (Random Token Embeddings)
input_embeddings = torch.randn(batch_size, seq_len, embed_dim, dtype=torch.float32, device=device)

# ✅ Apply Reward Model
rewards = reward_model(input_embeddings)

# ✅ Debugging Info
print("✅ Input Embeddings Shape:", input_embeddings.shape)  # Expected: (8, 128, 512)
print("✅ Rewards Shape:", rewards.shape)  # Expected: (8, 128, 1)
print("✅ Rewards dtype:", rewards.dtype)  # Expected: float32

✅ Input Embeddings Shape: torch.Size([8, 128, 512])
✅ Rewards Shape: torch.Size([8, 128, 1])
✅ Rewards dtype: torch.float32


In [37]:
import torch

import torch.nn as nn
import torch.nn.functional as F

# from MultiHeadSelfAttention.MultiHeadSelfAttention import MultiHeadSelfAttention,LayerNorm,ResidualConnection
# from PostNorm.PostNorm import LayerNorm,ResidualConnection,FeedforwardNetwork
# from ChunkedAttention.ChunkedAttention import ChunkedAttention,RecurrentMemory
# from RotaryPositionalEncoding.RotaryPositionalEncoding import RotaryPositionalEncoding


# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Set random seed for reproducibility
torch.manual_seed(42)


class Layer3WithContextAndRL(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads, memory_size, dropout=0.1, num_actions=10):
        super(Layer3WithContextAndRL, self).__init__()
        self.embed_dim = embed_dim # Store embed_dim as an attribute
        # Sublayers
        self.pre_layer_norm = LayerNorm(embed_dim)  # Pre-LayerNorm
        # Remove the direct initialization and passing of rope here
        # rope = RotaryPositionalEncoding(embed_dim=embed_dim).to(device)
        # MultiHeadSelfAttention in ipython-input-13-837d8f2a7177 handled RoPE internally based on head_dim
        # The MultiHeadSelfAttention imported/defined in a previous cell (ipython-input-19-363be42504bd)
        # also handles RoPE internally.
        # Initialize MultiHeadSelfAttention without the rotary_pe argument
        self.self_attention = MultiHeadSelfAttention(embed_dim, num_heads)

        self.residual1 = ResidualConnection(embed_dim, dropout)  # Residual Connection 1
        self.layer_norm1 = LayerNorm(embed_dim)  # LayerNorm after Self-Attention

        # Context Handling
        self.chunked_attention = ChunkedAttention(embed_dim, num_heads)  # Chunked Attention
        # Ensure RecurrentMemory also matches its defined __init__ signature (takes embed_dim, memory_size, num_heads)
        self.recurrent_memory = RecurrentMemory(embed_dim, memory_size, num_heads)  # Recurrent Memory

        # Feedforward Network
        self.ffn = FeedforwardNetwork(embed_dim, hidden_dim, dropout)  # FFN
        self.residual2 = ResidualConnection(embed_dim, dropout)  # Residual Connection 2
        self.layer_norm2 = LayerNorm(embed_dim)  # LayerNorm after FFN

        # RL Integration
        # Reward Model is now separate, passed during PPO initialization
        self.reward_model = nn.Linear(embed_dim, 1)  # Reward Model instance
        self.policy_network = nn.Linear(embed_dim, num_actions)  # Policy Network
        self.value_network = nn.Linear(embed_dim, 1)  # Value Network

        # PPO Optimizer will be initialized outside this module
        # self.ppo_optimizer = PPOOptimizer(model=self, reward_model=self.reward_model)  # Initialize PPO Optimizer

    def forward(self, x, actions=None):
        # Sublayer 1: Self-Attention with Residual Connection
        # Ensure MultiHeadSelfAttention forward method is compatible with ResidualConnection's expected input
        # ResidualConnection calls self.self_attention(self.norm(x))
        # The self_attention forward method in ipython-input-19-363be42504bd takes x and optional mask
        # The Layer3WithContextAndRL forward currently does not pass a mask to self_attention
        # If a causal mask is needed for this layer, add mask creation and passing here.
        x = self.residual1(x, self.self_attention) # residual1 passes normalized x to self_attention
        x = self.layer_norm1(x)

        # Context Handling
        # Check ChunkedAttention and RecurrentMemory forward methods to ensure they take the correct input (x)
        # ChunkedAttention and RecurrentMemory from ipython-input-19-363be42504bd take x as input
        x = self.chunked_attention(x)
        x = self.recurrent_memory(x)

        # Sublayer 2: Feedforward Network with Residual Connection
        # residual2 calls self.ffn(self.norm(x))
        # FeedforwardNetwork from ipython-input-15-45a01112038a takes x as input
        x = self.residual2(x, self.ffn) # residual2 passes normalized x to ffn
        x = self.layer_norm2(x)

        # RL Integration: Compute Rewards (part of the forward pass for gradient purposes)
        # Note: The reward_model instance used here must be the same one passed to PPOOptimizer
        # The reward_model forward method expects input_embeddings which is (batch, seq_len, embed_dim)
        # and outputs (batch, seq_len, 1)
        rewards = self.reward_model(x)  # Shape: (batch_size, seq_len, 1)

        # Policy Network (Action Selection)
        last_hidden_state = x[:, -1, :]  # Use the last hidden state for action selection
        logits = self.policy_network(last_hidden_state) # policy_network expects (batch, embed_dim), outputs (batch, num_actions)
        probs = F.softmax(logits, dim=-1) # probs is (batch, num_actions)

        # If actions are not provided, sample from policy
        # If actions *are* provided (during PPO update), use them
        if actions is None:
            # Sample actions when actions are not provided (first forward pass)
            # torch.multinomial expects input probs (batch, num_actions) and returns (batch, num_samples)
            actions = torch.multinomial(probs, num_samples=1).squeeze(-1) # actions is (batch,)


        # Compute log probs for selected actions
        # probs.gather expects index tensor (batch, 1) if dim=-1
        # actions is (batch,), unsqueeze to (batch, 1)
        new_log_probs = torch.log(probs.gather(-1, actions.unsqueeze(-1))).squeeze(-1) # new_log_probs is (batch,)


        # Value Network (State Value Estimation)
        # value_network expects (batch, embed_dim), outputs (batch, 1)
        # squeeze(-1) makes it (batch,)
        values = self.value_network(last_hidden_state).squeeze(-1) # values is (batch,)

        return x, new_log_probs, actions, values, rewards  # Return all five values

In [38]:
import os
import torch
from torch.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
import json
from pathlib import Path
import time

# from Layer1.Layer1 import Layer1
# from Layer2WithMemory.Layer2WithMemory import Layer2WithMemory
# from TokenEmbeddings.TokenEmbeddings import TokenEmbedding
# from Tokenization.tokenization import TokenizationLayer
# from Layer3WithContextAndRL.Layer3WithContextAndRL import Layer3WithContextAndRL
# from PPOOptimizer.PPOOptimizer import PPOOptimizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
torch.manual_seed(42)
print(f"Using {device} for training 🔥")

# -------- Dataset --------
# -------- Dataset --------
class BhaiDataset(Dataset):
    def __init__(self, file_path, max_samples=None):
        self.file_path = Path(file_path)
        self.max_samples = max_samples
        self.data = self._load_data()

    def _load_data(self):
        if self.file_path.suffix == '.txt':
            with open(self.file_path, 'r', encoding='utf-8', errors='replace') as f:
                lines = [self._clean(line) for line in f if line.strip()]
        elif self.file_path.suffix == '.json':
            with open(self.file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                lines = self._parse_json(data)
        else:
            raise ValueError("Only .txt or .json files supported!")

        if self.max_samples:
            lines = lines[:self.max_samples]
        return lines

    def _clean(self, text):
        return ''.join(c for c in text if c not in {'\x00', '\ufffd', '�', '\r'}).strip()

    def _parse_json(self, data):
        texts = []
        if isinstance(data, list):
            for item in data:
                if 'text' in item:
                    texts.append(self._clean(item['text']))
                elif 'content' in item:
                    texts.append(self._clean(item['content']))
        elif isinstance(data, dict):
            for v in data.values():
                if isinstance(v, str):
                    texts.append(self._clean(v))
        return texts

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# -------- Training Loop --------
def bhai_trainer(dataset_path, epochs=5, batch_size=8, max_seq_len=128, max_samples=50000):
    dataset = BhaiDataset(dataset_path, max_samples=max_samples)
    print(f"📦 Loaded {len(dataset)} samples out of total with limit {max_samples}")

    tokenizer = TokenizationLayer()

    def collate_fn(batch):
        tokens = tokenizer.tokenize(batch, max_length=max_seq_len)
        return tokens.to(device)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0)
    print(f"🔢 Batches per epoch: {len(dataloader)} with batch size {batch_size}")

    embed_dim = 256
    hidden_dim = 1024
    num_heads = 8
    memory_size = 100

    embedder = TokenEmbedding(tokenizer.vocab_size, embed_dim=embed_dim).to(device)
    layer1 = Layer1(embed_dim, hidden_dim, num_heads).to(device)
    layer2 = Layer2WithMemory(embed_dim, hidden_dim, num_heads, memory_size).to(device)
    layer3 = Layer3WithContextAndRL(embed_dim, hidden_dim, num_heads, memory_size).to(device)

    ppo_optimizer = PPOOptimizer(model=layer3, reward_model=None)
    optimizer = ppo_optimizer.optimizer
    scaler = GradScaler()

    for epoch in range(epochs):
        epoch_start = time.time()
        for batch_idx, inputs in enumerate(dataloader):
            batch_start = time.time()
            optimizer.zero_grad()

            try:
                with autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
                    emb = embedder(inputs)
                    l1_out = layer1(emb)
                    l2_out = layer2(l1_out)

                    out, old_log_probs, actions, values, rewards = layer3(l2_out, actions=None)

                    states = l2_out.detach().to(torch.float32)
                    actions = actions.detach()
                    old_log_probs = old_log_probs.detach().to(torch.float32)
                    values = values.detach().to(torch.float32)
                    rewards = rewards.mean(dim=1).detach().to(torch.float32)

                    loss = ppo_optimizer.update(
                        states=states,
                        actions=actions,
                        rewards=rewards,
                        old_log_probs=old_log_probs,
                        values=values
                    )

                if not torch.isfinite(loss) or not loss.requires_grad:
                    raise ValueError("Loss is NaN/Inf or does not require grad.")

                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(layer3.parameters(), max_norm=0.5)
                scaler.step(optimizer)
                scaler.update()

                if batch_idx % 10 == 0:
                    gpu_mem = torch.cuda.memory_allocated() // 1024 ** 2 if torch.cuda.is_available() else 0
                    batch_time = time.time() - batch_start
                    print(f"Epoch {epoch + 1} | Batch {batch_idx} | Loss: {loss.item():.4f} | Reward: {rewards.mean().item():.4f} | Value: {values.mean().item():.4f} | GPU: {gpu_mem}MB | Time/batch: {batch_time:.2f}s")

            except Exception as e:
                print(f"🚨 Skipping batch {batch_idx} due to error: {e}")
                continue

        epoch_time = time.time() - epoch_start
        print(f"✅ Epoch {epoch + 1} done in {epoch_time:.2f} seconds.")

        try:
            torch.save({
                'epoch': epoch + 1,
                'loss': loss.item() if torch.isfinite(loss) else None,
                'layer1_state_dict': layer1.state_dict(),
                'layer2_state_dict': layer2.state_dict(),
                'layer3_state_dict': layer3.state_dict(),
                'embedder_state_dict': embedder.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict()
            }, f"bhai_llm_epoch_{epoch + 1}.pt")
            print(f"✅ Model checkpoint saved: bhai_llm_epoch_{epoch + 1}.pt")
        except Exception as e:
            print(f"🚨 Error saving checkpoint: {e}")

# -------- Main --------
if __name__ == "__main__":
    print("🚀 Starting training...")

    dataset_path = "/content/conversation_chatgpt.txt"
    epochs = 5
    batch_size = 8

    if not Path(dataset_path).exists():
        print(f"🚨 Dataset file not found: {dataset_path}")
    else:
        bhai_trainer(dataset_path, epochs, batch_size, max_samples=50000)



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
📉 Loss: 0.4235021471977234
✅ PPO Update Done (Loss computed)! Backprop and Step outside.
Epoch 5 | Batch 5430 | Loss: 0.4235 | Reward: 0.6964 | Value: 6.2593 | GPU: 1186MB | Time/batch: 0.02s
🔍 PPO Update Start
🔹 Advantages: torch.Size([8, 1])
✅ Log Probs: torch.Size([8]), ✅ Values: torch.Size([8])
🔄 Returns shape: torch.Size([8])
📉 Loss: 0.4354788661003113
✅ PPO Update Done (Loss computed)! Backprop and Step outside.
🔍 PPO Update Start
🔹 Advantages: torch.Size([8, 1])
✅ Log Probs: torch.Size([8]), ✅ Values: torch.Size([8])
🔄 Returns shape: torch.Size([8])
📉 Loss: 0.41691529750823975
✅ PPO Update Done (Loss computed)! Backprop and Step outside.
🔍 PPO Update Start
🔹 Advantages: torch.Size([8, 1])
✅ Log Probs: torch.Size([8]), ✅ Values: torch.Size([8])
🔄 Returns shape: torch.Size([8])
📉 Loss: 0.425884485244751
✅ PPO Update Done (Loss computed)! Backprop and Step outside.
🔍 PPO Update Start
🔹 Advantages: torch.Size([8, 1])
✅

In [39]:
# file ipython-input-19-f30f37f21b8b
import torch
from pathlib import Path
# from Tokenization.tokenization import TokenizationLayer
# from TokenEmbeddings.TokenEmbeddings import TokenEmbedding
# from Layer1.Layer1 import Layer1
# from Layer2WithMemory.Layer2WithMemory import Layer2WithMemory
# from Layer3WithContextAndRL.Layer3WithContextAndRL import Layer3WithContextAndRL
import torch.nn as nn  # Import nn for the projection layer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# --- Load model and tokenizer ---
def load_model_and_tokenizer(pt_path):
    # Assuming TokenizationLayer is defined/imported from a previous cell
    tokenizer = TokenizationLayer()

    embed_dim = 256
    hidden_dim = 1024
    num_heads = 8
    memory_size = 100
    # Add num_actions here as it's required by Layer3WithContextAndRL
    num_actions = 10  # Assuming this matches the training configuration

    # Assuming TokenEmbedding, Layer1, Layer2WithMemory, Layer3WithContextAndRL
    # are defined/imported from previous cells
    embedder = TokenEmbedding(tokenizer.vocab_size, embed_dim=embed_dim).to(device)
    layer1 = Layer1(embed_dim, hidden_dim, num_heads).to(device)
    layer2 = Layer2WithMemory(embed_dim, hidden_dim, num_heads, memory_size).to(device)
    # Pass num_actions to Layer3 constructor
    layer3 = Layer3WithContextAndRL(embed_dim, hidden_dim, num_heads, memory_size, num_actions=num_actions).to(device)

    checkpoint = torch.load(pt_path, map_location=device)

    try:
        # Load state_dict for embedder, layer1, and layer2
        embedder.load_state_dict(checkpoint['embedder_state_dict'])
        layer1.load_state_dict(checkpoint['layer1_state_dict'])
        layer2.load_state_dict(checkpoint['layer2_state_dict'])

        # Load Layer3 state_dict
        # Assuming the keys in the checkpoint match the Layer3 in ipython-input-10.
        layer3.load_state_dict(checkpoint['layer3_state_dict'])


    except Exception as e:  # Catch general exceptions during loading
        print(f"🚨 Error loading state_dict: {e}. Ensure model architecture matches checkpoint.")
        raise  # Re-raise the error after printing context

    # Set all layers to evaluation mode
    embedder.eval()
    layer1.eval()
    layer2.eval()
    layer3.eval()

    # Ensure memory_bank parameter in Layer3 is on the correct device if it exists
    # and not loaded by state_dict (though state_dict should handle this)
    if hasattr(layer3, 'memory_bank') and isinstance(layer3.memory_bank, torch.nn.Parameter):
        layer3.memory_bank.data = layer3.memory_bank.data.to(device)
        # Ensure memory gate weights are also on the correct device if not handled by state_dict
        if hasattr(layer3, 'memory_gate'):
            layer3.memory_gate.to(device)

    # Return the loaded modules
    return tokenizer, embedder, layer1, layer2, layer3


# --- Greedy decoding function ---
@torch.no_grad()
def generate_text(tokenizer, embedder, layer1, layer2, layer3, prompt, max_length=50):
    # Use the same max_length as the trainer for consistency in input size for embeddings
    initial_input_length = 128  # Match trainer's max_seq_len
    input_ids = tokenizer.tokenize([prompt], max_length=initial_input_length).to(device)  # (1, initial_input_length)

    # We will append generated tokens to input_ids, so start with the tokenized prompt
    generated_input_ids = input_ids  # Shape (1, current_seq_len)

    # Ensure the model is on the correct device and in eval mode before generating
    embedder.to(device).eval()
    layer1.to(device).eval()
    layer2.to(device).eval()
    layer3.to(device).eval()

    # Define a projection layer from embed_dim to vocab_size
    # This layer is NOT trained. Its weights are random or default initialized.
    # This is a workaround to get token logits from a model not trained for this.
    vocab_size = tokenizer.vocab_size
    embed_dim = layer3.embed_dim  # Get embed_dim from Layer3 instance
    token_projection_layer = nn.Linear(embed_dim, vocab_size).to(device)
    # Initialize the projection layer (optional, but might help slightly)
    nn.init.xavier_uniform_(token_projection_layer.weight)
    if token_projection_layer.bias is not None:
        nn.init.zeros_(token_projection_layer.bias)

    # Use a loop for generation
    for _ in range(max_length):  # Generate up to max_length *new* tokens
        # For attention models, the forward pass usually takes the *entire* sequence so far.
        # If the sequence length exceeds the model's max_seq_len (128 in this case),
        # you need to handle this, e.g., by truncating the input to the last `initial_input_length` tokens.
        current_seq_len = generated_input_ids.shape[1]
        if current_seq_len > initial_input_length:
            # Truncate the input sequence to the model's expected max length
            model_input_ids = generated_input_ids[:,
                              -initial_input_length:]  # Take the last `initial_input_length` tokens
        else:
            model_input_ids = generated_input_ids  # Use the whole sequence if shorter than max length

        # Apply embeddings, layer1, layer2, and layer3
        # Ensure inputs to layers are on the correct device
        model_input_ids = model_input_ids.to(device)
        with torch.no_grad():  # Ensure no gradients are computed during inference
            emb = embedder(model_input_ids)
            l1_out = layer1(emb)
            l2_out = layer2(l1_out)
            # Layer3 forward returns 5 values. `out` is the hidden states of the last layer.
            out, _, _, _, _ = layer3(l2_out, actions=None)  # Pass actions=None for inference

        # Get the hidden state for the *last* token in the output sequence
        # `out` has shape (batch_size, current_model_seq_len, embed_dim)
        last_hidden_state = out[:, -1, :]  # Shape (1, embed_dim)

        # Project the last hidden state to the vocabulary size to get token logits
        # Ensure the hidden state is on the correct device for the projection layer
        next_token_logits = token_projection_layer(last_hidden_state.to(device))  # Shape (1, vocab_size)

        # Select the token with the highest probability (greedy decoding)
        # Ensure next_token_logits is on the correct device
        next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(1)  # Shape (1, 1)

        # Append the predicted token to the generated sequence
        generated_input_ids = torch.cat((generated_input_ids, next_token_id), dim=1)  # Shape (1, current_seq_len + 1)

        # Stop if EOS token is generated
        # Check if the generated token is the EOS token ID
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    # Remove the prompt tokens from the generated sequence to get the response tokens
    # Ensure generated_input_ids is on CPU and converted to list for slicing
    generated_tokens_list = generated_input_ids[0].cpu().tolist()
    input_ids_list = input_ids[0].cpu().tolist()

    # Find the index of the first token after the initial prompt.
    # This assumes the generated_input_ids starts exactly with input_ids.
    # If the prompt itself contains the EOS token, this might behave unexpectedly.
    # A more robust approach might be to just take the tokens appended *after* the loop starts.
    # However, sticking to the original slicing logic:
    # Ensure the slicing is valid
    if len(generated_tokens_list) > len(input_ids_list):
        response_tokens = generated_tokens_list[len(input_ids_list):]
    else:
        # If no new tokens were generated beyond the prompt length
        response_tokens = []
        # print("Warning: No tokens generated beyond the prompt length.")

    # Detokenize the response tokens using the *correct* method name
    generated_text = tokenizer.detokenize(response_tokens)  # ✅ Correct method name
    return generated_text


# --- Main chatbot loop ---
if __name__ == "__main__":
    # Define pt_path here
    pt_path = "/content/bhai_llm_epoch_5.pt"  # ✅ Change to your checkpoint path

    if not Path(pt_path).exists():
        print(f"🚨 Checkpoint not found: {pt_path}")
        exit(1)

    try:
        tokenizer, embedder, layer1, layer2, layer3 = load_model_and_tokenizer(pt_path)
        print("✅ Model and tokenizer loaded!")
        print("🤖 Chatbot ready. Type 'quit' to exit.")

        while True:
            user_input = input("You: ")
            if user_input.lower() in ['quit', 'exit']:
                print("Bye! 👋")
                break

            # Ensure input is on the correct device before passing to generate_text if needed,
            # but generate_text handles moving tensors to device.
            try:
                response = generate_text(tokenizer, embedder, layer1, layer2, layer3, user_input, max_length=50)
                print(f"Bot: {response}")
            except Exception as gen_error:
                print(f"🚨 An error occurred during text generation: {gen_error}")
                # You might want to add more specific error handling ocr debug prints here.


    except Exception as load_error:  # Catch errors during model loading
        print(f"An error occurred during model loading: {load_error}")


Using device: cuda
✅ Model and tokenizer loaded!
🤖 Chatbot ready. Type 'quit' to exit.
You: hi
Bot: wswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswswsws
You: hi
Bot:  profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable profitable
You: hello
Bot: egisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegisegis
You: tell me somet

KeyboardInterrupt: Interrupted by user