In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from torch.utils.data import DataLoader
import os

# Set CUDA_LAUNCH_BLOCKING environment variable
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class ExpertModel(nn.Module):
    def __init__(self, model_name_or_path, tokenizer_name_or_path=None):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
        
        # Use a specific tokenizer if provided, else default to the model's tokenizer
        if tokenizer_name_or_path:
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        
        # Some models may not have a pad token, set it if it's the case
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def tokenize_function(self, text):
        # Tokenize the text input for the specific expert model
        return self.tokenizer(text, max_length=1024, truncation=True, padding="max_length")
    
    def forward(self, input_ids, attention_mask=None):
        # Assuming input_ids are already tokenized and in the expected format for the model
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits  # or return the hidden states or any other representation


class TransformerWithMoE(nn.Module):
    def __init__(self, experts, input_dim, num_experts, top_k, hidden_size, capacity_factor=1.0, alpha=1e-2):
        super().__init__()
        self.experts = nn.ModuleList(experts)
        self.router = nn.Parameter(torch.randn(hidden_size, num_experts))
        self.top_k = top_k
        self.self_attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=4)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.capacity_factor = capacity_factor
        self.alpha = alpha
        self.logits_to_hidden = nn.Linear(50257, hidden_size)  # New layer to transform logits to hidden size

    
    def forward(self, input_ids):
        # Convert input token IDs to embeddings using the first expert model
        # Assuming all expert models have the same embedding layer
        embeddings = self.experts[0].model.transformer.wte(input_ids)
        print(f"embeddings {embeddings.shape}")

        # Apply self-attention
        attn_output, _ = self.self_attn(embeddings, embeddings, embeddings)
        print(f"attn_output {attn_output.shape}")

        x = self.norm1(embeddings + attn_output)
        print(f"x post norm1 {x.shape}")
      
        # Apply self-attention
        attn_output, _ = self.self_attn(x, x, x)
        print(f"Second attn_output {attn_output.shape}")
        
        # Add & Normalize (first residual connection)
        x = self.norm1(x + attn_output)
        print(f"x post second norm1 {x.shape}")

        # Compute logits h(x) for the router
        print(f"Router shape: {self.router.shape}")
        
        # Compute logits h(x) for the router
        logits = x @ self.router
        print(f"logits shape: {logits.shape}")

        # Apply softmax to get gate values p_i(x)
        gate_values = F.softmax(logits, dim=-1)
        print(f"gate_values shape: {gate_values.shape}")
        
        # Get top-k gate values and indices
        topk_gate_values, topk_indices = torch.topk(gate_values, self.top_k, dim=-1)
        print(f"topk_gate_values {topk_gate_values.shape}")
        print(f"topk_indices {topk_indices.shape}")
        
        # Initialize an empty tensor for the output
        output = torch.zeros_like(embeddings)

        # Loop over the top-k experts for each item in the batch
        for i in range(self.top_k):
            expert_outputs = []
            gate_values = topk_gate_values[:, :, i]

            for expert_idx in range(len(self.experts)):
                mask = topk_indices[:, :, i] == expert_idx
                input_masked = input_ids[mask]
                gate_values_masked = gate_values[mask]

                if input_masked.size(0) > 0:
                    expert = self.experts[expert_idx]
                    expert_output = expert(input_masked)
                    expert_output = self.logits_to_hidden(expert_output)  # Transform logits to hidden size

                    # Reshape to match the original batch size and sequence length
                    reshaped_output = torch.zeros_like(embeddings)
                    reshaped_output[mask] = expert_output

                    # Prepare expanded_gate_values tensor
                    expanded_gate_values = torch.zeros_like(embeddings)
                    expanded_gate_values[mask] = gate_values_masked.view(-1, 1).expand(-1, embeddings.size(-1))
                    expert_outputs.append(reshaped_output * expanded_gate_values)

            # Combine the expert outputs
            if expert_outputs:
                expert_contributions = sum(expert_outputs)
                output += expert_contributions


        # Normalize the final output
        output = self.norm2(output)
        
        # Calculate load balancing loss
        f_vector = torch.zeros(self.top_k, device=output.device)
        p_vector = torch.zeros(self.top_k, device=output.device)
        for i in range(self.top_k):
            f_vector[i] = torch.sum(topk_indices == i) / input_ids.size(0)
            p_vector[i] = torch.sum(topk_gate_values[:, i])
        
        loss = self.alpha * self.top_k * torch.sum(f_vector * p_vector)
        
        return output, loss




# Instantiate your models
expert_gpt2 = ExpertModel("./gpt2-codesearchnet-dpo-py-js", "gpt2-medium").to(device)
expert_llama = ExpertModel("gpt2", "gpt2-medium").to(device)
experts = [expert_gpt2, expert_llama]

# Correct the input dimension
input_dim = 32  # This should be the hidden size
batch_size = 16
seq_length = 32
hidden_size = 768

model_with_moe = TransformerWithMoE(
    experts=experts,
    input_dim=hidden_size,
    num_experts=len(experts),
    top_k=2,
    hidden_size=hidden_size,
    capacity_factor=1.0,
    alpha=1e-2
).to(device)

# Load the dataset
dataset = load_dataset("teknium/GPT4-LLM-Cleaned")

def tokenize_function(examples, tokenizer):
    concatenated_texts = [instr + " [SEP] " + inp for instr, inp in zip(examples['instruction'], examples['input'])]
    targets = examples['output']
    
    model_inputs = tokenizer(concatenated_texts, padding='max_length', truncation=True, max_length=512)
    
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, padding='max_length', truncation=True, max_length=128)
    
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

# Instantiate your ExpertModel
expert_gpt2 = ExpertModel("./gpt2-codesearchnet-dpo-py-js", "gpt2-medium").to(device)

# Apply the tokenize function to the dataset using the tokenizer from expert_gpt2
tokenized_datasets = dataset.map(lambda examples: tokenize_function(examples, tokenizer=expert_gpt2.tokenizer), batched=True)

# Format the dataset to output only the necessary columns for training
tokenized_datasets.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

# Optionally, select a subset for training
train_dataset = tokenized_datasets["train"].select(range(0, 2000)) 

# Create dataloader from train dataset
data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Training or evaluation loop
for batch in data_loader:
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    
    # Forward pass through your model
    outputs, loss = model_with_moe(input_ids=input_ids, attention_mask=attention_mask)



# v5_p2

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class ExpertModel(nn.Module):
    def __init__(self, model_name_or_path, tokenizer_name_or_path=None):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
        
        # Use a specific tokenizer if provided, else default to the model's tokenizer
        if tokenizer_name_or_path:
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        
        # Some models may not have a pad token, set it if it's the case
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def tokenize_function(self, text):
        # Tokenize the text input for the specific expert model
        return self.tokenizer(text, max_length=1024, truncation=True, padding="max_length")
    
    def forward(self, input_ids, attention_mask=None):
        # Assuming input_ids are already tokenized and in the expected format for the model
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits  # or return the hidden states or any other representation


class TransformerWithMoE(nn.Module):
    def __init__(self, experts, input_dim, num_experts, top_k, hidden_size, capacity_factor=1.0, alpha=1e-2):
        super().__init__()
        self.experts = nn.ModuleList(experts)
        self.router = nn.Parameter(torch.randn(hidden_size, num_experts))
        self.top_k = top_k
        self.self_attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=4)
        self.norm1 = nn.LayerNorm(hidden_size)
        self.norm2 = nn.LayerNorm(hidden_size)
        self.capacity_factor = capacity_factor
        self.alpha = alpha
        self.logits_to_hidden = nn.Linear(50257, hidden_size)  # New layer to transform logits to hidden size

    
    def forward(self, input_ids, attention_mask=None):
        # Convert input token IDs to embeddings using the first expert model
        # Assuming all expert models have the same embedding layer
        embeddings = self.experts[0].model.transformer.wte(input_ids)
        print(f"embeddings {embeddings.shape}")

        # Apply self-attention
        attn_output, _ = self.self_attn(embeddings, embeddings, embeddings)
        print(f"attn_output {attn_output.shape}")

        x = self.norm1(embeddings + attn_output)
        print(f"x post norm1 {x.shape}")
      
        # Apply self-attention
        attn_output, _ = self.self_attn(x, x, x)
        print(f"Second attn_output {attn_output.shape}")
        
        # Add & Normalize (first residual connection)
        x = self.norm1(x + attn_output)
        print(f"x post second norm1 {x.shape}")

        # Compute logits h(x) for the router
        print(f"Router shape: {self.router.shape}")
        
        # Compute logits h(x) for the router
        logits = x @ self.router
        print(f"logits shape: {logits.shape}")

        # Apply softmax to get gate values p_i(x)
        gate_values = F.softmax(logits, dim=-1)
        print(f"gate_values shape: {gate_values.shape}")
        
        # Get top-k gate values and indices
        topk_gate_values, topk_indices = torch.topk(gate_values, self.top_k, dim=-1)
        print(f"topk_gate_values {topk_gate_values.shape}")
        print(f"topk_indices {topk_indices.shape}")
        
        # Initialize an empty tensor for the output
        output = torch.zeros_like(embeddings)

        # Loop over the top-k experts for each item in the batch
        for i in range(self.top_k):
            expert_outputs = []
            gate_values = topk_gate_values[:, :, i]

            for expert_idx in range(len(self.experts)):
                mask = topk_indices[:, :, i] == expert_idx
                input_masked = input_ids[mask]
                gate_values_masked = gate_values[mask]

                if input_masked.size(0) > 0:
                    expert = self.experts[expert_idx]
                    #expert_output = expert(input_masked)
                    expert_output = expert(input_masked, attention_mask=attention_mask[mask] if attention_mask is not None else None)

                    expert_output = self.logits_to_hidden(expert_output)  # Transform logits to hidden size

                    # Reshape to match the original batch size and sequence length
                    reshaped_output = torch.zeros_like(embeddings)
                    reshaped_output[mask] = expert_output

                    # Prepare expanded_gate_values tensor
                    expanded_gate_values = torch.zeros_like(embeddings)
                    expanded_gate_values[mask] = gate_values_masked.view(-1, 1).expand(-1, embeddings.size(-1))
                    expert_outputs.append(reshaped_output * expanded_gate_values)

            # Combine the expert outputs
            if expert_outputs:
                expert_contributions = sum(expert_outputs)
                output += expert_contributions


        # Normalize the final output
        output = self.norm2(output)
        
        # Calculate load balancing loss
        f_vector = torch.zeros(self.top_k, device=output.device)
        p_vector = torch.zeros(self.top_k, device=output.device)
        for i in range(self.top_k):
            f_vector[i] = torch.sum(topk_indices == i) / input_ids.size(0)
            p_vector[i] = torch.sum(topk_gate_values[:, i])
        
        loss = self.alpha * self.top_k * torch.sum(f_vector * p_vector)
        
        return output, loss




# Instantiate your models
expert_gpt2 = ExpertModel("./gpt2-codesearchnet-dpo-py-js", "gpt2-medium").to(device)
expert_llama = ExpertModel("gpt2", "gpt2-medium").to(device)
experts = [expert_gpt2, expert_llama]

# Correct the input dimension
input_dim = 32  # This should be the hidden size
batch_size = 1
seq_length = 32
hidden_size = 768

model_with_moe = TransformerWithMoE(
    experts=experts,
    input_dim=hidden_size,
    num_experts=len(experts),
    top_k=2,
    hidden_size=hidden_size,
    capacity_factor=1.0,
    alpha=1e-2
).to(device)

# Load the dataset
dataset = load_dataset("teknium/GPT4-LLM-Cleaned")

def tokenize_function(examples, tokenizer):
    concatenated_texts = [instr + " [SEP] " + inp for instr, inp in zip(examples['instruction'], examples['input'])]
    targets = examples['output']
    
    model_inputs = tokenizer(concatenated_texts, padding='max_length', truncation=True, max_length=512)
    
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, padding='max_length', truncation=True, max_length=128)
    
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

# Instantiate your ExpertModel
expert_gpt2 = ExpertModel("./gpt2-codesearchnet-dpo-py-js", "gpt2-medium").to(device)

# Apply the tokenize function to the dataset using the tokenizer from expert_gpt2
tokenized_datasets = dataset.map(lambda examples: tokenize_function(examples, tokenizer=expert_gpt2.tokenizer), batched=True)

# Format the dataset to output only the necessary columns for training
tokenized_datasets.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

# Optionally, select a subset for training
train_dataset = tokenized_datasets["train"].select(range(0, 100)) 

# Create dataloader from train dataset
data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Training or evaluation loop
for batch in data_loader:
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    
    # Forward pass through your model
    outputs, loss = model_with_moe(input_ids=input_ids, attention_mask=attention_mask)





Map:   0%|          | 0/54568 [00:00<?, ? examples/s]



embeddings torch.Size([1, 512, 768])
attn_output torch.Size([1, 512, 768])
x post norm1 torch.Size([1, 512, 768])
Second attn_output torch.Size([1, 512, 768])
x post second norm1 torch.Size([1, 512, 768])
Router shape: torch.Size([768, 2])
logits shape: torch.Size([1, 512, 2])
gate_values shape: torch.Size([1, 512, 2])
topk_gate_values torch.Size([1, 512, 2])
topk_indices torch.Size([1, 512, 2])
embeddings torch.Size([1, 512, 768])
attn_output torch.Size([1, 512, 768])
x post norm1 torch.Size([1, 512, 768])
Second attn_output torch.Size([1, 512, 768])
x post second norm1 torch.Size([1, 512, 768])
Router shape: torch.Size([768, 2])
logits shape: torch.Size([1, 512, 2])
gate_values shape: torch.Size([1, 512, 2])
topk_gate_values torch.Size([1, 512, 2])
topk_indices torch.Size([1, 512, 2])
embeddings torch.Size([1, 512, 768])
attn_output torch.Size([1, 512, 768])
x post norm1 torch.Size([1, 512, 768])
Second attn_output torch.Size([1, 512, 768])
x post second norm1 torch.Size([1, 512, 768

KeyboardInterrupt: 