In [32]:
# 0. Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import fitz  # PyMuPDF
from transformers import BertModel
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

class ExpertConfig:
    def __init__(self, seq_len=512, head_dim=64, block_size=64, sparsity_factor=4,
                 input_dim=512, num_experts=3, vocab_size=30522, embed_size=256,
                 num_layers=6, forward_expansion=4, heads=8, dropout=0.1,
                 max_length=512,  # Updated to match seq_len for consistency
                 rank=16, device='cpu',
                 mamba_model_path='D:\\EXPERT_WEIGHTS\\mamba_model_weights.pth',
                 context_encoder_path='D:\\EXPERT_WEIGHTS\\context_encoder.pth',
                 language_model_path='D:\\EXPERT_WEIGHTS\\language_model.pth',
                 question_encoder_path='D:\\EXPERT_WEIGHTS\\question_encoder.pth',
                 dpo_model_path='D:\\EXPERT_WEIGHTS\\dpo_model_weights.pth',
                 model_name='bert-base-uncased', embedding_dim=768,
                 alpha=1, quantization_bits=8, tokenizer_name='bert-base-uncased',
                 d_model=512, d_state=2048, d_conv=3, expansion_factor=2):

        # Common hyperparameters
        self.seq_len = seq_len
        self.head_dim = head_dim
        self.block_size = block_size
        self.sparsity_factor = sparsity_factor
        self.input_dim = input_dim
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.num_layers = num_layers
        self.forward_expansion = forward_expansion
        self.heads = heads
        self.dropout = dropout
        self.max_length = max_length  # Ensure this is properly reflected in model components
        self.rank = rank

        # Model paths and device
        self.mamba_model_path = mamba_model_path
        self.context_encoder_path = context_encoder_path
        self.language_model_path = language_model_path
        self.question_encoder_path = question_encoder_path
        self.dpo_model_path = dpo_model_path
        self.device = device

        # Unique hyperparameters
        self.num_experts = num_experts
        self.model_name = model_name
        self.embedding_dim = embedding_dim
        self.alpha = alpha
        self.quantization_bits = quantization_bits
        self.tokenizer_name = tokenizer_name
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor

        # PDFs (unchanged)
        self.pdf_file_paths = [
            r'C:\Users\robbi\IEEMM\DPO.pdf', 
            r'C:\Users\robbi\IEEMM\MAMBA.pdf',
            r'C:\Users\robbi\IEEMM\QLORA.pdf',
            r'C:\Users\robbi\IEEMM\RAG.pdf',
            r'C:\Users\robbi\IEEMM\SWITCH_TRANSFORMER.pdf'
        ]
        
        # Preserving original dataset loading functionality
        self.rag_dataset = Expert.TransformerRAG.create_dataset_from_pdfs(self.pdf_file_paths)

    def validate(self):
        assert self.seq_len % self.block_size == 0, "seq_len must be divisible by block_size"
        assert self.max_length >= self.seq_len, "max_length should be equal to or greater than seq_len"



class Expert(nn.Module):

    @staticmethod
    # Load model weights function
    def load_model_weights(model, model_path):
        #checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        checkpoint = torch.load(model_path, map_location=config.device)

        if isinstance(checkpoint, dict):
            # Check for 'state_dict' or 'model_state_dict' keys
            if 'state_dict' in checkpoint:
                model.load_state_dict(checkpoint['state_dict'])
            elif 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                # If no known key is found, try loading it as a raw state dictionary
                try:
                    model.load_state_dict(checkpoint)
                except RuntimeError as e:
                    raise ValueError(f"Error loading state dict: {e}")
        elif isinstance(checkpoint, nn.Module):
            # If the checkpoint is a model object, assign it directly
            model = checkpoint
        else:
            raise ValueError(f"Unsupported checkpoint format: {type(checkpoint)}")

        model.eval()
        return model

    # Flash2_Attention
    class FlashAttention2(nn.Module):
        def __init__(self, sequence_length, head_dimension, block_size):
            super(Expert.FlashAttention2, self).__init__()
            self.block_size = block_size
            # Ensure that sequence_length is divisible by block_size for simplicity
            assert sequence_length % block_size == 0

        def forward(self, Q, K, V):
            # Partitioning of inputs
            Q_blocks, K_blocks, V_blocks = self.partition_inputs(Q, K, V)

            # Efficient computation of the attention mechanism
            outputs = []
            for i, Q_block in enumerate(Q_blocks):
                output_block = self.process_block(Q_block, K_blocks, V_blocks)
                outputs.append(output_block)

            # Concatenating the processed blocks
            output = torch.cat(outputs, dim=0)
            return output

        def partition_inputs(self, Q, K, V):
            # The actual partitioning scheme should be based on sequence length, head dimension, and block size
            Q_blocks = Q.chunk(chunks=Q.size(0) // self.block_size, dim=0)
            K_blocks = K.chunk(chunks=K.size(0) // self.block_size, dim=0)
            V_blocks = V.chunk(chunks=V.size(0) // self.block_size, dim=0)
            return Q_blocks, K_blocks, V_blocks

        def process_block(self, Q_block, K_blocks, V_blocks):
            # Process each block efficiently as per FLASH2's optimized method
            # This includes computing QK^T, applying online softmax, and multiplying with V
            output_blocks = []
            for K_block, V_block in zip(K_blocks, V_blocks):
                attention_scores = torch.matmul(Q_block, K_block.transpose(-2, -1))
                attention_scores = self.online_softmax(attention_scores)
                output_block = torch.matmul(attention_scores, V_block)
                output_blocks.append(output_block)

            # Summing up the results from each block
            output_block_sum = sum(output_blocks)
            return output_block_sum

        def online_softmax(self, scores, chunk_size=128):
            # Apply softmax in chunks for large sequences
            softmaxed_scores = []
            for i in range(0, scores.size(0), chunk_size):
                chunk = scores[i:i + chunk_size, :]
                softmaxed_chunk = F.softmax(chunk, dim=1)
                softmaxed_scores.append(softmaxed_chunk)
            return torch.cat(softmaxed_scores, dim=0)

    # SparseFlash2_Attention
    class SparseFlash2Attention(nn.Module):
        def __init__(self, seq_len, head_dim, blk_size, sparsity_factor):
            super().__init__()
            self.flash_attention = Expert.FlashAttention2(seq_len, head_dim, blk_size)
            self.seq_len = seq_len
            self.head_dim = head_dim
            self.block_size = blk_size  # Storing block_size as an instance variable
            self.sparsity_factor = sparsity_factor

        def generate_sparsity_mask(self):
            mask = torch.zeros(self.seq_len, self.seq_len)
            step = self.sparsity_factor
            for i in range(0, self.seq_len, step):
                mask[i:i + step, :] = 1
            return mask.bool()

        def forward(self, Q, K, V):
            output = self.flash_attention(Q, K, V)  # output shape: [sequence_length, head_dimension]

            # Reshape output to be 3D for batch matrix multiplication
            output = output.unsqueeze(0)  # New shape: [1, sequence_length, head_dimension]

            sparsity_mask = self.generate_sparsity_mask()  # shape: [sequence_length, sequence_length]

            # Apply the sparsity mask to the output
            sparsity_mask = sparsity_mask.unsqueeze(0)  # New shape: [1, sequence_length, sequence_length]
            output = torch.bmm(sparsity_mask.float(), output.float())  # Perform batch matrix multiplication

            # Reshape the output back to 2D
            output = output.squeeze(0)  # New shape: [sequence_length, head_dimension]

            return output


    # MAMBA
    # RoPE
    class RotaryPositionalEncoding(nn.Module):
        def __init__(self, dim, max_len=5000):
            super().__init__()
            self.dim = dim
            inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
            t = torch.arange(max_len).type_as(inv_freq)
            freqs = torch.einsum('n , d -> n d', t, inv_freq)
            self.register_buffer('sin', freqs.sin())
            self.register_buffer('cos', freqs.cos())

        def forward(self, x):
            n, _, device = x.shape[1], self.dim // 2, x.device
            sin, cos = self.sin[:n].to(device), self.cos[:n].to(device)

            # Apply RoPE to even and odd indices separately
            x_even = x[..., :self.dim:2] * cos.unsqueeze(0) + torch.roll(x[..., 1:self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
            x_odd = x[..., 1:self.dim:2] * cos.unsqueeze(0) - torch.roll(x[..., :self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
            return torch.cat((x_even, x_odd), dim=-1)

    # SWIGLU
    class SwiGLU(nn.Module):
        def __init__(self, dim_in, dim_out):
            super(Expert.SwiGLU, self).__init__()
            self.fc1 = nn.Linear(dim_in, dim_out)
            self.fc2 = nn.Linear(dim_in, dim_out)

        def forward(self, x):
            gate = torch.sigmoid(self.fc2(x))
            return self.fc1(x) * gate

    class SimplifiedMAMBA(nn.Module):
        def __init__(self, config):
            super().__init__()
            # Using configuration parameters
            self.num_layers = config.num_layers
            self.d_model = config.d_model
            self.d_state = config.d_state
            self.d_conv = config.d_conv
            self.expansion_factor = config.expansion_factor
            
            self.feedforward = nn.Sequential(
                nn.Linear(self.d_model, self.d_state),
                nn.GELU(),
                nn.Linear(self.d_state, self.d_model)
            )
            self.input_embedding = nn.Linear(self.d_model, self.d_model)
            self.convs = nn.Sequential(*[nn.Conv1d(self.d_model, self.d_model, kernel_size=self.d_conv, padding=(self.d_conv // 2)) for _ in range(self.num_layers)])
            self.swiglu = Expert.SwiGLU(self.d_model, self.d_model)
            self.output_projection = nn.Linear(self.d_model, self.d_model * self.expansion_factor)

            self.initialize_weights()

        def initialize_weights(self):
            gain = nn.init.calculate_gain('relu')

            nn.init.orthogonal_(self.input_embedding.weight, gain)
            nn.init.normal_(self.input_embedding.bias, mean=0, std=0.01)

            nn.init.kaiming_uniform_(self.convs[-1].weight, a=math.sqrt(5))
            nn.init.zeros_(self.convs[-1].bias)

            nn.init.xavier_uniform_(self.feedforward[0].weight, gain=nn.init.calculate_gain('relu'))
            nn.init.zeros_(self.feedforward[0].bias)

            nn.init.xavier_uniform_(self.feedforward[2].weight, gain=nn.init.calculate_gain('linear'))
            nn.init.zeros_(self.feedforward[2].bias)

            nn.init.xavier_uniform_(self.output_projection.weight, gain=nn.init.calculate_gain('linear'))
            nn.init.zeros_(self.output_projection.bias)

        def forward(self, inputs, attention_mask=None):
            print("Input shape:", inputs.shape)

            # Apply the attention mask if provided
            if attention_mask is not None:
                inputs = inputs * attention_mask.unsqueeze(-1)

            projected_inputs = self.input_embedding(inputs)
            print("projected_inputs pre-reshape shape:", projected_inputs.shape)

            projected_inputs = projected_inputs.permute(0, 2, 1)
            print("projected_inputs post-reshape shape:", projected_inputs.shape)

            for conv in self.convs:
                projected_inputs = conv(projected_inputs)

            projected_inputs = projected_inputs.permute(0, 2, 1)
            print("projected_inputs post convolution reshape:", projected_inputs.shape)

            projected_inputs = self.swiglu(projected_inputs)
            print("projected_inputs post swiglu shape:", projected_inputs.shape)

            output = self.output_projection(projected_inputs)
            print("output shape:", output.shape)

            return output

    class SimplifiedLanguageModelMAMBA(nn.Module):
        def __init__(self, config):
            super().__init__()
            # Using configuration parameters
            self.vocab_size = config.vocab_size
            self.num_layers = config.num_layers
            self.d_model = config.d_model
            self.d_state = config.d_state
            self.d_conv = config.d_conv
            self.expansion_factor = config.expansion_factor

            self.embedding = nn.Embedding(self.vocab_size, self.d_model)
            self.pos_encoder = Expert.RotaryPositionalEncoding(self.d_model)
            self.simplified_mamba = Expert.SimplifiedMAMBA(config)
            self.output_projection = nn.Linear(self.d_model * 2, self.vocab_size)  # Adjust if needed

            self.initialize_weights()

        def initialize_weights(self):
            gain = 1.0

            nn.init.orthogonal_(self.embedding.weight, gain)
            nn.init.xavier_uniform_(self.output_projection.weight, gain=gain)

        def forward(self, input_values, attention_mask):
            embedded = self.embedding(input_values) * math.sqrt(self.d_model)
            embedded = self.pos_encoder(embedded)
            simplified_mamba_output = self.simplified_mamba(embedded, attention_mask)
            logits = self.output_projection(simplified_mamba_output)
            return logits


    class SwitchRouter(nn.Module):
        CAPACITY_FACTOR = 1  # Class constant

        class SwitchGate(nn.Module):
            def __init__(self, input_dim, num_experts):
                super(Expert.SwitchRouter.SwitchGate, self).__init__()
                self.fc1 = nn.Linear(input_dim, input_dim // 2)
                self.fc2 = nn.Linear(input_dim // 2, num_experts)

            def forward(self, x):
                x = F.relu(self.fc1(x))
                gate_scores = F.softmax(self.fc2(x), dim=-1)
                return gate_scores

        def __init__(self, config):
            super(Expert.SwitchRouter, self).__init__()
            self.config = config
            self.device = config.device
            self.router = self.SwitchGate(config.input_dim, config.num_experts).to(config.device)
            self.transformer_rag = Expert.TransformerRAG(config).to(config.device)
            self.lmt = Expert.LanguageModelTransformer(config).to(config.device)
            self.transformer_dpo = Expert.DPO(self.lmt, config.device).to(config.device)
            self.mamba = Expert.SimplifiedLanguageModelMAMBA(config).to(config.device)
            self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
            self.input_embedding = nn.Linear(config.input_dim, config.input_dim)

        def forward(self, x, attention_mask, context_texts, question_text):
            x = x.to(self.device)
            if x.dtype != torch.float32:
                x = x.float()

            x = self.input_embedding(x)
            gate_scores = self.router(x)
            expert_indices = torch.argmax(gate_scores, dim=1)
            final_output = torch.zeros_like(x)

            for i, expert in enumerate(self.experts):
                mask = expert_indices == i
                if mask.any():
                    selected_inputs = x[mask]
                    selected_attention_mask = attention_mask[mask].to(self.device)
                    if isinstance(expert, Expert.TransformerRAG):
                        expert_output = expert.forward(context_texts, selected_inputs, selected_attention_mask, question_text)
                    else:
                        expert_output = expert(selected_inputs, selected_attention_mask)
                    final_output[mask] = expert_output

            aux_loss = self.auxiliary_loss(gate_scores)
            return final_output, aux_loss

        def auxiliary_loss(self, gate_scores):
            expert_load = gate_scores.sum(0) / gate_scores.size(0)
            loss_balancing = torch.std(expert_load)
            return loss_balancing


        @staticmethod
        def route_inputs(expert_indices, gate_scores, num_experts):
            capacity_factor_tensor = torch.tensor([Expert.SwitchRouter.CAPACITY_FACTOR], dtype=torch.float32)
            capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
            expert_counts = torch.zeros(num_experts, dtype=torch.int32)
            for idx in range(len(expert_indices)):
                selected_expert = expert_indices[idx]
                if expert_counts[selected_expert] < capacities[0]:
                    expert_counts[selected_expert] += 1
                else:
                    available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
                    if len(available_experts) > 0:
                        alternative_expert = available_experts[0]
                        expert_indices[idx] = alternative_expert
                        expert_counts[alternative_expert] += 1
                    else:
                        print("No available experts to reroute. Handling overflow.")
            return expert_indices
    
    class CustomDPRContextEncoder(nn.Module):
        def __init__(self, embedding_dim):
            super(Expert.CustomDPRContextEncoder, self).__init__()  
            self.bert_model = BertModel.from_pretrained('bert-base-uncased')
            self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

        def forward(self, input_ids, attention_mask=None):
            outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.pooler_output
            context_embeddings = self.embedding_layer(pooled_output)
            return context_embeddings

    class DPRQuestionEncoder(nn.Module):
        def __init__(self, embedding_dim):
            super(Expert.DPRQuestionEncoder, self).__init__() 
            self.bert_model = BertModel.from_pretrained('bert-base-uncased')
            self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

        def forward(self, input_ids, attention_mask, **kwargs):
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.pooler_output
            embeddings = self.embedding_layer(pooled_output)
            return embeddings

    class TransformerRAG(nn.Module):
        def __init__(self, config):
            super(Expert.TransformerRAG, self).__init__()
            self.config = config
            self.context_encoder = Expert.CustomDPRContextEncoder(config.embedding_dim).to(config.device)
            self.language_model = Expert.LanguageModelTransformer(config).to(config.device)
            self.question_encoder = Expert.DPRQuestionEncoder(config.embedding_dim).to(config.device)
            self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer_name)


        def forward(self, context_texts, question_input_ids, question_attention_mask, question_text):
            if question_input_ids.max() >= self.tokenizer.vocab_size:
                raise ValueError("question_input_ids contain ID(s) beyond the tokenizer's vocabulary size")
            
            aggregated_context_embeddings = []
            for context_list in context_texts:
                if not all(isinstance(context, dict) for context in context_list):
                    raise TypeError("Each item in context_texts must be a list of tokenized context dictionaries")
                
                aggregated_context_embedding = torch.zeros(self.context_encoder.bert_model.config.hidden_size, device=device)
                for context in context_list:
                    context_input_ids = context['input_ids'].to(config.device)
                    context_attention_mask = context['attention_mask'].to(config.device)
                    context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                    aggregated_context_embedding += context_embedding.mean(dim=0)
                
                aggregated_context_embeddings.append(aggregated_context_embedding / len(context_list))
            
            question_input_ids = question_input_ids.to(config.device).long()
            question_attention_mask = question_attention_mask.to(config.device).long()
            question_embeddings = self.question_encoder(input_ids=question_input_ids, attention_mask=question_attention_mask)
            
            cos_sim = torch.nn.CosineSimilarity(dim=1)
            similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in aggregated_context_embeddings]
            most_relevant_context_idx = torch.argmax(torch.tensor(similarities, device=config.device))
            
            combined_input = question_text + " " + context_texts[most_relevant_context_idx]
            tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
            tokenized_combined_input = {k: v.to(config.device) for k, v in tokenized_combined_input.items()}
            response_logits = self.language_model(**tokenized_combined_input)
            probabilities = F.softmax(response_logits.logits, dim=-1)
            predicted_token_ids = torch.argmax(probabilities, dim=-1)
            predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])
            response = self.tokenizer.convert_tokens_to_string(predicted_tokens)
            
            return response

        @staticmethod
        def extract_text_from_pdf(file_path):
            text = ""
            with fitz.open(file_path) as doc:
                for page in doc:
                    text += page.get_text()
            return text

        @staticmethod
        def split_into_chunks(text, chunk_size):
            return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

        @staticmethod
        def preprocess_text(text, max_length=512):
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            chunk_size = max_length - 50
            text_chunks = Expert.TransformerRAG.split_into_chunks(text, chunk_size)
            processed_chunks = []
            for chunk in text_chunks:
                tokenized_output = tokenizer(chunk, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
                processed_chunk = {
                    'input_ids': tokenized_output['input_ids'],
                    'attention_mask': tokenized_output['attention_mask']
                }
                processed_chunks.append(processed_chunk)
            return processed_chunks

        @staticmethod
        def create_dataset_from_pdfs(pdf_file_paths):
            dataset = []
            for file_path in pdf_file_paths:
                text = Expert.TransformerRAG.extract_text_from_pdf(file_path)
                processed_text = Expert.TransformerRAG.preprocess_text(text)
                dataset.append(processed_text)
            return dataset

        def retrieve_contexts(self, dataset, query_embedding, top_k=5):
            similarity_scores = []
            for context in dataset:
                context_input_ids = context['input_ids'].to(self.config.device)
                context_attention_mask = context['attention_mask'].to(self.config.device)
                # Use the class's context_encoder
                context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                similarity = torch.matmul(query_embedding, context_embedding.T)
                similarity_scores.append(similarity.squeeze().item())
            top_k_indices = sorted(range(len(similarity_scores)), key=lambda i: similarity_scores[i], reverse=True)[:top_k]
            top_contexts = [dataset[i] for i in top_k_indices]
            return top_contexts

        def rag_retrieve_and_generate(self, dataset, query):
            # Tokenize the query using the class's tokenizer
            tokenized_query = self.tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.config.device)
            input_ids = tokenized_query['input_ids']
            attention_mask = tokenized_query['attention_mask']
            # Use the class's question_encoder
            encoded_query = self.question_encoder(input_ids, attention_mask)
            relevant_contexts = self.retrieve_contexts(dataset, encoded_query)
            # Assuming generate_response is a method of LanguageModelTransformer that accepts tokenized contexts and generates a response
            response = self.language_model.generate_response(relevant_contexts)
            return response




    class LORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1):
            super(Expert.LORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha

            # Original weight and bias of the linear layer
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            #print("self.weight Shape:", self.weight.shape)
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # LORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def forward(self, x):
            #print("LORALayer Input Shape:", x.shape)
            
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)
            
            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            #print("LORALayer Output Shape:", x.shape)

            return x

    class QLORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
            super(Expert.QLORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha
            self.quantization_bits = quantization_bits

            # Original weight and bias
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # QLORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.dropout = nn.Dropout(0.1)
            self.layer_norm = nn.LayerNorm(output_dim)

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def quantize(self, x, num_bits):
            # Implement a simple quantization method
            scale = x.abs().max()
            x_quantized = torch.round(x / scale * (2**num_bits - 1))
            return x_quantized, scale

        def forward(self, x):
            #print("QLORALayer Input Shape:", x.shape)
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
            B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
            lora_adjustment = self.dropout(lora_adjustment)
            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)

            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            x = self.layer_norm(x)

            #print("QLORALayer Output Shape:", x.shape)

            return x
        
        def update_alpha(self, new_alpha):
            """
            Update the alpha scaling factor.
            """
            self.alpha = new_alpha

    class MultiHeadAttention(nn.Module):
        def __init__(self, embed_size, heads):
            super(Expert.MultiHeadAttention, self).__init__()
            self.embed_size = embed_size
            self.heads = heads
            self.head_dim = embed_size // heads

            assert (
                self.head_dim * heads == embed_size
            ), "Embedding size needs to be divisible by heads"

            self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

        def forward(self, values, keys, query, mask):
            N = query.shape[0]
            value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

            # Split the embedding into self.heads different pieces
            values = values.reshape(N, value_len, self.heads, self.head_dim)
            keys = keys.reshape(N, key_len, self.heads, self.head_dim)
            queries = query.reshape(N, query_len, self.heads, self.head_dim)

            values = self.values(values)
            keys = self.keys(keys)
            queries = self.queries(queries)

            # Einsum does the matrix multiplication for query*keys for each training example
            # with every other training example, then sum it up
            attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

            if mask is not None:
                attention = attention.masked_fill(mask == 0, float("-1e20"))

            attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

            out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
                N, query_len, self.heads * self.head_dim
            )

            out = self.fc_out(out)
            return out

    class TransformerBlock(nn.Module):
        def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
            super(Expert.TransformerBlock, self).__init__()
            self.attention = Expert.MultiHeadAttention(embed_size, heads)
            self.norm1 = nn.LayerNorm(embed_size)
            self.norm2 = nn.LayerNorm(embed_size)

            self.feed_forward = nn.Sequential(
                Expert.LORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                Expert.LORALayer(forward_expansion * embed_size, embed_size, rank),
            )

            self.dropout = nn.Dropout(dropout)

        def forward(self, value, key, query, mask):
            attention = self.attention(value, key, query, mask)
            x = self.dropout(self.norm1(attention + query))
            forward = self.feed_forward(x)
            out = self.dropout(self.norm2(forward + x))
            return out

    class LanguageModelDecoder(nn.Module):
        def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
            super(Expert.LanguageModelDecoder, self).__init__()
            self.word_embedding = nn.Embedding(vocab_size, embed_size)
            self.position_embedding = nn.Embedding(max_length, embed_size)

            # Adding BatchNorm layers
            self.bn1 = nn.BatchNorm1d(embed_size)
            self.bn2 = nn.BatchNorm1d(embed_size)

            self.layers = nn.ModuleList(
                [
                    Expert.TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                    for _ in range(num_layers)
                ]
            )

            # QLORA layers
            self.qlora_feed_forward = nn.Sequential(
                Expert.QLORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                Expert.QLORALayer(forward_expansion * embed_size, embed_size, rank),
            )
            self.use_qlora = False  # Flag to toggle QLORA

            self.fc_out = nn.Linear(embed_size, vocab_size)
            self.dropout = nn.Dropout(dropout)

        def forward(self, x, trg_mask):
            N, seq_length = x.shape
            if seq_length > self.position_embedding.num_embeddings:
                raise ValueError(f"Sequence length {seq_length} exceeds maximum allowed {self.position_embedding.num_embeddings}")
            positions = torch.arange(0, seq_length).expand(N, seq_length).to(config.device)
            x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

            print(f"Max position index: {positions.max().item()}, Position Embedding Size: {self.position_embedding.num_embeddings}")


            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn1(x)
            x = x.transpose(1, 2)

            for layer in self.layers:
                x = layer(x, x, x, trg_mask)
                if self.use_qlora:
                    x = self.qlora_feed_forward(x)

            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn2(x)
            x = x.transpose(1, 2)

            out = self.fc_out(x)
            #print(f"shape of output of forward method of LanguageModelDecoder: {out.shape} ")

            return out

        def toggle_qlora(self, use_qlora: bool):
            self.use_qlora = use_qlora

    class LanguageModelTransformer(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.vocab_size = config.vocab_size  # Use vocab_size from ExpertConfig
            self.embed_size = config.embed_size  # Use embed_size from ExpertConfig
            self.num_layers = config.num_layers  # Use num_layers from ExpertConfig
            self.forward_expansion = config.forward_expansion  # Use forward_expansion from ExpertConfig
            self.heads = config.heads  # Use heads from ExpertConfig
            self.dropout = config.dropout  # Use dropout from ExpertConfig
            self.max_length = config.max_length  # Use max_length from ExpertConfig
            self.rank = config.rank  # Use rank from ExpertConfig
            self.tokenizer_name = config.tokenizer_name  # Use tokenizer_name from ExpertConfig

            self.tokenizer = BertTokenizer.from_pretrained(self.tokenizer_name)

            self.decoder = Expert.LanguageModelDecoder(
                self.vocab_size,
                self.embed_size,
                self.num_layers,
                self.heads,
                self.forward_expansion,
                self.dropout,
                self.max_length,
                self.rank,
            )

        def forward(self, trg, attention_mask=None):  # Remove attention_mask here since it's not used
            print(f"Input shape to LanguageModelTransformer: {trg.shape}")
            trg_mask = self.make_trg_mask(trg)
            out = self.decoder(trg, trg_mask)  # Do not pass attention_mask here
            return out

        def make_trg_mask(self, trg):
            N, trg_len = trg.shape
            trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len).to(trg.device)
            return trg_mask

        def toggle_qlora(self, use_qlora: bool):
            self.decoder.toggle_qlora(use_qlora)

        def generate_response(self, input_ids, attention_mask):
            logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
            probabilities = F.softmax(logits, dim=-1)
            predicted_token_id = torch.argmax(probabilities, dim=-1)
            predicted_tokens = [self.tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
            response = self.tokenizer.convert_tokens_to_string(predicted_tokens)
            return response


    class DPO(nn.Module):
        def __init__(self, language_model, device):
            super(Expert.DPO, self).__init__()
            self.language_model = language_model
            self.device = device
        
        def forward(self, input_ids_question, input_ids_chosen=None, input_ids_rejected=None, labels=None, attention_mask=None):
            # Concatenate for DPO-specific data or use regular input handling
            combined_input_ids = torch.cat([input_ids_question, input_ids_chosen, input_ids_rejected], dim=1) if input_ids_chosen is not None and input_ids_rejected is not None else input_ids_question
            
            # Use the vocab_size from the language model
            combined_input_ids = torch.clamp(combined_input_ids, 0, self.language_model.vocab_size - 1)
            
            # Pass through the language model
            output = self.language_model(combined_input_ids) #, attention_mask=attention_mask)
            
            # Calculate loss if labels are provided
            loss = None
            if labels is not None:
                logits = output.logits if hasattr(output, 'logits') else output
                loss_fct = nn.BCEWithLogitsLoss()
                loss = loss_fct(logits.view(-1), labels.float().view(-1))
            
            return output, loss




    def __init__(self, config: ExpertConfig):
        super().__init__()
        # Validate configuration
        config.validate()

        # 1. SparseFlash2_attention
        self.sparse_flash2_attention = Expert.SparseFlash2Attention(
            config.seq_len, 
            config.head_dim, 
            config.block_size, 
            config.sparsity_factor
        )

        # Inside the Expert class definition
        self.lmt = Expert.LanguageModelTransformer(config).to(config.device)

        self.transformer_rag = Expert.TransformerRAG(config).to(config.device)

        # Corrected instantiation of SimplifiedLanguageModelMAMBA
        self.mamba = Expert.SimplifiedLanguageModelMAMBA(config).to(config.device)

        # Now initialize DPO with the language model transformer
        self.transformer_dpo = Expert.DPO(self.lmt, config.device).to(config.device)

        # 2. LayerNorm and Dropout
        self.layer_norm = nn.LayerNorm(config.input_dim)
        self.dropout = nn.Dropout(config.dropout)

        # 3. Internal Switch Routing
        self.switch_router = Expert.SwitchRouter(config)

        # 4. QLORA Layer
        self.qlora = Expert.QLORALayer(config.input_dim, config.input_dim, config.rank)


    def forward(self, x, attention_mask, context_texts, question_text):
        # 1. SparseFlash2_attention
        x = self.sparse_flash2_attention(x)

        # 2. LayerNorm and Dropout
        x = self.dropout(self.layer_norm(x))

        # 3. Internal Switch Routing
        x, aux_loss = self.switch_router(x, attention_mask, context_texts, question_text)

        # 4. QLORA Layer
        x = self.qlora(x)

        return x, aux_loss

    @staticmethod
    def calculate_new_alpha(current_loss, initial_loss, initial_alpha=1.0, final_alpha=0.1):
        """
        Calculate a new alpha value based on the current loss.
        """
        if current_loss >= initial_loss:
            return initial_alpha  # Keep initial alpha if loss isn't decreasing

        loss_ratio = current_loss / initial_loss
        alpha_range = initial_alpha - final_alpha
        new_alpha = final_alpha + (alpha_range * loss_ratio)
        return new_alpha  
    

    def train_dpo(self, train_loader, optimizer, label_column, input_columns, config, save_path):
        self.transformer_dpo.train()  # Set the model to training mode
        total_loss = 0

        for i, batch in enumerate(train_loader):
            print(f"Batch {i+1}:")
            input_ids_question = batch['input_ids_question'].to(config.device)
            input_ids_chosen = batch['input_ids_chosen'].to(config.device)
            input_ids_rejected = batch['input_ids_rejected'].to(config.device)
            labels = batch['labels'].to(config.device)

            optimizer.zero_grad()

            # Combine the input ids
            combined_input_ids = torch.cat([input_ids_question, input_ids_chosen, input_ids_rejected], dim=1)
            
            # Truncate the combined_input_ids to the maximum sequence length
            max_length = config.max_length  # 512
            if combined_input_ids.size(1) > max_length:
                combined_input_ids = combined_input_ids[:, :max_length]

            # Forward pass
            logits, loss = self.transformer_dpo(input_ids_question=combined_input_ids, input_ids_chosen=None, input_ids_rejected=None, labels=labels)

            if loss is not None:
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            else:
                print(f"No loss to backpropagate for batch {i+1}")

        average_loss = total_loss / len(train_loader)
        torch.save(self.transformer_dpo.state_dict(), save_path)
        return average_loss

    




    def train_language_model_rag(self, model, save_path, train_loader, device, num_epochs=5, lr=1e-8, weight_decay=1e-4):
        # Enable QLORA during training
        model.decoder.toggle_qlora(True)

        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = StepLR(optimizer, step_size=4, gamma=0.98)

        initial_loss = None

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            total_loss = 0

            for batch_idx, batch in enumerate(train_loader):
                inputs = batch['input_ids'].to(device)
                targets = batch['labels'].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))

                # Check for NaN in loss
                if math.isnan(loss.item()):
                    print("Encountered NaN loss, stopping training")
                    break

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()

                total_loss += loss.item()

                # Set the initial_loss after the first batch of the first epoch
                if initial_loss is None and batch_idx == 0:
                    initial_loss = loss.item()

            scheduler.step()

            # Check for NaN in total_loss
            if math.isnan(total_loss):
                print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
                break
            else:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

            # Average loss for the epoch
            average_loss = total_loss / len(train_loader)

            # Update alpha at the end of each epoch based on the average loss
            new_alpha = self.calculate_new_alpha(average_loss, initial_loss)
            for layer in model.modules():
                if isinstance(layer, QLORALayer):
                    layer.update_alpha(new_alpha)

        # Toggle QLORA off after training
        model.decoder.toggle_qlora(False)
        average_loss = total_loss / len(train_loader)

        torch.save(model.state_dict(), save_path)
        print("Training Complete")
        return model, average_loss

    def train_dpr_encoders(self, train_data, context_encoder, question_encoder, optimizer_context, optimizer_question, epochs , context_save_path, question_save_path):
        loss_function = nn.CosineEmbeddingLoss()

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        for epoch in range(epochs):
            total_loss = 0

            for i in range(len(train_data["queries"])):
                query = train_data["queries"][i]
                context = train_data["contexts"][i]

                # Ensure query is a string
                if not isinstance(query, str):
                    raise ValueError("Query must be a string.")
                tokenized_query = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)

                # Ensure context is a string
                if isinstance(context, dict):
                    # If context is a dictionary, extract the text content. This is a placeholder and might need adjustment.
                    context_text = context.get("text", "")
                elif isinstance(context, str):
                    context_text = context
                else:
                    raise ValueError("Context must be a string or a dictionary containing a text field.")
                tokenized_context = tokenizer(context_text, return_tensors="pt", padding=True, truncation=True, max_length=512)

                question_embeddings = question_encoder(input_ids=tokenized_query['input_ids'], attention_mask=tokenized_query['attention_mask'])
                context_embeddings = context_encoder(input_ids=tokenized_context['input_ids'], attention_mask=tokenized_context['attention_mask'])

                # The labels tensor should have the same first dimension size as the input tensors
                labels = torch.tensor([1.0] * question_embeddings.size(0), dtype=torch.float).to(question_embeddings.device)

                loss = loss_function(question_embeddings, context_embeddings, labels)

                optimizer_context.zero_grad()
                optimizer_question.zero_grad()
                loss.backward()
                optimizer_context.step()
                optimizer_question.step()

                total_loss += loss.item()

            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_data['queries'])}")
        average_loss = total_loss / len(train_data['queries'])
        torch.save(context_encoder.state_dict(), context_save_path)
        torch.save(question_encoder.state_dict(), question_save_path)
        return (context_encoder, question_encoder), average_loss
       
    def train_language_model_transformer(self, train_loader, device, vocab_size, save_path):
        model = self.lmt
        
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-8, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.98)
        num_epochs = 5
        
        initial_loss = None

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            model.decoder.toggle_qlora(True)
            total_loss = 0

            for batch_idx, batch in enumerate(train_loader):
                inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

                # Check for NaN in loss
                if math.isnan(loss.item()):
                    print("Encountered NaN loss, stopping training")
                    break

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()

                total_loss += loss.item()

                # Set the initial_loss after the first batch of the first epoch
                if initial_loss is None and batch_idx == 0:
                    initial_loss = loss.item()

            scheduler.step()

            # Check for NaN in total_loss
            if math.isnan(total_loss):
                print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
                break
            else:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

            # Average loss for the epoch
            average_loss = total_loss / len(train_loader)

            # Update alpha at the end of each epoch based on the average loss
            new_alpha = self.calculate_new_alpha(average_loss, initial_loss)
            for layer in model.modules():
                if isinstance(layer, Expert.QLORALayer):
                    layer.update_alpha(new_alpha)

        average_loss = total_loss / len(train_loader)
        torch.save(model.state_dict(), save_path)
        return model, average_loss


    def train_expert(self, train_loader, train_data, optimizer, main_loss_function, aux_loss_weight, device,save_path, accumulation_steps=4, num_epochs=5):
            self.train()  # Set the model to training mode
            for epoch in range(num_epochs):
                total_loss = 0
                optimizer.zero_grad()  # Initialize gradients to zero
                for batch_idx, batch in enumerate(train_loader):
                    inputs, attention_mask, targets = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)

                    # Calculate start and end indices for current batch in train_data
                    start_idx = batch_idx * batch['input_ids'].size(0)
                    end_idx = start_idx + batch['input_ids'].size(0)

                    # Extract current_queries and current_contexts for the batch
                    current_queries = train_data['queries'][start_idx:end_idx]
                    current_contexts = train_data['contexts'][start_idx:end_idx]

                    # Call to the model forward function
                    outputs, aux_loss = self(inputs, attention_mask, current_contexts, current_queries)

                    # Calculate loss and accumulate
                    main_loss = main_loss_function(outputs.view(-1, outputs.size(-1)), targets.view(-1))
                    loss = (main_loss + aux_loss_weight * aux_loss) / accumulation_steps
                    loss.backward()

                    if (batch_idx + 1) % accumulation_steps == 0:
                        optimizer.step()
                        optimizer.zero_grad()

                    total_loss += loss.item() * accumulation_steps  # Scale back up

                average_loss = total_loss / len(train_loader)
                print(f'End of Epoch {epoch+1}, Average Loss: {average_loss}')
            average_loss = total_loss / len(train_loader)
            torch.save(self.state_dict(), save_path)
            return self, average_loss


##################################
# Training transformer_with_dpo

torch.autograd.set_detect_anomaly(True)

def preprocess_dpo_data(examples):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    vocab_size = len(tokenizer.vocab)  # Make sure this matches your embedding layer's vocab size

    # Define max sequence length
    max_seq_length = 512

    # Tokenize 'question', 'chosen', and 'rejected' fields
    tokenized_questions = tokenizer(examples['question'], padding='max_length', truncation=True, max_length=max_seq_length)
    tokenized_chosen = tokenizer(examples['chosen'], padding='max_length', truncation=True, max_length=max_seq_length)
    tokenized_rejected = tokenizer(examples['rejected'], padding='max_length', truncation=True, max_length=max_seq_length)

    # Prepare the labels: 1 for 'chosen' and 0 for 'rejected'
    labels = [1 if i % 2 == 0 else 0 for i in range(len(examples['question']))]

    return {
        'input_ids_question': tokenized_questions['input_ids'],
        'attention_mask_question': tokenized_questions['attention_mask'],
        'input_ids_chosen': tokenized_chosen['input_ids'],
        'attention_mask_chosen': tokenized_chosen['attention_mask'],
        'input_ids_rejected': tokenized_rejected['input_ids'],
        'attention_mask_rejected': tokenized_rejected['attention_mask'],
        'labels': labels
    }



# Load the DPO dataset from Hugging Face
dpo_dataset = load_dataset("Intel/orca_dpo_pairs")

# Apply the preprocessing to the dataset
dpo_dataset = dpo_dataset.map(preprocess_dpo_data, batched=True)


# You can convert to PyTorch tensors after processing
dpo_dataset.set_format(type='torch', columns=['input_ids_question', 'attention_mask_question', 'input_ids_chosen', 'attention_mask_chosen', 'input_ids_rejected', 'attention_mask_rejected', 'labels'])

train_loader = DataLoader(dpo_dataset['train'], batch_size=2, shuffle=True)

# Instantiate the Expert model and optimizer
config = ExpertConfig()
expert_model = Expert(config)

model_vocab_size = expert_model.lmt.decoder.word_embedding.num_embeddings
print(f"Model vocab size: {model_vocab_size}")
print(f"Model embedding size: {expert_model.lmt.decoder.word_embedding.embedding_dim}")
print(f"Configured max length: {config.max_length}")

optimizer = AdamW(expert_model.parameters(), lr=1e-5)
save_path = 'D:/EXPERT_WEIGHTS/dpo_model.pth'

# Train the DPO model
label_column = 'labels'
input_columns = ['input_ids_question', 'attention_mask_question', 'input_ids_chosen', 'attention_mask_chosen', 'input_ids_rejected', 'attention_mask_rejected']
avg_loss = expert_model.train_dpo(train_loader, optimizer, label_column, input_columns, config, save_path)

# Save the model
torch.save(expert_model.transformer_dpo.state_dict(), save_path)



Model vocab size: 30522
Model embedding size: 256
Configured max length: 512
Batch 1:
Input shape to LanguageModelTransformer: torch.Size([2, 512])
Max position index: 511, Position Embedding Size: 512


ValueError: Target size (torch.Size([2])) must be the same as input size (torch.Size([31254528]))

# v2

In [33]:
# 0. Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import fitz  # PyMuPDF
from transformers import BertModel
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

class ExpertConfig:
    def __init__(self, seq_len=512, head_dim=64, block_size=64, sparsity_factor=4,
                 input_dim=512, num_experts=3, vocab_size=30522, embed_size=256,
                 num_layers=6, forward_expansion=4, heads=8, dropout=0.1,
                 max_length=512,  # Updated to match seq_len for consistency
                 rank=16, device='cpu',
                 mamba_model_path='D:\\EXPERT_WEIGHTS\\mamba_model_weights.pth',
                 context_encoder_path='D:\\EXPERT_WEIGHTS\\context_encoder.pth',
                 language_model_path='D:\\EXPERT_WEIGHTS\\language_model.pth',
                 question_encoder_path='D:\\EXPERT_WEIGHTS\\question_encoder.pth',
                 dpo_model_path='D:\\EXPERT_WEIGHTS\\dpo_model_weights.pth',
                 model_name='bert-base-uncased', embedding_dim=768,
                 alpha=1, quantization_bits=8, tokenizer_name='bert-base-uncased',
                 d_model=512, d_state=2048, d_conv=3, expansion_factor=2):

        # Common hyperparameters
        self.seq_len = seq_len
        self.head_dim = head_dim
        self.block_size = block_size
        self.sparsity_factor = sparsity_factor
        self.input_dim = input_dim
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.num_layers = num_layers
        self.forward_expansion = forward_expansion
        self.heads = heads
        self.dropout = dropout
        self.max_length = max_length  # Ensure this is properly reflected in model components
        self.rank = rank

        # Model paths and device
        self.mamba_model_path = mamba_model_path
        self.context_encoder_path = context_encoder_path
        self.language_model_path = language_model_path
        self.question_encoder_path = question_encoder_path
        self.dpo_model_path = dpo_model_path
        self.device = device

        # Unique hyperparameters
        self.num_experts = num_experts
        self.model_name = model_name
        self.embedding_dim = embedding_dim
        self.alpha = alpha
        self.quantization_bits = quantization_bits
        self.tokenizer_name = tokenizer_name
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor

        # PDFs (unchanged)
        self.pdf_file_paths = [
            r'C:\Users\robbi\IEEMM\DPO.pdf', 
            r'C:\Users\robbi\IEEMM\MAMBA.pdf',
            r'C:\Users\robbi\IEEMM\QLORA.pdf',
            r'C:\Users\robbi\IEEMM\RAG.pdf',
            r'C:\Users\robbi\IEEMM\SWITCH_TRANSFORMER.pdf'
        ]
        
        # Preserving original dataset loading functionality
        self.rag_dataset = Expert.TransformerRAG.create_dataset_from_pdfs(self.pdf_file_paths)

    def validate(self):
        assert self.seq_len % self.block_size == 0, "seq_len must be divisible by block_size"
        assert self.max_length >= self.seq_len, "max_length should be equal to or greater than seq_len"



class Expert(nn.Module):

    @staticmethod
    # Load model weights function
    def load_model_weights(model, model_path):
        #checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        checkpoint = torch.load(model_path, map_location=config.device)

        if isinstance(checkpoint, dict):
            # Check for 'state_dict' or 'model_state_dict' keys
            if 'state_dict' in checkpoint:
                model.load_state_dict(checkpoint['state_dict'])
            elif 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                # If no known key is found, try loading it as a raw state dictionary
                try:
                    model.load_state_dict(checkpoint)
                except RuntimeError as e:
                    raise ValueError(f"Error loading state dict: {e}")
        elif isinstance(checkpoint, nn.Module):
            # If the checkpoint is a model object, assign it directly
            model = checkpoint
        else:
            raise ValueError(f"Unsupported checkpoint format: {type(checkpoint)}")

        model.eval()
        return model

    # Flash2_Attention
    class FlashAttention2(nn.Module):
        def __init__(self, sequence_length, head_dimension, block_size):
            super(Expert.FlashAttention2, self).__init__()
            self.block_size = block_size
            # Ensure that sequence_length is divisible by block_size for simplicity
            assert sequence_length % block_size == 0

        def forward(self, Q, K, V):
            # Partitioning of inputs
            Q_blocks, K_blocks, V_blocks = self.partition_inputs(Q, K, V)

            # Efficient computation of the attention mechanism
            outputs = []
            for i, Q_block in enumerate(Q_blocks):
                output_block = self.process_block(Q_block, K_blocks, V_blocks)
                outputs.append(output_block)

            # Concatenating the processed blocks
            output = torch.cat(outputs, dim=0)
            return output

        def partition_inputs(self, Q, K, V):
            # The actual partitioning scheme should be based on sequence length, head dimension, and block size
            Q_blocks = Q.chunk(chunks=Q.size(0) // self.block_size, dim=0)
            K_blocks = K.chunk(chunks=K.size(0) // self.block_size, dim=0)
            V_blocks = V.chunk(chunks=V.size(0) // self.block_size, dim=0)
            return Q_blocks, K_blocks, V_blocks

        def process_block(self, Q_block, K_blocks, V_blocks):
            # Process each block efficiently as per FLASH2's optimized method
            # This includes computing QK^T, applying online softmax, and multiplying with V
            output_blocks = []
            for K_block, V_block in zip(K_blocks, V_blocks):
                attention_scores = torch.matmul(Q_block, K_block.transpose(-2, -1))
                attention_scores = self.online_softmax(attention_scores)
                output_block = torch.matmul(attention_scores, V_block)
                output_blocks.append(output_block)

            # Summing up the results from each block
            output_block_sum = sum(output_blocks)
            return output_block_sum

        def online_softmax(self, scores, chunk_size=128):
            # Apply softmax in chunks for large sequences
            softmaxed_scores = []
            for i in range(0, scores.size(0), chunk_size):
                chunk = scores[i:i + chunk_size, :]
                softmaxed_chunk = F.softmax(chunk, dim=1)
                softmaxed_scores.append(softmaxed_chunk)
            return torch.cat(softmaxed_scores, dim=0)

    # SparseFlash2_Attention
    class SparseFlash2Attention(nn.Module):
        def __init__(self, seq_len, head_dim, blk_size, sparsity_factor):
            super().__init__()
            self.flash_attention = Expert.FlashAttention2(seq_len, head_dim, blk_size)
            self.seq_len = seq_len
            self.head_dim = head_dim
            self.block_size = blk_size  # Storing block_size as an instance variable
            self.sparsity_factor = sparsity_factor

        def generate_sparsity_mask(self):
            mask = torch.zeros(self.seq_len, self.seq_len)
            step = self.sparsity_factor
            for i in range(0, self.seq_len, step):
                mask[i:i + step, :] = 1
            return mask.bool()

        def forward(self, Q, K, V):
            output = self.flash_attention(Q, K, V)  # output shape: [sequence_length, head_dimension]

            # Reshape output to be 3D for batch matrix multiplication
            output = output.unsqueeze(0)  # New shape: [1, sequence_length, head_dimension]

            sparsity_mask = self.generate_sparsity_mask()  # shape: [sequence_length, sequence_length]

            # Apply the sparsity mask to the output
            sparsity_mask = sparsity_mask.unsqueeze(0)  # New shape: [1, sequence_length, sequence_length]
            output = torch.bmm(sparsity_mask.float(), output.float())  # Perform batch matrix multiplication

            # Reshape the output back to 2D
            output = output.squeeze(0)  # New shape: [sequence_length, head_dimension]

            return output


    # MAMBA
    # RoPE
    class RotaryPositionalEncoding(nn.Module):
        def __init__(self, dim, max_len=5000):
            super().__init__()
            self.dim = dim
            inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
            t = torch.arange(max_len).type_as(inv_freq)
            freqs = torch.einsum('n , d -> n d', t, inv_freq)
            self.register_buffer('sin', freqs.sin())
            self.register_buffer('cos', freqs.cos())

        def forward(self, x):
            n, _, device = x.shape[1], self.dim // 2, x.device
            sin, cos = self.sin[:n].to(device), self.cos[:n].to(device)

            # Apply RoPE to even and odd indices separately
            x_even = x[..., :self.dim:2] * cos.unsqueeze(0) + torch.roll(x[..., 1:self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
            x_odd = x[..., 1:self.dim:2] * cos.unsqueeze(0) - torch.roll(x[..., :self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
            return torch.cat((x_even, x_odd), dim=-1)

    # SWIGLU
    class SwiGLU(nn.Module):
        def __init__(self, dim_in, dim_out):
            super(Expert.SwiGLU, self).__init__()
            self.fc1 = nn.Linear(dim_in, dim_out)
            self.fc2 = nn.Linear(dim_in, dim_out)

        def forward(self, x):
            gate = torch.sigmoid(self.fc2(x))
            return self.fc1(x) * gate

    class SimplifiedMAMBA(nn.Module):
        def __init__(self, config):
            super().__init__()
            # Using configuration parameters
            self.num_layers = config.num_layers
            self.d_model = config.d_model
            self.d_state = config.d_state
            self.d_conv = config.d_conv
            self.expansion_factor = config.expansion_factor
            
            self.feedforward = nn.Sequential(
                nn.Linear(self.d_model, self.d_state),
                nn.GELU(),
                nn.Linear(self.d_state, self.d_model)
            )
            self.input_embedding = nn.Linear(self.d_model, self.d_model)
            self.convs = nn.Sequential(*[nn.Conv1d(self.d_model, self.d_model, kernel_size=self.d_conv, padding=(self.d_conv // 2)) for _ in range(self.num_layers)])
            self.swiglu = Expert.SwiGLU(self.d_model, self.d_model)
            self.output_projection = nn.Linear(self.d_model, self.d_model * self.expansion_factor)

            self.initialize_weights()

        def initialize_weights(self):
            gain = nn.init.calculate_gain('relu')

            nn.init.orthogonal_(self.input_embedding.weight, gain)
            nn.init.normal_(self.input_embedding.bias, mean=0, std=0.01)

            nn.init.kaiming_uniform_(self.convs[-1].weight, a=math.sqrt(5))
            nn.init.zeros_(self.convs[-1].bias)

            nn.init.xavier_uniform_(self.feedforward[0].weight, gain=nn.init.calculate_gain('relu'))
            nn.init.zeros_(self.feedforward[0].bias)

            nn.init.xavier_uniform_(self.feedforward[2].weight, gain=nn.init.calculate_gain('linear'))
            nn.init.zeros_(self.feedforward[2].bias)

            nn.init.xavier_uniform_(self.output_projection.weight, gain=nn.init.calculate_gain('linear'))
            nn.init.zeros_(self.output_projection.bias)

        def forward(self, inputs, attention_mask=None):
            print("Input shape:", inputs.shape)

            # Apply the attention mask if provided
            if attention_mask is not None:
                inputs = inputs * attention_mask.unsqueeze(-1)

            projected_inputs = self.input_embedding(inputs)
            print("projected_inputs pre-reshape shape:", projected_inputs.shape)

            projected_inputs = projected_inputs.permute(0, 2, 1)
            print("projected_inputs post-reshape shape:", projected_inputs.shape)

            for conv in self.convs:
                projected_inputs = conv(projected_inputs)

            projected_inputs = projected_inputs.permute(0, 2, 1)
            print("projected_inputs post convolution reshape:", projected_inputs.shape)

            projected_inputs = self.swiglu(projected_inputs)
            print("projected_inputs post swiglu shape:", projected_inputs.shape)

            output = self.output_projection(projected_inputs)
            print("output shape:", output.shape)

            return output

    class SimplifiedLanguageModelMAMBA(nn.Module):
        def __init__(self, config):
            super().__init__()
            # Using configuration parameters
            self.vocab_size = config.vocab_size
            self.num_layers = config.num_layers
            self.d_model = config.d_model
            self.d_state = config.d_state
            self.d_conv = config.d_conv
            self.expansion_factor = config.expansion_factor

            self.embedding = nn.Embedding(self.vocab_size, self.d_model)
            self.pos_encoder = Expert.RotaryPositionalEncoding(self.d_model)
            self.simplified_mamba = Expert.SimplifiedMAMBA(config)
            self.output_projection = nn.Linear(self.d_model * 2, self.vocab_size)  # Adjust if needed

            self.initialize_weights()

        def initialize_weights(self):
            gain = 1.0

            nn.init.orthogonal_(self.embedding.weight, gain)
            nn.init.xavier_uniform_(self.output_projection.weight, gain=gain)

        def forward(self, input_values, attention_mask):
            embedded = self.embedding(input_values) * math.sqrt(self.d_model)
            embedded = self.pos_encoder(embedded)
            simplified_mamba_output = self.simplified_mamba(embedded, attention_mask)
            logits = self.output_projection(simplified_mamba_output)
            return logits


    class SwitchRouter(nn.Module):
        CAPACITY_FACTOR = 1  # Class constant

        class SwitchGate(nn.Module):
            def __init__(self, input_dim, num_experts):
                super(Expert.SwitchRouter.SwitchGate, self).__init__()
                self.fc1 = nn.Linear(input_dim, input_dim // 2)
                self.fc2 = nn.Linear(input_dim // 2, num_experts)

            def forward(self, x):
                x = F.relu(self.fc1(x))
                gate_scores = F.softmax(self.fc2(x), dim=-1)
                return gate_scores

        def __init__(self, config):
            super(Expert.SwitchRouter, self).__init__()
            self.config = config
            self.device = config.device
            self.router = self.SwitchGate(config.input_dim, config.num_experts).to(config.device)
            self.transformer_rag = Expert.TransformerRAG(config).to(config.device)
            self.lmt = Expert.LanguageModelTransformer(config).to(config.device)
            self.transformer_dpo = Expert.DPO(self.lmt, config.device).to(config.device)
            self.mamba = Expert.SimplifiedLanguageModelMAMBA(config).to(config.device)
            self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
            self.input_embedding = nn.Linear(config.input_dim, config.input_dim)

        def forward(self, x, attention_mask, context_texts, question_text):
            x = x.to(self.device)
            if x.dtype != torch.float32:
                x = x.float()

            x = self.input_embedding(x)
            gate_scores = self.router(x)
            expert_indices = torch.argmax(gate_scores, dim=1)
            final_output = torch.zeros_like(x)

            for i, expert in enumerate(self.experts):
                mask = expert_indices == i
                if mask.any():
                    selected_inputs = x[mask]
                    selected_attention_mask = attention_mask[mask].to(self.device)
                    if isinstance(expert, Expert.TransformerRAG):
                        expert_output = expert.forward(context_texts, selected_inputs, selected_attention_mask, question_text)
                    else:
                        expert_output = expert(selected_inputs, selected_attention_mask)
                    final_output[mask] = expert_output

            aux_loss = self.auxiliary_loss(gate_scores)
            return final_output, aux_loss

        def auxiliary_loss(self, gate_scores):
            expert_load = gate_scores.sum(0) / gate_scores.size(0)
            loss_balancing = torch.std(expert_load)
            return loss_balancing


        @staticmethod
        def route_inputs(expert_indices, gate_scores, num_experts):
            capacity_factor_tensor = torch.tensor([Expert.SwitchRouter.CAPACITY_FACTOR], dtype=torch.float32)
            capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
            expert_counts = torch.zeros(num_experts, dtype=torch.int32)
            for idx in range(len(expert_indices)):
                selected_expert = expert_indices[idx]
                if expert_counts[selected_expert] < capacities[0]:
                    expert_counts[selected_expert] += 1
                else:
                    available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
                    if len(available_experts) > 0:
                        alternative_expert = available_experts[0]
                        expert_indices[idx] = alternative_expert
                        expert_counts[alternative_expert] += 1
                    else:
                        print("No available experts to reroute. Handling overflow.")
            return expert_indices
    
    class CustomDPRContextEncoder(nn.Module):
        def __init__(self, embedding_dim):
            super(Expert.CustomDPRContextEncoder, self).__init__()  
            self.bert_model = BertModel.from_pretrained('bert-base-uncased')
            self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

        def forward(self, input_ids, attention_mask=None):
            outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.pooler_output
            context_embeddings = self.embedding_layer(pooled_output)
            return context_embeddings

    class DPRQuestionEncoder(nn.Module):
        def __init__(self, embedding_dim):
            super(Expert.DPRQuestionEncoder, self).__init__() 
            self.bert_model = BertModel.from_pretrained('bert-base-uncased')
            self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

        def forward(self, input_ids, attention_mask, **kwargs):
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.pooler_output
            embeddings = self.embedding_layer(pooled_output)
            return embeddings

    class TransformerRAG(nn.Module):
        def __init__(self, config):
            super(Expert.TransformerRAG, self).__init__()
            self.config = config
            self.context_encoder = Expert.CustomDPRContextEncoder(config.embedding_dim).to(config.device)
            self.language_model = Expert.LanguageModelTransformer(config).to(config.device)
            self.question_encoder = Expert.DPRQuestionEncoder(config.embedding_dim).to(config.device)
            self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer_name)


        def forward(self, context_texts, question_input_ids, question_attention_mask, question_text):
            if question_input_ids.max() >= self.tokenizer.vocab_size:
                raise ValueError("question_input_ids contain ID(s) beyond the tokenizer's vocabulary size")
            
            aggregated_context_embeddings = []
            for context_list in context_texts:
                if not all(isinstance(context, dict) for context in context_list):
                    raise TypeError("Each item in context_texts must be a list of tokenized context dictionaries")
                
                aggregated_context_embedding = torch.zeros(self.context_encoder.bert_model.config.hidden_size, device=device)
                for context in context_list:
                    context_input_ids = context['input_ids'].to(config.device)
                    context_attention_mask = context['attention_mask'].to(config.device)
                    context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                    aggregated_context_embedding += context_embedding.mean(dim=0)
                
                aggregated_context_embeddings.append(aggregated_context_embedding / len(context_list))
            
            question_input_ids = question_input_ids.to(config.device).long()
            question_attention_mask = question_attention_mask.to(config.device).long()
            question_embeddings = self.question_encoder(input_ids=question_input_ids, attention_mask=question_attention_mask)
            
            cos_sim = torch.nn.CosineSimilarity(dim=1)
            similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in aggregated_context_embeddings]
            most_relevant_context_idx = torch.argmax(torch.tensor(similarities, device=config.device))
            
            combined_input = question_text + " " + context_texts[most_relevant_context_idx]
            tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
            tokenized_combined_input = {k: v.to(config.device) for k, v in tokenized_combined_input.items()}
            response_logits = self.language_model(**tokenized_combined_input)
            probabilities = F.softmax(response_logits.logits, dim=-1)
            predicted_token_ids = torch.argmax(probabilities, dim=-1)
            predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])
            response = self.tokenizer.convert_tokens_to_string(predicted_tokens)
            
            return response

        @staticmethod
        def extract_text_from_pdf(file_path):
            text = ""
            with fitz.open(file_path) as doc:
                for page in doc:
                    text += page.get_text()
            return text

        @staticmethod
        def split_into_chunks(text, chunk_size):
            return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

        @staticmethod
        def preprocess_text(text, max_length=512):
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            chunk_size = max_length - 50
            text_chunks = Expert.TransformerRAG.split_into_chunks(text, chunk_size)
            processed_chunks = []
            for chunk in text_chunks:
                tokenized_output = tokenizer(chunk, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
                processed_chunk = {
                    'input_ids': tokenized_output['input_ids'],
                    'attention_mask': tokenized_output['attention_mask']
                }
                processed_chunks.append(processed_chunk)
            return processed_chunks

        @staticmethod
        def create_dataset_from_pdfs(pdf_file_paths):
            dataset = []
            for file_path in pdf_file_paths:
                text = Expert.TransformerRAG.extract_text_from_pdf(file_path)
                processed_text = Expert.TransformerRAG.preprocess_text(text)
                dataset.append(processed_text)
            return dataset

        def retrieve_contexts(self, dataset, query_embedding, top_k=5):
            similarity_scores = []
            for context in dataset:
                context_input_ids = context['input_ids'].to(self.config.device)
                context_attention_mask = context['attention_mask'].to(self.config.device)
                # Use the class's context_encoder
                context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                similarity = torch.matmul(query_embedding, context_embedding.T)
                similarity_scores.append(similarity.squeeze().item())
            top_k_indices = sorted(range(len(similarity_scores)), key=lambda i: similarity_scores[i], reverse=True)[:top_k]
            top_contexts = [dataset[i] for i in top_k_indices]
            return top_contexts

        def rag_retrieve_and_generate(self, dataset, query):
            # Tokenize the query using the class's tokenizer
            tokenized_query = self.tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.config.device)
            input_ids = tokenized_query['input_ids']
            attention_mask = tokenized_query['attention_mask']
            # Use the class's question_encoder
            encoded_query = self.question_encoder(input_ids, attention_mask)
            relevant_contexts = self.retrieve_contexts(dataset, encoded_query)
            # Assuming generate_response is a method of LanguageModelTransformer that accepts tokenized contexts and generates a response
            response = self.language_model.generate_response(relevant_contexts)
            return response




    class LORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1):
            super(Expert.LORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha

            # Original weight and bias of the linear layer
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            #print("self.weight Shape:", self.weight.shape)
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # LORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def forward(self, x):
            #print("LORALayer Input Shape:", x.shape)
            
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)
            
            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            #print("LORALayer Output Shape:", x.shape)

            return x

    class QLORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
            super(Expert.QLORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha
            self.quantization_bits = quantization_bits

            # Original weight and bias
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # QLORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.dropout = nn.Dropout(0.1)
            self.layer_norm = nn.LayerNorm(output_dim)

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def quantize(self, x, num_bits):
            # Implement a simple quantization method
            scale = x.abs().max()
            x_quantized = torch.round(x / scale * (2**num_bits - 1))
            return x_quantized, scale

        def forward(self, x):
            #print("QLORALayer Input Shape:", x.shape)
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
            B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
            lora_adjustment = self.dropout(lora_adjustment)
            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)

            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            x = self.layer_norm(x)

            #print("QLORALayer Output Shape:", x.shape)

            return x
        
        def update_alpha(self, new_alpha):
            """
            Update the alpha scaling factor.
            """
            self.alpha = new_alpha

    class MultiHeadAttention(nn.Module):
        def __init__(self, embed_size, heads):
            super(Expert.MultiHeadAttention, self).__init__()
            self.embed_size = embed_size
            self.heads = heads
            self.head_dim = embed_size // heads

            assert (
                self.head_dim * heads == embed_size
            ), "Embedding size needs to be divisible by heads"

            self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

        def forward(self, values, keys, query, mask):
            N = query.shape[0]
            value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

            # Split the embedding into self.heads different pieces
            values = values.reshape(N, value_len, self.heads, self.head_dim)
            keys = keys.reshape(N, key_len, self.heads, self.head_dim)
            queries = query.reshape(N, query_len, self.heads, self.head_dim)

            values = self.values(values)
            keys = self.keys(keys)
            queries = self.queries(queries)

            # Einsum does the matrix multiplication for query*keys for each training example
            # with every other training example, then sum it up
            attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

            if mask is not None:
                attention = attention.masked_fill(mask == 0, float("-1e20"))

            attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

            out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
                N, query_len, self.heads * self.head_dim
            )

            out = self.fc_out(out)
            return out

    class TransformerBlock(nn.Module):
        def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
            super(Expert.TransformerBlock, self).__init__()
            self.attention = Expert.MultiHeadAttention(embed_size, heads)
            self.norm1 = nn.LayerNorm(embed_size)
            self.norm2 = nn.LayerNorm(embed_size)

            self.feed_forward = nn.Sequential(
                Expert.LORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                Expert.LORALayer(forward_expansion * embed_size, embed_size, rank),
            )

            self.dropout = nn.Dropout(dropout)

        def forward(self, value, key, query, mask):
            attention = self.attention(value, key, query, mask)
            x = self.dropout(self.norm1(attention + query))
            forward = self.feed_forward(x)
            out = self.dropout(self.norm2(forward + x))
            return out

    class LanguageModelDecoder(nn.Module):
        def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
            super(Expert.LanguageModelDecoder, self).__init__()
            self.word_embedding = nn.Embedding(vocab_size, embed_size)
            self.position_embedding = nn.Embedding(max_length, embed_size)

            # Adding BatchNorm layers
            self.bn1 = nn.BatchNorm1d(embed_size)
            self.bn2 = nn.BatchNorm1d(embed_size)

            self.layers = nn.ModuleList(
                [
                    Expert.TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                    for _ in range(num_layers)
                ]
            )

            # QLORA layers
            self.qlora_feed_forward = nn.Sequential(
                Expert.QLORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                Expert.QLORALayer(forward_expansion * embed_size, embed_size, rank),
            )
            self.use_qlora = False  # Flag to toggle QLORA

            self.fc_out = nn.Linear(embed_size, vocab_size)
            self.dropout = nn.Dropout(dropout)

        def forward(self, x, trg_mask):
            N, seq_length = x.shape
            if seq_length > self.position_embedding.num_embeddings:
                raise ValueError(f"Sequence length {seq_length} exceeds maximum allowed {self.position_embedding.num_embeddings}")
            positions = torch.arange(0, seq_length).expand(N, seq_length).to(config.device)
            x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

            print(f"Max position index: {positions.max().item()}, Position Embedding Size: {self.position_embedding.num_embeddings}")


            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn1(x)
            x = x.transpose(1, 2)

            for layer in self.layers:
                x = layer(x, x, x, trg_mask)
                if self.use_qlora:
                    x = self.qlora_feed_forward(x)

            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn2(x)
            x = x.transpose(1, 2)

            out = self.fc_out(x)
            #print(f"shape of output of forward method of LanguageModelDecoder: {out.shape} ")

            return out

        def toggle_qlora(self, use_qlora: bool):
            self.use_qlora = use_qlora

    class LanguageModelTransformer(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.vocab_size = config.vocab_size  # Use vocab_size from ExpertConfig
            self.embed_size = config.embed_size  # Use embed_size from ExpertConfig
            self.num_layers = config.num_layers  # Use num_layers from ExpertConfig
            self.forward_expansion = config.forward_expansion  # Use forward_expansion from ExpertConfig
            self.heads = config.heads  # Use heads from ExpertConfig
            self.dropout = config.dropout  # Use dropout from ExpertConfig
            self.max_length = config.max_length  # Use max_length from ExpertConfig
            self.rank = config.rank  # Use rank from ExpertConfig
            self.tokenizer_name = config.tokenizer_name  # Use tokenizer_name from ExpertConfig

            self.tokenizer = BertTokenizer.from_pretrained(self.tokenizer_name)

            self.decoder = Expert.LanguageModelDecoder(
                self.vocab_size,
                self.embed_size,
                self.num_layers,
                self.heads,
                self.forward_expansion,
                self.dropout,
                self.max_length,
                self.rank,
            )

        def forward(self, trg, attention_mask=None):  # Remove attention_mask here since it's not used
            print(f"Input shape to LanguageModelTransformer: {trg.shape}")
            trg_mask = self.make_trg_mask(trg)
            out = self.decoder(trg, trg_mask)  # Do not pass attention_mask here
            return out

        def make_trg_mask(self, trg):
            N, trg_len = trg.shape
            trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len).to(trg.device)
            return trg_mask

        def toggle_qlora(self, use_qlora: bool):
            self.decoder.toggle_qlora(use_qlora)

        def generate_response(self, input_ids, attention_mask):
            logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
            probabilities = F.softmax(logits, dim=-1)
            predicted_token_id = torch.argmax(probabilities, dim=-1)
            predicted_tokens = [self.tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
            response = self.tokenizer.convert_tokens_to_string(predicted_tokens)
            return response


    class DPO(nn.Module):
        def __init__(self, language_model, device):
            super(Expert.DPO, self).__init__()
            self.language_model = language_model
            self.device = device
        
        def forward(self, input_ids_question, input_ids_chosen=None, input_ids_rejected=None, labels=None, attention_mask=None):
            # Concatenate for DPO-specific data or use regular input handling
            combined_input_ids = torch.cat([input_ids_question, input_ids_chosen, input_ids_rejected], dim=1) if input_ids_chosen is not None and input_ids_rejected is not None else input_ids_question
            
            # Use the vocab_size from the language model
            combined_input_ids = torch.clamp(combined_input_ids, 0, self.language_model.vocab_size - 1)
            
            # Pass through the language model
            output = self.language_model(combined_input_ids) #, attention_mask=attention_mask)
            
            # Calculate loss if labels are provided
            loss = None
            if labels is not None:
                logits = output.logits if hasattr(output, 'logits') else output
                loss_fct = nn.BCEWithLogitsLoss()
                loss = loss_fct(logits.view(-1), labels.float().view(-1))
            
            return output, loss




    def __init__(self, config: ExpertConfig):
        super().__init__()
        # Validate configuration
        config.validate()

        # 1. SparseFlash2_attention
        self.sparse_flash2_attention = Expert.SparseFlash2Attention(
            config.seq_len, 
            config.head_dim, 
            config.block_size, 
            config.sparsity_factor
        )

        # Inside the Expert class definition
        self.lmt = Expert.LanguageModelTransformer(config).to(config.device)

        self.transformer_rag = Expert.TransformerRAG(config).to(config.device)

        # Corrected instantiation of SimplifiedLanguageModelMAMBA
        self.mamba = Expert.SimplifiedLanguageModelMAMBA(config).to(config.device)

        # Now initialize DPO with the language model transformer
        self.transformer_dpo = Expert.DPO(self.lmt, config.device).to(config.device)

        # 2. LayerNorm and Dropout
        self.layer_norm = nn.LayerNorm(config.input_dim)
        self.dropout = nn.Dropout(config.dropout)

        # 3. Internal Switch Routing
        self.switch_router = Expert.SwitchRouter(config)

        # 4. QLORA Layer
        self.qlora = Expert.QLORALayer(config.input_dim, config.input_dim, config.rank)


    def forward(self, x, attention_mask, context_texts, question_text):
        # 1. SparseFlash2_attention
        x = self.sparse_flash2_attention(x)

        # 2. LayerNorm and Dropout
        x = self.dropout(self.layer_norm(x))

        # 3. Internal Switch Routing
        x, aux_loss = self.switch_router(x, attention_mask, context_texts, question_text)

        # 4. QLORA Layer
        x = self.qlora(x)

        return x, aux_loss

    @staticmethod
    def calculate_new_alpha(current_loss, initial_loss, initial_alpha=1.0, final_alpha=0.1):
        """
        Calculate a new alpha value based on the current loss.
        """
        if current_loss >= initial_loss:
            return initial_alpha  # Keep initial alpha if loss isn't decreasing

        loss_ratio = current_loss / initial_loss
        alpha_range = initial_alpha - final_alpha
        new_alpha = final_alpha + (alpha_range * loss_ratio)
        return new_alpha  
    
    def train_dpo(self, train_loader, optimizer, label_column, input_columns, config, save_path):
        self.transformer_dpo.train()  # Set the model to training mode
        total_loss = 0

        for i, batch in enumerate(train_loader):  # Include i here to use it in your print statements
            print(f"Batch {i+1}:")
            input_ids_question = batch['input_ids_question'].to(config.device)
            input_ids_chosen = batch['input_ids_chosen'].to(config.device)
            input_ids_rejected = batch['input_ids_rejected'].to(config.device)
            labels = batch['labels'].to(config.device)

            # Ensure none of the input ids exceed the vocab size
            assert input_ids_question.max() < config.vocab_size, "Question IDs exceed vocab size"
            assert input_ids_chosen.max() < config.vocab_size, "Chosen IDs exceed vocab size"
            assert input_ids_rejected.max() < config.vocab_size, "Rejected IDs exceed vocab size"

            # Print maximum index for each part of the input
            print("Max index in input_ids_question:", input_ids_question.max().item())
            print("Max index in input_ids_chosen:", input_ids_chosen.max().item())
            print("Max index in input_ids_rejected:", input_ids_rejected.max().item())

            optimizer.zero_grad()
            # Combine the input ids and ensure they do not exceed max_seq_length
            combined_input_ids = torch.cat([input_ids_question, input_ids_chosen, input_ids_rejected], dim=1)
            max_seq_length = config.max_length  # Ensure this uses the correct attribute from your config
            if combined_input_ids.size(1) > max_seq_length:
                combined_input_ids = combined_input_ids[:, :max_seq_length]          
            print(f"Max index in combined_input_ids: {combined_input_ids.max().item()}")  # Add this debug line

            # Forward pass
            logits, loss = self.transformer_dpo(combined_input_ids, labels)  # Use combined_input_ids here

            # logits, loss = self.transformer_dpo(input_ids_question, input_ids_chosen, input_ids_rejected, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        average_loss = total_loss / len(train_loader)
        torch.save(self.transformer_dpo.state_dict(), save_path)
        return average_loss




    def train_language_model_rag(self, model, save_path, train_loader, device, num_epochs=5, lr=1e-8, weight_decay=1e-4):
        # Enable QLORA during training
        model.decoder.toggle_qlora(True)

        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = StepLR(optimizer, step_size=4, gamma=0.98)

        initial_loss = None

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            total_loss = 0

            for batch_idx, batch in enumerate(train_loader):
                inputs = batch['input_ids'].to(device)
                targets = batch['labels'].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))

                # Check for NaN in loss
                if math.isnan(loss.item()):
                    print("Encountered NaN loss, stopping training")
                    break

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()

                total_loss += loss.item()

                # Set the initial_loss after the first batch of the first epoch
                if initial_loss is None and batch_idx == 0:
                    initial_loss = loss.item()

            scheduler.step()

            # Check for NaN in total_loss
            if math.isnan(total_loss):
                print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
                break
            else:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

            # Average loss for the epoch
            average_loss = total_loss / len(train_loader)

            # Update alpha at the end of each epoch based on the average loss
            new_alpha = self.calculate_new_alpha(average_loss, initial_loss)
            for layer in model.modules():
                if isinstance(layer, Expert.QLORALayer):
                    layer.update_alpha(new_alpha)

        # Toggle QLORA off after training
        model.decoder.toggle_qlora(False)
        average_loss = total_loss / len(train_loader)

        torch.save(model.state_dict(), save_path)
        print("Training Complete")
        return model, average_loss

    def train_dpr_encoders(self, train_data, context_encoder, question_encoder, optimizer_context, optimizer_question, epochs , context_save_path, question_save_path):
        loss_function = nn.CosineEmbeddingLoss()

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        for epoch in range(epochs):
            total_loss = 0

            for i in range(len(train_data["queries"])):
                query = train_data["queries"][i]
                context = train_data["contexts"][i]

                # Ensure query is a string
                if not isinstance(query, str):
                    raise ValueError("Query must be a string.")
                tokenized_query = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)

                # Ensure context is a string
                if isinstance(context, dict):
                    # If context is a dictionary, extract the text content. This is a placeholder and might need adjustment.
                    context_text = context.get("text", "")
                elif isinstance(context, str):
                    context_text = context
                else:
                    raise ValueError("Context must be a string or a dictionary containing a text field.")
                tokenized_context = tokenizer(context_text, return_tensors="pt", padding=True, truncation=True, max_length=512)

                question_embeddings = question_encoder(input_ids=tokenized_query['input_ids'], attention_mask=tokenized_query['attention_mask'])
                context_embeddings = context_encoder(input_ids=tokenized_context['input_ids'], attention_mask=tokenized_context['attention_mask'])

                # The labels tensor should have the same first dimension size as the input tensors
                labels = torch.tensor([1.0] * question_embeddings.size(0), dtype=torch.float).to(question_embeddings.device)

                loss = loss_function(question_embeddings, context_embeddings, labels)

                optimizer_context.zero_grad()
                optimizer_question.zero_grad()
                loss.backward()
                optimizer_context.step()
                optimizer_question.step()

                total_loss += loss.item()

            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_data['queries'])}")
        average_loss = total_loss / len(train_data['queries'])
        torch.save(context_encoder.state_dict(), context_save_path)
        torch.save(question_encoder.state_dict(), question_save_path)
        return (context_encoder, question_encoder), average_loss
       
    def train_language_model_transformer(self, train_loader, device, vocab_size, save_path):
        model = self.lmt
        
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-8, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.98)
        num_epochs = 5
        
        initial_loss = None

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            model.decoder.toggle_qlora(True)
            total_loss = 0

            for batch_idx, batch in enumerate(train_loader):
                inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

                # Check for NaN in loss
                if math.isnan(loss.item()):
                    print("Encountered NaN loss, stopping training")
                    break

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()

                total_loss += loss.item()

                # Set the initial_loss after the first batch of the first epoch
                if initial_loss is None and batch_idx == 0:
                    initial_loss = loss.item()

            scheduler.step()

            # Check for NaN in total_loss
            if math.isnan(total_loss):
                print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
                break
            else:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

            # Average loss for the epoch
            average_loss = total_loss / len(train_loader)

            # Update alpha at the end of each epoch based on the average loss
            new_alpha = self.calculate_new_alpha(average_loss, initial_loss)
            for layer in model.modules():
                if isinstance(layer, Expert.QLORALayer):
                    layer.update_alpha(new_alpha)

        average_loss = total_loss / len(train_loader)
        torch.save(model.state_dict(), save_path)
        return model, average_loss


    def train_expert(self, train_loader, train_data, optimizer, main_loss_function, aux_loss_weight, device,save_path, accumulation_steps=4, num_epochs=5):
            self.train()  # Set the model to training mode
            for epoch in range(num_epochs):
                total_loss = 0
                optimizer.zero_grad()  # Initialize gradients to zero
                for batch_idx, batch in enumerate(train_loader):
                    inputs, attention_mask, targets = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)

                    # Calculate start and end indices for current batch in train_data
                    start_idx = batch_idx * batch['input_ids'].size(0)
                    end_idx = start_idx + batch['input_ids'].size(0)

                    # Extract current_queries and current_contexts for the batch
                    current_queries = train_data['queries'][start_idx:end_idx]
                    current_contexts = train_data['contexts'][start_idx:end_idx]

                    # Call to the model forward function
                    outputs, aux_loss = self(inputs, attention_mask, current_contexts, current_queries)

                    # Calculate loss and accumulate
                    main_loss = main_loss_function(outputs.view(-1, outputs.size(-1)), targets.view(-1))
                    loss = (main_loss + aux_loss_weight * aux_loss) / accumulation_steps
                    loss.backward()

                    if (batch_idx + 1) % accumulation_steps == 0:
                        optimizer.step()
                        optimizer.zero_grad()

                    total_loss += loss.item() * accumulation_steps  # Scale back up

                average_loss = total_loss / len(train_loader)
                print(f'End of Epoch {epoch+1}, Average Loss: {average_loss}')
            average_loss = total_loss / len(train_loader)
            torch.save(self.state_dict(), save_path)
            return self, average_loss


##################################
# Training transformer_with_dpo

torch.autograd.set_detect_anomaly(True)

def preprocess_dpo_data(examples):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    vocab_size = len(tokenizer.vocab)  # Make sure this matches your embedding layer's vocab size

    # Define max sequence length
    max_seq_length = 512

    # Tokenize 'question', 'chosen', and 'rejected' fields
    tokenized_questions = tokenizer(examples['question'], padding='max_length', truncation=True, max_length=max_seq_length)
    tokenized_chosen = tokenizer(examples['chosen'], padding='max_length', truncation=True, max_length=max_seq_length)
    tokenized_rejected = tokenizer(examples['rejected'], padding='max_length', truncation=True, max_length=max_seq_length)

    # Prepare the labels: 1 for 'chosen' and 0 for 'rejected'
    labels = [1 if i % 2 == 0 else 0 for i in range(len(examples['question']))]

    return {
        'input_ids_question': tokenized_questions['input_ids'],
        'attention_mask_question': tokenized_questions['attention_mask'],
        'input_ids_chosen': tokenized_chosen['input_ids'],
        'attention_mask_chosen': tokenized_chosen['attention_mask'],
        'input_ids_rejected': tokenized_rejected['input_ids'],
        'attention_mask_rejected': tokenized_rejected['attention_mask'],
        'labels': labels
    }



# Load the DPO dataset from Hugging Face
dpo_dataset = load_dataset("Intel/orca_dpo_pairs")

# Apply the preprocessing to the dataset
dpo_dataset = dpo_dataset.map(preprocess_dpo_data, batched=True)


# You can convert to PyTorch tensors after processing
dpo_dataset.set_format(type='torch', columns=['input_ids_question', 'attention_mask_question', 'input_ids_chosen', 'attention_mask_chosen', 'input_ids_rejected', 'attention_mask_rejected', 'labels'])

train_loader = DataLoader(dpo_dataset['train'], batch_size=2, shuffle=True)

# Instantiate the Expert model and optimizer
config = ExpertConfig()
expert_model = Expert(config)

model_vocab_size = expert_model.lmt.decoder.word_embedding.num_embeddings
print(f"Model vocab size: {model_vocab_size}")
print(f"Model embedding size: {expert_model.lmt.decoder.word_embedding.embedding_dim}")
print(f"Configured max length: {config.max_length}")

optimizer = AdamW(expert_model.parameters(), lr=1e-5)
save_path = 'D:/EXPERT_WEIGHTS/dpo_model.pth'

# Train the DPO model
label_column = 'labels'
input_columns = ['input_ids_question', 'attention_mask_question', 'input_ids_chosen', 'attention_mask_chosen', 'input_ids_rejected', 'attention_mask_rejected']
avg_loss = expert_model.train_dpo(train_loader, optimizer, label_column, input_columns, config, save_path)

# Save the model
torch.save(expert_model.transformer_dpo.state_dict(), save_path)

Model vocab size: 30522
Model embedding size: 256
Configured max length: 512
Batch 1:
Max index in input_ids_question: 20714
Max index in input_ids_chosen: 22414
Max index in input_ids_rejected: 22414
Max index in combined_input_ids: 20714
Input shape to LanguageModelTransformer: torch.Size([2, 512])
Max position index: 511, Position Embedding Size: 512


AttributeError: 'NoneType' object has no attribute 'backward'

# v3

In [37]:
# 0. Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import fitz  # PyMuPDF
from transformers import BertModel
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

class ExpertConfig:
    def __init__(self, seq_len=512, head_dim=64, block_size=64, sparsity_factor=4,
                 input_dim=512, num_experts=3, vocab_size=30522, embed_size=256,
                 num_layers=6, forward_expansion=4, heads=8, dropout=0.1,
                 max_length=512,  # Updated to match seq_len for consistency
                 rank=16, device='cpu',
                 mamba_model_path='D:\\EXPERT_WEIGHTS\\mamba_model_weights.pth',
                 context_encoder_path='D:\\EXPERT_WEIGHTS\\context_encoder.pth',
                 language_model_path='D:\\EXPERT_WEIGHTS\\language_model.pth',
                 question_encoder_path='D:\\EXPERT_WEIGHTS\\question_encoder.pth',
                 dpo_model_path='D:\\EXPERT_WEIGHTS\\dpo_model_weights.pth',
                 model_name='bert-base-uncased', embedding_dim=768,
                 alpha=1, quantization_bits=8, tokenizer_name='bert-base-uncased',
                 d_model=512, d_state=2048, d_conv=3, expansion_factor=2):

        # Common hyperparameters
        self.seq_len = seq_len
        self.head_dim = head_dim
        self.block_size = block_size
        self.sparsity_factor = sparsity_factor
        self.input_dim = input_dim
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.num_layers = num_layers
        self.forward_expansion = forward_expansion
        self.heads = heads
        self.dropout = dropout
        self.max_length = max_length  # Ensure this is properly reflected in model components
        self.rank = rank

        # Model paths and device
        self.mamba_model_path = mamba_model_path
        self.context_encoder_path = context_encoder_path
        self.language_model_path = language_model_path
        self.question_encoder_path = question_encoder_path
        self.dpo_model_path = dpo_model_path
        self.device = device

        # Unique hyperparameters
        self.num_experts = num_experts
        self.model_name = model_name
        self.embedding_dim = embedding_dim
        self.alpha = alpha
        self.quantization_bits = quantization_bits
        self.tokenizer_name = tokenizer_name
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor

        # PDFs (unchanged)
        self.pdf_file_paths = [
            r'C:\Users\robbi\IEEMM\DPO.pdf', 
            r'C:\Users\robbi\IEEMM\MAMBA.pdf',
            r'C:\Users\robbi\IEEMM\QLORA.pdf',
            r'C:\Users\robbi\IEEMM\RAG.pdf',
            r'C:\Users\robbi\IEEMM\SWITCH_TRANSFORMER.pdf'
        ]
        
        # Preserving original dataset loading functionality
        self.rag_dataset = Expert.TransformerRAG.create_dataset_from_pdfs(self.pdf_file_paths)

    def validate(self):
        assert self.seq_len % self.block_size == 0, "seq_len must be divisible by block_size"
        assert self.max_length >= self.seq_len, "max_length should be equal to or greater than seq_len"



class Expert(nn.Module):

    @staticmethod
    # Load model weights function
    def load_model_weights(model, model_path):
        #checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        checkpoint = torch.load(model_path, map_location=config.device)

        if isinstance(checkpoint, dict):
            # Check for 'state_dict' or 'model_state_dict' keys
            if 'state_dict' in checkpoint:
                model.load_state_dict(checkpoint['state_dict'])
            elif 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                # If no known key is found, try loading it as a raw state dictionary
                try:
                    model.load_state_dict(checkpoint)
                except RuntimeError as e:
                    raise ValueError(f"Error loading state dict: {e}")
        elif isinstance(checkpoint, nn.Module):
            # If the checkpoint is a model object, assign it directly
            model = checkpoint
        else:
            raise ValueError(f"Unsupported checkpoint format: {type(checkpoint)}")

        model.eval()
        return model

    # Flash2_Attention
    class FlashAttention2(nn.Module):
        def __init__(self, sequence_length, head_dimension, block_size):
            super(Expert.FlashAttention2, self).__init__()
            self.block_size = block_size
            # Ensure that sequence_length is divisible by block_size for simplicity
            assert sequence_length % block_size == 0

        def forward(self, Q, K, V):
            # Partitioning of inputs
            Q_blocks, K_blocks, V_blocks = self.partition_inputs(Q, K, V)

            # Efficient computation of the attention mechanism
            outputs = []
            for i, Q_block in enumerate(Q_blocks):
                output_block = self.process_block(Q_block, K_blocks, V_blocks)
                outputs.append(output_block)

            # Concatenating the processed blocks
            output = torch.cat(outputs, dim=0)
            return output

        def partition_inputs(self, Q, K, V):
            # The actual partitioning scheme should be based on sequence length, head dimension, and block size
            Q_blocks = Q.chunk(chunks=Q.size(0) // self.block_size, dim=0)
            K_blocks = K.chunk(chunks=K.size(0) // self.block_size, dim=0)
            V_blocks = V.chunk(chunks=V.size(0) // self.block_size, dim=0)
            return Q_blocks, K_blocks, V_blocks

        def process_block(self, Q_block, K_blocks, V_blocks):
            # Process each block efficiently as per FLASH2's optimized method
            # This includes computing QK^T, applying online softmax, and multiplying with V
            output_blocks = []
            for K_block, V_block in zip(K_blocks, V_blocks):
                attention_scores = torch.matmul(Q_block, K_block.transpose(-2, -1))
                attention_scores = self.online_softmax(attention_scores)
                output_block = torch.matmul(attention_scores, V_block)
                output_blocks.append(output_block)

            # Summing up the results from each block
            output_block_sum = sum(output_blocks)
            return output_block_sum

        def online_softmax(self, scores, chunk_size=128):
            # Apply softmax in chunks for large sequences
            softmaxed_scores = []
            for i in range(0, scores.size(0), chunk_size):
                chunk = scores[i:i + chunk_size, :]
                softmaxed_chunk = F.softmax(chunk, dim=1)
                softmaxed_scores.append(softmaxed_chunk)
            return torch.cat(softmaxed_scores, dim=0)

    # SparseFlash2_Attention
    class SparseFlash2Attention(nn.Module):
        def __init__(self, seq_len, head_dim, blk_size, sparsity_factor):
            super().__init__()
            self.flash_attention = Expert.FlashAttention2(seq_len, head_dim, blk_size)
            self.seq_len = seq_len
            self.head_dim = head_dim
            self.block_size = blk_size  # Storing block_size as an instance variable
            self.sparsity_factor = sparsity_factor

        def generate_sparsity_mask(self):
            mask = torch.zeros(self.seq_len, self.seq_len)
            step = self.sparsity_factor
            for i in range(0, self.seq_len, step):
                mask[i:i + step, :] = 1
            return mask.bool()

        def forward(self, Q, K, V):
            output = self.flash_attention(Q, K, V)  # output shape: [sequence_length, head_dimension]

            # Reshape output to be 3D for batch matrix multiplication
            output = output.unsqueeze(0)  # New shape: [1, sequence_length, head_dimension]

            sparsity_mask = self.generate_sparsity_mask()  # shape: [sequence_length, sequence_length]

            # Apply the sparsity mask to the output
            sparsity_mask = sparsity_mask.unsqueeze(0)  # New shape: [1, sequence_length, sequence_length]
            output = torch.bmm(sparsity_mask.float(), output.float())  # Perform batch matrix multiplication

            # Reshape the output back to 2D
            output = output.squeeze(0)  # New shape: [sequence_length, head_dimension]

            return output


    # MAMBA
    # RoPE
    class RotaryPositionalEncoding(nn.Module):
        def __init__(self, dim, max_len=5000):
            super().__init__()
            self.dim = dim
            inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
            t = torch.arange(max_len).type_as(inv_freq)
            freqs = torch.einsum('n , d -> n d', t, inv_freq)
            self.register_buffer('sin', freqs.sin())
            self.register_buffer('cos', freqs.cos())

        def forward(self, x):
            n, _, device = x.shape[1], self.dim // 2, x.device
            sin, cos = self.sin[:n].to(device), self.cos[:n].to(device)

            # Apply RoPE to even and odd indices separately
            x_even = x[..., :self.dim:2] * cos.unsqueeze(0) + torch.roll(x[..., 1:self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
            x_odd = x[..., 1:self.dim:2] * cos.unsqueeze(0) - torch.roll(x[..., :self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
            return torch.cat((x_even, x_odd), dim=-1)

    # SWIGLU
    class SwiGLU(nn.Module):
        def __init__(self, dim_in, dim_out):
            super(Expert.SwiGLU, self).__init__()
            self.fc1 = nn.Linear(dim_in, dim_out)
            self.fc2 = nn.Linear(dim_in, dim_out)

        def forward(self, x):
            gate = torch.sigmoid(self.fc2(x))
            return self.fc1(x) * gate

    class SimplifiedMAMBA(nn.Module):
        def __init__(self, config):
            super().__init__()
            # Using configuration parameters
            self.num_layers = config.num_layers
            self.d_model = config.d_model
            self.d_state = config.d_state
            self.d_conv = config.d_conv
            self.expansion_factor = config.expansion_factor
            
            self.feedforward = nn.Sequential(
                nn.Linear(self.d_model, self.d_state),
                nn.GELU(),
                nn.Linear(self.d_state, self.d_model)
            )
            self.input_embedding = nn.Linear(self.d_model, self.d_model)
            self.convs = nn.Sequential(*[nn.Conv1d(self.d_model, self.d_model, kernel_size=self.d_conv, padding=(self.d_conv // 2)) for _ in range(self.num_layers)])
            self.swiglu = Expert.SwiGLU(self.d_model, self.d_model)
            self.output_projection = nn.Linear(self.d_model, self.d_model * self.expansion_factor)

            self.initialize_weights()

        def initialize_weights(self):
            gain = nn.init.calculate_gain('relu')

            nn.init.orthogonal_(self.input_embedding.weight, gain)
            nn.init.normal_(self.input_embedding.bias, mean=0, std=0.01)

            nn.init.kaiming_uniform_(self.convs[-1].weight, a=math.sqrt(5))
            nn.init.zeros_(self.convs[-1].bias)

            nn.init.xavier_uniform_(self.feedforward[0].weight, gain=nn.init.calculate_gain('relu'))
            nn.init.zeros_(self.feedforward[0].bias)

            nn.init.xavier_uniform_(self.feedforward[2].weight, gain=nn.init.calculate_gain('linear'))
            nn.init.zeros_(self.feedforward[2].bias)

            nn.init.xavier_uniform_(self.output_projection.weight, gain=nn.init.calculate_gain('linear'))
            nn.init.zeros_(self.output_projection.bias)

        def forward(self, inputs, attention_mask=None):
            print("Input shape:", inputs.shape)

            # Apply the attention mask if provided
            if attention_mask is not None:
                inputs = inputs * attention_mask.unsqueeze(-1)

            projected_inputs = self.input_embedding(inputs)
            print("projected_inputs pre-reshape shape:", projected_inputs.shape)

            projected_inputs = projected_inputs.permute(0, 2, 1)
            print("projected_inputs post-reshape shape:", projected_inputs.shape)

            for conv in self.convs:
                projected_inputs = conv(projected_inputs)

            projected_inputs = projected_inputs.permute(0, 2, 1)
            print("projected_inputs post convolution reshape:", projected_inputs.shape)

            projected_inputs = self.swiglu(projected_inputs)
            print("projected_inputs post swiglu shape:", projected_inputs.shape)

            output = self.output_projection(projected_inputs)
            print("output shape:", output.shape)

            return output

    class SimplifiedLanguageModelMAMBA(nn.Module):
        def __init__(self, config):
            super().__init__()
            # Using configuration parameters
            self.vocab_size = config.vocab_size
            self.num_layers = config.num_layers
            self.d_model = config.d_model
            self.d_state = config.d_state
            self.d_conv = config.d_conv
            self.expansion_factor = config.expansion_factor

            self.embedding = nn.Embedding(self.vocab_size, self.d_model)
            self.pos_encoder = Expert.RotaryPositionalEncoding(self.d_model)
            self.simplified_mamba = Expert.SimplifiedMAMBA(config)
            self.output_projection = nn.Linear(self.d_model * 2, self.vocab_size)  # Adjust if needed

            self.initialize_weights()

        def initialize_weights(self):
            gain = 1.0

            nn.init.orthogonal_(self.embedding.weight, gain)
            nn.init.xavier_uniform_(self.output_projection.weight, gain=gain)

        def forward(self, input_values, attention_mask):
            embedded = self.embedding(input_values) * math.sqrt(self.d_model)
            embedded = self.pos_encoder(embedded)
            simplified_mamba_output = self.simplified_mamba(embedded, attention_mask)
            logits = self.output_projection(simplified_mamba_output)
            return logits


    class SwitchRouter(nn.Module):
        CAPACITY_FACTOR = 1  # Class constant

        class SwitchGate(nn.Module):
            def __init__(self, input_dim, num_experts):
                super(Expert.SwitchRouter.SwitchGate, self).__init__()
                self.fc1 = nn.Linear(input_dim, input_dim // 2)
                self.fc2 = nn.Linear(input_dim // 2, num_experts)

            def forward(self, x):
                x = F.relu(self.fc1(x))
                gate_scores = F.softmax(self.fc2(x), dim=-1)
                return gate_scores

        def __init__(self, config):
            super(Expert.SwitchRouter, self).__init__()
            self.config = config
            self.device = config.device
            self.router = self.SwitchGate(config.input_dim, config.num_experts).to(config.device)
            self.transformer_rag = Expert.TransformerRAG(config).to(config.device)
            self.lmt = Expert.LanguageModelTransformer(config).to(config.device)
            self.transformer_dpo = Expert.DPO(self.lmt, config.device).to(config.device)
            self.mamba = Expert.SimplifiedLanguageModelMAMBA(config).to(config.device)
            self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
            self.input_embedding = nn.Linear(config.input_dim, config.input_dim)

        def forward(self, x, attention_mask, context_texts, question_text):
            x = x.to(self.device)
            if x.dtype != torch.float32:
                x = x.float()

            x = self.input_embedding(x)
            gate_scores = self.router(x)
            expert_indices = torch.argmax(gate_scores, dim=1)
            final_output = torch.zeros_like(x)

            for i, expert in enumerate(self.experts):
                mask = expert_indices == i
                if mask.any():
                    selected_inputs = x[mask]
                    selected_attention_mask = attention_mask[mask].to(self.device)
                    if isinstance(expert, Expert.TransformerRAG):
                        expert_output = expert.forward(context_texts, selected_inputs, selected_attention_mask, question_text)
                    else:
                        expert_output = expert(selected_inputs, selected_attention_mask)
                    final_output[mask] = expert_output

            aux_loss = self.auxiliary_loss(gate_scores)
            return final_output, aux_loss

        def auxiliary_loss(self, gate_scores):
            expert_load = gate_scores.sum(0) / gate_scores.size(0)
            loss_balancing = torch.std(expert_load)
            return loss_balancing


        @staticmethod
        def route_inputs(expert_indices, gate_scores, num_experts):
            capacity_factor_tensor = torch.tensor([Expert.SwitchRouter.CAPACITY_FACTOR], dtype=torch.float32)
            capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
            expert_counts = torch.zeros(num_experts, dtype=torch.int32)
            for idx in range(len(expert_indices)):
                selected_expert = expert_indices[idx]
                if expert_counts[selected_expert] < capacities[0]:
                    expert_counts[selected_expert] += 1
                else:
                    available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
                    if len(available_experts) > 0:
                        alternative_expert = available_experts[0]
                        expert_indices[idx] = alternative_expert
                        expert_counts[alternative_expert] += 1
                    else:
                        print("No available experts to reroute. Handling overflow.")
            return expert_indices
    
    class CustomDPRContextEncoder(nn.Module):
        def __init__(self, embedding_dim):
            super(Expert.CustomDPRContextEncoder, self).__init__()  
            self.bert_model = BertModel.from_pretrained('bert-base-uncased')
            self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

        def forward(self, input_ids, attention_mask=None):
            outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.pooler_output
            context_embeddings = self.embedding_layer(pooled_output)
            return context_embeddings

    class DPRQuestionEncoder(nn.Module):
        def __init__(self, embedding_dim):
            super(Expert.DPRQuestionEncoder, self).__init__() 
            self.bert_model = BertModel.from_pretrained('bert-base-uncased')
            self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

        def forward(self, input_ids, attention_mask, **kwargs):
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.pooler_output
            embeddings = self.embedding_layer(pooled_output)
            return embeddings

    class TransformerRAG(nn.Module):
        def __init__(self, config):
            super(Expert.TransformerRAG, self).__init__()
            self.config = config
            self.context_encoder = Expert.CustomDPRContextEncoder(config.embedding_dim).to(config.device)
            self.language_model = Expert.LanguageModelTransformer(config).to(config.device)
            self.question_encoder = Expert.DPRQuestionEncoder(config.embedding_dim).to(config.device)
            self.tokenizer = BertTokenizer.from_pretrained(config.tokenizer_name)


        def forward(self, context_texts, question_input_ids, question_attention_mask, question_text):
            if question_input_ids.max() >= self.tokenizer.vocab_size:
                raise ValueError("question_input_ids contain ID(s) beyond the tokenizer's vocabulary size")
            
            aggregated_context_embeddings = []
            for context_list in context_texts:
                if not all(isinstance(context, dict) for context in context_list):
                    raise TypeError("Each item in context_texts must be a list of tokenized context dictionaries")
                
                aggregated_context_embedding = torch.zeros(self.context_encoder.bert_model.config.hidden_size, device=device)
                for context in context_list:
                    context_input_ids = context['input_ids'].to(config.device)
                    context_attention_mask = context['attention_mask'].to(config.device)
                    context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                    aggregated_context_embedding += context_embedding.mean(dim=0)
                
                aggregated_context_embeddings.append(aggregated_context_embedding / len(context_list))
            
            question_input_ids = question_input_ids.to(config.device).long()
            question_attention_mask = question_attention_mask.to(config.device).long()
            question_embeddings = self.question_encoder(input_ids=question_input_ids, attention_mask=question_attention_mask)
            
            cos_sim = torch.nn.CosineSimilarity(dim=1)
            similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in aggregated_context_embeddings]
            most_relevant_context_idx = torch.argmax(torch.tensor(similarities, device=config.device))
            
            combined_input = question_text + " " + context_texts[most_relevant_context_idx]
            tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
            tokenized_combined_input = {k: v.to(config.device) for k, v in tokenized_combined_input.items()}
            response_logits = self.language_model(**tokenized_combined_input)
            probabilities = F.softmax(response_logits.logits, dim=-1)
            predicted_token_ids = torch.argmax(probabilities, dim=-1)
            predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])
            response = self.tokenizer.convert_tokens_to_string(predicted_tokens)
            
            return response

        @staticmethod
        def extract_text_from_pdf(file_path):
            text = ""
            with fitz.open(file_path) as doc:
                for page in doc:
                    text += page.get_text()
            return text

        @staticmethod
        def split_into_chunks(text, chunk_size):
            return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

        @staticmethod
        def preprocess_text(text, max_length=512):
            tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            chunk_size = max_length - 50
            text_chunks = Expert.TransformerRAG.split_into_chunks(text, chunk_size)
            processed_chunks = []
            for chunk in text_chunks:
                tokenized_output = tokenizer(chunk, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
                processed_chunk = {
                    'input_ids': tokenized_output['input_ids'],
                    'attention_mask': tokenized_output['attention_mask']
                }
                processed_chunks.append(processed_chunk)
            return processed_chunks

        @staticmethod
        def create_dataset_from_pdfs(pdf_file_paths):
            dataset = []
            for file_path in pdf_file_paths:
                text = Expert.TransformerRAG.extract_text_from_pdf(file_path)
                processed_text = Expert.TransformerRAG.preprocess_text(text)
                dataset.append(processed_text)
            return dataset

        def retrieve_contexts(self, dataset, query_embedding, top_k=5):
            similarity_scores = []
            for context in dataset:
                context_input_ids = context['input_ids'].to(self.config.device)
                context_attention_mask = context['attention_mask'].to(self.config.device)
                # Use the class's context_encoder
                context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                similarity = torch.matmul(query_embedding, context_embedding.T)
                similarity_scores.append(similarity.squeeze().item())
            top_k_indices = sorted(range(len(similarity_scores)), key=lambda i: similarity_scores[i], reverse=True)[:top_k]
            top_contexts = [dataset[i] for i in top_k_indices]
            return top_contexts

        def rag_retrieve_and_generate(self, dataset, query):
            # Tokenize the query using the class's tokenizer
            tokenized_query = self.tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.config.device)
            input_ids = tokenized_query['input_ids']
            attention_mask = tokenized_query['attention_mask']
            # Use the class's question_encoder
            encoded_query = self.question_encoder(input_ids, attention_mask)
            relevant_contexts = self.retrieve_contexts(dataset, encoded_query)
            # Assuming generate_response is a method of LanguageModelTransformer that accepts tokenized contexts and generates a response
            response = self.language_model.generate_response(relevant_contexts)
            return response




    class LORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1):
            super(Expert.LORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha

            # Original weight and bias of the linear layer
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            #print("self.weight Shape:", self.weight.shape)
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # LORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def forward(self, x):
            #print("LORALayer Input Shape:", x.shape)
            
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)
            
            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            #print("LORALayer Output Shape:", x.shape)

            return x

    class QLORALayer(nn.Module):
        def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
            super(Expert.QLORALayer, self).__init__()
            self.rank = rank
            self.alpha = alpha
            self.quantization_bits = quantization_bits

            # Original weight and bias
            self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
            self.bias = nn.Parameter(torch.Tensor(output_dim))

            # QLORA specific parameters
            self.A = nn.Parameter(torch.Tensor(input_dim, rank))
            self.B = nn.Parameter(torch.Tensor(rank, output_dim))

            self.dropout = nn.Dropout(0.1)
            self.layer_norm = nn.LayerNorm(output_dim)

            self.reset_parameters()

        def reset_parameters(self):
            nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
            nn.init.zeros_(self.bias)
            nn.init.normal_(self.A, 0, 0.02)
            nn.init.normal_(self.B, 0, 0.02)

        def quantize(self, x, num_bits):
            # Implement a simple quantization method
            scale = x.abs().max()
            x_quantized = torch.round(x / scale * (2**num_bits - 1))
            return x_quantized, scale

        def forward(self, x):
            #print("QLORALayer Input Shape:", x.shape)
            original_size = x.size()
            batch_size, seq_len, _ = x.shape
            x_flattened = x.reshape(-1, original_size[-1])

            A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
            B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

            # Compute lora_adjustment for each input in the batch
            lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
            lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
            lora_adjustment = self.dropout(lora_adjustment)
            #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
            #print("self.weight Shape:", self.weight.shape)

            # Apply linear transformation to x_flattened
            x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
            x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

            # Add lora_adjustment to the transformed x
            x = x_transformed + lora_adjustment
            x = self.layer_norm(x)

            #print("QLORALayer Output Shape:", x.shape)

            return x
        
        def update_alpha(self, new_alpha):
            """
            Update the alpha scaling factor.
            """
            self.alpha = new_alpha

    class MultiHeadAttention(nn.Module):
        def __init__(self, embed_size, heads):
            super(Expert.MultiHeadAttention, self).__init__()
            self.embed_size = embed_size
            self.heads = heads
            self.head_dim = embed_size // heads

            assert (
                self.head_dim * heads == embed_size
            ), "Embedding size needs to be divisible by heads"

            self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
            self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

        def forward(self, values, keys, query, mask):
            N = query.shape[0]
            value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

            # Split the embedding into self.heads different pieces
            values = values.reshape(N, value_len, self.heads, self.head_dim)
            keys = keys.reshape(N, key_len, self.heads, self.head_dim)
            queries = query.reshape(N, query_len, self.heads, self.head_dim)

            values = self.values(values)
            keys = self.keys(keys)
            queries = self.queries(queries)

            # Einsum does the matrix multiplication for query*keys for each training example
            # with every other training example, then sum it up
            attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

            if mask is not None:
                attention = attention.masked_fill(mask == 0, float("-1e20"))

            attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

            out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
                N, query_len, self.heads * self.head_dim
            )

            out = self.fc_out(out)
            return out

    class TransformerBlock(nn.Module):
        def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
            super(Expert.TransformerBlock, self).__init__()
            self.attention = Expert.MultiHeadAttention(embed_size, heads)
            self.norm1 = nn.LayerNorm(embed_size)
            self.norm2 = nn.LayerNorm(embed_size)

            self.feed_forward = nn.Sequential(
                Expert.LORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                Expert.LORALayer(forward_expansion * embed_size, embed_size, rank),
            )

            self.dropout = nn.Dropout(dropout)

        def forward(self, value, key, query, mask):
            attention = self.attention(value, key, query, mask)
            x = self.dropout(self.norm1(attention + query))
            forward = self.feed_forward(x)
            out = self.dropout(self.norm2(forward + x))
            return out

    class LanguageModelDecoder(nn.Module):
        def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
            super(Expert.LanguageModelDecoder, self).__init__()
            self.word_embedding = nn.Embedding(vocab_size, embed_size)
            self.position_embedding = nn.Embedding(max_length, embed_size)

            # Adding BatchNorm layers
            self.bn1 = nn.BatchNorm1d(embed_size)
            self.bn2 = nn.BatchNorm1d(embed_size)

            self.layers = nn.ModuleList(
                [
                    Expert.TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                    for _ in range(num_layers)
                ]
            )

            # QLORA layers
            self.qlora_feed_forward = nn.Sequential(
                Expert.QLORALayer(embed_size, forward_expansion * embed_size, rank),
                nn.ReLU(),
                Expert.QLORALayer(forward_expansion * embed_size, embed_size, rank),
            )
            self.use_qlora = False  # Flag to toggle QLORA

            self.fc_out = nn.Linear(embed_size, vocab_size)
            self.dropout = nn.Dropout(dropout)

        def forward(self, x, trg_mask):
            N, seq_length = x.shape
            if seq_length > self.position_embedding.num_embeddings:
                raise ValueError(f"Sequence length {seq_length} exceeds maximum allowed {self.position_embedding.num_embeddings}")
            positions = torch.arange(0, seq_length).expand(N, seq_length).to(config.device)
            x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

            print(f"Max position index: {positions.max().item()}, Position Embedding Size: {self.position_embedding.num_embeddings}")


            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn1(x)
            x = x.transpose(1, 2)

            for layer in self.layers:
                x = layer(x, x, x, trg_mask)
                if self.use_qlora:
                    x = self.qlora_feed_forward(x)

            # Transpose for BatchNorm, apply batch normalization, and then transpose back
            x = x.transpose(1, 2)
            x = self.bn2(x)
            x = x.transpose(1, 2)

            out = self.fc_out(x)
            #print(f"shape of output of forward method of LanguageModelDecoder: {out.shape} ")

            return out

        def toggle_qlora(self, use_qlora: bool):
            self.use_qlora = use_qlora

    class LanguageModelTransformer(nn.Module):
        def __init__(self, config):
            super().__init__()
            self.vocab_size = config.vocab_size  # Use vocab_size from ExpertConfig
            self.embed_size = config.embed_size  # Use embed_size from ExpertConfig
            self.num_layers = config.num_layers  # Use num_layers from ExpertConfig
            self.forward_expansion = config.forward_expansion  # Use forward_expansion from ExpertConfig
            self.heads = config.heads  # Use heads from ExpertConfig
            self.dropout = config.dropout  # Use dropout from ExpertConfig
            self.max_length = config.max_length  # Use max_length from ExpertConfig
            self.rank = config.rank  # Use rank from ExpertConfig
            self.tokenizer_name = config.tokenizer_name  # Use tokenizer_name from ExpertConfig

            self.tokenizer = BertTokenizer.from_pretrained(self.tokenizer_name)

            self.decoder = Expert.LanguageModelDecoder(
                self.vocab_size,
                self.embed_size,
                self.num_layers,
                self.heads,
                self.forward_expansion,
                self.dropout,
                self.max_length,
                self.rank,
            )

        def forward(self, trg, attention_mask=None):  # Remove attention_mask here since it's not used
            print(f"Input shape to LanguageModelTransformer: {trg.shape}")
            trg_mask = self.make_trg_mask(trg)
            out = self.decoder(trg, trg_mask)  # Do not pass attention_mask here
            return out

        def make_trg_mask(self, trg):
            N, trg_len = trg.shape
            trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(N, 1, trg_len, trg_len).to(trg.device)
            return trg_mask

        def toggle_qlora(self, use_qlora: bool):
            self.decoder.toggle_qlora(use_qlora)

        def generate_response(self, input_ids, attention_mask):
            logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
            probabilities = F.softmax(logits, dim=-1)
            predicted_token_id = torch.argmax(probabilities, dim=-1)
            predicted_tokens = [self.tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
            response = self.tokenizer.convert_tokens_to_string(predicted_tokens)
            return response


    class DPO(nn.Module):
        def __init__(self, language_model, device):
            super(Expert.DPO, self).__init__()
            self.language_model = language_model
            self.device = device
        
        def forward(self, input_ids_question, input_ids_chosen=None, input_ids_rejected=None, labels=None):
            # Concatenate for DPO-specific data or use regular input handling
            combined_input_ids = torch.cat([input_ids_question, input_ids_chosen, input_ids_rejected], dim=1) if input_ids_chosen is not None and input_ids_rejected is not None else input_ids_question
            
            # Ensure the inputs do not exceed the language model's vocab size
            combined_input_ids = torch.clamp(combined_input_ids, 0, self.language_model.vocab_size - 1)
            
            # Pass through the language model
            logits = self.language_model(combined_input_ids)  # Adjust based on your model's output
            
            # Calculate loss if labels are provided
            loss = None
            if labels is not None:
                # Adjust the shape of logits if necessary to match labels
                logits = logits.view(-1)  # Adjust based on your model's output and labels shape
                loss_fct = nn.BCEWithLogitsLoss()
                loss = loss_fct(logits, labels.float().view(-1))
            
            return logits, loss





    def __init__(self, config: ExpertConfig):
        super().__init__()
        # Validate configuration
        config.validate()

        # 1. SparseFlash2_attention
        self.sparse_flash2_attention = Expert.SparseFlash2Attention(
            config.seq_len, 
            config.head_dim, 
            config.block_size, 
            config.sparsity_factor
        )

        # Inside the Expert class definition
        self.lmt = Expert.LanguageModelTransformer(config).to(config.device)

        self.transformer_rag = Expert.TransformerRAG(config).to(config.device)

        # Corrected instantiation of SimplifiedLanguageModelMAMBA
        self.mamba = Expert.SimplifiedLanguageModelMAMBA(config).to(config.device)

        # Now initialize DPO with the language model transformer
        self.transformer_dpo = Expert.DPO(self.lmt, config.device).to(config.device)

        # 2. LayerNorm and Dropout
        self.layer_norm = nn.LayerNorm(config.input_dim)
        self.dropout = nn.Dropout(config.dropout)

        # 3. Internal Switch Routing
        self.switch_router = Expert.SwitchRouter(config)

        # 4. QLORA Layer
        self.qlora = Expert.QLORALayer(config.input_dim, config.input_dim, config.rank)


    def forward(self, x, attention_mask, context_texts, question_text):
        # 1. SparseFlash2_attention
        x = self.sparse_flash2_attention(x)

        # 2. LayerNorm and Dropout
        x = self.dropout(self.layer_norm(x))

        # 3. Internal Switch Routing
        x, aux_loss = self.switch_router(x, attention_mask, context_texts, question_text)

        # 4. QLORA Layer
        x = self.qlora(x)

        return x, aux_loss

    @staticmethod
    def calculate_new_alpha(current_loss, initial_loss, initial_alpha=1.0, final_alpha=0.1):
        """
        Calculate a new alpha value based on the current loss.
        """
        if current_loss >= initial_loss:
            return initial_alpha  # Keep initial alpha if loss isn't decreasing

        loss_ratio = current_loss / initial_loss
        alpha_range = initial_alpha - final_alpha
        new_alpha = final_alpha + (alpha_range * loss_ratio)
        return new_alpha  
    
    def train_dpo(self, train_loader, optimizer, label_column, input_columns, config, save_path):
        self.transformer_dpo.train()  # Set the model to training mode
        total_loss = 0

        for i, batch in enumerate(train_loader):
            print(f"Batch {i+1}:")

            input_ids_question = batch['input_ids_question'].to(config.device)
            input_ids_chosen = batch['input_ids_chosen'].to(config.device)
            input_ids_rejected = batch['input_ids_rejected'].to(config.device)
            
            # Adjust labels here to ensure they're in the correct shape
            labels = batch[label_column].view(-1).to(config.device)

            optimizer.zero_grad()

            # Trim inputs before concatenation
            max_length_per_input = config.max_length // 3  # Divide max_length by the number of inputs
            input_ids_question = input_ids_question[:, :max_length_per_input]
            input_ids_chosen = input_ids_chosen[:, :max_length_per_input]
            input_ids_rejected = input_ids_rejected[:, :max_length_per_input]

            # Combine the input ids
            combined_input_ids = torch.cat([input_ids_question, input_ids_chosen, input_ids_rejected], dim=1)

            # Forward pass
            logits, loss = self.transformer_dpo(input_ids_question, input_ids_chosen, input_ids_rejected, labels)

            # Backward pass and optimization
            if loss is not None:
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

        average_loss = total_loss / len(train_loader)
        torch.save(self.transformer_dpo.state_dict(), save_path)
        return average_loss




    def train_language_model_rag(self, model, save_path, train_loader, device, num_epochs=5, lr=1e-8, weight_decay=1e-4):
        # Enable QLORA during training
        model.decoder.toggle_qlora(True)

        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        scheduler = StepLR(optimizer, step_size=4, gamma=0.98)

        initial_loss = None

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            total_loss = 0

            for batch_idx, batch in enumerate(train_loader):
                inputs = batch['input_ids'].to(device)
                targets = batch['labels'].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))

                # Check for NaN in loss
                if math.isnan(loss.item()):
                    print("Encountered NaN loss, stopping training")
                    break

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()

                total_loss += loss.item()

                # Set the initial_loss after the first batch of the first epoch
                if initial_loss is None and batch_idx == 0:
                    initial_loss = loss.item()

            scheduler.step()

            # Check for NaN in total_loss
            if math.isnan(total_loss):
                print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
                break
            else:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

            # Average loss for the epoch
            average_loss = total_loss / len(train_loader)

            # Update alpha at the end of each epoch based on the average loss
            new_alpha = self.calculate_new_alpha(average_loss, initial_loss)
            for layer in model.modules():
                if isinstance(layer, Expert.QLORALayer):
                    layer.update_alpha(new_alpha)

        # Toggle QLORA off after training
        model.decoder.toggle_qlora(False)
        average_loss = total_loss / len(train_loader)

        torch.save(model.state_dict(), save_path)
        print("Training Complete")
        return model, average_loss

    def train_dpr_encoders(self, train_data, context_encoder, question_encoder, optimizer_context, optimizer_question, epochs , context_save_path, question_save_path):
        loss_function = nn.CosineEmbeddingLoss()

        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

        for epoch in range(epochs):
            total_loss = 0

            for i in range(len(train_data["queries"])):
                query = train_data["queries"][i]
                context = train_data["contexts"][i]

                # Ensure query is a string
                if not isinstance(query, str):
                    raise ValueError("Query must be a string.")
                tokenized_query = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)

                # Ensure context is a string
                if isinstance(context, dict):
                    # If context is a dictionary, extract the text content. This is a placeholder and might need adjustment.
                    context_text = context.get("text", "")
                elif isinstance(context, str):
                    context_text = context
                else:
                    raise ValueError("Context must be a string or a dictionary containing a text field.")
                tokenized_context = tokenizer(context_text, return_tensors="pt", padding=True, truncation=True, max_length=512)

                question_embeddings = question_encoder(input_ids=tokenized_query['input_ids'], attention_mask=tokenized_query['attention_mask'])
                context_embeddings = context_encoder(input_ids=tokenized_context['input_ids'], attention_mask=tokenized_context['attention_mask'])

                # The labels tensor should have the same first dimension size as the input tensors
                labels = torch.tensor([1.0] * question_embeddings.size(0), dtype=torch.float).to(question_embeddings.device)

                loss = loss_function(question_embeddings, context_embeddings, labels)

                optimizer_context.zero_grad()
                optimizer_question.zero_grad()
                loss.backward()
                optimizer_context.step()
                optimizer_question.step()

                total_loss += loss.item()

            print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_data['queries'])}")
        average_loss = total_loss / len(train_data['queries'])
        torch.save(context_encoder.state_dict(), context_save_path)
        torch.save(question_encoder.state_dict(), question_save_path)
        return (context_encoder, question_encoder), average_loss
       
    def train_language_model_transformer(self, train_loader, device, vocab_size, save_path):
        model = self.lmt
        
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-8, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.98)
        num_epochs = 5
        
        initial_loss = None

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            model.decoder.toggle_qlora(True)
            total_loss = 0

            for batch_idx, batch in enumerate(train_loader):
                inputs, targets = batch['input_ids'].to(device), batch['labels'].to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

                # Check for NaN in loss
                if math.isnan(loss.item()):
                    print("Encountered NaN loss, stopping training")
                    break

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()

                total_loss += loss.item()

                # Set the initial_loss after the first batch of the first epoch
                if initial_loss is None and batch_idx == 0:
                    initial_loss = loss.item()

            scheduler.step()

            # Check for NaN in total_loss
            if math.isnan(total_loss):
                print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
                break
            else:
                print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

            # Average loss for the epoch
            average_loss = total_loss / len(train_loader)

            # Update alpha at the end of each epoch based on the average loss
            new_alpha = self.calculate_new_alpha(average_loss, initial_loss)
            for layer in model.modules():
                if isinstance(layer, Expert.QLORALayer):
                    layer.update_alpha(new_alpha)

        average_loss = total_loss / len(train_loader)
        torch.save(model.state_dict(), save_path)
        return model, average_loss


    def train_expert(self, train_loader, train_data, optimizer, main_loss_function, aux_loss_weight, device,save_path, accumulation_steps=4, num_epochs=5):
            self.train()  # Set the model to training mode
            for epoch in range(num_epochs):
                total_loss = 0
                optimizer.zero_grad()  # Initialize gradients to zero
                for batch_idx, batch in enumerate(train_loader):
                    inputs, attention_mask, targets = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)

                    # Calculate start and end indices for current batch in train_data
                    start_idx = batch_idx * batch['input_ids'].size(0)
                    end_idx = start_idx + batch['input_ids'].size(0)

                    # Extract current_queries and current_contexts for the batch
                    current_queries = train_data['queries'][start_idx:end_idx]
                    current_contexts = train_data['contexts'][start_idx:end_idx]

                    # Call to the model forward function
                    outputs, aux_loss = self(inputs, attention_mask, current_contexts, current_queries)

                    # Calculate loss and accumulate
                    main_loss = main_loss_function(outputs.view(-1, outputs.size(-1)), targets.view(-1))
                    loss = (main_loss + aux_loss_weight * aux_loss) / accumulation_steps
                    loss.backward()

                    if (batch_idx + 1) % accumulation_steps == 0:
                        optimizer.step()
                        optimizer.zero_grad()

                    total_loss += loss.item() * accumulation_steps  # Scale back up

                average_loss = total_loss / len(train_loader)
                print(f'End of Epoch {epoch+1}, Average Loss: {average_loss}')
            average_loss = total_loss / len(train_loader)
            torch.save(self.state_dict(), save_path)
            return self, average_loss


##################################
# Training transformer_with_dpo

torch.autograd.set_detect_anomaly(True)

def preprocess_dpo_data(examples):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    vocab_size = len(tokenizer.vocab)  # Make sure this matches your embedding layer's vocab size

    # Define max sequence length
    max_seq_length = 512

    # Tokenize 'question', 'chosen', and 'rejected' fields
    tokenized_questions = tokenizer(examples['question'], padding='max_length', truncation=True, max_length=max_seq_length)
    tokenized_chosen = tokenizer(examples['chosen'], padding='max_length', truncation=True, max_length=max_seq_length)
    tokenized_rejected = tokenizer(examples['rejected'], padding='max_length', truncation=True, max_length=max_seq_length)

    # Prepare the labels: 1 for 'chosen' and 0 for 'rejected'
    labels = [1 if i % 2 == 0 else 0 for i in range(len(examples['question']))]

    return {
        'input_ids_question': tokenized_questions['input_ids'],
        'attention_mask_question': tokenized_questions['attention_mask'],
        'input_ids_chosen': tokenized_chosen['input_ids'],
        'attention_mask_chosen': tokenized_chosen['attention_mask'],
        'input_ids_rejected': tokenized_rejected['input_ids'],
        'attention_mask_rejected': tokenized_rejected['attention_mask'],
        'labels': labels
    }



# Load the DPO dataset from Hugging Face
dpo_dataset = load_dataset("Intel/orca_dpo_pairs")

# Apply the preprocessing to the dataset
dpo_dataset = dpo_dataset.map(preprocess_dpo_data, batched=True)


# You can convert to PyTorch tensors after processing
dpo_dataset.set_format(type='torch', columns=['input_ids_question', 'attention_mask_question', 'input_ids_chosen', 'attention_mask_chosen', 'input_ids_rejected', 'attention_mask_rejected', 'labels'])

train_loader = DataLoader(dpo_dataset['train'], batch_size=2, shuffle=True)

# Instantiate the Expert model and optimizer
config = ExpertConfig()
expert_model = Expert(config)

model_vocab_size = expert_model.lmt.decoder.word_embedding.num_embeddings
print(f"Model vocab size: {model_vocab_size}")
print(f"Model embedding size: {expert_model.lmt.decoder.word_embedding.embedding_dim}")
print(f"Configured max length: {config.max_length}")

optimizer = AdamW(expert_model.parameters(), lr=1e-5)
save_path = 'D:/EXPERT_WEIGHTS/dpo_model.pth'

# Train the DPO model
label_column = 'labels'
input_columns = ['input_ids_question', 'attention_mask_question', 'input_ids_chosen', 'attention_mask_chosen', 'input_ids_rejected', 'attention_mask_rejected']
avg_loss = expert_model.train_dpo(train_loader, optimizer, label_column, input_columns, config, save_path)

# Save the model
torch.save(expert_model.transformer_dpo.state_dict(), save_path)

Model vocab size: 30522
Model embedding size: 256
Configured max length: 512
Batch 1:
Input shape to LanguageModelTransformer: torch.Size([2, 510])
Max position index: 509, Position Embedding Size: 512


ValueError: Target size (torch.Size([2])) must be the same as input size (torch.Size([31132440]))