# Vision (Image and Video) Model

In [None]:
#VIT
class FlexiblePatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_size=768, temporal_patch_size=1, is_3d=False):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_size = embed_size
        self.is_3d = is_3d

        if is_3d:
            self.num_patches = int(((img_size // patch_size) ** 2) * temporal_patch_size)
            self.projection = nn.Conv3d(in_channels, embed_size, kernel_size=(temporal_patch_size, patch_size, patch_size), 

                                        stride=(temporal_patch_size, patch_size, patch_size))

        else:
            self.num_patches = (img_size // patch_size) ** 2
            self.projection = nn.Conv2d(in_channels, embed_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)  # [B, E, T/P, H/P, W/P] or [B, E, H/P, W/P]
        x = x.flatten(2)  # Flatten spatial and temporal dimensions
        x = x.transpose(1, 2)  # [B, N, E]
        return x

class PositionalEmbedding(nn.Module):
    def __init__(self, num_patches, embed_size):
        super().__init__()

        self.positional_embedding = nn.Parameter(torch.zeros(1, num_patches + 1, embed_size))

    def forward(self, x):
        batch_size = x.shape[0]
        cls_token = torch.zeros(batch_size, 1, x.shape[-1], device=x.device)
        x = torch.cat([cls_token, x], dim=1)  # [B, 1+N, E]
        x += self.positional_embedding
        return x

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        # Feedforward network
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = nn.ReLU()

    def forward(self, src):
        src2 = self.self_attn(src, src, src)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers):
        super(TransformerEncoder, self).__init__()
        self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])
        self.num_layers = num_layers

    def forward(self, src):
        for layer in self.layers:
            src = layer(src)

        return src


class FlexibleVisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_size=768, num_heads=12, num_layers=12, num_classes=1000, temporal_patch_size=1, is_3d=False):
        super().__init__()
        self.patch_embedding = FlexiblePatchEmbedding(img_size, patch_size, in_channels, embed_size, temporal_patch_size, is_3d)
        self.positional_embedding = PositionalEmbedding(self.patch_embedding.num_patches, embed_size)

        encoder_layer = TransformerEncoderLayer(d_model=embed_size, nhead=num_heads)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_size),
            nn.Linear(embed_size, num_classes)
        )

    def forward(self, x, return_embeddings=False):
        x = self.patch_embedding(x)
        x = self.positional_embedding(x)
        x = self.transformer_encoder(x)
        if return_embeddings:
            return x  # Return the sequence of embeddings directly
        cls_token = x[:, 0]
        x = self.mlp_head(cls_token)
        return x
    


# Language Model

In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import re
import collections
import json
import numpy as np
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
# Test on larger text
import re
import collections
from collections import Counter, defaultdict
import json

from transformers import BertModel
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from tqdm.notebook import tqdm
from torch.nn.utils.rnn import pad_sequence
import os

# Set the HF_HOME environment variable to a new cache directory on the D drive
os.environ['HF_HOME'] = 'D:/hf_datasets_cache'
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
class ModelConfig:
    def __init__(self):
        self.tokenizer = wordpiece_tokenizer
        self.vocab_size = len(wordpiece_tokenizer.vocab)
        self.embed_size = 512  # Size of each embedding vector
        self.heads = 8  # Number of attention heads
        self.num_layers = 6  # Number of transformer blocks
        self.forward_expansion = 4  # Expansion size for the feedforward layer
        self.dropout = 0.1
        self.max_length = 1024  # Maximum length of the input sequences
        self.rank = 64  # Rank for LORA adjustments
        self.sequence_length = 1024  # Input sequence length for SPLASH
        self.projection_dim = 256  # Projection dimension for SPLASH
        self.partition_size = 128  # Partition size for SPLASH processing
        self.device = 'cpu'
        #self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.alpha = 2  # Alpha parameter for LORA layers
        self.quantization_bits = 8  # Bit size for quantization in QLORA
        self.freq_threshold = 1000  # Frequency threshold for adaptive embeddings
        self.large_embed_dim = 512  # Embedding dimension for frequent tokens
        self.small_embed_dim = 128  # Embedding dimension for infrequent tokens
        self.head_dim = self.embed_size // self.heads

  # Language Model Transformer
class LanguageExpert:
    def __init__(self, config):
        self.config = config
        self.language_model_transformer = self.LanguageModelTransformer(config, tokenizer=config.tokenizer)
        self.splash = self.SPLASH(config)
        self.dpo = self.DPO(config, self.language_model_transformer)

    ###############################

    # SPLASH
    class SPLASH(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.vocab_size = config.vocab_size
            self.embed_size = config.embed_size
            self.heads = config.heads
            self.sequence_length = config.sequence_length
            self.projection_dim = config.projection_dim
            self.partition_size = config.partition_size
            self.head_dim = self.embed_size // self.heads

            self.embedding = nn.Embedding(config.vocab_size, config.embed_size)
            self.values = nn.Linear(config.head_dim, config.head_dim, bias=False)
            self.keys = nn.Linear(config.head_dim, config.projection_dim, bias=False)
            self.queries = nn.Linear(config.head_dim, config.projection_dim, bias=False)
            self.value_projection = nn.Linear(config.head_dim, config.projection_dim//2)  # To project values to match dimensions
            self.ponder = nn.Linear(config.partition_size, 1, bias=True)
            self.sigmoid = nn.Sigmoid()
            #self.final_projection = nn.Linear(config.heads * (config.projection_dim // 2), config.vocab_size)
            self.final_projection = nn.Linear(config.heads * (config.projection_dim // 2), config.embed_size)

        def random_projection(self, matrix, k):
            """Random projection to reduce dimensionality of matrix to k dimensions."""
            random_matrix = torch.randn(matrix.size(-1), k, device=matrix.device)
            return torch.matmul(matrix, random_matrix)

        def cur_decomposition(self, matrix, projection_dim):
            """Applies CUR decomposition with C matrix dimension aligned to projection dimension."""
            batch_size, seq_length, heads, dim = matrix.shape
            k = min(projection_dim // 2, dim, seq_length)
            C = torch.zeros(batch_size, seq_length, heads, k, device=matrix.device)
            R = torch.zeros(batch_size, k, heads, dim, device=matrix.device)

            for b in range(batch_size):
                for h in range(heads):
                    # Using torch.randperm for random indices in PyTorch
                    col_indices = torch.randperm(dim, device=matrix.device)[:k]
                    row_indices = torch.randperm(seq_length, device=matrix.device)[:k]
                    C[b, :, h] = matrix[b, :, h, col_indices]
                    R[b, :, h] = matrix[b, row_indices, h]
            return C, R

        def forward(self, input_ids):
            if input_ids.dim() != 2 or input_ids.dtype != torch.long:
                raise ValueError(f"input_ids must be a 2D tensor of long integers, got shape {input_ids.shape} and dtype {input_ids.dtype}")
            # Check for out-of-range token IDs right before embedding call
            if input_ids.max() >= self.vocab_size:
                print(f"Error: Max token ID {input_ids.max().item()} exceeds vocab size {self.vocab_size}")

            print(f"Before embedding: input_ids.shape={input_ids.shape}, max input_id={input_ids.max().item()}, vocab_size={self.vocab_size}")

            x = self.embedding(input_ids.long())            
            print(f"After embedding: {x.shape}")
            N, seq_length, _ = x.shape
            x_reshaped = x.view(N, seq_length, self.heads, self.head_dim)
            print(f"x_reshaped shape: {x_reshaped.shape}")

            values = self.values(x_reshaped)
            queries = self.random_projection(self.queries(x_reshaped), self.projection_dim // 2)
            keys = self.random_projection(self.keys(x_reshaped), self.projection_dim // 2 )
            print("values :", values.shape)
            print("queries :", queries.shape)
            print("keys :", keys.shape)

            attention_scores = torch.zeros(N, self.heads, seq_length, self.projection_dim // 2, device=x.device)
            print(f"attention_scores after random projection: {attention_scores.shape}")

            for i in range(0, seq_length, self.partition_size):
                print(f"PARTITION START")
                partition_start = i
                partition_end = min(i + self.partition_size, seq_length)
                keys_part = keys[:, partition_start:partition_end, :, :]
                queries_part = queries[:, partition_start:partition_end, :, :]

                C_keys, R_queries = self.cur_decomposition(keys_part, self.projection_dim)
                print("C_keys shape before return:", C_keys.shape)

                ponder_scores = torch.zeros(N, self.heads, partition_end - partition_start, 1, device=x.device)
                print(f"Partition Start {i}, Partition End {partition_end} , ponder_scores: {ponder_scores.shape}")

                for h in range(self.heads):
                    #print(f"HEADS START")
                    head_queries = queries_part[:, :, h, :]
                    #print(f"head_queries: {head_queries.shape}")
                    head_ponder_scores = self.sigmoid(self.ponder(head_queries))
                    #print(f"head_ponder_scores: {head_ponder_scores.shape}")
                    ponder_scores[:, h, :, 0] = head_ponder_scores.squeeze(-1)

                # Correctly expand ponder_scores without adding an unnecessary dimension
                print("BEFORE 1ST EINSUM:")
                ponder_scores_permuted = ponder_scores.permute(0, 2, 1, 3)  # Move to [2, 128, 8, 1]
                print("ponder_scores_permuted shape:", ponder_scores_permuted.shape) 
                ponder_scores_broadcastable = ponder_scores_permuted.expand(-1, -1, -1, 128)  # Expand to [2, 128, 8, 128]            
                print("ponder_scores_broadcastable shape:", ponder_scores_broadcastable.shape) 
                print("queries_part shape:", queries_part.shape) 
                print("C_keys shape:", C_keys.shape)
                energy = torch.einsum('bnhd,bnhk->bnhd', queries_part, C_keys)
                attention_weights = F.softmax(energy, dim=-1)
                print("AFTER 1ST EINSUM:")
                print("energy shape:", energy.shape) 
                print("attention_weights shape:", attention_weights.shape)
                attention = attention_weights * ponder_scores_broadcastable
                print("attention shape:", attention.shape)
                attention_corrected = attention.permute(0, 2, 1, 3)
                attention_scores[:, :, partition_start:partition_end, :] = attention_corrected

            values = values.permute(0, 2, 1, 3)  # Swap heads and seq_length to bring heads next to head_dim
            print("values shape:", values.shape)
            values = values.reshape(-1, self.head_dim)  # Flatten to [N*heads*seq_length, head_dim] for linear layer
            print("values.reshape(-1, self.head_dim) shape:", values.shape)
            projected_values = self.value_projection(values)  # Now [N*heads*seq_length, projection_dim / 2]
            print("self.value_projection(values) shape:", projected_values.shape)
            projected_values = projected_values.view(N, self.heads, seq_length, self.projection_dim // 2)
            print("projected_values shape:", projected_values.shape)

            print(f"2ND EINSUM")
            # Combine attention_scores and projected_values then pass through the final linear layer
            out = torch.einsum('bnhp,bnhp->bnhp', attention_scores, projected_values)
            print("out shape after einsum:", out.shape)

            # Correct reshaping: Flatten batch and sequence length dimensions, keep the last two dimensions for projection
            out = out.reshape(-1, self.heads * (self.projection_dim // 2))
            print("out reshaped for projection:", out.shape)

            # Ensure the final_projection layer matches the flattened shape expected after reshaping
            # Assuming final_projection is defined as nn.Linear(self.heads * (self.projection_dim // 2), vocab_size)
            out = self.final_projection(out)
            print(f"out after final_projection:", out.shape)

            # At this point, out should have a shape of [batch_size * sequence_length, vocab_size], ready for loss calculation
            return out


    # LORA
    class LORALayer(nn.Module):
        def __init__(self, config, input_dim, output_dim):
            super(LanguageExpert.LORALayer, self).__init__()
            self.rank = config.rank
            self.alpha = config.alpha

            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            self.A = nn.Parameter(torch.Tensor(input_dim, self.rank))
            self.B = nn.Parameter(torch.Tensor(self.rank, output_dim))

            self.reset_parameters()


        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def forward(self, x):
            #print("LORALayer Input Shape:", x.shape)
            
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)
            
            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            #print("LORALayer Output Shape:", x.shape)

            return x

    # QLORA
    class QLORALayer(nn.Module):
        def __init__(self, config, input_dim, output_dim):
            super(LanguageExpert.QLORALayer, self).__init__()
            self.rank = config.rank
            self.alpha = config.alpha
            self.quantization_bits = config.quantization_bits

            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            self.A = nn.Parameter(torch.Tensor(input_dim, self.rank))
            self.B = nn.Parameter(torch.Tensor(self.rank, output_dim))

            self.dropout = nn.Dropout(config.dropout)
            self.layer_norm = nn.LayerNorm(output_dim)

            self.reset_parameters()


        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def quantize(self, x, num_bits):
            # Implement a simple quantization method
            scale = x.abs().max()
            x_quantized = torch.round(x / scale * (2**num_bits - 1))
            return x_quantized, scale

        def forward(self, x):
            #print("QLORALayer Input Shape:", x.shape)
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
            B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
            lora_adjustment = self.dropout(lora_adjustment)
            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)

            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            x = self.layer_norm(x)

            #print("QLORALayer Output Shape:", x.shape)

            return x
        
        def update_alpha(self, new_alpha):
            """
            Update the alpha scaling factor.
            """
            self.alpha = new_alpha
    
    class LanguageModelDecoder(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.config = config  # Already stored
            self.layers = nn.ModuleList([
                LanguageExpert.TransformerBlock(config) for _ in range(config.num_layers)
            ])
            self.fc_out = nn.Linear(config.embed_size, config.vocab_size)
            self.dropout = nn.Dropout(config.dropout)

        def forward(self, x):
            # Ensure x is token IDs for SPLASH
            x = x.to(dtype=torch.long)
            # Note: No need to apply dropout here as x are token IDs

            for layer in self.layers:
                x = layer(x)  # SPLASH is called within TransformerBlock

            # Apply dropout after all transformer blocks and before the final projection layer
            x = self.dropout(x)

            out = self.fc_out(x)  # Final projection from transformer block output to vocab size

            return out


    class TransformerBlock(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.splash = LanguageExpert.SPLASH(config=config)
            self.norm1 = nn.LayerNorm(config.embed_size)
            self.norm2 = nn.LayerNorm(config.embed_size)
            self.feed_forward = nn.Sequential(
                nn.Linear(config.embed_size, config.forward_expansion * config.embed_size),
                nn.ReLU(),
                nn.Linear(config.forward_expansion * config.embed_size, config.embed_size),
            )
            self.dropout = nn.Dropout(config.dropout)

        def forward(self, input_ids):
            # SPLASH already expects token IDs and handles embedding
            attention_output = self.splash(input_ids)
            attention_output = self.dropout(self.norm1(attention_output))
            forward_output = self.feed_forward(attention_output)
            output = self.dropout(self.norm2(forward_output + attention_output))
            return output
        
    class LanguageModelTransformer(nn.Module):
        def __init__(self, config, tokenizer):
            super().__init__()
            self.config = config
            self.tokenizer = tokenizer
            self.splash = LanguageExpert.SPLASH(config)
            self.decoder = LanguageExpert.LanguageModelDecoder(config)


        def forward(self, input_ids=None, embeddings=None, attention_mask=None):
            if input_ids is not None:
                # Explicitly cast input_ids to long integers
                input_ids = input_ids.long()
                embeddings = self.splash(input_ids)
            elif embeddings is None:
                raise ValueError("Either input_ids or embeddings must be provided.")
            
            # Continue processing with embeddings
            # Now, embeddings are passed to the decoder or subsequent layers
            out = self.decoder(embeddings)  # Adjust your decoder to accept embeddings
            return out

        def make_trg_mask(self, trg):
            N, trg_len = trg.shape
            trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len).to(trg.device)
            return trg_mask

        def toggle_qlora(self, use_qlora: bool):
            self.decoder.toggle_qlora(use_qlora)

        def generate_response(self, input_ids, attention_mask):
            logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
            probabilities = F.softmax(logits, dim=-1)
            predicted_token_id = torch.argmax(probabilities, dim=-1)
            predicted_tokens = [self.tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
            response = self.tokenizer.convert_tokens_to_string(predicted_tokens)
            return response

    
    ##################################################
    # Tokenizer

    class TrieNode:
        def __init__(self):
            self.children = {}
            self.token_id = None
            self.frequency = 0
            self.failure_link = None
            self.is_end = False  # Add is_end attribute to mark the end of a word
            self.token = None  # Add token attribute to store the token associated with the node


    class Trie:
        def __init__(self, unk_token_id=0):
            self.root = LanguageExpert.TrieNode()
            self.unk_token_id = unk_token_id

        def insert(self, token, token_id, frequency):
            node = self.root
            for char in token:
                if char not in node.children:
                    node.children[char] = LanguageExpert.TrieNode()
                node = node.children[char]
            node.token_id = token_id
            node.frequency = frequency

        def find_subwords(self, token):
            """Finds the most probable subwords based on frequency."""
            node = self.root
            best_subwords = []

            def dfs(current_node, subword='', collected_subwords=[]):
                if current_node.token_id is not None:
                    # Update to correctly calculate total_frequency based on the structure of collected_subwords
                    total_frequency = sum(n.frequency for _, _, n in collected_subwords) + current_node.frequency
                    probability = current_node.frequency / total_frequency if total_frequency else 0
                    collected_subwords.append((subword, probability, current_node))

                for char, next_node in current_node.children.items():
                    dfs(next_node, subword + char, list(collected_subwords))  # Create a copy of the list to avoid shared state

            dfs(node)
            best_subwords = sorted(best_subwords, key=lambda x: x[1], reverse=True)
            return [subword for subword, _, _ in best_subwords][:5] or [self.unk_token_id]


        def compute_failure_links(self):
            root = self.root
            root.failure_link = root  # Root's failure link points to itself
            queue = [root]

            while queue:
                current_node = queue.pop(0)

                for char, child_node in current_node.children.items():
                    queue.append(child_node)

                    # Follow failure link to find the longest suffix for the child_node
                    failure_candidate = current_node.failure_link
                    while failure_candidate != root and char not in failure_candidate.children:
                        failure_candidate = failure_candidate.failure_link
                    child_node.failure_link = failure_candidate.children.get(char, root)


    class SimpleSentencePiece:
        def __init__(self, model_type="bpe", vocab_size=30522):
            self.vocab = {}
            self.id_to_subword = {}
            self.unk_token = "[UNK]"
            self.unk_token_id = 0
            self.vocab_size = vocab_size
            self.model = None if model_type == "bpe" else None
            self.model_type = model_type

        def train(self, text):
            if self.model_type == "bpe":
                self.model = LanguageExpert.BPE(num_merges=self.vocab_size, unk_token_id=self.unk_token_id)
                self.model.train(text)
                self.vocab = self.model.vocab
                self.id_to_subword = {i: word for word, i in self.vocab.items()}
            else:
                raise NotImplementedError(f"Model type {self.model_type} not supported yet.")

        def encode(self, text):
            text = self.preprocess_text(text)  # Preprocess text before encoding
            if not self.model:
                raise ValueError("Model has not been trained yet.")
            encoded = self.model.encode(text)
            #print(f"Encoded: {encoded[:10]}")  # Print first 10 encoded tokens
            return encoded

        def decode(self, ids):
            if not self.id_to_subword:
                raise ValueError("Vocabulary is empty. Ensure the model is trained first.")
            text = " ".join([self.id_to_subword.get(id_, self.unk_token) for id_ in ids])
            text = text.replace(" </w>", "").replace("</w>", " ").strip()
            return text

        def preprocess_text(self, text):
            # Convert text to lowercase to ensure case insensitivity
            text = text.lower()
            # Optionally, handle punctuation by adding spaces around it for better tokenization
            text = re.sub(r'([.,!?()])', r' \1 ', text)
            # Replace multiple spaces with a single space
            text = re.sub(r'\s+', ' ', text)
            # Trim leading and trailing spaces
            text = text.strip()
            return text
        
        def save_model(self, filepath):
            model_data = {
                'vocab': self.vocab,
                'id_to_subword': self.id_to_subword,
                'model_type': self.model_type,
                'vocab_size': self.vocab_size,
                # Potentially include other relevant attributes
            }
            # Save the high-level tokenizer settings
            with open(filepath, 'w') as f:
                json.dump(model_data, f)
            
            # Now save the BPE model specifically
            if self.model_type == "bpe" and self.model:
                self.model.save_model(filepath + "_bpe")

        def load_model(self, filepath):
            with open(filepath, 'r') as f:
                model_data = json.load(f)
            
            self.vocab = model_data['vocab']
            self.id_to_subword = model_data['id_to_subword']
            self.model_type = model_data['model_type']
            self.vocab_size = model_data['vocab_size']
            
            # Assuming model_type is still "bpe", we now load the BPE model
            if self.model_type == "bpe":
                self.model = LanguageExpert.BPE(self.vocab_size, self.unk_token_id)
                self.model.load_model(filepath + "_bpe")

    class BPE:
        def __init__(self, num_merges=100, unk_token_id=0):  # Accept unk_token_id parameter
            self.vocab = {}
            self.merges = []
            self.num_merges = num_merges
            self.unk_token_id = unk_token_id  # Store the unknown token ID

        def train(self, text):
            words = re.findall(r'\w+|[^\w\s]', text, re.UNICODE)
            vocab = collections.Counter(words)
            vocab = {word + '</w>': count for word, count in vocab.items()}
            
            for _ in range(self.num_merges):  # Use the num_merges from the instance variable
                pairs = self.get_stats(vocab)
                if not pairs:
                    break
                best = max(pairs, key=pairs.get)
                vocab = self.merge_vocab(best, vocab)
                self.merges.append(best)

            self.vocab = {word: i for i, word in enumerate(vocab.keys())}

        @staticmethod
        def get_stats(vocab):
            pairs = collections.defaultdict(int)
            for word, freq in vocab.items():
                symbols = word.split()
                for i in range(len(symbols)-1):
                    pairs[symbols[i], symbols[i+1]] += freq
            return pairs

        @staticmethod
        def merge_vocab(pair, vocab):
            v_out = {}
            bigram = re.escape(' '.join(pair))
            p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
            for word in vocab:
                w_out = p.sub(''.join(pair), word)
                v_out[w_out] = vocab[word]
            return v_out

        def encode(self, text):
            """Encode text into subwords using learned BPE merges."""
            encoded_tokens = []
            for word in re.findall(r'\w+|[^\w\s]', text, re.UNICODE):
                word += '</w>'
                subwords = [word]  # Start with the entire word as one subword
                for merge in self.merges:
                    new_subwords = []
                    for subword in subwords:
                        # If the merge is in subword, split it; otherwise, keep it as is
                        if ' '.join(merge) in subword:
                            new_subwords.extend(subword.replace(' '.join(merge), ''.join(merge)).split(' '))
                        else:
                            new_subwords.append(subword)
                    subwords = new_subwords
                encoded_tokens.extend(subwords)
            return [self.vocab.get(token, self.unk_token_id) for token in encoded_tokens]
        
            # New method to save trained model
        def save_model(self, filepath):
            bpe_data = {
                'merges': self.merges,
                'vocab': self.vocab,
                'num_merges': self.num_merges,
                # Include other attributes as needed
            }
            with open(filepath, 'w') as f:
                json.dump(bpe_data, f)

        def load_model(self, filepath):
            with open(filepath, 'r') as f:
                bpe_data = json.load(f)
            
            self.merges = bpe_data['merges']
            self.vocab = bpe_data['vocab']
            self.num_merges = bpe_data['num_merges']


    class WordPiece:
        def __init__(self, vocab, unk_token_id=0, unk_token="[UNK]"):
            self.vocab = vocab
            self.unk_token_id = unk_token_id
            self.unk_token = unk_token  # Define the unknown token
            self.root = self.build_trie(vocab)
            self.id_to_token = {id_: token for token, id_ in vocab.items()}  # Inverse mapping
            self.compute_failure_links(self.root)
            print("Trie built successfully.")

        def convert_ids_to_tokens(self, ids):
            """
            Convert a list of token ids back to their string token representations.
            """
            return [self.id_to_token.get(id_, self.unk_token) for id_ in ids]

        # Add debug prints to build_trie to confirm structure
        def build_trie(self, vocab):
            root = LanguageExpert.TrieNode()
            for token in vocab:
                node = root
                for char in token:
                    if char not in node.children:
                        node.children[char] = LanguageExpert.TrieNode()
                    node = node.children[char]
                node.is_end = True
                node.token = token
            print("Trie Construction Completed Successfully")
            return root


        def compute_failure_links(self, root):
            queue = [root]
            while queue:
                current_node = queue.pop(0)
                for char, child_node in current_node.children.items():
                    failure_node = current_node.failure_link
                    while failure_node and char not in failure_node.children:
                        failure_node = failure_node.failure_link
                    child_node.failure_link = failure_node.children[char] if failure_node else root
                    queue.append(child_node)

        # Improved debug prints in tokenize method
                    
        def tokenize(self, text):
            # Preprocess input text
            text = self.preprocess_text(text)
            node = self.root
            token_ids = []  # Will store token IDs instead of tokens
            i = 0

            while i < len(text):
                char = text[i]
                if char == ' ':
                    node = self.root
                    i += 1
                    continue

                if char not in node.children:
                    if node != self.root and node.token is not None:
                        # Convert found token to its ID
                        token_id = self.vocab.get(node.token, self.unk_token_id)
                        token_ids.append(token_id)
                        node = self.root  # Reset to root
                        continue
                    else:
                        # Append unknown token ID
                        token_ids.append(self.unk_token_id)
                        i += 1
                        continue

                node = node.children[char]
                if node.is_end:
                    if i + 1 == len(text) or text[i + 1] == ' ':
                        # Convert found token to its ID
                        token_id = self.vocab.get(node.token, self.unk_token_id)
                        token_ids.append(token_id)
                        node = self.root

                i += 1

            #print(f"Token IDs: {token_ids[:10]}")
            return token_ids

        def preprocess_text(self, text):
            # Convert text to lowercase to ensure case insensitivity
            text = text.lower()

            # Optionally, handle punctuation by adding spaces around it for better tokenization
            # This depends on how your vocabulary handles punctuation
            text = re.sub(r'([.,!?()])', r' \1 ', text)

            # Replace multiple spaces with a single space
            text = re.sub(r'\s+', ' ', text)

            # Trim leading and trailing spaces
            text = text.strip()

            return text



###############################
    # DPO
    class DPO(nn.Module):
        def __init__(self, config, language_model):
            super(LanguageExpert.DPO, self).__init__()
            self.language_model = language_model
            self.device = config.device
            self.projection = nn.Linear(config.vocab_size, config.embed_size)
            self.classifier = nn.Linear(config.embed_size, 2)

        def forward(self, input_ids_question, input_ids_chosen=None, input_ids_rejected=None, labels=None):
            combined_input_ids = torch.cat((input_ids_question, input_ids_chosen, input_ids_rejected), dim=1)

            # Assuming combined_input_ids has shape [batch_size, sequence_length]
            logits = self.language_model(combined_input_ids)  # Output shape: [batch_size, sequence_length, vocab_size]

            # Project logits to embedding space before pooling
            projected_logits = self.projection(logits)  # New shape: [batch_size, sequence_length, embed_size]
            
            # Apply global mean pooling across the sequence length dimension
            pooled_logits = projected_logits.mean(dim=1)  # New shape: [batch_size, embed_size]

            # Pass the pooled representation through the classifier
            predictions = self.classifier(pooled_logits)  # New shape: [batch_size, 2]

            # Calculate loss if labels are provided
            loss = None
            if labels is not None:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(predictions, labels)

            return predictions, loss


    ###############################
    # RAG
    
    class PositionalEncoding(nn.Module):
        def __init__(self, config):
            super(LanguageExpert.PositionalEncoding, self).__init__()
            self.d_model = config.embed_size
            self.max_len = config.max_length
            pe = torch.zeros(self.max_len, self.d_model)
            position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * -(math.log(10000.0) / self.d_model))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
            pe = pe.unsqueeze(0)
            self.register_buffer('pe', pe)

        def forward(self, x):
            # x: Tensor of shape [Batch Size, Sequence Length, Embedding Dimension]
            # Adjust positional encoding to match the input size and device
            pe = self.pe[:, :x.size(1)]
            # Assuming x is on the correct device, pe will be automatically aligned to the same device
            return pe


    class AdaptiveDropoutLayer(nn.Module):
        def __init__(self, config):
            super(LanguageExpert.AdaptiveDropoutLayer, self).__init__()
            self.log_alpha = nn.Parameter(torch.tensor(math.log(config.dropout / (1 - config.dropout))).float())

        def forward(self, x):
            p = torch.sigmoid(self.log_alpha)
            # Convert p from a tensor to a float
            p_value = p.item()  # This extracts the scalar value as a Python float
            return nn.functional.dropout(x, p=p_value, training=self.training)


    class AdaptiveEmbeddingLayer(nn.Module):
        def __init__(self, config):
            super(LanguageExpert.AdaptiveEmbeddingLayer, self).__init__()
            self.vocab = config.wordpiece_vocab  # Ensure this is part of your config
            self.vocab_size = config.vocab_size
            self.freq_threshold = config.freq_threshold
            self.large_embed_dim = config.large_embed_dim
            self.small_embed_dim = config.small_embed_dim
            self.max_seq_len = config.max_length
            self.split_vocab(self.vocab, self.freq_threshold)
            self.frequent_embeddings = nn.Embedding(len(self.frequent_vocab), self.large_embed_dim)
            self.infrequent_embeddings = nn.Embedding(len(self.infrequent_vocab), self.small_embed_dim)
            self.infrequent_projection = nn.Linear(self.small_embed_dim, self.large_embed_dim)
            self.positional_embeddings = LanguageExpert.PositionalEncoding(config)



        def split_vocab(self, vocab, freq_threshold):
            token_counts = [(token, count) for token, count in vocab.items()]
            token_counts.sort(key=lambda x: -x[1])  # Sort by frequency
            split_point = next(i for i, (_, count) in enumerate(token_counts) if count < freq_threshold)
            
            self.frequent_vocab = dict(token_counts[:split_point])
            self.infrequent_vocab = dict(token_counts[split_point:])

        def forward(self, token_ids):
            device = token_ids.device
            seq_len = token_ids.size(1)
            batch_size = token_ids.size(0)  # Obtain batch size from token_ids tensor

            # Initialize embeddings tensor
            embeddings = torch.zeros(token_ids.shape[0], seq_len, self.large_embed_dim, device=device)

            # Map token_ids to indices for frequent and infrequent vocab
            frequent_indices = torch.zeros_like(token_ids)
            infrequent_indices = torch.zeros_like(token_ids)
            
            for token_id, index in self.vocab.items():
                mask = token_ids == token_id
                if token_id in self.frequent_vocab:
                    # Map to index in frequent_vocab
                    frequent_indices[mask] = self.frequent_vocab[token_id]
                elif token_id in self.infrequent_vocab:
                    # Map to index in infrequent_vocab
                    infrequent_indices[mask] = self.infrequent_vocab[token_id]

            # Create masks for frequent and infrequent tokens
            frequent_mask = frequent_indices > 0
            infrequent_mask = infrequent_indices > 0

            # Embed frequent tokens
            if frequent_mask.any():
                frequent_embeddings = self.frequent_embeddings(frequent_indices[frequent_mask])
                embeddings[frequent_mask] = frequent_embeddings

            # Embed and project infrequent tokens
            if infrequent_mask.any():
                infrequent_embeddings = self.infrequent_embeddings(infrequent_indices[infrequent_mask])
                infrequent_embeddings_projected = self.infrequent_projection(infrequent_embeddings)
                embeddings[infrequent_mask] = infrequent_embeddings_projected

            # Apply positional embeddings
            position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device).unsqueeze(0)
            position_embeddings = self.positional_embeddings(position_ids)  # Generate for seq_len

            # Ensure positional embeddings are broadcastable to the embeddings tensor
            if position_embeddings.size(0) != batch_size:
                position_embeddings = position_embeddings.expand(batch_size, -1, -1)

            print(f"Embeddings shape: {embeddings.shape}")
            print(f"Positional embeddings shape: {position_embeddings.shape}")
            embeddings += position_embeddings

            return embeddings


    class DPRContextEncoder(nn.Module):
        def __init__(self, config):
            super(LanguageExpert.DPRContextEncoder, self).__init__()
            self.wordpiece_tokenizer = config.wordpiece_tokenizer
            self.embedding_layer = LanguageExpert.AdaptiveEmbeddingLayer(config)
            self.attention_layer = LanguageExpert.SPLASH(config=config).to(config.device)
            self.dropout = LanguageExpert.AdaptiveDropoutLayer(config)
 

        def forward(self, input_ids, attention_mask):
            embeddings = self.embedding_layer(input_ids)
            attention_output = self.attention_layer(embeddings, attention_mask=attention_mask)
            attention_output = self.dropout(attention_output)

            # Mean pooling across the sequence length dimension
            pooled_output = attention_output.mean(dim=1)

            return pooled_output


    class DPRQuestionEncoder(nn.Module):
        def __init__(self, config):
            super(LanguageExpert.DPRQuestionEncoder, self).__init__()
            self.wordpiece_tokenizer = config.wordpiece_tokenizer
            self.embedding_layer = LanguageExpert.AdaptiveEmbeddingLayer(config)
            self.attention_layer = LanguageExpert.SPLASH(config=config).to(config.device)
            self.dropout = LanguageExpert.AdaptiveDropoutLayer(config)

        def forward(self, input_ids, attention_mask):
            embeddings = self.embedding_layer(input_ids)
            attention_output = self.attention_layer(embeddings, attention_mask=attention_mask)
            attention_output = self.dropout(attention_output)

            # Mean pooling across the sequence length dimension
            pooled_output = attention_output.mean(dim=1)

            return pooled_output


    class TransformerRAG(nn.Module):
        def __init__(self, config):
            super(LanguageExpert.TransformerRAG, self).__init__()
            self.config = config
            self.context_encoder = LanguageExpert.DPRContextEncoder(config).to(config.device)
            self.language_model = LanguageExpert.LanguageModelTransformer(config, config.wordpiece_tokenizer).to(config.device)
            self.question_encoder = LanguageExpert.DPRQuestionEncoder(config).to(config.device)
            self.tokenizer = config.wordpiece_tokenizer


        def forward(self, context_texts, question_input_ids, question_attention_mask, question_text):
            if question_input_ids.max() >= config.vocab_size:
                raise ValueError("question_input_ids contain ID(s) beyond the tokenizer's vocabulary size")
            
            aggregated_context_embeddings = []
            for context_list in context_texts:
                if not all(isinstance(context, dict) for context in context_list):
                    raise TypeError("Each item in context_texts must be a list of tokenized context dictionaries")
                
                aggregated_context_embedding = torch.zeros(config.embedding_dim, device=config.device)
                for context in context_list:
                    context_input_ids = torch.tensor(context['input_ids']).unsqueeze(0).to(config.device)  # Add unsqueeze(0) for batch dimension
                    context_attention_mask = torch.tensor(context['attention_mask']).unsqueeze(0).to(config.device)  # Add unsqueeze(0) for batch dimension
                    print(f"context_input_ids shape: {context_input_ids.shape}")
                    print(f"context_attention_mask shape: {context_attention_mask.shape}")

                    context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                    aggregated_context_embedding += context_embedding.mean(dim=0)
                
                aggregated_context_embeddings.append(aggregated_context_embedding / len(context_list))
            
            question_input_ids = question_input_ids.to(config.device).long()
            question_attention_mask = question_attention_mask.to(config.device).long()
            question_embeddings = self.question_encoder(input_ids=question_input_ids, attention_mask=question_attention_mask)
            
            cos_sim = torch.nn.CosineSimilarity(dim=1)
            #similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in aggregated_context_embeddings]
            similarities = [cos_sim(question_embeddings.unsqueeze(0), context_emb.unsqueeze(0)) for context_emb in aggregated_context_embeddings]

            most_relevant_context_idx = torch.argmax(torch.tensor(similarities, device=config.device))
            
            combined_input = question_text + " " + context_texts[most_relevant_context_idx]
            tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
            tokenized_combined_input = {k: v.to(config.device) for k, v in tokenized_combined_input.items()}
            response_logits = self.language_model(**tokenized_combined_input)
            probabilities = F.softmax(response_logits.logits, dim=-1)
            predicted_token_ids = torch.argmax(probabilities, dim=-1)
            predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])
            #response = self.tokenizer.convert_tokens_to_string(predicted_tokens)
            response = " ".join(predicted_tokens).replace(" </w>", "").replace("</w>", " ").strip()
            
            return response

        @staticmethod
        def extract_text_from_pdf(file_path):
            text = ""
            with fitz.open(file_path) as doc:
                for page in doc:
                    text += page.get_text()
            return text

        @staticmethod
        def split_into_chunks(text, chunk_size):
            return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

        @staticmethod
        def preprocess_text(text, tokenizer, max_length=512):
            chunk_size = max_length - 50
            text_chunks = LanguageExpert.TransformerRAG.split_into_chunks(text, chunk_size)
            processed_chunks = []
            for chunk in text_chunks:
                # Tokenize the chunk using the WordPiece tokenizer
                token_ids = tokenizer.tokenize(chunk)
                
                # Manual padding and attention mask creation
                attention_mask = [1] * len(token_ids)
                # Padding: Extend token_ids and attention_mask to max_length
                while len(token_ids) < max_length:
                    token_ids.append(tokenizer.unk_token_id)  # Use unk_token_id for padding
                    attention_mask.append(0)  # Padding token has attention mask 0
                
                # Ensure token_ids and attention_mask are not longer than max_length
                token_ids = token_ids[:max_length]
                attention_mask = attention_mask[:max_length]
                
                processed_chunk = {
                    'input_ids': token_ids,
                    'attention_mask': attention_mask
                }
                processed_chunks.append(processed_chunk)
            return processed_chunks

        @staticmethod
        def create_dataset_from_pdfs(pdf_file_paths, tokenizer):
            dataset = []
            for file_path in pdf_file_paths:
                text = LanguageExpert.TransformerRAG.extract_text_from_pdf(file_path)
                processed_text = LanguageExpert.TransformerRAG.preprocess_text(text, tokenizer)                
                dataset.append(processed_text)
            return dataset

        def retrieve_contexts(self, dataset, query_embedding, top_k=5):
            similarity_scores = []
            for context in dataset:
                context_input_ids = context['input_ids'].to(self.config.device)
                context_attention_mask = context['attention_mask'].to(self.config.device)
                # Use the class's context_encoder
                context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                similarity = torch.matmul(query_embedding, context_embedding.T)
                similarity_scores.append(similarity.squeeze().item())
            top_k_indices = sorted(range(len(similarity_scores)), key=lambda i: similarity_scores[i], reverse=True)[:top_k]
            top_contexts = [dataset[i] for i in top_k_indices]
            return top_contexts

        def rag_retrieve_and_generate(self, dataset, query):
            # Tokenize the query using the class's tokenizer
            tokenized_query = self.tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.config.device)
            input_ids = tokenized_query['input_ids']
            attention_mask = tokenized_query['attention_mask']
            # Use the class's question_encoder
            encoded_query = self.question_encoder(input_ids, attention_mask)
            relevant_contexts = self.retrieve_contexts(dataset, encoded_query)
            # Assuming generate_response is a method of LanguageModelTransformer that accepts tokenized contexts and generates a response
            response = self.language_model.generate_response(relevant_contexts)
            return response

    @staticmethod
    def calculate_new_alpha(current_loss, initial_loss, initial_alpha=1.0, final_alpha=0.1):
        """
        Calculate a new alpha value based on the current loss.
        """
        if current_loss >= initial_loss:
            return initial_alpha  # Keep initial alpha if loss isn't decreasing

        loss_ratio = current_loss / initial_loss
        alpha_range = initial_alpha - final_alpha
        new_alpha = final_alpha + (alpha_range * loss_ratio)
        return new_alpha  
    
    @staticmethod
    def setup_optimizer(model, learning_rate, weight_decay, warmup_steps, total_steps):
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

        # Linear warmup with cosine decay
        scheduler = LambdaLR(optimizer, lambda step: min((step + 1) / warmup_steps, 0.5 * (1 + math.cos(math.pi * step / total_steps))))

        return optimizer, scheduler
        
    # DPO Training
    def train_dpo(self, train_loader, optimizer, config, save_path):
            self.train()  # Set the model to training mode
            total_loss = 0

            for batch_idx, batch in enumerate(train_loader):
                input_ids_question = batch['input_ids_question'].to(config.device)
                input_ids_chosen = batch['input_ids_chosen'].to(config.device)
                input_ids_rejected = batch['input_ids_rejected'].to(config.device)
                labels = batch['labels'].to(config.device)
                print(f"train_dpo input_ids_question: {input_ids_question.shape}")
                print(f"train_dpo input_ids_chosen: {input_ids_chosen.shape}")
                print(f"train_dpo input_ids_rejected: {input_ids_rejected.shape}")
                print(f"train_dpo labels: {labels.shape}")

                optimizer.zero_grad()

                # Forward pass
                logit, loss = self.transformer_dpo(input_ids_question, input_ids_chosen, input_ids_rejected, labels)
                print(f"Logits shape: {logit.shape}, Labels shape: {labels.shape}")

                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            average_loss = total_loss / len(train_loader)
            print(f"Training complete. Average Loss: {average_loss}")
            
            # Save the model
            torch.save(self.transformer_dpo.state_dict(), save_path)

            return average_loss

    # DPR Training
    def train_dpr_encoders(self, train_data, context_encoder, question_encoder, optimizer_context, optimizer_question, epochs , context_save_path, question_save_path):
        loss_function = nn.CosineEmbeddingLoss()

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        for epoch in range(epochs):
            total_loss = 0

            for i in range(len(train_data["queries"])):
                query = train_data["queries"][i]
                context = train_data["contexts"][i]

                # Ensure query is a string
                if not isinstance(query, str):
                    raise ValueError("Query must be a string.")
                tokenized_query = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)

                # Ensure context is a string
                if isinstance(context, dict):
                    # If context is a dictionary, extract the text content. This is a placeholder and might need adjustment.
                    context_text = context.get("text", "")
                elif isinstance(context, str):
                    context_text = context
                else:
                    raise ValueError("Context must be a string or a dictionary containing a text field.")
                tokenized_context = tokenizer(context_text, return_tensors="pt", padding=True, truncation=True, max_length=512)

                question_embeddings = question_encoder(input_ids=tokenized_query['input_ids'], attention_mask=tokenized_query['attention_mask'])
                context_embeddings = context_encoder(input_ids=tokenized_context['input_ids'], attention_mask=tokenized_context['attention_mask'])

                # The labels tensor should have the same first dimension size as the input tensors
                labels = torch.tensor([1.0] * question_embeddings.size(0), dtype=torch.float).to(question_embeddings.device)

                loss = loss_function(question_embeddings, context_embeddings, labels)

                optimizer_context.zero_grad()
                optimizer_question.zero_grad()
                loss.backward()
                optimizer_context.step()
                optimizer_question.step()

                total_loss += loss.item()

            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_data['queries'])}")
        average_loss = total_loss / len(train_data['queries'])
        torch.save(context_encoder.state_dict(), context_save_path)
        torch.save(question_encoder.state_dict(), question_save_path)
        return (context_encoder, question_encoder), average_loss
       
    # LMT Training
    # Define the function to check token IDs
    @staticmethod
    def check_token_ids_for_embedding(token_ids, vocab_size):
        if token_ids.max() >= vocab_size:
            print(f"Out-of-range token ID found: {token_ids.max()}. Max allowed: {vocab_size - 1}")
            return False
        else:
            print("All token IDs are within the expected range.")
            return True

    def train_language_model_transformer(self, train_loader, device, vocab_size, save_path):
        # Reference the correct attribute for the model
        model = self.language_model_transformer.to(device)
        
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-8, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.98)
        num_epochs = 5
        
        initial_loss = None

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            # Assuming you have a method or way to toggle QLORA within your model, if applicable
            # model.toggle_qlora(True)
            total_loss = 0

            for batch_idx, batch in enumerate(train_loader):
                inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)

                # Check if any token IDs are out-of-range before proceeding
                if not LanguageExpert.check_token_ids_for_embedding(inputs, vocab_size):
                    print("Halting training due to out-of-range token IDs")
                    return None, None  # You might want to handle this situation more gracefully
                print(f"Input device: {inputs.device}, Input shape: {inputs.shape} , Input Type: {inputs.type}")
                print("Model device:", next(self.language_model_transformer.parameters()).device)

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

                # Check for NaN in loss
                if math.isnan(loss.item()):
                    print("Encountered NaN loss, stopping training")
                    break

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()

                total_loss += loss.item()

                # Set the initial_loss after the first batch of the first epoch
                if initial_loss is None and batch_idx == 0:
                    initial_loss = loss.item()

            scheduler.step()

            # Check for NaN in total_loss
            if math.isnan(total_loss):
                print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
                break
            else:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

            # Average loss for the epoch
            average_loss = total_loss / len(train_loader)

            # Update alpha at the end of each epoch based on the average loss, if your model uses QLORA or a similar mechanism

        average_loss = total_loss / len(train_loader)
        torch.save(model.state_dict(), save_path)
        return model, average_loss


##################################################################
    
# Test Language Expert
def load_corpus(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        texts = [line.strip() for line in file.readlines()]
    return texts

texts = load_corpus("D:\\EXPERT_WEIGHTS\\sample.txt")

num_merges = 100
def adapt_vocab_for_wordpiece(ssp_vocab):
    adapted_vocab = {}
    for token, id_or_freq in ssp_vocab.items():
        if not token.startswith(" ") and not token.endswith("</w>"):
            adapted_token = "##" + token.replace("</w>", "")
        else:
            adapted_token = token.replace("</w>", "")
        adapted_vocab[adapted_token] = id_or_freq
    return adapted_vocab

# Initialize and train the SimpleSentencePiece model with BPE
ssp = LanguageExpert.SimpleSentencePiece(model_type="bpe", vocab_size=30522)
# Assume `texts` is a list of text to train the tokenizer
ssp.train('\n'.join(texts))
wordpiece_vocab = adapt_vocab_for_wordpiece(ssp.vocab)
# Debugging step to ensure vocabulary completeness
def debug_vocab(adapted_vocab):
    print("Sample Vocabulary Check:")
    # Iterate over the first 10 key-value pairs in the adapted vocabulary
    for i, (token, id_or_freq) in enumerate(adapted_vocab.items()):
        print(f"{token}: {id_or_freq}")
        if i >= 9:  # Stop after printing 10 entries
            break
    # Specifically check for subtokens if your tokenizer expects them
    subtokens = [token for token in adapted_vocab.keys() if token.startswith("##")]
    print(f"Found {len(subtokens)} subtokens in vocabulary.")

# Ensure wordpiece_vocab is a list of vocabulary tokens
# debug_vocab(wordpiece_vocab)  # Call this after initializing wordpiece_vocab

# Initialize WordPiece with the adapted vocabulary
wordpiece_tokenizer = LanguageExpert.WordPiece(wordpiece_vocab, unk_token_id=0, unk_token="[UNK]")


class WikiTextDatasetForLM(Dataset):
    def __init__(self, texts, tokenizer, sequence_length):
        self.tokenizer = tokenizer
        self.sequence_length = sequence_length
        self.inputs, self.labels = self.process_texts(texts)
        
    def process_texts(self, texts):
        inputs, labels = [], []
        step_size = 256  # Example step size for overlapping sequences
        for text in texts:
            token_ids = self.tokenizer.tokenize(text)
            for i in range(0, len(token_ids) - self.sequence_length, step_size):  # Adjust step_size for overlap
                inputs.append(token_ids[i:i+self.sequence_length])
                labels.append(token_ids[i+1:i+self.sequence_length+1])
        return torch.tensor(inputs, dtype=torch.long), torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]

####################################################################################

# Initialize configuration
config = ModelConfig()
language_expert = LanguageExpert(config)
####################################################################################
# LMT Training
# Load the wikitext-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")

def generate_attention_mask(token_ids):
    """Generate an attention mask for the given token IDs."""
    return [1 if token_id != 0 else 0 for token_id in token_ids]

def tokenize_and_prepare_labels(examples):
    token_ids = [wordpiece_tokenizer.tokenize(text) for text in examples["text"]]
    labels = [[wordpiece_tokenizer.unk_token_id] + ids[:-1] for ids in token_ids]  # Shift labels
    return {"input_ids": token_ids, "labels": labels}

# Load dataset and apply tokenization
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")
tokenized_datasets = dataset.map(tokenize_and_prepare_labels, batched=True, remove_columns=["text"])

def custom_collate_fn(batch):
    batch_input_ids = [torch.tensor(item['input_ids']) for item in batch]
    batch_labels = [torch.tensor(item['labels']) for item in batch]
    input_ids_padded = pad_sequence(batch_input_ids, batch_first=True, padding_value=0)
    labels_padded = pad_sequence(batch_labels, batch_first=True, padding_value=-100)
    attention_masks_padded = (input_ids_padded != 0).long()
    return {'input_ids': input_ids_padded, 'attention_mask': attention_masks_padded, 'labels': labels_padded}

# Initialize DataLoader with the custom collate function
train_loader = DataLoader(tokenized_datasets, batch_size=8, shuffle=True, collate_fn=custom_collate_fn)
# Define save path for the trained model
save_path = 'D:/EXPERT_WEIGHTS/lmt_expert_trained_custom_tokenizer.pth'

# Train the LMT sub-model within the Expert system
trained_model, average_loss = language_expert.train_language_model_transformer(
    train_loader=train_loader, 
    device=config.device, 
    vocab_size=config.vocab_size, 
    save_path=save_path
)

print(f"Training complete. Model saved to {save_path}. Average Loss: {average_loss}")





Trie Construction Completed Successfully
Trie built successfully.


  StockPickler.save(self, obj, save_persistent_id)
  StockPickler.save(self, obj, save_persistent_id)


Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

All token IDs are within the expected range.
Input device: cpu, Input shape: torch.Size([8, 128]) , Input Type: <built-in method type of Tensor object at 0x0000027A15D06510>
Model device: cpu
Before embedding: input_ids.shape=torch.Size([8, 128]), max input_id=2287, vocab_size=2749
After embedding: torch.Size([8, 128, 512])
x_reshaped shape: torch.Size([8, 128, 8, 64])
values : torch.Size([8, 128, 8, 64])
queries : torch.Size([8, 128, 8, 128])
keys : torch.Size([8, 128, 8, 128])
attention_scores after random projection: torch.Size([8, 8, 128, 128])
PARTITION START
C_keys shape before return: torch.Size([8, 128, 8, 128])
Partition Start 0, Partition End 128 , ponder_scores: torch.Size([8, 8, 128, 1])
BEFORE 1ST EINSUM:
ponder_scores_permuted shape: torch.Size([8, 128, 8, 1])
ponder_scores_broadcastable shape: torch.Size([8, 128, 8, 128])
queries_part shape: torch.Size([8, 128, 8, 128])
C_keys shape: torch.Size([8, 128, 8, 128])
AFTER 1ST EINSUM:
energy shape: torch.Size([8, 128, 8, 128]

KeyboardInterrupt: 

In [44]:
import json
import re
import collections

class Tokenizer:
    class TrieNode:
        def __init__(self):
            self.children = {}
            self.token_id = None
            self.frequency = 0
            self.failure_link = None
            self.is_end = False  
            self.token = None  

    class Tokenize:  
        def __init__(self, bpe_vocab_size=30522, wordpiece_vocab_size=30522, unk_token="[UNK]", unk_token_id=0, num_merges=100):
            self.bpe_model = None  # Initialize BPE model
            self.wordpiece_model = None  # Initialize WordPiece model
            self.bpe_vocab_size = bpe_vocab_size
            self.wordpiece_vocab_size = wordpiece_vocab_size
            self.unk_token = unk_token
            self.unk_token_id = unk_token_id
            self.num_merges = num_merges

        def train(self, text):
            # Step 1: Train the BPE model
            self.bpe_model = Tokenizer.BPE(self.bpe_vocab_size, self.unk_token_id) 
            self.bpe_model.train(text)

            # Step 2: Apply BPE encoding to the text
            bpe_encoded_text = " ".join(str(token) for token in self.bpe_model.encode(text))
            print("bpe_encoded_text vocabulary:", bpe_encoded_text)  # Inspect contents

            # Step 3: Build vocabulary for WordPiece from BPE-encoded text
            wordpiece_vocab = self.build_wordpiece_vocab(bpe_encoded_text)
            print("WordPiece vocabulary:", wordpiece_vocab)  # Inspect contents

            # Step 4: Initialize WordPiece model dynamically using vocabulary
            self.wordpiece_model = Tokenizer.WordPiece(wordpiece_vocab, unk_token_id=self.unk_token_id) 

        def encode(self, text):
            """
            Encodes text using the trained BPE and WordPiece models.
            """
            if not self.bpe_model or not self.wordpiece_model:
                raise ValueError("Tokenizer has not been trained yet.")

            # Step 1: Apply BPE encoding
            bpe_tokens = self.bpe_model.encode(text)

            # Step 2: Convert BPE token IDs to tokens
            bpe_text = " ".join(self.bpe_model.decode(bpe_tokens))

            # Step 3: Apply WordPiece encoding
            wordpiece_ids = self.wordpiece_model.tokenize(bpe_text)

            return wordpiece_ids

        def decode(self, ids):
            if not self.bpe_model or not self.wordpiece_model:
                raise ValueError("Tokenizer has not been trained yet.")

            # Step 1: WordPiece decoding 
            wordpiece_text = self.wordpiece_model.decode(ids)

            # Step 2: Convert WordPiece text to BPE tokens
            bpe_tokens = wordpiece_text.split(" ")

            # Step 3: Decode BPE tokens to original text
            bpe_token_ids = [self.bpe_model.vocab.get(token, self.unk_token_id) for token in bpe_tokens]
            decoded_text = self.bpe_model.decode(bpe_token_ids)

            return decoded_text
        
        def build_wordpiece_vocab(self, text):
            """
            Builds a vocabulary for WordPiece from BPE-encoded text.
            """
            words = text.split(" ")  # Split the BPE-encoded text into words
            vocab = collections.Counter(words) 
            return vocab

    class BPE:
        def __init__(self, num_merges, unk_token_id=0):  # Accept unk_token_id parameter
            self.vocab = {}
            self.merges = []
            self.num_merges = num_merges
            self.unk_token_id = unk_token_id  # Store the unknown token ID
            self.unk_token = "[UNK]"  # Add this line to define the unknown token

        def train(self, text):
            words = re.findall(r'\w+|[^\w\s]', text, re.UNICODE)
            vocab = collections.Counter(words)
            vocab = {word + '</w>': count for word, count in vocab.items()}
            print("Initial vocabulary:", vocab)  # Check the initial vocabulary

            for _ in range(self.num_merges):  # Use the num_merges from the instance variable
                pairs = self.get_stats(vocab)
                if not pairs:
                    break
                best = max(pairs, key=pairs.get)
                vocab = self.merge_vocab(best, vocab)
                self.merges.append(best)
                print("Vocabulary after merge:", vocab)  # Check how it evolves

            self.vocab = {word: i for i, word in enumerate(vocab.keys())}

        @staticmethod
        def get_stats(vocab):
            pairs = collections.defaultdict(int)
            for word, freq in vocab.items():
                symbols = word.split()
                for i in range(len(symbols)-1):
                    pairs[symbols[i], symbols[i+1]] += freq
            return pairs

        @staticmethod
        def merge_vocab(pair, vocab):
            v_out = {}
            bigram = re.escape(' '.join(pair))
            p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
            for word in vocab:
                w_out = p.sub(''.join(pair), word)
                v_out[w_out] = vocab[word]
            return v_out

        def encode(self, text):
            """Encode text into subwords using learned BPE merges."""
            encoded_tokens = []
            for word in re.findall(r'\w+|[^\w\s]', text, re.UNICODE):
                word += '</w>'
                subwords = [word]  # Start with the entire word as one subword
                for merge in self.merges:
                    new_subwords = []
                    for subword in subwords:
                        # If the merge is in subword, split it; otherwise, keep it as is
                        if ' '.join(merge) in subword:
                            new_subwords.extend(subword.replace(' '.join(merge), ''.join(merge)).split(' '))
                        else:
                            new_subwords.append(subword)
                    subwords = new_subwords
                encoded_tokens.extend(subwords)
                print("Subwords:", subwords)   # Check subword generation
                print("BPE Vocabulary:", self.vocab) # Ensure the vocabulary is populated
                return [str(self.vocab.get(token, self.unk_token_id)) for token in encoded_tokens]
        
            # New method to save trained model
        def save_model(self, filepath):
            bpe_data = {
                'merges': self.merges,
                'vocab': self.vocab,
                'num_merges': self.num_merges,
                # Include other attributes as needed
            }
            with open(filepath, 'w') as f:
                json.dump(bpe_data, f)

        def load_model(self, filepath):
            with open(filepath, 'r') as f:
                bpe_data = json.load(f)
            
            self.merges = bpe_data['merges']
            self.vocab = bpe_data['vocab']
            self.num_merges = bpe_data['num_merges']

        def decode(self, ids):
            """Decode a list of BPE token IDs back into the original text."""
            decoded_text = ""
            for id_ in ids:
                token = self.vocab.get(id_, self.unk_token)
                decoded_text += token.replace("</w>", "")  # Remove the '</w>' marker
            return decoded_text

    class WordPiece:
        def __init__(self, vocab, unk_token_id=0, unk_token="[UNK]"):
            self.vocab = vocab
            self.unk_token_id = unk_token_id
            self.unk_token = unk_token  # Define the unknown token
            self.root = self.build_trie(vocab)
            self.id_to_token = {id_: token for token, id_ in vocab.items()}  # Inverse mapping
            self.compute_failure_links(self.root)
            print("Trie built successfully.")

        def convert_ids_to_tokens(self, ids):
            """
            Convert a list of token ids back to their string token representations.
            """
            return [self.id_to_token.get(id_, self.unk_token) for id_ in ids]

        # Add debug prints to build_trie to confirm structure
        def build_trie(self, vocab):
            root = Tokenizer.TrieNode()
            for token in vocab:
                node = root
                for char in token:
                    if char not in node.children:
                        node.children[char] = Tokenizer.TrieNode()
                    node = node.children[char]
                node.is_end = True
                node.token = token
            print("Trie Construction Completed Successfully")
            return root


        def compute_failure_links(self, root):
            queue = [root]
            while queue:
                current_node = queue.pop(0)
                for char, child_node in current_node.children.items():
                    failure_node = current_node.failure_link
                    while failure_node and char not in failure_node.children:
                        failure_node = failure_node.failure_link
                    child_node.failure_link = failure_node.children[char] if failure_node else root
                    queue.append(child_node)

        # Improved debug prints in tokenize method
                    
        def tokenize(self, text):
            # Preprocess input text
            text = self.preprocess_text(text)
            node = self.root
            token_ids = []  # Will store token IDs instead of tokens
            i = 0
            print("Preprocessed Text:", text)
            while i < len(text):
                char = text[i]
                if char == ' ':
                    node = self.root
                    i += 1
                    print("Current Node:", node.token)  # Track how the trie is traversed
                    print("Token IDs:", token_ids)   
                    continue

                if char not in node.children:
                    if node != self.root and node.token is not None:
                        # Convert found token to its ID
                        token_id = self.vocab.get(node.token, self.unk_token_id)
                        token_ids.append(token_id)
                        node = self.root  # Reset to root
                        continue
                    else:
                        # Append unknown token ID
                        token_ids.append(self.unk_token_id)
                        i += 1
                        continue

                node = node.children[char]
                if node.is_end:
                    if i + 1 == len(text) or text[i + 1] == ' ':
                        # Convert found token to its ID
                        token_id = self.vocab.get(node.token, self.unk_token_id)
                        token_ids.append(token_id)
                        node = self.root

                i += 1

            #print(f"Token IDs: {token_ids[:10]}")
            return token_ids

        def preprocess_text(self, text):
            # Convert text to lowercase to ensure case insensitivity
            text = text.lower()

            # Optionally, handle punctuation by adding spaces around it for better tokenization
            # This depends on how your vocabulary handles punctuation
            text = re.sub(r'([.,!?()])', r' \1 ', text)

            # Replace multiple spaces with a single space
            text = re.sub(r'\s+', ' ', text)

            # Trim leading and trailing spaces
            text = text.strip()

            return text


# Example Usage
tokenizer = Tokenizer.Tokenize(bpe_vocab_size=30522, wordpiece_vocab_size=30522, num_merges=200) 
with open("D:\\EXPERT_WEIGHTS\\sample.txt", 'r', encoding='utf-8') as f: 
    text = f.read()
tokenizer.train(text) 
encoded_text = tokenizer.encode(text)
print(encoded_text)


Initial vocabulary: {'import</w>': 8, 'numpy</w>': 1, 'as</w>': 107, 'np</w>': 3, 'torch</w>': 19, '.</w>': 1079, 'nn</w>': 14, 'functional</w>': 2, 'F</w>': 2, 'from</w>': 89, 'utils</w>': 1, 'data</w>': 33, 'DataLoader</w>': 3, ',</w>': 910, 'Dataset</w>': 2, 'cuda</w>': 3, 'amp</w>': 1, 'GradScaler</w>': 2, 'autocast</w>': 2, 'optim</w>': 1, 'Adam</w>': 4, 'def</w>': 9, 'random_projection</w>': 3, '(</w>': 414, 'matrix</w>': 12, 'k</w>': 20, ')</w>': 416, ':</w>': 170, '"</w>': 97, 'Random</w>': 4, 'projection</w>': 4, 'to</w>': 333, 'reduce</w>': 2, 'dimensionality</w>': 1, 'of</w>': 345, 'dimensions</w>': 5, 'random_matrix</w>': 2, '=</w>': 149, 'randn</w>': 2, 'size</w>': 15, '-</w>': 434, '1</w>': 124, 'device</w>': 18, 'return</w>': 7, 'matmul</w>': 1, 'cur_decomposition</w>': 2, 'projection_dim</w>': 21, '#</w>': 59, 'Change</w>': 2, "'</w>": 16, 'argument</w>': 1, 'Applies</w>': 1, 'CUR</w>': 1, 'decomposition</w>': 1, 'with</w>': 87, 'C</w>': 8, 'dimension</w>': 5, 'aligned<

# Multi-Modal Data Processing

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import cv2
import numpy as np
#############################################
# Dataset class
video_folder_path = "D:\MMM_Data\Video"
image_folder_path = "D:\MMM_Data\Image"
text_folder_path = "D:\MMM_Data\Text"

class MultimodalDataset(Dataset):
    def __init__(self, image_paths, video_paths, texts, tokenizer, transform=None):
        self.image_paths = image_paths
        self.video_paths = video_paths
        self.texts = texts
        self.tokenizer = tokenizer
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        # Load image
        image = Image.open(self.image_paths[idx])
        if self.transform:
            image = self.transform(image)

        # Load video (simplified: loading only the first frame for example)
        video_cap = cv2.VideoCapture(self.video_paths[idx])
        ret, frame = video_cap.read()
        if ret:
            # Convert BGR to RGB and apply same transformations as for the image
            frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
            frame = self.transform(frame)
        video_cap.release()

        # Process text
        text = self.tokenizer(self.texts[idx], return_tensors="pt", padding='max_length', truncation=True, max_length=512)

        return image, frame, text.input_ids.squeeze(), text.attention_mask.squeeze()

def multimodal_collate_fn(batch):
    images, frames, texts, attention_masks = zip(*batch)
    images = torch.stack(images)
    frames = torch.stack(frames)
    texts = torch.nn.utils.rnn.pad_sequence(texts, batch_first=True, padding_value=0)
    attention_masks = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=0)
    return images, frames, texts, attention_masks


# Assuming vocab is your vocabulary dictionary where keys are tokens and values are token IDs
vocab = {...}  # Your vocabulary here

# Instantiate your WordPiece tokenizer
wordpiece_tokenizer = WordPiece(vocab, unk_token_id=vocab.get("[UNK]", 0), unk_token="[UNK]")

# Example paths and texts remain the same
image_paths = ["path/to/image1.jpg", "path/to/image2.jpg"]
video_paths = ["path/to/video1.mp4", "path/to/video2.mp4"]
texts = ["This is a description for the first item", "This is a description for the second item"]

# Create dataset with custom tokenizer
dataset = MultimodalDataset(image_paths, video_paths, texts, wordpiece_tokenizer)

# Create DataLoader as before
dataloader = DataLoader(dataset, batch_size=2, collate_fn=multimodal_collate_fn)

# Example usage
for images, frames, texts, attention_masks in dataloader:
    print(images.shape, frames.shape, texts.shape, attention_masks.shape)
    # Route each modality input to the relevant model part