<a href="https://colab.research.google.com/github/Smol-Kaiju/SmolGhidorah/blob/main/smolGhidorah_Psuedo_MoE_KeywordRouter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers sentence-transformers accelerate

Configure

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer, util

# Configuration
general_model_path = "Qwen/Qwen2.5-3B-Instruct"
math_reasoning_1_model_path = "cutelemonlili/Qwen2.5-1.5B-Instruct_MATH_training_response_Qwen2.5_1.5B_only_right"
reasoning_2_model_path = "prithivMLmods/QwQ-LCoT-3B-Instruct"
embedding_model_name = "all-MiniLM-L6-v2"

Routing

In [None]:
# Keyword-Based Router
reasoning_keywords = ["think", "explain", "why", "how", "infer", "deduce", "analyze", "solve", "reason", "logic", "therefore", "because", "consequence"]

def is_reasoning_query(query):
    query_lower = query.lower()
    for keyword in reasoning_keywords:
        if keyword in query_lower:
            return True
    return False

# Embedding-Based Sub-Router
# Example reasoning queries
math_reasoning_queries_1 = [
    "Solve the equation 2x + 5 = 15.",
    "What is the derivative of x^2?",
    "Calculate the area of a circle with radius 5.",
    "Find the sum of 25 and 37.",
    "What is the probability of rolling a 6 on a fair die?",
    "Simplify the expression (x + 2)(x - 3).",
    "What is the square root of 144?",
    "Find the value of sin(30 degrees).",
    "If f(x) = 3x^2 + 2x - 5, what is f(2)?",
    "What is the slope of the line y = 2x + 3?"
]

reasoning_queries_2 = [
    "How can we solve this problem using a logical approach?",
    "Analyze the different factors contributing to this outcome and explain their causal relationships.",
    "What is the reasoning for the observed phenomenon based on the provided evidence?",
    "Think step by step and solve this problem",
    "Think step by step and reach a conclusion",
    "Break down this taks into easy steps"
]

def generate_embedding(text, model, tokenizer):
    sentences = text if isinstance(text, list) else [text]

    with torch.no_grad():
        embeddings = model.encode(sentences)
        embeddings = torch.tensor(embeddings)

    embeddings = embeddings.to(model.device)

    return embeddings

def calculate_similarity(embedding1, embedding2):
    return util.pytorch_cos_sim(embedding1, embedding2)

# Load embedding model
embedding_model = SentenceTransformer(embedding_model_name)

# Generate and store average embeddings
avg_embedding_reasoning_1 = generate_embedding(math_reasoning_queries_1, embedding_model, embedding_model.tokenizer).mean(dim=0, keepdim=True)
avg_embedding_reasoning_2 = generate_embedding(reasoning_queries_2, embedding_model, embedding_model.tokenizer).mean(dim=0, keepdim=True)

def route_to_reasoning_expert(query):
    query_embedding = generate_embedding(query, embedding_model, embedding_model.tokenizer)
    similarity_1 = calculate_similarity(query_embedding, avg_embedding_reasoning_1)
    similarity_2 = calculate_similarity(query_embedding, avg_embedding_reasoning_2)

    if similarity_1 > similarity_2:
        return math_reasoning_1_model_path
    else:
        return reasoning_2_model_path



Model Loading/Unloading

In [None]:
def load_expert(model_path):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    return model, tokenizer

def unload_expert(model, tokenizer):
    del model
    del tokenizer
    torch.cuda.empty_cache()

# Main Router
def route_query(query):
    if is_reasoning_query(query):
        model_path = route_to_reasoning_expert(query)
    else:
        model_path = general_model_path

    # Load expert
    expert, tokenizer = load_expert(model_path)

    # Generate response
    inputs = tokenizer(query, return_tensors="pt").to(expert.device)
    with torch.no_grad():
        outputs = expert.generate(**inputs, max_new_tokens=200, do_sample=True, top_k=50, top_p=0.95, temperature=0.7)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Unload expert
    unload_expert(expert, tokenizer)

    return response

UI

In [None]:
if __name__ == "__main__":
    while True:
        query = input("Enter your query (or 'quit' to exit): ")
        if query.lower() == "quit":
            break
        response = route_query(query)
        print("Response:", response)