In [None]:
import torch
import torch.nn as nn
import time
import os
from torch.nn import functional as F

# Create folder if not exists
os.makedirs("integration_ablation", exist_ok=True)

# Simulate FP8 Quantization manually
def simulate_fp8(x):
    x = torch.clamp(x, min=-1.0, max=1.0)
    return (x * 127).round() / 127

# LoRA Adapter for Fine-Tuning
class LoRAAdapter(nn.Module):
    def __init__(self, original_layer, rank=8):
        super().__init__()
        self.original_layer = original_layer
        self.lora_A = nn.Linear(original_layer.in_features, rank, bias=False)
        self.lora_B = nn.Linear(rank, original_layer.out_features, bias=False)
        for param in self.original_layer.parameters():
            param.requires_grad = False

    def forward(self, x):
        return self.original_layer(x) + self.lora_B(self.lora_A(x))

# Mixture of Experts Router
class MoERouter(nn.Module):
    def __init__(self, input_dim, num_experts):
        super().__init__()
        self.router = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        logits = self.router(x)
        top2_vals, top2_indices = torch.topk(logits, k=2, dim=-1)
        weights = F.softmax(top2_vals, dim=-1)
        return top2_indices, weights

# Expert Network
class Expert(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        return self.ffn(x)

# Full Integrated Model
class IntegratedModel(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_experts, use_fp8=False, use_moe=False, use_lora=False):
        super().__init__()
        self.use_fp8 = use_fp8
        self.use_moe = use_moe
        self.use_lora = use_lora

        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        if self.use_lora:
            self.q_proj = LoRAAdapter(self.q_proj)
            self.v_proj = LoRAAdapter(self.v_proj)

        if self.use_moe:
            self.router = MoERouter(embed_dim, num_experts)
            self.experts = nn.ModuleList([Expert(embed_dim, hidden_dim) for _ in range(num_experts)])
        else:
            self.ffn = nn.Sequential(
                nn.Linear(embed_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, embed_dim)
            )

    def forward(self, x):
        if self.use_fp8:
            x = simulate_fp8(x)

        q = self.q_proj(x)
        v = self.v_proj(x)

        if self.use_moe:
            top2_indices, weights = self.router(x)
            output = torch.zeros_like(x)
            for i in range(2):
                expert_idx = top2_indices[..., i]
                weight = weights[..., i].unsqueeze(-1)
                expert_outputs = []
                for b in range(x.size(0)):
                    expert = self.experts[expert_idx[b, 0].item()]
                    expert_outputs.append(expert(x[b:b+1]))
                expert_outputs = torch.cat(expert_outputs, dim=0)
                output += expert_outputs * weight
        else:
            output = self.ffn(x)

        return q + v + output

# Benchmarking
def benchmark(model, input_tensor):
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
    model.train()
    start = time.time()
    for _ in range(5):
        optimizer.zero_grad()
        output = model(input_tensor)
        loss = output.mean()
        loss.backward()
        optimizer.step()
    end = time.time()
    avg_time = (end - start) / 5
    return avg_time

# Create input
batch_size = 8
seq_len = 128
embed_dim = 1024
hidden_dim = 4096
num_experts = 16
input_tensor = torch.randn(batch_size, seq_len, embed_dim)

# Test configurations
configs = {
    "Dense Baseline": {"use_fp8": False, "use_moe": False, "use_lora": False},
    "FP8 Only": {"use_fp8": True, "use_moe": False, "use_lora": False},
    "MoE Only": {"use_fp8": False, "use_moe": True, "use_lora": False},
    "QLoRA Only": {"use_fp8": False, "use_moe": False, "use_lora": True},
    "FP8 + MoE": {"use_fp8": True, "use_moe": True, "use_lora": False},
    "FP8 + MoE + QLoRA": {"use_fp8": True, "use_moe": True, "use_lora": True},
}

results = {}

for name, flags in configs.items():
    model = IntegratedModel(embed_dim, hidden_dim, num_experts, **flags)
    avg_time = benchmark(model, input_tensor)
    results[name] = avg_time
    print(f"{name}: {avg_time:.6f} sec/step")

# Save results
result_text = "=== Integration + Ablation Benchmark Results ===\n\n"
for name, avg_time in results.items():
    result_text += f"{name}: {avg_time:.6f} sec/step\n"

with open("integration_ablation/integration_ablation_results.txt", "w") as f:
    f.write(result_text)

print("\nResults saved to: integration_ablation/integration_ablation_results.txt")


Dense Baseline: 1.041147 sec/step
FP8 Only: 0.804552 sec/step
MoE Only: 2.577101 sec/step
QLoRA Only: 0.731166 sec/step
FP8 + MoE: 2.313285 sec/step
FP8 + MoE + QLoRA: 2.144114 sec/step

Results saved to: integration_ablation/integration_ablation_results.txt
