In [None]:
import torch
import torch.nn.functional as F
import time
import os

# Create folder if not exists
os.makedirs("moe_experiments", exist_ok=True)

# MoE Router: Select Top-2 Experts per token
class MoERouter(torch.nn.Module):
    def __init__(self, input_dim, num_experts):
        super().__init__()
        self.router = torch.nn.Linear(input_dim, num_experts)

    def forward(self, x):
        logits = self.router(x)  # [batch, seq_len, num_experts]
        top2_vals, top2_indices = torch.topk(logits, k=2, dim=-1)
        weights = F.softmax(top2_vals, dim=-1)
        return top2_indices, weights

# Expert Network: Simple feedforward for each expert
class Expert(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.ffn = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        return self.ffn(x)

# MoE Model: Top-2 selected experts
class MoEModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, num_experts):
        super().__init__()
        self.router = MoERouter(input_dim, num_experts)
        self.experts = torch.nn.ModuleList([Expert(input_dim, hidden_dim) for _ in range(num_experts)])

    def forward(self, x):
        top2_indices, weights = self.router(x)
        output = torch.zeros_like(x)
        for i in range(2):  # Top-2 experts
            expert_idx = top2_indices[..., i]
            weight = weights[..., i].unsqueeze(-1)
            expert_outputs = []
            for b in range(x.size(0)):
                expert = self.experts[expert_idx[b, 0].item()]  # Batch first
                expert_outputs.append(expert(x[b:b+1]))
            expert_outputs = torch.cat(expert_outputs, dim=0)
            output += expert_outputs * weight
        return output

# Dense Model: Normal feedforward (no MoE)
class DenseModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.ffn = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        return self.ffn(x)

# Benchmarking function
def benchmark(model, input_tensor):
    start = time.time()
    for _ in range(10):
        _ = model(input_tensor)
    end = time.time()
    avg_time = (end - start) / 10
    return avg_time

# Create fake input
batch_size = 8
seq_len = 128
embed_dim = 1024
hidden_dim = 4096
num_experts = 16

input_tensor = torch.randn(batch_size, seq_len, embed_dim)

# Instantiate models
moe_model = MoEModel(embed_dim, hidden_dim, num_experts)
dense_model = DenseModel(embed_dim, hidden_dim)

# Benchmark both models
dense_time = benchmark(dense_model, input_tensor)
moe_time = benchmark(moe_model, input_tensor)

# Calculate simulated efficiency
speedup = dense_time / moe_time

# Print results
print("=== MoE vs Dense Simulation Results ===")
print(f"Average Inference Time (Dense Model): {dense_time:.6f} sec")
print(f"Average Inference Time (MoE Top-2 Model): {moe_time:.6f} sec")
print(f"Simulated Speedup (Dense → MoE Top-2): {speedup:.2f}x")
print(f"Active Parameters per Token (Dense): {num_experts * embed_dim * hidden_dim} (simulated full)")
print(f"Active Parameters per Token (MoE Top-2): {2 * embed_dim * hidden_dim} (2 experts only)")

# Save results to file
result_text = f"""
=== MoE vs Dense Simulation Results ===

Batch Size: {batch_size}
Sequence Length: {seq_len}
Embedding Dimension: {embed_dim}
Hidden Dimension: {hidden_dim}
Number of Experts: {num_experts}

Average Inference Time (Dense Model): {dense_time:.6f} sec
Average Inference Time (MoE Top-2 Model): {moe_time:.6f} sec
Simulated Speedup: {speedup:.2f}x

Active Parameters per Token (Dense): {num_experts * embed_dim * hidden_dim}
Active Parameters per Token (MoE Top-2): {2 * embed_dim * hidden_dim}
"""

with open("moe_experiments/moe_vs_dense_results.txt", "w") as f:
    f.write(result_text)

print("\nResults saved to: moe_experiments/moe_vs_dense_results.txt")


=== MoE vs Dense Simulation Results ===
Average Inference Time (Dense Model): 0.413534 sec
Average Inference Time (MoE Top-2 Model): 0.676468 sec
Simulated Speedup (Dense → MoE Top-2): 0.61x
Active Parameters per Token (Dense): 67108864 (simulated full)
Active Parameters per Token (MoE Top-2): 8388608 (2 experts only)

Results saved to: moe_experiments/moe_vs_dense_results.txt
