# TransformerDPO

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import random_split
from transformers import BertTokenizer
from datasets import load_dataset


class LORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1):
        super(LORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha

        # Original weight and bias of the linear layer
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        #print("self.weight Shape:", self.weight.shape)
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # LORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def forward(self, x):
        #print("LORALayer Input Shape:", x.shape)
        
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)
        
        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        #print("LORALayer Output Shape:", x.shape)

        return x

class QLORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
        super(QLORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha
        self.quantization_bits = quantization_bits

        # Original weight and bias
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # QLORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(output_dim)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def quantize(self, x, num_bits):
        # Implement a simple quantization method
        scale = x.abs().max()
        x_quantized = torch.round(x / scale * (2**num_bits - 1))
        return x_quantized, scale

    def forward(self, x):
        #print("QLORALayer Input Shape:", x.shape)
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
        B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
        lora_adjustment = self.dropout(lora_adjustment)
        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)

        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        x = self.layer_norm(x)

        #print("QLORALayer Output Shape:", x.shape)

        return x
    
    def update_alpha(self, new_alpha):
        """
        Update the alpha scaling factor.
        """
        self.alpha = new_alpha

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does the matrix multiplication for query*keys for each training example
        # with every other training example, then sum it up
        attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            attention = attention.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            LORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            LORALayer(forward_expansion * embed_size, embed_size, rank),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class LanguageModelDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
        super(LanguageModelDecoder, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        # Adding BatchNorm layers
        self.bn1 = nn.BatchNorm1d(embed_size)
        self.bn2 = nn.BatchNorm1d(embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                for _ in range(num_layers)
            ]
        )

        # QLORA layers
        self.qlora_feed_forward = nn.Sequential(
            QLORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            QLORALayer(forward_expansion * embed_size, embed_size, rank),
        )
        self.use_qlora = False  # Flag to toggle QLORA

        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length) #.to(x.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn1(x)
        x = x.transpose(1, 2)

        for layer in self.layers:
            x = layer(x, x, x, trg_mask)
            if self.use_qlora:
                x = self.qlora_feed_forward(x)

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn2(x)
        x = x.transpose(1, 2)

        out = self.fc_out(x)
        return out

    def toggle_qlora(self, use_qlora: bool):
        self.use_qlora = use_qlora

class LanguageModelTransformer(nn.Module):
    def __init__(self, vocab_size, embed_size=256, num_layers=6, forward_expansion=4, heads=8, dropout=0, max_length=100, rank=16):
        super(LanguageModelTransformer, self).__init__()

        self.decoder = LanguageModelDecoder(
            vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
            rank,
        )

    def forward(self, trg):
        trg_mask = self.make_trg_mask(trg)
        out = self.decoder(trg, trg_mask)
        return out

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        ).to(trg.device)

        return trg_mask
    
    def process_dpo_inputs(self, chosen_ids, chosen_mask, rejected_ids, rejected_mask):
        """
        Process inputs for Direct Preference Optimization, generating logits for both 'chosen' and 'rejected' responses.

        Parameters:
        chosen_ids (torch.Tensor): Token IDs for chosen responses.
        chosen_mask (torch.Tensor): Attention mask for chosen responses.
        rejected_ids (torch.Tensor): Token IDs for rejected responses.
        rejected_mask (torch.Tensor): Attention mask for rejected responses.

        Returns:
        dict: A dictionary containing logits for 'chosen' and 'rejected' responses.
        """

        # Process 'chosen' responses
        chosen_outputs = self.decoder(chosen_ids, chosen_mask)
        chosen_logits = chosen_outputs  # Assuming your decoder returns logits directly

        # Process 'rejected' responses
        rejected_outputs = self.decoder(rejected_ids, rejected_mask)
        rejected_logits = rejected_outputs  # Assuming your decoder returns logits directly

        return {
            "chosen_logits": chosen_logits,
            "rejected_logits": rejected_logits
        }

# Define vocabulary size and dummy data parameters
NUM_WORDS = 1000  # Example vocabulary size
sequence_length = 30  # Sequence length for the LanguageDataset
dummy_data_size = 1000  # Total number of tokens in the dummy dataset


# Load dataset
dataset = load_dataset('wikipedia', '20220301.simple')

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab_size = tokenizer.vocab_size

def tokenize_function(examples):
    # Tokenize the text
    tokenized_output = tokenizer(examples['text'], padding='max_length', truncation=True, max_length=sequence_length)
    
    # Shift input_ids to create labels and truncate the last token
    labels = [seq[1:] + [tokenizer.pad_token_id] for seq in tokenized_output['input_ids']]
    tokenized_output['labels'] = labels
    
    return tokenized_output

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Set format for PyTorch
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_loader = DataLoader(tokenized_datasets['train'], batch_size=64, shuffle=True)

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create the model instance
model = LanguageModelTransformer(
    vocab_size=vocab_size,  # Use the vocab size from the tokenizer
    embed_size=256,
    num_layers=6,
    forward_expansion=4,
    heads=8,
    dropout=0,
    max_length=100,
    rank=16
).to(device)


def calculate_new_alpha(current_loss, initial_loss, initial_alpha=1.0, final_alpha=0.1):
    """
    Calculate a new alpha value based on the current loss.
    """
    if current_loss >= initial_loss:
        return initial_alpha  # Keep initial alpha if loss isn't decreasing

    loss_ratio = current_loss / initial_loss
    alpha_range = initial_alpha - final_alpha
    new_alpha = final_alpha + (alpha_range * loss_ratio)
    return new_alpha

# Enable QLORA during training
model.decoder.toggle_qlora(True)

initial_loss = None
# Training loop
# Assuming model is an instance of LanguageModelTransformer
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-8, weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=4, gamma=0.98)
num_epochs = 5

# Training loop
for epoch in range(num_epochs):
    model.train()
    model.decoder.toggle_qlora(True)
    total_loss = 0

    for batch_idx, batch in enumerate(train_loader):
        inputs = batch['input_ids'].to(device)
        targets = batch['labels'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

        # Check for NaN in loss
        if math.isnan(loss.item()):
            print("Encountered NaN loss, stopping training")
            break

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        optimizer.step()

        total_loss += loss.item()

        # Set the initial_loss after the first batch of the first epoch
        if initial_loss is None and batch_idx == 0:
            initial_loss = loss.item()

    scheduler.step()

    # Check for NaN in total_loss
    if math.isnan(total_loss):
        print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
        break
    else:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

    # Average loss for the epoch
    average_loss = total_loss / len(train_loader)

    # Update alpha at the end of each epoch based on the average loss
    new_alpha = calculate_new_alpha(average_loss, initial_loss)
    for layer in model.modules():
        if isinstance(layer, QLORALayer):
            layer.update_alpha(new_alpha)

    #model.decoder.toggle_qlora(False)




AttributeError: partially initialized module 'datasets' has no attribute 'utils' (most likely due to a circular import)

# TransformerRAG

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import random_split
from transformers import BertTokenizer
from datasets import load_dataset
from transformers import BertModel
import fitz  # PyMuPDF

def extract_text_from_pdf(file_path):
    text = ""
    with fitz.open(file_path) as doc:
        for page in doc:
            text += page.get_text()
    return text


def split_into_chunks(text, chunk_size):
    # Split the text into chunks of chunk_size
    return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

def preprocess_text(text, max_length=512):
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Split the text into smaller chunks to maintain context
    # The chunk size is slightly less than max_length to account for special tokens
    chunk_size = max_length - 50  # Adjust this value based on the model's requirements
    text_chunks = split_into_chunks(text, chunk_size)

    # Process each chunk
    processed_chunks = []
    for chunk in text_chunks:
        tokenized_output = tokenizer(chunk, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
        processed_chunk = {
            'input_ids': tokenized_output['input_ids'],
            'attention_mask': tokenized_output['attention_mask']
        }
        processed_chunks.append(processed_chunk)

    return processed_chunks


def create_dataset_from_pdfs(pdf_file_paths):
    dataset = []
    for file_path in pdf_file_paths:
        text = extract_text_from_pdf(file_path)
        processed_text = preprocess_text(text)
        dataset.append(processed_text)
    return dataset

def retrieve_contexts(dataset, query_embedding, top_k=5):
    # Assume dataset is a list of dictionaries with each dictionary containing 'input_ids' and 'attention_mask'
    # for a particular context and that each context has been processed through a DPRContextEncoder to get embeddings

    # Placeholder for storing similarity scores
    similarity_scores = []

    # Iterate over each context in the dataset
    for context in dataset:
        context_input_ids = context['input_ids'].to(device)
        context_attention_mask = context['attention_mask'].to(device)

        # Assuming context_encoder is an instance of CustomDPRContextEncoder that's already trained
        # and available in your scope
        context_embedding = context_encoder(context_input_ids, context_attention_mask)

        # Compute similarity (e.g., using dot product)
        similarity = torch.matmul(query_embedding, context_embedding.T)

        similarity_scores.append(similarity.squeeze().item())

    # Sort contexts based on similarity scores and retrieve top_k indices
    top_k_indices = sorted(range(len(similarity_scores)), key=lambda i: similarity_scores[i], reverse=True)[:top_k]

    # Retrieve top_k contexts
    top_contexts = [dataset[i] for i in top_k_indices]

    return top_contexts

def rag_retrieve_and_generate(dataset, query):
    # Instantiate the question encoder
    question_encoder = DPRQuestionEncoder()

    # Encode the query
    encoded_query = question_encoder(query)

    # Retrieve relevant context
    # This involves finding the most similar documents in the dataset
    # For simplicity, this is represented as a function 'retrieve_contexts'
    relevant_contexts = retrieve_contexts(dataset, encoded_query)

    # Language model for generation
    language_model = LanguageModelTransformer()

    # Generate a response based on the retrieved contexts
    # This step may involve further formatting or preprocessing
    response = language_model.generate_response(relevant_contexts)

    return response

# pdfs
pdf_file_paths = [r'C:\Users\robbi\IEEMM\DPO.pdf', 
                  r'C:\Users\robbi\IEEMM\MAMBA.pdf',
                  r'C:\Users\robbi\IEEMM\QLORA.pdf',
                  r'C:\Users\robbi\IEEMM\RAG.pdf',
                  r'C:\Users\robbi\IEEMM\SWITCH_TRANSFORMER.pdf']

dataset = create_dataset_from_pdfs(pdf_file_paths)


class CustomDPRContextEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
        super(CustomDPRContextEncoder, self).__init__()
        # Transformer-based model, e.g., BERT
        self.bert_model = BertModel.from_pretrained(model_name)
        # Additional layer to produce fixed-size embeddings
        self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

    def forward(self, input_ids, attention_mask=None):
        # Generate outputs from the BERT model
        outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        # Use the pooled output for creating embeddings
        pooled_output = outputs.pooler_output
        # Pass through the embedding layer
        context_embeddings = self.embedding_layer(pooled_output)
        return context_embeddings

class DPRQuestionEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
        super(DPRQuestionEncoder, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.embedding_layer = nn.Linear(self.bert.config.hidden_size, embedding_dim)

    def forward(self, input_ids, attention_mask, **kwargs):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        embeddings = self.embedding_layer(pooled_output)
        return embeddings

class LORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1):
        super(LORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha

        # Original weight and bias of the linear layer
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        #print("self.weight Shape:", self.weight.shape)
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # LORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def forward(self, x):
        #print("LORALayer Input Shape:", x.shape)
        
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)
        
        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        #print("LORALayer Output Shape:", x.shape)

        return x

class QLORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
        super(QLORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha
        self.quantization_bits = quantization_bits

        # Original weight and bias
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # QLORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(output_dim)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def quantize(self, x, num_bits):
        # Implement a simple quantization method
        scale = x.abs().max()
        x_quantized = torch.round(x / scale * (2**num_bits - 1))
        return x_quantized, scale

    def forward(self, x):
        #print("QLORALayer Input Shape:", x.shape)
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
        B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
        lora_adjustment = self.dropout(lora_adjustment)
        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)

        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        x = self.layer_norm(x)

        #print("QLORALayer Output Shape:", x.shape)

        return x
    
    def update_alpha(self, new_alpha):
        """
        Update the alpha scaling factor.
        """
        self.alpha = new_alpha

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does the matrix multiplication for query*keys for each training example
        # with every other training example, then sum it up
        attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            attention = attention.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            LORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            LORALayer(forward_expansion * embed_size, embed_size, rank),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class LanguageModelDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
        super(LanguageModelDecoder, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        # Adding BatchNorm layers
        self.bn1 = nn.BatchNorm1d(embed_size)
        self.bn2 = nn.BatchNorm1d(embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                for _ in range(num_layers)
            ]
        )

        # QLORA layers
        self.qlora_feed_forward = nn.Sequential(
            QLORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            QLORALayer(forward_expansion * embed_size, embed_size, rank),
        )
        self.use_qlora = False  # Flag to toggle QLORA

        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length) #.to(x.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn1(x)
        x = x.transpose(1, 2)

        for layer in self.layers:
            x = layer(x, x, x, trg_mask)
            if self.use_qlora:
                x = self.qlora_feed_forward(x)

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn2(x)
        x = x.transpose(1, 2)

        out = self.fc_out(x)
        #print(f"shape of output of forward method of LanguageModelDecoder: {out.shape} ")

        return out

    def toggle_qlora(self, use_qlora: bool):
        self.use_qlora = use_qlora

class LanguageModelTransformer(nn.Module):
    def __init__(self, vocab_size, embed_size=256, num_layers=6, forward_expansion=4, heads=8, dropout=0, max_length=100, rank=16):
        super(LanguageModelTransformer, self).__init__()

        self.decoder = LanguageModelDecoder(
            vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
            rank,
        )

    def forward(self, trg):
        trg_mask = self.make_trg_mask(trg)
        out = self.decoder(trg, trg_mask)
        return out

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        ).to(trg.device)

        return trg_mask
    
    # Function to enable or disable QLORA layers (for fine-tuning purposes)
    def toggle_qlora(self, use_qlora: bool):
        self.decoder.toggle_qlora(use_qlora)

    def generate_response(self, input_ids, attention_mask):
        # Assuming you have a forward method that returns logits
        logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
        
        # Convert logits to probabilities
        probabilities = F.softmax(logits, dim=-1)
        
        # For simplicity, using greedy decoding here. You might want to use beam search or sampling.
        predicted_token_id = torch.argmax(probabilities, dim=-1)
        
        # Convert predicted token ids to tokens
        predicted_tokens = [tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
        
        # Join tokens to form the response string. This is a very basic way to generate text and might not produce the best results.
        response = tokenizer.convert_tokens_to_string(predicted_tokens)
        
        return response

# Define vocabulary size and dummy data parameters
NUM_WORDS = 1000  # Example vocabulary size
sequence_length = 30  # Sequence length for the LanguageDataset
dummy_data_size = 1000  # Total number of tokens in the dummy dataset



# Load dataset
dataset = load_dataset('wikipedia', '20220301.simple')

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab_size = tokenizer.vocab_size

def tokenize_function(examples):
    # Tokenize the text
    tokenized_output = tokenizer(examples['text'], padding='max_length', truncation=True, max_length=sequence_length)
    
    # Shift input_ids to create labels and truncate the last token
    labels = [seq[1:] + [tokenizer.pad_token_id] for seq in tokenized_output['input_ids']]
    tokenized_output['labels'] = labels
    
    return tokenized_output

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Set format for PyTorch
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_loader = DataLoader(tokenized_datasets['train'], batch_size=64, shuffle=True)

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create the model instance
model = LanguageModelTransformer(
    vocab_size=vocab_size,  # Use the vocab size from the tokenizer
    embed_size=256,
    num_layers=6,
    forward_expansion=4,
    heads=8,
    dropout=0,
    max_length=100,
    rank=16
).to(device)


def calculate_new_alpha(current_loss, initial_loss, initial_alpha=1.0, final_alpha=0.1):
    """
    Calculate a new alpha value based on the current loss.
    """
    if current_loss >= initial_loss:
        return initial_alpha  # Keep initial alpha if loss isn't decreasing

    loss_ratio = current_loss / initial_loss
    alpha_range = initial_alpha - final_alpha
    new_alpha = final_alpha + (alpha_range * loss_ratio)
    return new_alpha

# Assuming the context and question encoders and language model are already defined as per your provided code

def train_dpr_encoders(train_data, context_encoder, question_encoder, optimizer_context, optimizer_question, epochs):
    loss_function = nn.CosineEmbeddingLoss()

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    for epoch in range(epochs):
        total_loss = 0

        for i in range(len(train_data["queries"])):
            query = train_data["queries"][i]
            context = train_data["contexts"][i]

            # Ensure query is a string
            if not isinstance(query, str):
                raise ValueError("Query must be a string.")
            tokenized_query = tokenizer(query, return_tensors="pt", padding=True, truncation=True, max_length=512)

            # Ensure context is a string
            if isinstance(context, dict):
                # If context is a dictionary, extract the text content. This is a placeholder and might need adjustment.
                context_text = context.get("text", "")
            elif isinstance(context, str):
                context_text = context
            else:
                raise ValueError("Context must be a string or a dictionary containing a text field.")
            tokenized_context = tokenizer(context_text, return_tensors="pt", padding=True, truncation=True, max_length=512)

            question_embeddings = question_encoder(input_ids=tokenized_query['input_ids'], attention_mask=tokenized_query['attention_mask'])
            context_embeddings = context_encoder(input_ids=tokenized_context['input_ids'], attention_mask=tokenized_context['attention_mask'])

            # The labels tensor should have the same first dimension size as the input tensors
            labels = torch.tensor([1.0] * question_embeddings.size(0), dtype=torch.float).to(question_embeddings.device)

            loss = loss_function(question_embeddings, context_embeddings, labels)

            optimizer_context.zero_grad()
            optimizer_question.zero_grad()
            loss.backward()
            optimizer_context.step()
            optimizer_question.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(train_data['queries'])}")

    print("Training complete.")

def train_language_model(model, train_loader, device, num_epochs=5, lr=1e-8, weight_decay=1e-4):
    # Enable QLORA during training
    model.decoder.toggle_qlora(True)

    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = StepLR(optimizer, step_size=4, gamma=0.98)

    initial_loss = None

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0

        for batch_idx, batch in enumerate(train_loader):
            inputs = batch['input_ids'].to(device)
            targets = batch['labels'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))

            # Check for NaN in loss
            if math.isnan(loss.item()):
                print("Encountered NaN loss, stopping training")
                break

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()

            total_loss += loss.item()

            # Set the initial_loss after the first batch of the first epoch
            if initial_loss is None and batch_idx == 0:
                initial_loss = loss.item()

        scheduler.step()

        # Check for NaN in total_loss
        if math.isnan(total_loss):
            print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
            break
        else:
            print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

        # Average loss for the epoch
        average_loss = total_loss / len(train_loader)

        # Update alpha at the end of each epoch based on the average loss
        new_alpha = calculate_new_alpha(average_loss, initial_loss)
        for layer in model.modules():
            if isinstance(layer, QLORALayer):
                layer.update_alpha(new_alpha)

    # Toggle QLORA off after training
    model.decoder.toggle_qlora(False)

    return model


# Sample training data structure (You need to construct this with real data)
train_data = {
    "queries": [
        # Queries for DPO.pdf
        "What is Direct Preference Optimization (DPO)?",
        "How does Direct Preference Optimization work?",
        "How can I implement Direct Preference Optimization in my organization?",
        "Why does Direct Preference Optimization improve the efficiency of language modelling?",
        # Queries for MAMBA.pdf
        "What is MAMBA?",
        "How does MAMBA function?",
        "How can I build a system based on MAMBA technology?",
        "Why does MAMBA enhance the performance of its application area?",
        # Queries for QLORA.pdf
        "What is QLORA?",
        "How does QLORA operate?",
        "How can I develop a project using QLORA?",
        "Why does QLORA improve the capabilities of its relevant field?",
        # Queries for RAG.pdf
        "What is Retrieval Augmented Generation (RAG)?",
        "How does Retrieval Augmented Generation work?",
        "How can I build a Retrieval Augmented Generation model?",
        "Why does Retrieval Augmented Generation enhance language model performance?",
        # Queries for SWITCH_TRANSFORMER.pdf
        "What is the Switch Transformer model?",
        "How does the Switch Transformer model operate?",
        "How can I construct a Switch Transformer model?",
        "Why does the Switch Transformer model improve language processing tasks?"
    ],
    "contexts": [
        # Contexts from DPO.pdf
        dataset['train'][0],  # Assuming dataset[0] is the processed content of DPO.pdf
        dataset['train'][0],
        dataset['train'][0],
        dataset['train'][0],
        # Contexts from MAMBA.pdf
        dataset['train'][1],  # Assuming dataset[1] is the processed content of MAMBA.pdf
        dataset['train'][1],
        dataset['train'][1],
        dataset['train'][1],
        # Contexts from QLORA.pdf
        dataset['train'][2],  # Assuming dataset[2] is the processed content of QLORA.pdf
        dataset['train'][2],
        dataset['train'][2],
        dataset['train'][2],
        # Contexts from RAG.pdf
        dataset['train'][3],  # Assuming dataset[3] is the processed content of RAG.pdf
        dataset['train'][3],
        dataset['train'][3],
        dataset['train'][3],
        # Contexts from SWITCH_TRANSFORMER.pdf
        dataset['train'][4],  # Assuming dataset[4] is the processed content of SWITCH_TRANSFORMER.pdf
        dataset['train'][4],
        dataset['train'][4],
        dataset['train'][4]
    ]
}


# Instantiate models and optimizers
context_encoder = CustomDPRContextEncoder()
question_encoder = DPRQuestionEncoder()
rag_language_model = LanguageModelTransformer(vocab_size=vocab_size)

# Define optimizers for each model component
optimizer_context = AdamW(context_encoder.parameters(), lr=1e-5)
optimizer_question = AdamW(question_encoder.parameters(), lr=1e-5)
optimizer_language_model = AdamW(rag_language_model.parameters(), lr=1e-5)
epochs = 5
# Train the models
train_dpr_encoders(train_data, context_encoder, question_encoder, optimizer_context, optimizer_question, epochs=epochs)
train_language_model(model, train_loader, device, num_epochs=5, lr=1e-8, weight_decay=1e-4)




Epoch 1/5, Loss: 0.7378959834575654
Epoch 2/5, Loss: 0.3635985404253006
Epoch 3/5, Loss: 0.162116739153862
Epoch 4/5, Loss: 0.10286747217178345


KeyboardInterrupt: 

# MAMBA

In [7]:
from torch.optim.lr_scheduler import LambdaLR
import torch.optim as optim

import torch
import torch.nn as nn
import math
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader
from transformers import BertTokenizer
from datasets import load_dataset

from einops import rearrange


'''
device = "cuda" if torch.cuda.is_available() else "cpu"


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = load_dataset('wikipedia', '20220301.simple')

def tokenize_function(examples):
    # Tokenize the text
    tokenized_output = tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)
    
    # Convert lists to PyTorch tensors
    input_ids = torch.tensor(tokenized_output['input_ids']).to(device)
    attention_mask = torch.tensor(tokenized_output['attention_mask']).to(device)

    # Creating labels by shifting the input_ids
    labels = input_ids[:, :-1].clone().to(device)
    labels = torch.nn.functional.pad(labels, (0, 1), value=tokenizer.pad_token_id)  # Pad labels to match sequence length
    
    tokenized_output['input_ids'] = input_ids
    tokenized_output['attention_mask'] = attention_mask
    tokenized_output['labels'] = labels

    return tokenized_output





tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])


# DataLoader - same as your working code
train_loader = DataLoader(tokenized_datasets['train'], batch_size=8, shuffle=True)
#val_loader = DataLoader(tokenized_datasets['val'], batch_size=8, sh



tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])


# DataLoader - same as your working code
train_loader = DataLoader(tokenized_datasets['train'], batch_size=8, shuffle=True)
#val_loader = DataLoader(tokenized_datasets['val'], batch_size=8, shuffle=True)
'''
# RoPE
class RotaryPositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_len).type_as(inv_freq)
        freqs = torch.einsum('n , d -> n d', t, inv_freq)
        self.register_buffer('sin', freqs.sin())
        self.register_buffer('cos', freqs.cos())

    def forward(self, x):
        n, _, device = x.shape[1], self.dim // 2, x.device
        sin, cos = self.sin[:n].to(device), self.cos[:n].to(device)

        # Apply RoPE to even and odd indices separately
        x_even = x[..., :self.dim:2] * cos.unsqueeze(0) + torch.roll(x[..., 1:self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
        x_odd = x[..., 1:self.dim:2] * cos.unsqueeze(0) - torch.roll(x[..., :self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
        return torch.cat((x_even, x_odd), dim=-1)




# SWIGLU
class SwiGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(SwiGLU, self).__init__()
        self.fc1 = nn.Linear(dim_in, dim_out)
        self.fc2 = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        gate = torch.sigmoid(self.fc2(x))
        return self.fc1(x) * gate

class SimplifiedMAMBA(nn.Module):
    # Adjusted to include SwiGLU blocks
    def __init__(self, num_layers, d_model, d_state, d_conv, expansion_factor):
        super().__init__()

        self.num_layers = num_layers
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_state),
            nn.GELU(),
            nn.Linear(d_state, d_model)
        )
        self.input_embedding = nn.Linear(d_model, d_model)
        self.convs = nn.Sequential(*[nn.Conv1d(d_model, d_model, kernel_size=d_conv, padding=(d_conv // 2)) for _ in range(num_layers)])
        self.swiglu = SwiGLU(d_model, d_model)
        self.output_projection = nn.Linear(d_model, d_model * expansion_factor)  # Adjusted to match the output of SwiGLU


        self.initialize_weights()

    def initialize_weights(self):
        gain = nn.init.calculate_gain('relu')

        nn.init.orthogonal_(self.input_embedding.weight, gain)
        nn.init.normal_(self.input_embedding.bias, mean=0, std=0.01)

        nn.init.kaiming_uniform_(self.convs[-1].weight, a=math.sqrt(5))
        nn.init.zeros_(self.convs[-1].bias)

        nn.init.xavier_uniform_(self.feedforward[0].weight, gain=nn.init.calculate_gain('relu'))
        nn.init.zeros_(self.feedforward[0].bias)

        nn.init.xavier_uniform_(self.feedforward[2].weight, gain=nn.init.calculate_gain('linear'))
        nn.init.zeros_(self.feedforward[2].bias)

        nn.init.xavier_uniform_(self.output_projection.weight, gain=nn.init.calculate_gain('linear'))
        nn.init.zeros_(self.output_projection.bias)

    def forward(self, inputs, attention_mask=None):
        print("Input shape:", inputs.shape)

        # Apply the attention mask if provided
        if attention_mask is not None:
            inputs = inputs * attention_mask.unsqueeze(-1)

        projected_inputs = self.input_embedding(inputs)
        print("projected_inputs pre-reshape shape:", projected_inputs.shape)

        projected_inputs = projected_inputs.permute(0, 2, 1)
        print("projected_inputs post-reshape shape:", projected_inputs.shape)

        for conv in self.convs:
            projected_inputs = conv(projected_inputs)

        projected_inputs = projected_inputs.permute(0, 2, 1)
        print("projected_inputs post convolution reshape:", projected_inputs.shape)

        projected_inputs = self.swiglu(projected_inputs)
        print("projected_inputs post swiglu shape:", projected_inputs.shape)

        output = self.output_projection(projected_inputs)
        print("output shape:", output.shape)

        return output

class SimplifiedLanguageModelMAMBA(nn.Module):
    # Including rotary positional encodings if required
    def __init__(self, vocab_size, num_layers, d_model, d_state, d_conv, expansion_factor):
        super().__init__()

        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = RotaryPositionalEncoding(d_model)
        self.simplified_mamba = SimplifiedMAMBA(num_layers, d_model, d_state, d_conv, expansion_factor)
        self.output_projection = nn.Linear(d_model*2, vocab_size)

        self.initialize_weights()

    def initialize_weights(self):
        gain = 1.0

        nn.init.orthogonal_(self.embedding.weight, gain)
        nn.init.xavier_uniform_(self.output_projection.weight, gain=gain)

    def forward(self, input_values, attention_mask):
        embedded = self.embedding(input_values) * math.sqrt(self.d_model)
        embedded = self.pos_encoder(embedded)
        simplified_mamba_output = self.simplified_mamba(embedded, attention_mask)
        logits = self.output_projection(simplified_mamba_output)
        return logits


LEARNING_RATE = 5e-4
WEIGHT_DECAY =  0.1
WARMUP_STEPS = 100
TOTAL_STEPS = 1000 # we want this to be : epochs * (size of dataset / batch_size )
EPOCHS = 100
VOCAB_SIZE = 30522
NUM_LAYERS = 4
BATCH_SIZE = 8
EXPANSION_FACTOR = 2
CLIP_GRADIENT = 1.0
D_MODEL = 512  # Dimensionality of the model's embeddings
D_STATE = 2048  # Dimensionality of the intermediate state in feedforward
D_CONV = 3  # Kernel size for convolutional layers


# Instantiate the model with the required parameters
model = SimplifiedLanguageModelMAMBA(vocab_size=VOCAB_SIZE, 
                                     num_layers=NUM_LAYERS, 
                                     d_model=D_MODEL, 
                                     d_state=D_STATE, 
                                     d_conv=D_CONV, 
                                     expansion_factor=EXPANSION_FACTOR)

model.to(device)
print("Using", device)

def setup_optimizer(model, learning_rate, weight_decay, warmup_steps, total_steps):
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Linear warmup with cosine decay
    scheduler = LambdaLR(optimizer, lambda step: min((step + 1) / warmup_steps, 0.5 * (1 + math.cos(math.pi * step / total_steps))))

    return optimizer, scheduler


# Initialize the optimizer and scheduler with appropriate parameters
optimizer, scheduler = setup_optimizer(model, LEARNING_RATE, WEIGHT_DECAY, WARMUP_STEPS, TOTAL_STEPS)
# Adjust the train loop
def train_loop(model, loader, optimizer, scheduler):
    loss_fn = nn.CrossEntropyLoss()
    progress_bar = tqdm(range(EPOCHS))

    for epoch in progress_bar:
        model.train()
        for batch in loader:
            input_values, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['labels']
            input_values = input_values.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(input_values, attention_mask)
            print(f"shape of output before loss fn: {outputs.shape}")
            print(f"shape of labels before loss fn: {labels.view(-1).shape}")

            loss = loss_fn(outputs.view(-1, outputs.size(-1)), labels.view(-1))

            # Backward pass and optimization
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRADIENT)
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

        # avg_loss = evaluate(model, val_loader, loss_fn)  # Uncomment and define evaluate function
        # print(f'\nEpoch {epoch}: Loss={avg_loss}\n')




train_loop(model, train_loader, optimizer, scheduler)



Using cuda


  0%|          | 0/100 [00:00<?, ?it/s]

Input shape: torch.Size([64, 30, 512])
projected_inputs pre-reshape shape: torch.Size([64, 30, 512])
projected_inputs post-reshape shape: torch.Size([64, 512, 30])
projected_inputs post convolution reshape: torch.Size([64, 30, 512])
projected_inputs post swiglu shape: torch.Size([64, 30, 512])
output shape: torch.Size([64, 30, 1024])
shape of output before loss fn: torch.Size([64, 30, 30522])
shape of labels before loss fn: torch.Size([1920])
Input shape: torch.Size([64, 30, 512])
projected_inputs pre-reshape shape: torch.Size([64, 30, 512])
projected_inputs post-reshape shape: torch.Size([64, 512, 30])
projected_inputs post convolution reshape: torch.Size([64, 30, 512])
projected_inputs post swiglu shape: torch.Size([64, 30, 512])
output shape: torch.Size([64, 30, 1024])
shape of output before loss fn: torch.Size([64, 30, 30522])
shape of labels before loss fn: torch.Size([1920])
Input shape: torch.Size([64, 30, 512])
projected_inputs pre-reshape shape: torch.Size([64, 30, 512])
proje

KeyboardInterrupt: 

# v3 Switch Internal Routing

# Unique Switch Routing Dataset training

In [1]:
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader

# Load the dataset, here using Python examples
dataset = load_dataset("code_search_net", "python")
# Preprocess and tokenize the dataset
def preprocess_function(examples):
    return tokenizer(examples['func_code_string'], padding="max_length", truncation=True, max_length=512)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

tokenized_dataset = dataset.map(preprocess_function, batched=True)
tokenized_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

train_dataset = tokenized_dataset["train"] #.select(range(2001, 4001))  


train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

Map:   0%|          | 0/412178 [00:00<?, ? examples/s]

Map:   0%|          | 0/22176 [00:00<?, ? examples/s]

Map:   0%|          | 0/23107 [00:00<?, ? examples/s]

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader

# Load the dataset, here using Python examples
dataset = load_dataset("code_search_net", "python")
# Preprocess and tokenize the dataset
def preprocess_function(examples):
    # Tokenize the code strings
    tokenized_output = tokenizer(examples['func_code_string'], padding="max_length", truncation=True, max_length=512)
    
    # Create labels for language modeling by shifting the input_ids
    labels = [row[:-1] + [tokenizer.pad_token_id] for row in tokenized_output["input_ids"]]
    tokenized_output["labels"] = labels

    return tokenized_output

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Apply the preprocessing function to the dataset
tokenized_dataset = dataset.map(preprocess_function, batched=True)

# Set the dataset format to include 'input_ids', 'attention_mask', and 'labels'
tokenized_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

# Prepare the DataLoader
train_dataset = tokenized_dataset["train"]
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)


def load_model_weights(model, model_path):
    # Load the saved file
    checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    
    # Extract the model state dictionary and load it into the model
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()  # Set the model to evaluation mode
    return model

def auxiliary_loss(gate_scores, expert_capacity):
    # Load per expert
    expert_load = gate_scores.sum(0) / gate_scores.size(0)

    # Balancing loss: Encourage each expert to be utilized equally
    loss_balancing = torch.std(expert_load)
    return loss_balancing

CAPACITY_FACTOR = 1  # Adjustable factor

def route_inputs(expert_indices, gate_scores, num_experts, capacity_factor=CAPACITY_FACTOR):
    # Convert capacity_factor to a tensor for operations
    capacity_factor_tensor = torch.tensor([capacity_factor], dtype=torch.float32)

    # Calculate the capacity for each expert
    capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()

    expert_counts = torch.zeros(num_experts, dtype=torch.int32)
    for idx in range(len(expert_indices)):
        selected_expert = expert_indices[idx]
        if expert_counts[selected_expert] < capacities[0]:  # Access the first element of capacities tensor
            expert_counts[selected_expert] += 1
        else:
            # Find alternative expert with available capacity
            available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
            if len(available_experts) > 0:
                alternative_expert = available_experts[0]
                expert_indices[idx] = alternative_expert
                expert_counts[alternative_expert] += 1
            else:
                # Handle the scenario when no experts are available
                print("No available experts to reroute. Handling overflow.")
                # Implement logic as needed for handling this case
    return expert_indices




# 1. Transformer with DPO:
tran_dpo = r'C:\Users\robbi\IEEMM\language_model_weights.pth'
# 2. MAMBA:
mamba_model_path = r'C:\Users\robbi\IEEMM\mamba_model_weights.pth'
# 3. Transformer and RAG:
context_encoder = r'C:\Users\robbi\IEEMM\context_encoder.pth'
language_model = r'C:\Users\robbi\IEEMM\language_model.pth'
question_encoder = r'C:\Users\robbi\IEEMM\question_encoder.pth'

class TransformerDPO(nn.Module):
    def __init__(self, model_path, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank):
        super(TransformerDPO, self).__init__()
        self.model = LanguageModelTransformer(
            vocab_size=vocab_size, 
            embed_size=embed_size, 
            num_layers=num_layers, 
            forward_expansion=forward_expansion, 
            heads=heads, 
            dropout=dropout, 
            max_length=100,  # Update this to match the saved model's configuration
            rank=rank
        )
        # Load pre-trained weights
        self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))


    def forward(self, input_values, attention_mask):
        # Assuming input_values and attention_mask are appropriate for your model
        return self.model(input_values, attention_mask)


class MAMBA(nn.Module):
    def __init__(self, model_path, vocab_size, num_layers, d_model, d_state, d_conv, expansion_factor):
        super(MAMBA, self).__init__()
        # Initialize MAMBA architecture with the configuration used during training
        self.model = SimplifiedLanguageModelMAMBA(
            vocab_size=vocab_size, 
            num_layers=num_layers, 
            d_model=d_model, 
            d_state=d_state, 
            d_conv=d_conv, 
            expansion_factor=expansion_factor
        )
        # Load pre-trained weights
        self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

    def forward(self, input_values, attention_mask):
        return self.model(input_values, attention_mask)

def process_for_transformer_rag(tokenizer, input_ids, attention_mask, context_texts):
    # Convert input_ids back to text
    # Note: This is a simplified representation. In practice, this might not be straightforward
    input_texts = [tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]

    # Process each input text with the corresponding context
    outputs = []
    for input_text, context in zip(input_texts, context_texts):
        combined_input = input_text + " " + context
        tokenized_input = tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
        output = some_transformer_rag_model_processing_function(tokenized_input)
        outputs.append(output)

    return torch.stack(outputs)


class TransformerRAG(nn.Module):
    def __init__(self, context_encoder_path, language_model_path, question_encoder_path, vocab_size):
        super(TransformerRAG, self).__init__()

        # Initialize and load the pre-trained Context Encoder
        self.context_encoder = CustomDPRContextEncoder()  # Ensure this matches the actual class name and import
        self.context_encoder = load_model_weights(self.context_encoder, context_encoder_path)

        # Initialize and load the pre-trained Language Model
        self.language_model = LanguageModelTransformer(vocab_size=vocab_size)  # Ensure `vocab_size` is correctly set
        self.language_model = load_model_weights(self.language_model, language_model_path)

        # Initialize and load the pre-trained Question Encoder
        self.question_encoder = DPRQuestionEncoder()  # Ensure this matches the actual class name and import
        self.question_encoder = load_model_weights(self.question_encoder, question_encoder_path)

        # Tokenizer (ensure it's the same one used during training)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def forward(self, context_texts, question_inputs):
        # Generate question embeddings
        question_embeddings = self.question_encoder(input_ids=question_inputs['input_ids'],
                                                    attention_mask=question_inputs['attention_mask'])

        # Process each context text
        context_embeddings = []
        for context_text in context_texts:
            tokenized_context = self.tokenizer(context_text, return_tensors="pt", padding=True, truncation=True, max_length=512, add_special_tokens=True)
            context_embedding = self.context_encoder(**tokenized_context)
            context_embeddings.append(context_embedding)

        # Simple retrieval mechanism based on cosine similarity
        cos_sim = torch.nn.CosineSimilarity(dim=1)
        similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in context_embeddings]
        most_relevant_context_idx = torch.argmax(torch.tensor(similarities))

        # Use the most relevant context to generate a response
        combined_input = question_inputs['input_text'] + " " + context_texts[most_relevant_context_idx]
        tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
        response_logits = self.language_model(**tokenized_combined_input)

        # Convert logits to probabilities and then to tokens
        probabilities = F.softmax(response_logits.logits, dim=-1)
        predicted_token_ids = torch.argmax(probabilities, dim=-1)
        predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])

        # Assemble the predicted tokens into a coherent response
        response = self.tokenizer.convert_tokens_to_string(predicted_tokens)

        return response


class SwitchGate(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(SwitchGate, self).__init__()
        self.fc1 = nn.Linear(input_dim, input_dim // 2)
        self.fc2 = nn.Linear(input_dim // 2, num_experts)

    def forward(self, x):
        print(f"switch gate input x : {x.shape}")
        x = x.float()
        print(f"x.float : {x.shape}")
        x = F.relu(self.fc1(x))
        print(f"F.relu(self.fc1(x)) : {x.shape}")
        gate_scores = F.softmax(self.fc2(x), dim=-1)
        print(f"F.softmax(self.fc2(x), dim=-1) : {gate_scores.shape}")

        return gate_scores

class SwitchTransformerRouting(nn.Module):
    def __init__(self, 
                 input_dim, 
                 num_experts, 
                 context_encoder_path, 
                 language_model_path, 
                 question_encoder_path, 
                 dpo_model_path, 
                 vocab_size, 
                 embed_size, 
                 num_layers, 
                 forward_expansion, 
                 heads, 
                 dropout, 
                 max_length, 
                 rank,
                 mamba_model_path):
        super(SwitchTransformerRouting, self).__init__()
        self.router = SwitchGate(input_dim, num_experts)
        # Instantiate TransformerRAG with the required model paths
        transformer_rag = TransformerRAG(
            context_encoder_path=context_encoder_path,
            language_model_path=language_model_path,
            question_encoder_path=question_encoder_path,
            vocab_size=vocab_size 
        )

        # Instantiate TransformerDPO with the required configurations
        transformer_dpo = TransformerDPO(
            model_path=dpo_model_path,
            vocab_size=vocab_size,  # Ensure `vocab_size` is correctly set
            embed_size=256,        # Assuming these values match your saved model's configuration
            num_layers=6,
            forward_expansion=4,
            heads=8,
            dropout=0.1,
            max_length=512,        # This should match the saved model's configuration
            rank=16
        )

        # Instantiate MAMBA with the required configurations
        mamba = MAMBA(
            model_path=mamba_model_path,
            vocab_size=30522,  # Example values, adjust according to your training setup
            num_layers=4,
            d_model=512,
            d_state=2048,
            d_conv=3,
            expansion_factor=2
        )
        #self.input_embedding = nn.Linear(actual_input_feature_size, 512)

        self.experts = nn.ModuleList([
            transformer_rag,
            transformer_dpo,
            mamba
        ])

        self.input_embedding = nn.Linear(512, input_dim)  # Assuming the actual feature size is 512


    def forward(self, x, attention_mask):
        print(f"Input shape before embedding: {x.shape}")
        x = x.float()  # Convert x to float

        x = self.input_embedding(x)  # Embedding the input to the required dimension
        print(f"Input shape after embedding: {x.shape}")        
        gate_scores = self.router(x)
        expert_indices = torch.argmax(gate_scores, dim=1)
        expert_indices = route_inputs(expert_indices, gate_scores, num_experts=len(self.experts))

        final_output = torch.zeros_like(x)
        aux_loss = 0


        for i, expert in enumerate(self.experts):
            mask = expert_indices == i
            if mask.any():
                selected_inputs = x[mask]
                selected_attention_mask = attention_mask[mask]

                if isinstance(expert, TransformerRAG):
                    # Assuming 'context_texts' is a list of context strings for each input
                    context_texts = ['context for each input'] # Replace with actual context texts
                    expert_output = process_for_transformer_rag(
                        tokenizer=self.tokenizer,
                        input_ids=selected_inputs,
                        attention_mask=selected_attention_mask,
                        context_texts=context_texts
                    )
                else:
                    # Process as usual for other experts
                    expert_output = expert(selected_inputs, selected_attention_mask)

                final_output[mask] = expert_output

        # Compute auxiliary loss for load balancing
        aux_loss += auxiliary_loss(gate_scores, expert_capacity=torch.tensor([CAPACITY_FACTOR] * len(self.experts)))

        return final_output, aux_loss


import torch
import torch.optim as optim
# Define necessary parameters for initializing the SwitchTransformerRouting model
input_dim = 512  # Example value, adjust as per your model's input dimension
num_experts = 3  # Since you have three experts (TransformerDPO, MAMBA, TransformerRAG)

# Instantiate the SwitchTransformerRouting model
model = SwitchTransformerRouting(
    input_dim=input_dim,
    num_experts=num_experts,
    context_encoder_path=context_encoder,
    language_model_path=language_model,
    question_encoder_path=question_encoder,
    dpo_model_path=tran_dpo,
    vocab_size=30522,  # Example value, adjust as per your setup
    embed_size=256,    # Example value
    num_layers=6,      # Example value
    forward_expansion=4,  # Example value
    heads=8,          # Example value
    dropout=0.1,      # Example value
    max_length=512,   # Example value
    rank=16,          # Example value
    mamba_model_path=mamba_model_path
).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))



# Define main task loss function and optimizer
main_loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)

# Hyperparameter for the weight of the auxiliary loss
aux_loss_weight = 0.1
# Training loop
model.train()  # Set the model to training mode

# Determine the device from the model's parameters
device = next(model.parameters()).device

for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, batch in enumerate(train_loader):
        # Extract inputs, attention_mask, and targets from the batch
        inputs = batch['input_ids']
        attention_mask = batch['attention_mask']
        targets = batch['labels']

        # Move data to the same device as the model
        inputs, attention_mask, targets = inputs.to(device), attention_mask.to(device), targets.to(device)

        optimizer.zero_grad()  # Clear gradients

        # Forward pass
        outputs, aux_loss = model(inputs, attention_mask)

        # Compute the main task loss
        main_loss = main_loss_function(outputs.view(-1, outputs.size(-1)), targets.view(-1))

        # Combine the main task loss with the auxiliary loss
        total_loss = main_loss + aux_loss_weight * aux_loss

        # Backward pass and optimize
        total_loss.backward()
        optimizer.step()

        # Logging
        if batch_idx % 100 == 0:  # Adjust print frequency according to your needs
            print(f'Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {total_loss.item()}')

    print(f'End of Epoch {epoch+1}, Average Loss: {total_loss / len(train_loader)}')




Input shape before embedding: torch.Size([8, 512])
Input shape after embedding: torch.Size([8, 512])
switch gate input x : torch.Size([8, 512])
x.float : torch.Size([8, 512])
F.relu(self.fc1(x)) : torch.Size([8, 256])
F.softmax(self.fc2(x), dim=-1) : torch.Size([8, 3])
No available experts to reroute. Handling overflow.
No available experts to reroute. Handling overflow.


TypeError: process_for_transformer_rag() missing 2 required positional arguments: 'attention_mask' and 'context_texts'

# testing/debugging

In [21]:
# Importing necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader

# Load the dataset
dataset = load_dataset("code_search_net", "python")

# Define the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Preprocess and tokenize the dataset
def preprocess_function(examples):
    return tokenizer(examples['func_code_string'], padding="max_length", truncation=True, max_length=512)

tokenized_dataset = dataset.map(preprocess_function, batched=True)
tokenized_dataset.set_format('torch', columns=['input_ids', 'attention_mask'])

# Define training dataset and DataLoader
train_dataset = tokenized_dataset["train"]
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# Define a dummy model for testing
class DummyModel(nn.Module):
    def __init__(self):
        super(DummyModel, self).__init__()
        self.linear = nn.Linear(512, 10)  # An arbitrary linear layer

    def forward(self, x):
        print(f"Input shape at the beginning of forward: {x.shape}")
        x = x.to(torch.float32)  # Convert input to float
        x = self.linear(x)
        print(f"Output shape after linear layer: {x.shape}")
        return x


# Instantiate the dummy model
model = DummyModel()

# Iterate over the DataLoader and pass the batches through the model
for batch in train_loader:
    inputs = batch['input_ids']
    print(f"Batch shape from DataLoader: {inputs.shape}")
    outputs = model(inputs)  # Forward pass through the model

    # Break after first batch to avoid long outputs
    break



Batch shape from DataLoader: torch.Size([8, 512])
Input shape at the beginning of forward: torch.Size([8, 512])
Output shape after linear layer: torch.Size([8, 10])


In [6]:
!pip uninstall datasets
!pip install datasets


^C


# v4 Switch Router

In [12]:
checkpoint = torch.load(tran_dpo, map_location='cpu')
print(checkpoint.keys())


odict_keys(['decoder.word_embedding.weight', 'decoder.position_embedding.weight', 'decoder.bn1.weight', 'decoder.bn1.bias', 'decoder.bn1.running_mean', 'decoder.bn1.running_var', 'decoder.bn1.num_batches_tracked', 'decoder.bn2.weight', 'decoder.bn2.bias', 'decoder.bn2.running_mean', 'decoder.bn2.running_var', 'decoder.bn2.num_batches_tracked', 'decoder.layers.0.attention.values.weight', 'decoder.layers.0.attention.keys.weight', 'decoder.layers.0.attention.queries.weight', 'decoder.layers.0.attention.fc_out.weight', 'decoder.layers.0.attention.fc_out.bias', 'decoder.layers.0.norm1.weight', 'decoder.layers.0.norm1.bias', 'decoder.layers.0.norm2.weight', 'decoder.layers.0.norm2.bias', 'decoder.layers.0.feed_forward.0.weight', 'decoder.layers.0.feed_forward.0.bias', 'decoder.layers.0.feed_forward.0.A', 'decoder.layers.0.feed_forward.0.B', 'decoder.layers.0.feed_forward.2.weight', 'decoder.layers.0.feed_forward.2.bias', 'decoder.layers.0.feed_forward.2.A', 'decoder.layers.0.feed_forward.2.B

In [3]:
# 0. Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import fitz  # PyMuPDF
from transformers import BertModel
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

torch.autograd.set_detect_anomaly(True)

#############################################################################
# 1. Preprocessing Data

# Load the dataset
code_dataset = load_dataset("code_search_net", "python")

# Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Preprocess and tokenize dataset
def preprocess_function(examples):
    tokenized_output = tokenizer(examples['func_code_string'], padding="max_length", truncation=True, max_length=512)
    labels = [row[:-1] + [tokenizer.pad_token_id] for row in tokenized_output["input_ids"]]
    tokenized_output["labels"] = labels
    return tokenized_output

tokenized_dataset = code_dataset.map(preprocess_function, batched=True)
tokenized_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

# DataLoader
train_dataset = tokenized_dataset["train"]
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#############################################################################
# 2. Sub-Model Weights
# 1. Transformer with DPO:
tran_dpo = r'C:\Users\robbi\IEEMM\language_model_weights.pth'
# 2. MAMBA:
mamba_model_path = r'C:\Users\robbi\IEEMM\mamba_model_weights.pth'
# 3. Transformer and RAG:
context_encoder = r'C:\Users\robbi\IEEMM\context_encoder.pth'
language_model = r'C:\Users\robbi\IEEMM\language_model.pth'
question_encoder = r'C:\Users\robbi\IEEMM\question_encoder.pth'

# Load model weights function
def load_model_weights(model, model_path):
    checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

    if isinstance(checkpoint, dict):
        # Check for 'state_dict' or 'model_state_dict' keys
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        elif 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            # If no known key is found, try loading it as a raw state dictionary
            try:
                model.load_state_dict(checkpoint)
            except RuntimeError as e:
                raise ValueError(f"Error loading state dict: {e}")
    elif isinstance(checkpoint, nn.Module):
        # If the checkpoint is a model object, assign it directly
        model = checkpoint
    else:
        raise ValueError(f"Unsupported checkpoint format: {type(checkpoint)}")

    model.eval()
    return model
#############################################################################
# 3. MAMBA
# RoPE
class RotaryPositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_len).type_as(inv_freq)
        freqs = torch.einsum('n , d -> n d', t, inv_freq)
        self.register_buffer('sin', freqs.sin())
        self.register_buffer('cos', freqs.cos())

    def forward(self, x):
        n, _, device = x.shape[1], self.dim // 2, x.device
        sin, cos = self.sin[:n].to(device), self.cos[:n].to(device)

        # Apply RoPE to even and odd indices separately
        x_even = x[..., :self.dim:2] * cos.unsqueeze(0) + torch.roll(x[..., 1:self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
        x_odd = x[..., 1:self.dim:2] * cos.unsqueeze(0) - torch.roll(x[..., :self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
        return torch.cat((x_even, x_odd), dim=-1)

# SWIGLU
class SwiGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(SwiGLU, self).__init__()
        self.fc1 = nn.Linear(dim_in, dim_out)
        self.fc2 = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        gate = torch.sigmoid(self.fc2(x))
        return self.fc1(x) * gate

class SimplifiedMAMBA(nn.Module):
    # Adjusted to include SwiGLU blocks
    def __init__(self, num_layers, d_model, d_state, d_conv, expansion_factor):
        super().__init__()

        self.num_layers = num_layers
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_state),
            nn.GELU(),
            nn.Linear(d_state, d_model)
        )
        self.input_embedding = nn.Linear(d_model, d_model)
        self.convs = nn.Sequential(*[nn.Conv1d(d_model, d_model, kernel_size=d_conv, padding=(d_conv // 2)) for _ in range(num_layers)])
        self.swiglu = SwiGLU(d_model, d_model)
        self.output_projection = nn.Linear(d_model, d_model * expansion_factor)  # Adjusted to match the output of SwiGLU


        self.initialize_weights()

    def initialize_weights(self):
        gain = nn.init.calculate_gain('relu')

        nn.init.orthogonal_(self.input_embedding.weight, gain)
        nn.init.normal_(self.input_embedding.bias, mean=0, std=0.01)

        nn.init.kaiming_uniform_(self.convs[-1].weight, a=math.sqrt(5))
        nn.init.zeros_(self.convs[-1].bias)

        nn.init.xavier_uniform_(self.feedforward[0].weight, gain=nn.init.calculate_gain('relu'))
        nn.init.zeros_(self.feedforward[0].bias)

        nn.init.xavier_uniform_(self.feedforward[2].weight, gain=nn.init.calculate_gain('linear'))
        nn.init.zeros_(self.feedforward[2].bias)

        nn.init.xavier_uniform_(self.output_projection.weight, gain=nn.init.calculate_gain('linear'))
        nn.init.zeros_(self.output_projection.bias)

    def forward(self, inputs, attention_mask=None):
        print("Input shape:", inputs.shape)

        # Apply the attention mask if provided
        if attention_mask is not None:
            inputs = inputs * attention_mask.unsqueeze(-1)

        projected_inputs = self.input_embedding(inputs)
        print("projected_inputs pre-reshape shape:", projected_inputs.shape)

        projected_inputs = projected_inputs.permute(0, 2, 1)
        print("projected_inputs post-reshape shape:", projected_inputs.shape)

        for conv in self.convs:
            projected_inputs = conv(projected_inputs)

        projected_inputs = projected_inputs.permute(0, 2, 1)
        print("projected_inputs post convolution reshape:", projected_inputs.shape)

        projected_inputs = self.swiglu(projected_inputs)
        print("projected_inputs post swiglu shape:", projected_inputs.shape)

        output = self.output_projection(projected_inputs)
        print("output shape:", output.shape)

        return output

class SimplifiedLanguageModelMAMBA(nn.Module):
    # Including rotary positional encodings if required
    def __init__(self, vocab_size, num_layers, d_model, d_state, d_conv, expansion_factor):
        super().__init__()

        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = RotaryPositionalEncoding(d_model)
        self.simplified_mamba = SimplifiedMAMBA(num_layers, d_model, d_state, d_conv, expansion_factor)
        self.output_projection = nn.Linear(d_model*2, vocab_size)

        self.initialize_weights()

    def initialize_weights(self):
        gain = 1.0

        nn.init.orthogonal_(self.embedding.weight, gain)
        nn.init.xavier_uniform_(self.output_projection.weight, gain=gain)

    def forward(self, input_values, attention_mask):
        embedded = self.embedding(input_values) * math.sqrt(self.d_model)
        embedded = self.pos_encoder(embedded)
        simplified_mamba_output = self.simplified_mamba(embedded, attention_mask)
        logits = self.output_projection(simplified_mamba_output)
        return logits


LEARNING_RATE = 5e-4
WEIGHT_DECAY =  0.1
WARMUP_STEPS = 100
TOTAL_STEPS = 1000 # we want this to be : epochs * (size of dataset / batch_size )
EPOCHS = 100
VOCAB_SIZE = 30522
NUM_LAYERS = 4
BATCH_SIZE = 8
EXPANSION_FACTOR = 2
CLIP_GRADIENT = 1.0
D_MODEL = 512  # Dimensionality of the model's embeddings
D_STATE = 2048  # Dimensionality of the intermediate state in feedforward
D_CONV = 3  # Kernel size for convolutional layers


def setup_optimizer(model, learning_rate, weight_decay, warmup_steps, total_steps):
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Linear warmup with cosine decay
    scheduler = LambdaLR(optimizer, lambda step: min((step + 1) / warmup_steps, 0.5 * (1 + math.cos(math.pi * step / total_steps))))

    return optimizer, scheduler


# Initialize the optimizer and scheduler with appropriate parameters
#mamba_optimizer, mamba_scheduler = setup_optimizer(mamba_model, LEARNING_RATE, WEIGHT_DECAY, WARMUP_STEPS, TOTAL_STEPS)


#######################################################################################
# 4. RAG
def extract_text_from_pdf(file_path):
    text = ""
    with fitz.open(file_path) as doc:
        for page in doc:
            text += page.get_text()
    return text


def split_into_chunks(text, chunk_size):
    # Split the text into chunks of chunk_size
    return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

def preprocess_text(text, max_length=512):
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Split the text into smaller chunks to maintain context
    # The chunk size is slightly less than max_length to account for special tokens
    chunk_size = max_length - 50  # Adjust this value based on the model's requirements
    text_chunks = split_into_chunks(text, chunk_size)

    # Process each chunk
    processed_chunks = []
    for chunk in text_chunks:
        tokenized_output = tokenizer(chunk, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
        processed_chunk = {
            'input_ids': tokenized_output['input_ids'],
            'attention_mask': tokenized_output['attention_mask']
        }
        processed_chunks.append(processed_chunk)

    return processed_chunks


def create_dataset_from_pdfs(pdf_file_paths):
    dataset = []
    for file_path in pdf_file_paths:
        text = extract_text_from_pdf(file_path)
        processed_text = preprocess_text(text)
        dataset.append(processed_text)
    return dataset

def retrieve_contexts(dataset, query_embedding, top_k=5):
    # Assume dataset is a list of dictionaries with each dictionary containing 'input_ids' and 'attention_mask'
    # for a particular context and that each context has been processed through a DPRContextEncoder to get embeddings

    # Placeholder for storing similarity scores
    similarity_scores = []

    # Iterate over each context in the dataset
    for context in dataset:
        context_input_ids = context['input_ids']
        context_attention_mask = context['attention_mask']

        # Assuming context_encoder is an instance of CustomDPRContextEncoder that's already trained
        # and available in your scope
        context_embedding = context_encoder(context_input_ids, context_attention_mask)

        # Compute similarity (e.g., using dot product)
        similarity = torch.matmul(query_embedding, context_embedding.T)

        similarity_scores.append(similarity.squeeze().item())

    # Sort contexts based on similarity scores and retrieve top_k indices
    top_k_indices = sorted(range(len(similarity_scores)), key=lambda i: similarity_scores[i], reverse=True)[:top_k]

    # Retrieve top_k contexts
    top_contexts = [dataset[i] for i in top_k_indices]

    return top_contexts

def rag_retrieve_and_generate(dataset, query):
    # Instantiate the question encoder
    question_encoder = DPRQuestionEncoder()

    # Encode the query
    encoded_query = question_encoder(query)

    # Retrieve relevant context
    # This involves finding the most similar documents in the dataset
    # For simplicity, this is represented as a function 'retrieve_contexts'
    relevant_contexts = retrieve_contexts(dataset, encoded_query)

    # Language model for generation
    language_model = LanguageModelTransformer()

    # Generate a response based on the retrieved contexts
    # This step may involve further formatting or preprocessing
    response = language_model.generate_response(relevant_contexts)

    return response

# pdfs
pdf_file_paths = [r'C:\Users\robbi\IEEMM\DPO.pdf', 
                  r'C:\Users\robbi\IEEMM\MAMBA.pdf',
                  r'C:\Users\robbi\IEEMM\QLORA.pdf',
                  r'C:\Users\robbi\IEEMM\RAG.pdf',
                  r'C:\Users\robbi\IEEMM\SWITCH_TRANSFORMER.pdf']

rag_dataset = create_dataset_from_pdfs(pdf_file_paths)

class CustomDPRContextEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
        super(CustomDPRContextEncoder, self).__init__()
        # Transformer-based model, e.g., BERT
        self.bert_model = BertModel.from_pretrained(model_name)
        # Additional layer to produce fixed-size embeddings
        self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

    def forward(self, input_ids, attention_mask=None):
        # Generate outputs from the BERT model
        outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        # Use the pooled output for creating embeddings
        pooled_output = outputs.pooler_output
        # Pass through the embedding layer
        context_embeddings = self.embedding_layer(pooled_output)
        return context_embeddings

class DPRQuestionEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
        super(DPRQuestionEncoder, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.embedding_layer = nn.Linear(self.bert.config.hidden_size, embedding_dim)

    def forward(self, input_ids, attention_mask, **kwargs):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        embeddings = self.embedding_layer(pooled_output)
        return embeddings

# TransformerRAG class
class TransformerRAG(nn.Module):
    def __init__(self, context_encoder_path, language_model_path, question_encoder_path, vocab_size):
        super(TransformerRAG, self).__init__()
        self.context_encoder = CustomDPRContextEncoder()
        #self.context_encoder = load_model_weights(self.context_encoder, context_encoder_path)
        self.language_model = LanguageModelTransformer(
            vocab_size=vocab_size,
            embed_size=256, 
            num_layers=6, 
            forward_expansion=4, 
            heads=8, 
            dropout=0, 
            max_length=100,  # Set to 512 to match the tokenization max_length
            rank=16
        )        
        self.language_model = load_model_weights(self.language_model, language_model_path)
        self.question_encoder = DPRQuestionEncoder()
        #self.question_encoder = load_model_weights(self.question_encoder, question_encoder_path)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def forward(self, context_texts, question_input_ids, question_attention_mask, question_text):
        print(f"question_input_ids: {question_input_ids.shape}")
        print(f"question_attention_mask: {question_attention_mask.shape}")
        print(f"context_texts: {len(context_texts)}")  # Updated this line
        print(f"question_text: {question_text}")
        # Convert question_input_ids and question_attention_mask to LongTensor if they are not already
        question_input_ids = question_input_ids.long()
        question_attention_mask = question_attention_mask.long()
        question_embeddings = self.question_encoder(input_ids=question_input_ids, attention_mask=question_attention_mask)
        context_embeddings = []
        for context_text in context_texts:
            tokenized_context = self.tokenizer(context_text, return_tensors="pt", padding=True, truncation=True, max_length=512, add_special_tokens=True)
            context_embedding = self.context_encoder(**tokenized_context)
            context_embeddings.append(context_embedding)

        cos_sim = torch.nn.CosineSimilarity(dim=1)
        similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in context_embeddings]
        most_relevant_context_idx = torch.argmax(torch.tensor(similarities))

        # Use the provided question_text
        combined_input = question_text + " " + context_texts[most_relevant_context_idx]
        tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
        response_logits = self.language_model(**tokenized_combined_input)
        probabilities = F.softmax(response_logits.logits, dim=-1)
        predicted_token_ids = torch.argmax(probabilities, dim=-1)
        predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])
        response = self.tokenizer.convert_tokens_to_string(predicted_tokens)

        return response

###########################################################################
# 5. LanguageModelTransformer

class LORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1):
        super(LORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha

        # Original weight and bias of the linear layer
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        #print("self.weight Shape:", self.weight.shape)
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # LORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def forward(self, x):
        #print("LORALayer Input Shape:", x.shape)
        
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)
        
        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        #print("LORALayer Output Shape:", x.shape)

        return x

class QLORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
        super(QLORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha
        self.quantization_bits = quantization_bits

        # Original weight and bias
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # QLORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(output_dim)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def quantize(self, x, num_bits):
        # Implement a simple quantization method
        scale = x.abs().max()
        x_quantized = torch.round(x / scale * (2**num_bits - 1))
        return x_quantized, scale

    def forward(self, x):
        #print("QLORALayer Input Shape:", x.shape)
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
        B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
        lora_adjustment = self.dropout(lora_adjustment)
        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)

        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        x = self.layer_norm(x)

        #print("QLORALayer Output Shape:", x.shape)

        return x
    
    def update_alpha(self, new_alpha):
        """
        Update the alpha scaling factor.
        """
        self.alpha = new_alpha

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does the matrix multiplication for query*keys for each training example
        # with every other training example, then sum it up
        attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            attention = attention.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            LORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            LORALayer(forward_expansion * embed_size, embed_size, rank),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class LanguageModelDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
        super(LanguageModelDecoder, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        # Adding BatchNorm layers
        self.bn1 = nn.BatchNorm1d(embed_size)
        self.bn2 = nn.BatchNorm1d(embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                for _ in range(num_layers)
            ]
        )

        # QLORA layers
        self.qlora_feed_forward = nn.Sequential(
            QLORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            QLORALayer(forward_expansion * embed_size, embed_size, rank),
        )
        self.use_qlora = False  # Flag to toggle QLORA

        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length) #.to(x.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn1(x)
        x = x.transpose(1, 2)

        for layer in self.layers:
            x = layer(x, x, x, trg_mask)
            if self.use_qlora:
                x = self.qlora_feed_forward(x)

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn2(x)
        x = x.transpose(1, 2)

        out = self.fc_out(x)
        #print(f"shape of output of forward method of LanguageModelDecoder: {out.shape} ")

        return out

    def toggle_qlora(self, use_qlora: bool):
        self.use_qlora = use_qlora

class LanguageModelTransformer(nn.Module):
    def __init__(self, vocab_size, embed_size=256, num_layers=6, forward_expansion=4, heads=8, dropout=0, max_length=100, rank=16):
        super(LanguageModelTransformer, self).__init__()

        self.decoder = LanguageModelDecoder(
            vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
            rank,
        )

    def forward(self, trg):
        trg_mask = self.make_trg_mask(trg)
        out = self.decoder(trg, trg_mask)
        return out

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        ).to(trg.device)

        return trg_mask
    
    # Function to enable or disable QLORA layers (for fine-tuning purposes)
    def toggle_qlora(self, use_qlora: bool):
        self.decoder.toggle_qlora(use_qlora)

    def generate_response(self, input_ids, attention_mask):
        # Assuming you have a forward method that returns logits
        logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
        
        # Convert logits to probabilities
        probabilities = F.softmax(logits, dim=-1)
        
        # For simplicity, using greedy decoding here. You might want to use beam search or sampling.
        predicted_token_id = torch.argmax(probabilities, dim=-1)
        
        # Convert predicted token ids to tokens
        predicted_tokens = [tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
        
        # Join tokens to form the response string. This is a very basic way to generate text and might not produce the best results.
        response = tokenizer.convert_tokens_to_string(predicted_tokens)
        
        return response

###########################################################################
# 7. Internal Switch Routing

# Auxiliary loss function
def auxiliary_loss(gate_scores, expert_capacity):
    expert_load = gate_scores.sum(0) / gate_scores.size(0)
    loss_balancing = torch.std(expert_load)
    return loss_balancing

# Routing function
CAPACITY_FACTOR = 1

def route_inputs(expert_indices, gate_scores, num_experts):
    capacity_factor_tensor = torch.tensor([CAPACITY_FACTOR], dtype=torch.float32)
    capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
    expert_counts = torch.zeros(num_experts, dtype=torch.int32)
    for idx in range(len(expert_indices)):
        selected_expert = expert_indices[idx]
        if expert_counts[selected_expert] < capacities[0]:
            expert_counts[selected_expert] += 1
        else:
            available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
            if len(available_experts) > 0:
                alternative_expert = available_experts[0]
                expert_indices[idx] = alternative_expert
                expert_counts[alternative_expert] += 1
            else:
                print("No available experts to reroute. Handling overflow.")
    return expert_indices

# SwitchGate 
class SwitchGate(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(SwitchGate, self).__init__()
        self.fc1 = nn.Linear(input_dim, input_dim // 2)
        self.fc2 = nn.Linear(input_dim // 2, num_experts)

    def forward(self, x):
        x = F.relu(self.fc1(x.float()))
        gate_scores = F.softmax(self.fc2(x), dim=-1)
        return gate_scores

# SwitchRouter 
class SwitchRouter(nn.Module):
    def __init__(self, input_dim, num_experts, mamba_model_path, context_encoder_path, language_model_path, question_encoder_path, dpo_model_path, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank):
        super(SwitchRouter, self).__init__()
        self.router = SwitchGate(input_dim, num_experts)
        self.transformer_rag = TransformerRAG(context_encoder_path, 
                                              language_model_path, 
                                              question_encoder_path, 
                                              vocab_size)
        self.transformer_dpo = LanguageModelTransformer(vocab_size, 
                                                        embed_size, 
                                                        num_layers, 
                                                        forward_expansion, 
                                                        heads, dropout, 
                                                        max_length, 
                                                        rank)
        self.transformer_dpo = load_model_weights(self.transformer_dpo, dpo_model_path)
        self.mamba = SimplifiedLanguageModelMAMBA(vocab_size=VOCAB_SIZE, 
                                     num_layers=NUM_LAYERS, 
                                     d_model=D_MODEL, 
                                     d_state=D_STATE, 
                                     d_conv=D_CONV, 
                                     expansion_factor=EXPANSION_FACTOR)
        self.mamba = load_model_weights(self.mamba, mamba_model_path)
        self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
        self.input_embedding = nn.Linear(512, input_dim)

    def forward(self, x, attention_mask, context_texts, question_text):
        x = self.input_embedding(x.float())
        gate_scores = self.router(x)
        expert_indices = torch.argmax(gate_scores, dim=1)
        expert_indices = route_inputs(expert_indices, gate_scores, len(self.experts))
        final_output = torch.zeros_like(x)
        aux_loss = 0

        for i, expert in enumerate(self.experts):
            mask = expert_indices == i
            if mask.any():
                selected_inputs = x[mask]
                selected_attention_mask = attention_mask[mask]

                if isinstance(expert, TransformerRAG):
                    # Now passing the required arguments to TransformerRAG
                    expert_output = self.transformer_rag(context_texts, selected_inputs, selected_attention_mask, question_text)
                else:
                    # Process as usual for other experts
                    expert_output = expert(selected_inputs, selected_attention_mask)

                final_output[mask] = expert_output

        # Compute auxiliary loss for load balancing
        aux_loss += auxiliary_loss(gate_scores, expert_capacity=torch.tensor([CAPACITY_FACTOR] * len(self.experts)))

        return final_output, aux_loss

###########################################################################
# 8.Training loop
input_dim = 512
num_experts = 3


model = SwitchRouter(input_dim, num_experts, mamba_model_path, context_encoder, language_model, question_encoder, tran_dpo, 30522, 256, 6, 4, 8, 0.1, 100, 16).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
main_loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
aux_loss_weight = 0.1

# Ensure train_data is accessible here, with 'queries' and 'contexts' keys
train_data = {
    "queries": [
        # Queries for DPO.pdf
        "What is Direct Preference Optimization (DPO)?",
        "How does Direct Preference Optimization work?",
        "How can I implement Direct Preference Optimization in my organization?",
        "Why does Direct Preference Optimization improve the efficiency of language modelling?",
        # Queries for MAMBA.pdf
        "What is MAMBA?",
        "How does MAMBA function?",
        "How can I build a system based on MAMBA technology?",
        "Why does MAMBA enhance the performance of its application area?",
        # Queries for QLORA.pdf
        "What is QLORA?",
        "How does QLORA operate?",
        "How can I develop a project using QLORA?",
        "Why does QLORA improve the capabilities of its relevant field?",
        # Queries for RAG.pdf
        "What is Retrieval Augmented Generation (RAG)?",
        "How does Retrieval Augmented Generation work?",
        "How can I build a Retrieval Augmented Generation model?",
        "Why does Retrieval Augmented Generation enhance language model performance?",
        # Queries for SWITCH_TRANSFORMER.pdf
        "What is the Switch Transformer model?",
        "How does the Switch Transformer model operate?",
        "How can I construct a Switch Transformer model?",
        "Why does the Switch Transformer model improve language processing tasks?"
    ],
    "contexts": [
        # Contexts from DPO.pdf
        rag_dataset[0],  # Assuming dataset[0] is the processed content of DPO.pdf
        rag_dataset[0],
        rag_dataset[0],
        rag_dataset[0],
        # Contexts from MAMBA.pdf
        rag_dataset[1],  # Assuming dataset[1] is the processed content of MAMBA.pdf
        rag_dataset[1],
        rag_dataset[1],
        rag_dataset[1],
        # Contexts from QLORA.pdf
        rag_dataset[2],  # Assuming dataset[2] is the processed content of QLORA.pdf
        rag_dataset[2],
        rag_dataset[2],
        rag_dataset[2],
        # Contexts from RAG.pdf
        rag_dataset[3],  # Assuming dataset[3] is the processed content of RAG.pdf
        rag_dataset[3],
        rag_dataset[3],
        rag_dataset[3],
        # Contexts from SWITCH_TRANSFORMER.pdf
        rag_dataset[4],  # Assuming dataset[4] is the processed content of SWITCH_TRANSFORMER.pdf
        rag_dataset[4],
        rag_dataset[4],
        rag_dataset[4]
    ]
}
model.train()
num_epochs = 5
for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, batch in enumerate(train_loader):
        inputs, attention_mask, targets = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)

        batch_size = inputs.size(0)
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size
        current_queries = train_data['queries'][start_idx:end_idx]
        current_contexts = train_data['contexts'][start_idx:end_idx]  # Use the context chunks directly

        # Now passing the actual queries and contexts
        outputs, aux_loss = model(inputs, attention_mask, current_contexts, current_queries)

        main_loss = main_loss_function(outputs.view(-1, outputs.size(-1)), targets.view(-1))
        total_loss = main_loss + aux_loss_weight * aux_loss
        total_loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {total_loss.item()}')

    print(f'End of Epoch {epoch+1}, Average Loss: {total_loss / len(train_loader)}')




KeyboardInterrupt: 

# v5 cpu version of code to check it works

In [6]:
# 0. Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import fitz  # PyMuPDF
from transformers import BertModel
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

torch.autograd.set_detect_anomaly(True)

#############################################################################
# 1. Preprocessing Data

# Load the dataset
code_dataset = load_dataset("code_search_net", "python")

# Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Preprocess and tokenize dataset
def preprocess_function(examples):
    tokenized_output = tokenizer(examples['func_code_string'], padding="max_length", truncation=True, max_length=512)
    labels = [row[:-1] + [tokenizer.pad_token_id] for row in tokenized_output["input_ids"]]
    tokenized_output["labels"] = labels
    return tokenized_output

tokenized_dataset = code_dataset.map(preprocess_function, batched=True)
tokenized_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

# DataLoader
train_dataset = tokenized_dataset["train"]
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"True vocab size: {tokenizer.vocab_size}")

#############################################################################
# 2. Sub-Model Weights
# 1. Transformer with DPO:
tran_dpo = r'C:\Users\robbi\IEEMM\language_model_weights.pth'
# 2. MAMBA:
mamba_model_path = r'C:\Users\robbi\IEEMM\mamba_model_weights.pth'
# 3. Transformer and RAG:
context_encoder = r'C:\Users\robbi\IEEMM\context_encoder.pth'
language_model = r'C:\Users\robbi\IEEMM\language_model.pth'
question_encoder = r'C:\Users\robbi\IEEMM\question_encoder.pth'

# Load model weights function
def load_model_weights(model, model_path):
    #checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    checkpoint = torch.load(model_path, map_location= 'cpu')

    if isinstance(checkpoint, dict):
        # Check for 'state_dict' or 'model_state_dict' keys
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        elif 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            # If no known key is found, try loading it as a raw state dictionary
            try:
                model.load_state_dict(checkpoint)
            except RuntimeError as e:
                raise ValueError(f"Error loading state dict: {e}")
    elif isinstance(checkpoint, nn.Module):
        # If the checkpoint is a model object, assign it directly
        model = checkpoint
    else:
        raise ValueError(f"Unsupported checkpoint format: {type(checkpoint)}")

    model.eval()
    return model
#############################################################################
# 3. MAMBA
# RoPE
class RotaryPositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_len).type_as(inv_freq)
        freqs = torch.einsum('n , d -> n d', t, inv_freq)
        self.register_buffer('sin', freqs.sin())
        self.register_buffer('cos', freqs.cos())

    def forward(self, x):
        n, _, device = x.shape[1], self.dim // 2, x.device
        sin, cos = self.sin[:n].to(device), self.cos[:n].to(device)

        # Apply RoPE to even and odd indices separately
        x_even = x[..., :self.dim:2] * cos.unsqueeze(0) + torch.roll(x[..., 1:self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
        x_odd = x[..., 1:self.dim:2] * cos.unsqueeze(0) - torch.roll(x[..., :self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
        return torch.cat((x_even, x_odd), dim=-1)

# SWIGLU
class SwiGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(SwiGLU, self).__init__()
        self.fc1 = nn.Linear(dim_in, dim_out)
        self.fc2 = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        gate = torch.sigmoid(self.fc2(x))
        return self.fc1(x) * gate

class SimplifiedMAMBA(nn.Module):
    # Adjusted to include SwiGLU blocks
    def __init__(self, num_layers, d_model, d_state, d_conv, expansion_factor):
        super().__init__()

        self.num_layers = num_layers
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_state),
            nn.GELU(),
            nn.Linear(d_state, d_model)
        )
        self.input_embedding = nn.Linear(d_model, d_model)
        self.convs = nn.Sequential(*[nn.Conv1d(d_model, d_model, kernel_size=d_conv, padding=(d_conv // 2)) for _ in range(num_layers)])
        self.swiglu = SwiGLU(d_model, d_model)
        self.output_projection = nn.Linear(d_model, d_model * expansion_factor)  # Adjusted to match the output of SwiGLU


        self.initialize_weights()

    def initialize_weights(self):
        gain = nn.init.calculate_gain('relu')

        nn.init.orthogonal_(self.input_embedding.weight, gain)
        nn.init.normal_(self.input_embedding.bias, mean=0, std=0.01)

        nn.init.kaiming_uniform_(self.convs[-1].weight, a=math.sqrt(5))
        nn.init.zeros_(self.convs[-1].bias)

        nn.init.xavier_uniform_(self.feedforward[0].weight, gain=nn.init.calculate_gain('relu'))
        nn.init.zeros_(self.feedforward[0].bias)

        nn.init.xavier_uniform_(self.feedforward[2].weight, gain=nn.init.calculate_gain('linear'))
        nn.init.zeros_(self.feedforward[2].bias)

        nn.init.xavier_uniform_(self.output_projection.weight, gain=nn.init.calculate_gain('linear'))
        nn.init.zeros_(self.output_projection.bias)

    def forward(self, inputs, attention_mask=None):
        print("Input shape:", inputs.shape)

        # Apply the attention mask if provided
        if attention_mask is not None:
            inputs = inputs * attention_mask.unsqueeze(-1)

        projected_inputs = self.input_embedding(inputs)
        print("projected_inputs pre-reshape shape:", projected_inputs.shape)

        projected_inputs = projected_inputs.permute(0, 2, 1)
        print("projected_inputs post-reshape shape:", projected_inputs.shape)

        for conv in self.convs:
            projected_inputs = conv(projected_inputs)

        projected_inputs = projected_inputs.permute(0, 2, 1)
        print("projected_inputs post convolution reshape:", projected_inputs.shape)

        projected_inputs = self.swiglu(projected_inputs)
        print("projected_inputs post swiglu shape:", projected_inputs.shape)

        output = self.output_projection(projected_inputs)
        print("output shape:", output.shape)

        return output

class SimplifiedLanguageModelMAMBA(nn.Module):
    # Including rotary positional encodings if required
    def __init__(self, vocab_size, num_layers, d_model, d_state, d_conv, expansion_factor):
        super().__init__()

        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = RotaryPositionalEncoding(d_model)
        self.simplified_mamba = SimplifiedMAMBA(num_layers, d_model, d_state, d_conv, expansion_factor)
        self.output_projection = nn.Linear(d_model*2, vocab_size)

        self.initialize_weights()

    def initialize_weights(self):
        gain = 1.0

        nn.init.orthogonal_(self.embedding.weight, gain)
        nn.init.xavier_uniform_(self.output_projection.weight, gain=gain)

    def forward(self, input_values, attention_mask):
        embedded = self.embedding(input_values) * math.sqrt(self.d_model)
        embedded = self.pos_encoder(embedded)
        simplified_mamba_output = self.simplified_mamba(embedded, attention_mask)
        logits = self.output_projection(simplified_mamba_output)
        return logits


LEARNING_RATE = 5e-4
WEIGHT_DECAY =  0.1
WARMUP_STEPS = 100
TOTAL_STEPS = 1000 # we want this to be : epochs * (size of dataset / batch_size )
EPOCHS = 100
VOCAB_SIZE = 30522
NUM_LAYERS = 4
BATCH_SIZE = 8
EXPANSION_FACTOR = 2
CLIP_GRADIENT = 1.0
D_MODEL = 512  # Dimensionality of the model's embeddings
D_STATE = 2048  # Dimensionality of the intermediate state in feedforward
D_CONV = 3  # Kernel size for convolutional layers


def setup_optimizer(model, learning_rate, weight_decay, warmup_steps, total_steps):
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Linear warmup with cosine decay
    scheduler = LambdaLR(optimizer, lambda step: min((step + 1) / warmup_steps, 0.5 * (1 + math.cos(math.pi * step / total_steps))))

    return optimizer, scheduler


# Initialize the optimizer and scheduler with appropriate parameters
#mamba_optimizer, mamba_scheduler = setup_optimizer(mamba_model, LEARNING_RATE, WEIGHT_DECAY, WARMUP_STEPS, TOTAL_STEPS)


#######################################################################################
# 4. RAG
def extract_text_from_pdf(file_path):
    text = ""
    with fitz.open(file_path) as doc:
        for page in doc:
            text += page.get_text()
    return text


def split_into_chunks(text, chunk_size):
    # Split the text into chunks of chunk_size
    return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

def preprocess_text(text, max_length=512):
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Split the text into smaller chunks to maintain context
    # The chunk size is slightly less than max_length to account for special tokens
    chunk_size = max_length - 50  # Adjust this value based on the model's requirements
    text_chunks = split_into_chunks(text, chunk_size)

    # Process each chunk
    processed_chunks = []
    for chunk in text_chunks:
        tokenized_output = tokenizer(chunk, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
        processed_chunk = {
            'input_ids': tokenized_output['input_ids'],
            'attention_mask': tokenized_output['attention_mask']
        }
        processed_chunks.append(processed_chunk)

    return processed_chunks


def create_dataset_from_pdfs(pdf_file_paths):
    dataset = []
    for file_path in pdf_file_paths:
        text = extract_text_from_pdf(file_path)
        processed_text = preprocess_text(text)
        dataset.append(processed_text)
    return dataset

def retrieve_contexts(dataset, query_embedding, top_k=5):
    # Assume dataset is a list of dictionaries with each dictionary containing 'input_ids' and 'attention_mask'
    # for a particular context and that each context has been processed through a DPRContextEncoder to get embeddings

    # Placeholder for storing similarity scores
    similarity_scores = []

    # Iterate over each context in the dataset
    for context in dataset:
        context_input_ids = context['input_ids']
        context_attention_mask = context['attention_mask']

        # Assuming context_encoder is an instance of CustomDPRContextEncoder that's already trained
        # and available in your scope
        context_embedding = context_encoder(context_input_ids, context_attention_mask)

        # Compute similarity (e.g., using dot product)
        similarity = torch.matmul(query_embedding, context_embedding.T)

        similarity_scores.append(similarity.squeeze().item())

    # Sort contexts based on similarity scores and retrieve top_k indices
    top_k_indices = sorted(range(len(similarity_scores)), key=lambda i: similarity_scores[i], reverse=True)[:top_k]

    # Retrieve top_k contexts
    top_contexts = [dataset[i] for i in top_k_indices]

    return top_contexts

def rag_retrieve_and_generate(dataset, query):
    # Instantiate the question encoder
    question_encoder = DPRQuestionEncoder()

    # Encode the query
    encoded_query = question_encoder(query)

    # Retrieve relevant context
    # This involves finding the most similar documents in the dataset
    # For simplicity, this is represented as a function 'retrieve_contexts'
    relevant_contexts = retrieve_contexts(dataset, encoded_query)

    # Language model for generation
    language_model = LanguageModelTransformer()

    # Generate a response based on the retrieved contexts
    # This step may involve further formatting or preprocessing
    response = language_model.generate_response(relevant_contexts)

    return response

# pdfs
pdf_file_paths = [r'C:\Users\robbi\IEEMM\DPO.pdf', 
                  r'C:\Users\robbi\IEEMM\MAMBA.pdf',
                  r'C:\Users\robbi\IEEMM\QLORA.pdf',
                  r'C:\Users\robbi\IEEMM\RAG.pdf',
                  r'C:\Users\robbi\IEEMM\SWITCH_TRANSFORMER.pdf']

rag_dataset = create_dataset_from_pdfs(pdf_file_paths)

class CustomDPRContextEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
        super(CustomDPRContextEncoder, self).__init__()
        # Transformer-based model, e.g., BERT
        self.bert_model = BertModel.from_pretrained(model_name)
        # Additional layer to produce fixed-size embeddings
        self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

    def forward(self, input_ids, attention_mask=None):
        # Generate outputs from the BERT model
        outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        # Use the pooled output for creating embeddings
        pooled_output = outputs.pooler_output
        # Pass through the embedding layer
        context_embeddings = self.embedding_layer(pooled_output)
        return context_embeddings

class DPRQuestionEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
        super(DPRQuestionEncoder, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.embedding_layer = nn.Linear(self.bert.config.hidden_size, embedding_dim)

    def forward(self, input_ids, attention_mask, **kwargs):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        embeddings = self.embedding_layer(pooled_output)
        return embeddings

# TransformerRAG class
class TransformerRAG(nn.Module):
    def __init__(self, context_encoder_path, language_model_path, question_encoder_path, vocab_size):
        super(TransformerRAG, self).__init__()
        self.context_encoder = CustomDPRContextEncoder()
        self.language_model = LanguageModelTransformer(
            vocab_size=vocab_size,
            embed_size=256, 
            num_layers=6, 
            forward_expansion=4, 
            heads=8, 
            dropout=0, 
            max_length=100,  # Set to 512 to match the tokenization max_length
            rank=16
        )
        self.language_model = load_model_weights(self.language_model, language_model_path)
        self.question_encoder = DPRQuestionEncoder()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def forward(self, context_texts, question_input_ids, question_attention_mask, question_text):
        if question_input_ids.max() >= self.tokenizer.vocab_size:
            raise ValueError("question_input_ids contain ID(s) beyond the tokenizer's vocabulary size")

        # Process each context_text
        aggregated_context_embeddings = []
        for context_list in context_texts:
            if not all(isinstance(context, dict) for context in context_list):
                raise TypeError("Each item in context_texts must be a list of tokenized context dictionaries")

            # Create a tensor of zeros with the correct shape
            aggregated_context_embedding = torch.zeros(self.context_encoder.bert_model.config.hidden_size)
            for context in context_list:
                context_input_ids = context['input_ids']
                context_attention_mask = context['attention_mask']
                context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                aggregated_context_embedding += context_embedding.mean(dim=0)

            aggregated_context_embeddings.append(aggregated_context_embedding / len(context_list))


        print(f"question_input_ids: {question_input_ids.shape}")
        print(f"question_attention_mask: {question_attention_mask.shape}")
        print(f"context_texts: {len(context_texts)}")
        print(f"question_text: {question_text}")

        question_input_ids = question_input_ids.long()
        question_attention_mask = question_attention_mask.long()
        question_embeddings = self.question_encoder(input_ids=question_input_ids, attention_mask=question_attention_mask)

        cos_sim = torch.nn.CosineSimilarity(dim=1)
        similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in aggregated_context_embeddings]
        most_relevant_context_idx = torch.argmax(torch.tensor(similarities))

        combined_input = question_text + " " + context_texts[most_relevant_context_idx]
        tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
        response_logits = self.language_model(**tokenized_combined_input)
        probabilities = F.softmax(response_logits.logits, dim=-1)
        predicted_token_ids = torch.argmax(probabilities, dim=-1)
        predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])
        response = self.tokenizer.convert_tokens_to_string(predicted_tokens)

        return response
###########################################################################
# 5. LanguageModelTransformer

class LORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1):
        super(LORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha

        # Original weight and bias of the linear layer
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        #print("self.weight Shape:", self.weight.shape)
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # LORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def forward(self, x):
        #print("LORALayer Input Shape:", x.shape)
        
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)
        
        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        #print("LORALayer Output Shape:", x.shape)

        return x

class QLORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
        super(QLORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha
        self.quantization_bits = quantization_bits

        # Original weight and bias
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # QLORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(output_dim)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def quantize(self, x, num_bits):
        # Implement a simple quantization method
        scale = x.abs().max()
        x_quantized = torch.round(x / scale * (2**num_bits - 1))
        return x_quantized, scale

    def forward(self, x):
        #print("QLORALayer Input Shape:", x.shape)
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
        B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
        lora_adjustment = self.dropout(lora_adjustment)
        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)

        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        x = self.layer_norm(x)

        #print("QLORALayer Output Shape:", x.shape)

        return x
    
    def update_alpha(self, new_alpha):
        """
        Update the alpha scaling factor.
        """
        self.alpha = new_alpha

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does the matrix multiplication for query*keys for each training example
        # with every other training example, then sum it up
        attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            attention = attention.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            LORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            LORALayer(forward_expansion * embed_size, embed_size, rank),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class LanguageModelDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
        super(LanguageModelDecoder, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        # Adding BatchNorm layers
        self.bn1 = nn.BatchNorm1d(embed_size)
        self.bn2 = nn.BatchNorm1d(embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                for _ in range(num_layers)
            ]
        )

        # QLORA layers
        self.qlora_feed_forward = nn.Sequential(
            QLORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            QLORALayer(forward_expansion * embed_size, embed_size, rank),
        )
        self.use_qlora = False  # Flag to toggle QLORA

        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length) #.to(x.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn1(x)
        x = x.transpose(1, 2)

        for layer in self.layers:
            x = layer(x, x, x, trg_mask)
            if self.use_qlora:
                x = self.qlora_feed_forward(x)

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn2(x)
        x = x.transpose(1, 2)

        out = self.fc_out(x)
        #print(f"shape of output of forward method of LanguageModelDecoder: {out.shape} ")

        return out

    def toggle_qlora(self, use_qlora: bool):
        self.use_qlora = use_qlora

class LanguageModelTransformer(nn.Module):
    def __init__(self, vocab_size, embed_size=256, num_layers=6, forward_expansion=4, heads=8, dropout=0, max_length=100, rank=16):
        super(LanguageModelTransformer, self).__init__()

        self.decoder = LanguageModelDecoder(
            vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
            rank,
        )

    def forward(self, trg):
        trg_mask = self.make_trg_mask(trg)
        out = self.decoder(trg, trg_mask)
        return out

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        ).to(trg.device)

        return trg_mask
    
    # Function to enable or disable QLORA layers (for fine-tuning purposes)
    def toggle_qlora(self, use_qlora: bool):
        self.decoder.toggle_qlora(use_qlora)

    def generate_response(self, input_ids, attention_mask):
        # Assuming you have a forward method that returns logits
        logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
        
        # Convert logits to probabilities
        probabilities = F.softmax(logits, dim=-1)
        
        # For simplicity, using greedy decoding here. You might want to use beam search or sampling.
        predicted_token_id = torch.argmax(probabilities, dim=-1)
        
        # Convert predicted token ids to tokens
        predicted_tokens = [tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
        
        # Join tokens to form the response string. This is a very basic way to generate text and might not produce the best results.
        response = tokenizer.convert_tokens_to_string(predicted_tokens)
        
        return response

###########################################################################
# 7. Internal Switch Routing

# Auxiliary loss function
def auxiliary_loss(gate_scores, expert_capacity):
    expert_load = gate_scores.sum(0) / gate_scores.size(0)
    loss_balancing = torch.std(expert_load)
    return loss_balancing

# Routing function
CAPACITY_FACTOR = 1

def route_inputs(expert_indices, gate_scores, num_experts):
    capacity_factor_tensor = torch.tensor([CAPACITY_FACTOR], dtype=torch.float32)
    capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
    expert_counts = torch.zeros(num_experts, dtype=torch.int32)
    for idx in range(len(expert_indices)):
        selected_expert = expert_indices[idx]
        if expert_counts[selected_expert] < capacities[0]:
            expert_counts[selected_expert] += 1
        else:
            available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
            if len(available_experts) > 0:
                alternative_expert = available_experts[0]
                expert_indices[idx] = alternative_expert
                expert_counts[alternative_expert] += 1
            else:
                print("No available experts to reroute. Handling overflow.")
    return expert_indices

# SwitchGate 
class SwitchGate(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(SwitchGate, self).__init__()
        self.fc1 = nn.Linear(input_dim, input_dim // 2)
        self.fc2 = nn.Linear(input_dim // 2, num_experts)

    def forward(self, x):
        x = F.relu(self.fc1(x.float()))
        gate_scores = F.softmax(self.fc2(x), dim=-1)
        return gate_scores

# SwitchRouter 
class SwitchRouter(nn.Module):
    def __init__(self, input_dim, num_experts, mamba_model_path, context_encoder_path, language_model_path, question_encoder_path, dpo_model_path, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank):
        super(SwitchRouter, self).__init__()
        self.router = SwitchGate(input_dim, num_experts)
        self.transformer_rag = TransformerRAG(context_encoder_path, 
                                              language_model_path, 
                                              question_encoder_path, 
                                              vocab_size)
        self.transformer_dpo = LanguageModelTransformer(vocab_size, 
                                                        embed_size, 
                                                        num_layers, 
                                                        forward_expansion, 
                                                        heads, dropout, 
                                                        max_length, 
                                                        rank)
        self.transformer_dpo = load_model_weights(self.transformer_dpo, dpo_model_path)
        self.mamba = SimplifiedLanguageModelMAMBA(vocab_size=VOCAB_SIZE, 
                                     num_layers=NUM_LAYERS, 
                                     d_model=D_MODEL, 
                                     d_state=D_STATE, 
                                     d_conv=D_CONV, 
                                     expansion_factor=EXPANSION_FACTOR)
        self.mamba = load_model_weights(self.mamba, mamba_model_path)
        self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
        self.input_embedding = nn.Linear(512, input_dim)

    def forward(self, x, attention_mask, context_texts, question_text):
        x = self.input_embedding(x.float())
        gate_scores = self.router(x)
        expert_indices = torch.argmax(gate_scores, dim=1)
        expert_indices = route_inputs(expert_indices, gate_scores, len(self.experts))
        final_output = torch.zeros_like(x)
        aux_loss = 0

        for i, expert in enumerate(self.experts):
            mask = expert_indices == i
            if mask.any():
                selected_inputs = x[mask]
                selected_attention_mask = attention_mask[mask]

                if isinstance(expert, TransformerRAG):
                    # Now passing the required arguments to TransformerRAG
                    expert_output = self.transformer_rag(context_texts, selected_inputs, selected_attention_mask, question_text)
                else:
                    # Process as usual for other experts
                    expert_output = expert(selected_inputs, selected_attention_mask)

                final_output[mask] = expert_output

        # Compute auxiliary loss for load balancing
        aux_loss += auxiliary_loss(gate_scores, expert_capacity=torch.tensor([CAPACITY_FACTOR] * len(self.experts)))

        return final_output, aux_loss

###########################################################################
# 8.Training loop
input_dim = 512
num_experts = 3


model = SwitchRouter(input_dim, num_experts, mamba_model_path, context_encoder, language_model, question_encoder, tran_dpo, 30522, 256, 6, 4, 8, 0.1, 100, 16) #.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
main_loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
aux_loss_weight = 0.1
vocab_size = tokenizer.vocab_size
model_embedding_size = model.transformer_dpo.decoder.word_embedding.num_embeddings
assert tokenizer.vocab_size == VOCAB_SIZE, f"Tokenizer vocab size ({tokenizer.vocab_size}) doesn't match expected vocab size ({VOCAB_SIZE})"

# Ensure train_data is accessible here, with 'queries' and 'contexts' keys
train_data = {
    "queries": [
        # Queries for DPO.pdf
        "What is Direct Preference Optimization (DPO)?",
        "How does Direct Preference Optimization work?",
        "How can I implement Direct Preference Optimization in my organization?",
        "Why does Direct Preference Optimization improve the efficiency of language modelling?",
        # Queries for MAMBA.pdf
        "What is MAMBA?",
        "How does MAMBA function?",
        "How can I build a system based on MAMBA technology?",
        "Why does MAMBA enhance the performance of its application area?",
        # Queries for QLORA.pdf
        "What is QLORA?",
        "How does QLORA operate?",
        "How can I develop a project using QLORA?",
        "Why does QLORA improve the capabilities of its relevant field?",
        # Queries for RAG.pdf
        "What is Retrieval Augmented Generation (RAG)?",
        "How does Retrieval Augmented Generation work?",
        "How can I build a Retrieval Augmented Generation model?",
        "Why does Retrieval Augmented Generation enhance language model performance?",
        # Queries for SWITCH_TRANSFORMER.pdf
        "What is the Switch Transformer model?",
        "How does the Switch Transformer model operate?",
        "How can I construct a Switch Transformer model?",
        "Why does the Switch Transformer model improve language processing tasks?"
    ],
    "contexts": [
        # Contexts from DPO.pdf
        rag_dataset[0],  # Assuming dataset[0] is the processed content of DPO.pdf
        rag_dataset[0],
        rag_dataset[0],
        rag_dataset[0],
        # Contexts from MAMBA.pdf
        rag_dataset[1],  # Assuming dataset[1] is the processed content of MAMBA.pdf
        rag_dataset[1],
        rag_dataset[1],
        rag_dataset[1],
        # Contexts from QLORA.pdf
        rag_dataset[2],  # Assuming dataset[2] is the processed content of QLORA.pdf
        rag_dataset[2],
        rag_dataset[2],
        rag_dataset[2],
        # Contexts from RAG.pdf
        rag_dataset[3],  # Assuming dataset[3] is the processed content of RAG.pdf
        rag_dataset[3],
        rag_dataset[3],
        rag_dataset[3],
        # Contexts from SWITCH_TRANSFORMER.pdf
        rag_dataset[4],  # Assuming dataset[4] is the processed content of SWITCH_TRANSFORMER.pdf
        rag_dataset[4],
        rag_dataset[4],
        rag_dataset[4]
    ]
}
# Before training, ensure the embedding size of each component matches the tokenizer vocab size
assert model.transformer_dpo.decoder.word_embedding.num_embeddings == tokenizer.vocab_size, \
    "Transformer DPO's embedding size does not match tokenizer vocab size"

assert model.transformer_rag.context_encoder.bert_model.config.vocab_size == tokenizer.vocab_size, \
    "RAG context encoder vocab size does not match tokenizer vocab size"

assert model.transformer_rag.question_encoder.bert.config.vocab_size == tokenizer.vocab_size, \
    "RAG question encoder vocab size does not match tokenizer vocab size"


def count_total_parameters(models):
    total_params = 0
    for model in models:
        model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params += model_params
        print(f"Parameters in {model.__class__.__name__}: {model_params}")
    return total_params

# Usage
models = [model.transformer_rag, model.transformer_dpo, model.mamba, model]
total_params = count_total_parameters(models)
print(f"Total trainable parameters across all models: {total_params}")

# Start training
model.train()
num_epochs = 5
for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, batch in enumerate(train_loader):
        inputs, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['labels']

        # Ensure inputs do not exceed the tokenizer's vocabulary size
        if inputs.max() >= tokenizer.vocab_size:
            raise ValueError("Input IDs exceed tokenizer's vocabulary size")

        batch_size = inputs.size(0)
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size

        # Ensure batch indices are within the range of train_data
        assert end_idx <= len(train_data['queries']), "Batch index exceeds size of 'queries' data"
        assert end_idx <= len(train_data['contexts']), "Batch index exceeds size of 'contexts' data"

        # Debug prints for shapes and value ranges
        print(f"Batch index: {batch_idx}")
        print(f"Input IDs shape: {inputs.shape}, Max ID: {inputs.max()}")
        print(f"Attention Mask shape: {attention_mask.shape}")
        print(f"Targets shape: {targets.shape}, Max ID: {targets.max()}")

        current_queries = train_data['queries'][start_idx:end_idx]
        current_contexts = train_data['contexts'][start_idx:end_idx]

        # Debug prints for current queries and contexts
        for i, (q, context_list) in enumerate(zip(current_queries, current_contexts)):
            print(f"Processing query-context pair {i}:")
            print(f"Query: {q}")
            #for c in context_list:  # Assuming context_list is a list of dictionaries
                #print(f"Context input IDs shape: {torch.tensor(c['input_ids']).shape}")
                #print(f"Context attention mask shape: {torch.tensor(c['attention_mask']).shape}")

        # Call to the model forward function
        outputs, aux_loss = model(inputs, attention_mask, current_contexts, current_queries)

        # Calculate loss
        main_loss = main_loss_function(outputs.view(-1, outputs.size(-1)), targets.view(-1))
        total_loss = main_loss + aux_loss_weight * aux_loss
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()  

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {total_loss.item()}')

    print(f'End of Epoch {epoch+1}, Average Loss: {total_loss / len(train_loader)}')

True vocab size: 30522


ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

# v6

In [1]:
# 0. Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import fitz  # PyMuPDF
from transformers import BertModel
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

torch.autograd.set_detect_anomaly(True)

#############################################################################
# 1. Preprocessing Data

# Load the dataset
code_dataset = load_dataset("code_search_net", "python")

# Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Preprocess and tokenize dataset
def preprocess_function(examples):
    tokenized_output = tokenizer(examples['func_code_string'], padding="max_length", truncation=True, max_length=512)
    labels = [row[:-1] + [tokenizer.pad_token_id] for row in tokenized_output["input_ids"]]
    tokenized_output["labels"] = labels
    return tokenized_output

tokenized_dataset = code_dataset.map(preprocess_function, batched=True)
tokenized_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

# DataLoader
train_dataset = tokenized_dataset["train"]
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"True vocab size: {tokenizer.vocab_size}")

#############################################################################
# 2. Sub-Model Weights
# 1. Transformer with DPO:
tran_dpo = r'C:\Users\robbi\IEEMM\language_model_weights.pth'
# 2. MAMBA:
mamba_model_path = r'C:\Users\robbi\IEEMM\mamba_model_weights.pth'
# 3. Transformer and RAG:
context_encoder = r'C:\Users\robbi\IEEMM\context_encoder.pth'
language_model = r'C:\Users\robbi\IEEMM\language_model.pth'
question_encoder = r'C:\Users\robbi\IEEMM\question_encoder.pth'

# Load model weights function
def load_model_weights(model, model_path):
    #checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    checkpoint = torch.load(model_path, map_location=device)

    if isinstance(checkpoint, dict):
        # Check for 'state_dict' or 'model_state_dict' keys
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        elif 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            # If no known key is found, try loading it as a raw state dictionary
            try:
                model.load_state_dict(checkpoint)
            except RuntimeError as e:
                raise ValueError(f"Error loading state dict: {e}")
    elif isinstance(checkpoint, nn.Module):
        # If the checkpoint is a model object, assign it directly
        model = checkpoint
    else:
        raise ValueError(f"Unsupported checkpoint format: {type(checkpoint)}")

    model.eval()
    return model
#############################################################################
# 3. MAMBA
# RoPE
class RotaryPositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_len).type_as(inv_freq)
        freqs = torch.einsum('n , d -> n d', t, inv_freq)
        self.register_buffer('sin', freqs.sin())
        self.register_buffer('cos', freqs.cos())

    def forward(self, x):
        n, _, device = x.shape[1], self.dim // 2, x.device
        sin, cos = self.sin[:n].to(device), self.cos[:n].to(device)

        # Apply RoPE to even and odd indices separately
        x_even = x[..., :self.dim:2] * cos.unsqueeze(0) + torch.roll(x[..., 1:self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
        x_odd = x[..., 1:self.dim:2] * cos.unsqueeze(0) - torch.roll(x[..., :self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
        return torch.cat((x_even, x_odd), dim=-1)

# SWIGLU
class SwiGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(SwiGLU, self).__init__()
        self.fc1 = nn.Linear(dim_in, dim_out)
        self.fc2 = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        gate = torch.sigmoid(self.fc2(x))
        return self.fc1(x) * gate

class SimplifiedMAMBA(nn.Module):
    # Adjusted to include SwiGLU blocks
    def __init__(self, num_layers, d_model, d_state, d_conv, expansion_factor):
        super().__init__()

        self.num_layers = num_layers
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_state),
            nn.GELU(),
            nn.Linear(d_state, d_model)
        )
        self.input_embedding = nn.Linear(d_model, d_model)
        self.convs = nn.Sequential(*[nn.Conv1d(d_model, d_model, kernel_size=d_conv, padding=(d_conv // 2)) for _ in range(num_layers)])
        self.swiglu = SwiGLU(d_model, d_model)
        self.output_projection = nn.Linear(d_model, d_model * expansion_factor)  # Adjusted to match the output of SwiGLU


        self.initialize_weights()

    def initialize_weights(self):
        gain = nn.init.calculate_gain('relu')

        nn.init.orthogonal_(self.input_embedding.weight, gain)
        nn.init.normal_(self.input_embedding.bias, mean=0, std=0.01)

        nn.init.kaiming_uniform_(self.convs[-1].weight, a=math.sqrt(5))
        nn.init.zeros_(self.convs[-1].bias)

        nn.init.xavier_uniform_(self.feedforward[0].weight, gain=nn.init.calculate_gain('relu'))
        nn.init.zeros_(self.feedforward[0].bias)

        nn.init.xavier_uniform_(self.feedforward[2].weight, gain=nn.init.calculate_gain('linear'))
        nn.init.zeros_(self.feedforward[2].bias)

        nn.init.xavier_uniform_(self.output_projection.weight, gain=nn.init.calculate_gain('linear'))
        nn.init.zeros_(self.output_projection.bias)

    def forward(self, inputs, attention_mask=None):
        print("Input shape:", inputs.shape)

        # Apply the attention mask if provided
        if attention_mask is not None:
            inputs = inputs * attention_mask.unsqueeze(-1)

        projected_inputs = self.input_embedding(inputs)
        print("projected_inputs pre-reshape shape:", projected_inputs.shape)

        projected_inputs = projected_inputs.permute(0, 2, 1)
        print("projected_inputs post-reshape shape:", projected_inputs.shape)

        for conv in self.convs:
            projected_inputs = conv(projected_inputs)

        projected_inputs = projected_inputs.permute(0, 2, 1)
        print("projected_inputs post convolution reshape:", projected_inputs.shape)

        projected_inputs = self.swiglu(projected_inputs)
        print("projected_inputs post swiglu shape:", projected_inputs.shape)

        output = self.output_projection(projected_inputs)
        print("output shape:", output.shape)

        return output

class SimplifiedLanguageModelMAMBA(nn.Module):
    # Including rotary positional encodings if required
    def __init__(self, vocab_size, num_layers, d_model, d_state, d_conv, expansion_factor):
        super().__init__()

        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = RotaryPositionalEncoding(d_model)
        self.simplified_mamba = SimplifiedMAMBA(num_layers, d_model, d_state, d_conv, expansion_factor)
        self.output_projection = nn.Linear(d_model*2, vocab_size)

        self.initialize_weights()

    def initialize_weights(self):
        gain = 1.0

        nn.init.orthogonal_(self.embedding.weight, gain)
        nn.init.xavier_uniform_(self.output_projection.weight, gain=gain)

    def forward(self, input_values, attention_mask):
        embedded = self.embedding(input_values) * math.sqrt(self.d_model)
        embedded = self.pos_encoder(embedded)
        simplified_mamba_output = self.simplified_mamba(embedded, attention_mask)
        logits = self.output_projection(simplified_mamba_output)
        return logits


LEARNING_RATE = 5e-4
WEIGHT_DECAY =  0.1
WARMUP_STEPS = 100
TOTAL_STEPS = 1000 # we want this to be : epochs * (size of dataset / batch_size )
EPOCHS = 100
VOCAB_SIZE = 30522
NUM_LAYERS = 4
BATCH_SIZE = 8
EXPANSION_FACTOR = 2
CLIP_GRADIENT = 1.0
D_MODEL = 512  # Dimensionality of the model's embeddings
D_STATE = 2048  # Dimensionality of the intermediate state in feedforward
D_CONV = 3  # Kernel size for convolutional layers


def setup_optimizer(model, learning_rate, weight_decay, warmup_steps, total_steps):
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Linear warmup with cosine decay
    scheduler = LambdaLR(optimizer, lambda step: min((step + 1) / warmup_steps, 0.5 * (1 + math.cos(math.pi * step / total_steps))))

    return optimizer, scheduler


# Initialize the optimizer and scheduler with appropriate parameters
#mamba_optimizer, mamba_scheduler = setup_optimizer(mamba_model, LEARNING_RATE, WEIGHT_DECAY, WARMUP_STEPS, TOTAL_STEPS)


#######################################################################################
# 4. RAG
def extract_text_from_pdf(file_path):
    text = ""
    with fitz.open(file_path) as doc:
        for page in doc:
            text += page.get_text()
    return text


def split_into_chunks(text, chunk_size):
    # Split the text into chunks of chunk_size
    return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

def preprocess_text(text, max_length=512):
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Split the text into smaller chunks to maintain context
    # The chunk size is slightly less than max_length to account for special tokens
    chunk_size = max_length - 50  # Adjust this value based on the model's requirements
    text_chunks = split_into_chunks(text, chunk_size)

    # Process each chunk
    processed_chunks = []
    for chunk in text_chunks:
        tokenized_output = tokenizer(chunk, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
        processed_chunk = {
            'input_ids': tokenized_output['input_ids'],
            'attention_mask': tokenized_output['attention_mask']
        }
        processed_chunks.append(processed_chunk)

    return processed_chunks


def create_dataset_from_pdfs(pdf_file_paths):
    dataset = []
    for file_path in pdf_file_paths:
        text = extract_text_from_pdf(file_path)
        processed_text = preprocess_text(text)
        dataset.append(processed_text)
    return dataset

def retrieve_contexts(dataset, query_embedding, top_k=5):
    # Assume dataset is a list of dictionaries with each dictionary containing 'input_ids' and 'attention_mask'
    # for a particular context and that each context has been processed through a DPRContextEncoder to get embeddings

    # Placeholder for storing similarity scores
    similarity_scores = []

    # Iterate over each context in the dataset
    for context in dataset:
        context_input_ids = context['input_ids']
        context_attention_mask = context['attention_mask']

        # Assuming context_encoder is an instance of CustomDPRContextEncoder that's already trained
        # and available in your scope
        context_embedding = context_encoder(context_input_ids, context_attention_mask)

        # Compute similarity (e.g., using dot product)
        similarity = torch.matmul(query_embedding, context_embedding.T)

        similarity_scores.append(similarity.squeeze().item())

    # Sort contexts based on similarity scores and retrieve top_k indices
    top_k_indices = sorted(range(len(similarity_scores)), key=lambda i: similarity_scores[i], reverse=True)[:top_k]

    # Retrieve top_k contexts
    top_contexts = [dataset[i] for i in top_k_indices]

    return top_contexts

def rag_retrieve_and_generate(dataset, query):
    # Instantiate the question encoder
    question_encoder = DPRQuestionEncoder()

    # Encode the query
    encoded_query = question_encoder(query)

    # Retrieve relevant context
    # This involves finding the most similar documents in the dataset
    # For simplicity, this is represented as a function 'retrieve_contexts'
    relevant_contexts = retrieve_contexts(dataset, encoded_query)

    # Language model for generation
    language_model = LanguageModelTransformer()

    # Generate a response based on the retrieved contexts
    # This step may involve further formatting or preprocessing
    response = language_model.generate_response(relevant_contexts)

    return response

# pdfs
pdf_file_paths = [r'C:\Users\robbi\IEEMM\DPO.pdf', 
                  r'C:\Users\robbi\IEEMM\MAMBA.pdf',
                  r'C:\Users\robbi\IEEMM\QLORA.pdf',
                  r'C:\Users\robbi\IEEMM\RAG.pdf',
                  r'C:\Users\robbi\IEEMM\SWITCH_TRANSFORMER.pdf']

rag_dataset = create_dataset_from_pdfs(pdf_file_paths)

class CustomDPRContextEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
        super(CustomDPRContextEncoder, self).__init__()
        # Transformer-based model, e.g., BERT
        self.bert_model = BertModel.from_pretrained(model_name)
        # Additional layer to produce fixed-size embeddings
        self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

    def forward(self, input_ids, attention_mask=None):
        # Generate outputs from the BERT model
        outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        # Use the pooled output for creating embeddings
        pooled_output = outputs.pooler_output
        # Pass through the embedding layer
        context_embeddings = self.embedding_layer(pooled_output)
        return context_embeddings

class DPRQuestionEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
        super(DPRQuestionEncoder, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.embedding_layer = nn.Linear(self.bert.config.hidden_size, embedding_dim)

    def forward(self, input_ids, attention_mask, **kwargs):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        embeddings = self.embedding_layer(pooled_output)
        return embeddings

# TransformerRAG class
class TransformerRAG(nn.Module):
    def __init__(self, context_encoder_path, language_model_path, question_encoder_path, vocab_size):
        super(TransformerRAG, self).__init__()
        self.context_encoder = CustomDPRContextEncoder().to(device)
        self.language_model = LanguageModelTransformer(
            vocab_size=vocab_size,
            embed_size=256, 
            num_layers=6, 
            forward_expansion=4, 
            heads=8, 
            dropout=0, 
            max_length=100,  # Set to 512 to match the tokenization max_length
            rank=16
        ).to(device)
        self.language_model = load_model_weights(self.language_model, language_model_path)
        self.question_encoder = DPRQuestionEncoder().to(device)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def forward(self, context_texts, question_input_ids, question_attention_mask, question_text):
        if question_input_ids.max() >= self.tokenizer.vocab_size:
            raise ValueError("question_input_ids contain ID(s) beyond the tokenizer's vocabulary size")

        # Process each context_text
        aggregated_context_embeddings = []
        for context_list in context_texts:
            if not all(isinstance(context, dict) for context in context_list):
                raise TypeError("Each item in context_texts must be a list of tokenized context dictionaries")

            # Create a tensor of zeros with the correct shape on the GPU
            aggregated_context_embedding = torch.zeros(self.context_encoder.bert_model.config.hidden_size, device=device)
            for context in context_list:
                context_input_ids = context['input_ids'].to(device)
                context_attention_mask = context['attention_mask'].to(device)
                context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                aggregated_context_embedding += context_embedding.mean(dim=0)

            aggregated_context_embeddings.append(aggregated_context_embedding / len(context_list))

        question_input_ids = question_input_ids.to(device).long()
        question_attention_mask = question_attention_mask.to(device).long()
        question_embeddings = self.question_encoder(input_ids=question_input_ids, attention_mask=question_attention_mask)

        cos_sim = torch.nn.CosineSimilarity(dim=1)
        similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in aggregated_context_embeddings]
        most_relevant_context_idx = torch.argmax(torch.tensor(similarities, device=device))

        combined_input = question_text + " " + context_texts[most_relevant_context_idx]
        tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
        tokenized_combined_input = {k: v.to(device) for k, v in tokenized_combined_input.items()}  # Move to GPU
        response_logits = self.language_model(**tokenized_combined_input)
        probabilities = F.softmax(response_logits.logits, dim=-1)
        predicted_token_ids = torch.argmax(probabilities, dim=-1)
        predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])
        response = self.tokenizer.convert_tokens_to_string(predicted_tokens)

        return response


###########################################################################
# 5. LanguageModelTransformer

class LORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1):
        super(LORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha

        # Original weight and bias of the linear layer
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        #print("self.weight Shape:", self.weight.shape)
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # LORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def forward(self, x):
        #print("LORALayer Input Shape:", x.shape)
        
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)
        
        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        #print("LORALayer Output Shape:", x.shape)

        return x

class QLORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
        super(QLORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha
        self.quantization_bits = quantization_bits

        # Original weight and bias
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # QLORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(output_dim)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def quantize(self, x, num_bits):
        # Implement a simple quantization method
        scale = x.abs().max()
        x_quantized = torch.round(x / scale * (2**num_bits - 1))
        return x_quantized, scale

    def forward(self, x):
        #print("QLORALayer Input Shape:", x.shape)
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
        B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
        lora_adjustment = self.dropout(lora_adjustment)
        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)

        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        x = self.layer_norm(x)

        #print("QLORALayer Output Shape:", x.shape)

        return x
    
    def update_alpha(self, new_alpha):
        """
        Update the alpha scaling factor.
        """
        self.alpha = new_alpha

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does the matrix multiplication for query*keys for each training example
        # with every other training example, then sum it up
        attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            attention = attention.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            LORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            LORALayer(forward_expansion * embed_size, embed_size, rank),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class LanguageModelDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
        super(LanguageModelDecoder, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        # Adding BatchNorm layers
        self.bn1 = nn.BatchNorm1d(embed_size)
        self.bn2 = nn.BatchNorm1d(embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                for _ in range(num_layers)
            ]
        )

        # QLORA layers
        self.qlora_feed_forward = nn.Sequential(
            QLORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            QLORALayer(forward_expansion * embed_size, embed_size, rank),
        )
        self.use_qlora = False  # Flag to toggle QLORA

        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length) #.to(x.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn1(x)
        x = x.transpose(1, 2)

        for layer in self.layers:
            x = layer(x, x, x, trg_mask)
            if self.use_qlora:
                x = self.qlora_feed_forward(x)

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn2(x)
        x = x.transpose(1, 2)

        out = self.fc_out(x)
        #print(f"shape of output of forward method of LanguageModelDecoder: {out.shape} ")

        return out

    def toggle_qlora(self, use_qlora: bool):
        self.use_qlora = use_qlora

class LanguageModelTransformer(nn.Module):
    def __init__(self, vocab_size, embed_size=256, num_layers=6, forward_expansion=4, heads=8, dropout=0, max_length=100, rank=16):
        super(LanguageModelTransformer, self).__init__()

        self.decoder = LanguageModelDecoder(
            vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
            rank,
        )

    def forward(self, trg):
        trg_mask = self.make_trg_mask(trg)
        out = self.decoder(trg, trg_mask)
        return out

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        ).to(trg.device)

        return trg_mask
    
    # Function to enable or disable QLORA layers (for fine-tuning purposes)
    def toggle_qlora(self, use_qlora: bool):
        self.decoder.toggle_qlora(use_qlora)

    def generate_response(self, input_ids, attention_mask):
        # Assuming you have a forward method that returns logits
        logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
        
        # Convert logits to probabilities
        probabilities = F.softmax(logits, dim=-1)
        
        # For simplicity, using greedy decoding here. You might want to use beam search or sampling.
        predicted_token_id = torch.argmax(probabilities, dim=-1)
        
        # Convert predicted token ids to tokens
        predicted_tokens = [tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
        
        # Join tokens to form the response string. This is a very basic way to generate text and might not produce the best results.
        response = tokenizer.convert_tokens_to_string(predicted_tokens)
        
        return response

###########################################################################
# 7. Internal Switch Routing

# Auxiliary loss function
def auxiliary_loss(gate_scores, expert_capacity):
    expert_load = gate_scores.sum(0) / gate_scores.size(0)
    loss_balancing = torch.std(expert_load)
    return loss_balancing

# Routing function
CAPACITY_FACTOR = 1

def route_inputs(expert_indices, gate_scores, num_experts):
    capacity_factor_tensor = torch.tensor([CAPACITY_FACTOR], dtype=torch.float32)
    capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
    expert_counts = torch.zeros(num_experts, dtype=torch.int32)
    for idx in range(len(expert_indices)):
        selected_expert = expert_indices[idx]
        if expert_counts[selected_expert] < capacities[0]:
            expert_counts[selected_expert] += 1
        else:
            available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
            if len(available_experts) > 0:
                alternative_expert = available_experts[0]
                expert_indices[idx] = alternative_expert
                expert_counts[alternative_expert] += 1
            else:
                print("No available experts to reroute. Handling overflow.")
    return expert_indices

# SwitchGate 
class SwitchGate(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(SwitchGate, self).__init__()
        self.fc1 = nn.Linear(input_dim, input_dim // 2)
        self.fc2 = nn.Linear(input_dim // 2, num_experts)

    def forward(self, x):
        x = F.relu(self.fc1(x.float()))
        gate_scores = F.softmax(self.fc2(x), dim=-1)
        return gate_scores

# SwitchRouter 
class SwitchRouter(nn.Module):
    def __init__(self, input_dim, num_experts, mamba_model_path, context_encoder_path, language_model_path, question_encoder_path, dpo_model_path, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank):
        super(SwitchRouter, self).__init__()
        self.router = SwitchGate(input_dim, num_experts)
        self.transformer_rag = TransformerRAG(context_encoder_path, 
                                              language_model_path, 
                                              question_encoder_path, 
                                              vocab_size).to(device)
        self.transformer_dpo = LanguageModelTransformer(vocab_size, 
                                                        embed_size, 
                                                        num_layers, 
                                                        forward_expansion, 
                                                        heads, dropout, 
                                                        max_length, 
                                                        rank).to(device)
        self.transformer_dpo = load_model_weights(self.transformer_dpo, dpo_model_path)
        self.mamba = SimplifiedLanguageModelMAMBA(vocab_size=VOCAB_SIZE, 
                                     num_layers=NUM_LAYERS, 
                                     d_model=D_MODEL, 
                                     d_state=D_STATE, 
                                     d_conv=D_CONV, 
                                     expansion_factor=EXPANSION_FACTOR).to(device)
        self.mamba = load_model_weights(self.mamba, mamba_model_path)
        self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
        self.input_embedding = nn.Linear(512, input_dim)

    def forward(self, x, attention_mask, context_texts, question_text):
        x = self.input_embedding(x.float())
        gate_scores = self.router(x)
        expert_indices = torch.argmax(gate_scores, dim=1)
        expert_indices = route_inputs(expert_indices, gate_scores, len(self.experts))
        final_output = torch.zeros_like(x)
        aux_loss = 0

        for i, expert in enumerate(self.experts):
            mask = expert_indices == i
            if mask.any():
                selected_inputs = x[mask]
                selected_attention_mask = attention_mask[mask]

                if isinstance(expert, TransformerRAG):
                    # Now passing the required arguments to TransformerRAG
                    expert_output = self.transformer_rag(context_texts, selected_inputs, selected_attention_mask, question_text)
                else:
                    # Process as usual for other experts
                    expert_output = expert(selected_inputs, selected_attention_mask)

                final_output[mask] = expert_output

        # Compute auxiliary loss for load balancing
        aux_loss += auxiliary_loss(gate_scores, expert_capacity=torch.tensor([CAPACITY_FACTOR] * len(self.experts)))

        return final_output, aux_loss

###########################################################################
# 8.Training loop
input_dim = 512
num_experts = 3


model = SwitchRouter(input_dim, num_experts, mamba_model_path, context_encoder, language_model, question_encoder, tran_dpo, 30522, 256, 6, 4, 8, 0.1, 100, 16) #.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
model.to(device)
main_loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
aux_loss_weight = 0.1
vocab_size = tokenizer.vocab_size
model_embedding_size = model.transformer_dpo.decoder.word_embedding.num_embeddings
assert tokenizer.vocab_size == VOCAB_SIZE, f"Tokenizer vocab size ({tokenizer.vocab_size}) doesn't match expected vocab size ({VOCAB_SIZE})"

# Ensure train_data is accessible here, with 'queries' and 'contexts' keys
train_data = {
    "queries": [
        # Queries for DPO.pdf
        "What is Direct Preference Optimization (DPO)?",
        "How does Direct Preference Optimization work?",
        "How can I implement Direct Preference Optimization in my organization?",
        "Why does Direct Preference Optimization improve the efficiency of language modelling?",
        # Queries for MAMBA.pdf
        "What is MAMBA?",
        "How does MAMBA function?",
        "How can I build a system based on MAMBA technology?",
        "Why does MAMBA enhance the performance of its application area?",
        # Queries for QLORA.pdf
        "What is QLORA?",
        "How does QLORA operate?",
        "How can I develop a project using QLORA?",
        "Why does QLORA improve the capabilities of its relevant field?",
        # Queries for RAG.pdf
        "What is Retrieval Augmented Generation (RAG)?",
        "How does Retrieval Augmented Generation work?",
        "How can I build a Retrieval Augmented Generation model?",
        "Why does Retrieval Augmented Generation enhance language model performance?",
        # Queries for SWITCH_TRANSFORMER.pdf
        "What is the Switch Transformer model?",
        "How does the Switch Transformer model operate?",
        "How can I construct a Switch Transformer model?",
        "Why does the Switch Transformer model improve language processing tasks?"
    ],
    "contexts": [
        # Contexts from DPO.pdf
        rag_dataset[0],  # Assuming dataset[0] is the processed content of DPO.pdf
        rag_dataset[0],
        rag_dataset[0],
        rag_dataset[0],
        # Contexts from MAMBA.pdf
        rag_dataset[1],  # Assuming dataset[1] is the processed content of MAMBA.pdf
        rag_dataset[1],
        rag_dataset[1],
        rag_dataset[1],
        # Contexts from QLORA.pdf
        rag_dataset[2],  # Assuming dataset[2] is the processed content of QLORA.pdf
        rag_dataset[2],
        rag_dataset[2],
        rag_dataset[2],
        # Contexts from RAG.pdf
        rag_dataset[3],  # Assuming dataset[3] is the processed content of RAG.pdf
        rag_dataset[3],
        rag_dataset[3],
        rag_dataset[3],
        # Contexts from SWITCH_TRANSFORMER.pdf
        rag_dataset[4],  # Assuming dataset[4] is the processed content of SWITCH_TRANSFORMER.pdf
        rag_dataset[4],
        rag_dataset[4],
        rag_dataset[4]
    ]
}
# Before training, ensure the embedding size of each component matches the tokenizer vocab size
assert model.transformer_dpo.decoder.word_embedding.num_embeddings == tokenizer.vocab_size, \
    "Transformer DPO's embedding size does not match tokenizer vocab size"

assert model.transformer_rag.context_encoder.bert_model.config.vocab_size == tokenizer.vocab_size, \
    "RAG context encoder vocab size does not match tokenizer vocab size"

assert model.transformer_rag.question_encoder.bert.config.vocab_size == tokenizer.vocab_size, \
    "RAG question encoder vocab size does not match tokenizer vocab size"


def count_total_parameters(models):
    total_params = 0
    for model in models:
        model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params += model_params
        print(f"Parameters in {model.__class__.__name__}: {model_params}")
    return total_params

# Usage
models = [model.transformer_rag, model.transformer_dpo, model.mamba, model]
total_params = count_total_parameters(models)
print(f"Total trainable parameters across all models: {total_params}")

# Start training
model.train()
num_epochs = 5
for epoch in range(num_epochs):
    total_loss = 0
    for batch_idx, batch in enumerate(train_loader):
        inputs, attention_mask, targets = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)

        # Ensure inputs do not exceed the tokenizer's vocabulary size
        if inputs.max() >= tokenizer.vocab_size:
            raise ValueError("Input IDs exceed tokenizer's vocabulary size")

        batch_size = inputs.size(0)
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size

        # Ensure batch indices are within the range of train_data
        assert end_idx <= len(train_data['queries']), "Batch index exceeds size of 'queries' data"
        assert end_idx <= len(train_data['contexts']), "Batch index exceeds size of 'contexts' data"

        # Debug prints for shapes and value ranges
        print(f"Batch index: {batch_idx}")
        print(f"Input IDs shape: {inputs.shape}, Max ID: {inputs.max()}")
        print(f"Attention Mask shape: {attention_mask.shape}")
        print(f"Targets shape: {targets.shape}, Max ID: {targets.max()}")

        current_queries = train_data['queries'][start_idx:end_idx]
        current_contexts = train_data['contexts'][start_idx:end_idx]

        # Debug prints for current queries and contexts
        for i, (q, context_list) in enumerate(zip(current_queries, current_contexts)):
            print(f"Processing query-context pair {i}:")
            print(f"Query: {q}")
            #for c in context_list:  # Assuming context_list is a list of dictionaries
                #print(f"Context input IDs shape: {torch.tensor(c['input_ids']).shape}")
                #print(f"Context attention mask shape: {torch.tensor(c['attention_mask']).shape}")

        # Call to the model forward function
        outputs, aux_loss = model(inputs, attention_mask, current_contexts, current_queries)

        # Calculate loss
        main_loss = main_loss_function(outputs.view(-1, outputs.size(-1)), targets.view(-1))
        total_loss = main_loss + aux_loss_weight * aux_loss
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()  

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch+1}, Batch: {batch_idx}, Loss: {total_loss.item()}')

    print(f'End of Epoch {epoch+1}, Average Loss: {total_loss / len(train_loader)}')

Map:   0%|          | 0/412178 [00:00<?, ? examples/s]

Map:   0%|          | 0/22176 [00:00<?, ? examples/s]

Map:   0%|          | 0/23107 [00:00<?, ? examples/s]

True vocab size: 30522
Parameters in TransformerRAG: 240217658
Parameters in LanguageModelTransformer: 20071994
Parameters in SimplifiedLanguageModelMAMBA: 53473082
Parameters in SwitchRouter: 314157489
Total trainable parameters across all models: 627920223
Batch index: 0
Input IDs shape: torch.Size([4, 512]), Max ID: 27507
Attention Mask shape: torch.Size([4, 512])
Targets shape: torch.Size([4, 512]), Max ID: 27507
Processing query-context pair 0:
Query: What is Direct Preference Optimization (DPO)?
Processing query-context pair 1:
Query: How does Direct Preference Optimization work?
Processing query-context pair 2:
Query: How can I implement Direct Preference Optimization in my organization?
Processing query-context pair 3:
Query: Why does Direct Preference Optimization improve the efficiency of language modelling?
No available experts to reroute. Handling overflow.


OutOfMemoryError: CUDA out of memory. Tried to allocate 12.00 MiB. GPU 0 has a total capacty of 4.00 GiB of which 0 bytes is free. Of the allocated memory 10.79 GiB is allocated by PyTorch, and 13.33 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

# v7- grad accumulation

In [3]:
# 0. Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import fitz  # PyMuPDF
from transformers import BertModel
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

torch.autograd.set_detect_anomaly(True)

#############################################################################
# 1. Preprocessing Data

# Load the dataset
code_dataset = load_dataset("code_search_net", "python")

# Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Preprocess and tokenize dataset
def preprocess_function(examples):
    tokenized_output = tokenizer(examples['func_code_string'], padding="max_length", truncation=True, max_length=512)
    labels = [row[:-1] + [tokenizer.pad_token_id] for row in tokenized_output["input_ids"]]
    tokenized_output["labels"] = labels
    return tokenized_output

tokenized_dataset = code_dataset.map(preprocess_function, batched=True)
tokenized_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

# DataLoader
train_dataset = tokenized_dataset["train"]
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"True vocab size: {tokenizer.vocab_size}")

#############################################################################
# 2. Sub-Model Weights
# 1. Transformer with DPO:
tran_dpo = r'C:\Users\robbi\IEEMM\language_model_weights.pth'
# 2. MAMBA:
mamba_model_path = r'C:\Users\robbi\IEEMM\mamba_model_weights.pth'
# 3. Transformer and RAG:
context_encoder = r'C:\Users\robbi\IEEMM\context_encoder.pth'
language_model = r'C:\Users\robbi\IEEMM\language_model.pth'
question_encoder = r'C:\Users\robbi\IEEMM\question_encoder.pth'

# Load model weights function
def load_model_weights(model, model_path):
    #checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
    checkpoint = torch.load(model_path, map_location=device)

    if isinstance(checkpoint, dict):
        # Check for 'state_dict' or 'model_state_dict' keys
        if 'state_dict' in checkpoint:
            model.load_state_dict(checkpoint['state_dict'])
        elif 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            # If no known key is found, try loading it as a raw state dictionary
            try:
                model.load_state_dict(checkpoint)
            except RuntimeError as e:
                raise ValueError(f"Error loading state dict: {e}")
    elif isinstance(checkpoint, nn.Module):
        # If the checkpoint is a model object, assign it directly
        model = checkpoint
    else:
        raise ValueError(f"Unsupported checkpoint format: {type(checkpoint)}")

    model.eval()
    return model
#############################################################################
# 3. MAMBA
# RoPE
class RotaryPositionalEncoding(nn.Module):
    def __init__(self, dim, max_len=5000):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_len).type_as(inv_freq)
        freqs = torch.einsum('n , d -> n d', t, inv_freq)
        self.register_buffer('sin', freqs.sin())
        self.register_buffer('cos', freqs.cos())

    def forward(self, x):
        n, _, device = x.shape[1], self.dim // 2, x.device
        sin, cos = self.sin[:n].to(device), self.cos[:n].to(device)

        # Apply RoPE to even and odd indices separately
        x_even = x[..., :self.dim:2] * cos.unsqueeze(0) + torch.roll(x[..., 1:self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
        x_odd = x[..., 1:self.dim:2] * cos.unsqueeze(0) - torch.roll(x[..., :self.dim:2], shifts=1, dims=-1) * sin.unsqueeze(0)
        return torch.cat((x_even, x_odd), dim=-1)

# SWIGLU
class SwiGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(SwiGLU, self).__init__()
        self.fc1 = nn.Linear(dim_in, dim_out)
        self.fc2 = nn.Linear(dim_in, dim_out)

    def forward(self, x):
        gate = torch.sigmoid(self.fc2(x))
        return self.fc1(x) * gate

class SimplifiedMAMBA(nn.Module):
    # Adjusted to include SwiGLU blocks
    def __init__(self, num_layers, d_model, d_state, d_conv, expansion_factor):
        super().__init__()

        self.num_layers = num_layers
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_state),
            nn.GELU(),
            nn.Linear(d_state, d_model)
        )
        self.input_embedding = nn.Linear(d_model, d_model)
        self.convs = nn.Sequential(*[nn.Conv1d(d_model, d_model, kernel_size=d_conv, padding=(d_conv // 2)) for _ in range(num_layers)])
        self.swiglu = SwiGLU(d_model, d_model)
        self.output_projection = nn.Linear(d_model, d_model * expansion_factor)  # Adjusted to match the output of SwiGLU


        self.initialize_weights()

    def initialize_weights(self):
        gain = nn.init.calculate_gain('relu')

        nn.init.orthogonal_(self.input_embedding.weight, gain)
        nn.init.normal_(self.input_embedding.bias, mean=0, std=0.01)

        nn.init.kaiming_uniform_(self.convs[-1].weight, a=math.sqrt(5))
        nn.init.zeros_(self.convs[-1].bias)

        nn.init.xavier_uniform_(self.feedforward[0].weight, gain=nn.init.calculate_gain('relu'))
        nn.init.zeros_(self.feedforward[0].bias)

        nn.init.xavier_uniform_(self.feedforward[2].weight, gain=nn.init.calculate_gain('linear'))
        nn.init.zeros_(self.feedforward[2].bias)

        nn.init.xavier_uniform_(self.output_projection.weight, gain=nn.init.calculate_gain('linear'))
        nn.init.zeros_(self.output_projection.bias)

    def forward(self, inputs, attention_mask=None):
        print("Input shape:", inputs.shape)

        # Apply the attention mask if provided
        if attention_mask is not None:
            inputs = inputs * attention_mask.unsqueeze(-1)

        projected_inputs = self.input_embedding(inputs)
        print("projected_inputs pre-reshape shape:", projected_inputs.shape)

        projected_inputs = projected_inputs.permute(0, 2, 1)
        print("projected_inputs post-reshape shape:", projected_inputs.shape)

        for conv in self.convs:
            projected_inputs = conv(projected_inputs)

        projected_inputs = projected_inputs.permute(0, 2, 1)
        print("projected_inputs post convolution reshape:", projected_inputs.shape)

        projected_inputs = self.swiglu(projected_inputs)
        print("projected_inputs post swiglu shape:", projected_inputs.shape)

        output = self.output_projection(projected_inputs)
        print("output shape:", output.shape)

        return output

class SimplifiedLanguageModelMAMBA(nn.Module):
    # Including rotary positional encodings if required
    def __init__(self, vocab_size, num_layers, d_model, d_state, d_conv, expansion_factor):
        super().__init__()

        self.vocab_size = vocab_size
        self.num_layers = num_layers
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expansion_factor = expansion_factor

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = RotaryPositionalEncoding(d_model)
        self.simplified_mamba = SimplifiedMAMBA(num_layers, d_model, d_state, d_conv, expansion_factor)
        self.output_projection = nn.Linear(d_model*2, vocab_size)

        self.initialize_weights()

    def initialize_weights(self):
        gain = 1.0

        nn.init.orthogonal_(self.embedding.weight, gain)
        nn.init.xavier_uniform_(self.output_projection.weight, gain=gain)

    def forward(self, input_values, attention_mask):
        embedded = self.embedding(input_values) * math.sqrt(self.d_model)
        embedded = self.pos_encoder(embedded)
        simplified_mamba_output = self.simplified_mamba(embedded, attention_mask)
        logits = self.output_projection(simplified_mamba_output)
        return logits


LEARNING_RATE = 5e-4
WEIGHT_DECAY =  0.1
WARMUP_STEPS = 100
TOTAL_STEPS = 1000 # we want this to be : epochs * (size of dataset / batch_size )
EPOCHS = 100
VOCAB_SIZE = 30522
NUM_LAYERS = 4
BATCH_SIZE = 8
EXPANSION_FACTOR = 2
CLIP_GRADIENT = 1.0
D_MODEL = 512  # Dimensionality of the model's embeddings
D_STATE = 2048  # Dimensionality of the intermediate state in feedforward
D_CONV = 3  # Kernel size for convolutional layers


def setup_optimizer(model, learning_rate, weight_decay, warmup_steps, total_steps):
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Linear warmup with cosine decay
    scheduler = LambdaLR(optimizer, lambda step: min((step + 1) / warmup_steps, 0.5 * (1 + math.cos(math.pi * step / total_steps))))

    return optimizer, scheduler


# Initialize the optimizer and scheduler with appropriate parameters
#mamba_optimizer, mamba_scheduler = setup_optimizer(mamba_model, LEARNING_RATE, WEIGHT_DECAY, WARMUP_STEPS, TOTAL_STEPS)


#######################################################################################
# 4. RAG
def extract_text_from_pdf(file_path):
    text = ""
    with fitz.open(file_path) as doc:
        for page in doc:
            text += page.get_text()
    return text

def split_into_chunks(text, chunk_size):
    # Split the text into chunks of chunk_size
    return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)]

def preprocess_text(text, max_length=512):
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Split the text into smaller chunks to maintain context
    # The chunk size is slightly less than max_length to account for special tokens
    chunk_size = max_length - 50  # Adjust this value based on the model's requirements
    text_chunks = split_into_chunks(text, chunk_size)

    # Process each chunk
    processed_chunks = []
    for chunk in text_chunks:
        tokenized_output = tokenizer(chunk, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
        processed_chunk = {
            'input_ids': tokenized_output['input_ids'],
            'attention_mask': tokenized_output['attention_mask']
        }
        processed_chunks.append(processed_chunk)

    return processed_chunks

def create_dataset_from_pdfs(pdf_file_paths):
    dataset = []
    for file_path in pdf_file_paths:
        text = extract_text_from_pdf(file_path)
        processed_text = preprocess_text(text)
        dataset.append(processed_text)
    return dataset

def retrieve_contexts(dataset, query_embedding, top_k=5):
    # Assume dataset is a list of dictionaries with each dictionary containing 'input_ids' and 'attention_mask'
    # for a particular context and that each context has been processed through a DPRContextEncoder to get embeddings

    # Placeholder for storing similarity scores
    similarity_scores = []

    # Iterate over each context in the dataset
    for context in dataset:
        context_input_ids = context['input_ids']
        context_attention_mask = context['attention_mask']

        # Assuming context_encoder is an instance of CustomDPRContextEncoder that's already trained
        # and available in your scope
        context_embedding = context_encoder(context_input_ids, context_attention_mask)

        # Compute similarity (e.g., using dot product)
        similarity = torch.matmul(query_embedding, context_embedding.T)

        similarity_scores.append(similarity.squeeze().item())

    # Sort contexts based on similarity scores and retrieve top_k indices
    top_k_indices = sorted(range(len(similarity_scores)), key=lambda i: similarity_scores[i], reverse=True)[:top_k]

    # Retrieve top_k contexts
    top_contexts = [dataset[i] for i in top_k_indices]

    return top_contexts

def rag_retrieve_and_generate(dataset, query):
    # Instantiate the question encoder
    question_encoder = DPRQuestionEncoder()

    # Encode the query
    encoded_query = question_encoder(query)

    # Retrieve relevant context
    # This involves finding the most similar documents in the dataset
    # For simplicity, this is represented as a function 'retrieve_contexts'
    relevant_contexts = retrieve_contexts(dataset, encoded_query)

    # Language model for generation
    language_model = LanguageModelTransformer()

    # Generate a response based on the retrieved contexts
    # This step may involve further formatting or preprocessing
    response = language_model.generate_response(relevant_contexts)

    return response

# pdfs
pdf_file_paths = [r'C:\Users\robbi\IEEMM\DPO.pdf', 
                  r'C:\Users\robbi\IEEMM\MAMBA.pdf',
                  r'C:\Users\robbi\IEEMM\QLORA.pdf',
                  r'C:\Users\robbi\IEEMM\RAG.pdf',
                  r'C:\Users\robbi\IEEMM\SWITCH_TRANSFORMER.pdf']

rag_dataset = create_dataset_from_pdfs(pdf_file_paths)

class CustomDPRContextEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
        super(CustomDPRContextEncoder, self).__init__()
        # Transformer-based model, e.g., BERT
        self.bert_model = BertModel.from_pretrained(model_name)
        # Additional layer to produce fixed-size embeddings
        self.embedding_layer = nn.Linear(self.bert_model.config.hidden_size, embedding_dim)

    def forward(self, input_ids, attention_mask=None):
        # Generate outputs from the BERT model
        outputs = self.bert_model(input_ids=input_ids, attention_mask=attention_mask)
        # Use the pooled output for creating embeddings
        pooled_output = outputs.pooler_output
        # Pass through the embedding layer
        context_embeddings = self.embedding_layer(pooled_output)
        return context_embeddings

class DPRQuestionEncoder(nn.Module):
    def __init__(self, model_name='bert-base-uncased', embedding_dim=768):
        super(DPRQuestionEncoder, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.embedding_layer = nn.Linear(self.bert.config.hidden_size, embedding_dim)

    def forward(self, input_ids, attention_mask, **kwargs):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        embeddings = self.embedding_layer(pooled_output)
        return embeddings

# TransformerRAG class
class TransformerRAG(nn.Module):
    def __init__(self, context_encoder_path, language_model_path, question_encoder_path, vocab_size):
        super(TransformerRAG, self).__init__()
        self.context_encoder = CustomDPRContextEncoder().to(device)
        self.language_model = LanguageModelTransformer(
            vocab_size=vocab_size,
            embed_size=256, 
            num_layers=6, 
            forward_expansion=4, 
            heads=8, 
            dropout=0, 
            max_length=100,  # Set to 512 to match the tokenization max_length
            rank=16
        ).to(device)
        self.language_model = load_model_weights(self.language_model, language_model_path)
        self.question_encoder = DPRQuestionEncoder().to(device)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    def forward(self, context_texts, question_input_ids, question_attention_mask, question_text):
        if question_input_ids.max() >= self.tokenizer.vocab_size:
            raise ValueError("question_input_ids contain ID(s) beyond the tokenizer's vocabulary size")

        # Process each context_text
        aggregated_context_embeddings = []
        for context_list in context_texts:
            if not all(isinstance(context, dict) for context in context_list):
                raise TypeError("Each item in context_texts must be a list of tokenized context dictionaries")

            # Create a tensor of zeros with the correct shape on the GPU
            aggregated_context_embedding = torch.zeros(self.context_encoder.bert_model.config.hidden_size, device=device)
            for context in context_list:
                context_input_ids = context['input_ids'].to(device)
                context_attention_mask = context['attention_mask'].to(device)
                context_embedding = self.context_encoder(context_input_ids, context_attention_mask)
                aggregated_context_embedding += context_embedding.mean(dim=0)

            aggregated_context_embeddings.append(aggregated_context_embedding / len(context_list))

        question_input_ids = question_input_ids.to(device).long()
        question_attention_mask = question_attention_mask.to(device).long()
        question_embeddings = self.question_encoder(input_ids=question_input_ids, attention_mask=question_attention_mask)

        cos_sim = torch.nn.CosineSimilarity(dim=1)
        similarities = [cos_sim(question_embeddings, context_emb.squeeze(0)) for context_emb in aggregated_context_embeddings]
        most_relevant_context_idx = torch.argmax(torch.tensor(similarities, device=device))

        combined_input = question_text + " " + context_texts[most_relevant_context_idx]
        tokenized_combined_input = self.tokenizer(combined_input, return_tensors="pt", padding=True, truncation=True, max_length=512)
        tokenized_combined_input = {k: v.to(device) for k, v in tokenized_combined_input.items()}  # Move to GPU
        response_logits = self.language_model(**tokenized_combined_input)
        probabilities = F.softmax(response_logits.logits, dim=-1)
        predicted_token_ids = torch.argmax(probabilities, dim=-1)
        predicted_tokens = self.tokenizer.convert_ids_to_tokens(predicted_token_ids[0])
        response = self.tokenizer.convert_tokens_to_string(predicted_tokens)

        return response


###########################################################################
# 5. LanguageModelTransformer

class LORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1):
        super(LORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha

        # Original weight and bias of the linear layer
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        #print("self.weight Shape:", self.weight.shape)
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # LORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def forward(self, x):
        #print("LORALayer Input Shape:", x.shape)
        
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)
        
        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        #print("LORALayer Output Shape:", x.shape)

        return x

class QLORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
        super(QLORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha
        self.quantization_bits = quantization_bits

        # Original weight and bias
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # QLORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(output_dim)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def quantize(self, x, num_bits):
        # Implement a simple quantization method
        scale = x.abs().max()
        x_quantized = torch.round(x / scale * (2**num_bits - 1))
        return x_quantized, scale

    def forward(self, x):
        #print("QLORALayer Input Shape:", x.shape)
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
        B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
        lora_adjustment = self.dropout(lora_adjustment)
        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)

        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        x = self.layer_norm(x)

        #print("QLORALayer Output Shape:", x.shape)

        return x
    
    def update_alpha(self, new_alpha):
        """
        Update the alpha scaling factor.
        """
        self.alpha = new_alpha

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does the matrix multiplication for query*keys for each training example
        # with every other training example, then sum it up
        attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            attention = attention.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            LORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            LORALayer(forward_expansion * embed_size, embed_size, rank),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class LanguageModelDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
        super(LanguageModelDecoder, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        # Adding BatchNorm layers
        self.bn1 = nn.BatchNorm1d(embed_size)
        self.bn2 = nn.BatchNorm1d(embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                for _ in range(num_layers)
            ]
        )

        # QLORA layers
        self.qlora_feed_forward = nn.Sequential(
            QLORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            QLORALayer(forward_expansion * embed_size, embed_size, rank),
        )
        self.use_qlora = False  # Flag to toggle QLORA

        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length) #.to(x.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn1(x)
        x = x.transpose(1, 2)

        for layer in self.layers:
            x = layer(x, x, x, trg_mask)
            if self.use_qlora:
                x = self.qlora_feed_forward(x)

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn2(x)
        x = x.transpose(1, 2)

        out = self.fc_out(x)
        #print(f"shape of output of forward method of LanguageModelDecoder: {out.shape} ")

        return out

    def toggle_qlora(self, use_qlora: bool):
        self.use_qlora = use_qlora

class LanguageModelTransformer(nn.Module):
    def __init__(self, vocab_size, embed_size=256, num_layers=6, forward_expansion=4, heads=8, dropout=0, max_length=100, rank=16):
        super(LanguageModelTransformer, self).__init__()

        self.decoder = LanguageModelDecoder(
            vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
            rank,
        )

    def forward(self, trg):
        trg_mask = self.make_trg_mask(trg)
        out = self.decoder(trg, trg_mask)
        return out

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        ).to(trg.device)

        return trg_mask
    
    # Function to enable or disable QLORA layers (for fine-tuning purposes)
    def toggle_qlora(self, use_qlora: bool):
        self.decoder.toggle_qlora(use_qlora)

    def generate_response(self, input_ids, attention_mask):
        # Assuming you have a forward method that returns logits
        logits = self.forward(input_ids=input_ids, attention_mask=attention_mask)
        
        # Convert logits to probabilities
        probabilities = F.softmax(logits, dim=-1)
        
        # For simplicity, using greedy decoding here. You might want to use beam search or sampling.
        predicted_token_id = torch.argmax(probabilities, dim=-1)
        
        # Convert predicted token ids to tokens
        predicted_tokens = [tokenizer.convert_ids_to_tokens(idx.item()) for idx in predicted_token_id]
        
        # Join tokens to form the response string. This is a very basic way to generate text and might not produce the best results.
        response = tokenizer.convert_tokens_to_string(predicted_tokens)
        
        return response

###########################################################################
# 7. Internal Switch Routing

# Auxiliary loss function
def auxiliary_loss(gate_scores, expert_capacity):
    expert_load = gate_scores.sum(0) / gate_scores.size(0)
    loss_balancing = torch.std(expert_load)
    return loss_balancing

# Routing function
CAPACITY_FACTOR = 1

def route_inputs(expert_indices, gate_scores, num_experts):
    capacity_factor_tensor = torch.tensor([CAPACITY_FACTOR], dtype=torch.float32)
    capacities = (gate_scores.size(0) * capacity_factor_tensor / num_experts).int()
    expert_counts = torch.zeros(num_experts, dtype=torch.int32)
    for idx in range(len(expert_indices)):
        selected_expert = expert_indices[idx]
        if expert_counts[selected_expert] < capacities[0]:
            expert_counts[selected_expert] += 1
        else:
            available_experts = (expert_counts < capacities[0]).nonzero(as_tuple=False).view(-1)
            if len(available_experts) > 0:
                alternative_expert = available_experts[0]
                expert_indices[idx] = alternative_expert
                expert_counts[alternative_expert] += 1
            else:
                print("No available experts to reroute. Handling overflow.")
    return expert_indices

# SwitchGate 
class SwitchGate(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(SwitchGate, self).__init__()
        self.fc1 = nn.Linear(input_dim, input_dim // 2)
        self.fc2 = nn.Linear(input_dim // 2, num_experts)

    def forward(self, x):
        x = F.relu(self.fc1(x.float()))
        gate_scores = F.softmax(self.fc2(x), dim=-1)
        return gate_scores

# SwitchRouter 
class SwitchRouter(nn.Module):
    def __init__(self, input_dim, num_experts, mamba_model_path, context_encoder_path, language_model_path, question_encoder_path, dpo_model_path, vocab_size, embed_size, num_layers, forward_expansion, heads, dropout, max_length, rank):
        super(SwitchRouter, self).__init__()
        self.router = SwitchGate(input_dim, num_experts)
        self.transformer_rag = TransformerRAG(context_encoder_path, 
                                              language_model_path, 
                                              question_encoder_path, 
                                              vocab_size).to(device)
        self.transformer_dpo = LanguageModelTransformer(vocab_size, 
                                                        embed_size, 
                                                        num_layers, 
                                                        forward_expansion, 
                                                        heads, dropout, 
                                                        max_length, 
                                                        rank).to(device)
        self.transformer_dpo = load_model_weights(self.transformer_dpo, dpo_model_path)
        self.mamba = SimplifiedLanguageModelMAMBA(vocab_size=VOCAB_SIZE, 
                                     num_layers=NUM_LAYERS, 
                                     d_model=D_MODEL, 
                                     d_state=D_STATE, 
                                     d_conv=D_CONV, 
                                     expansion_factor=EXPANSION_FACTOR).to(device)
        self.mamba = load_model_weights(self.mamba, mamba_model_path)
        self.experts = nn.ModuleList([self.transformer_rag, self.transformer_dpo, self.mamba])
        self.input_embedding = nn.Linear(512, input_dim)

    def forward(self, x, attention_mask, context_texts, question_text):
        x = self.input_embedding(x.float())
        gate_scores = self.router(x)
        expert_indices = torch.argmax(gate_scores, dim=1)
        expert_indices = route_inputs(expert_indices, gate_scores, len(self.experts))
        final_output = torch.zeros_like(x)
        aux_loss = 0

        for i, expert in enumerate(self.experts):
            mask = expert_indices == i
            if mask.any():
                selected_inputs = x[mask]
                selected_attention_mask = attention_mask[mask]

                if isinstance(expert, TransformerRAG):
                    # Now passing the required arguments to TransformerRAG
                    expert_output = self.transformer_rag(context_texts, selected_inputs, selected_attention_mask, question_text)
                else:
                    # Process as usual for other experts
                    expert_output = expert(selected_inputs, selected_attention_mask)

                final_output[mask] = expert_output

        # Compute auxiliary loss for load balancing
        aux_loss += auxiliary_loss(gate_scores, expert_capacity=torch.tensor([CAPACITY_FACTOR] * len(self.experts)))

        return final_output, aux_loss

###########################################################################
# 8.Training loop
input_dim = 512
num_experts = 3


model = SwitchRouter(input_dim, num_experts, mamba_model_path, context_encoder, language_model, question_encoder, tran_dpo, 30522, 256, 6, 4, 8, 0.1, 100, 16) #.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
model.to(device)
main_loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)
aux_loss_weight = 0.1
vocab_size = tokenizer.vocab_size
model_embedding_size = model.transformer_dpo.decoder.word_embedding.num_embeddings
assert tokenizer.vocab_size == VOCAB_SIZE, f"Tokenizer vocab size ({tokenizer.vocab_size}) doesn't match expected vocab size ({VOCAB_SIZE})"

# Ensure train_data is accessible here, with 'queries' and 'contexts' keys
train_data = {
    "queries": [
        # Queries for DPO.pdf
        "What is Direct Preference Optimization (DPO)?",
        "How does Direct Preference Optimization work?",
        "How can I implement Direct Preference Optimization in my organization?",
        "Why does Direct Preference Optimization improve the efficiency of language modelling?",
        # Queries for MAMBA.pdf
        "What is MAMBA?",
        "How does MAMBA function?",
        "How can I build a system based on MAMBA technology?",
        "Why does MAMBA enhance the performance of its application area?",
        # Queries for QLORA.pdf
        "What is QLORA?",
        "How does QLORA operate?",
        "How can I develop a project using QLORA?",
        "Why does QLORA improve the capabilities of its relevant field?",
        # Queries for RAG.pdf
        "What is Retrieval Augmented Generation (RAG)?",
        "How does Retrieval Augmented Generation work?",
        "How can I build a Retrieval Augmented Generation model?",
        "Why does Retrieval Augmented Generation enhance language model performance?",
        # Queries for SWITCH_TRANSFORMER.pdf
        "What is the Switch Transformer model?",
        "How does the Switch Transformer model operate?",
        "How can I construct a Switch Transformer model?",
        "Why does the Switch Transformer model improve language processing tasks?"
    ],
    "contexts": [
        # Contexts from DPO.pdf
        rag_dataset[0],  # Assuming dataset[0] is the processed content of DPO.pdf
        rag_dataset[0],
        rag_dataset[0],
        rag_dataset[0],
        # Contexts from MAMBA.pdf
        rag_dataset[1],  # Assuming dataset[1] is the processed content of MAMBA.pdf
        rag_dataset[1],
        rag_dataset[1],
        rag_dataset[1],
        # Contexts from QLORA.pdf
        rag_dataset[2],  # Assuming dataset[2] is the processed content of QLORA.pdf
        rag_dataset[2],
        rag_dataset[2],
        rag_dataset[2],
        # Contexts from RAG.pdf
        rag_dataset[3],  # Assuming dataset[3] is the processed content of RAG.pdf
        rag_dataset[3],
        rag_dataset[3],
        rag_dataset[3],
        # Contexts from SWITCH_TRANSFORMER.pdf
        rag_dataset[4],  # Assuming dataset[4] is the processed content of SWITCH_TRANSFORMER.pdf
        rag_dataset[4],
        rag_dataset[4],
        rag_dataset[4]
    ]
}
# Before training, ensure the embedding size of each component matches the tokenizer vocab size
assert model.transformer_dpo.decoder.word_embedding.num_embeddings == tokenizer.vocab_size, \
    "Transformer DPO's embedding size does not match tokenizer vocab size"

assert model.transformer_rag.context_encoder.bert_model.config.vocab_size == tokenizer.vocab_size, \
    "RAG context encoder vocab size does not match tokenizer vocab size"

assert model.transformer_rag.question_encoder.bert.config.vocab_size == tokenizer.vocab_size, \
    "RAG question encoder vocab size does not match tokenizer vocab size"


def count_total_parameters(models):
    total_params = 0
    for model in models:
        model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        total_params += model_params
        print(f"Parameters in {model.__class__.__name__}: {model_params}")
    return total_params

# Usage
models = [model.transformer_rag, model.transformer_dpo, model.mamba, model]
total_params = count_total_parameters(models)
print(f"Total trainable parameters across all models: {total_params}")

# Start training
accumulation_steps = 4
model.train()
num_epochs = 5
for epoch in range(num_epochs):
    total_loss = 0
    optimizer.zero_grad()  # Initialize gradients to zero
    for batch_idx, batch in enumerate(train_loader):
        inputs, attention_mask, targets = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)

        # Calculate start and end indices for current batch in train_data
        start_idx = batch_idx * batch['input_ids'].size(0)
        end_idx = start_idx + batch['input_ids'].size(0)

        # Extract current_queries and current_contexts for the batch
        current_queries = train_data['queries'][start_idx:end_idx]
        current_contexts = train_data['contexts'][start_idx:end_idx]

        # Call to the model forward function
        outputs, aux_loss = model(inputs, attention_mask, current_contexts, current_queries)

        # Calculate loss and accumulate
        main_loss = main_loss_function(outputs.view(-1, outputs.size(-1)), targets.view(-1))
        loss = (main_loss + aux_loss_weight * aux_loss) / accumulation_steps
        loss.backward()

        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item() * accumulation_steps  # Scale back up

    average_loss = total_loss / len(train_loader)
    print(f'End of Epoch {epoch+1}, Average Loss: {average_loss}')


True vocab size: 30522


OutOfMemoryError: CUDA out of memory. Tried to allocate 90.00 MiB. GPU 0 has a total capacty of 4.00 GiB of which 0 bytes is free. Of the allocated memory 10.78 GiB is allocated by PyTorch, and 13.29 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF