# SWITCH TRANSFORMER / MoE :

1. **Expert Level:** Each expert consists of three sub-models - a Transformer with DPO, RAG, and MAMBA. This level focuses on text-based input and is designed to handle specific domains or types of queries with high proficiency.



2. **Switch Transformer MoE (Mixture of Experts) Level:** This layer integrates multiple Experts, each fine-tuned for different domains or tasks. The Switch Transformer directs queries to the most relevant Expert based on the context, leveraging the specialized skills of each component model.



3. **King Model Level:** Each King Model represents a multimodal system, incorporating separate MoE structures for different types of data, such as text, images, videos, and audio. This approach allows the King Model to process and understand a wide range of inputs, potentially enabling richer and more nuanced responses.



4. **God Model Level:** At this level, multiple King Models, each trained in different domains (e.g., science, literature, art), are integrated. The God Model can draw on a vast pool of domain-specific knowledge and multimodal understanding, making it highly versatile and capable of handling complex, multi-faceted queries.



5. **Multiverse Model Level:** This ultimate layer aggregates multiple God Models, each representing different spheres of knowledge or different approaches to intelligence. Such a system could theoretically possess an extraordinarily broad and deep understanding of the world, resembling AGI.

# LANGUAGE TRANSFORMER AND DPO

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import random_split
from transformers import BertTokenizer
from datasets import load_dataset


class LORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1):
        super(LORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha

        # Original weight and bias of the linear layer
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        #print("self.weight Shape:", self.weight.shape)
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # LORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def forward(self, x):
        #print("LORALayer Input Shape:", x.shape)
        
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ self.A) @ self.B
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)

        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)
        
        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        #print("LORALayer Output Shape:", x.shape)

        return x

class QLORALayer(nn.Module):
    def __init__(self, input_dim, output_dim, rank, alpha=1, quantization_bits=8):
        super(QLORALayer, self).__init__()
        self.rank = rank
        self.alpha = alpha
        self.quantization_bits = quantization_bits

        # Original weight and bias
        self.weight = nn.Parameter(torch.Tensor(output_dim, input_dim))
        self.bias = nn.Parameter(torch.Tensor(output_dim))

        # QLORA specific parameters
        self.A = nn.Parameter(torch.Tensor(input_dim, rank))
        self.B = nn.Parameter(torch.Tensor(rank, output_dim))

        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(output_dim)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.zeros_(self.bias)
        nn.init.normal_(self.A, 0, 0.02)
        nn.init.normal_(self.B, 0, 0.02)

    def quantize(self, x, num_bits):
        # Implement a simple quantization method
        scale = x.abs().max()
        x_quantized = torch.round(x / scale * (2**num_bits - 1))
        return x_quantized, scale

    def forward(self, x):
        #print("QLORALayer Input Shape:", x.shape)
        original_size = x.size()
        batch_size, seq_len, _ = x.shape
        x_flattened = x.reshape(-1, original_size[-1])

        A_quantized, scale_A = self.quantize(self.A, self.quantization_bits)
        B_quantized, scale_B = self.quantize(self.B, self.quantization_bits)

        # Compute lora_adjustment for each input in the batch
        lora_adjustment = self.alpha * (x_flattened @ (A_quantized / scale_A)) @ (B_quantized / scale_B)
        lora_adjustment = lora_adjustment.reshape(batch_size, seq_len, -1)
        lora_adjustment = self.dropout(lora_adjustment)
        #print("Adjusted lora_adjustment Shape:", lora_adjustment.shape)
        #print("self.weight Shape:", self.weight.shape)

        # Apply linear transformation to x_flattened
        x_transformed = nn.functional.linear(x_flattened, self.weight, self.bias)
        x_transformed = x_transformed.reshape(batch_size, seq_len, -1)

        # Add lora_adjustment to the transformed x
        x = x_transformed + lora_adjustment
        x = self.layer_norm(x)

        #print("QLORALayer Output Shape:", x.shape)

        return x
    
    def update_alpha(self, new_alpha):
        """
        Update the alpha scaling factor.
        """
        self.alpha = new_alpha

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Einsum does the matrix multiplication for query*keys for each training example
        # with every other training example, then sum it up
        attention = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            attention = attention.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        out = self.fc_out(out)
        return out

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion, rank):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            LORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            LORALayer(forward_expansion * embed_size, embed_size, rank),
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

class LanguageModelDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_layers, heads, forward_expansion, dropout, max_length, rank):
        super(LanguageModelDecoder, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        # Adding BatchNorm layers
        self.bn1 = nn.BatchNorm1d(embed_size)
        self.bn2 = nn.BatchNorm1d(embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, dropout, forward_expansion, rank)
                for _ in range(num_layers)
            ]
        )

        # QLORA layers
        self.qlora_feed_forward = nn.Sequential(
            QLORALayer(embed_size, forward_expansion * embed_size, rank),
            nn.ReLU(),
            QLORALayer(forward_expansion * embed_size, embed_size, rank),
        )
        self.use_qlora = False  # Flag to toggle QLORA

        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(x.device)
        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn1(x)
        x = x.transpose(1, 2)

        for layer in self.layers:
            x = layer(x, x, x, trg_mask)
            if self.use_qlora:
                x = self.qlora_feed_forward(x)

        # Transpose for BatchNorm, apply batch normalization, and then transpose back
        x = x.transpose(1, 2)
        x = self.bn2(x)
        x = x.transpose(1, 2)

        out = self.fc_out(x)
        return out

    def toggle_qlora(self, use_qlora: bool):
        self.use_qlora = use_qlora

class LanguageModelTransformer(nn.Module):
    def __init__(self, vocab_size, embed_size=256, num_layers=6, forward_expansion=4, heads=8, dropout=0, max_length=100, rank=16):
        super(LanguageModelTransformer, self).__init__()

        self.decoder = LanguageModelDecoder(
            vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            max_length,
            rank,
        )

    def forward(self, trg):
        trg_mask = self.make_trg_mask(trg)
        out = self.decoder(trg, trg_mask)
        return out

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        ).to(trg.device)

        return trg_mask

# Define vocabulary size and dummy data parameters
NUM_WORDS = 1000  # Example vocabulary size
sequence_length = 30  # Sequence length for the LanguageDataset
dummy_data_size = 1000  # Total number of tokens in the dummy dataset



# Load dataset
dataset = load_dataset('wikipedia', '20220301.simple')

# Initialize tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab_size = tokenizer.vocab_size

def tokenize_function(examples):
    # Tokenize the text
    tokenized_output = tokenizer(examples['text'], padding='max_length', truncation=True, max_length=sequence_length)
    
    # Shift input_ids to create labels and truncate the last token
    labels = [seq[1:] + [tokenizer.pad_token_id] for seq in tokenized_output['input_ids']]
    tokenized_output['labels'] = labels
    
    return tokenized_output

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Set format for PyTorch
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_loader = DataLoader(tokenized_datasets['train'], batch_size=64, shuffle=True)

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create the model instance
model = LanguageModelTransformer(
    vocab_size=vocab_size,  # Use the vocab size from the tokenizer
    embed_size=256,
    num_layers=6,
    forward_expansion=4,
    heads=8,
    dropout=0,
    max_length=100,
    rank=16
).to(device)


def calculate_new_alpha(current_loss, initial_loss, initial_alpha=1.0, final_alpha=0.1):
    """
    Calculate a new alpha value based on the current loss.
    """
    if current_loss >= initial_loss:
        return initial_alpha  # Keep initial alpha if loss isn't decreasing

    loss_ratio = current_loss / initial_loss
    alpha_range = initial_alpha - final_alpha
    new_alpha = final_alpha + (alpha_range * loss_ratio)
    return new_alpha

# Enable QLORA during training
model.decoder.toggle_qlora(True)

initial_loss = None
# Training loop
# Assuming model is an instance of LanguageModelTransformer
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=1e-8, weight_decay=1e-4)
scheduler = StepLR(optimizer, step_size=4, gamma=0.98)
num_epochs = 5

# Training loop
for epoch in range(num_epochs):
    model.train()
    model.decoder.toggle_qlora(True)
    total_loss = 0

    for batch_idx, batch in enumerate(train_loader):
        inputs = batch['input_ids'].to(device)
        targets = batch['labels'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))

        # Check for NaN in loss
        if math.isnan(loss.item()):
            print("Encountered NaN loss, stopping training")
            break

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        optimizer.step()

        total_loss += loss.item()

        # Set the initial_loss after the first batch of the first epoch
        if initial_loss is None and batch_idx == 0:
            initial_loss = loss.item()

    scheduler.step()

    # Check for NaN in total_loss
    if math.isnan(total_loss):
        print(f"Epoch {epoch+1}/{num_epochs} stopped due to NaN loss")
        break
    else:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss / len(train_loader)}")

    # Average loss for the epoch
    average_loss = total_loss / len(train_loader)

    # Update alpha at the end of each epoch based on the average loss
    new_alpha = calculate_new_alpha(average_loss, initial_loss)
    for layer in model.modules():
        if isinstance(layer, QLORALayer):
            layer.update_alpha(new_alpha)

    #model.decoder.toggle_qlora(False)




In [None]:
import torch
from datasets import load_dataset

def format_stackexchange_dpo(samples):
    return {
        "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
        "chosen": samples["response_j"],   # Rated better than k
        "rejected": samples["response_k"], # Rated worse than j
    }

# Load and format a subset (30%) of the StackExchange DPO dataset
dataset = load_dataset("lvwerra/stack-exchange-paired")
subset_size = int(0.3 * len(dataset['train']))  # 30% of the dataset
subset_indices = torch.randperm(len(dataset['train'])).tolist()[:subset_size]  # Randomly select indices
formatted_dataset = dataset['train'].select(subset_indices).map(format_stackexchange_dpo, batched=True, load_from_cache_file=False)

# Convert formatted dataset to DataLoader for batch processing
dpo_dataloader = DataLoader(formatted_dataset, batch_size=64, shuffle=True)

In [None]:
from torch.nn import MarginRankingLoss

# Define DPO-specific loss function
dpo_loss_function = MarginRankingLoss(margin=1.0)
dpo_num_epochs = 2  # Define the number of epochs for DPO training

# DPO Training loop
for epoch in range(dpo_num_epochs):
    model.train()  # Ensure the model is in training mode
    total_dpo_loss = 0

    for batch in dpo_dataloader:
        optimizer.zero_grad()

        # Prepare the input for the model
        prompts = batch['prompt'].to(device)
        preferred_responses = batch['chosen'].to(device)
        less_preferred_responses = batch['rejected'].to(device)

        # Forward pass and model's scoring mechanism for responses
        # The model should output scores for the preferred and less-preferred responses
        output_preferred = model(preferred_responses)
        output_less_preferred = model(less_preferred_responses)

        # Compute DPO loss
        dpo_loss = dpo_loss_function(output_preferred, output_less_preferred, torch.ones(output_preferred.size(0)).to(device))
        total_dpo_loss += dpo_loss.item()

        # Backward pass and optimization
        dpo_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        optimizer.step()

    print(f"Epoch {epoch+1}/{dpo_num_epochs}, DPO Loss: {total_dpo_loss / len(dpo_dataloader)}")


# MAMBA

In [None]:
!pip install datasets


In [32]:
#MAMBA
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
from torch.optim import AdamW
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"


# Tokenizer and Dataset - same as your working code
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = load_dataset('wikipedia', '20220301.simple')

def tokenize_function(examples):
    return tokenizer(examples['text'], padding='max_length', truncation=True, max_length=512)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask'])

# DataLoader - same as your working code
train_loader = DataLoader(tokenized_datasets['train'], batch_size=8, shuffle=True)


# Define MAMBA Model with appropriate hyperparameters
#vocab_size = len(tokenizer)  # Get the correct vocab size from the tokenizer
vocab_size = tokenizer.vocab_size



class SparseAttention(nn.Module):

    def __init__(self, embed_size, heads, window_size):

        super(SparseAttention, self).__init__()

        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        assert self.head_dim * heads == embed_size, "Embedding size must be divisible by heads"
        self.window_size = window_size

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        print(f"self.value: {self.values}")

        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        print(f"self.keys: {self.keys}")

        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        print(f"self.queries: {self.queries}")


        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)


    def forward(self, values, keys, query, mask):
        N, value_len, key_len, query_len = values.size(0), values.size(1), keys.size(1), query.size(1)
        
        # Reshape values, keys, and queries for multi-head attention
        values_reshaped = values.view(N, value_len, self.heads, self.head_dim)
        print(f"values_reshaped: {values_reshaped.shape}")
        keys_reshaped = keys.view(N, key_len, self.heads, self.head_dim)
        print(f"keys_reshaped: {keys_reshaped.shape}")
        queries_reshaped = query.view(N, query_len, self.heads, self.head_dim)
        print(f"queries_reshaped: {queries_reshaped.shape}")

        # Apply linear layers
        values = self.values(values_reshaped)
        print(f"values: {values.shape}")
        keys = self.keys(keys_reshaped)
        print(f"keys: {keys.shape}")
        queries = self.queries(queries_reshaped)
        print(f"queries: {queries.shape}")

        # Calculate attention scores
        attention = self.calculate_sparse_attention(queries, keys, self.window_size)
        print(f"attention: {attention.shape}")

        if mask is not None:
            attention = attention.masked_fill(mask == 0, float("-1e20"))

        # Apply softmax and calculate the output
        attention = torch.softmax(attention / (self.embed_size ** (1 / 2)), dim=3)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).view(N, query_len, self.heads * self.head_dim)

        # Before the final linear layer
        print("Output shape before fc_out:", out.shape)
        out = self.fc_out(out)
        print("Output shape after fc_out:", out.shape)

        return out
    
    def calculate_sparse_attention(self, queries, keys, window_size):
        N, query_len, heads, head_dim = queries.size()
        _, seq_len, _, _ = keys.size()

        # Initialize attention scores with negative infinity for sparsity
        attention_scores = torch.full((N, heads, query_len, seq_len), float('-inf'), device=queries.device)
        print(f"attention_scores: {attention_scores.shape}")
        # Iterate over each position in the sequence
        for i in range(seq_len):
            # Determine the local window range
            start = max(0, i - window_size)
            end = min(seq_len, i + window_size + 1)

            # Calculate attention scores within the window
            keys_window = keys[:, start:end, :, :]
            print(f"keys_window: {keys_window.shape}")
            # Adjust queries_window slicing to be aligned with the keys_window
            queries_window = queries[:, i:i+1, :, :].expand(-1, end - start, -1, -1)
            print(f"queries_window: {queries_window.shape}")

            # Compute scores using einsum
            scores = torch.einsum("nhqd,nwhd->nhqw", [queries_window, keys_window])

            # Place the calculated scores in their respective positions
            attention_scores[:, :, i, start:end] = scores.squeeze(2)

        return attention_scores




class SSMLayer(nn.Module):

    def __init__(self, input_dim, state_dim):

        super(SSMLayer, self).__init__()

        self.state_update = nn.Linear(input_dim, state_dim)

        self.state_process = nn.Linear(state_dim, input_dim)



    def forward(self, x, states):

        new_states = self.state_update(x) + states

        processed_states = self.state_process(new_states)

        return processed_states, new_states


class ModifiedTransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, state_dim, dropout, forward_expansion, lstm_dim, window_size):
        
        super(ModifiedTransformerBlock, self).__init__()
        
        self.attention = SparseAttention(embed_size, heads, window_size)

        self.norm1 = nn.LayerNorm(embed_size)

        self.norm2 = nn.LayerNorm(embed_size)

        self.ssm_layer = SSMLayer(embed_size, state_dim)



        # LSTM Layer

        self.lstm = nn.LSTM(input_size=embed_size, hidden_size=lstm_dim, batch_first=True)

        

        self.feed_forward = nn.Sequential(

            nn.Linear(embed_size + lstm_dim, forward_expansion * embed_size),

            nn.ReLU(),

            nn.Linear(forward_expansion * embed_size, embed_size),

        )

        self.dropout = nn.Dropout(dropout)



    def forward(self, value, key, query, mask, states, lstm_hidden=None):

        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))

        processed_states, new_states = self.ssm_layer(x, states)

        

        # Pass the output through LSTM layer

        lstm_output, new_lstm_hidden = self.lstm(x, lstm_hidden)

        x = torch.cat((x, lstm_output), dim=-1)



        forward = self.feed_forward(x)

        out = self.dropout(self.norm2(forward + x))

        return out, new_states, new_lstm_hidden


class MAMBA(nn.Module):

    def __init__(self, vocab_size, embed_size, num_blocks, heads, state_dim, lstm_dim, forward_expansion, max_length, dropout, window_size):

        super(MAMBA, self).__init__()



        self.word_embedding = nn.Embedding(vocab_size, embed_size)

        self.position_embedding = nn.Embedding(max_length, embed_size)



        self.transformer_blocks = nn.ModuleList([

            ModifiedTransformerBlock(

                embed_size=embed_size,

                heads=heads,

                state_dim=state_dim,

                dropout=dropout,

                forward_expansion=forward_expansion,

                lstm_dim=lstm_dim,

                window_size=window_size

            ) for _ in range(num_blocks)
        ])



        self.dropout = nn.Dropout(dropout)

        self.fc_out = nn.Linear(embed_size + lstm_dim, vocab_size)



    def forward(self, x, mask):

        N, seq_length = x.shape

        positions = torch.arange(0, seq_length).expand(N, seq_length).to(x.device)

        x = self.dropout(self.word_embedding(x) + self.position_embedding(positions))



        states = None

        lstm_hidden = None

        for block in self.transformer_blocks:

            x, states, lstm_hidden = block(x, x, x, mask, states, lstm_hidden)



        out = self.fc_out(x)

        return out



# Define hyperparameters
embed_size = 256  # Embedding size
num_blocks = 6  # Number of transformer blocks
heads = 8  # Number of attention heads
state_dim = 64  # Dimension of the state in SSMLayer
lstm_dim = 64  # Dimension of the hidden state in LSTM
forward_expansion = 4  # Expansion factor in feed forward network
max_length = 512  # Maximum length of the sequence
dropout = 0.1  # Dropout rate


# Create MAMBA model instance
mamba_model = MAMBA(
    vocab_size=vocab_size,
    embed_size=embed_size,
    num_blocks=num_blocks,
    heads=heads,
    state_dim=state_dim,
    lstm_dim=lstm_dim,
    forward_expansion=forward_expansion,
    max_length=max_length,
    dropout=dropout,
    window_size=25  
)

# Define Device for Training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mamba_model = mamba_model.to(device)
# Loss Function and Optimizer
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = AdamW(mamba_model.parameters(), lr=5e-5, weight_decay=1e-2)

# Training Loop
num_epochs = 1
for epoch in range(num_epochs):
    mamba_model.train()
    total_loss = 0

    for batch in train_loader:
        input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
    
        # Forward pass
        outputs = mamba_model(input_ids, attention_mask)
        # Shifted input IDs for labels
        labels = torch.cat([input_ids[:, 1:], input_ids[:, :1]], dim=1)

        # Reshape for loss calculation
        outputs = outputs.view(-1, outputs.size(-1))
        labels = labels.view(-1)

        # Loss calculation
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader)}")




self.value: Linear(in_features=32, out_features=32, bias=False)
self.keys: Linear(in_features=32, out_features=32, bias=False)
self.queries: Linear(in_features=32, out_features=32, bias=False)
self.value: Linear(in_features=32, out_features=32, bias=False)
self.keys: Linear(in_features=32, out_features=32, bias=False)
self.queries: Linear(in_features=32, out_features=32, bias=False)
self.value: Linear(in_features=32, out_features=32, bias=False)
self.keys: Linear(in_features=32, out_features=32, bias=False)
self.queries: Linear(in_features=32, out_features=32, bias=False)
self.value: Linear(in_features=32, out_features=32, bias=False)
self.keys: Linear(in_features=32, out_features=32, bias=False)
self.queries: Linear(in_features=32, out_features=32, bias=False)
self.value: Linear(in_features=32, out_features=32, bias=False)
self.keys: Linear(in_features=32, out_features=32, bias=False)
self.queries: Linear(in_features=32, out_features=32, bias=False)
self.value: Linear(in_features=32, 

RuntimeError: einsum(): subscript h has size 8 for operand 1 which does not broadcast with previously seen size 26

# RAG

In [None]:
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration, BertTokenizer




# Initialize the tokenizer for your custom model

custom_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')  # or another tokenizer if not BERT



# Initialize your custom model (already fine-tuned with DPO)

custom_model = LanguageModelTransformer(...)

custom_model.load_state_dict(torch.load('path_to_fine_tuned_model.pt'))  # Load your fine-tuned model



# Set up RAG components

rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-base")

rag_retriever = RagRetriever.from_pretrained("facebook/rag-sequence-base", index_name="custom", passages_path="path_to_your_passages_file")

rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-base", retriever=rag_retriever)



def unified_inference_pipeline(input_text):

    # Generate initial response using your custom model

    inputs = custom_tokenizer(input_text, return_tensors="pt")

    output = custom_model(**inputs)

    response = custom_tokenizer.decode(output, skip_special_tokens=True)



    # Augment response using RAG

    rag_inputs = rag_tokenizer(response, return_tensors="pt")

    generated_ids = rag_model.generate(rag_inputs["input_ids"])

    augmented_response = rag_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)



    return augmented_response



# Example usage

input_text = "Your input query here"

augmented_response = unified_inference_pipeline(input_text)

print(augmented_response)

# SWITCH ROUTER

In [None]:
#Switch Router:


import torch

import torch.nn as nn

import torch.nn.functional as F



class RoutingLayer(nn.Module):

    def __init__(self, response_dim, num_models):

        super().__init__()

        self.fc = nn.Linear(response_dim * num_models, num_models)

    

    def forward(self, responses):

        # Concatenate responses and pass through the fully connected layer

        combined_responses = torch.cat(responses, dim=1)

        weights = self.fc(combined_responses)

        return F.softmax(weights, dim=1)



class IntegratedModel(nn.Module):

    def __init__(self, model1, model2, model3, response_dim):

        super().__init__()

        self.model1 = model1  # Transformer with DPO

        self.model2 = model2  # RAG

        self.model3 = model3  # MAMBA

        self.routing_layer = RoutingLayer(response_dim, 3)

    

    def forward(self, input_prompt):

        # Assuming each model returns a response vector of the same dimension

        response1 = self.model1(input_prompt)

        response2 = self.model2(input_prompt)

        response3 = self.model3(input_prompt)



        # Routing

        weights = self.routing_layer([response1, response2, response3])



        # Weighted sum of responses

        final_response = weights[:, 0, None] * response1 + weights[:, 1, None] * response2 + weights[:, 2, None] * response3

        return final_response





import torch

import torch.nn as nn

import torch.optim as optim



# Assuming IntegratedModel is already defined and instantiated as integrated_model

# Assuming a DataLoader named 'data_loader' providing input prompts and corresponding target tokens



# Define AdamW optimizer for the integrated model

optimizer = optim.AdamW(integrated_model.parameters(), lr=5e-5)



# Define loss function - Cross-Entropy Loss for language modeling

criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)  # tokenizer.pad_token_id for the padding token ID



num_epochs = 5  # Number of training epochs



for epoch in range(num_epochs):

    integrated_model.train()

    total_loss = 0



    for batch in data_loader:

        input_prompts, target_tokens = batch['input_prompts'], batch['target_tokens']

        

        # Forward pass through integrated model

        optimizer.zero_grad()

        outputs = integrated_model(input_prompts)

        

        # Reshape for calculating loss (Cross-Entropy expects 2D input)

        outputs = outputs.view(-1, outputs.size(-1))

        target_tokens = target_tokens.view(-1)



        # Compute loss

        loss = criterion(outputs, target_tokens)

        total_loss += loss.item()



        # Backward pass and optimize

        loss.backward()

        optimizer.step()



    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {total_loss / len(data_loader)}")



# Save model state if needed

torch.save(integrated_model.state_dict(), 'path_to_save_model.pt')

# SWITCH TRANSFORMER

In [None]:

import torch

import torch.nn as nn

import torch.nn.functional as F





class MoESwitchTransformer(nn.Module):

    def __init__(self, num_experts, response_dim, expert_params):

        super(MoESwitchTransformer, self).__init__()

        self.experts = nn.ModuleList([IntegratedModel(**expert_params) for _ in range(num_experts)])

        self.routing_layer = RoutingLayer(response_dim, num_experts)



    def forward(self, input_prompt):

        expert_responses = [expert(input_prompt) for expert in self.experts]

        # Combine responses for routing

        combined_responses = torch.stack(expert_responses, dim=1)

        # Routing: Shape of weights will be [batch_size, num_experts]

        weights = self.routing_layer(combined_responses)

        # Weighted sum of expert responses

        final_response = torch.sum(weights.unsqueeze(2) * combined_responses, dim=1)

        return final_response



# Define the MoE model

num_experts = 5  # Number of experts

response_dim = 512  # Dimension of response (assumed for illustration)

expert_params = {'model1': model_dpo, 'model2': model_rag, 'model3': model_mamba, 'response_dim': response_dim}



moe_switch_transformer = MoESwitchTransformer(num_experts, response_dim, expert_params)



# Define an input prompt (for example purposes)

input_prompt = torch.rand(1, 512)  # Replace with actual input



# Forward pass

output = moe_switch_transformer(input_prompt)

print(output.shape)  # Output shape will depend on the response_dim



# Training the MoE Switch Transformer

# Define optimizer and loss function as before

optimizer = torch.optim.AdamW(moe_switch_transformer.parameters(), lr=5e-5)

criterion = nn.CrossEntropyLoss()  # Define according to your specific task



# Assume 'data_loader' is provided

num_epochs = 3

for epoch in range(num_epochs):

    moe_switch_transformer.train()

    total_loss = 0

    for batch in data_loader:

        optimizer.zero_grad()

        inputs, targets = batch  # Assuming batch contains inputs and targets

        outputs = moe_switch_transformer(inputs)

        loss = criterion(outputs, targets)

        loss.backward()

        optimizer.step()

        total_loss += loss.item()

    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(data_loader)}')