In [None]:
# 0. Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import fitz  # PyMuPDF
from transformers import BertModel
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

torch.autograd.set_detect_anomaly(True)

# 2. Sub-Model Weights
# 1. Transformer with DPO:
tran_dpo = r'C:\Users\robbi\IEEMM\language_model_weights.pth'
# 2. MAMBA:
mamba_model_path = r'C:\Users\robbi\IEEMM\mamba_model_weights.pth'
# 3. Transformer and RAG:
context_encoder = r'C:\Users\robbi\IEEMM\context_encoder.pth'
language_model = r'C:\Users\robbi\IEEMM\language_model.pth'
question_encoder = r'C:\Users\robbi\IEEMM\question_encoder.pth'

# Load model weights function
def load_model_weights(model, model_path):
    #checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    checkpoint = torch.load(model_path, map_location=device)

    if isinstance(checkpoint, dict):
        # Check for 'state_dict' or 'model_state_dict' keys
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        elif 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            # If no known key is found, try loading it as a raw state dictionary
            try:
                model.load_state_dict(checkpoint)
            except RuntimeError as e:
                raise ValueError(f"Error loading state dict: {e}")
    elif isinstance(checkpoint, nn.Module):
        # If the checkpoint is a model object, assign it directly
        model = checkpoint
    else:
        raise ValueError(f"Unsupported checkpoint format: {type(checkpoint)}")

    model.eval()
    return model

# 3. MAMBA
# RoPE
class RotaryPositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_len).type_as(inv_freq)
        freqs = torch.einsum('n , d -> n d', t, inv_freq)
        self.register_buffer('sin', freqs.sin())
        self.register_buffer('cos', freqs.cos())

    def forward(self, x):
        n, _, device = x.shape[1], self.dim // 2, x.device
        sin, cos = self.sin[:n].to(device), self.cos[:n].to(device)

        # Apply RoPE to even and odd indices separately
        x_even = x[..., :self.dim:2] * cos.unsqueeze(0) + torch.roll(x[..., 1:self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
        x_odd = x[..., 1:self.dim:2] * cos.unsqueeze(0) - torch.roll(x[..., :self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
        return torch.cat((x_even, x_odd), dim=-1)

# SWIGLU
class SwiGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(SwiGLU, self).__init__()
        self.fc1 = nn.Linear(dim_in, dim_out)
        self.fc2 = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        gate = torch.sigmoid(self.fc2(x))
        return self.fc1(x) * gate

class SimplifiedMAMBA(nn.Module):
    # Adjusted to include SwiGLU blocks
    def __init__(self, num_layers, d_model, d_state, d_conv, expansion_factor):
        super().__init__()

        self.num_layers = num_layers
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_state),
            nn.GELU(),
            nn.Linear(d_state, d_model)
        )
        self.input_embedding = nn.Linear(d_model, d_model)
        self.convs = nn.Sequential(*[nn.Conv1d(d_model, d_model, kernel_size=d_conv, padding=(d_conv // 2)) for _ in range(num_layers)])
        self.swiglu = SwiGLU(d_model, d_model)
        self.output_projection = nn.Linear(d_model, d_model * expansion_factor)  # Adjusted to match the output of SwiGLU


        self.initialize_weights()

    def initialize_weights(self):
        gain = nn.init.calculate_gain('relu')

        nn.init.orthogonal_(self.input_embedding.weight, gain)
        nn.init.normal_(self.input_embedding.bias, mean=0, std=0.01)

        nn.init.kaiming_uniform_(self.convs[-1].weight, a=math.sqrt(5))
        nn.init.zeros_(self.convs[-1].bias)

        nn.init.xavier_uniform_(self.feedforward[0].weight, gain=nn.init.calculate_gain('relu'))
        nn.init.zeros_(self.feedforward[0].bias)

        nn.init.xavier_uniform_(self.feedforward[2].weight, gain=nn.init.calculate_gain('linear'))
        nn.init.zeros_(self.feedforward[2].bias)

        nn.init.xavier_uniform_(self.output_projection.weight, gain=nn.init.calculate_gain('linear'))
        nn.init.zeros_(self.output_projection.bias)

    def forward(self, inputs, attention_mask=None):
        print("Input shape:", inputs.shape)

        # Apply the attention mask if provided
        if attention_mask is not None:
            inputs = inputs * attention_mask.unsqueeze(-1)

        projected_inputs = self.input_embedding(inputs)
        print("projected_inputs pre-reshape shape:", projected_inputs.shape)

        projected_inputs = projected_inputs.permute(0, 2, 1)
        print("projected_inputs post-reshape shape:", projected_inputs.shape)

        for conv in self.convs:
            projected_inputs = conv(projected_inputs)

        projected_inputs = projected_inputs.permute(0, 2, 1)
        print("projected_inputs post convolution reshape:", projected_inputs.shape)

        projected_inputs = self.swiglu(projected_inputs)
        print("projected_inputs post swiglu shape:", projected_inputs.shape)

        output = self.output_projection(projected_inputs)
        print("output shape:", output.shape)

        return output

class SimplifiedLanguageModelMAMBA(nn.Module):
    # Including rotary positional encodings if required
    def __init__(self, vocab_size, num_layers, d_model, d_state, d_conv, expansion_factor):
        super().__init__()

        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = RotaryPositionalEncoding(d_model)
        self.simplified_mamba = SimplifiedMAMBA(num_layers, d_model, d_state, d_conv, expansion_factor)
        self.output_projection = nn.Linear(d_model*2, vocab_size)

        self.initialize_weights()

    def initialize_weights(self):
        gain = 1.0

        nn.init.orthogonal_(self.embedding.weight, gain)
        nn.init.xavier_uniform_(self.output_projection.weight, gain=gain)

    def forward(self, input_values, attention_mask):
        embedded = self.embedding(input_values) * math.sqrt(self.d_model)
        embedded = self.pos_encoder(embedded)
        simplified_mamba_output = self.simplified_mamba(embedded, attention_mask)
        logits = self.output_projection(simplified_mamba_output)
        return logits


def setup_optimizer(model, learning_rate, weight_decay, warmup_steps, total_steps):
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Linear warmup with cosine decay
    scheduler = LambdaLR(optimizer, lambda step: min((step + 1) / warmup_steps, 0.5 * (1 + math.cos(math.pi * step / total_steps))))

    return optimizer, scheduler

# 4. RAG
def extract_text_from_pdf(file_path):
    text = ""
    with fitz.open(file_path) as doc:
        for page in doc:
            text += page.get_text()
    return text

def split_into_chunks(text, chunk_size):
    # Split the text into chunks of chunk_size
    return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

def preprocess_text(text, max_length=512):
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Split the text into smaller chunks to maintain context
    # The chunk size is slightly less than max_length to account for special tokens
    chunk_size = max_length - 50  # Adjust this value based on the model's requirements
    text_chunks = split_into_chunks(text, chunk_size)

    # Process each chunk
    processed_chunks = []
    for chunk in text_chunks:
        tokenized_output = tokenizer(chunk, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
        processed_chunk = {
            'input_ids': tokenized_output['input_ids'],
            'attention_mask': tokenized_output['attention_mask']
        }
        processed_chunks.append(processed_chunk)

    return processed_chunks

def create_dataset_from_pdfs(pdf_file_paths):
    dataset = []
    for file_path in pdf_file_paths:
        text = extract_text_from_pdf(file_path)
        processed_text = preprocess_text(text)
        dataset.append(processed_text)
    return dataset

def retrieve_contexts(dataset, query_embedding, top_k=5):
    # Assume dataset is a list of dictionaries with each dictionary containing 'input_ids' and 'attention_mask'
    # for a particular context and that each context has been processed through a DPRContextEncoder to get embeddings

    # Placeholder for storing similarity scores
    similarity_scores = []

    # Iterate over each context in the dataset
    for context in dataset:
        context_input_ids = context['input_ids']
        context_attention_mask = context['attention_mask']

        # Assuming context_encoder is an instance of CustomDPRContextEncoder that's already trained
        # and available in your scope
        context_embedding = context_encoder(context_input_ids, context_attention_mask)

        # Compute similarity (e.g., using dot product)
        similarity = torch.matmul(query_embedding, context_embedding.T)

        similarity_scores.append(similarity.squeeze().item())

    # Sort contexts based on similarity scores and retrieve top_k indices
    top_k_indices = sorted(range(len(similarity_scores)), key=lambda i: similarity_scores[i], reverse=True)[:top_k]

    # Retrieve top_k contexts
    top_contexts = [dataset[i] for i in top_k_indices]

    return top_contexts

def rag_retrieve_and_generate(dataset, query):
    # Instantiate the question encoder
    question_encoder = DPRQuestionEncoder()

    # Encode the query
    encoded_query = question_encoder(query)

    # Retrieve relevant context
    # This involves finding the most similar documents in the dataset
    # For simplicity, this is represented as a function 'retrieve_contexts'
    relevant_contexts = retrieve_contexts(dataset, encoded_query)

    # Language model for generation
    language_model = LanguageModelTransformer()

    # Generate a response based on the retrieved contexts
    # This step may involve further formatting or preprocessing
    response = language_model.generate_response(relevant_contexts)

    return response

# pdfs
pdf_file_paths = [r'C:\Users\robbi\IEEMM\DPO.pdf', 
                  r'C:\Users\robbi\IEEMM\MAMBA.pdf',
                  r'C:\Users\robbi\IEEMM\QLORA.pdf',
                  r'C:\Users\robbi\IEEMM\RAG.pdf',
                  r'C:\Users\robbi\IEEMM\SWITCH_TRANSFORMER.pdf']

rag_dataset = create_dataset_from_pdfs(pdf_file_paths)

class CustomDPRContextEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
        super(CustomDPRContextEncoder, self).__init__()
        # Transformer-based model, e.g., BERT
        self.bert_model = BertModel.from_pretrained(model_name)
        # Additional layer to produce fixed-size embeddings
        self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

    def forward(self, input_ids, attention_mask=None):
        # Generate outputs from the BERT model
        outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        # Use the pooled output for creating embeddings
        pooled_output = outputs.pooler_output
        # Pass through the embedding layer
        context_embeddings = self.embedding_layer(pooled_output)
        return context_embeddings

class DPRQuestionEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
        super(DPRQuestionEncoder, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.embedding_layer = nn.Linear(self.bert.config.hidden_size, embedding_dim)

    def forward(self, input_ids, attention_mask, **kwargs):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        embeddings = self.embedding_layer(pooled_output)
        return embeddings

# TransformerRAG class
class TransformerRAG(nn.Module):
    def __init__(self, context_encoder_path, language_model_path, question_encoder_path, vocab_size):
        super(TransformerRAG, self).__init__()
        self.context_encoder = CustomDPRContextEncoder().to(device)
        self.language_model = LanguageModelTransformer(
            vocab_size=vocab_size,
            embed_size=256, 
            num_layers=6, 
            forward_expansion=4, 
            heads=8, 
            dropout=0, 
            max_length=100,  # Set to 512 to match the tokenization max_length
            rank=16
        ).to(device)
        self.language_model = load_model_weights(self.language_model, language_model_path)
        self.question_encoder = DPRQuestionEncoder().to(device)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def forward(self, context_texts, question_input_ids, question_attention_mask, question_text):
        if question_input_ids.max() >= self.tokenizer.vocab_size:
            raise ValueError("question_input_ids contain ID(s) beyond the tokenizer's vocabulary size")

        # Process each context_text
        aggregated_context_embeddings = []
        for context_list in context_texts:
            if not all(isinstance(context, dict) for context in context_list):
                raise TypeError("Each item in context_texts must be a list of tokenized context dictionaries")

            # Create a tensor of zeros with the correct shape on the GPU
            aggregated_context_embedding = torch.zeros(self.context_encoder.bert_model.config.hidden_size, device=device)
            for context in context_list:
                context_input_ids = context['input_ids'].to(device)
                context_attention_mask = context['attention_mask'].to(device)
                context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                aggregated_context_embedding += context_embedding.mean(dim=0)

            aggregated_context_embeddings.append(aggregated_context_embedding / len(context_list))

        question_input_ids = question_input_ids.to(device).long()
        question_attention_mask = question_attention_mask.to(device).long()
        question_embeddings = self.question_encoder(input_ids=question_input_ids, attention_mask=question_attention_mask)

        cos_sim = torch.nn.CosineSimilarity(dim=1)
        similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in aggregated_context_embeddings]
        most_relevant_context_idx = torch.argmax(torch.tensor(similarities, device=device))

        combined_input = question_text + " " + context_texts[most_relevant_context_idx]
        tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
        tokenized_combined_input = {k: v.to(device) for k, v in tokenized_combined_input.items()}  # Move to GPU
        response_logits = self.language_model(**tokenized_combined_input)
        probabilities = F.softmax(response_logits.logits, dim=-1)
        predicted_token_ids = torch.argmax(probabilities, dim=-1)
        predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])
        response = self.tokenizer.convert_tokens_to_string(predicted_tokens)

        return response

class LORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1):
        super(LORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha

        # Original weight and bias of the linear layer
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        #print("self.weight Shape:", self.weight.shape)
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # LORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def forward(self, x):
        #print("LORALayer Input Shape:", x.shape)
        
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)
        
        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        #print("LORALayer Output Shape:", x.shape)

        return x

class QLORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
        super(QLORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha
        self.quantization_bits = quantization_bits

        # Original weight and bias
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # QLORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(output_dim)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def quantize(self, x, num_bits):
        # Implement a simple quantization method
        scale = x.abs().max()
        x_quantized = torch.round(x / scale * (2**num_bits - 1))
        return x_quantized, scale

    def forward(self, x):
        #print("QLORALayer Input Shape:", x.shape)
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
        B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
        lora_adjustment = self.dropout(lora_adjustment)
        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)

        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        x = self.layer_norm(x)

        #print("QLORALayer Output Shape:", x.shape)

        return x
    
    def update_alpha(self, new_alpha):
        """
        Update the alpha scaling factor.
        """
        self.alpha = new_alpha

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does the matrix multiplication for query*keys for each training example
        # with every other training example, then sum it up
        attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            attention = attention.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            LORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            LORALayer(forward_expansion * embed_size, embed_size, rank),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class LanguageModelDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
        super(LanguageModelDecoder, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        # Adding BatchNorm layers
        self.bn1 = nn.BatchNorm1d(embed_size)
        self.bn2 = nn.BatchNorm1d(embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                for _ in range(num_layers)
            ]
        )

        # QLORA layers
        self.qlora_feed_forward = nn.Sequential(
            QLORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            QLORALayer(forward_expansion * embed_size, embed_size, rank),
        )
        self.use_qlora = False  # Flag to toggle QLORA

        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length) #.to(x.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn1(x)
        x = x.transpose(1, 2)

        for layer in self.layers:
            x = layer(x, x, x, trg_mask)
            if self.use_qlora:
                x = self.qlora_feed_forward(x)

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn2(x)
        x = x.transpose(1, 2)

        out = self.fc_out(x)
        #print(f"shape of output of forward method of LanguageModelDecoder: {out.shape} ")

        return out

    def toggle_qlora(self, use_qlora: bool):
        self.use_qlora = use_qlora

class LanguageModelTransformer(nn.Module):
    def __init__(self, vocab_size, embed_size=256, num_layers=6, forward_expansion=4, heads=8, dropout=0, max_length=100, rank=16):
        super(LanguageModelTransformer, self).__init__()

        self.decoder = LanguageModelDecoder(
            vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
            rank,
        )

    def forward(self, trg):
        trg_mask = self.make_trg_mask(trg)
        out = self.decoder(trg, trg_mask)
        return out

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        ).to(trg.device)

        return trg_mask
    
    # Function to enable or disable QLORA layers (for fine-tuning purposes)
    def toggle_qlora(self, use_qlora: bool):
        self.decoder.toggle_qlora(use_qlora)

    def generate_response(self, input_ids, attention_mask):
        # Assuming you have a forward method that returns logits
        logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
        
        # Convert logits to probabilities
        probabilities = F.softmax(logits, dim=-1)
        
        # For simplicity, using greedy decoding here. You might want to use beam search or sampling.
        predicted_token_id = torch.argmax(probabilities, dim=-1)
        
        # Convert predicted token ids to tokens
        predicted_tokens = [tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
        
        # Join tokens to form the response string. This is a very basic way to generate text and might not produce the best results.
        response = tokenizer.convert_tokens_to_string(predicted_tokens)
        
        return response

# Flash2_Attention

class FlashAttention2(nn.Module):
    def __init__(self, sequence_length, head_dimension, block_size):
        super(FlashAttention2, self).__init__()
        self.block_size = block_size
        # Ensure that sequence_length is divisible by block_size for simplicity
        assert sequence_length % block_size == 0

    def forward(self, Q, K, V):
        # Partitioning of inputs
        Q_blocks, K_blocks, V_blocks = self.partition_inputs(Q, K, V)

        # Efficient computation of the attention mechanism
        outputs = []
        for i, Q_block in enumerate(Q_blocks):
            output_block = self.process_block(Q_block, K_blocks, V_blocks)
            outputs.append(output_block)

        # Concatenating the processed blocks
        output = torch.cat(outputs, dim=0)
        return output

    def partition_inputs(self, Q, K, V):
        # The actual partitioning scheme should be based on sequence length, head dimension, and block size
        Q_blocks = Q.chunk(chunks=Q.size(0) // self.block_size, dim=0)
        K_blocks = K.chunk(chunks=K.size(0) // self.block_size, dim=0)
        V_blocks = V.chunk(chunks=V.size(0) // self.block_size, dim=0)
        return Q_blocks, K_blocks, V_blocks

    def process_block(self, Q_block, K_blocks, V_blocks):
        # Process each block efficiently as per FLASH2's optimized method
        # This includes computing QK^T, applying online softmax, and multiplying with V
        output_blocks = []
        for K_block, V_block in zip(K_blocks, V_blocks):
            attention_scores = torch.matmul(Q_block, K_block.transpose(-2, -1))
            attention_scores = self.online_softmax(attention_scores)
            output_block = torch.matmul(attention_scores, V_block)
            output_blocks.append(output_block)

        # Summing up the results from each block
        output_block_sum = sum(output_blocks)
        return output_block_sum

    def online_softmax(self, scores, chunk_size=128):
        # Apply softmax in chunks for large sequences
        softmaxed_scores = []
        for i in range(0, scores.size(0), chunk_size):
            chunk = scores[i:i + chunk_size, :]
            softmaxed_chunk = F.softmax(chunk, dim=1)
            softmaxed_scores.append(softmaxed_chunk)
        return torch.cat(softmaxed_scores, dim=0)

# SparseFlash2_Attention

class SparseFlash2Attention(nn.Module):
    def __init__(self, seq_len, head_dim, blk_size, sparsity_factor):
        super().__init__()
        self.flash_attention = FlashAttention2(seq_len, head_dim, blk_size)
        self.seq_len = seq_len
        self.head_dim = head_dim
        self.block_size = blk_size  # Storing block_size as an instance variable
        self.sparsity_factor = sparsity_factor

    def generate_sparsity_mask(self):
        mask = torch.zeros(self.seq_len, self.seq_len)
        step = self.sparsity_factor
        for i in range(0, self.seq_len, step):
            mask[i:i + step, :] = 1
        return mask.bool()

    def forward(self, Q, K, V):
        output = self.flash_attention(Q, K, V)  # output shape: [sequence_length, head_dimension]

        # Reshape output to be 3D for batch matrix multiplication
        output = output.unsqueeze(0)  # New shape: [1, sequence_length, head_dimension]

        sparsity_mask = self.generate_sparsity_mask()  # shape: [sequence_length, sequence_length]

        # Apply the sparsity mask to the output
        sparsity_mask = sparsity_mask.unsqueeze(0)  # New shape: [1, sequence_length, sequence_length]
        output = torch.bmm(sparsity_mask.float(), output.float())  # Perform batch matrix multiplication

        # Reshape the output back to 2D
        output = output.squeeze(0)  # New shape: [sequence_length, head_dimension]

        return output



# Auxiliary loss function
def auxiliary_loss(gate_scores, expert_capacity):
    expert_load = gate_scores.sum(0) / gate_scores.size(0)
    loss_balancing = torch.std(expert_load)
    return loss_balancing

# Routing function
CAPACITY_FACTOR = 1

def route_inputs(expert_indices, gate_scores, num_experts):
    capacity_factor_tensor = torch.tensor([CAPACITY_FACTOR], dtype=torch.float32)
    capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
    expert_counts = torch.zeros(num_experts, dtype=torch.int32)
    for idx in range(len(expert_indices)):
        selected_expert = expert_indices[idx]
        if expert_counts[selected_expert] < capacities[0]:
            expert_counts[selected_expert] += 1
        else:
            available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
            if len(available_experts) > 0:
                alternative_expert = available_experts[0]
                expert_indices[idx] = alternative_expert
                expert_counts[alternative_expert] += 1
            else:
                print("No available experts to reroute. Handling overflow.")
    return expert_indices

# SwitchGate 
class SwitchGate(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(SwitchGate, self).__init__()
        self.fc1 = nn.Linear(input_dim, input_dim // 2)
        self.fc2 = nn.Linear(input_dim // 2, num_experts)

    def forward(self, x):
        x = F.relu(self.fc1(x.float()))
        gate_scores = F.softmax(self.fc2(x), dim=-1)
        return gate_scores

# SwitchRouter 
class SwitchRouter(nn.Module):
    def __init__(self, input_dim, num_experts, mamba_model_path, context_encoder_path, language_model_path, question_encoder_path, dpo_model_path, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank):
        super(SwitchRouter, self).__init__()
        self.router = SwitchGate(input_dim, num_experts)
        self.transformer_rag = TransformerRAG(context_encoder_path, 
                                              language_model_path, 
                                              question_encoder_path, 
                                              vocab_size).to(device)
        self.transformer_dpo = LanguageModelTransformer(vocab_size, 
                                                        embed_size, 
                                                        num_layers, 
                                                        forward_expansion, 
                                                        heads, dropout, 
                                                        max_length, 
                                                        rank).to(device)
        self.transformer_dpo = load_model_weights(self.transformer_dpo, dpo_model_path)
        self.mamba = SimplifiedLanguageModelMAMBA(vocab_size=VOCAB_SIZE, 
                                     num_layers=NUM_LAYERS, 
                                     d_model=D_MODEL, 
                                     d_state=D_STATE, 
                                     d_conv=D_CONV, 
                                     expansion_factor=EXPANSION_FACTOR).to(device)
        self.mamba = load_model_weights(self.mamba, mamba_model_path)
        self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
        self.input_embedding = nn.Linear(512, input_dim)

    def forward(self, x, attention_mask, context_texts, question_text):
        x = self.input_embedding(x.float())
        gate_scores = self.router(x)
        expert_indices = torch.argmax(gate_scores, dim=1)
        expert_indices = route_inputs(expert_indices, gate_scores, len(self.experts))
        final_output = torch.zeros_like(x)
        aux_loss = 0

        for i, expert in enumerate(self.experts):
            mask = expert_indices == i
            if mask.any():
                selected_inputs = x[mask]
                selected_attention_mask = attention_mask[mask]

                if isinstance(expert, TransformerRAG):
                    # Now passing the required arguments to TransformerRAG
                    expert_output = self.transformer_rag(context_texts, selected_inputs, selected_attention_mask, question_text)
                else:
                    # Process as usual for other experts
                    expert_output = expert(selected_inputs, selected_attention_mask)

                final_output[mask] = expert_output

        # Compute auxiliary loss for load balancing
        aux_loss += auxiliary_loss(gate_scores, expert_capacity=torch.tensor([CAPACITY_FACTOR] * len(self.experts)))

        return final_output, aux_loss

class Expert(nn.Module):
    def __init__(self, sparsity_factor, seq_len, head_dim, blk_size, input_dim, num_experts, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank):
        super(Expert, self).__init__()

        # 1. SparseFlash2_attention
        self.sparse_flash2_attention = SparseFlash2Attention(seq_len, head_dim, blk_size, sparsity_factor)

        # 2. LayerNorm and Dropout
        self.layer_norm = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(0.1)

        # 3. Internal Switch Routing
        self.switch_router = SwitchRouter(input_dim, num_experts, mamba_model_path, context_encoder_path, language_model_path, question_encoder_path, dpo_model_path, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank)

        # Sub-models are part of SwitchRouter

        # 4. QLORA Layer
        self.qlora = QLORALayer(input_dim, input_dim, rank)  # Adjust dimensions as needed

    def forward(self, x, attention_mask, context_texts, question_text):
        # 1. SparseFlash2_attention
        x = self.sparse_flash2_attention(x)

        # 2. LayerNorm and Dropout
        x = self.dropout(self.layer_norm(x))

        # 3. Internal Switch Routing
        x, _ = self.switch_router(x, attention_mask, context_texts, question_text)

        # 4. QLORA Layer
        x = self.qlora(x)

        return x


# Expert v2



In [None]:
class SwitchRouter(nn.Module):
    CAPACITY_FACTOR = 1  # Class constant

    class SwitchGate(nn.Module):
        def __init__(self, input_dim, num_experts):
            super(SwitchRouter.SwitchGate, self).__init__()
            self.fc1 = nn.Linear(input_dim, input_dim // 2)
            self.fc2 = nn.Linear(input_dim // 2, num_experts)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            gate_scores = F.softmax(self.fc2(x), dim=-1)
            return gate_scores

    def __init__(self, input_dim, num_experts, mamba_model_path, context_encoder_path, language_model_path, question_encoder_path, dpo_model_path, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank, d_model, d_state, d_conv, expansion_factor, device):
        super(SwitchRouter, self).__init__()
        self.device = device
        self.router = self.SwitchGate(input_dim, num_experts).to(self.device)
        self.transformer_rag = TransformerRAG(context_encoder_path, language_model_path, question_encoder_path, vocab_size).to(self.device)
        self.transformer_dpo = LanguageModelTransformer(vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank).to(self.device)
        self.transformer_dpo = load_model_weights(self.transformer_dpo, dpo_model_path)
        self.mamba = SimplifiedLanguageModelMAMBA(vocab_size, num_layers, d_model, d_state, d_conv, expansion_factor).to(self.device)
        self.mamba = load_model_weights(self.mamba, mamba_model_path)
        self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
        self.input_embedding = nn.Linear(512, input_dim)

    def forward(self, x, attention_mask, context_texts, question_text):
        x = self.input_embedding(x.float()).to(self.device)
        gate_scores = self.router(x)
        expert_indices = torch.argmax(gate_scores, dim=1)
        expert_indices = self.route_inputs(expert_indices, gate_scores, len(self.experts))
        final_output = torch.zeros_like(x)

        for i, expert in enumerate(self.experts):
            mask = expert_indices == i
            if mask.any():
                selected_inputs = x[mask]
                selected_attention_mask = attention_mask[mask]

                # Handle expert processing
                if isinstance(expert, TransformerRAG):
                    expert_output = self.transformer_rag(context_texts, selected_inputs, selected_attention_mask, question_text)
                else:
                    expert_output = expert(selected_inputs, selected_attention_mask)

                final_output[mask] = expert_output

        # Compute auxiliary loss for load balancing
        aux_loss = self.auxiliary_loss(gate_scores)

        return final_output, aux_loss

    @staticmethod
    def auxiliary_loss(gate_scores):
        expert_load = gate_scores.sum(0) / gate_scores.size(0)
        loss_balancing = torch.std(expert_load)
        return loss_balancing

    @staticmethod
    def route_inputs(expert_indices, gate_scores, num_experts):
        capacity_factor_tensor = torch.tensor([SwitchRouter.CAPACITY_FACTOR], dtype=torch.float32)
        capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
        expert_counts = torch.zeros(num_experts, dtype=torch.int32)

        for idx in range(len(expert_indices)):
            selected_expert = expert_indices[idx]
            if expert_counts[selected_expert] < capacities[0]:
                expert_counts[selected_expert] += 1
            else:
                available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
                if len(available_experts) > 0:
                    alternative_expert = available_experts[0]
                    expert_indices[idx] = alternative_expert
                    expert_counts[alternative_expert] += 1
                else:
                    print("No available experts to reroute. Handling overflow.")

        return expert_indices


# v3

In [None]:
class SwitchRouter(nn.Module):
    CAPACITY_FACTOR = 1  # Class constant

    class SwitchGate(nn.Module):
        def __init__(self, input_dim, num_experts):
            super(SwitchRouter.SwitchGate, self).__init__()
            self.fc1 = nn.Linear(input_dim, input_dim // 2)
            self.fc2 = nn.Linear(input_dim // 2, num_experts)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            gate_scores = F.softmax(self.fc2(x), dim=-1)
            return gate_scores

    def __init__(self, input_dim, num_experts, mamba_model_path, context_encoder_path, language_model_path, question_encoder_path, dpo_model_path, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank, d_model, d_state, d_conv, expansion_factor, device):
        super(SwitchRouter, self).__init__()
        self.device = device
        self.router = self.SwitchGate(input_dim, num_experts).to(device)
        self.transformer_rag = TransformerRAG(context_encoder_path, language_model_path, question_encoder_path, vocab_size).to(device)
        self.transformer_dpo = LanguageModelTransformer(vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank).to(device)
        self.transformer_dpo = load_model_weights(self.transformer_dpo, dpo_model_path)
        self.mamba = SimplifiedLanguageModelMAMBA(vocab_size, num_layers, d_model, d_state, d_conv, expansion_factor).to(device)
        self.mamba = load_model_weights(self.mamba, mamba_model_path)
        self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
        self.input_embedding = nn.Linear(512, input_dim)

    def forward(self, x, attention_mask, context_texts, question_text):
        x = x.to(self.device)
        if x.dtype != torch.float32:
            x = x.float()

        x = self.input_embedding(x)
        gate_scores = self.router(x)
        expert_indices = torch.argmax(gate_scores, dim=1)
        expert_indices = self.route_inputs(expert_indices, gate_scores, len(self.experts))
        final_output = torch.zeros_like(x)

        for i, expert in enumerate(self.experts):
            mask = expert_indices == i
            if mask.any():
                selected_inputs = x[mask]
                selected_attention_mask = attention_mask[mask].to(self.device)
                # Handling expert processing...
                if isinstance(expert, TransformerRAG):
                    expert_output = expert(context_texts, selected_inputs, selected_attention_mask, question_text)
                else:
                    expert_output = expert(selected_inputs, selected_attention_mask)
                final_output[mask] = expert_output

        aux_loss = self.auxiliary_loss(gate_scores)
        return final_output, aux_loss

    @staticmethod
    def auxiliary_loss(gate_scores):
        expert_load = gate_scores.sum(0) / gate_scores.size(0)
        loss_balancing = torch.std(expert_load)
        return loss_balancing

    @staticmethod
    def route_inputs(expert_indices, gate_scores, num_experts):
        capacity_factor_tensor = torch.tensor([SwitchRouter.CAPACITY_FACTOR], dtype=torch.float32)
        capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
        expert_counts = torch.zeros(num_experts, dtype=torch.int32)
        for idx in range(len(expert_indices)):
            selected_expert = expert_indices[idx]
            if expert_counts[selected_expert] < capacities[0]:
                expert_counts[selected_expert] += 1
            else:
                available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
                if len(available_experts) > 0:
                    alternative_expert = available_experts[0]
                    expert_indices[idx] = alternative_expert
                    expert_counts[alternative_expert] += 1
                else:
                    print("No available experts to reroute. Handling overflow.")
        return expert_indices


# v4

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Assuming all other necessary imports are done, like models for TransformerRAG, etc.

class Expert(nn.Module):
    # Flash2_Attention
    class FlashAttention2(nn.Module):
        def __init__(self, sequence_length, head_dimension, block_size):
            super(FlashAttention2, self).__init__()
            self.block_size = block_size
            # Ensure that sequence_length is divisible by block_size for simplicity
            assert sequence_length % block_size == 0

        def forward(self, Q, K, V):
            # Partitioning of inputs
            Q_blocks, K_blocks, V_blocks = self.partition_inputs(Q, K, V)

            # Efficient computation of the attention mechanism
            outputs = []
            for i, Q_block in enumerate(Q_blocks):
                output_block = self.process_block(Q_block, K_blocks, V_blocks)
                outputs.append(output_block)

            # Concatenating the processed blocks
            output = torch.cat(outputs, dim=0)
            return output

        def partition_inputs(self, Q, K, V):
            # The actual partitioning scheme should be based on sequence length, head dimension, and block size
            Q_blocks = Q.chunk(chunks=Q.size(0) // self.block_size, dim=0)
            K_blocks = K.chunk(chunks=K.size(0) // self.block_size, dim=0)
            V_blocks = V.chunk(chunks=V.size(0) // self.block_size, dim=0)
            return Q_blocks, K_blocks, V_blocks

        def process_block(self, Q_block, K_blocks, V_blocks):
            # Process each block efficiently as per FLASH2's optimized method
            # This includes computing QK^T, applying online softmax, and multiplying with V
            output_blocks = []
            for K_block, V_block in zip(K_blocks, V_blocks):
                attention_scores = torch.matmul(Q_block, K_block.transpose(-2, -1))
                attention_scores = self.online_softmax(attention_scores)
                output_block = torch.matmul(attention_scores, V_block)
                output_blocks.append(output_block)

            # Summing up the results from each block
            output_block_sum = sum(output_blocks)
            return output_block_sum

        def online_softmax(self, scores, chunk_size=128):
            # Apply softmax in chunks for large sequences
            softmaxed_scores = []
            for i in range(0, scores.size(0), chunk_size):
                chunk = scores[i:i + chunk_size, :]
                softmaxed_chunk = F.softmax(chunk, dim=1)
                softmaxed_scores.append(softmaxed_chunk)
            return torch.cat(softmaxed_scores, dim=0)

    # SparseFlash2_Attention
    class SparseFlash2Attention(nn.Module):
        def __init__(self, seq_len, head_dim, blk_size, sparsity_factor):
            super().__init__()
            self.flash_attention = FlashAttention2(seq_len, head_dim, blk_size)
            self.seq_len = seq_len
            self.head_dim = head_dim
            self.block_size = blk_size  # Storing block_size as an instance variable
            self.sparsity_factor = sparsity_factor

        def generate_sparsity_mask(self):
            mask = torch.zeros(self.seq_len, self.seq_len)
            step = self.sparsity_factor
            for i in range(0, self.seq_len, step):
                mask[i:i + step, :] = 1
            return mask.bool()

        def forward(self, Q, K, V):
            output = self.flash_attention(Q, K, V)  # output shape: [sequence_length, head_dimension]

            # Reshape output to be 3D for batch matrix multiplication
            output = output.unsqueeze(0)  # New shape: [1, sequence_length, head_dimension]

            sparsity_mask = self.generate_sparsity_mask()  # shape: [sequence_length, sequence_length]

            # Apply the sparsity mask to the output
            sparsity_mask = sparsity_mask.unsqueeze(0)  # New shape: [1, sequence_length, sequence_length]
            output = torch.bmm(sparsity_mask.float(), output.float())  # Perform batch matrix multiplication

            # Reshape the output back to 2D
            output = output.squeeze(0)  # New shape: [sequence_length, head_dimension]

            return output

    class SwitchRouter(nn.Module):
        CAPACITY_FACTOR = 1  # Class constant

        class SwitchGate(nn.Module):
            def __init__(self, input_dim, num_experts):
                super(SwitchRouter.SwitchGate, self).__init__()
                self.fc1 = nn.Linear(input_dim, input_dim // 2)
                self.fc2 = nn.Linear(input_dim // 2, num_experts)

            def forward(self, x):
                x = F.relu(self.fc1(x))
                gate_scores = F.softmax(self.fc2(x), dim=-1)
                return gate_scores

        def __init__(self, input_dim, num_experts, mamba_model_path, context_encoder_path, language_model_path, question_encoder_path, dpo_model_path, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank, d_model, d_state, d_conv, expansion_factor, device):
            super(SwitchRouter, self).__init__()
            self.device = device
            self.router = self.SwitchGate(input_dim, num_experts).to(device)
            self.transformer_rag = TransformerRAG(context_encoder_path, language_model_path, question_encoder_path, vocab_size).to(device)
            self.transformer_dpo = LanguageModelTransformer(vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank).to(device)
            self.transformer_dpo = load_model_weights(self.transformer_dpo, dpo_model_path)
            self.mamba = SimplifiedLanguageModelMAMBA(vocab_size, num_layers, d_model, d_state, d_conv, expansion_factor).to(device)
            self.mamba = load_model_weights(self.mamba, mamba_model_path)
            self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
            self.input_embedding = nn.Linear(512, input_dim)

        def forward(self, x, attention_mask, context_texts, question_text):
            x = x.to(self.device)
            if x.dtype != torch.float32:
                x = x.float()

            x = self.input_embedding(x)
            gate_scores = self.router(x)
            expert_indices = torch.argmax(gate_scores, dim=1)
            expert_indices = self.route_inputs(expert_indices, gate_scores, len(self.experts))
            final_output = torch.zeros_like(x)

            for i, expert in enumerate(self.experts):
                mask = expert_indices == i
                if mask.any():
                    selected_inputs = x[mask]
                    selected_attention_mask = attention_mask[mask].to(self.device)
                    # Handling expert processing...
                    if isinstance(expert, TransformerRAG):
                        expert_output = expert(context_texts, selected_inputs, selected_attention_mask, question_text)
                    else:
                        expert_output = expert(selected_inputs, selected_attention_mask)
                    final_output[mask] = expert_output

            aux_loss = self.auxiliary_loss(gate_scores)
            return final_output, aux_loss

        @staticmethod
        def auxiliary_loss(gate_scores):
            expert_load = gate_scores.sum(0) / gate_scores.size(0)
            loss_balancing = torch.std(expert_load)
            return loss_balancing

        @staticmethod
        def route_inputs(expert_indices, gate_scores, num_experts):
            capacity_factor_tensor = torch.tensor([SwitchRouter.CAPACITY_FACTOR], dtype=torch.float32)
            capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
            expert_counts = torch.zeros(num_experts, dtype=torch.int32)
            for idx in range(len(expert_indices)):
                selected_expert = expert_indices[idx]
                if expert_counts[selected_expert] < capacities[0]:
                    expert_counts[selected_expert] += 1
                else:
                    available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
                    if len(available_experts) > 0:
                        alternative_expert = available_experts[0]
                        expert_indices[idx] = alternative_expert
                        expert_counts[alternative_expert] += 1
                    else:
                        print("No available experts to reroute. Handling overflow.")
            return expert_indices

    class CustomDPRContextEncoder(nn.Module):
        def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
            super(CustomDPRContextEncoder, self).__init__()
            # Transformer-based model, e.g., BERT
            self.bert_model = BertModel.from_pretrained(model_name)
            # Additional layer to produce fixed-size embeddings
            self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

        def forward(self, input_ids, attention_mask=None):
            # Generate outputs from the BERT model
            outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
            # Use the pooled output for creating embeddings
            pooled_output = outputs.pooler_output
            # Pass through the embedding layer
            context_embeddings = self.embedding_layer(pooled_output)
            return context_embeddings

    class DPRQuestionEncoder(nn.Module):
        def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
            super(DPRQuestionEncoder, self).__init__()
            self.bert = BertModel.from_pretrained(model_name)
            self.embedding_layer = nn.Linear(self.bert.config.hidden_size, embedding_dim)

        def forward(self, input_ids, attention_mask, **kwargs):
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.pooler_output
            embeddings = self.embedding_layer(pooled_output)
            return embeddings

    # TransformerRAG class
    class TransformerRAG(nn.Module):
        def __init__(self, context_encoder_path, language_model_path, question_encoder_path, vocab_size):
            super(TransformerRAG, self).__init__()
            self.context_encoder = CustomDPRContextEncoder().to(device)
            self.language_model = LanguageModelTransformer(
                vocab_size=vocab_size,
                embed_size=256, 
                num_layers=6, 
                forward_expansion=4, 
                heads=8, 
                dropout=0, 
                max_length=100,  # Set to 512 to match the tokenization max_length
                rank=16
            ).to(device)
            self.language_model = load_model_weights(self.language_model, language_model_path)
            self.question_encoder = DPRQuestionEncoder().to(device)
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        def forward(self, context_texts, question_input_ids, question_attention_mask, question_text):
            if question_input_ids.max() >= self.tokenizer.vocab_size:
                raise ValueError("question_input_ids contain ID(s) beyond the tokenizer's vocabulary size")

            # Process each context_text
            aggregated_context_embeddings = []
            for context_list in context_texts:
                if not all(isinstance(context, dict) for context in context_list):
                    raise TypeError("Each item in context_texts must be a list of tokenized context dictionaries")

                # Create a tensor of zeros with the correct shape on the GPU
                aggregated_context_embedding = torch.zeros(self.context_encoder.bert_model.config.hidden_size, device=device)
                for context in context_list:
                    context_input_ids = context['input_ids'].to(device)
                    context_attention_mask = context['attention_mask'].to(device)
                    context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                    aggregated_context_embedding += context_embedding.mean(dim=0)

                aggregated_context_embeddings.append(aggregated_context_embedding / len(context_list))

            question_input_ids = question_input_ids.to(device).long()
            question_attention_mask = question_attention_mask.to(device).long()
            question_embeddings = self.question_encoder(input_ids=question_input_ids, attention_mask=question_attention_mask)

            cos_sim = torch.nn.CosineSimilarity(dim=1)
            similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in aggregated_context_embeddings]
            most_relevant_context_idx = torch.argmax(torch.tensor(similarities, device=device))

            combined_input = question_text + " " + context_texts[most_relevant_context_idx]
            tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
            tokenized_combined_input = {k: v.to(device) for k, v in tokenized_combined_input.items()}  # Move to GPU
            response_logits = self.language_model(**tokenized_combined_input)
            probabilities = F.softmax(response_logits.logits, dim=-1)
            predicted_token_ids = torch.argmax(probabilities, dim=-1)
            predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])
            response = self.tokenizer.convert_tokens_to_string(predicted_tokens)

            return response

    class LORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1):
            super(LORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha

            # Original weight and bias of the linear layer
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            #print("self.weight Shape:", self.weight.shape)
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # LORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def forward(self, x):
            #print("LORALayer Input Shape:", x.shape)
            
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)
            
            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            #print("LORALayer Output Shape:", x.shape)

            return x

    class QLORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
            super(QLORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha
            self.quantization_bits = quantization_bits

            # Original weight and bias
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # QLORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.dropout = nn.Dropout(0.1)
            self.layer_norm = nn.LayerNorm(output_dim)

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def quantize(self, x, num_bits):
            # Implement a simple quantization method
            scale = x.abs().max()
            x_quantized = torch.round(x / scale * (2**num_bits - 1))
            return x_quantized, scale

        def forward(self, x):
            #print("QLORALayer Input Shape:", x.shape)
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
            B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
            lora_adjustment = self.dropout(lora_adjustment)
            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)

            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            x = self.layer_norm(x)

            #print("QLORALayer Output Shape:", x.shape)

            return x
        
        def update_alpha(self, new_alpha):
            """
            Update the alpha scaling factor.
            """
            self.alpha = new_alpha

    class MultiHeadAttention(nn.Module):
        def __init__(self, embed_size, heads):
            super(MultiHeadAttention, self).__init__()
            self.embed_size = embed_size
            self.heads = heads
            self.head_dim = embed_size // heads

            assert (
                self.head_dim * heads == embed_size
            ), "Embedding size needs to be divisible by heads"

            self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

        def forward(self, values, keys, query, mask):
            N = query.shape[0]
            value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

            # Split the embedding into self.heads different pieces
            values = values.reshape(N, value_len, self.heads, self.head_dim)
            keys = keys.reshape(N, key_len, self.heads, self.head_dim)
            queries = query.reshape(N, query_len, self.heads, self.head_dim)

            values = self.values(values)
            keys = self.keys(keys)
            queries = self.queries(queries)

            # Einsum does the matrix multiplication for query*keys for each training example
            # with every other training example, then sum it up
            attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

            if mask is not None:
                attention = attention.masked_fill(mask == 0, float("-1e20"))

            attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

            out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
                N, query_len, self.heads * self.head_dim
            )

            out = self.fc_out(out)
            return out

    class TransformerBlock(nn.Module):
        def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
            super(TransformerBlock, self).__init__()
            self.attention = MultiHeadAttention(embed_size, heads)
            self.norm1 = nn.LayerNorm(embed_size)
            self.norm2 = nn.LayerNorm(embed_size)

            self.feed_forward = nn.Sequential(
                LORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                LORALayer(forward_expansion * embed_size, embed_size, rank),
            )

            self.dropout = nn.Dropout(dropout)

        def forward(self, value, key, query, mask):
            attention = self.attention(value, key, query, mask)
            x = self.dropout(self.norm1(attention + query))
            forward = self.feed_forward(x)
            out = self.dropout(self.norm2(forward + x))
            return out

    class LanguageModelDecoder(nn.Module):
        def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
            super(LanguageModelDecoder, self).__init__()
            self.word_embedding = nn.Embedding(vocab_size, embed_size)
            self.position_embedding = nn.Embedding(max_length, embed_size)

            # Adding BatchNorm layers
            self.bn1 = nn.BatchNorm1d(embed_size)
            self.bn2 = nn.BatchNorm1d(embed_size)

            self.layers = nn.ModuleList(
                [
                    TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                    for _ in range(num_layers)
                ]
            )

            # QLORA layers
            self.qlora_feed_forward = nn.Sequential(
                QLORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                QLORALayer(forward_expansion * embed_size, embed_size, rank),
            )
            self.use_qlora = False  # Flag to toggle QLORA

            self.fc_out = nn.Linear(embed_size, vocab_size)
            self.dropout = nn.Dropout(dropout)

        def forward(self, x, trg_mask):
            N, seq_length = x.shape
            positions = torch.arange(0, seq_length).expand(N, seq_length) #.to(x.device)
            x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn1(x)
            x = x.transpose(1, 2)

            for layer in self.layers:
                x = layer(x, x, x, trg_mask)
                if self.use_qlora:
                    x = self.qlora_feed_forward(x)

            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn2(x)
            x = x.transpose(1, 2)

            out = self.fc_out(x)
            #print(f"shape of output of forward method of LanguageModelDecoder: {out.shape} ")

            return out

        def toggle_qlora(self, use_qlora: bool):
            self.use_qlora = use_qlora

    class LanguageModelTransformer(nn.Module):
        def __init__(self, vocab_size, embed_size=256, num_layers=6, forward_expansion=4, heads=8, dropout=0, max_length=100, rank=16):
            super(LanguageModelTransformer, self).__init__()

            self.decoder = LanguageModelDecoder(
                vocab_size,
                embed_size,
                num_layers,
                heads,
                forward_expansion,
                dropout,
                max_length,
                rank,
            )

        def forward(self, trg):
            trg_mask = self.make_trg_mask(trg)
            out = self.decoder(trg, trg_mask)
            return out

        def make_trg_mask(self, trg):
            N, trg_len = trg.shape
            trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
                N, 1, trg_len, trg_len
            ).to(trg.device)

            return trg_mask
        
        # Function to enable or disable QLORA layers (for fine-tuning purposes)
        def toggle_qlora(self, use_qlora: bool):
            self.decoder.toggle_qlora(use_qlora)

        def generate_response(self, input_ids, attention_mask):
            # Assuming you have a forward method that returns logits
            logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
            
            # Convert logits to probabilities
            probabilities = F.softmax(logits, dim=-1)
            
            # For simplicity, using greedy decoding here. You might want to use beam search or sampling.
            predicted_token_id = torch.argmax(probabilities, dim=-1)
            
            # Convert predicted token ids to tokens
            predicted_tokens = [tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
            
            # Join tokens to form the response string. This is a very basic way to generate text and might not produce the best results.
            response = tokenizer.convert_tokens_to_string(predicted_tokens)
            
            return response


    def __init__(self, sparsity_factor, seq_len, head_dim, blk_size, input_dim, num_experts, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank, device):
        super().__init__()
        # 1. SparseFlash2_attention
        self.sparse_flash2_attention = Expert.SparseFlash2Attention(seq_len, head_dim, blk_size, sparsity_factor)

        # 2. LayerNorm and Dropout
        self.layer_norm = nn.LayerNorm(input_dim)
        self.dropout = nn.Dropout(0.1)

        # 3. Internal Switch Routing
        self.switch_router = Expert.SwitchRouter(input_dim, num_experts, mamba_model_path, context_encoder_path, language_model_path, question_encoder_path, dpo_model_path, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank, device)

        # 4. QLORA Layer
        self.qlora = Expert.QLORALayer(input_dim, input_dim, rank)

    def forward(self, x, attention_mask, context_texts, question_text):
        # 1. SparseFlash2_attention
        x = self.sparse_flash2_attention(x)

        # 2. LayerNorm and Dropout
        x = self.dropout(self.layer_norm(x))

        # 3. Internal Switch Routing
        x, _ = self.switch_router(x, attention_mask, context_texts, question_text)

        # 4. QLORA Layer
        x = self.qlora(x)

        return x

# Create an instance of the Expert class
# expert = Expert(sparsity_factor, seq_len, head_dim, blk_size, input_dim, num_experts, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank, device)


# v5

Expert: {
1. SparseFlash2_attention
2. LayerNorm + Dropout
3. Internal Switch Routing Between Sub-models:
a) Sub model 1: Transformer with DPO (Direct Preference Optimization)
b) Sub model 2: Transformer with RAG(Retrieval Augmented Generation)
c) Sub Model 3: MAMBA (Linear-Time Sequence Modelling with Selective State Spaces)
4. QLORA
}

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer

class ExpertConfig:
    def __init__(self, seq_len=512, head_dim=64, block_size=64, sparsity_factor=4,
                 input_dim=512, num_experts=3, vocab_size=30000, embed_size=256,
                 num_layers=6, forward_expansion=4, heads=8, dropout=0.1,
                 max_length=100, rank=16, device='cuda',
                 mamba_model_path='path/to/mamba_model',
                 context_encoder_path='path/to/context_encoder',
                 language_model_path='path/to/language_model',
                 question_encoder_path='path/to/question_encoder',
                 dpo_model_path='path/to/dpo_model'):

        self.seq_len = seq_len
        self.head_dim = head_dim
        self.block_size = block_size
        self.sparsity_factor = sparsity_factor
        self.input_dim = input_dim
        self.num_experts = num_experts
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.num_layers = num_layers
        self.forward_expansion = forward_expansion
        self.heads = heads
        self.dropout = dropout
        self.max_length = max_length
        self.rank = rank
        self.device = device

        # Model paths
        self.mamba_model_path = mamba_model_path
        self.context_encoder_path = context_encoder_path
        self.language_model_path = language_model_path
        self.question_encoder_path = question_encoder_path
        self.dpo_model_path = dpo_model_path

    def validate(self):
        # Add validation checks here if needed
        assert self.seq_len % self.block_size == 0, "seq_len must be divisible by block_size"
        # Additional validations can be added as required

class Expert(nn.Module):

    # Flash2_Attention
    class FlashAttention2(nn.Module):
        def __init__(self, sequence_length, head_dimension, block_size):
            super(FlashAttention2, self).__init__()
            self.block_size = block_size
            # Ensure that sequence_length is divisible by block_size for simplicity
            assert sequence_length % block_size == 0

        def forward(self, Q, K, V):
            # Partitioning of inputs
            Q_blocks, K_blocks, V_blocks = self.partition_inputs(Q, K, V)

            # Efficient computation of the attention mechanism
            outputs = []
            for i, Q_block in enumerate(Q_blocks):
                output_block = self.process_block(Q_block, K_blocks, V_blocks)
                outputs.append(output_block)

            # Concatenating the processed blocks
            output = torch.cat(outputs, dim=0)
            return output

        def partition_inputs(self, Q, K, V):
            # The actual partitioning scheme should be based on sequence length, head dimension, and block size
            Q_blocks = Q.chunk(chunks=Q.size(0) // self.block_size, dim=0)
            K_blocks = K.chunk(chunks=K.size(0) // self.block_size, dim=0)
            V_blocks = V.chunk(chunks=V.size(0) // self.block_size, dim=0)
            return Q_blocks, K_blocks, V_blocks

        def process_block(self, Q_block, K_blocks, V_blocks):
            # Process each block efficiently as per FLASH2's optimized method
            # This includes computing QK^T, applying online softmax, and multiplying with V
            output_blocks = []
            for K_block, V_block in zip(K_blocks, V_blocks):
                attention_scores = torch.matmul(Q_block, K_block.transpose(-2, -1))
                attention_scores = self.online_softmax(attention_scores)
                output_block = torch.matmul(attention_scores, V_block)
                output_blocks.append(output_block)

            # Summing up the results from each block
            output_block_sum = sum(output_blocks)
            return output_block_sum

        def online_softmax(self, scores, chunk_size=128):
            # Apply softmax in chunks for large sequences
            softmaxed_scores = []
            for i in range(0, scores.size(0), chunk_size):
                chunk = scores[i:i + chunk_size, :]
                softmaxed_chunk = F.softmax(chunk, dim=1)
                softmaxed_scores.append(softmaxed_chunk)
            return torch.cat(softmaxed_scores, dim=0)

    # SparseFlash2_Attention
    class SparseFlash2Attention(nn.Module):
        def __init__(self, seq_len, head_dim, blk_size, sparsity_factor):
            super().__init__()
            self.flash_attention = FlashAttention2(seq_len, head_dim, blk_size)
            self.seq_len = seq_len
            self.head_dim = head_dim
            self.block_size = blk_size  # Storing block_size as an instance variable
            self.sparsity_factor = sparsity_factor

        def generate_sparsity_mask(self):
            mask = torch.zeros(self.seq_len, self.seq_len)
            step = self.sparsity_factor
            for i in range(0, self.seq_len, step):
                mask[i:i + step, :] = 1
            return mask.bool()

        def forward(self, Q, K, V):
            output = self.flash_attention(Q, K, V)  # output shape: [sequence_length, head_dimension]

            # Reshape output to be 3D for batch matrix multiplication
            output = output.unsqueeze(0)  # New shape: [1, sequence_length, head_dimension]

            sparsity_mask = self.generate_sparsity_mask()  # shape: [sequence_length, sequence_length]

            # Apply the sparsity mask to the output
            sparsity_mask = sparsity_mask.unsqueeze(0)  # New shape: [1, sequence_length, sequence_length]
            output = torch.bmm(sparsity_mask.float(), output.float())  # Perform batch matrix multiplication

            # Reshape the output back to 2D
            output = output.squeeze(0)  # New shape: [sequence_length, head_dimension]

            return output

    class SwitchRouter(nn.Module):
        CAPACITY_FACTOR = 1  # Class constant

        class SwitchGate(nn.Module):
            def __init__(self, input_dim, num_experts):
                super(SwitchRouter.SwitchGate, self).__init__()
                self.fc1 = nn.Linear(input_dim, input_dim // 2)
                self.fc2 = nn.Linear(input_dim // 2, num_experts)

            def forward(self, x):
                x = F.relu(self.fc1(x))
                gate_scores = F.softmax(self.fc2(x), dim=-1)
                return gate_scores

        def __init__(self, input_dim, num_experts, mamba_model_path, context_encoder_path, language_model_path, question_encoder_path, dpo_model_path, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank, d_model, d_state, d_conv, expansion_factor, device):
            super(SwitchRouter, self).__init__()
            self.device = device
            self.router = self.SwitchGate(input_dim, num_experts).to(device)
            self.transformer_rag = TransformerRAG(context_encoder_path, language_model_path, question_encoder_path, vocab_size).to(device)
            self.transformer_dpo = LanguageModelTransformer(vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank).to(device)
            self.transformer_dpo = load_model_weights(self.transformer_dpo, dpo_model_path)
            self.mamba = SimplifiedLanguageModelMAMBA(vocab_size, num_layers, d_model, d_state, d_conv, expansion_factor).to(device)
            self.mamba = load_model_weights(self.mamba, mamba_model_path)
            self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
            self.input_embedding = nn.Linear(512, input_dim)

        def forward(self, x, attention_mask, context_texts, question_text):
            x = x.to(self.device)
            if x.dtype != torch.float32:
                x = x.float()

            x = self.input_embedding(x)
            gate_scores = self.router(x)
            expert_indices = torch.argmax(gate_scores, dim=1)
            expert_indices = self.route_inputs(expert_indices, gate_scores, len(self.experts))
            final_output = torch.zeros_like(x)

            for i, expert in enumerate(self.experts):
                mask = expert_indices == i
                if mask.any():
                    selected_inputs = x[mask]
                    selected_attention_mask = attention_mask[mask].to(self.device)
                    # Handling expert processing...
                    if isinstance(expert, TransformerRAG):
                        expert_output = expert(context_texts, selected_inputs, selected_attention_mask, question_text)
                    else:
                        expert_output = expert(selected_inputs, selected_attention_mask)
                    final_output[mask] = expert_output

            aux_loss = self.auxiliary_loss(gate_scores)
            return final_output, aux_loss

        @staticmethod
        def auxiliary_loss(gate_scores):
            expert_load = gate_scores.sum(0) / gate_scores.size(0)
            loss_balancing = torch.std(expert_load)
            return loss_balancing

        @staticmethod
        def route_inputs(expert_indices, gate_scores, num_experts):
            capacity_factor_tensor = torch.tensor([SwitchRouter.CAPACITY_FACTOR], dtype=torch.float32)
            capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
            expert_counts = torch.zeros(num_experts, dtype=torch.int32)
            for idx in range(len(expert_indices)):
                selected_expert = expert_indices[idx]
                if expert_counts[selected_expert] < capacities[0]:
                    expert_counts[selected_expert] += 1
                else:
                    available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
                    if len(available_experts) > 0:
                        alternative_expert = available_experts[0]
                        expert_indices[idx] = alternative_expert
                        expert_counts[alternative_expert] += 1
                    else:
                        print("No available experts to reroute. Handling overflow.")
            return expert_indices

    class CustomDPRContextEncoder(nn.Module):
        def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
            super(CustomDPRContextEncoder, self).__init__()
            # Transformer-based model, e.g., BERT
            self.bert_model = BertModel.from_pretrained(model_name)
            # Additional layer to produce fixed-size embeddings
            self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

        def forward(self, input_ids, attention_mask=None):
            # Generate outputs from the BERT model
            outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
            # Use the pooled output for creating embeddings
            pooled_output = outputs.pooler_output
            # Pass through the embedding layer
            context_embeddings = self.embedding_layer(pooled_output)
            return context_embeddings

    class DPRQuestionEncoder(nn.Module):
        def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
            super(DPRQuestionEncoder, self).__init__()
            self.bert = BertModel.from_pretrained(model_name)
            self.embedding_layer = nn.Linear(self.bert.config.hidden_size, embedding_dim)

        def forward(self, input_ids, attention_mask, **kwargs):
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.pooler_output
            embeddings = self.embedding_layer(pooled_output)
            return embeddings

    # TransformerRAG class
    class TransformerRAG(nn.Module):
        def __init__(self, context_encoder_path, language_model_path, question_encoder_path, vocab_size):
            super(TransformerRAG, self).__init__()
            self.context_encoder = CustomDPRContextEncoder().to(device)
            self.language_model = LanguageModelTransformer(
                vocab_size=vocab_size,
                embed_size=256, 
                num_layers=6, 
                forward_expansion=4, 
                heads=8, 
                dropout=0, 
                max_length=100,  # Set to 512 to match the tokenization max_length
                rank=16
            ).to(device)
            self.language_model = load_model_weights(self.language_model, language_model_path)
            self.question_encoder = DPRQuestionEncoder().to(device)
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        def forward(self, context_texts, question_input_ids, question_attention_mask, question_text):
            if question_input_ids.max() >= self.tokenizer.vocab_size:
                raise ValueError("question_input_ids contain ID(s) beyond the tokenizer's vocabulary size")

            # Process each context_text
            aggregated_context_embeddings = []
            for context_list in context_texts:
                if not all(isinstance(context, dict) for context in context_list):
                    raise TypeError("Each item in context_texts must be a list of tokenized context dictionaries")

                # Create a tensor of zeros with the correct shape on the GPU
                aggregated_context_embedding = torch.zeros(self.context_encoder.bert_model.config.hidden_size, device=device)
                for context in context_list:
                    context_input_ids = context['input_ids'].to(device)
                    context_attention_mask = context['attention_mask'].to(device)
                    context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                    aggregated_context_embedding += context_embedding.mean(dim=0)

                aggregated_context_embeddings.append(aggregated_context_embedding / len(context_list))

            question_input_ids = question_input_ids.to(device).long()
            question_attention_mask = question_attention_mask.to(device).long()
            question_embeddings = self.question_encoder(input_ids=question_input_ids, attention_mask=question_attention_mask)

            cos_sim = torch.nn.CosineSimilarity(dim=1)
            similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in aggregated_context_embeddings]
            most_relevant_context_idx = torch.argmax(torch.tensor(similarities, device=device))

            combined_input = question_text + " " + context_texts[most_relevant_context_idx]
            tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
            tokenized_combined_input = {k: v.to(device) for k, v in tokenized_combined_input.items()}  # Move to GPU
            response_logits = self.language_model(**tokenized_combined_input)
            probabilities = F.softmax(response_logits.logits, dim=-1)
            predicted_token_ids = torch.argmax(probabilities, dim=-1)
            predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])
            response = self.tokenizer.convert_tokens_to_string(predicted_tokens)

            return response

    class LORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1):
            super(LORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha

            # Original weight and bias of the linear layer
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            #print("self.weight Shape:", self.weight.shape)
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # LORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def forward(self, x):
            #print("LORALayer Input Shape:", x.shape)
            
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)
            
            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            #print("LORALayer Output Shape:", x.shape)

            return x

    class QLORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
            super(QLORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha
            self.quantization_bits = quantization_bits

            # Original weight and bias
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # QLORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.dropout = nn.Dropout(0.1)
            self.layer_norm = nn.LayerNorm(output_dim)

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def quantize(self, x, num_bits):
            # Implement a simple quantization method
            scale = x.abs().max()
            x_quantized = torch.round(x / scale * (2**num_bits - 1))
            return x_quantized, scale

        def forward(self, x):
            #print("QLORALayer Input Shape:", x.shape)
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
            B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
            lora_adjustment = self.dropout(lora_adjustment)
            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)

            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            x = self.layer_norm(x)

            #print("QLORALayer Output Shape:", x.shape)

            return x
        
        def update_alpha(self, new_alpha):
            """
            Update the alpha scaling factor.
            """
            self.alpha = new_alpha

    class MultiHeadAttention(nn.Module):
        def __init__(self, embed_size, heads):
            super(MultiHeadAttention, self).__init__()
            self.embed_size = embed_size
            self.heads = heads
            self.head_dim = embed_size // heads

            assert (
                self.head_dim * heads == embed_size
            ), "Embedding size needs to be divisible by heads"

            self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

        def forward(self, values, keys, query, mask):
            N = query.shape[0]
            value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

            # Split the embedding into self.heads different pieces
            values = values.reshape(N, value_len, self.heads, self.head_dim)
            keys = keys.reshape(N, key_len, self.heads, self.head_dim)
            queries = query.reshape(N, query_len, self.heads, self.head_dim)

            values = self.values(values)
            keys = self.keys(keys)
            queries = self.queries(queries)

            # Einsum does the matrix multiplication for query*keys for each training example
            # with every other training example, then sum it up
            attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

            if mask is not None:
                attention = attention.masked_fill(mask == 0, float("-1e20"))

            attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

            out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
                N, query_len, self.heads * self.head_dim
            )

            out = self.fc_out(out)
            return out

    class TransformerBlock(nn.Module):
        def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
            super(TransformerBlock, self).__init__()
            self.attention = MultiHeadAttention(embed_size, heads)
            self.norm1 = nn.LayerNorm(embed_size)
            self.norm2 = nn.LayerNorm(embed_size)

            self.feed_forward = nn.Sequential(
                LORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                LORALayer(forward_expansion * embed_size, embed_size, rank),
            )

            self.dropout = nn.Dropout(dropout)

        def forward(self, value, key, query, mask):
            attention = self.attention(value, key, query, mask)
            x = self.dropout(self.norm1(attention + query))
            forward = self.feed_forward(x)
            out = self.dropout(self.norm2(forward + x))
            return out

    class LanguageModelDecoder(nn.Module):
        def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
            super(LanguageModelDecoder, self).__init__()
            self.word_embedding = nn.Embedding(vocab_size, embed_size)
            self.position_embedding = nn.Embedding(max_length, embed_size)

            # Adding BatchNorm layers
            self.bn1 = nn.BatchNorm1d(embed_size)
            self.bn2 = nn.BatchNorm1d(embed_size)

            self.layers = nn.ModuleList(
                [
                    TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                    for _ in range(num_layers)
                ]
            )

            # QLORA layers
            self.qlora_feed_forward = nn.Sequential(
                QLORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                QLORALayer(forward_expansion * embed_size, embed_size, rank),
            )
            self.use_qlora = False  # Flag to toggle QLORA

            self.fc_out = nn.Linear(embed_size, vocab_size)
            self.dropout = nn.Dropout(dropout)

        def forward(self, x, trg_mask):
            N, seq_length = x.shape
            positions = torch.arange(0, seq_length).expand(N, seq_length) #.to(x.device)
            x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn1(x)
            x = x.transpose(1, 2)

            for layer in self.layers:
                x = layer(x, x, x, trg_mask)
                if self.use_qlora:
                    x = self.qlora_feed_forward(x)

            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn2(x)
            x = x.transpose(1, 2)

            out = self.fc_out(x)
            #print(f"shape of output of forward method of LanguageModelDecoder: {out.shape} ")

            return out

        def toggle_qlora(self, use_qlora: bool):
            self.use_qlora = use_qlora

    class LanguageModelTransformer(nn.Module):
        def __init__(self, vocab_size, embed_size=256, num_layers=6, forward_expansion=4, heads=8, dropout=0, max_length=100, rank=16):
            super(LanguageModelTransformer, self).__init__()

            self.decoder = LanguageModelDecoder(
                vocab_size,
                embed_size,
                num_layers,
                heads,
                forward_expansion,
                dropout,
                max_length,
                rank,
            )

        def forward(self, trg):
            trg_mask = self.make_trg_mask(trg)
            out = self.decoder(trg, trg_mask)
            return out

        def make_trg_mask(self, trg):
            N, trg_len = trg.shape
            trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
                N, 1, trg_len, trg_len
            ).to(trg.device)

            return trg_mask
        
        # Function to enable or disable QLORA layers (for fine-tuning purposes)
        def toggle_qlora(self, use_qlora: bool):
            self.decoder.toggle_qlora(use_qlora)

        def generate_response(self, input_ids, attention_mask):
            # Assuming you have a forward method that returns logits
            logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
            
            # Convert logits to probabilities
            probabilities = F.softmax(logits, dim=-1)
            
            # For simplicity, using greedy decoding here. You might want to use beam search or sampling.
            predicted_token_id = torch.argmax(probabilities, dim=-1)
            
            # Convert predicted token ids to tokens
            predicted_tokens = [tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
            
            # Join tokens to form the response string. This is a very basic way to generate text and might not produce the best results.
            response = tokenizer.convert_tokens_to_string(predicted_tokens)
            
            return response

    class DPO(nn.Module):
        def __init__(self, model, device):
            super(DPO, self).__init__()
            self.model = model
            self.device = device

        def forward(self, input_ids, labels):
            # Forward pass through the existing model
            logits = self.model(input_ids)

            # Compute binary cross-entropy loss
            loss = F.binary_cross_entropy_with_logits(logits, labels.float())
            return loss

        def train_dpo(self, train_loader, optimizer):
            self.model.train()
            for input_ids, labels in train_loader:
                input_ids, labels = input_ids.to(self.device), labels.to(self.device)

                optimizer.zero_grad()
                loss = self.forward(input_ids, labels)
                loss.backward()
                optimizer.step()


    def __init__(self, config: ExpertConfig):
        super().__init__()
        # Validate configuration
        config.validate()

        # 1. SparseFlash2_attention
        self.sparse_flash2_attention = Expert.SparseFlash2Attention(
            config.seq_len, 
            config.head_dim, 
            config.block_size, 
            config.sparsity_factor
        )

        # 2. LayerNorm and Dropout
        self.layer_norm = nn.LayerNorm(config.input_dim)
        self.dropout = nn.Dropout(config.dropout)

        # 3. Internal Switch Routing
        self.switch_router = Expert.SwitchRouter(
            config.input_dim, 
            config.num_experts, 
            config.mamba_model_path,
            config.context_encoder_path, 
            config.language_model_path,
            config.question_encoder_path, 
            config.dpo_model_path,
            config.vocab_size, 
            config.embed_size, 
            config.num_layers,
            config.forward_expansion, 
            config.heads, config.dropout,
            config.max_length, 
            config.rank, 
            config.device
        )

        # 4. QLORA Layer
        self.qlora = Expert.QLORALayer(config.input_dim, config.input_dim, config.rank)

    def forward(self, x, attention_mask, context_texts, question_text):
        # 1. SparseFlash2_attention
        x = self.sparse_flash2_attention(x)

        # 2. LayerNorm and Dropout
        x = self.dropout(self.layer_norm(x))

        # 3. Internal Switch Routing
        x, aux_loss = self.switch_router(x, attention_mask, context_texts, question_text)

        # 4. QLORA Layer
        x = self.qlora(x)

        return x, aux_loss

# Example usage
config = ExpertConfig()
expert_model = Expert(config)
