# Expert

In [14]:
import os

# Set the HF_HOME environment variable to a new cache directory on the D drive
os.environ['HF_HOME'] = 'D:/hf_datasets_cache'
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# Test on larger text
import re
import collections
from collections import Counter, defaultdict
import json


# 0. Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import fitz  # PyMuPDF
from transformers import BertModel
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from tqdm.notebook import tqdm
from torch.nn.utils.rnn import pad_sequence




class ExpertConfig:
    def __init__(self,wordpiece_vocab,wordpiece_tokenizer, cls_token_id=1770, 
                 sep_token_id=1771, pad_token_id=0, 
                 seq_len=512, head_dim=64, block_size=64, sparsity_factor=4,
                 input_dim=512, num_experts=3, vocab_size=30522, embed_size=256,
                 num_layers=6, forward_expansion=4, heads=8, dropout=0.1,
                 max_length=512,  # Updated to match seq_len for consistency
                 rank=16, device='cpu',
                 mamba_model_path='D:\\EXPERT_WEIGHTS\\mamba_model_weights.pth',
                 rag_model_path='D:\\EXPERT_WEIGHTS\\rag_model_weights.pth',
                 context_encoder_path='D:\\EXPERT_WEIGHTS\\context_encoder.pth',
                 language_model_path='D:\\EXPERT_WEIGHTS\\language_model.pth',
                 question_encoder_path='D:\\EXPERT_WEIGHTS\\question_encoder.pth',
                 dpo_model_path='D:\\EXPERT_WEIGHTS\\dpo_model_weights.pth',
                 model_name='bert-base-uncased', embedding_dim=768,
                 alpha=1, quantization_bits=8, tokenizer_name='bert-base-uncased',                 
                  freq_threshold=100, d_model=512, d_state=2048, d_conv=3, expansion_factor=2, 
                 clip_gradient = 1.0, mamba_learning_rate = 5e-4, weight_decay = 0.1,
                 warmup_steps = 10, total_mamba_steps = 100
                ):

        # Common hyperparameters
        self.freq_threshold = freq_threshold
        self.wordpiece_vocab = wordpiece_vocab
        self.wordpiece_tokenizer = wordpiece_tokenizer        
        self.seq_len = seq_len
        self.pad_token_id = pad_token_id
        self.cls_token_id = cls_token_id
        self.sep_token_id = sep_token_id
        self.head_dim = head_dim
        self.block_size = block_size
        self.sparsity_factor = sparsity_factor
        self.input_dim = input_dim
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.num_layers = num_layers
        self.forward_expansion = forward_expansion
        self.heads = heads
        self.dropout = dropout
        self.max_length = max_length  
        self.rank = rank

        # Model paths and device
        self.mamba_model_path = mamba_model_path
        self.rag_model_path = rag_model_path
        self.context_encoder_path = context_encoder_path
        self.language_model_path = language_model_path
        self.question_encoder_path = question_encoder_path
        self.dpo_model_path = dpo_model_path
        self.device = device

        # Unique hyperparameters
        self.num_experts = num_experts
        self.model_name = model_name
        self.embedding_dim = embedding_dim
        self.alpha = alpha
        self.quantization_bits = quantization_bits
        self.tokenizer_name = tokenizer_name
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor
        self.clip_gradient = clip_gradient
        self.mamba_learning_rate = mamba_learning_rate
        self.weight_decay = weight_decay
        self.warmup_steps = warmup_steps
        self.total_mamba_steps = total_mamba_steps

        # PDFs (unchanged)
        self.pdf_file_paths = [
            #r'C:\Users\robbi\IEEMM\DPO.pdf', 
            r'C:\Users\robbi\OneDrive\AI_Papers_Research\DPO.pdf',
            #r'C:\Users\robbi\IEEMM\MAMBA.pdf',
            r'C:\Users\robbi\OneDrive\AI_Papers_Research\DPO.pdf',

            #r'C:\Users\robbi\IEEMM\QLORA.pdf',
            r'C:\Users\robbi\OneDrive\AI_Papers_Research\DPO.pdf',

            #r'C:\Users\robbi\IEEMM\RAG.pdf',
            r'C:\Users\robbi\OneDrive\AI_Papers_Research\DPO.pdf',

            #r'C:\Users\robbi\IEEMM\SWITCH_TRANSFORMER.pdf'
            r'C:\Users\robbi\OneDrive\AI_Papers_Research\DPO.pdf',

        ]
        
        # Preserving original dataset loading functionality
        self.rag_dataset = Expert.TransformerRAG.create_dataset_from_pdfs(self.pdf_file_paths, self.wordpiece_tokenizer)

        
    @staticmethod
    def adapt_vocab_for_wordpiece(ssp_vocab):
        adapted_vocab = {}
        for token, id_or_freq in ssp_vocab.items():
            if not token.startswith(" ") and not token.endswith("</w>"):
                adapted_token = "##" + token.replace("</w>", "")
            else:
                adapted_token = token.replace("</w>", "")
            adapted_vocab[adapted_token] = id_or_freq
        return adapted_vocab

        
    def validate(self):
        assert self.seq_len % self.block_size == 0, "seq_len must be divisible by block_size"
        assert self.max_length >= self.seq_len, "max_length should be equal to or greater than seq_len"



class Expert(nn.Module):

    @staticmethod
    # Load model weights function
    def load_model_weights(model, model_path):
        #checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        checkpoint = torch.load(model_path, map_location=config.device)

        if isinstance(checkpoint, dict):
            # Check for 'state_dict' or 'model_state_dict' keys
            if 'state_dict' in checkpoint:
                model.load_state_dict(checkpoint['state_dict'])
            elif 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                # If no known key is found, try loading it as a raw state dictionary
                try:
                    model.load_state_dict(checkpoint)
                except RuntimeError as e:
                    raise ValueError(f"Error loading state dict: {e}")
        elif isinstance(checkpoint, nn.Module):
            # If the checkpoint is a model object, assign it directly
            model = checkpoint
        else:
            raise ValueError(f"Unsupported checkpoint format: {type(checkpoint)}")

        model.eval()
        return model

    ###############################
    # Flash2_Attention
    class FlashAttention2(nn.Module):
        def __init__(self, sequence_length, head_dimension, block_size):
            super(Expert.FlashAttention2, self).__init__()
            self.block_size = block_size
            # Ensure that sequence_length is divisible by block_size for simplicity
            assert sequence_length % block_size == 0

        def forward(self, Q, K, V):
            # Partitioning of inputs
            Q_blocks, K_blocks, V_blocks = self.partition_inputs(Q, K, V)

            # Efficient computation of the attention mechanism
            outputs = []
            for i, Q_block in enumerate(Q_blocks):
                output_block = self.process_block(Q_block, K_blocks, V_blocks)
                outputs.append(output_block)

            # Concatenating the processed blocks
            output = torch.cat(outputs, dim=0)
            return output

        def partition_inputs(self, Q, K, V):
            # Calculate the number of chunks, ensuring it's at least 1
            num_chunks_Q = max(1, Q.size(0) // self.block_size)
            num_chunks_K = max(1, K.size(0) // self.block_size)
            num_chunks_V = max(1, V.size(0) // self.block_size)
            
            # Chunk the inputs
            Q_blocks = Q.chunk(chunks=num_chunks_Q, dim=0)
            K_blocks = K.chunk(chunks=num_chunks_K, dim=0)
            V_blocks = V.chunk(chunks=num_chunks_V, dim=0)
            
            return Q_blocks, K_blocks, V_blocks

        def process_block(self, Q_block, K_blocks, V_blocks):
            # Process each block efficiently as per FLASH2's optimized method
            # This includes computing QK^T, applying online softmax, and multiplying with V
            output_blocks = []
            for K_block, V_block in zip(K_blocks, V_blocks):
                print(f"Q_block: {Q_block.shape} , K_block.transpose(-2, -1): {K_block.transpose(-2, -1).shape}")
                attention_scores = torch.matmul(Q_block, K_block.transpose(-2, -1))
                print(f"attention_scores = torch.matmul(Q_block, K_block.transpose(-2, -1)): {attention_scores.shape}")
                attention_scores = attention_scores.float()
                attention_scores = self.online_softmax(attention_scores)
                output_block = torch.matmul(attention_scores, V_block.float())
                print(f"output_block shape:, {output_block.shape}") 
                output_blocks.append(output_block)

            # Summing up the results from each block
            output_block_sum = sum(output_blocks)
            return output_block_sum

        def online_softmax(self, scores, chunk_size=128):
            # Apply softmax in chunks for large sequences
            softmaxed_scores = []
            for i in range(0, scores.size(0), chunk_size):
                chunk = scores[i:i + chunk_size, :]
                softmaxed_chunk = F.softmax(chunk, dim=1)
                softmaxed_scores.append(softmaxed_chunk)
            return torch.cat(softmaxed_scores, dim=0)

    ###############################
    # SparseFlash2_Attention
    class SparseFlash2Attention(nn.Module):
        def __init__(self, seq_len, head_dim, blk_size, sparsity_factor):
            super().__init__()
            self.flash_attention = Expert.FlashAttention2(seq_len, head_dim, blk_size)
            self.seq_len = seq_len
            self.head_dim = head_dim
            self.block_size = blk_size  # Storing block_size as an instance variable
            self.sparsity_factor = sparsity_factor

        def generate_sparsity_mask(self):
            mask = torch.zeros(self.seq_len, self.seq_len)
            step = self.sparsity_factor
            for i in range(0, self.seq_len, step):
                mask[i:i + step, :] = 1
            return mask.bool()

        def forward(self, Q, K, V):
            output = self.flash_attention(Q, K, V)  # output shape: [sequence_length, head_dimension]
            print(f"Initial output shape: {output.shape}")

            # Reshape output to be 3D for batch matrix multiplication
            output = output.unsqueeze(0)  # New shape: [1, sequence_length, head_dimension]
            print(f"output.unsqueeze(0) shape: {output.shape}")

            sparsity_mask = self.generate_sparsity_mask()  # shape: [sequence_length, sequence_length]
            print(f"Sparsity mask shape: {sparsity_mask.shape}")

            # Apply the sparsity mask to the output
            sparsity_mask = sparsity_mask.unsqueeze(0)  # New shape: [1, sequence_length, sequence_length]
            print(f"sparsity mask sparsity_mask.unsqueeze(0) shape: {sparsity_mask.shape}")
            output = torch.bmm(sparsity_mask.float(), output.float().transpose(1, 2))  
            print(f"output from torch.bmm shape: {output.shape}")
            output = output.transpose(1, 2)
            # Reshape the output back to 2D
            output = output.squeeze(0)  # New shape: [sequence_length, head_dimension]

            return output


    ###############################
    # MAMBA
    # RoPE
    class RotaryPositionalEncoding(nn.Module):
        def __init__(self, dim, max_len=5000):
            super().__init__()
            self.dim = dim
            inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
            t = torch.arange(max_len).type_as(inv_freq)
            freqs = torch.einsum('n , d -> n d', t, inv_freq)
            self.register_buffer('sin', freqs.sin())
            self.register_buffer('cos', freqs.cos())

        def forward(self, x):
            n, _, device = x.shape[1], self.dim // 2, x.device
            sin, cos = self.sin[:n].to(device), self.cos[:n].to(device)

            # Apply RoPE to even and odd indices separately
            x_even = x[..., :self.dim:2] * cos.unsqueeze(0) + torch.roll(x[..., 1:self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
            x_odd = x[..., 1:self.dim:2] * cos.unsqueeze(0) - torch.roll(x[..., :self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
            return torch.cat((x_even, x_odd), dim=-1)

    # SWIGLU
    class SwiGLU(nn.Module):
        def __init__(self, dim_in, dim_out):
            super(Expert.SwiGLU, self).__init__()
            self.fc1 = nn.Linear(dim_in, dim_out)
            self.fc2 = nn.Linear(dim_in, dim_out)

        def forward(self, x):
            gate = torch.sigmoid(self.fc2(x))
            return self.fc1(x) * gate

    class SimplifiedMAMBA(nn.Module):
        def __init__(self, config):
            super().__init__()
            # Using configuration parameters
            self.num_layers = config.num_layers
            self.d_model = config.d_model
            self.d_state = config.d_state
            self.d_conv = config.d_conv
            self.expansion_factor = config.expansion_factor
            
            self.feedforward = nn.Sequential(
                nn.Linear(self.d_model, self.d_state),
                nn.GELU(),
                nn.Linear(self.d_state, self.d_model)
            )
            self.input_embedding = nn.Linear(self.d_model, self.d_model)
            self.convs = nn.Sequential(*[nn.Conv1d(self.d_model, self.d_model, kernel_size=self.d_conv, padding=(self.d_conv // 2)) for _ in range(self.num_layers)])
            self.swiglu = Expert.SwiGLU(self.d_model, self.d_model)
            self.output_projection = nn.Linear(self.d_model, self.d_model * self.expansion_factor)

            self.initialize_weights()

        def initialize_weights(self):
            gain = nn.init.calculate_gain('relu')

            nn.init.orthogonal_(self.input_embedding.weight, gain)
            nn.init.normal_(self.input_embedding.bias, mean=0, std=0.01)

            nn.init.kaiming_uniform_(self.convs[-1].weight, a=math.sqrt(5))
            nn.init.zeros_(self.convs[-1].bias)

            nn.init.xavier_uniform_(self.feedforward[0].weight, gain=nn.init.calculate_gain('relu'))
            nn.init.zeros_(self.feedforward[0].bias)

            nn.init.xavier_uniform_(self.feedforward[2].weight, gain=nn.init.calculate_gain('linear'))
            nn.init.zeros_(self.feedforward[2].bias)

            nn.init.xavier_uniform_(self.output_projection.weight, gain=nn.init.calculate_gain('linear'))
            nn.init.zeros_(self.output_projection.bias)

        def forward(self, inputs, attention_mask=None):
            print("Input shape:", inputs.shape)

            # Apply the attention mask if provided
            if attention_mask is not None:
                inputs = inputs * attention_mask.unsqueeze(-1)

            projected_inputs = self.input_embedding(inputs)
            print("projected_inputs pre-reshape shape:", projected_inputs.shape)

            projected_inputs = projected_inputs.permute(0, 2, 1)
            print("projected_inputs post-reshape shape:", projected_inputs.shape)

            for conv in self.convs:
                projected_inputs = conv(projected_inputs)

            projected_inputs = projected_inputs.permute(0, 2, 1)
            print("projected_inputs post convolution reshape:", projected_inputs.shape)

            projected_inputs = self.swiglu(projected_inputs)
            print("projected_inputs post swiglu shape:", projected_inputs.shape)

            output = self.output_projection(projected_inputs)
            print("output shape:", output.shape)

            return output

    class SimplifiedLanguageModelMAMBA(nn.Module):
        def __init__(self, config):
            super().__init__()
            # Using configuration parameters
            self.vocab_size = config.vocab_size
            self.num_layers = config.num_layers
            self.d_model = config.d_model
            self.d_state = config.d_state
            self.d_conv = config.d_conv
            self.expansion_factor = config.expansion_factor

            self.embedding = nn.Embedding(self.vocab_size, self.d_model)
            self.pos_encoder = Expert.RotaryPositionalEncoding(self.d_model)
            self.simplified_mamba = Expert.SimplifiedMAMBA(config)
            self.output_projection = nn.Linear(self.d_model * 2, self.vocab_size)  # Adjust if needed

            self.initialize_weights()

        def initialize_weights(self):
            gain = 1.0

            nn.init.orthogonal_(self.embedding.weight, gain)
            nn.init.xavier_uniform_(self.output_projection.weight, gain=gain)

        def forward(self, input_values, attention_mask):
            print(f"SimplifiedLanguageModelMAMBA fwd input_values: {input_values.shape}")
            print(f"SimplifiedLanguageModelMAMBA fwd input_values: {input_values.dtype}")
            if input_values.dtype != torch.long:
                input_values = input_values.long()
            embedded = self.embedding(input_values) * math.sqrt(self.d_model)
            embedded = self.pos_encoder(embedded)
            simplified_mamba_output = self.simplified_mamba(embedded, attention_mask)
            logits = self.output_projection(simplified_mamba_output)
            return logits

    ###############################
    # Switch Router

    class SwitchRouter(nn.Module):
        CAPACITY_FACTOR = 1  # Class constant

        class SwitchGate(nn.Module):
            def __init__(self, input_dim, num_experts):
                super(Expert.SwitchRouter.SwitchGate, self).__init__()
                self.fc1 = nn.Linear(input_dim, input_dim // 2)
                self.fc2 = nn.Linear(input_dim // 2, num_experts)

            def forward(self, x):
                x = F.relu(self.fc1(x))
                gate_scores = F.softmax(self.fc2(x), dim=-1)
                return gate_scores

        def __init__(self, config):
            super(Expert.SwitchRouter, self).__init__()
            self.config = config
            self.device = config.device
            self.router = self.SwitchGate(config.input_dim, config.num_experts).to(config.device)
            self.transformer_rag = Expert.TransformerRAG(config).to(config.device)
            self.lmt = Expert.LanguageModelTransformer(config).to(config.device)
            self.transformer_dpo = Expert.DPO(self.lmt, config.device,config.embed_size).to(config.device)
            self.mamba = Expert.SimplifiedLanguageModelMAMBA(config).to(config.device)
            self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
            self.input_embedding = nn.Linear(config.input_dim, config.input_dim)

        def forward(self, x, attention_mask, context_texts, question_text):
            x = x.to(self.device)
            x = self.input_embedding(x)
            gate_scores = self.router(x)
            expert_indices = torch.argmax(gate_scores, dim=1)
            final_output = torch.zeros_like(x)

            for i, expert in enumerate(self.experts):
                mask = expert_indices == i
                if mask.any():
                    selected_inputs = x[mask]
                    selected_attention_mask = attention_mask[mask].to(self.device)
                    if isinstance(expert, Expert.TransformerRAG):
                        expert_output = expert.forward(context_texts, selected_inputs, selected_attention_mask, question_text)
                    elif isinstance(expert, Expert.DPO):  # Check if the expert is an instance of DPO
                        # Use the forward_expert method for DPO within the routing process
                        expert_output = expert.forward_expert(selected_inputs, selected_attention_mask, context_texts, question_text)
                    else:
                        # For other experts, continue using the standard forward method
                        expert_output = expert(selected_inputs, selected_attention_mask)
                    final_output[mask] = expert_output

            aux_loss = self.auxiliary_loss(gate_scores)
            return final_output, aux_loss

        def auxiliary_loss(self, gate_scores):
            expert_load = gate_scores.sum(0) / gate_scores.size(0)
            loss_balancing = torch.std(expert_load)
            return loss_balancing


        @staticmethod
        def route_inputs(expert_indices, gate_scores, num_experts):
            capacity_factor_tensor = torch.tensor([Expert.SwitchRouter.CAPACITY_FACTOR], dtype=torch.float32)
            capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
            expert_counts = torch.zeros(num_experts, dtype=torch.int32)
            for idx in range(len(expert_indices)):
                selected_expert = expert_indices[idx]
                if expert_counts[selected_expert] < capacities[0]:
                    expert_counts[selected_expert] += 1
                else:
                    available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
                    if len(available_experts) > 0:
                        alternative_expert = available_experts[0]
                        expert_indices[idx] = alternative_expert
                        expert_counts[alternative_expert] += 1
                    else:
                        print("No available experts to reroute. Handling overflow.")
            return expert_indices
    
    ###############################
    # RAG
    
    class PositionalEncoding(nn.Module):
        def __init__(self, d_model, max_len=10000):
            super(Expert.PositionalEncoding, self).__init__()
            self.d_model = d_model
            self.max_len = max_len

            # Create positional encodings
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)

            # Add a batch dimension (B x T x C)
            pe = pe.unsqueeze(0)
            self.register_buffer('pe', pe)

        def forward(self, x):
            # x: Tensor of shape [Batch Size, Sequence Length, Embedding Dimension]
            # Adjust positional encoding to match the input size and device
            pe = self.pe[:, :x.size(1)]
            # Assuming x is on the correct device, pe will be automatically aligned to the same device
            return pe


    class AdaptiveDropoutLayer(nn.Module):
        def __init__(self, init_dropout_rate=0.1):
            super(Expert.AdaptiveDropoutLayer, self).__init__()
            # Use logit transformation for stability
            self.log_alpha = nn.Parameter(torch.tensor(math.log(init_dropout_rate / (1 - init_dropout_rate))).float())

        def forward(self, x):
            p = torch.sigmoid(self.log_alpha)
            # Convert p from a tensor to a float
            p_value = p.item()  # This extracts the scalar value as a Python float
            return nn.functional.dropout(x, p=p_value, training=self.training)


    class MultiHeadLinformerAttention(nn.Module):
        def __init__(self, embed_dim, num_heads, k=None):
            super().__init__()  # Ensure this is called first
            self.embed_dim = embed_dim
            self.num_heads = num_heads
            self.k = k if k is not None else embed_dim // num_heads  # Projection dimension per head

            self.key_projections = nn.Linear(embed_dim, self.k * num_heads)
            self.value_projections = nn.Linear(embed_dim, self.k * num_heads)
            self.out_projection = nn.Linear(self.k * num_heads, embed_dim)

        def forward(self, query, attention_mask=None):
            batch_size, seq_len, _ = query.size()
            
            # Project keys and values
            keys = self.key_projections(query)
            values = self.value_projections(query)
            
            # Reshape into [batch_size, num_heads, seq_len, k]
            keys = keys.reshape(batch_size, seq_len, self.num_heads, self.k).transpose(1, 2)
            values = values.reshape(batch_size, seq_len, self.num_heads, self.k).transpose(1, 2)
            
            # Calculate attention (scaled dot-product attention)
            # Scaling by the square root of the depth of the key vectors to prevent large values in the dot product
            # which could push the softmax function into regions where it has extremely small gradients
            keys = keys / (self.k ** 0.5)
            attention_scores = torch.softmax(torch.matmul(keys, values.transpose(-2, -1)), dim=-1)
            if attention_mask is not None:
                attention_scores = attention_scores.masked_fill(~attention_mask.bool(), float('-inf'))
                # Recalculate softmax for masked scores
                attention_scores = torch.softmax(attention_scores, dim=-1)
            # Apply attention to values
            out = torch.matmul(attention_scores, values)
            
            # Concatenate heads and project back to original embedding dimension
            out = out.transpose(1, 2).reshape(batch_size, seq_len, self.num_heads * self.k)
            out = self.out_projection(out)
            
            return out


    class AdaptiveEmbeddingLayer(nn.Module):
        def __init__(self, vocab,  vocab_size, freq_threshold, large_embed_dim, small_embed_dim, max_seq_len):
            super(Expert.AdaptiveEmbeddingLayer, self).__init__()
            self.vocab = vocab
            self.vocab_size = vocab_size
            self.freq_threshold = freq_threshold
            self.large_embed_dim = large_embed_dim
            self.small_embed_dim = small_embed_dim
            self.max_seq_len = max_seq_len

            self.split_vocab(vocab, freq_threshold)  

            self.frequent_embeddings = nn.Embedding(num_embeddings=len(self.frequent_vocab), embedding_dim=large_embed_dim)
            self.infrequent_embeddings = nn.Embedding(num_embeddings=len(self.infrequent_vocab), embedding_dim=small_embed_dim)
            self.infrequent_projection = nn.Linear(small_embed_dim, large_embed_dim)
            self.positional_embeddings = Expert.PositionalEncoding(large_embed_dim, max_seq_len)


        def split_vocab(self, vocab, freq_threshold):
            token_counts = [(token, count) for token, count in vocab.items()]
            token_counts.sort(key=lambda x: -x[1])  # Sort by frequency
            split_point = next(i for i, (_, count) in enumerate(token_counts) if count < freq_threshold)
            
            self.frequent_vocab = dict(token_counts[:split_point])
            self.infrequent_vocab = dict(token_counts[split_point:])

        def forward(self, token_ids):
            device = token_ids.device
            seq_len = token_ids.size(1)
            batch_size = token_ids.size(0)  # Obtain batch size from token_ids tensor

            # Initialize embeddings tensor
            embeddings = torch.zeros(token_ids.shape[0], seq_len, self.large_embed_dim, device=device)

            # Map token_ids to indices for frequent and infrequent vocab
            frequent_indices = torch.zeros_like(token_ids)
            infrequent_indices = torch.zeros_like(token_ids)
            
            for token_id, index in self.vocab.items():
                mask = token_ids == token_id
                if token_id in self.frequent_vocab:
                    # Map to index in frequent_vocab
                    frequent_indices[mask] = self.frequent_vocab[token_id]
                elif token_id in self.infrequent_vocab:
                    # Map to index in infrequent_vocab
                    infrequent_indices[mask] = self.infrequent_vocab[token_id]

            # Create masks for frequent and infrequent tokens
            frequent_mask = frequent_indices > 0
            infrequent_mask = infrequent_indices > 0

            # Embed frequent tokens
            if frequent_mask.any():
                frequent_embeddings = self.frequent_embeddings(frequent_indices[frequent_mask])
                embeddings[frequent_mask] = frequent_embeddings

            # Embed and project infrequent tokens
            if infrequent_mask.any():
                infrequent_embeddings = self.infrequent_embeddings(infrequent_indices[infrequent_mask])
                infrequent_embeddings_projected = self.infrequent_projection(infrequent_embeddings)
                embeddings[infrequent_mask] = infrequent_embeddings_projected

            # Apply positional embeddings
            position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device).unsqueeze(0)
            position_embeddings = self.positional_embeddings(position_ids)  # Generate for seq_len

            # Ensure positional embeddings are broadcastable to the embeddings tensor
            # This step may not be necessary if your positional embeddings are already correctly shaped
            if position_embeddings.size(0) != batch_size:
                position_embeddings = position_embeddings.expand(batch_size, -1, -1)

            print(f"Embeddings shape: {embeddings.shape}")
            print(f"Positional embeddings shape: {position_embeddings.shape}")
            embeddings += position_embeddings

            return embeddings



    class DPRContextEncoder(nn.Module):
        def __init__(self, config):  # Accept an ExpertConfig instance
            super().__init__()
            self.wordpiece_tokenizer = config.wordpiece_tokenizer
            # Use attributes from config directly
            self.embedding_layer = Expert.AdaptiveEmbeddingLayer(
                config.wordpiece_vocab,
                config.vocab_size,
                config.freq_threshold,
                config.embedding_dim,
                config.embedding_dim // 4,  # Assuming you want to reduce the dimension for the infrequent tokens
                max_seq_len=config.max_length  # Use max_length from config
            )
            self.attention_layer = Expert.MultiHeadLinformerAttention(
                config.embedding_dim, 
                num_heads=config.heads
            )
            self.dropout = Expert.AdaptiveDropoutLayer(init_dropout_rate=config.dropout)  # Assuming dropout rate is defined in config

        def forward(self, input_ids, attention_mask):
            embeddings = self.embedding_layer(input_ids)
            attention_output = self.attention_layer(embeddings, attention_mask=attention_mask)
            attention_output = self.dropout(attention_output)

            # Mean pooling across the sequence length dimension
            pooled_output = attention_output.mean(dim=1)

            return pooled_output


    class DPRQuestionEncoder(nn.Module):
        def __init__(self, config):  # Accept an ExpertConfig instance
            super().__init__()
            self.wordpiece_tokenizer = config.wordpiece_tokenizer
            # Use attributes from config directly
            self.embedding_layer = Expert.AdaptiveEmbeddingLayer(
                config.wordpiece_vocab,
                config.vocab_size,
                config.freq_threshold,
                config.embedding_dim,
                config.embedding_dim // 4,  # Assuming you want to reduce the dimension for the infrequent tokens
                max_seq_len=config.max_length  # Use max_length from config
            )
            self.attention_layer = Expert.MultiHeadLinformerAttention(
                config.embedding_dim, 
                num_heads=config.heads
            )
            self.dropout = Expert.AdaptiveDropoutLayer(init_dropout_rate=config.dropout)

        def forward(self, input_ids, attention_mask):
            embeddings = self.embedding_layer(input_ids)
            attention_output = self.attention_layer(embeddings, attention_mask=attention_mask)
            attention_output = self.dropout(attention_output)

            # Mean pooling across the sequence length dimension
            pooled_output = attention_output.mean(dim=1)

            return pooled_output


    class TransformerRAG(nn.Module):
        def __init__(self, config):
            super(Expert.TransformerRAG, self).__init__()
            self.config = config
            self.embed_size = config.embed_size  
            self.context_encoder = Expert.DPRContextEncoder(config).to(config.device)          
            self.language_model = Expert.LanguageModelTransformer(config).to(config.device)
            self.question_encoder = Expert.DPRQuestionEncoder(config).to(config.device)
            self.tokenizer = config.wordpiece_tokenizer


        def forward(self, context_texts, question_input_ids, question_attention_mask, question_text):
            if question_input_ids.max() >= config.vocab_size:
                raise ValueError("question_input_ids contain ID(s) beyond the tokenizer's vocabulary size")
            
            aggregated_context_embeddings = []
            for context_list in context_texts:
                if not all(isinstance(context, dict) for context in context_list):
                    raise TypeError("Each item in context_texts must be a list of tokenized context dictionaries")
                
                aggregated_context_embedding = torch.zeros(config.embedding_dim, device=config.device)
                for context in context_list:
                    context_input_ids = torch.tensor(context['input_ids']).unsqueeze(0).to(config.device)  # Add unsqueeze(0) for batch dimension
                    context_attention_mask = torch.tensor(context['attention_mask']).unsqueeze(0).to(config.device)  # Add unsqueeze(0) for batch dimension
                    print(f"context_input_ids shape: {context_input_ids.shape}")
                    print(f"context_attention_mask shape: {context_attention_mask.shape}")

                    context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                    aggregated_context_embedding += context_embedding.mean(dim=0)
                
                aggregated_context_embeddings.append(aggregated_context_embedding / len(context_list))
            
            question_input_ids = question_input_ids.to(config.device).long()
            question_attention_mask = question_attention_mask.to(config.device).long()
            question_embeddings = self.question_encoder(input_ids=question_input_ids, attention_mask=question_attention_mask)
            
            cos_sim = torch.nn.CosineSimilarity(dim=1)
            similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in aggregated_context_embeddings]
            most_relevant_context_idx = torch.argmax(torch.tensor(similarities, device=config.device))
            
            combined_input = question_text + " " + context_texts[most_relevant_context_idx]
            tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
            tokenized_combined_input = {k: v.to(config.device) for k, v in tokenized_combined_input.items()}
            response_logits = self.language_model(**tokenized_combined_input)
            probabilities = F.softmax(response_logits.logits, dim=-1)
            predicted_token_ids = torch.argmax(probabilities, dim=-1)
            predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])
            #response = self.tokenizer.convert_tokens_to_string(predicted_tokens)
            response = " ".join(predicted_tokens).replace(" </w>", "").replace("</w>", " ").strip()
            
            return response

        @staticmethod
        def extract_text_from_pdf(file_path):
            text = ""
            with fitz.open(file_path) as doc:
                for page in doc:
                    text += page.get_text()
            return text

        @staticmethod
        def split_into_chunks(text, chunk_size):
            return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

        @staticmethod
        def preprocess_text(text, tokenizer, max_length=512):
            chunk_size = max_length - 50
            text_chunks = Expert.TransformerRAG.split_into_chunks(text, chunk_size)
            processed_chunks = []
            for chunk in text_chunks:
                # Tokenize the chunk using the WordPiece tokenizer
                token_ids = tokenizer.tokenize(chunk)
                
                # Manual padding and attention mask creation
                attention_mask = [1] * len(token_ids)
                # Padding: Extend token_ids and attention_mask to max_length
                while len(token_ids) < max_length:
                    token_ids.append(tokenizer.unk_token_id)  # Use unk_token_id for padding
                    attention_mask.append(0)  # Padding token has attention mask 0
                
                # Ensure token_ids and attention_mask are not longer than max_length
                token_ids = token_ids[:max_length]
                attention_mask = attention_mask[:max_length]
                
                processed_chunk = {
                    'input_ids': token_ids,
                    'attention_mask': attention_mask
                }
                processed_chunks.append(processed_chunk)
            return processed_chunks

        @staticmethod
        def create_dataset_from_pdfs(pdf_file_paths, tokenizer):
            dataset = []
            for file_path in pdf_file_paths:
                text = Expert.TransformerRAG.extract_text_from_pdf(file_path)
                processed_text = Expert.TransformerRAG.preprocess_text(text, tokenizer)                
                dataset.append(processed_text)
            return dataset

        def retrieve_contexts(self, dataset, query_embedding, top_k=5):
            similarity_scores = []
            for context in dataset:
                context_input_ids = context['input_ids'].to(self.config.device)
                context_attention_mask = context['attention_mask'].to(self.config.device)
                # Use the class's context_encoder
                context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                similarity = torch.matmul(query_embedding, context_embedding.T)
                similarity_scores.append(similarity.squeeze().item())
            top_k_indices = sorted(range(len(similarity_scores)), key=lambda i: similarity_scores[i], reverse=True)[:top_k]
            top_contexts = [dataset[i] for i in top_k_indices]
            return top_contexts

        def rag_retrieve_and_generate(self, dataset, query):
            # Tokenize the query using the class's tokenizer
            tokenized_query = self.tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.config.device)
            input_ids = tokenized_query['input_ids']
            attention_mask = tokenized_query['attention_mask']
            # Use the class's question_encoder
            encoded_query = self.question_encoder(input_ids, attention_mask)
            relevant_contexts = self.retrieve_contexts(dataset, encoded_query)
            # Assuming generate_response is a method of LanguageModelTransformer that accepts tokenized contexts and generates a response
            response = self.language_model.generate_response(relevant_contexts)
            return response

    ###############################
    # LORA
    class LORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1):
            super(Expert.LORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha

            # Original weight and bias of the linear layer
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            #print("self.weight Shape:", self.weight.shape)
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # LORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def forward(self, x):
            #print("LORALayer Input Shape:", x.shape)
            
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)
            
            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            #print("LORALayer Output Shape:", x.shape)

            return x

    # QLORA
    class QLORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
            super(Expert.QLORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha
            self.quantization_bits = quantization_bits

            # Original weight and bias
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # QLORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.dropout = nn.Dropout(0.1)
            self.layer_norm = nn.LayerNorm(output_dim)

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def quantize(self, x, num_bits):
            # Implement a simple quantization method
            scale = x.abs().max()
            x_quantized = torch.round(x / scale * (2**num_bits - 1))
            return x_quantized, scale

        def forward(self, x):
            #print("QLORALayer Input Shape:", x.shape)
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
            B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
            lora_adjustment = self.dropout(lora_adjustment)
            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)

            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            x = self.layer_norm(x)

            #print("QLORALayer Output Shape:", x.shape)

            return x
        
        def update_alpha(self, new_alpha):
            """
            Update the alpha scaling factor.
            """
            self.alpha = new_alpha

    ###############################
    # Language Model Transformer

    class MultiHeadAttention(nn.Module):
        def __init__(self, embed_size, heads):
            super(Expert.MultiHeadAttention, self).__init__()
            self.embed_size = embed_size
            self.heads = heads
            self.head_dim = embed_size // heads

            assert (
                self.head_dim * heads == embed_size
            ), "Embedding size needs to be divisible by heads"

            self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

        def forward(self, values, keys, query, mask):
            N = query.shape[0]
            value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

            # Split the embedding into self.heads different pieces
            values = values.reshape(N, value_len, self.heads, self.head_dim)
            keys = keys.reshape(N, key_len, self.heads, self.head_dim)
            queries = query.reshape(N, query_len, self.heads, self.head_dim)

            values = self.values(values)
            keys = self.keys(keys)
            queries = self.queries(queries)

            # Einsum does the matrix multiplication for query*keys for each training example
            # with every other training example, then sum it up
            attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

            if mask is not None:
                attention = attention.masked_fill(mask == 0, float("-1e20"))

            attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

            out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
                N, query_len, self.heads * self.head_dim
            )

            out = self.fc_out(out)
            return out

    class TransformerBlock(nn.Module):
        def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
            super(Expert.TransformerBlock, self).__init__()
            self.attention = Expert.MultiHeadAttention(embed_size, heads)
            self.norm1 = nn.LayerNorm(embed_size)
            self.norm2 = nn.LayerNorm(embed_size)

            self.feed_forward = nn.Sequential(
                Expert.LORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                Expert.LORALayer(forward_expansion * embed_size, embed_size, rank),
            )

            self.dropout = nn.Dropout(dropout)

        def forward(self, value, key, query, mask):
            attention = self.attention(value, key, query, mask)
            x = self.dropout(self.norm1(attention + query))
            forward = self.feed_forward(x)
            out = self.dropout(self.norm2(forward + x))
            return out

    class LanguageModelDecoder(nn.Module):
        def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
            super(Expert.LanguageModelDecoder, self).__init__()
            self.word_embedding = nn.Embedding(vocab_size, embed_size)
            self.position_embedding = nn.Embedding(max_length, embed_size)

            # Adding BatchNorm layers
            self.bn1 = nn.BatchNorm1d(embed_size)
            self.bn2 = nn.BatchNorm1d(embed_size)

            self.layers = nn.ModuleList(
                [
                    Expert.TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                    for _ in range(num_layers)
                ]
            )

            # QLORA layers
            self.qlora_feed_forward = nn.Sequential(
                Expert.QLORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                Expert.QLORALayer(forward_expansion * embed_size, embed_size, rank),
            )
            self.use_qlora = False  # Flag to toggle QLORA

            self.fc_out = nn.Linear(embed_size, vocab_size)
            self.dropout = nn.Dropout(dropout)

        def forward(self, x, trg_mask):
            N, seq_length = x.shape
            if seq_length > self.position_embedding.num_embeddings:
                raise ValueError(f"Sequence length {seq_length} exceeds maximum allowed {self.position_embedding.num_embeddings}")
            positions = torch.arange(0, seq_length).expand(N, seq_length).to(config.device)
            x = x.long() # make x a long type for the embeddings
            x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

            print(f"Max position index: {positions.max().item()}, Position Embedding Size: {self.position_embedding.num_embeddings}")


            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn1(x)
            x = x.transpose(1, 2)

            for layer in self.layers:
                x = layer(x, x, x, trg_mask)
                if self.use_qlora:
                    x = self.qlora_feed_forward(x)

            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn2(x)
            x = x.transpose(1, 2)

            out = self.fc_out(x)
            #print(f"shape of output of forward method of LanguageModelDecoder: {out.shape} ")

            return out

        def toggle_qlora(self, use_qlora: bool):
            self.use_qlora = use_qlora

    class LanguageModelTransformer(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.vocab_size = config.vocab_size  # Use vocab_size from ExpertConfig
            self.embed_size = config.embed_size  # Use embed_size from ExpertConfig
            self.num_layers = config.num_layers  # Use num_layers from ExpertConfig
            self.forward_expansion = config.forward_expansion  # Use forward_expansion from ExpertConfig
            self.heads = config.heads  # Use heads from ExpertConfig
            self.dropout = config.dropout  # Use dropout from ExpertConfig
            self.max_length = config.max_length  # Use max_length from ExpertConfig
            self.rank = config.rank  # Use rank from ExpertConfig
            self.tokenizer_name = config.tokenizer_name  # Use tokenizer_name from ExpertConfig

            self.tokenizer = BertTokenizer.from_pretrained(self.tokenizer_name)

            self.decoder = Expert.LanguageModelDecoder(
                self.vocab_size,
                self.embed_size,
                self.num_layers,
                self.heads,
                self.forward_expansion,
                self.dropout,
                self.max_length,
                self.rank,
            )

        def forward(self, trg, attention_mask=None):  # Remove attention_mask here since it's not used
            print(f"Input shape to LanguageModelTransformer: {trg.shape}")
            trg_mask = self.make_trg_mask(trg)
            out = self.decoder(trg, trg_mask)  # Do not pass attention_mask here
            print(f"Language Model Transformer out shape: {out.shape}")
            return out

        def make_trg_mask(self, trg):
            N, trg_len = trg.shape
            trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len).to(trg.device)
            return trg_mask

        def toggle_qlora(self, use_qlora: bool):
            self.decoder.toggle_qlora(use_qlora)

        def generate_response(self, input_ids, attention_mask):
            logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
            probabilities = F.softmax(logits, dim=-1)
            predicted_token_id = torch.argmax(probabilities, dim=-1)
            predicted_tokens = [self.tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
            response = self.tokenizer.convert_tokens_to_string(predicted_tokens)
            return response

    ###############################
    # DPO
    class DPO(nn.Module):
        def __init__(self, language_model, device, embed_size):
            super(Expert.DPO, self).__init__()
            self.language_model = language_model
            self.device = device
            # Assuming embed_size is accessible and correct
            self.projection = nn.Linear(language_model.vocab_size, embed_size)  # Project from vocab_size to embed_size
            self.classifier = nn.Linear(embed_size, 2)  # Assuming embed_size is accessible
            

        def forward(self, input_ids, labels=None):
            logits = self.language_model(input_ids)  # Output shape: [batch_size, seq_len, vocab_size]

            # Remove the flattening before pooling if you intend to apply mean pooling over seq_len
            # Ensure logits are correctly projected to embed_size per token
            projected_logits = self.projection(logits.view(-1, logits.size(-1)))
            projected_logits = projected_logits.view(logits.size(0), logits.size(1), -1)  # Reshape to [batch_size, seq_len, embed_size]

            # Apply global mean pooling across the sequence length dimension correctly
            pooled_logits = projected_logits.mean(dim=1)  # Correctly applies mean pooling across seq_len

            predictions = self.classifier(pooled_logits)

            loss = None

            print(f"logits shape: {logits.shape}")
            print(f"projected_logits shape: {projected_logits.shape}")
            print(f"pooled_logits shape: {pooled_logits.shape}")
            print(f"predictions shape: {predictions.shape}")
            if labels is not None:
                print(f"labels shape: {labels.shape}")
                # Ensure labels are flattened if they're not already 1D
                if labels.dim() > 1:
                    labels = labels.view(-1)  # Flatten labels to match predictions shape
                loss_fct = nn.CrossEntropyLoss()  # Correctly instantiate the loss function
                loss = loss_fct(predictions, labels)

            return predictions, loss
        '''
        def forward(self, input_ids, labels=None):
            # Process combined input_ids directly, assuming they already include question, chosen, and rejected inputs
            # This assumes input_ids has shape [batch_size, sequence_length]

            logits = self.language_model(input_ids)  # Output shape: [batch_size, sequence_length, vocab_size]
            # Flattening logits from [batch_size, seq_len, vocab_size] to [batch_size * seq_len, vocab_size]
            logits = logits.view(-1, logits.size(-1))  # UNCOMMENT FOR EXPERT TRAINING

            # Ensure labels are also flattened to match the logits shape if doing token-level classification
            labels = labels.view(-1)  # UNCOMMENT FOR EXPERT TRAINING
            # Project logits to embedding space before pooling
            projected_logits = self.projection(logits)  # New shape: [batch_size, sequence_length, embed_size]

            # Apply global mean pooling across the sequence length dimension
            pooled_logits = projected_logits.mean(dim=1)  # New shape: [batch_size, embed_size]

            # Pass the pooled representation through the classifier
            predictions = self.classifier(pooled_logits)  # New shape: [batch_size, 2]

            # Calculate loss if labels are provided
            loss = None
            if labels is not None:
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(predictions.view(-1, 2), labels.view(-1))

            return predictions, loss

        '''
     
        def forward_expert(self, input_ids, attention_mask=None, context_texts=None, question_text=None):
            """
            Special forward method designed for use within the Expert Model, where the DPO component
            is one of several being routed and managed.
            """
            # Assuming `input_ids` are pre-processed appropriately by the calling method
            # This method might need context_texts and question_text for other purposes, or you can remove them if not used

            logits = self.language_model(input_ids)
            # Process logits as needed specifically for the Expert Model scenario
            # For instance, you might not apply pooling or projection in the same way

            # Here, adapt the processing based on the Expert Model's requirements
            # This is just a placeholder to illustrate that the processing could be different
            predictions = self.classifier(logits.mean(dim=1))

            # Note: No loss calculation here, assuming Expert Model manages loss across its components
            return predictions     


    def __init__(self, config: ExpertConfig):
        super().__init__()
        # Validate configuration
        config.validate()

        # 1. SparseFlash2_attention
        self.sparse_flash2_attention = Expert.SparseFlash2Attention(
            config.seq_len, 
            config.head_dim, 
            config.block_size, 
            config.sparsity_factor
        )

        # Inside the Expert class definition
        self.lmt = Expert.LanguageModelTransformer(config).to(config.device)

        self.transformer_rag = Expert.TransformerRAG(config).to(config.device)

        # Corrected instantiation of SimplifiedLanguageModelMAMBA
        self.mamba = Expert.SimplifiedLanguageModelMAMBA(config).to(config.device)

        # Now initialize DPO with the language model transformer
        self.transformer_dpo = Expert.DPO(self.lmt, config.device, config.embed_size).to(config.device)

        # 2. LayerNorm and Dropout
        self.layer_norm = nn.LayerNorm(config.seq_len)
        self.dropout = nn.Dropout(config.dropout)

        # 3. Internal Switch Routing
        self.switch_router = Expert.SwitchRouter(config)

        # 4. QLORA Layer
        self.qlora = Expert.QLORALayer(config.input_dim, config.input_dim, config.rank)


    def forward(self, x, attention_mask, context_texts, question_text):
        # 1. SparseFlash2_attention
        x = self.sparse_flash2_attention(x, x, x)  # Pass x as Q, K, and V
        print(f"Shape after SparseFlash2Attention: {x.shape}") 

        # 2. LayerNorm and Dropout
        x = self.dropout(self.layer_norm(x))
        print(f"Shape after LayerNorm: {x.shape}") 

        # 3. Internal Switch Routing
        x, aux_loss = self.switch_router(x, attention_mask, context_texts, question_text)

        # 4. QLORA Layer
        x = self.qlora(x)

        return x, aux_loss

    @staticmethod
    def calculate_new_alpha(current_loss, initial_loss, initial_alpha=1.0, final_alpha=0.1):
        """
        Calculate a new alpha value based on the current loss.
        """
        if current_loss >= initial_loss:
            return initial_alpha  # Keep initial alpha if loss isn't decreasing

        loss_ratio = current_loss / initial_loss
        alpha_range = initial_alpha - final_alpha
        new_alpha = final_alpha + (alpha_range * loss_ratio)
        return new_alpha  
    
    @staticmethod
    def setup_optimizer(model, learning_rate, weight_decay, warmup_steps, total_steps):
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

        # Linear warmup with cosine decay
        scheduler = LambdaLR(optimizer, lambda step: min((step + 1) / warmup_steps, 0.5 * (1 + math.cos(math.pi * step / total_steps))))

        return optimizer, scheduler
    ###############################
    # TRAINING METHODS

    # DPO Training
    def train_dpo(self, train_loader, optimizer, config, save_path):
        self.train()  # Set the model to training mode
        total_loss = 0

        for batch_idx, batch in enumerate(train_loader):
            input_ids = batch['input_ids'].to(config.device)  # Adjusted to use 'input_ids'
            labels = batch['labels'].to(config.device)

            optimizer.zero_grad()
            # Forward pass: Adjust the model's forward method to accept the single 'input_ids' input
            logits, loss = self.transformer_dpo(input_ids, labels=labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        average_loss = total_loss / len(train_loader)
        print(f"Training complete. Average Loss: {average_loss}")

        # Save the model
        torch.save(self.transformer_dpo.state_dict(), save_path)

        return average_loss

    # RAG Training
    def train_language_model_rag(self, model, train_loader, device, vocab_size,num_epochs=5):
        
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-8, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.98)
        
        initial_loss = None

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            model.decoder.toggle_qlora(True)
            total_loss = 0

            for batch_idx, batch in enumerate(train_loader):
                inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                print("Output shape:", outputs.shape)
                print("Targets shape:", targets.shape)
                loss = criterion(outputs.contiguous().view(-1, 30522), targets.view(-1))


                # Check for NaN in loss
                if math.isnan(loss.item()):
                    print("Encountered NaN loss, stopping training")
                    break

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()

                total_loss += loss.item()

                # Set the initial_loss after the first batch of the first epoch
                if initial_loss is None and batch_idx == 0:
                    initial_loss = loss.item()

            scheduler.step()

            # Check for NaN in total_loss
            if math.isnan(total_loss):
                print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
                break
            else:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

            # Average loss for the epoch
            average_loss = total_loss / len(train_loader)

            # Update alpha at the end of each epoch based on the average loss
            new_alpha = self.calculate_new_alpha(average_loss, initial_loss)
            for layer in model.modules():
                if isinstance(layer, Expert.QLORALayer):
                    layer.update_alpha(new_alpha)

        average_loss = total_loss / len(train_loader)
        torch.save(model.state_dict(), save_path)
        return model, average_loss


    # DPR Training
    def train_dpr_encoders(self, train_data, context_encoder, question_encoder, optimizer_context, optimizer_question, epochs, context_save_path, question_save_path):
        loss_function = nn.CosineEmbeddingLoss()

        for epoch in range(epochs):
            total_loss = 0

            for i in range(len(train_data["queries"])):
                query = train_data["queries"][i]
                context_list = train_data["contexts"][i]  # This is a list of context dicts

                # Tokenize query using the tokenize method and convert token IDs back to tensors for model input
                tokenized_query = config.wordpiece_tokenizer.tokenize(query)
                input_ids_query = torch.tensor([tokenized_query], dtype=torch.long).to(config.device)
                attention_mask_query = torch.ones_like(input_ids_query).to(config.device)

                # Since context is a list of dictionaries with 'input_ids', process each and average embeddings
                context_embeddings_list = []
                for context in context_list:
                    if 'input_ids' in context:
                        input_ids_context = torch.tensor([context['input_ids']], dtype=torch.long).to(config.device)
                        attention_mask_context = torch.ones_like(input_ids_context, dtype=torch.bool).to(config.device)

                        # Adjusted for using input_ids and attention_mask directly
                        context_embedding = context_encoder(input_ids_context, attention_mask_context)
                        context_embeddings_list.append(context_embedding)

                # Average the context embeddings if there are multiple contexts
                if context_embeddings_list:
                    context_embeddings = torch.mean(torch.stack(context_embeddings_list), dim=0)
                else:
                    raise ValueError("No valid contexts found for averaging embeddings.")

                # Forward pass for the query
                question_embeddings = question_encoder(input_ids_query, attention_mask_query)

                # Compute loss with labels for positive examples
                labels = torch.tensor([1.0] * question_embeddings.size(0), dtype=torch.float).to(config.device)
                loss = loss_function(question_embeddings, context_embeddings, labels)
                optimizer_context.zero_grad()
                optimizer_question.zero_grad()
                loss.backward()
                optimizer_context.step()
                optimizer_question.step()

                total_loss += loss.item()

            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_data['queries'])}")

        average_loss = total_loss / len(train_data['queries'])
        torch.save(context_encoder.state_dict(), context_save_path)
        torch.save(question_encoder.state_dict(), question_save_path)
        return (context_encoder, question_encoder), average_loss

       
    # LMT Training
    def train_language_model_transformer(self, train_loader, device, vocab_size, save_path):
        model = self.lmt
        
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-8, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.98)
        num_epochs = 5
        
        initial_loss = None

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            model.decoder.toggle_qlora(True)
            total_loss = 0

            for batch_idx, batch in enumerate(train_loader):
                inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

                # Check for NaN in loss
                if math.isnan(loss.item()):
                    print("Encountered NaN loss, stopping training")
                    break

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()

                total_loss += loss.item()

                # Set the initial_loss after the first batch of the first epoch
                if initial_loss is None and batch_idx == 0:
                    initial_loss = loss.item()

            scheduler.step()

            # Check for NaN in total_loss
            if math.isnan(total_loss):
                print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
                break
            else:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

            # Average loss for the epoch
            average_loss = total_loss / len(train_loader)

            # Update alpha at the end of each epoch based on the average loss
            new_alpha = self.calculate_new_alpha(average_loss, initial_loss)
            for layer in model.modules():
                if isinstance(layer, Expert.QLORALayer):
                    layer.update_alpha(new_alpha)

        average_loss = total_loss / len(train_loader)
        torch.save(model.state_dict(), save_path)
        return model, average_loss

    # MAMBA Training
    def train_mamba(self, train_loader, num_epochs, config):
            # Initialize the optimizer and scheduler with MAMBA model parameters
            optimizer, scheduler = self.setup_optimizer(self.mamba, 
                                                        config.mamba_learning_rate, 
                                                        config.weight_decay, 
                                                        config.warmup_steps, 
                                                        config.total_mamba_steps)

            loss_fn = nn.CrossEntropyLoss()
            progress_bar = tqdm(range(num_epochs))

            for epoch in progress_bar:
                self.mamba.train()
                total_loss = 0

                for batch in train_loader:
                    input_values, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['labels']
                    input_values = input_values.to(config.device)
                    attention_mask = attention_mask.to(config.device)
                    labels = labels.to(config.device)

                    optimizer.zero_grad()

                    # Forward pass through MAMBA model
                    outputs = self.mamba(input_values, attention_mask)
                    
                    # Calculate loss
                    loss = loss_fn(outputs.view(-1, config.vocab_size), labels.view(-1))
                    loss.backward()
                    
                    # Clip gradients and perform an optimization step
                    torch.nn.utils.clip_grad_norm_(self.mamba.parameters(), config.clip_gradient)
                    optimizer.step()
                    scheduler.step()

                    total_loss += loss.item()

                avg_loss = total_loss / len(train_loader)
                progress_bar.set_description(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}")

            # Save the trained MAMBA model
            torch.save(self.mamba.state_dict(), config.mamba_model_path)
            print(f"MAMBA Training Complete. Model saved to {config.mamba_model_path}")



    # Full Expert Training
    def train_expert(self, train_loader, train_data, optimizer, main_loss_function, aux_loss_weight, device,save_path, accumulation_steps=4, num_epochs=5):
            self.train()  # Set the model to training mode
            for epoch in range(num_epochs):
                total_loss = 0
                optimizer.zero_grad()  # Initialize gradients to zero
                for batch_idx, batch in enumerate(train_loader):
                    inputs, attention_mask, targets = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)

                    # Calculate start and end indices for current batch in train_data
                    start_idx = batch_idx * batch['input_ids'].size(0)
                    end_idx = start_idx + batch['input_ids'].size(0)

                    # Extract current_queries and current_contexts for the batch
                    current_queries = train_data['queries'][start_idx:end_idx]
                    current_contexts = train_data['contexts'][start_idx:end_idx]

                    # Call to the model forward function
                    outputs, aux_loss = self(inputs, attention_mask, current_contexts, current_queries)

                    # Calculate loss and accumulate
                    main_loss = main_loss_function(outputs.view(-1, outputs.size(-1)), targets.view(-1))
                    loss = (main_loss + aux_loss_weight * aux_loss) / accumulation_steps
                    loss.backward()

                    if (batch_idx + 1) % accumulation_steps == 0:
                        optimizer.step()
                        optimizer.zero_grad()

                    total_loss += loss.item() * accumulation_steps  # Scale back up

                average_loss = total_loss / len(train_loader)
                print(f'End of Epoch {epoch+1}, Average Loss: {average_loss}')
            average_loss = total_loss / len(train_loader)
            torch.save(self.state_dict(), save_path)
            return self, average_loss



##################################################
# Tokenizer

class TrieNode:
    def __init__(self):
        self.children = {}
        self.token_id = None
        self.frequency = 0
        self.failure_link = None
        self.is_end = False  # Add is_end attribute to mark the end of a word
        self.token = None  # Add token attribute to store the token associated with the node


class Trie:
    def __init__(self, unk_token_id=0):
        self.root = TrieNode()
        self.unk_token_id = unk_token_id

    def insert(self, token, token_id, frequency):
        node = self.root
        for char in token:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
        node.token_id = token_id
        node.frequency = frequency

    def find_subwords(self, token):
        """Finds the most probable subwords based on frequency."""
        node = self.root
        best_subwords = []

        def dfs(current_node, subword='', collected_subwords=[]):
            if current_node.token_id is not None:
                # Update to correctly calculate total_frequency based on the structure of collected_subwords
                total_frequency = sum(n.frequency for _, _, n in collected_subwords) + current_node.frequency
                probability = current_node.frequency / total_frequency if total_frequency else 0
                collected_subwords.append((subword, probability, current_node))

            for char, next_node in current_node.children.items():
                dfs(next_node, subword + char, list(collected_subwords))  # Create a copy of the list to avoid shared state

        dfs(node)
        best_subwords = sorted(best_subwords, key=lambda x: x[1], reverse=True)
        return [subword for subword, _, _ in best_subwords][:5] or [self.unk_token_id]


    def compute_failure_links(self):
        root = self.root
        root.failure_link = root  # Root's failure link points to itself
        queue = [root]

        while queue:
            current_node = queue.pop(0)

            for char, child_node in current_node.children.items():
                queue.append(child_node)

                # Follow failure link to find the longest suffix for the child_node
                failure_candidate = current_node.failure_link
                while failure_candidate != root and char not in failure_candidate.children:
                    failure_candidate = failure_candidate.failure_link
                child_node.failure_link = failure_candidate.children.get(char, root)


class SimpleSentencePiece:
    def __init__(self, model_type="bpe", vocab_size=30522):
        self.vocab = {}
        self.id_to_subword = {}
        self.unk_token = "[UNK]"
        self.unk_token_id = 0
        self.vocab_size = vocab_size
        self.model = None if model_type == "bpe" else None
        self.model_type = model_type

    def train(self, text):
        if self.model_type == "bpe":
            self.model = BPE(num_merges=self.vocab_size, unk_token_id=self.unk_token_id)
            self.model.train(text)
            self.vocab = self.model.vocab
            self.id_to_subword = {i: word for word, i in self.vocab.items()}
        else:
            raise NotImplementedError(f"Model type {self.model_type} not supported yet.")

    def encode(self, text):
        text = self.preprocess_text(text)  # Preprocess text before encoding
        if not self.model:
            raise ValueError("Model has not been trained yet.")
        encoded = self.model.encode(text)
        print(f"Encoded: {encoded[:10]}")  # Print first 10 encoded tokens
        return encoded

    def decode(self, ids):
        if not self.id_to_subword:
            raise ValueError("Vocabulary is empty. Ensure the model is trained first.")
        text = " ".join([self.id_to_subword.get(id_, self.unk_token) for id_ in ids])
        text = text.replace(" </w>", "").replace("</w>", " ").strip()
        return text

    def preprocess_text(self, text):
        # Convert text to lowercase to ensure case insensitivity
        text = text.lower()
        # Optionally, handle punctuation by adding spaces around it for better tokenization
        text = re.sub(r'([.,!?()])', r' \1 ', text)
        # Replace multiple spaces with a single space
        text = re.sub(r'\s+', ' ', text)
        # Trim leading and trailing spaces
        text = text.strip()
        return text
    
    def save_model(self, filepath):
        model_data = {
            'vocab': self.vocab,
            'id_to_subword': self.id_to_subword,
            'model_type': self.model_type,
            'vocab_size': self.vocab_size,
            # Potentially include other relevant attributes
        }
        # Save the high-level tokenizer settings
        with open(filepath, 'w') as f:
            json.dump(model_data, f)
        
        # Now save the BPE model specifically
        if self.model_type == "bpe" and self.model:
            self.model.save_model(filepath + "_bpe")

    def load_model(self, filepath):
        with open(filepath, 'r') as f:
            model_data = json.load(f)
        
        self.vocab = model_data['vocab']
        self.id_to_subword = model_data['id_to_subword']
        self.model_type = model_data['model_type']
        self.vocab_size = model_data['vocab_size']
        
        # Assuming model_type is still "bpe", we now load the BPE model
        if self.model_type == "bpe":
            self.model = BPE(self.vocab_size, self.unk_token_id)
            self.model.load_model(filepath + "_bpe")

class BPE:
    def __init__(self, num_merges=100, unk_token_id=0):  # Accept unk_token_id parameter
        self.vocab = {}
        self.merges = []
        self.num_merges = num_merges
        self.unk_token_id = unk_token_id  # Store the unknown token ID

    def train(self, text):
        words = re.findall(r'\w+|[^\w\s]', text, re.UNICODE)
        vocab = collections.Counter(words)
        vocab = {word + '</w>': count for word, count in vocab.items()}
        
        for _ in range(self.num_merges):  # Use the num_merges from the instance variable
            pairs = self.get_stats(vocab)
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            self.merges.append(best)

        self.vocab = {word: i for i, word in enumerate(vocab.keys())}

    @staticmethod
    def get_stats(vocab):
        pairs = collections.defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols)-1):
                pairs[symbols[i], symbols[i+1]] += freq
        return pairs

    @staticmethod
    def merge_vocab(pair, vocab):
        v_out = {}
        bigram = re.escape(' '.join(pair))
        p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
        for word in vocab:
            w_out = p.sub(''.join(pair), word)
            v_out[w_out] = vocab[word]
        return v_out

    def encode(self, text):
        """Encode text into subwords using learned BPE merges."""
        encoded_tokens = []
        for word in re.findall(r'\w+|[^\w\s]', text, re.UNICODE):
            word += '</w>'
            subwords = [word]  # Start with the entire word as one subword
            for merge in self.merges:
                new_subwords = []
                for subword in subwords:
                    # If the merge is in subword, split it; otherwise, keep it as is
                    if ' '.join(merge) in subword:
                        new_subwords.extend(subword.replace(' '.join(merge), ''.join(merge)).split(' '))
                    else:
                        new_subwords.append(subword)
                subwords = new_subwords
            encoded_tokens.extend(subwords)
        return [self.vocab.get(token, self.unk_token_id) for token in encoded_tokens]
    
        # New method to save trained model
    def save_model(self, filepath):
        bpe_data = {
            'merges': self.merges,
            'vocab': self.vocab,
            'num_merges': self.num_merges,
            # Include other attributes as needed
        }
        with open(filepath, 'w') as f:
            json.dump(bpe_data, f)

    def load_model(self, filepath):
        with open(filepath, 'r') as f:
            bpe_data = json.load(f)
        
        self.merges = bpe_data['merges']
        self.vocab = bpe_data['vocab']
        self.num_merges = bpe_data['num_merges']


class WordPiece:
    def __init__(self, vocab, unk_token_id=0, unk_token="[UNK]"):
        self.vocab = vocab
        self.unk_token_id = unk_token_id
        self.unk_token = unk_token  # Define the unknown token
        self.root = self.build_trie(vocab)
        self.id_to_token = {id_: token for token, id_ in vocab.items()}  # Inverse mapping
        self.compute_failure_links(self.root)
        print("Trie built successfully.")

    def convert_ids_to_tokens(self, ids):
        """
        Convert a list of token ids back to their string token representations.
        """
        return [self.id_to_token.get(id_, self.unk_token) for id_ in ids]

    # Add debug prints to build_trie to confirm structure
    def build_trie(self, vocab):
        root = TrieNode()
        for token in vocab:
            node = root
            for char in token:
                if char not in node.children:
                    node.children[char] = TrieNode()
                node = node.children[char]
            node.is_end = True
            node.token = token
        print("Trie Construction Completed Successfully")
        return root


    def compute_failure_links(self, root):
        queue = [root]
        while queue:
            current_node = queue.pop(0)
            for char, child_node in current_node.children.items():
                failure_node = current_node.failure_link
                while failure_node and char not in failure_node.children:
                    failure_node = failure_node.failure_link
                child_node.failure_link = failure_node.children[char] if failure_node else root
                queue.append(child_node)

    # Improved debug prints in tokenize method
                
    def tokenize(self, text):
        # Preprocess input text
        text = self.preprocess_text(text)
        node = self.root
        token_ids = []  # Will store token IDs instead of tokens
        i = 0

        while i < len(text):
            char = text[i]
            if char == ' ':
                node = self.root
                i += 1
                continue

            if char not in node.children:
                if node != self.root and node.token is not None:
                    # Convert found token to its ID
                    token_id = self.vocab.get(node.token, self.unk_token_id)
                    token_ids.append(token_id)
                    node = self.root  # Reset to root
                    continue
                else:
                    # Append unknown token ID
                    token_ids.append(self.unk_token_id)
                    i += 1
                    continue

            node = node.children[char]
            if node.is_end:
                if i + 1 == len(text) or text[i + 1] == ' ':
                    # Convert found token to its ID
                    token_id = self.vocab.get(node.token, self.unk_token_id)
                    token_ids.append(token_id)
                    node = self.root

            i += 1

        print(f"Token IDs: {token_ids[:10]}")
        return token_ids

    def preprocess_text(self, text):
        # Convert text to lowercase to ensure case insensitivity
        text = text.lower()

        # Optionally, handle punctuation by adding spaces around it for better tokenization
        # This depends on how your vocabulary handles punctuation
        text = re.sub(r'([.,!?()])', r' \1 ', text)

        # Replace multiple spaces with a single space
        text = re.sub(r'\s+', ' ', text)

        # Trim leading and trailing spaces
        text = text.strip()

        return text


def load_corpus(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        texts = [line.strip() for line in file.readlines()]
    return texts

texts = load_corpus("D:\\EXPERT_WEIGHTS\\sample.txt")
# texts = load_corpus("C:/Users/robbi/Expert/sample.txt")
num_merges = 100

# Initialize and train the SimpleSentencePiece model with BPE
ssp = SimpleSentencePiece(model_type="bpe", vocab_size=30522)
# Assume `texts` is a list of text to train the tokenizer
ssp.train('\n'.join(texts))
wordpiece_vocab = ExpertConfig.adapt_vocab_for_wordpiece(ssp.vocab)
# Debugging step to ensure vocabulary completeness
def debug_vocab(adapted_vocab):
    print("Sample Vocabulary Check:")
    # Iterate over the first 10 key-value pairs in the adapted vocabulary
    for i, (token, id_or_freq) in enumerate(adapted_vocab.items()):
        print(f"{token}: {id_or_freq}")
        if i >= 9:  # Stop after printing 10 entries
            break
    # Specifically check for subtokens if your tokenizer expects them
    subtokens = [token for token in adapted_vocab.keys() if token.startswith("##")]
    print(f"Found {len(subtokens)} subtokens in vocabulary.")

# Ensure wordpiece_vocab is a list of vocabulary tokens
# debug_vocab(wordpiece_vocab)  # Call this after initializing wordpiece_vocab

# Initialize WordPiece with the adapted vocabulary
wordpiece_tokenizer = WordPiece(wordpiece_vocab, unk_token_id=0, unk_token="[UNK]")

def find_all_free_token_ids(vocab, max_id=30522):
    """
    Find all free token IDs up to a maximum ID.

    Parameters:
    - vocab: Dictionary, mapping from token to token ID.
    - max_id: Integer, the maximum token ID to consider.

    Returns:
    - List of free token IDs up to max_id.
    """
    used_ids = set(vocab.values())
    return [id_ for id_ in range(max_id) if id_ not in used_ids]

# Assuming wordpiece_vocab is your vocabulary
vocab = {token: id_ for token, id_ in wordpiece_vocab.items()}
free_ids = find_all_free_token_ids(vocab, max_id=30522)  # Adjust max_id as needed
#print("Free token IDs:", free_ids[:100])  # Print the first 100 free IDs for brevity



#################################################################################
# Instantiate Expert Model:
config = ExpertConfig(wordpiece_vocab=wordpiece_vocab, wordpiece_tokenizer=wordpiece_tokenizer)
expert_model = Expert(config)
##################################################################################
##################################################################################
# DPO TRAINING
'''# DPO Training
def preprocess_dpo_data(examples, tokenizer, max_length=512):
    # Calculate the length allocated to each component
    component_max_length = max_length // 3  # Dividing by 3 to equally distribute the max length

    # Tokenize text and ensure it fits within the allocated max_length for each component
    def tokenize_and_trim(text, tokenizer, max_length=component_max_length):
        token_ids = tokenizer.tokenize(text)
        # Trim to the max_length if necessary
        token_ids = token_ids[:max_length]
        return token_ids

    # Tokenize and adjust length for each field
    tokenized_questions = [tokenize_and_trim(question, tokenizer) for question in examples['question']]
    tokenized_chosen = [tokenize_and_trim(chosen, tokenizer) for chosen in examples['chosen']]
    tokenized_rejected = [tokenize_and_trim(rejected, tokenizer) for rejected in examples['rejected']]

    # Generate labels (adjust logic as necessary for your task)
    labels = [1 if i % 2 == 0 else 0 for i in range(len(tokenized_questions))]

    # Prepare final input IDs by concatenating the adjusted token IDs from each component
    # Note: This step may require adjustments based on your specific model input requirements.
    input_ids = [q + c + r for q, c, r in zip(tokenized_questions, tokenized_chosen, tokenized_rejected)]

    # Ensure concatenated input_ids do not exceed the total max_length
    input_ids = [ids[:max_length] for ids in input_ids]

    return {
        'input_ids': input_ids,  # Adjusted to return a single list of concatenated token IDs
        'labels': labels
    }


def updated_custom_collate_fn(batch):
    input_ids_list = [item['input_ids'] for item in batch]  # Assuming 'input_ids' is a list of token IDs
    labels = [item['labels'] for item in batch]

    # Convert list of token IDs to tensors
    input_ids_tensors = [torch.tensor(ids, dtype=torch.long) for ids in input_ids_list]
    labels_tensor = torch.tensor(labels, dtype=torch.long)

    # Pad the sequences so they all have the same length within this batch
    padded_input_ids = pad_sequence(input_ids_tensors, batch_first=True, padding_value=0)

    # Return a dictionary suitable for your model's input
    return {'input_ids': padded_input_ids, 'labels': labels_tensor}


# Assuming the rest of your code for dataset loading and tokenizer initialization remains unchanged
# Example usage:
dpo_dataset = load_dataset("Intel/orca_dpo_pairs")
max_seq_length = 512  # Adjust as needed
# Assuming wordpiece_tokenizer is an instance of your WordPiece class
dpo_dataset = dpo_dataset.map(lambda x: preprocess_dpo_data(x, wordpiece_tokenizer, max_seq_length), batched=True)

# Convert to PyTorch tensors after processing
dpo_dataset.set_format(type='torch', columns=['input_ids', 'labels'])

# Adjust the custom collate function to accept the tokenizer and max_length as arguments

train_loader = DataLoader(dpo_dataset['train'], batch_size=2, shuffle=True, collate_fn=updated_custom_collate_fn)

#train_loader = DataLoader(dpo_dataset['train'], batch_size=2, shuffle=True, collate_fn=custom_collate_fn)


# Instantiate the Expert model and optimizer
model_vocab_size = expert_model.lmt.decoder.word_embedding.num_embeddings
print(f"Model vocab size: {model_vocab_size}")
print(f"Model embedding size: {expert_model.lmt.decoder.word_embedding.embedding_dim}")
print(f"Configured max length: {config.max_length}")

optimizer = AdamW(expert_model.parameters(), lr=1e-5)
#save_path = 'D:/EXPERT_WEIGHTS/dpo_model.pth'
save_path = 'C:/Users/robbi/OneDrive/Expert_stuff/dpo_model.pth'

# Train the DPO model
label_column = 'labels'
input_columns = ['input_ids_question', 'attention_mask_question', 'input_ids_chosen', 'attention_mask_chosen', 'input_ids_rejected', 'attention_mask_rejected']
# Assuming expert_model is an instance of Expert
avg_loss = expert_model.train_dpo(train_loader, optimizer, config, save_path)
# Save the model
#torch.save(expert_model.transformer_dpo.state_dict(), save_path)
'''
##################################################################################
# LMT Training
'''
# Load the wikitext-2 dataset
#dataset = load_dataset("wikitext", "wikitext-2-v1")
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")

# Access a portion of the dataset for inspection
#print(dataset['train'][0])

def generate_attention_mask(token_ids):
    """Generate an attention mask for the given token IDs."""
    return [1 if token_id != 0 else 0 for token_id in token_ids]

def tokenize_function(examples):
    # Initialize lists for token IDs, attention masks, and labels
    token_ids, attention_masks, labels = [], [], []

    for text in examples["text"]:
        # Tokenize text and ensure at least one token ID is generated
        ids = wordpiece_tokenizer.tokenize(text) or [0]  # Replace [0] with your tokenizer's pad token ID
        # Generate attention mask for the tokenized text
        mask = [1] * len(ids)

        # Check for length of ids and pad/truncate as necessary
        if len(ids) < config.max_length:
            # Pad
            pad_length = config.max_length - len(ids)
            ids += [0] * pad_length  # Assuming 0 is the padding ID
            mask += [0] * pad_length
        else:
            # Truncate
            ids = ids[:config.max_length]
            mask = mask[:config.max_length]

        token_ids.append(ids)
        attention_masks.append(mask)
        labels.append(ids)  # For simplicity, using the same IDs as labels; adjust as needed for your model

    return {"input_ids": token_ids, "attention_mask": attention_masks, "labels": labels}



# Apply the custom tokenize function to the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_loader = DataLoader(tokenized_datasets, batch_size=8, shuffle=True)



# Define save path for the trained model
save_path = 'D:/EXPERT_WEIGHTS/lmt_expert_trained_custom_tokenizer.pth'


# Train the LMT sub-model within the Expert system
trained_model, average_loss = expert_model.train_language_model_transformer(
    train_loader=train_loader, 
    device=config.device, 
    vocab_size=config.vocab_size, 
    save_path=save_path
)

print(f"Training complete. Model saved to {save_path}. Average Loss: {average_loss}")
'''
##################################################################################
# MAMBA Training
'''from datasets import load_dataset
from torch.utils.data import DataLoader
import torch

# Assuming the Expert class, ExpertConfig class, and your custom tokenizer have already been defined.

config = ExpertConfig()
expert_system = Expert(config)
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")

from torch.nn.utils.rnn import pad_sequence
import torch

def tokenize_function(examples):
    # Directly tokenize the text into input IDs using the custom tokenizer
    tokenized_outputs = [torch.tensor(wordpiece_tokenizer.tokenize(text), dtype=torch.long) for text in examples["text"]]
    
    # Pad sequences for uniform input size
    padded_input_ids = pad_sequence(tokenized_outputs, batch_first=True, padding_value=wordpiece_tokenizer.unk_token_id)
    
    # Generate attention masks
    attention_masks = (padded_input_ids != wordpiece_tokenizer.unk_token_id).float()
    
    # Shift input IDs to create labels, padding the last position
    labels = torch.cat([padded_input_ids[:, 1:], torch.full((padded_input_ids.shape[0], 1), wordpiece_tokenizer.unk_token_id, dtype=torch.long)], dim=1)
    
    return {"input_ids": padded_input_ids, "attention_mask": attention_masks, "labels": labels}

from torch.nn.utils.rnn import pad_sequence

def custom_collate_fn(batch):
    # Extract input_ids, attention_mask, and labels from the batch
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]

    # Pad sequences to the maximum length in this batch
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=wordpiece_tokenizer.unk_token_id)
    attention_masks_padded = pad_sequence(attention_masks, batch_first=True, padding_value=0)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=wordpiece_tokenizer.unk_token_id)

    # Convert lists of tensors to a single tensor for each type of data
    batch = {
        'input_ids': input_ids_padded,
        'attention_mask': attention_masks_padded,
        'labels': labels_padded
    }

    return batch


# Apply the custom tokenize function to the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
train_loader = DataLoader(tokenized_datasets, batch_size=8, shuffle=True, collate_fn=custom_collate_fn)


# Train the MAMBA model
expert_system.train_mamba(train_loader, 5, config)
'''
####################################################################################
# RAG Transformer Training
# Load Wikipedia dataset and preprocess
#dataset = load_dataset("wikipedia", "20220301.en", split="train[:0000001%]")
'''
# Load the wikitext-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")
# Print the column names
#print(dataset.column_names)
# Access a portion of the dataset for inspection
#print(dataset['train'][0])

def generate_attention_mask(token_ids):
    """Generate an attention mask for the given token IDs."""
    return [1 if token_id != 0 else 0 for token_id in token_ids]

def tokenize_function(examples):
    token_ids, attention_masks, labels = [], [], []

    for text in examples["text"]:
        # Tokenize each word in the text and flatten the list of token IDs
        words = text.split()
        ids = []
        for word in words:
            word_ids = config.wordpiece_tokenizer.tokenize(word)
            ids.extend(word_ids)
        
        # Adjust for special tokens ([CLS] and [SEP])
        if len(ids) > config.max_length - 2:
            ids = ids[:config.max_length - 2]
        
        # Add [CLS] at the beginning and [SEP] at the end
        ids = [config.cls_token_id] + ids + [config.sep_token_id]

        attention_mask = [1] * len(ids)  # Attention mask with 1s for real tokens
        
        # Padding
        padding_length = config.max_length - len(ids)
        ids += [config.pad_token_id] * padding_length  # Pad token IDs
        attention_mask += [0] * padding_length  # Extend attention mask for padding

        token_ids.append(ids)
        attention_masks.append(attention_mask)
        labels.append(ids)  # Using the same IDs as labels; this might need adjustment

    return {"input_ids": token_ids, "attention_mask": attention_masks, "labels": labels}


# Apply the custom tokenize function to the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_loader = DataLoader(tokenized_datasets, batch_size=8, shuffle=True)

# Train the LMT_Rag sub-model within the Expert system
# Train th model
model, average_loss = expert_model.train_language_model_rag(
    expert_model.transformer_rag.language_model,
    train_loader,
    config.device,
    vocab_size= len(config.wordpiece_vocab),
    num_epochs=5
)

# Rag data
train_rag_data = {
    "queries": [
        # Queries for DPO.pdf
        "What is Direct Preference Optimization (DPO)?",
        "How does Direct Preference Optimization work?",
        "How can I implement Direct Preference Optimization in my organization?",
        "Why does Direct Preference Optimization improve the efficiency of language modelling?",
        # Queries for MAMBA.pdf
        "What is MAMBA?",
        "How does MAMBA function?",
        "How can I build a system based on MAMBA technology?",
        "Why does MAMBA enhance the performance of its application area?",
        # Queries for QLORA.pdf
        "What is QLORA?",
        "How does QLORA operate?",
        "How can I develop a project using QLORA?",
        "Why does QLORA improve the capabilities of its relevant field?",
        # Queries for RAG.pdf
        "What is Retrieval Augmented Generation (RAG)?",
        "How does Retrieval Augmented Generation work?",
        "How can I build a Retrieval Augmented Generation model?",
        "Why does Retrieval Augmented Generation enhance language model performance?",
        # Queries for SWITCH_TRANSFORMER.pdf
        "What is the Switch Transformer model?",
        "How does the Switch Transformer model operate?",
        "How can I construct a Switch Transformer model?",
        "Why does the Switch Transformer model improve language processing tasks?"
    ],
    "contexts": [
        # Contexts from DPO.pdf
        config.rag_dataset[0],  # Assuming dataset[0] is the processed content of DPO.pdf
        config.rag_dataset[0],
        config.rag_dataset[0],        
        config.rag_dataset[0],        
        # Contexts from MAMBA.pdf
        config.rag_dataset[1],  # Assuming dataset[1] is the processed content of MAMBA.pdf
        config.rag_dataset[1], 
        config.rag_dataset[1], 
        config.rag_dataset[1], 
        # Contexts from QLORA.pdf
        config.rag_dataset[2],  # Assuming dataset[2] is the processed content of QLORA.pdf
        config.rag_dataset[2],
        config.rag_dataset[2],
        config.rag_dataset[2],
        # Contexts from RAG.pdf
        config.rag_dataset[3],  # Assuming dataset[3] is the processed content of RAG.pdf
        config.rag_dataset[3],
        config.rag_dataset[3],
        config.rag_dataset[3],
        # Contexts from SWITCH_TRANSFORMER.pdf
        config.rag_dataset[4],  # Assuming dataset[4] is the processed content of SWITCH_TRANSFORMER.pdf
        config.rag_dataset[4],
        config.rag_dataset[4],
        config.rag_dataset[4],
    ]
}

# Train DPR encoders
(context_encoder, question_encoder), average_loss = expert_model.train_dpr_encoders(
    train_rag_data,
    expert_model.transformer_rag.context_encoder,  
    expert_model.transformer_rag.question_encoder,  
    optimizer_context = AdamW(expert_model.transformer_rag.context_encoder.parameters(), lr=1e-5),  
    optimizer_question = AdamW(expert_model.transformer_rag.question_encoder.parameters(), lr=1e-5),
    epochs=5,
    context_save_path=config.context_encoder_path,
    question_save_path=config.question_encoder_path
)
'''
####################################################################################
# Expert Training
'''
# Load Wikipedia dataset and preprocess
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")

# Access a portion of the dataset for inspection
#print(dataset['train'][0])

def generate_attention_mask(token_ids):
    """Generate an attention mask for the given token IDs."""
    return [1 if token_id != 0 else 0 for token_id in token_ids]

def tokenize_function(examples):
    # Initialize lists for token IDs, attention masks, and labels
    token_ids, attention_masks, labels = [], [], []

    for text in examples["text"]:
        # Tokenize text and ensure at least one token ID is generated
        ids = wordpiece_tokenizer.tokenize(text) or [0]  # Replace [0] with your tokenizer's pad token ID
        # Generate attention mask for the tokenized text
        mask = [1] * len(ids)

        # Check for length of ids and pad/truncate as necessary
        if len(ids) < config.max_length:
            # Pad
            pad_length = config.max_length - len(ids)
            ids += [0] * pad_length  # Assuming 0 is the padding ID
            mask += [0] * pad_length
        else:
            # Truncate
            ids = ids[:config.max_length]
            mask = mask[:config.max_length]

        token_ids.append(ids)
        attention_masks.append(mask)
        labels.append(ids)  # For simplicity, using the same IDs as labels; adjust as needed for your model

    return {"input_ids": token_ids, "attention_mask": attention_masks, "labels": labels}



# Apply the custom tokenize function to the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_loader = DataLoader(tokenized_datasets, batch_size=8, shuffle=True)

main_loss_function = torch.nn.CrossEntropyLoss()
aux_loss_weight = 0.1  # Adjust based on the significance of the auxiliary loss in your training

optimizer = torch.optim.AdamW(expert_model.parameters(), lr=1e-5, weight_decay=1e-4)
save_path = 'D:\\EXPERT_WEIGHTS\\expert_model_weights.pth'
# Rag data
train_rag_data = {
    "queries": [
        # Queries for DPO.pdf
        "What is Direct Preference Optimization (DPO)?",
        "How does Direct Preference Optimization work?",
        "How can I implement Direct Preference Optimization in my organization?",
        "Why does Direct Preference Optimization improve the efficiency of language modelling?",
        # Queries for MAMBA.pdf
        "What is MAMBA?",
        "How does MAMBA function?",
        "How can I build a system based on MAMBA technology?",
        "Why does MAMBA enhance the performance of its application area?",
        # Queries for QLORA.pdf
        "What is QLORA?",
        "How does QLORA operate?",
        "How can I develop a project using QLORA?",
        "Why does QLORA improve the capabilities of its relevant field?",
        # Queries for RAG.pdf
        "What is Retrieval Augmented Generation (RAG)?",
        "How does Retrieval Augmented Generation work?",
        "How can I build a Retrieval Augmented Generation model?",
        "Why does Retrieval Augmented Generation enhance language model performance?",
        # Queries for SWITCH_TRANSFORMER.pdf
        "What is the Switch Transformer model?",
        "How does the Switch Transformer model operate?",
        "How can I construct a Switch Transformer model?",
        "Why does the Switch Transformer model improve language processing tasks?"
    ],
    "contexts": [
        # Contexts from DPO.pdf
        config.rag_dataset[0],  # Assuming dataset[0] is the processed content of DPO.pdf
        config.rag_dataset[0],
        config.rag_dataset[0],        
        config.rag_dataset[0],        
        # Contexts from MAMBA.pdf
        config.rag_dataset[1],  # Assuming dataset[1] is the processed content of MAMBA.pdf
        config.rag_dataset[1], 
        config.rag_dataset[1], 
        config.rag_dataset[1], 
        # Contexts from QLORA.pdf
        config.rag_dataset[2],  # Assuming dataset[2] is the processed content of QLORA.pdf
        config.rag_dataset[2],
        config.rag_dataset[2],
        config.rag_dataset[2],
        # Contexts from RAG.pdf
        config.rag_dataset[3],  # Assuming dataset[3] is the processed content of RAG.pdf
        config.rag_dataset[3],
        config.rag_dataset[3],
        config.rag_dataset[3],
        # Contexts from SWITCH_TRANSFORMER.pdf
        config.rag_dataset[4],  # Assuming dataset[4] is the processed content of SWITCH_TRANSFORMER.pdf
        config.rag_dataset[4],
        config.rag_dataset[4],
        config.rag_dataset[4],
    ]
}
# Train the model
trained_expert_model, average_loss = expert_model.train_expert(
    train_loader=train_loader,
    train_data=train_rag_data,
    optimizer=optimizer,
    main_loss_function=main_loss_function,
    aux_loss_weight=aux_loss_weight,
    device=config.device,
    save_path=save_path,
    accumulation_steps=4,  # Adjust based on your preference
    num_epochs=5  # Adjust based on your training needs
)

print(f"Training completed. Average loss: {average_loss}")
'''

Trie Construction Completed Successfully
Trie built successfully.
Token IDs: [766, 0, 0, 0, 0, 0, 0, 1360, 162, 0]
Token IDs: [1408, 1489, 116, 475, 767, 29, 746, 2249, 43, 2220]
Token IDs: [0, 0, 0, 0, 0, 408, 62, 1701, 49, 0]
Token IDs: [69, 17, 38, 0, 0, 0, 17, 0, 17, 134]
Token IDs: [72, 76, 0, 808, 0, 202, 1899, 2415, 2003, 62]
Token IDs: [0, 0, 0, 0, 0, 1324, 0, 0, 0, 0]
Token IDs: [66, 879, 261, 262, 0, 0, 0, 0, 0, 0]
Token IDs: [0, 0, 0, 0, 408, 29, 746, 2249, 43, 1386]
Token IDs: [108, 746, 2249, 269, 270, 342, 539, 207, 39, 698]
Token IDs: [213, 80, 176, 29, 202, 55, 258, 1921, 589, 62]
Token IDs: [66, 446, 879, 29, 202, 0, 0, 0, 1867, 173]
Token IDs: [52, 766, 0, 0, 0, 0, 0, 0, 0, 0]
Token IDs: [1402, 207, 39, 2019, 2143, 0, 0, 1428, 29, 174]
Token IDs: [66, 55, 0, 0, 0, 0, 0, 0, 1953, 80]
Token IDs: [213, 1576, 80, 803, 80, 746, 2249, 17, 427, 0]
Token IDs: [2090, 1868, 108, 1386, 76, 1387, 114, 1335, 116, 0]
Token IDs: [213, 80, 176, 262, 116, 2330, 0, 0, 0, 0]
Token IDs: 

'\n# Load Wikipedia dataset and preprocess\ndataset = load_dataset("wikitext", "wikitext-2-v1", split="train")\n\n# Access a portion of the dataset for inspection\n#print(dataset[\'train\'][0])\n\ndef generate_attention_mask(token_ids):\n    """Generate an attention mask for the given token IDs."""\n    return [1 if token_id != 0 else 0 for token_id in token_ids]\n\ndef tokenize_function(examples):\n    # Initialize lists for token IDs, attention masks, and labels\n    token_ids, attention_masks, labels = [], [], []\n\n    for text in examples["text"]:\n        # Tokenize text and ensure at least one token ID is generated\n        ids = wordpiece_tokenizer.tokenize(text) or [0]  # Replace [0] with your tokenizer\'s pad token ID\n        # Generate attention mask for the tokenized text\n        mask = [1] * len(ids)\n\n        # Check for length of ids and pad/truncate as necessary\n        if len(ids) < config.max_length:\n            # Pad\n            pad_length = config.max_length -

# Switch Transformer / Mixture of Experts

In [15]:
# Define the Switch Transformer / Mixture of Experts architecture
class SwitchTransformerMoE(nn.Module):
    def __init__(self, experts, config):
        super(SwitchTransformerMoE, self).__init__()
        self.experts = nn.ModuleList(experts)
        self.config = config
        self.sparse_flash2_attention = Expert.SparseFlash2Attention(
            config.seq_len, 
            config.head_dim, 
            config.block_size, 
            config.sparsity_factor
        )        
        self.layer_norm = nn.LayerNorm(config.seq_len)
        self.dropout = nn.Dropout(config.dropout)
        self.switch_layer = Expert.SwitchRouter(config)
        print(f"SwitchTransformerMoE initialized with embed_size={config.embed_size}, heads={config.heads}")
        self.qlora = Expert.QLORALayer(config.input_dim, config.input_dim, config.rank)

    def forward(self, x, attention_mask, context_texts, question_text):
        print(f"\nForward pass started")
        print(f"x initial shape: {x.shape}")
        print(f"About to call MultiheadAttention with x of shape: {x.shape}")
        x = self.sparse_flash2_attention(x, x, x)        
        print(f"Sparse Flash 2 Attention output shape: {x.shape}")
        
        x = self.dropout(self.layer_norm(x))

        
        # Debugging the dimensions and type of x before it's sent to switch_layer
        print(f"x shape before switch_layer: {x.shape}, dtype: {x.dtype}")

        x, aux_loss = self.switch_layer(x, attention_mask, context_texts, question_text)
        print(f"x shape after switch_layer: {x.shape}, aux_loss: {aux_loss}")
        x = self.qlora(x)
        return x, aux_loss


config = ExpertConfig(wordpiece_vocab=wordpiece_vocab, wordpiece_tokenizer=wordpiece_tokenizer)
expert_model_1 = Expert(config)
expert_model_2 = Expert(config)
expert_model_3 = Expert(config)
expert_model_4 = Expert(config)
expert_model_5 = Expert(config)

switch_transformer_moe = SwitchTransformerMoE(
    experts=[
        expert_model_1,
        expert_model_2,
        expert_model_3,
        expert_model_4,
        expert_model_5
    ],
    config=config
)
# Rag data
train_rag_data = {
    "queries": [
        # Queries for DPO.pdf
        "What is Direct Preference Optimization (DPO)?",
        "How does Direct Preference Optimization work?",
        "How can I implement Direct Preference Optimization in my organization?",
        "Why does Direct Preference Optimization improve the efficiency of language modelling?",
        # Queries for MAMBA.pdf
        "What is MAMBA?",
        "How does MAMBA function?",
        "How can I build a system based on MAMBA technology?",
        "Why does MAMBA enhance the performance of its application area?",
        # Queries for QLORA.pdf
        "What is QLORA?",
        "How does QLORA operate?",
        "How can I develop a project using QLORA?",
        "Why does QLORA improve the capabilities of its relevant field?",
        # Queries for RAG.pdf
        "What is Retrieval Augmented Generation (RAG)?",
        "How does Retrieval Augmented Generation work?",
        "How can I build a Retrieval Augmented Generation model?",
        "Why does Retrieval Augmented Generation enhance language model performance?",
        # Queries for SWITCH_TRANSFORMER.pdf
        "What is the Switch Transformer model?",
        "How does the Switch Transformer model operate?",
        "How can I construct a Switch Transformer model?",
        "Why does the Switch Transformer model improve language processing tasks?"
    ],
    "contexts": [
        # Contexts from DPO.pdf
        config.rag_dataset[0],  # Assuming dataset[0] is the processed content of DPO.pdf
        config.rag_dataset[0],
        config.rag_dataset[0],        
        config.rag_dataset[0],        
        # Contexts from MAMBA.pdf
        config.rag_dataset[1],  # Assuming dataset[1] is the processed content of MAMBA.pdf
        config.rag_dataset[1], 
        config.rag_dataset[1], 
        config.rag_dataset[1], 
        # Contexts from QLORA.pdf
        config.rag_dataset[2],  # Assuming dataset[2] is the processed content of QLORA.pdf
        config.rag_dataset[2],
        config.rag_dataset[2],
        config.rag_dataset[2],
        # Contexts from RAG.pdf
        config.rag_dataset[3],  # Assuming dataset[3] is the processed content of RAG.pdf
        config.rag_dataset[3],
        config.rag_dataset[3],
        config.rag_dataset[3],
        # Contexts from SWITCH_TRANSFORMER.pdf
        config.rag_dataset[4],  # Assuming dataset[4] is the processed content of SWITCH_TRANSFORMER.pdf
        config.rag_dataset[4],
        config.rag_dataset[4],
        config.rag_dataset[4],
    ]
}
# Example training loop (simplified)
num_epochs = 5
# Load Wikipedia dataset and preprocess
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")

# Access a portion of the dataset for inspection
#print(dataset['train'][0])

def generate_attention_mask(token_ids):
    """Generate an attention mask for the given token IDs."""
    return [1 if token_id != 0 else 0 for token_id in token_ids]

def tokenize_function(examples):
    # Initialize lists for token IDs, attention masks, and labels
    token_ids, attention_masks, labels = [], [], []

    for text in examples["text"]:
        # Tokenize text and ensure at least one token ID is generated
        ids = wordpiece_tokenizer.tokenize(text) or [0]  # Replace [0] with your tokenizer's pad token ID
        # Generate attention mask for the tokenized text
        mask = [1] * len(ids)

        # Check for length of ids and pad/truncate as necessary
        if len(ids) < config.max_length:
            # Pad
            pad_length = config.max_length - len(ids)
            ids += [0] * pad_length  # Assuming 0 is the padding ID
            mask += [0] * pad_length
        else:
            # Truncate
            ids = ids[:config.max_length]
            mask = mask[:config.max_length]

        token_ids.append(ids)
        attention_masks.append(mask)
        labels.append(ids)  # For simplicity, using the same IDs as labels; adjust as needed for your model

    return {"input_ids": token_ids, "attention_mask": attention_masks, "labels": labels}



# Apply the custom tokenize function to the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_loader = DataLoader(tokenized_datasets, batch_size=1, shuffle=True)
optimizer = torch.optim.Adam(switch_transformer_moe.parameters(), lr=1e-4)
main_loss_function = torch.nn.CrossEntropyLoss()
aux_loss_weight = 0.2
accumulation_steps = 4
for epoch in range(num_epochs):
    total_loss = 0
    optimizer.zero_grad()  # Initialize gradients to zero
    for batch_idx, batch in enumerate(train_loader):
        inputs, attention_mask, targets = batch['input_ids'].to(config.device), batch['attention_mask'].to(config.device), batch['labels'].to(config.device)

        # Calculate start and end indices for current batch in train_data
        start_idx = batch_idx * batch['input_ids'].size(0)
        end_idx = start_idx + batch['input_ids'].size(0)

        # Extract current_queries and current_contexts for the batch
        current_queries = train_rag_data['queries'][start_idx:end_idx]
        current_contexts = train_rag_data['contexts'][start_idx:end_idx]

        # Call to the model forward function
        outputs, aux_loss = switch_transformer_moe(inputs, attention_mask, current_contexts, current_queries)

        # Calculate loss and accumulate
        main_loss = main_loss_function(outputs.view(-1, outputs.size(-1)), targets.view(-1))
        loss = (main_loss + aux_loss_weight * aux_loss) / accumulation_steps
        loss.backward()

        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item() * accumulation_steps  # Scale back up

    average_loss = total_loss / len(train_loader)
    print(f'End of Epoch {epoch+1}, Average Loss: {average_loss}')
average_loss = total_loss / len(train_loader)
#torch.save(self.state_dict(), save_path)




Token IDs: [766, 0, 0, 0, 0, 0, 0, 1360, 162, 0]
Token IDs: [1408, 1489, 116, 475, 767, 29, 746, 2249, 43, 2220]
Token IDs: [0, 0, 0, 0, 0, 408, 62, 1701, 49, 0]
Token IDs: [69, 17, 38, 0, 0, 0, 17, 0, 17, 134]
Token IDs: [72, 76, 0, 808, 0, 202, 1899, 2415, 2003, 62]
Token IDs: [0, 0, 0, 0, 0, 1324, 0, 0, 0, 0]
Token IDs: [66, 879, 261, 262, 0, 0, 0, 0, 0, 0]
Token IDs: [0, 0, 0, 0, 408, 29, 746, 2249, 43, 1386]
Token IDs: [108, 746, 2249, 269, 270, 342, 539, 207, 39, 698]
Token IDs: [213, 80, 176, 29, 202, 55, 258, 1921, 589, 62]
Token IDs: [66, 446, 879, 29, 202, 0, 0, 0, 1867, 173]
Token IDs: [52, 766, 0, 0, 0, 0, 0, 0, 0, 0]
Token IDs: [1402, 207, 39, 2019, 2143, 0, 0, 1428, 29, 174]
Token IDs: [66, 55, 0, 0, 0, 0, 0, 0, 1953, 80]
Token IDs: [213, 1576, 80, 803, 80, 746, 2249, 17, 427, 0]
Token IDs: [2090, 1868, 108, 1386, 76, 1387, 114, 1335, 116, 0]
Token IDs: [213, 80, 176, 262, 116, 2330, 0, 0, 0, 0]
Token IDs: [177, 0, 0, 0, 1360, 65, 0, 0, 312, 180]
Token IDs: [908, 0, 207, 

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Token IDs: []
Token IDs: [640, 0, 0, 0, 0, 0, 0, 0, 1916, 1520]
Token IDs: []
Token IDs: [0, 0, 627, 0, 0, 0, 0, 180, 162, 522]
Token IDs: [55, 0, 0, 0, 0, 869, 58, 0, 17, 0]
Token IDs: [321, 207, 0, 0, 0, 0, 58, 0, 0, 0]
Token IDs: []
Token IDs: [640, 640, 0, 0, 1942, 640, 640]
Token IDs: []
Token IDs: [80, 207, 466, 522, 614, 524, 0, 0, 0, 1916]
Token IDs: [55, 0, 0, 0, 262, 0, 0, 0, 374, 17]
Token IDs: [0, 0, 0, 0, 87, 2364, 132, 1565, 1548, 162]
Token IDs: []
Token IDs: [640, 640, 640, 640]
Token IDs: []
Token IDs: [55, 0, 0, 2255, 1453, 1273, 55, 1463, 66, 0]
Token IDs: [80, 55, 0, 0, 0, 1230, 301, 0, 0, 0]
Token IDs: [522, 614, 524, 507, 62, 841, 0, 17, 134, 1196]
Token IDs: []
Token IDs: [640, 640, 869, 640, 640]
Token IDs: []
Token IDs: [0, 0, 0, 336, 43, 0, 0, 0, 0, 0]
Token IDs: [55, 0, 0, 0, 0, 0, 116, 376, 1617, 43]
Token IDs: []
Token IDs: [640, 640, 640, 0, 0, 640, 640, 640]
Token IDs: []
Token IDs: [55, 0, 0, 1045, 2343, 108, 0, 0, 0, 0]
Token IDs: []
Token IDs: [640, 64

RuntimeError: shape mismatch: value tensor of shape [512, 30522] cannot be broadcast to indexing result of shape [1, 512]

# Switch Transformer V2

In [16]:
import os

# Set the HF_HOME environment variable to a new cache directory on the D drive
#os.environ['HF_HOME'] = 'D:/hf_datasets_cache'
#os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# Test on larger text
import re
import collections
from collections import Counter, defaultdict
import json


# 0. Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import fitz  # PyMuPDF
from transformers import BertModel
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from tqdm.notebook import tqdm
from torch.nn.utils.rnn import pad_sequence




class ExpertConfig:
    def __init__(self,wordpiece_vocab,wordpiece_tokenizer, cls_token_id=1770, 
                 sep_token_id=1771, pad_token_id=0, 
                 seq_len=512, head_dim=64, block_size=64, sparsity_factor=4,
                 input_dim=512, num_experts=3, vocab_size=30522, embed_size=256,
                 num_layers=6, forward_expansion=4, heads=8, dropout=0.1,
                 max_length=512,  # Updated to match seq_len for consistency
                 rank=16, device='cpu',
                 mamba_model_path='D:\\EXPERT_WEIGHTS\\mamba_model_weights.pth',
                 rag_model_path='D:\\EXPERT_WEIGHTS\\rag_model_weights.pth',
                 context_encoder_path='D:\\EXPERT_WEIGHTS\\context_encoder.pth',
                 language_model_path='D:\\EXPERT_WEIGHTS\\language_model.pth',
                 question_encoder_path='D:\\EXPERT_WEIGHTS\\question_encoder.pth',
                 dpo_model_path='D:\\EXPERT_WEIGHTS\\dpo_model_weights.pth',
                 model_name='bert-base-uncased', embedding_dim=768,
                 alpha=1, quantization_bits=8, tokenizer_name='bert-base-uncased',                 
                  freq_threshold=100, d_model=512, d_state=2048, d_conv=3, expansion_factor=2, 
                 clip_gradient = 1.0, mamba_learning_rate = 5e-4, weight_decay = 0.1,
                 warmup_steps = 10, total_mamba_steps = 100
                ):

        # Common hyperparameters
        self.freq_threshold = freq_threshold
        self.wordpiece_vocab = wordpiece_vocab
        self.wordpiece_tokenizer = wordpiece_tokenizer        
        self.seq_len = seq_len
        self.pad_token_id = pad_token_id
        self.cls_token_id = cls_token_id
        self.sep_token_id = sep_token_id
        self.head_dim = head_dim
        self.block_size = block_size
        self.sparsity_factor = sparsity_factor
        self.input_dim = input_dim
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.num_layers = num_layers
        self.forward_expansion = forward_expansion
        self.heads = heads
        self.dropout = dropout
        self.max_length = max_length  
        self.rank = rank

        # Model paths and device
        self.mamba_model_path = mamba_model_path
        self.rag_model_path = rag_model_path
        self.context_encoder_path = context_encoder_path
        self.language_model_path = language_model_path
        self.question_encoder_path = question_encoder_path
        self.dpo_model_path = dpo_model_path
        self.device = device

        # Unique hyperparameters
        self.num_experts = num_experts
        self.model_name = model_name
        self.embedding_dim = embedding_dim
        self.alpha = alpha
        self.quantization_bits = quantization_bits
        self.tokenizer_name = tokenizer_name
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor
        self.clip_gradient = clip_gradient
        self.mamba_learning_rate = mamba_learning_rate
        self.weight_decay = weight_decay
        self.warmup_steps = warmup_steps
        self.total_mamba_steps = total_mamba_steps

        # PDFs (unchanged)
        self.pdf_file_paths = [
            #r'C:\Users\robbi\IEEMM\DPO.pdf', 
            r'C:\Users\robbi\OneDrive\AI_Papers_Research\DPO.pdf',
            #r'C:\Users\robbi\IEEMM\MAMBA.pdf',
            r'C:\Users\robbi\OneDrive\AI_Papers_Research\DPO.pdf',

            #r'C:\Users\robbi\IEEMM\QLORA.pdf',
            r'C:\Users\robbi\OneDrive\AI_Papers_Research\DPO.pdf',

            #r'C:\Users\robbi\IEEMM\RAG.pdf',
            r'C:\Users\robbi\OneDrive\AI_Papers_Research\DPO.pdf',

            #r'C:\Users\robbi\IEEMM\SWITCH_TRANSFORMER.pdf'
            r'C:\Users\robbi\OneDrive\AI_Papers_Research\DPO.pdf',

        ]
        
        # Preserving original dataset loading functionality
        self.rag_dataset = Expert.TransformerRAG.create_dataset_from_pdfs(self.pdf_file_paths, self.wordpiece_tokenizer)

        
    @staticmethod
    def adapt_vocab_for_wordpiece(ssp_vocab):
        adapted_vocab = {}
        for token, id_or_freq in ssp_vocab.items():
            if not token.startswith(" ") and not token.endswith("</w>"):
                adapted_token = "##" + token.replace("</w>", "")
            else:
                adapted_token = token.replace("</w>", "")
            adapted_vocab[adapted_token] = id_or_freq
        return adapted_vocab

        
    def validate(self):
        assert self.seq_len % self.block_size == 0, "seq_len must be divisible by block_size"
        assert self.max_length >= self.seq_len, "max_length should be equal to or greater than seq_len"



class Expert(nn.Module):

    @staticmethod
    # Load model weights function
    def load_model_weights(model, model_path):
        #checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        checkpoint = torch.load(model_path, map_location=config.device)

        if isinstance(checkpoint, dict):
            # Check for 'state_dict' or 'model_state_dict' keys
            if 'state_dict' in checkpoint:
                model.load_state_dict(checkpoint['state_dict'])
            elif 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                # If no known key is found, try loading it as a raw state dictionary
                try:
                    model.load_state_dict(checkpoint)
                except RuntimeError as e:
                    raise ValueError(f"Error loading state dict: {e}")
        elif isinstance(checkpoint, nn.Module):
            # If the checkpoint is a model object, assign it directly
            model = checkpoint
        else:
            raise ValueError(f"Unsupported checkpoint format: {type(checkpoint)}")

        model.eval()
        return model

    ###############################
    # Flash2_Attention
    class FlashAttention2(nn.Module):
        def __init__(self, sequence_length, head_dimension, block_size):
            super(Expert.FlashAttention2, self).__init__()
            self.block_size = block_size
            # Ensure that sequence_length is divisible by block_size for simplicity
            assert sequence_length % block_size == 0

        def forward(self, Q, K, V):
            # Partitioning of inputs
            Q_blocks, K_blocks, V_blocks = self.partition_inputs(Q, K, V)

            # Efficient computation of the attention mechanism
            outputs = []
            for i, Q_block in enumerate(Q_blocks):
                output_block = self.process_block(Q_block, K_blocks, V_blocks)
                outputs.append(output_block)

            # Concatenating the processed blocks
            output = torch.cat(outputs, dim=0)
            return output

        def partition_inputs(self, Q, K, V):
            # Calculate the number of chunks, ensuring it's at least 1
            num_chunks_Q = max(1, Q.size(0) // self.block_size)
            num_chunks_K = max(1, K.size(0) // self.block_size)
            num_chunks_V = max(1, V.size(0) // self.block_size)
            
            # Chunk the inputs
            Q_blocks = Q.chunk(chunks=num_chunks_Q, dim=0)
            K_blocks = K.chunk(chunks=num_chunks_K, dim=0)
            V_blocks = V.chunk(chunks=num_chunks_V, dim=0)
            
            return Q_blocks, K_blocks, V_blocks

        def process_block(self, Q_block, K_blocks, V_blocks):
            # Process each block efficiently as per FLASH2's optimized method
            # This includes computing QK^T, applying online softmax, and multiplying with V
            output_blocks = []
            for K_block, V_block in zip(K_blocks, V_blocks):
                print(f"Q_block: {Q_block.shape} , K_block.transpose(-2, -1): {K_block.transpose(-2, -1).shape}")
                attention_scores = torch.matmul(Q_block, K_block.transpose(-2, -1))
                print(f"attention_scores = torch.matmul(Q_block, K_block.transpose(-2, -1)): {attention_scores.shape}")
                attention_scores = attention_scores.float()
                attention_scores = self.online_softmax(attention_scores)
                output_block = torch.matmul(attention_scores, V_block.float())
                print(f"output_block shape:, {output_block.shape}") 
                output_blocks.append(output_block)

            # Summing up the results from each block
            output_block_sum = sum(output_blocks)
            return output_block_sum

        def online_softmax(self, scores, chunk_size=128):
            # Apply softmax in chunks for large sequences
            softmaxed_scores = []
            for i in range(0, scores.size(0), chunk_size):
                chunk = scores[i:i + chunk_size, :]
                softmaxed_chunk = F.softmax(chunk, dim=1)
                softmaxed_scores.append(softmaxed_chunk)
            return torch.cat(softmaxed_scores, dim=0)

    ###############################
    # SparseFlash2_Attention
    class SparseFlash2Attention(nn.Module):
        def __init__(self, seq_len, head_dim, blk_size, sparsity_factor):
            super().__init__()
            self.flash_attention = Expert.FlashAttention2(seq_len, head_dim, blk_size)
            self.seq_len = seq_len
            self.head_dim = head_dim
            self.block_size = blk_size  # Storing block_size as an instance variable
            self.sparsity_factor = sparsity_factor

        def generate_sparsity_mask(self):
            mask = torch.zeros(self.seq_len, self.seq_len)
            step = self.sparsity_factor
            for i in range(0, self.seq_len, step):
                mask[i:i + step, :] = 1
            return mask.bool()

        def forward(self, Q, K, V):
            output = self.flash_attention(Q, K, V)  # output shape: [sequence_length, head_dimension]
            print(f"Initial output shape: {output.shape}")

            # Reshape output to be 3D for batch matrix multiplication
            output = output.unsqueeze(0)  # New shape: [1, sequence_length, head_dimension]
            print(f"output.unsqueeze(0) shape: {output.shape}")

            sparsity_mask = self.generate_sparsity_mask()  # shape: [sequence_length, sequence_length]
            print(f"Sparsity mask shape: {sparsity_mask.shape}")

            # Apply the sparsity mask to the output
            sparsity_mask = sparsity_mask.unsqueeze(0)  # New shape: [1, sequence_length, sequence_length]
            print(f"sparsity mask sparsity_mask.unsqueeze(0) shape: {sparsity_mask.shape}")
            output = torch.bmm(sparsity_mask.float(), output.float().transpose(1, 2))  
            print(f"output from torch.bmm shape: {output.shape}")
            output = output.transpose(1, 2)
            # Reshape the output back to 2D
            output = output.squeeze(0)  # New shape: [sequence_length, head_dimension]

            return output


    ###############################
    # MAMBA
    # RoPE
    class RotaryPositionalEncoding(nn.Module):
        def __init__(self, dim, max_len=5000):
            super().__init__()
            self.dim = dim
            inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
            t = torch.arange(max_len).type_as(inv_freq)
            freqs = torch.einsum('n , d -> n d', t, inv_freq)
            self.register_buffer('sin', freqs.sin())
            self.register_buffer('cos', freqs.cos())

        def forward(self, x):
            n, _, device = x.shape[1], self.dim // 2, x.device
            sin, cos = self.sin[:n].to(device), self.cos[:n].to(device)

            # Apply RoPE to even and odd indices separately
            x_even = x[..., :self.dim:2] * cos.unsqueeze(0) + torch.roll(x[..., 1:self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
            x_odd = x[..., 1:self.dim:2] * cos.unsqueeze(0) - torch.roll(x[..., :self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
            return torch.cat((x_even, x_odd), dim=-1)

    # SWIGLU
    class SwiGLU(nn.Module):
        def __init__(self, dim_in, dim_out):
            super(Expert.SwiGLU, self).__init__()
            self.fc1 = nn.Linear(dim_in, dim_out)
            self.fc2 = nn.Linear(dim_in, dim_out)

        def forward(self, x):
            gate = torch.sigmoid(self.fc2(x))
            return self.fc1(x) * gate

    class SimplifiedMAMBA(nn.Module):
        def __init__(self, config):
            super().__init__()
            # Using configuration parameters
            self.num_layers = config.num_layers
            self.d_model = config.d_model
            self.d_state = config.d_state
            self.d_conv = config.d_conv
            self.expansion_factor = config.expansion_factor
            
            self.feedforward = nn.Sequential(
                nn.Linear(self.d_model, self.d_state),
                nn.GELU(),
                nn.Linear(self.d_state, self.d_model)
            )
            self.input_embedding = nn.Linear(self.d_model, self.d_model)
            self.convs = nn.Sequential(*[nn.Conv1d(self.d_model, self.d_model, kernel_size=self.d_conv, padding=(self.d_conv // 2)) for _ in range(self.num_layers)])
            self.swiglu = Expert.SwiGLU(self.d_model, self.d_model)
            self.output_projection = nn.Linear(self.d_model, self.d_model * self.expansion_factor)

            self.initialize_weights()

        def initialize_weights(self):
            gain = nn.init.calculate_gain('relu')

            nn.init.orthogonal_(self.input_embedding.weight, gain)
            nn.init.normal_(self.input_embedding.bias, mean=0, std=0.01)

            nn.init.kaiming_uniform_(self.convs[-1].weight, a=math.sqrt(5))
            nn.init.zeros_(self.convs[-1].bias)

            nn.init.xavier_uniform_(self.feedforward[0].weight, gain=nn.init.calculate_gain('relu'))
            nn.init.zeros_(self.feedforward[0].bias)

            nn.init.xavier_uniform_(self.feedforward[2].weight, gain=nn.init.calculate_gain('linear'))
            nn.init.zeros_(self.feedforward[2].bias)

            nn.init.xavier_uniform_(self.output_projection.weight, gain=nn.init.calculate_gain('linear'))
            nn.init.zeros_(self.output_projection.bias)

        def forward(self, inputs, attention_mask=None):
            print("Input shape:", inputs.shape)

            # Apply the attention mask if provided
            if attention_mask is not None:
                inputs = inputs * attention_mask.unsqueeze(-1)

            projected_inputs = self.input_embedding(inputs)
            print("projected_inputs pre-reshape shape:", projected_inputs.shape)

            projected_inputs = projected_inputs.permute(0, 2, 1)
            print("projected_inputs post-reshape shape:", projected_inputs.shape)

            for conv in self.convs:
                projected_inputs = conv(projected_inputs)

            projected_inputs = projected_inputs.permute(0, 2, 1)
            print("projected_inputs post convolution reshape:", projected_inputs.shape)

            projected_inputs = self.swiglu(projected_inputs)
            print("projected_inputs post swiglu shape:", projected_inputs.shape)

            output = self.output_projection(projected_inputs)
            print("output shape:", output.shape)

            return output

    class SimplifiedLanguageModelMAMBA(nn.Module):
        def __init__(self, config):
            super().__init__()
            # Using configuration parameters
            self.vocab_size = config.vocab_size
            self.num_layers = config.num_layers
            self.d_model = config.d_model
            self.d_state = config.d_state
            self.d_conv = config.d_conv
            self.expansion_factor = config.expansion_factor

            self.embedding = nn.Embedding(self.vocab_size, self.d_model)
            self.pos_encoder = Expert.RotaryPositionalEncoding(self.d_model)
            self.simplified_mamba = Expert.SimplifiedMAMBA(config)
            self.output_projection = nn.Linear(self.d_model * 2, self.vocab_size)  # Adjust if needed

            self.initialize_weights()

        def initialize_weights(self):
            gain = 1.0

            nn.init.orthogonal_(self.embedding.weight, gain)
            nn.init.xavier_uniform_(self.output_projection.weight, gain=gain)

        def forward(self, input_values, attention_mask):
            print(f"SimplifiedLanguageModelMAMBA fwd input_values: {input_values.shape}")
            print(f"SimplifiedLanguageModelMAMBA fwd input_values: {input_values.dtype}")
            print(f"Max input_id before embedding: {input_values.max().item()}")
            if input_values.max() >= vocab_size:
                print("Detected out-of-range value after adjustment.")

            if input_values.dtype != torch.long:
              input_values = input_values.long() 
            embedded = self.embedding(input_values) * math.sqrt(self.d_model)
            embedded = self.pos_encoder(embedded)
            simplified_mamba_output = self.simplified_mamba(embedded, attention_mask)
            logits = self.output_projection(simplified_mamba_output)
            return logits

    ###############################
    # Switch Router

    class SwitchRouter(nn.Module):
        CAPACITY_FACTOR = 1  # Class constant

        class SwitchGate(nn.Module):
            def __init__(self, input_dim, num_experts):
                super(Expert.SwitchRouter.SwitchGate, self).__init__()
                self.fc1 = nn.Linear(input_dim, input_dim // 2)
                self.fc2 = nn.Linear(input_dim // 2, num_experts)

            def forward(self, x):
                x = F.relu(self.fc1(x))
                gate_scores = F.softmax(self.fc2(x), dim=-1)
                return gate_scores

        def __init__(self, config):
            super(Expert.SwitchRouter, self).__init__()
            self.config = config
            self.device = config.device
            self.router = self.SwitchGate(config.input_dim, config.num_experts).to(config.device)
            self.transformer_rag = Expert.TransformerRAG(config).to(config.device)
            self.lmt = Expert.LanguageModelTransformer(config).to(config.device)
            self.transformer_dpo = Expert.DPO(self.lmt, config.device,config.embed_size).to(config.device)
            self.mamba = Expert.SimplifiedLanguageModelMAMBA(config).to(config.device)
            self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
            self.input_embedding = nn.Linear(config.input_dim, config.input_dim)

        def forward(self, x, attention_mask, context_texts, question_text):
            x = x.to(self.device)
            x = self.input_embedding(x)
            gate_scores = self.router(x)
            expert_indices = torch.argmax(gate_scores, dim=1)
            final_output = torch.zeros_like(x)

            for i, expert in enumerate(self.experts):
                mask = expert_indices == i
                if mask.any():
                    selected_inputs = x[mask]
                    selected_attention_mask = attention_mask[mask].to(self.device)
                    if isinstance(expert, Expert.TransformerRAG):
                        expert_output = expert.forward(context_texts, selected_inputs, selected_attention_mask, question_text)
                    elif isinstance(expert, Expert.DPO):  # Check if the expert is an instance of DPO
                        # Use the forward_expert method for DPO within the routing process
                        expert_output = expert.forward_expert(selected_inputs, selected_attention_mask, context_texts, question_text)
                    else:
                        # For other experts, continue using the standard forward method
                        expert_output = expert(selected_inputs, selected_attention_mask)
                    final_output[mask] = expert_output

            aux_loss = self.auxiliary_loss(gate_scores)
            return final_output, aux_loss

        def auxiliary_loss(self, gate_scores):
            expert_load = gate_scores.sum(0) / gate_scores.size(0)
            loss_balancing = torch.std(expert_load)
            return loss_balancing


        @staticmethod
        def route_inputs(expert_indices, gate_scores, num_experts):
            capacity_factor_tensor = torch.tensor([Expert.SwitchRouter.CAPACITY_FACTOR], dtype=torch.float32)
            capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
            expert_counts = torch.zeros(num_experts, dtype=torch.int32)
            for idx in range(len(expert_indices)):
                selected_expert = expert_indices[idx]
                if expert_counts[selected_expert] < capacities[0]:
                    expert_counts[selected_expert] += 1
                else:
                    available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
                    if len(available_experts) > 0:
                        alternative_expert = available_experts[0]
                        expert_indices[idx] = alternative_expert
                        expert_counts[alternative_expert] += 1
                    else:
                        print("No available experts to reroute. Handling overflow.")
            return expert_indices
    
    ###############################
    # RAG
    
    class PositionalEncoding(nn.Module):
        def __init__(self, d_model, max_len=10000):
            super(Expert.PositionalEncoding, self).__init__()
            self.d_model = d_model
            self.max_len = max_len

            # Create positional encodings
            pe = torch.zeros(max_len, d_model)
            position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)

            # Add a batch dimension (B x T x C)
            pe = pe.unsqueeze(0)
            self.register_buffer('pe', pe)

        def forward(self, x):
            # x: Tensor of shape [Batch Size, Sequence Length, Embedding Dimension]
            # Adjust positional encoding to match the input size and device
            pe = self.pe[:, :x.size(1)]
            # Assuming x is on the correct device, pe will be automatically aligned to the same device
            return pe


    class AdaptiveDropoutLayer(nn.Module):
        def __init__(self, init_dropout_rate=0.1):
            super(Expert.AdaptiveDropoutLayer, self).__init__()
            # Use logit transformation for stability
            self.log_alpha = nn.Parameter(torch.tensor(math.log(init_dropout_rate / (1 - init_dropout_rate))).float())

        def forward(self, x):
            p = torch.sigmoid(self.log_alpha)
            # Convert p from a tensor to a float
            p_value = p.item()  # This extracts the scalar value as a Python float
            return nn.functional.dropout(x, p=p_value, training=self.training)


    class MultiHeadLinformerAttention(nn.Module):
        def __init__(self, embed_dim, num_heads, k=None):
            super().__init__()  # Ensure this is called first
            self.embed_dim = embed_dim
            self.num_heads = num_heads
            self.k = k if k is not None else embed_dim // num_heads  # Projection dimension per head

            self.key_projections = nn.Linear(embed_dim, self.k * num_heads)
            self.value_projections = nn.Linear(embed_dim, self.k * num_heads)
            self.out_projection = nn.Linear(self.k * num_heads, embed_dim)

        def forward(self, query, attention_mask=None):
            batch_size, seq_len, _ = query.size()
            
            # Project keys and values
            keys = self.key_projections(query)
            values = self.value_projections(query)
            
            # Reshape into [batch_size, num_heads, seq_len, k]
            keys = keys.reshape(batch_size, seq_len, self.num_heads, self.k).transpose(1, 2)
            values = values.reshape(batch_size, seq_len, self.num_heads, self.k).transpose(1, 2)
            
            # Calculate attention (scaled dot-product attention)
            # Scaling by the square root of the depth of the key vectors to prevent large values in the dot product
            # which could push the softmax function into regions where it has extremely small gradients
            keys = keys / (self.k ** 0.5)
            attention_scores = torch.softmax(torch.matmul(keys, values.transpose(-2, -1)), dim=-1)
            if attention_mask is not None:
                attention_scores = attention_scores.masked_fill(~attention_mask.bool(), float('-inf'))
                # Recalculate softmax for masked scores
                attention_scores = torch.softmax(attention_scores, dim=-1)
            # Apply attention to values
            out = torch.matmul(attention_scores, values)
            
            # Concatenate heads and project back to original embedding dimension
            out = out.transpose(1, 2).reshape(batch_size, seq_len, self.num_heads * self.k)
            out = self.out_projection(out)
            
            return out


    class AdaptiveEmbeddingLayer(nn.Module):
        def __init__(self, vocab,  vocab_size, freq_threshold, large_embed_dim, small_embed_dim, max_seq_len):
            super(Expert.AdaptiveEmbeddingLayer, self).__init__()
            self.vocab = vocab
            self.vocab_size = vocab_size
            self.freq_threshold = freq_threshold
            self.large_embed_dim = large_embed_dim
            self.small_embed_dim = small_embed_dim
            self.max_seq_len = max_seq_len

            self.split_vocab(vocab, freq_threshold)  

            self.frequent_embeddings = nn.Embedding(num_embeddings=len(self.frequent_vocab), embedding_dim=large_embed_dim)
            self.infrequent_embeddings = nn.Embedding(num_embeddings=len(self.infrequent_vocab), embedding_dim=small_embed_dim)
            self.infrequent_projection = nn.Linear(small_embed_dim, large_embed_dim)
            self.positional_embeddings = Expert.PositionalEncoding(large_embed_dim, max_seq_len)


        def split_vocab(self, vocab, freq_threshold):
            token_counts = [(token, count) for token, count in vocab.items()]
            token_counts.sort(key=lambda x: -x[1])  # Sort by frequency
            split_point = next(i for i, (_, count) in enumerate(token_counts) if count < freq_threshold)
            
            self.frequent_vocab = dict(token_counts[:split_point])
            self.infrequent_vocab = dict(token_counts[split_point:])

        def forward(self, token_ids):
            device = token_ids.device
            seq_len = token_ids.size(1)
            batch_size = token_ids.size(0)  # Obtain batch size from token_ids tensor

            # Initialize embeddings tensor
            embeddings = torch.zeros(token_ids.shape[0], seq_len, self.large_embed_dim, device=device)

            # Map token_ids to indices for frequent and infrequent vocab
            frequent_indices = torch.zeros_like(token_ids)
            infrequent_indices = torch.zeros_like(token_ids)
            
            for token_id, index in self.vocab.items():
                mask = token_ids == token_id
                if token_id in self.frequent_vocab:
                    # Map to index in frequent_vocab
                    frequent_indices[mask] = self.frequent_vocab[token_id]
                elif token_id in self.infrequent_vocab:
                    # Map to index in infrequent_vocab
                    infrequent_indices[mask] = self.infrequent_vocab[token_id]

            # Create masks for frequent and infrequent tokens
            frequent_mask = frequent_indices > 0
            infrequent_mask = infrequent_indices > 0

            # Embed frequent tokens
            if frequent_mask.any():
                frequent_embeddings = self.frequent_embeddings(frequent_indices[frequent_mask])
                embeddings[frequent_mask] = frequent_embeddings

            # Embed and project infrequent tokens
            if infrequent_mask.any():
                infrequent_embeddings = self.infrequent_embeddings(infrequent_indices[infrequent_mask])
                infrequent_embeddings_projected = self.infrequent_projection(infrequent_embeddings)
                embeddings[infrequent_mask] = infrequent_embeddings_projected

            # Apply positional embeddings
            position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device).unsqueeze(0)
            position_embeddings = self.positional_embeddings(position_ids)  # Generate for seq_len

            # Ensure positional embeddings are broadcastable to the embeddings tensor
            # This step may not be necessary if your positional embeddings are already correctly shaped
            if position_embeddings.size(0) != batch_size:
                position_embeddings = position_embeddings.expand(batch_size, -1, -1)

            print(f"Embeddings shape: {embeddings.shape}")
            print(f"Positional embeddings shape: {position_embeddings.shape}")
            embeddings += position_embeddings

            return embeddings



    class DPRContextEncoder(nn.Module):
        def __init__(self, config):  # Accept an ExpertConfig instance
            super().__init__()
            self.wordpiece_tokenizer = config.wordpiece_tokenizer
            # Use attributes from config directly
            self.embedding_layer = Expert.AdaptiveEmbeddingLayer(
                config.wordpiece_vocab,
                config.vocab_size,
                config.freq_threshold,
                config.embedding_dim,
                config.embedding_dim // 4,  # Assuming you want to reduce the dimension for the infrequent tokens
                max_seq_len=config.max_length  # Use max_length from config
            )
            self.attention_layer = Expert.MultiHeadLinformerAttention(
                config.embedding_dim, 
                num_heads=config.heads
            )
            self.dropout = Expert.AdaptiveDropoutLayer(init_dropout_rate=config.dropout)  # Assuming dropout rate is defined in config

        def forward(self, input_ids, attention_mask):
            embeddings = self.embedding_layer(input_ids)
            attention_output = self.attention_layer(embeddings, attention_mask=attention_mask)
            attention_output = self.dropout(attention_output)

            # Mean pooling across the sequence length dimension
            pooled_output = attention_output.mean(dim=1)

            return pooled_output


    class DPRQuestionEncoder(nn.Module):
        def __init__(self, config):  # Accept an ExpertConfig instance
            super().__init__()
            self.wordpiece_tokenizer = config.wordpiece_tokenizer
            # Use attributes from config directly
            self.embedding_layer = Expert.AdaptiveEmbeddingLayer(
                config.wordpiece_vocab,
                config.vocab_size,
                config.freq_threshold,
                config.embedding_dim,
                config.embedding_dim // 4,  # Assuming you want to reduce the dimension for the infrequent tokens
                max_seq_len=config.max_length  # Use max_length from config
            )
            self.attention_layer = Expert.MultiHeadLinformerAttention(
                config.embedding_dim, 
                num_heads=config.heads
            )
            self.dropout = Expert.AdaptiveDropoutLayer(init_dropout_rate=config.dropout)

        def forward(self, input_ids, attention_mask):
            embeddings = self.embedding_layer(input_ids)
            attention_output = self.attention_layer(embeddings, attention_mask=attention_mask)
            attention_output = self.dropout(attention_output)

            # Mean pooling across the sequence length dimension
            pooled_output = attention_output.mean(dim=1)

            return pooled_output


    class TransformerRAG(nn.Module):
        def __init__(self, config):
            super(Expert.TransformerRAG, self).__init__()
            self.config = config
            self.embed_size = config.embed_size  
            self.context_encoder = Expert.DPRContextEncoder(config).to(config.device)          
            self.language_model = Expert.LanguageModelTransformer(config).to(config.device)
            self.question_encoder = Expert.DPRQuestionEncoder(config).to(config.device)
            self.tokenizer = config.wordpiece_tokenizer


        def forward(self, context_texts, question_input_ids, question_attention_mask, question_text):
            if question_input_ids.max() >= config.vocab_size:
                raise ValueError("question_input_ids contain ID(s) beyond the tokenizer's vocabulary size")
            
            aggregated_context_embeddings = []
            for context_list in context_texts:
                if not all(isinstance(context, dict) for context in context_list):
                    raise TypeError("Each item in context_texts must be a list of tokenized context dictionaries")
                
                aggregated_context_embedding = torch.zeros(config.embedding_dim, device=config.device)
                for context in context_list:
                    context_input_ids = torch.tensor(context['input_ids']).unsqueeze(0).to(config.device)  # Add unsqueeze(0) for batch dimension
                    context_attention_mask = torch.tensor(context['attention_mask']).unsqueeze(0).to(config.device)  # Add unsqueeze(0) for batch dimension
                    print(f"context_input_ids shape: {context_input_ids.shape}")
                    print(f"context_attention_mask shape: {context_attention_mask.shape}")

                    context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                    aggregated_context_embedding += context_embedding.mean(dim=0)
                
                aggregated_context_embeddings.append(aggregated_context_embedding / len(context_list))
            
            question_input_ids = question_input_ids.to(config.device).long()
            question_attention_mask = question_attention_mask.to(config.device).long()
            question_embeddings = self.question_encoder(input_ids=question_input_ids, attention_mask=question_attention_mask)
            
            cos_sim = torch.nn.CosineSimilarity(dim=1)
            #similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in aggregated_context_embeddings]
            similarities = [cos_sim(question_embeddings.unsqueeze(0), context_emb.unsqueeze(0)) for context_emb in aggregated_context_embeddings]

            most_relevant_context_idx = torch.argmax(torch.tensor(similarities, device=config.device))
            
            combined_input = question_text + " " + context_texts[most_relevant_context_idx]
            tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
            tokenized_combined_input = {k: v.to(config.device) for k, v in tokenized_combined_input.items()}
            response_logits = self.language_model(**tokenized_combined_input)
            probabilities = F.softmax(response_logits.logits, dim=-1)
            predicted_token_ids = torch.argmax(probabilities, dim=-1)
            predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])
            #response = self.tokenizer.convert_tokens_to_string(predicted_tokens)
            response = " ".join(predicted_tokens).replace(" </w>", "").replace("</w>", " ").strip()
            
            return response

        @staticmethod
        def extract_text_from_pdf(file_path):
            text = ""
            with fitz.open(file_path) as doc:
                for page in doc:
                    text += page.get_text()
            return text

        @staticmethod
        def split_into_chunks(text, chunk_size):
            return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

        @staticmethod
        def preprocess_text(text, tokenizer, max_length=512):
            chunk_size = max_length - 50
            text_chunks = Expert.TransformerRAG.split_into_chunks(text, chunk_size)
            processed_chunks = []
            for chunk in text_chunks:
                # Tokenize the chunk using the WordPiece tokenizer
                token_ids = tokenizer.tokenize(chunk)
                
                # Manual padding and attention mask creation
                attention_mask = [1] * len(token_ids)
                # Padding: Extend token_ids and attention_mask to max_length
                while len(token_ids) < max_length:
                    token_ids.append(tokenizer.unk_token_id)  # Use unk_token_id for padding
                    attention_mask.append(0)  # Padding token has attention mask 0
                
                # Ensure token_ids and attention_mask are not longer than max_length
                token_ids = token_ids[:max_length]
                attention_mask = attention_mask[:max_length]
                
                processed_chunk = {
                    'input_ids': token_ids,
                    'attention_mask': attention_mask
                }
                processed_chunks.append(processed_chunk)
            return processed_chunks

        @staticmethod
        def create_dataset_from_pdfs(pdf_file_paths, tokenizer):
            dataset = []
            for file_path in pdf_file_paths:
                text = Expert.TransformerRAG.extract_text_from_pdf(file_path)
                processed_text = Expert.TransformerRAG.preprocess_text(text, tokenizer)                
                dataset.append(processed_text)
            return dataset

        def retrieve_contexts(self, dataset, query_embedding, top_k=5):
            similarity_scores = []
            for context in dataset:
                context_input_ids = context['input_ids'].to(self.config.device)
                context_attention_mask = context['attention_mask'].to(self.config.device)
                # Use the class's context_encoder
                context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                similarity = torch.matmul(query_embedding, context_embedding.T)
                similarity_scores.append(similarity.squeeze().item())
            top_k_indices = sorted(range(len(similarity_scores)), key=lambda i: similarity_scores[i], reverse=True)[:top_k]
            top_contexts = [dataset[i] for i in top_k_indices]
            return top_contexts

        def rag_retrieve_and_generate(self, dataset, query):
            # Tokenize the query using the class's tokenizer
            tokenized_query = self.tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.config.device)
            input_ids = tokenized_query['input_ids']
            attention_mask = tokenized_query['attention_mask']
            # Use the class's question_encoder
            encoded_query = self.question_encoder(input_ids, attention_mask)
            relevant_contexts = self.retrieve_contexts(dataset, encoded_query)
            # Assuming generate_response is a method of LanguageModelTransformer that accepts tokenized contexts and generates a response
            response = self.language_model.generate_response(relevant_contexts)
            return response

    ###############################
    # LORA
    class LORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1):
            super(Expert.LORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha

            # Original weight and bias of the linear layer
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            #print("self.weight Shape:", self.weight.shape)
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # LORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def forward(self, x):
            #print("LORALayer Input Shape:", x.shape)
            
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)
            
            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            #print("LORALayer Output Shape:", x.shape)

            return x

    # QLORA
    class QLORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
            super(Expert.QLORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha
            self.quantization_bits = quantization_bits

            # Original weight and bias
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # QLORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.dropout = nn.Dropout(0.1)
            self.layer_norm = nn.LayerNorm(output_dim)

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def quantize(self, x, num_bits):
            # Implement a simple quantization method
            scale = x.abs().max()
            x_quantized = torch.round(x / scale * (2**num_bits - 1))
            return x_quantized, scale

        def forward(self, x):
            #print("QLORALayer Input Shape:", x.shape)
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
            B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
            lora_adjustment = self.dropout(lora_adjustment)
            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)

            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            x = self.layer_norm(x)

            #print("QLORALayer Output Shape:", x.shape)

            return x
        
        def update_alpha(self, new_alpha):
            """
            Update the alpha scaling factor.
            """
            self.alpha = new_alpha

    ###############################
    # Language Model Transformer

    class MultiHeadAttention(nn.Module):
        def __init__(self, embed_size, heads):
            super(Expert.MultiHeadAttention, self).__init__()
            self.embed_size = embed_size
            self.heads = heads
            self.head_dim = embed_size // heads

            assert (
                self.head_dim * heads == embed_size
            ), "Embedding size needs to be divisible by heads"

            self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

        def forward(self, values, keys, query, mask):
            N = query.shape[0]
            value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

            # Split the embedding into self.heads different pieces
            values = values.reshape(N, value_len, self.heads, self.head_dim)
            keys = keys.reshape(N, key_len, self.heads, self.head_dim)
            queries = query.reshape(N, query_len, self.heads, self.head_dim)

            values = self.values(values)
            keys = self.keys(keys)
            queries = self.queries(queries)

            # Einsum does the matrix multiplication for query*keys for each training example
            # with every other training example, then sum it up
            attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

            if mask is not None:
                attention = attention.masked_fill(mask == 0, float("-1e20"))

            attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

            out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
                N, query_len, self.heads * self.head_dim
            )

            out = self.fc_out(out)
            return out

    class TransformerBlock(nn.Module):
        def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
            super(Expert.TransformerBlock, self).__init__()
            self.attention = Expert.MultiHeadAttention(embed_size, heads)
            self.norm1 = nn.LayerNorm(embed_size)
            self.norm2 = nn.LayerNorm(embed_size)

            self.feed_forward = nn.Sequential(
                Expert.LORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                Expert.LORALayer(forward_expansion * embed_size, embed_size, rank),
            )

            self.dropout = nn.Dropout(dropout)

        def forward(self, value, key, query, mask):
            attention = self.attention(value, key, query, mask)
            x = self.dropout(self.norm1(attention + query))
            forward = self.feed_forward(x)
            out = self.dropout(self.norm2(forward + x))
            return out

    class LanguageModelDecoder(nn.Module):
        def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
            super(Expert.LanguageModelDecoder, self).__init__()
            self.word_embedding = nn.Embedding(vocab_size, embed_size)
            self.position_embedding = nn.Embedding(max_length, embed_size)

            # Adding BatchNorm layers
            self.bn1 = nn.BatchNorm1d(embed_size)
            self.bn2 = nn.BatchNorm1d(embed_size)

            self.layers = nn.ModuleList(
                [
                    Expert.TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                    for _ in range(num_layers)
                ]
            )

            # QLORA layers
            self.qlora_feed_forward = nn.Sequential(
                Expert.QLORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                Expert.QLORALayer(forward_expansion * embed_size, embed_size, rank),
            )
            self.use_qlora = False  # Flag to toggle QLORA

            self.fc_out = nn.Linear(embed_size, vocab_size)
            self.dropout = nn.Dropout(dropout)

        def forward(self, x, trg_mask):
            N, seq_length = x.shape
            if seq_length > self.position_embedding.num_embeddings:
                raise ValueError(f"Sequence length {seq_length} exceeds maximum allowed {self.position_embedding.num_embeddings}")
            positions = torch.arange(0, seq_length).expand(N, seq_length).to(config.device)
            x = x.long() # make x a long type for the embeddings
            x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

            print(f"Max position index: {positions.max().item()}, Position Embedding Size: {self.position_embedding.num_embeddings}")


            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn1(x)
            x = x.transpose(1, 2)

            for layer in self.layers:
                x = layer(x, x, x, trg_mask)
                if self.use_qlora:
                    x = self.qlora_feed_forward(x)

            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn2(x)
            x = x.transpose(1, 2)

            out = self.fc_out(x)
            #print(f"shape of output of forward method of LanguageModelDecoder: {out.shape} ")

            return out

        def toggle_qlora(self, use_qlora: bool):
            self.use_qlora = use_qlora

    class LanguageModelTransformer(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.vocab_size = config.vocab_size  # Use vocab_size from ExpertConfig
            self.embed_size = config.embed_size  # Use embed_size from ExpertConfig
            self.num_layers = config.num_layers  # Use num_layers from ExpertConfig
            self.forward_expansion = config.forward_expansion  # Use forward_expansion from ExpertConfig
            self.heads = config.heads  # Use heads from ExpertConfig
            self.dropout = config.dropout  # Use dropout from ExpertConfig
            self.max_length = config.max_length  # Use max_length from ExpertConfig
            self.rank = config.rank  # Use rank from ExpertConfig
            self.tokenizer_name = config.tokenizer_name  # Use tokenizer_name from ExpertConfig

            self.tokenizer = BertTokenizer.from_pretrained(self.tokenizer_name)

            self.decoder = Expert.LanguageModelDecoder(
                self.vocab_size,
                self.embed_size,
                self.num_layers,
                self.heads,
                self.forward_expansion,
                self.dropout,
                self.max_length,
                self.rank,
            )

        def forward(self, trg, attention_mask=None):  # Remove attention_mask here since it's not used
            print(f"Input shape to LanguageModelTransformer: {trg.shape}")
            trg_mask = self.make_trg_mask(trg)
            out = self.decoder(trg, trg_mask)  # Do not pass attention_mask here
            print(f"Language Model Transformer out shape: {out.shape}")
            return out

        def make_trg_mask(self, trg):
            N, trg_len = trg.shape
            trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len).to(trg.device)
            return trg_mask

        def toggle_qlora(self, use_qlora: bool):
            self.decoder.toggle_qlora(use_qlora)

        def generate_response(self, input_ids, attention_mask):
            logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
            probabilities = F.softmax(logits, dim=-1)
            predicted_token_id = torch.argmax(probabilities, dim=-1)
            predicted_tokens = [self.tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
            response = self.tokenizer.convert_tokens_to_string(predicted_tokens)
            return response

    ###############################
    # DPO
    class DPO(nn.Module):
        def __init__(self, language_model, device, embed_size):
            super(Expert.DPO, self).__init__()
            self.language_model = language_model
            self.device = device
            # Assuming embed_size is accessible and correct
            self.projection = nn.Linear(language_model.vocab_size, embed_size)  # Project from vocab_size to embed_size
            self.classifier = nn.Linear(embed_size, 2)  # Assuming embed_size is accessible
            

        def forward(self, input_ids, labels=None):
            print(f"Max input_id in DPO fwd before embedding: {input_ids.max().item()}")

            logits = self.language_model(input_ids)  # Output shape: [batch_size, seq_len, vocab_size]

            # Remove the flattening before pooling if you intend to apply mean pooling over seq_len
            # Ensure logits are correctly projected to embed_size per token
            projected_logits = self.projection(logits.view(-1, logits.size(-1)))
            projected_logits = projected_logits.view(logits.size(0), logits.size(1), -1)  # Reshape to [batch_size, seq_len, embed_size]

            # Apply global mean pooling across the sequence length dimension correctly
            pooled_logits = projected_logits.mean(dim=1)  # Correctly applies mean pooling across seq_len

            predictions = self.classifier(pooled_logits)

            loss = None

            print(f"logits shape: {logits.shape}")
            print(f"projected_logits shape: {projected_logits.shape}")
            print(f"pooled_logits shape: {pooled_logits.shape}")
            print(f"predictions shape: {predictions.shape}")
            if labels is not None:
                print(f"labels shape: {labels.shape}")
                # Ensure labels are flattened if they're not already 1D
                if labels.dim() > 1:
                    labels = labels.view(-1)  # Flatten labels to match predictions shape
                loss_fct = nn.CrossEntropyLoss()  # Correctly instantiate the loss function
                loss = loss_fct(predictions, labels)

            return predictions, loss


     
        def forward_expert(self, input_ids, attention_mask=None, context_texts=None, question_text=None):
            """
            Special forward method designed for use within the Expert Model, where the DPO component
            is one of several being routed and managed.
            """
            print(f"Max input_id in DPO forward_expert before embedding: {input_ids.max().item()}")

            # Get the logits from the language model
            logits = self.language_model(input_ids)  # Assuming output shape: [batch_size, seq_len, vocab_size]

            # Mean pooling over the sequence length dimension to get shape: [batch_size, vocab_size]
            pooled_logits = logits.mean(dim=1)

            # Use the projection layer to transform the logits from vocab_size to embed_size
            # This aligns with the expected input dimensions for the classifier
            transformed_logits = self.projection(pooled_logits)  # New shape: [batch_size, embed_size]

            # Now, the transformed logits can be correctly classified
            predictions = self.classifier(transformed_logits)

            return predictions
    


    def __init__(self, config: ExpertConfig):
        super().__init__()
        # Validate configuration
        config.validate()

        # 1. SparseFlash2_attention
        self.sparse_flash2_attention = Expert.SparseFlash2Attention(
            config.seq_len, 
            config.head_dim, 
            config.block_size, 
            config.sparsity_factor
        )

        # Inside the Expert class definition
        self.lmt = Expert.LanguageModelTransformer(config).to(config.device)

        self.transformer_rag = Expert.TransformerRAG(config).to(config.device)

        # Corrected instantiation of SimplifiedLanguageModelMAMBA
        self.mamba = Expert.SimplifiedLanguageModelMAMBA(config).to(config.device)

        # Now initialize DPO with the language model transformer
        self.transformer_dpo = Expert.DPO(self.lmt, config.device, config.embed_size).to(config.device)

        # 2. LayerNorm and Dropout
        self.layer_norm = nn.LayerNorm(config.seq_len)
        self.dropout = nn.Dropout(config.dropout)

        # 3. Internal Switch Routing
        self.switch_router = Expert.SwitchRouter(config)

        # 4. QLORA Layer
        self.qlora = Expert.QLORALayer(config.input_dim, config.input_dim, config.rank)


    def forward(self, x, attention_mask, context_texts, question_text):
        # 1. SparseFlash2_attention
        x = self.sparse_flash2_attention(x, x, x)  # Pass x as Q, K, and V
        print(f"Shape after SparseFlash2Attention: {x.shape}") 

        # 2. LayerNorm and Dropout
        x = self.dropout(self.layer_norm(x))
        print(f"Shape after LayerNorm: {x.shape}") 

        # 3. Internal Switch Routing
        x, aux_loss = self.switch_router(x, attention_mask, context_texts, question_text)

        # 4. QLORA Layer
        x = self.qlora(x)

        return x, aux_loss

    @staticmethod
    def calculate_new_alpha(current_loss, initial_loss, initial_alpha=1.0, final_alpha=0.1):
        """
        Calculate a new alpha value based on the current loss.
        """
        if current_loss >= initial_loss:
            return initial_alpha  # Keep initial alpha if loss isn't decreasing

        loss_ratio = current_loss / initial_loss
        alpha_range = initial_alpha - final_alpha
        new_alpha = final_alpha + (alpha_range * loss_ratio)
        return new_alpha  
    
    @staticmethod
    def setup_optimizer(model, learning_rate, weight_decay, warmup_steps, total_steps):
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

        # Linear warmup with cosine decay
        scheduler = LambdaLR(optimizer, lambda step: min((step + 1) / warmup_steps, 0.5 * (1 + math.cos(math.pi * step / total_steps))))

        return optimizer, scheduler
    ###############################
    # TRAINING METHODS

    # DPO Training
    def train_dpo(self, train_loader, optimizer, config, save_path):
        self.train()  # Set the model to training mode
        total_loss = 0

        for batch_idx, batch in enumerate(train_loader):
            input_ids = batch['input_ids'].to(config.device)  # Adjusted to use 'input_ids'
            labels = batch['labels'].to(config.device)

            optimizer.zero_grad()
            # Forward pass: Adjust the model's forward method to accept the single 'input_ids' input
            logits, loss = self.transformer_dpo(input_ids, labels=labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        average_loss = total_loss / len(train_loader)
        print(f"Training complete. Average Loss: {average_loss}")

        # Save the model
        torch.save(self.transformer_dpo.state_dict(), save_path)

        return average_loss

    # RAG Training
    def train_language_model_rag(self, model, train_loader, device, vocab_size,num_epochs=5):
        
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-8, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.98)
        
        initial_loss = None

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            model.decoder.toggle_qlora(True)
            total_loss = 0

            for batch_idx, batch in enumerate(train_loader):
                inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                print("Output shape:", outputs.shape)
                print("Targets shape:", targets.shape)
                loss = criterion(outputs.contiguous().view(-1, 30522), targets.view(-1))


                # Check for NaN in loss
                if math.isnan(loss.item()):
                    print("Encountered NaN loss, stopping training")
                    break

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()

                total_loss += loss.item()

                # Set the initial_loss after the first batch of the first epoch
                if initial_loss is None and batch_idx == 0:
                    initial_loss = loss.item()

            scheduler.step()

            # Check for NaN in total_loss
            if math.isnan(total_loss):
                print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
                break
            else:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

            # Average loss for the epoch
            average_loss = total_loss / len(train_loader)

            # Update alpha at the end of each epoch based on the average loss
            new_alpha = self.calculate_new_alpha(average_loss, initial_loss)
            for layer in model.modules():
                if isinstance(layer, Expert.QLORALayer):
                    layer.update_alpha(new_alpha)

        average_loss = total_loss / len(train_loader)
        torch.save(model.state_dict(), save_path)
        return model, average_loss


    # DPR Training
    def train_dpr_encoders(self, train_data, context_encoder, question_encoder, optimizer_context, optimizer_question, epochs, context_save_path, question_save_path):
        loss_function = nn.CosineEmbeddingLoss()

        for epoch in range(epochs):
            total_loss = 0

            for i in range(len(train_data["queries"])):
                query = train_data["queries"][i]
                context_list = train_data["contexts"][i]  # This is a list of context dicts

                # Tokenize query using the tokenize method and convert token IDs back to tensors for model input
                tokenized_query = config.wordpiece_tokenizer.tokenize(query)
                input_ids_query = torch.tensor([tokenized_query], dtype=torch.long).to(config.device)
                attention_mask_query = torch.ones_like(input_ids_query).to(config.device)

                # Since context is a list of dictionaries with 'input_ids', process each and average embeddings
                context_embeddings_list = []
                for context in context_list:
                    if 'input_ids' in context:
                        input_ids_context = torch.tensor([context['input_ids']], dtype=torch.long).to(config.device)
                        attention_mask_context = torch.ones_like(input_ids_context, dtype=torch.bool).to(config.device)

                        # Adjusted for using input_ids and attention_mask directly
                        context_embedding = context_encoder(input_ids_context, attention_mask_context)
                        context_embeddings_list.append(context_embedding)

                # Average the context embeddings if there are multiple contexts
                if context_embeddings_list:
                    context_embeddings = torch.mean(torch.stack(context_embeddings_list), dim=0)
                else:
                    raise ValueError("No valid contexts found for averaging embeddings.")

                # Forward pass for the query
                question_embeddings = question_encoder(input_ids_query, attention_mask_query)

                # Compute loss with labels for positive examples
                labels = torch.tensor([1.0] * question_embeddings.size(0), dtype=torch.float).to(config.device)
                loss = loss_function(question_embeddings, context_embeddings, labels)
                optimizer_context.zero_grad()
                optimizer_question.zero_grad()
                loss.backward()
                optimizer_context.step()
                optimizer_question.step()

                total_loss += loss.item()

            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_data['queries'])}")

        average_loss = total_loss / len(train_data['queries'])
        torch.save(context_encoder.state_dict(), context_save_path)
        torch.save(question_encoder.state_dict(), question_save_path)
        return (context_encoder, question_encoder), average_loss

       
    # LMT Training
    def train_language_model_transformer(self, train_loader, device, vocab_size, save_path):
        model = self.lmt
        
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-8, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.98)
        num_epochs = 5
        
        initial_loss = None

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            model.decoder.toggle_qlora(True)
            total_loss = 0

            for batch_idx, batch in enumerate(train_loader):
                inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

                # Check for NaN in loss
                if math.isnan(loss.item()):
                    print("Encountered NaN loss, stopping training")
                    break

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()

                total_loss += loss.item()

                # Set the initial_loss after the first batch of the first epoch
                if initial_loss is None and batch_idx == 0:
                    initial_loss = loss.item()

            scheduler.step()

            # Check for NaN in total_loss
            if math.isnan(total_loss):
                print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
                break
            else:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

            # Average loss for the epoch
            average_loss = total_loss / len(train_loader)

            # Update alpha at the end of each epoch based on the average loss
            new_alpha = self.calculate_new_alpha(average_loss, initial_loss)
            for layer in model.modules():
                if isinstance(layer, Expert.QLORALayer):
                    layer.update_alpha(new_alpha)

        average_loss = total_loss / len(train_loader)
        torch.save(model.state_dict(), save_path)
        return model, average_loss

    # MAMBA Training
    def train_mamba(self, train_loader, num_epochs, config):
            # Initialize the optimizer and scheduler with MAMBA model parameters
            optimizer, scheduler = self.setup_optimizer(self.mamba, 
                                                        config.mamba_learning_rate, 
                                                        config.weight_decay, 
                                                        config.warmup_steps, 
                                                        config.total_mamba_steps)

            loss_fn = nn.CrossEntropyLoss()
            progress_bar = tqdm(range(num_epochs))

            for epoch in progress_bar:
                self.mamba.train()
                total_loss = 0

                for batch in train_loader:
                    input_values, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['labels']
                    input_values = input_values.to(config.device)
                    attention_mask = attention_mask.to(config.device)
                    labels = labels.to(config.device)

                    optimizer.zero_grad()

                    # Forward pass through MAMBA model
                    outputs = self.mamba(input_values, attention_mask)
                    
                    # Calculate loss
                    loss = loss_fn(outputs.view(-1, config.vocab_size), labels.view(-1))
                    loss.backward()
                    
                    # Clip gradients and perform an optimization step
                    torch.nn.utils.clip_grad_norm_(self.mamba.parameters(), config.clip_gradient)
                    optimizer.step()
                    scheduler.step()

                    total_loss += loss.item()

                avg_loss = total_loss / len(train_loader)
                progress_bar.set_description(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}")

            # Save the trained MAMBA model
            torch.save(self.mamba.state_dict(), config.mamba_model_path)
            print(f"MAMBA Training Complete. Model saved to {config.mamba_model_path}")



    # Full Expert Training
    def train_expert(self, train_loader, train_data, optimizer, main_loss_function, aux_loss_weight, device,save_path, accumulation_steps=4, num_epochs=5):
            self.train()  # Set the model to training mode
            for epoch in range(num_epochs):
                total_loss = 0
                optimizer.zero_grad()  # Initialize gradients to zero
                for batch_idx, batch in enumerate(train_loader):
                    inputs, attention_mask, targets = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)

                    # Calculate start and end indices for current batch in train_data
                    start_idx = batch_idx * batch['input_ids'].size(0)
                    end_idx = start_idx + batch['input_ids'].size(0)

                    # Extract current_queries and current_contexts for the batch
                    current_queries = train_data['queries'][start_idx:end_idx]
                    current_contexts = train_data['contexts'][start_idx:end_idx]

                    # Call to the model forward function
                    outputs, aux_loss = self(inputs, attention_mask, current_contexts, current_queries)

                    # Calculate loss and accumulate
                    main_loss = main_loss_function(outputs.view(-1, outputs.size(-1)), targets.view(-1))
                    loss = (main_loss + aux_loss_weight * aux_loss) / accumulation_steps
                    loss.backward()

                    if (batch_idx + 1) % accumulation_steps == 0:
                        optimizer.step()
                        optimizer.zero_grad()

                    total_loss += loss.item() * accumulation_steps  # Scale back up

                average_loss = total_loss / len(train_loader)
                print(f'End of Epoch {epoch+1}, Average Loss: {average_loss}')
            average_loss = total_loss / len(train_loader)
            torch.save(self.state_dict(), save_path)
            return self, average_loss



##################################################
# Tokenizer

class TrieNode:
    def __init__(self):
        self.children = {}
        self.token_id = None
        self.frequency = 0
        self.failure_link = None
        self.is_end = False  # Add is_end attribute to mark the end of a word
        self.token = None  # Add token attribute to store the token associated with the node


class Trie:
    def __init__(self, unk_token_id=0):
        self.root = TrieNode()
        self.unk_token_id = unk_token_id

    def insert(self, token, token_id, frequency):
        node = self.root
        for char in token:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
        node.token_id = token_id
        node.frequency = frequency

    def find_subwords(self, token):
        """Finds the most probable subwords based on frequency."""
        node = self.root
        best_subwords = []

        def dfs(current_node, subword='', collected_subwords=[]):
            if current_node.token_id is not None:
                # Update to correctly calculate total_frequency based on the structure of collected_subwords
                total_frequency = sum(n.frequency for _, _, n in collected_subwords) + current_node.frequency
                probability = current_node.frequency / total_frequency if total_frequency else 0
                collected_subwords.append((subword, probability, current_node))

            for char, next_node in current_node.children.items():
                dfs(next_node, subword + char, list(collected_subwords))  # Create a copy of the list to avoid shared state

        dfs(node)
        best_subwords = sorted(best_subwords, key=lambda x: x[1], reverse=True)
        return [subword for subword, _, _ in best_subwords][:5] or [self.unk_token_id]


    def compute_failure_links(self):
        root = self.root
        root.failure_link = root  # Root's failure link points to itself
        queue = [root]

        while queue:
            current_node = queue.pop(0)

            for char, child_node in current_node.children.items():
                queue.append(child_node)

                # Follow failure link to find the longest suffix for the child_node
                failure_candidate = current_node.failure_link
                while failure_candidate != root and char not in failure_candidate.children:
                    failure_candidate = failure_candidate.failure_link
                child_node.failure_link = failure_candidate.children.get(char, root)


class SimpleSentencePiece:
    def __init__(self, model_type="bpe", vocab_size=30522):
        self.vocab = {}
        self.id_to_subword = {}
        self.unk_token = "[UNK]"
        self.unk_token_id = 0
        self.vocab_size = vocab_size
        self.model = None if model_type == "bpe" else None
        self.model_type = model_type

    def train(self, text):
        if self.model_type == "bpe":
            self.model = BPE(num_merges=self.vocab_size, unk_token_id=self.unk_token_id)
            self.model.train(text)
            self.vocab = self.model.vocab
            self.id_to_subword = {i: word for word, i in self.vocab.items()}
        else:
            raise NotImplementedError(f"Model type {self.model_type} not supported yet.")

    def encode(self, text):
        text = self.preprocess_text(text)  # Preprocess text before encoding
        if not self.model:
            raise ValueError("Model has not been trained yet.")
        encoded = self.model.encode(text)
        print(f"Encoded: {encoded[:10]}")  # Print first 10 encoded tokens
        return encoded

    def decode(self, ids):
        if not self.id_to_subword:
            raise ValueError("Vocabulary is empty. Ensure the model is trained first.")
        text = " ".join([self.id_to_subword.get(id_, self.unk_token) for id_ in ids])
        text = text.replace(" </w>", "").replace("</w>", " ").strip()
        return text

    def preprocess_text(self, text):
        # Convert text to lowercase to ensure case insensitivity
        text = text.lower()
        # Optionally, handle punctuation by adding spaces around it for better tokenization
        text = re.sub(r'([.,!?()])', r' \1 ', text)
        # Replace multiple spaces with a single space
        text = re.sub(r'\s+', ' ', text)
        # Trim leading and trailing spaces
        text = text.strip()
        return text
    
    def save_model(self, filepath):
        model_data = {
            'vocab': self.vocab,
            'id_to_subword': self.id_to_subword,
            'model_type': self.model_type,
            'vocab_size': self.vocab_size,
            # Potentially include other relevant attributes
        }
        # Save the high-level tokenizer settings
        with open(filepath, 'w') as f:
            json.dump(model_data, f)
        
        # Now save the BPE model specifically
        if self.model_type == "bpe" and self.model:
            self.model.save_model(filepath + "_bpe")

    def load_model(self, filepath):
        with open(filepath, 'r') as f:
            model_data = json.load(f)
        
        self.vocab = model_data['vocab']
        self.id_to_subword = model_data['id_to_subword']
        self.model_type = model_data['model_type']
        self.vocab_size = model_data['vocab_size']
        
        # Assuming model_type is still "bpe", we now load the BPE model
        if self.model_type == "bpe":
            self.model = BPE(self.vocab_size, self.unk_token_id)
            self.model.load_model(filepath + "_bpe")

class BPE:
    def __init__(self, num_merges=100, unk_token_id=0):  # Accept unk_token_id parameter
        self.vocab = {}
        self.merges = []
        self.num_merges = num_merges
        self.unk_token_id = unk_token_id  # Store the unknown token ID

    def train(self, text):
        words = re.findall(r'\w+|[^\w\s]', text, re.UNICODE)
        vocab = collections.Counter(words)
        vocab = {word + '</w>': count for word, count in vocab.items()}
        
        for _ in range(self.num_merges):  # Use the num_merges from the instance variable
            pairs = self.get_stats(vocab)
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            self.merges.append(best)

        self.vocab = {word: i for i, word in enumerate(vocab.keys())}

    @staticmethod
    def get_stats(vocab):
        pairs = collections.defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols)-1):
                pairs[symbols[i], symbols[i+1]] += freq
        return pairs

    @staticmethod
    def merge_vocab(pair, vocab):
        v_out = {}
        bigram = re.escape(' '.join(pair))
        p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
        for word in vocab:
            w_out = p.sub(''.join(pair), word)
            v_out[w_out] = vocab[word]
        return v_out

    def encode(self, text):
        """Encode text into subwords using learned BPE merges."""
        encoded_tokens = []
        for word in re.findall(r'\w+|[^\w\s]', text, re.UNICODE):
            word += '</w>'
            subwords = [word]  # Start with the entire word as one subword
            for merge in self.merges:
                new_subwords = []
                for subword in subwords:
                    # If the merge is in subword, split it; otherwise, keep it as is
                    if ' '.join(merge) in subword:
                        new_subwords.extend(subword.replace(' '.join(merge), ''.join(merge)).split(' '))
                    else:
                        new_subwords.append(subword)
                subwords = new_subwords
            encoded_tokens.extend(subwords)
        return [self.vocab.get(token, self.unk_token_id) for token in encoded_tokens]
    
        # New method to save trained model
    def save_model(self, filepath):
        bpe_data = {
            'merges': self.merges,
            'vocab': self.vocab,
            'num_merges': self.num_merges,
            # Include other attributes as needed
        }
        with open(filepath, 'w') as f:
            json.dump(bpe_data, f)

    def load_model(self, filepath):
        with open(filepath, 'r') as f:
            bpe_data = json.load(f)
        
        self.merges = bpe_data['merges']
        self.vocab = bpe_data['vocab']
        self.num_merges = bpe_data['num_merges']


class WordPiece:
    def __init__(self, vocab, unk_token_id=0, unk_token="[UNK]"):
        self.vocab = vocab
        self.unk_token_id = unk_token_id
        self.unk_token = unk_token  # Define the unknown token
        self.root = self.build_trie(vocab)
        self.id_to_token = {id_: token for token, id_ in vocab.items()}  # Inverse mapping
        self.compute_failure_links(self.root)
        print("Trie built successfully.")

    def convert_ids_to_tokens(self, ids):
        """
        Convert a list of token ids back to their string token representations.
        """
        return [self.id_to_token.get(id_, self.unk_token) for id_ in ids]

    # Add debug prints to build_trie to confirm structure
    def build_trie(self, vocab):
        root = TrieNode()
        for token in vocab:
            node = root
            for char in token:
                if char not in node.children:
                    node.children[char] = TrieNode()
                node = node.children[char]
            node.is_end = True
            node.token = token
        print("Trie Construction Completed Successfully")
        return root


    def compute_failure_links(self, root):
        queue = [root]
        while queue:
            current_node = queue.pop(0)
            for char, child_node in current_node.children.items():
                failure_node = current_node.failure_link
                while failure_node and char not in failure_node.children:
                    failure_node = failure_node.failure_link
                child_node.failure_link = failure_node.children[char] if failure_node else root
                queue.append(child_node)

    # Improved debug prints in tokenize method
                
    def tokenize(self, text):
        # Preprocess input text
        text = self.preprocess_text(text)
        node = self.root
        token_ids = []  # Will store token IDs instead of tokens
        i = 0

        while i < len(text):
            char = text[i]
            if char == ' ':
                node = self.root
                i += 1
                continue

            if char not in node.children:
                if node != self.root and node.token is not None:
                    # Convert found token to its ID
                    token_id = self.vocab.get(node.token, self.unk_token_id)
                    token_ids.append(token_id)
                    node = self.root  # Reset to root
                    continue
                else:
                    # Append unknown token ID
                    token_ids.append(self.unk_token_id)
                    i += 1
                    continue

            node = node.children[char]
            if node.is_end:
                if i + 1 == len(text) or text[i + 1] == ' ':
                    # Convert found token to its ID
                    token_id = self.vocab.get(node.token, self.unk_token_id)
                    token_ids.append(token_id)
                    node = self.root

            i += 1

        print(f"Token IDs: {token_ids[:10]}")
        return token_ids

    def preprocess_text(self, text):
        # Convert text to lowercase to ensure case insensitivity
        text = text.lower()

        # Optionally, handle punctuation by adding spaces around it for better tokenization
        # This depends on how your vocabulary handles punctuation
        text = re.sub(r'([.,!?()])', r' \1 ', text)

        # Replace multiple spaces with a single space
        text = re.sub(r'\s+', ' ', text)

        # Trim leading and trailing spaces
        text = text.strip()

        return text


def load_corpus(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        texts = [line.strip() for line in file.readlines()]
    return texts

texts = load_corpus("D:\\EXPERT_WEIGHTS\\sample.txt")
#texts = load_corpus("/content/drive/MyDrive/EXPERT_STUFF/sample.txt")
# texts = load_corpus("C:/Users/robbi/Expert/sample.txt")
num_merges = 100

# Initialize and train the SimpleSentencePiece model with BPE
ssp = SimpleSentencePiece(model_type="bpe", vocab_size=30522)
# Assume `texts` is a list of text to train the tokenizer
ssp.train('\n'.join(texts))
wordpiece_vocab = ExpertConfig.adapt_vocab_for_wordpiece(ssp.vocab)
# Debugging step to ensure vocabulary completeness
def debug_vocab(adapted_vocab):
    print("Sample Vocabulary Check:")
    # Iterate over the first 10 key-value pairs in the adapted vocabulary
    for i, (token, id_or_freq) in enumerate(adapted_vocab.items()):
        print(f"{token}: {id_or_freq}")
        if i >= 9:  # Stop after printing 10 entries
            break
    # Specifically check for subtokens if your tokenizer expects them
    subtokens = [token for token in adapted_vocab.keys() if token.startswith("##")]
    print(f"Found {len(subtokens)} subtokens in vocabulary.")

# Ensure wordpiece_vocab is a list of vocabulary tokens
# debug_vocab(wordpiece_vocab)  # Call this after initializing wordpiece_vocab

# Initialize WordPiece with the adapted vocabulary
wordpiece_tokenizer = WordPiece(wordpiece_vocab, unk_token_id=0, unk_token="[UNK]")

def find_all_free_token_ids(vocab, max_id=30522):
    """
    Find all free token IDs up to a maximum ID.

    Parameters:
    - vocab: Dictionary, mapping from token to token ID.
    - max_id: Integer, the maximum token ID to consider.

    Returns:
    - List of free token IDs up to max_id.
    """
    used_ids = set(vocab.values())
    return [id_ for id_ in range(max_id) if id_ not in used_ids]

# Assuming wordpiece_vocab is your vocabulary
vocab = {token: id_ for token, id_ in wordpiece_vocab.items()}
free_ids = find_all_free_token_ids(vocab, max_id=30522)  # Adjust max_id as needed
#print("Free token IDs:", free_ids[:100])  # Print the first 100 free IDs for brevity



#################################################################################
# Instantiate Expert Model:
config = ExpertConfig(wordpiece_vocab=wordpiece_vocab, wordpiece_tokenizer=wordpiece_tokenizer)
expert_model = Expert(config)
##################################################################################
##################################################################################
# DPO TRAINING
'''# DPO Training
def preprocess_dpo_data(examples, tokenizer, max_length=512):
    # Calculate the length allocated to each component
    component_max_length = max_length // 3  # Dividing by 3 to equally distribute the max length

    # Tokenize text and ensure it fits within the allocated max_length for each component
    def tokenize_and_trim(text, tokenizer, max_length=component_max_length):
        token_ids = tokenizer.tokenize(text)
        # Trim to the max_length if necessary
        token_ids = token_ids[:max_length]
        return token_ids

    # Tokenize and adjust length for each field
    tokenized_questions = [tokenize_and_trim(question, tokenizer) for question in examples['question']]
    tokenized_chosen = [tokenize_and_trim(chosen, tokenizer) for chosen in examples['chosen']]
    tokenized_rejected = [tokenize_and_trim(rejected, tokenizer) for rejected in examples['rejected']]

    # Generate labels (adjust logic as necessary for your task)
    labels = [1 if i % 2 == 0 else 0 for i in range(len(tokenized_questions))]

    # Prepare final input IDs by concatenating the adjusted token IDs from each component
    # Note: This step may require adjustments based on your specific model input requirements.
    input_ids = [q + c + r for q, c, r in zip(tokenized_questions, tokenized_chosen, tokenized_rejected)]

    # Ensure concatenated input_ids do not exceed the total max_length
    input_ids = [ids[:max_length] for ids in input_ids]

    return {
        'input_ids': input_ids,  # Adjusted to return a single list of concatenated token IDs
        'labels': labels
    }


def updated_custom_collate_fn(batch):
    input_ids_list = [item['input_ids'] for item in batch]  # Assuming 'input_ids' is a list of token IDs
    labels = [item['labels'] for item in batch]

    # Convert list of token IDs to tensors
    input_ids_tensors = [torch.tensor(ids, dtype=torch.long) for ids in input_ids_list]
    labels_tensor = torch.tensor(labels, dtype=torch.long)

    # Pad the sequences so they all have the same length within this batch
    padded_input_ids = pad_sequence(input_ids_tensors, batch_first=True, padding_value=0)

    # Return a dictionary suitable for your model's input
    return {'input_ids': padded_input_ids, 'labels': labels_tensor}


# Assuming the rest of your code for dataset loading and tokenizer initialization remains unchanged
# Example usage:
dpo_dataset = load_dataset("Intel/orca_dpo_pairs")
max_seq_length = 512  # Adjust as needed
# Assuming wordpiece_tokenizer is an instance of your WordPiece class
dpo_dataset = dpo_dataset.map(lambda x: preprocess_dpo_data(x, wordpiece_tokenizer, max_seq_length), batched=True)

# Convert to PyTorch tensors after processing
dpo_dataset.set_format(type='torch', columns=['input_ids', 'labels'])

# Adjust the custom collate function to accept the tokenizer and max_length as arguments

train_loader = DataLoader(dpo_dataset['train'], batch_size=2, shuffle=True, collate_fn=updated_custom_collate_fn)

#train_loader = DataLoader(dpo_dataset['train'], batch_size=2, shuffle=True, collate_fn=custom_collate_fn)


# Instantiate the Expert model and optimizer
model_vocab_size = expert_model.lmt.decoder.word_embedding.num_embeddings
print(f"Model vocab size: {model_vocab_size}")
print(f"Model embedding size: {expert_model.lmt.decoder.word_embedding.embedding_dim}")
print(f"Configured max length: {config.max_length}")

optimizer = AdamW(expert_model.parameters(), lr=1e-5)
#save_path = 'D:/EXPERT_WEIGHTS/dpo_model.pth'
save_path = 'C:/Users/robbi/OneDrive/Expert_stuff/dpo_model.pth'

# Train the DPO model
label_column = 'labels'
input_columns = ['input_ids_question', 'attention_mask_question', 'input_ids_chosen', 'attention_mask_chosen', 'input_ids_rejected', 'attention_mask_rejected']
# Assuming expert_model is an instance of Expert
avg_loss = expert_model.train_dpo(train_loader, optimizer, config, save_path)
# Save the model
#torch.save(expert_model.transformer_dpo.state_dict(), save_path)
'''
##################################################################################
# LMT Training
'''
# Load the wikitext-2 dataset
#dataset = load_dataset("wikitext", "wikitext-2-v1")
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")

# Access a portion of the dataset for inspection
#print(dataset['train'][0])

def generate_attention_mask(token_ids):
    """Generate an attention mask for the given token IDs."""
    return [1 if token_id != 0 else 0 for token_id in token_ids]

def tokenize_function(examples):
    # Initialize lists for token IDs, attention masks, and labels
    token_ids, attention_masks, labels = [], [], []

    for text in examples["text"]:
        # Tokenize text and ensure at least one token ID is generated
        ids = wordpiece_tokenizer.tokenize(text) or [0]  # Replace [0] with your tokenizer's pad token ID
        # Generate attention mask for the tokenized text
        mask = [1] * len(ids)

        # Check for length of ids and pad/truncate as necessary
        if len(ids) < config.max_length:
            # Pad
            pad_length = config.max_length - len(ids)
            ids += [0] * pad_length  # Assuming 0 is the padding ID
            mask += [0] * pad_length
        else:
            # Truncate
            ids = ids[:config.max_length]
            mask = mask[:config.max_length]

        token_ids.append(ids)
        attention_masks.append(mask)
        labels.append(ids)  # For simplicity, using the same IDs as labels; adjust as needed for your model

    return {"input_ids": token_ids, "attention_mask": attention_masks, "labels": labels}



# Apply the custom tokenize function to the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_loader = DataLoader(tokenized_datasets, batch_size=8, shuffle=True)



# Define save path for the trained model
save_path = 'D:/EXPERT_WEIGHTS/lmt_expert_trained_custom_tokenizer.pth'


# Train the LMT sub-model within the Expert system
trained_model, average_loss = expert_model.train_language_model_transformer(
    train_loader=train_loader, 
    device=config.device, 
    vocab_size=config.vocab_size, 
    save_path=save_path
)

print(f"Training complete. Model saved to {save_path}. Average Loss: {average_loss}")
'''
##################################################################################
# MAMBA Training
'''from datasets import load_dataset
from torch.utils.data import DataLoader
import torch

# Assuming the Expert class, ExpertConfig class, and your custom tokenizer have already been defined.

config = ExpertConfig()
expert_system = Expert(config)
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")

from torch.nn.utils.rnn import pad_sequence
import torch

def tokenize_function(examples):
    # Directly tokenize the text into input IDs using the custom tokenizer
    tokenized_outputs = [torch.tensor(wordpiece_tokenizer.tokenize(text), dtype=torch.long) for text in examples["text"]]
    
    # Pad sequences for uniform input size
    padded_input_ids = pad_sequence(tokenized_outputs, batch_first=True, padding_value=wordpiece_tokenizer.unk_token_id)
    
    # Generate attention masks
    attention_masks = (padded_input_ids != wordpiece_tokenizer.unk_token_id).float()
    
    # Shift input IDs to create labels, padding the last position
    labels = torch.cat([padded_input_ids[:, 1:], torch.full((padded_input_ids.shape[0], 1), wordpiece_tokenizer.unk_token_id, dtype=torch.long)], dim=1)
    
    return {"input_ids": padded_input_ids, "attention_mask": attention_masks, "labels": labels}

from torch.nn.utils.rnn import pad_sequence

def custom_collate_fn(batch):
    # Extract input_ids, attention_mask, and labels from the batch
    input_ids = [item['input_ids'] for item in batch]
    attention_masks = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]

    # Pad sequences to the maximum length in this batch
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=wordpiece_tokenizer.unk_token_id)
    attention_masks_padded = pad_sequence(attention_masks, batch_first=True, padding_value=0)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=wordpiece_tokenizer.unk_token_id)

    # Convert lists of tensors to a single tensor for each type of data
    batch = {
        'input_ids': input_ids_padded,
        'attention_mask': attention_masks_padded,
        'labels': labels_padded
    }

    return batch


# Apply the custom tokenize function to the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
train_loader = DataLoader(tokenized_datasets, batch_size=8, shuffle=True, collate_fn=custom_collate_fn)


# Train the MAMBA model
expert_system.train_mamba(train_loader, 5, config)
'''
####################################################################################
# RAG Transformer Training
# Load Wikipedia dataset and preprocess
#dataset = load_dataset("wikipedia", "20220301.en", split="train[:0000001%]")
'''
# Load the wikitext-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")
# Print the column names
#print(dataset.column_names)
# Access a portion of the dataset for inspection
#print(dataset['train'][0])

def generate_attention_mask(token_ids):
    """Generate an attention mask for the given token IDs."""
    return [1 if token_id != 0 else 0 for token_id in token_ids]

def tokenize_function(examples):
    token_ids, attention_masks, labels = [], [], []

    for text in examples["text"]:
        # Tokenize each word in the text and flatten the list of token IDs
        words = text.split()
        ids = []
        for word in words:
            word_ids = config.wordpiece_tokenizer.tokenize(word)
            ids.extend(word_ids)
        
        # Adjust for special tokens ([CLS] and [SEP])
        if len(ids) > config.max_length - 2:
            ids = ids[:config.max_length - 2]
        
        # Add [CLS] at the beginning and [SEP] at the end
        ids = [config.cls_token_id] + ids + [config.sep_token_id]

        attention_mask = [1] * len(ids)  # Attention mask with 1s for real tokens
        
        # Padding
        padding_length = config.max_length - len(ids)
        ids += [config.pad_token_id] * padding_length  # Pad token IDs
        attention_mask += [0] * padding_length  # Extend attention mask for padding

        token_ids.append(ids)
        attention_masks.append(attention_mask)
        labels.append(ids)  # Using the same IDs as labels; this might need adjustment

    return {"input_ids": token_ids, "attention_mask": attention_masks, "labels": labels}


# Apply the custom tokenize function to the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_loader = DataLoader(tokenized_datasets, batch_size=8, shuffle=True)

# Train the LMT_Rag sub-model within the Expert system
# Train th model
model, average_loss = expert_model.train_language_model_rag(
    expert_model.transformer_rag.language_model,
    train_loader,
    config.device,
    vocab_size= len(config.wordpiece_vocab),
    num_epochs=5
)

# Rag data
train_rag_data = {
    "queries": [
        # Queries for DPO.pdf
        "What is Direct Preference Optimization (DPO)?",
        "How does Direct Preference Optimization work?",
        "How can I implement Direct Preference Optimization in my organization?",
        "Why does Direct Preference Optimization improve the efficiency of language modelling?",
        # Queries for MAMBA.pdf
        "What is MAMBA?",
        "How does MAMBA function?",
        "How can I build a system based on MAMBA technology?",
        "Why does MAMBA enhance the performance of its application area?",
        # Queries for QLORA.pdf
        "What is QLORA?",
        "How does QLORA operate?",
        "How can I develop a project using QLORA?",
        "Why does QLORA improve the capabilities of its relevant field?",
        # Queries for RAG.pdf
        "What is Retrieval Augmented Generation (RAG)?",
        "How does Retrieval Augmented Generation work?",
        "How can I build a Retrieval Augmented Generation model?",
        "Why does Retrieval Augmented Generation enhance language model performance?",
        # Queries for SWITCH_TRANSFORMER.pdf
        "What is the Switch Transformer model?",
        "How does the Switch Transformer model operate?",
        "How can I construct a Switch Transformer model?",
        "Why does the Switch Transformer model improve language processing tasks?"
    ],
    "contexts": [
        # Contexts from DPO.pdf
        config.rag_dataset[0],  # Assuming dataset[0] is the processed content of DPO.pdf
        config.rag_dataset[0],
        config.rag_dataset[0],        
        config.rag_dataset[0],        
        # Contexts from MAMBA.pdf
        config.rag_dataset[1],  # Assuming dataset[1] is the processed content of MAMBA.pdf
        config.rag_dataset[1], 
        config.rag_dataset[1], 
        config.rag_dataset[1], 
        # Contexts from QLORA.pdf
        config.rag_dataset[2],  # Assuming dataset[2] is the processed content of QLORA.pdf
        config.rag_dataset[2],
        config.rag_dataset[2],
        config.rag_dataset[2],
        # Contexts from RAG.pdf
        config.rag_dataset[3],  # Assuming dataset[3] is the processed content of RAG.pdf
        config.rag_dataset[3],
        config.rag_dataset[3],
        config.rag_dataset[3],
        # Contexts from SWITCH_TRANSFORMER.pdf
        config.rag_dataset[4],  # Assuming dataset[4] is the processed content of SWITCH_TRANSFORMER.pdf
        config.rag_dataset[4],
        config.rag_dataset[4],
        config.rag_dataset[4],
    ]
}

# Train DPR encoders
(context_encoder, question_encoder), average_loss = expert_model.train_dpr_encoders(
    train_rag_data,
    expert_model.transformer_rag.context_encoder,  
    expert_model.transformer_rag.question_encoder,  
    optimizer_context = AdamW(expert_model.transformer_rag.context_encoder.parameters(), lr=1e-5),  
    optimizer_question = AdamW(expert_model.transformer_rag.question_encoder.parameters(), lr=1e-5),
    epochs=5,
    context_save_path=config.context_encoder_path,
    question_save_path=config.question_encoder_path
)
'''
####################################################################################
# Expert Training
'''
# Load Wikipedia dataset and preprocess
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")

# Access a portion of the dataset for inspection
#print(dataset['train'][0])

def generate_attention_mask(token_ids):
    """Generate an attention mask for the given token IDs."""
    return [1 if token_id != 0 else 0 for token_id in token_ids]

def tokenize_function(examples):
    # Initialize lists for token IDs, attention masks, and labels
    token_ids, attention_masks, labels = [], [], []

    for text in examples["text"]:
        # Tokenize text and ensure at least one token ID is generated
        ids = wordpiece_tokenizer.tokenize(text) or [0]  # Replace [0] with your tokenizer's pad token ID
        # Generate attention mask for the tokenized text
        mask = [1] * len(ids)

        # Check for length of ids and pad/truncate as necessary
        if len(ids) < config.max_length:
            # Pad
            pad_length = config.max_length - len(ids)
            ids += [0] * pad_length  # Assuming 0 is the padding ID
            mask += [0] * pad_length
        else:
            # Truncate
            ids = ids[:config.max_length]
            mask = mask[:config.max_length]

        token_ids.append(ids)
        attention_masks.append(mask)
        labels.append(ids)  # For simplicity, using the same IDs as labels; adjust as needed for your model

    return {"input_ids": token_ids, "attention_mask": attention_masks, "labels": labels}



# Apply the custom tokenize function to the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_loader = DataLoader(tokenized_datasets, batch_size=8, shuffle=True)

main_loss_function = torch.nn.CrossEntropyLoss()
aux_loss_weight = 0.1  # Adjust based on the significance of the auxiliary loss in your training

optimizer = torch.optim.AdamW(expert_model.parameters(), lr=1e-5, weight_decay=1e-4)
save_path = 'D:\\EXPERT_WEIGHTS\\expert_model_weights.pth'
# Rag data
train_rag_data = {
    "queries": [
        # Queries for DPO.pdf
        "What is Direct Preference Optimization (DPO)?",
        "How does Direct Preference Optimization work?",
        "How can I implement Direct Preference Optimization in my organization?",
        "Why does Direct Preference Optimization improve the efficiency of language modelling?",
        # Queries for MAMBA.pdf
        "What is MAMBA?",
        "How does MAMBA function?",
        "How can I build a system based on MAMBA technology?",
        "Why does MAMBA enhance the performance of its application area?",
        # Queries for QLORA.pdf
        "What is QLORA?",
        "How does QLORA operate?",
        "How can I develop a project using QLORA?",
        "Why does QLORA improve the capabilities of its relevant field?",
        # Queries for RAG.pdf
        "What is Retrieval Augmented Generation (RAG)?",
        "How does Retrieval Augmented Generation work?",
        "How can I build a Retrieval Augmented Generation model?",
        "Why does Retrieval Augmented Generation enhance language model performance?",
        # Queries for SWITCH_TRANSFORMER.pdf
        "What is the Switch Transformer model?",
        "How does the Switch Transformer model operate?",
        "How can I construct a Switch Transformer model?",
        "Why does the Switch Transformer model improve language processing tasks?"
    ],
    "contexts": [
        # Contexts from DPO.pdf
        config.rag_dataset[0],  # Assuming dataset[0] is the processed content of DPO.pdf
        config.rag_dataset[0],
        config.rag_dataset[0],        
        config.rag_dataset[0],        
        # Contexts from MAMBA.pdf
        config.rag_dataset[1],  # Assuming dataset[1] is the processed content of MAMBA.pdf
        config.rag_dataset[1], 
        config.rag_dataset[1], 
        config.rag_dataset[1], 
        # Contexts from QLORA.pdf
        config.rag_dataset[2],  # Assuming dataset[2] is the processed content of QLORA.pdf
        config.rag_dataset[2],
        config.rag_dataset[2],
        config.rag_dataset[2],
        # Contexts from RAG.pdf
        config.rag_dataset[3],  # Assuming dataset[3] is the processed content of RAG.pdf
        config.rag_dataset[3],
        config.rag_dataset[3],
        config.rag_dataset[3],
        # Contexts from SWITCH_TRANSFORMER.pdf
        config.rag_dataset[4],  # Assuming dataset[4] is the processed content of SWITCH_TRANSFORMER.pdf
        config.rag_dataset[4],
        config.rag_dataset[4],
        config.rag_dataset[4],
    ]
}
# Train the model
trained_expert_model, average_loss = expert_model.train_expert(
    train_loader=train_loader,
    train_data=train_rag_data,
    optimizer=optimizer,
    main_loss_function=main_loss_function,
    aux_loss_weight=aux_loss_weight,
    device=config.device,
    save_path=save_path,
    accumulation_steps=4,  # Adjust based on your preference
    num_epochs=5  # Adjust based on your training needs
)

print(f"Training completed. Average loss: {average_loss}")
'''




# Define the Switch Transformer / Mixture of Experts architecture
class ExpertRouter(nn.Module):
    def __init__(self, num_experts, project_to_dim):
        super(ExpertRouter, self).__init__()
        self.num_experts = num_experts
        self.project_to_dim = project_to_dim
        # Initialize the projection layer as None; it will be created dynamically
        self.projection_layer = None
        # Placeholder for the routing layer, to be initialized later
        self.routing_layer = None

    def forward(self, x):
        # Dynamically create the projection layer based on the input dimension, if it hasn't been created yet
        if self.projection_layer is None:
            input_feature_dim = x.size(-1)  # Assuming x is of shape [batch_size, ..., feature_dim]
            self.projection_layer = nn.Linear(input_feature_dim, self.project_to_dim).to(x.device)
            # Now that we know the projected dimension, we can initialize the routing layer
            self.routing_layer = nn.Linear(self.project_to_dim, self.num_experts).to(x.device)
        
        # Project the input to the workable dimension
        x_projected = self.projection_layer(x)
        
        # Generate scores for each expert
        expert_scores = self.routing_layer(x_projected)
        # Convert scores to probabilities (e.g., using softmax)
        expert_probs = F.softmax(expert_scores, dim=-1)
        # Decide which expert to route to based on the highest score/probability
        chosen_expert_indices = torch.argmax(expert_probs, dim=1)
        
        return chosen_expert_indices, expert_probs

    def auxiliary_loss(self, gate_scores):
        expert_load = gate_scores.sum(0) / gate_scores.size(0)
        loss_balancing = torch.std(expert_load)
        return loss_balancing


    def route_inputs(self,expert_indices, gate_scores, num_experts):
        capacity_factor_tensor = torch.tensor([Expert.SwitchRouter.CAPACITY_FACTOR], dtype=torch.float32)
        capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
        expert_counts = torch.zeros(num_experts, dtype=torch.int32)
        for idx in range(len(expert_indices)):
            selected_expert = expert_indices[idx]
            if expert_counts[selected_expert] < capacities[0]:
                expert_counts[selected_expert] += 1
            else:
                available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
                if len(available_experts) > 0:
                    alternative_expert = available_experts[0]
                    expert_indices[idx] = alternative_expert
                    expert_counts[alternative_expert] += 1
                else:
                    print("No available experts to reroute. Handling overflow.")
        return expert_indices


class SwitchTransformerMoE(nn.Module):
    def __init__(self, experts, config):
        super(SwitchTransformerMoE, self).__init__()
        self.experts = nn.ModuleList(experts)
        self.config = config
        # Initialize SparseFlash2Attention, LayerNorm, Dropout, and QLORALayer as before
        self.sparse_flash2_attention = Expert.SparseFlash2Attention(
            config.seq_len, config.head_dim, config.block_size, config.sparsity_factor)
        self.layer_norm = nn.LayerNorm(config.seq_len)
        self.dropout = nn.Dropout(config.dropout)
        self.qlora = Expert.QLORALayer(config.input_dim, config.input_dim, config.rank)
        # Initialize ExpertRouter with the number of experts and input feature dimension
        self.expert_router = ExpertRouter(len(experts), project_to_dim=256)
        print(f"SwitchTransformerMoE initialized with embed_size={config.embed_size}, heads={config.heads}")

    def forward(self, x, attention_mask, context_texts, question_text):
        # SparseFlash2Attention, LayerNorm, and Dropout operations as before
        x = self.sparse_flash2_attention(x, x, x)        
        x = self.dropout(self.layer_norm(x))

        # Obtain expert indices and gate scores from ExpertRouter
        chosen_expert_indices, gate_scores = self.expert_router(x)

        # Initialize an empty tensor for the final output
        final_output = torch.zeros_like(x)

        # Iterate over each expert, route inputs, and combine outputs
        for idx, expert in enumerate(self.experts):
            mask = chosen_expert_indices == idx
            if mask.any():
                selected_inputs = x[mask]
                # Adapt expert forwarding based on your model's requirements
                # This is a simplified example assuming experts have a forward method accepting selected_inputs
                expert_output = expert.forward(selected_inputs, attention_mask[mask], context_texts, question_text)
                final_output[mask] = expert_output

        # Calculate auxiliary loss for load balancing across experts
        aux_loss = self.expert_router.auxiliary_loss(gate_scores)

        x = self.qlora(final_output)
        return x, aux_loss



config = ExpertConfig(wordpiece_vocab=wordpiece_vocab, wordpiece_tokenizer=wordpiece_tokenizer)
expert_model_1 = Expert(config)
expert_model_2 = Expert(config)
expert_model_3 = Expert(config)
expert_model_4 = Expert(config)
expert_model_5 = Expert(config)

switch_transformer_moe = SwitchTransformerMoE(
    experts=[
        expert_model_1,
        expert_model_2,
        expert_model_3,
        expert_model_4,
        expert_model_5
    ],
    config=config
)
# Rag data
train_rag_data = {
    "queries": [
        # Queries for DPO.pdf
        "What is Direct Preference Optimization (DPO)?",
        "How does Direct Preference Optimization work?",
        "How can I implement Direct Preference Optimization in my organization?",
        "Why does Direct Preference Optimization improve the efficiency of language modelling?",
        # Queries for MAMBA.pdf
        "What is MAMBA?",
        "How does MAMBA function?",
        "How can I build a system based on MAMBA technology?",
        "Why does MAMBA enhance the performance of its application area?",
        # Queries for QLORA.pdf
        "What is QLORA?",
        "How does QLORA operate?",
        "How can I develop a project using QLORA?",
        "Why does QLORA improve the capabilities of its relevant field?",
        # Queries for RAG.pdf
        "What is Retrieval Augmented Generation (RAG)?",
        "How does Retrieval Augmented Generation work?",
        "How can I build a Retrieval Augmented Generation model?",
        "Why does Retrieval Augmented Generation enhance language model performance?",
        # Queries for SWITCH_TRANSFORMER.pdf
        "What is the Switch Transformer model?",
        "How does the Switch Transformer model operate?",
        "How can I construct a Switch Transformer model?",
        "Why does the Switch Transformer model improve language processing tasks?"
    ],
    "contexts": [
        # Contexts from DPO.pdf
        config.rag_dataset[0],  # Assuming dataset[0] is the processed content of DPO.pdf
        config.rag_dataset[0],
        config.rag_dataset[0],        
        config.rag_dataset[0],        
        # Contexts from MAMBA.pdf
        config.rag_dataset[1],  # Assuming dataset[1] is the processed content of MAMBA.pdf
        config.rag_dataset[1], 
        config.rag_dataset[1], 
        config.rag_dataset[1], 
        # Contexts from QLORA.pdf
        config.rag_dataset[2],  # Assuming dataset[2] is the processed content of QLORA.pdf
        config.rag_dataset[2],
        config.rag_dataset[2],
        config.rag_dataset[2],
        # Contexts from RAG.pdf
        config.rag_dataset[3],  # Assuming dataset[3] is the processed content of RAG.pdf
        config.rag_dataset[3],
        config.rag_dataset[3],
        config.rag_dataset[3],
        # Contexts from SWITCH_TRANSFORMER.pdf
        config.rag_dataset[4],  # Assuming dataset[4] is the processed content of SWITCH_TRANSFORMER.pdf
        config.rag_dataset[4],
        config.rag_dataset[4],
        config.rag_dataset[4],
    ]
}
# Example training loop (simplified)
num_epochs = 5
# Load Wikipedia dataset and preprocess
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")

# Access a portion of the dataset for inspection
#print(dataset['train'][0])

def generate_attention_mask(token_ids):
    """Generate an attention mask for the given token IDs."""
    return [1 if token_id != 0 else 0 for token_id in token_ids]

def tokenize_function(examples):
    # Initialize lists for token IDs, attention masks, and labels
    token_ids, attention_masks, labels = [], [], []

    for text in examples["text"]:
        # Tokenize text and ensure at least one token ID is generated
        ids = wordpiece_tokenizer.tokenize(text) or [0]  # Replace [0] with your tokenizer's pad token ID
        # Generate attention mask for the tokenized text
        mask = [1] * len(ids)

        # Check for length of ids and pad/truncate as necessary
        if len(ids) < config.max_length:
            # Pad
            pad_length = config.max_length - len(ids)
            ids += [0] * pad_length  # Assuming 0 is the padding ID
            mask += [0] * pad_length
        else:
            # Truncate
            ids = ids[:config.max_length]
            mask = mask[:config.max_length]

        token_ids.append(ids)
        attention_masks.append(mask)
        labels.append(ids)  # For simplicity, using the same IDs as labels; adjust as needed for your model

    return {"input_ids": token_ids, "attention_mask": attention_masks, "labels": labels}



# Apply the custom tokenize function to the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_loader = DataLoader(tokenized_datasets, batch_size=1, shuffle=True)
optimizer = torch.optim.Adam(switch_transformer_moe.parameters(), lr=1e-4)
main_loss_function = torch.nn.CrossEntropyLoss()
aux_loss_weight = 0.2
accumulation_steps = 4

def adjust_input_ids(input_ids, vocab_size):
    # Ensure the operation is done on the same device as input_ids
    adjusted_ids = torch.where(input_ids < vocab_size, input_ids, torch.tensor(vocab_size - 1, device=input_ids.device))
    return adjusted_ids

for epoch in range(num_epochs):
    total_loss = 0
    optimizer.zero_grad()  # Initialize gradients to zero
    for batch_idx, batch in enumerate(train_loader):
        inputs, attention_mask, targets = batch['input_ids'].to(config.device), batch['attention_mask'].to(config.device), batch['labels'].to(config.device)
        # Adjust input_ids here to ensure they are within the vocabulary range
        vocab_size = config.vocab_size  # Ensure this matches the vocab_size used in your embedding layer
        inputs = adjust_input_ids(inputs, vocab_size)
        # Calculate start and end indices for current batch in train_data
        start_idx = batch_idx * batch['input_ids'].size(0)
        end_idx = start_idx + batch['input_ids'].size(0)

        # Extract current_queries and current_contexts for the batch
        current_queries = train_rag_data['queries'][start_idx:end_idx]
        current_contexts = train_rag_data['contexts'][start_idx:end_idx]

        # Call to the model forward function
        outputs, aux_loss = switch_transformer_moe(inputs, attention_mask, current_contexts, current_queries)

        # Calculate loss and accumulate
        main_loss = main_loss_function(outputs.view(-1, outputs.size(-1)), targets.view(-1))
        loss = (main_loss + aux_loss_weight * aux_loss) / accumulation_steps
        loss.backward()

        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item() * accumulation_steps  # Scale back up

    average_loss = total_loss / len(train_loader)
    print(f'End of Epoch {epoch+1}, Average Loss: {average_loss}')
average_loss = total_loss / len(train_loader)
#torch.save(self.state_dict(), save_path)




FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/EXPERT_STUFF/sample.txt'