# **BUD Ecosystem Inc. | NLI Research | BUD *Muffler*-$\alpha$ | Shreyan Chaubey**

This notebook is licensed under the GNU Affero General Public License v3.0 (AGPL-3.0).
See the LICENSE file in the GitHub repository for details. For the full terms, read the [AGPL license text](https://www.gnu.org/licenses/agpl-3.0.en.html).
---

### **BUD Muffler-$\alpha$: A Binary Language Model Router | Author: Shreyan C**

The notebook has been developed as part of a research initiative at BUD to survey the capabilities of the current SOTA techniques in the area of natural language inference optimization, more specficially, language model routing. In this notebook, we present a faithful reproduction of a SOTA methodology presented in the 2024 paper '**RouteLLM: Learning to Route LLMs with Preference Data**' ([Ong et. al.](https://arxiv.org/pdf/2406.18665)). This notebook covers everything from data preparation to router benchmarking, and provides an in-depth look at how the Muffler-$\alpha$ router functions, including its model routing logic, training approach, and cost-adjustment mechanisms. It can be useful for researchers or developers seeking to optimize language model usage in cost-sensitive environments or for those interested in implementing cost efficient routing mechanisms.

### Model Recipe - Building the Dataset

Before defining the architecture for the Language Model (LM) router, it is essential to take foundational steps to ensure effective model convergence. A critical part of this process involves creating a well-structured, supervised training dataset. Such a dataset might have the following structure:

```
id [INT] | model_a [STR] | model_b [STR] | prompt [STR] | winner_model_a [0/1] | winner_model_b [0/1] | tie_ab [0/1]
```

Each row represents a comparison between two models (`model_a` and `model_b`) based on their responses to a specific `prompt`. The outcomes (`winner_model_a`, `winner_model_b`, or `tie_ab`) indicate which model performed better or if there was a tie.

#### Possible Augmentations

##### 1. **Benchmark Augmented Datasets**
   - These datasets are derived by evaluating models across well-established benchmarks. Benchmarks typically focus on specific domains of reasoning.
   - By comparing the performance of a stronger model $M_{strong}$ and a weaker model $M_{weak}$ on these benchmarks, metrics such as *accuracy* or *exact match* can be collected. These metrics determine which model performs better for a given prompt. 
   - Such datasets are often referred to as *Golden Labelled Datasets* because they provide a clear indication of model aptitude within specific domains.

##### 2. **Human Augmented Datasets**
   - Publicly available datasets that capture human preferences for language model responses can also be used. Examples include:
     - **`lmarena-ai/arena-human-preference-55k`**
     - **`lmarena-ai/PPE-Human-Preference-V1`**
   - These datasets reflect the nuanced preferences humans may have for certain types of responses. However, a potential downside is that they may introduce inherent biases into the training data. These biases can affect the router's performance in unpredictable ways. Moreover, obtaining a sufficient amount of datapoints in such augmentations is a costly & time consuming business.

##### 3. **LLM Judge Augmented Datasets**
   - To reduce reliance on human-annotated preferences and address potential human biases, a *Language Model Judge* (LLM Judge) can be employed. This judge compares responses from an LLM and an SLM (a smaller language model) and assigns a label to the better-performing model.
   - The resulting datasets, known as *LLM Preference Datasets*, aim to offer a more balanced and scalable alternative to human-augmented datasets.

**Each dataset type comes with strengths and limitations:**
- **Benchmark Augmented Datasets** ensure strong domain coverage but may lack the richness of human preferences.
- **Human Preference Datasets** capture real-world preferences but risk introducing subjective biases.
- **LLM Preference Datasets** strike a balance by leveraging the capabilities of an LLM as a neutral arbiter, though the quality of the "judge" becomes critical.

We thus begin by importing relevant python libraries.

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import datasets
from tqdm import tqdm, trange
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, roc_curve, auc
from sklearn.model_selection import train_test_split
from sentence_transformers import SentenceTransformer
import safetensors.torch as st

# ----------------- #
proj_dir = "." # Replace with '.' if running locally.
custom_ds_dir = "."
# ----------------- #

Here are the various kinds of preference datasets that can be used:

1. **LMSYS Arena Human Preference 55K (LMSYS)** [Human Preference]:
    -  This dataset contains human preferences collected from the LMSYS Arena. It includes comparisons across 64 language models based on human judgments.
    - This dataset is useful for training the router to understand *human preferences* and make routing decisions accordingly. However, this dataset is not suitable for model deployment because of *i) A large number of target models 2) Most target models are obsolete. 3) Human preferences are prone to a number of biases.*

2. **GPT-4 Judge Battles (RouteLLM)** [LLM Preference]:
    - This dataset includes comparisons where GPT-4 acts as a judge to determine the better response between itself and Mixtral 8x7B Mixture of Experts Language Model.
    - Helps in training the router to align with GPT-4's judgment when obtaining human preference labels is a costly business.
    - Prompts are synthetically generated.

3. **Arena Human Preference + Judge Battles** [Human + LLM Preference]:
    - A combination of the LMSYS Arena Human Preference 55K and GPT-4 Judge Battles datasets.

4. **BUD Complexity Corpus (BUD Ecosystems Inc.)** [Benchmark Augmented/Golden Labelled]:
    - This dataset focuses on complex queries and includes comparisons between models like Gemma-2-27b-it and Llama 3.2 1B.
    - Specifically useful for training the router to handle complex queries and make decisions based on the complexity of the task.
    
5. **Custom Dataset**:
    - A custom dataset created from various sources and processed into a parquet format.
    - Allows for flexibility in training the router with specific preferences and scenarios tailored to the user's needs.

In [None]:
def decide_dataset(choice):
    if choice == 0:
        ds_name = "lmsys/lmsys-arena-human-preference-55k"
        ds = load_dataset("lmsys/lmsys-arena-human-preference-55k")["train"].to_pandas()
        ds = ds[(ds['model_a'] == 'gpt-4-1106-preview') & (ds['model_b'] == 'mixtral-8x7b-instruct-v0.1')]
    elif choice == 1:
        ds_name = "routellm/gpt4_judge_battles"
        ds = load_dataset("routellm/gpt4_judge_battles")["train"].to_pandas()
    elif choice == 2:
        ds_name = "lmsys/lmsys-arena-human-preference-55k and routellm/gpt4_judge_battles"
        ds1 = load_dataset("lmsys/lmsys-arena-human-preference-55k")["train"].to_pandas()
        ds2 = load_dataset("routellm/gpt4_judge_battles")["train"].to_pandas()
        ds = pd.concat([ds1, ds2], ignore_index=True)
        ds = ds[(ds['model_a'] == 'gpt-4-1106-preview') & (ds['model_b'] == 'mixtral-8x7b-instruct-v0.1')]
    elif choice == 3:
        ds_name = "own-dataset"
        ds = pd.read_csv(f'{custom_ds_dir}/combined_dataset.csv')
    return ds

ds_name = None
choice = 1  # 0: lmsys-arena-human-preference-55k, 1: gpt4_judge_battles, 2: both (1+2), 3: custom dataset
ds = decide_dataset(choice)
print('Shape of the dataset after removing ties:', ds.shape)
ds.head(5)

We proceed with a 95% train 5% validation split ratio.

In [None]:
split_ratio = 0.95
balanced_flag = False
train_ds, test_ds = train_test_split(ds, test_size=1-split_ratio, shuffle=False, random_state=None, stratify=None)
print(f"Shape of the train dataset: {train_ds.shape}\nShape of the test dataset: {test_ds.shape}")

# Dataset splits can be accessed in the '~/output/dataset/raw/' directory.
os.makedirs(f"{proj_dir}/output/dataset/raw/train/", exist_ok=True)
train_ds.to_parquet(f"{proj_dir}/output/dataset/raw/train/train_ds.parquet", index=True) # train_ds

os.makedirs(f"{proj_dir}/output/dataset/raw/test/", exist_ok=True)
test_ds.to_parquet(f"{proj_dir}/output/dataset/raw/test/test_ds.parquet", index=True) # test_ds

print(f"\nTrain dataset saved at '~/output/dataset/raw/train/train_ds.parquet'")
display(train_ds.head(5))
print("\nTest Dataset saved at '~/output/dataset/raw/test/test_ds.parquet'")
display(test_ds.head(5))

Encodings happen as follows:
- Model A - Integer ID (Stored in MODEL_ID)
- Model B - Integer ID (Stored in MODEL_ID)
- Prompt - Case Index
- Winner - ('model_a'/'model_b')

Tie positives (True, True) are encoded as 'tie_positive' whereas tie negatives (False, False) are encoded as 'tie_negative'.

In [None]:
model_encoder = LabelEncoder()
model_encoder.fit(pd.concat([train_ds['model_a'], train_ds['model_b']]).unique())

MODEL_IDS = {
    model: model_encoder.transform([model])[0] for model in model_encoder.classes_
}

print(f"Identified models : {MODEL_IDS}")

train_ds = train_ds.reset_index(drop=True)
test_ds = test_ds.reset_index(drop=True)


train_ds_enc = pd.DataFrame()

train_ds_enc['model_a_idx'] = train_ds['model_a'].map(MODEL_IDS)
train_ds_enc['model_b_idx'] = train_ds['model_b'].map(MODEL_IDS)
train_ds_enc['prompt_idx'] = train_ds.index

train_ds_enc['winner'] = np.where(
    (train_ds['winner_model_a'] == 1) & (train_ds['winner_model_b'] == 0),
    "model_a",
    np.where(
        (train_ds['winner_model_a'] == 0) & (train_ds['winner_model_b'] == 1),
        "model_b",
        np.where(
            (train_ds['winner_model_a'] == 1) & (train_ds['winner_model_b'] == 1),  # Both win
            #"model_b", "model_a" # Removing tie breaking mechanism for now.
            "tie_positive", "tie_negative" # Due to the lack of tie-positive cases in GPT4-Judge-Battles ds, former label won't be observed)
        )
    )
)

display(train_ds_enc.head(10))

test_ds_enc = pd.DataFrame()

test_ds_enc['model_a_idx'] = test_ds['model_a'].map(MODEL_IDS)
test_ds_enc['model_b_idx'] = test_ds['model_b'].map(MODEL_IDS)
test_ds_enc['prompt_idx'] = test_ds.index

test_ds_enc['winner'] = np.where(
    (test_ds['winner_model_a'] == 1) & (test_ds['winner_model_b'] == 0),
    "model_a",
    np.where(
        (test_ds['winner_model_a'] == 0) & (test_ds['winner_model_b'] == 1),
        "model_b",
        np.where(
            (test_ds['winner_0model_a'] == 1) & (test_ds['winner_model_b'] == 1),  # Both win
            #"model_b", "model_a" # Removing tie breaking mechanism for now.
            "tie_positive", "tie_negative" # Due to the lack of tie-positive cases in GPT4-Judge-Battles ds, former label won't be observed. This modification's for this specific ds only.
        )
    )
)
display(test_ds_enc.head(10))

if choice == 3:
    strong_model = 'gemma-2-27b-it'
    weak_model = 'phi-3.5-mini-instruct'
    print(f"Strong model (a): {strong_model}\nWeak model (b): {weak_model}")

elif choice == 4:
    strong_model = 'Gemma-2-27b-it'
    weak_model = 'Llama-3.2-1B-Instruct'
    print(f"Strong model (a): {strong_model}\nWeak model (b): {weak_model}")

elif choice == 1:
    strong_model = 'gpt-4-1106-preview'
    weak_model = 'mixtral-8x7b-instruct-v0.1'
    print(f"Strong model (a): {strong_model}\nWeak model (b): {weak_model}")


# Check if the strong and weak models are present in the MODEL_IDS dictionary
if strong_model and weak_model not in MODEL_IDS:
    raise ValueError("The strong and weak models must be present in the dataset.")

We now save encoded train & test splits locally, compute ratios in the training split

In [None]:
os.makedirs(f"{proj_dir}/output/dataset/encoded/train/", exist_ok=True)
train_ds_enc.to_parquet(f"{proj_dir}/output/dataset/encoded/train/train_ds_encoded.parquet")
train_ds_enc = pd.read_parquet(f"{proj_dir}/output/dataset/encoded/train/train_ds_encoded.parquet")
print('Shape of the encoded training dataset:', train_ds_enc.shape)
print("Train dataset (encoded) saved at '~/output/dataset/encoded/train/train_ds_encoded.parquet'")
print(train_ds_enc.head(10))

os.makedirs(f"{proj_dir}/output/dataset/encoded/test/", exist_ok=True)
test_ds_enc.to_parquet(f"{proj_dir}/output/dataset/encoded/test/test_ds_encoded.parquet")
test_ds_enc = pd.read_parquet(f"{proj_dir}/output/dataset/encoded/test/test_ds_encoded.parquet")
print("\nShape of the encoded test dataset:", test_ds_enc.shape)
print("Test dataset (encoded) saved at '~/output/dataset/encoded/test/test_ds_encoded.parquet'")
print(test_ds_enc.head(10))

# Calculate ratios of wins by model a and model b

# Assuming train_ds_enc contains the 'winner' column with "model_a" or "model_b" strings
total_samples = len(train_ds_enc)

# Count the wins for model_a and model_b
model_a_wins = (train_ds_enc['winner'] == "model_a").sum()
model_b_wins = (train_ds_enc['winner'] == "model_b").sum()

# Calculate the ratios
model_a_win_ratio = model_a_wins / total_samples
model_b_win_ratio = model_b_wins / total_samples

# Print the results
print(f"\nModel A Win Ratio: {model_a_win_ratio:.2%}")
print(f"Model B Win Ratio: {model_b_win_ratio:.2%}")
print(f"Tie A (Positive) Ratio: {(train_ds_enc['winner'] == 'tie_positive').sum() / total_samples:.2%}")
print(f"Tie B (Negative) Ratio: {(train_ds_enc['winner'] == 'tie_negative').sum() / total_samples:.2%}")


# **Implementing Randomized Minority Oversampling Technique (RMOTE) for Mitigating Load Imbalance** (Skippable)

In [None]:
"""

def balance_classes(df, target_column, random_state=42):
    class_counts = df[target_column].value_counts()
    minority_class = class_counts.idxmin()
    majority_class = class_counts.idxmax()
    minority_count = class_counts[minority_class]
    original_size = len(df)

    majority_samples = df[df[target_column] == majority_class].sample(
        n=minority_count, random_state=random_state
    )
    minority_samples = df[df[target_column] == minority_class]

    balanced_df = pd.concat([minority_samples, majority_samples])
    if len(balanced_df) < original_size:
        additional_samples = balanced_df.sample(
            n=original_size - len(balanced_df), random_state=random_state, replace=True
        )
        balanced_df = pd.concat([balanced_df, additional_samples])

    balanced_df = balanced_df.sample(frac=1, random_state=random_state)
    return balanced_df

balanced_flag = True
train_ds_enc = balance_classes(train_ds_enc, target_column='winner', random_state=42)
test_ds_enc = balance_classes(test_ds_enc, target_column='winner', random_state=42)

print('Shape of the balanced training dataset:', train_ds_enc.shape)
print('Shape of the balanced test dataset:', test_ds_enc.shape)
display(train_ds_enc.head(10))
display(test_ds_enc.head(10))

for split, dataset in [('Train', train_ds_enc), ('Test', test_ds_enc)]:
    class_counts = dataset['winner'].value_counts(normalize=True) * 100
    print(f"\n{split} dataset balanced.")
    print(f"Model A Win Ratio: {class_counts.get('model_a', 0):.2f}%")
    print(f"Model B Win Ratio: {class_counts.get('model_b', 0):.2f}%")

"""


# Setting up Database & Dataloaders

In [None]:
class BinaryClassificationDataset(Dataset):
    def __init__(self, dataset):
      self.model_a = torch.tensor(dataset['model_a_idx'].tolist())
      self.model_b = torch.tensor(dataset['model_b_idx'].tolist())
      self.prompt_id = dataset['prompt_idx'].tolist()
      self.winner = dataset['winner'].tolist()

      assert len(self.model_a) == len(self.model_b) == len(self.prompt_id) == len(self.winner), "Data lengths do not match."

    def __len__(self):
        return len(self.model_a)

    def __getitem__(self, idx):
      if self.winner[idx] == "model_a":
          return self.model_a[idx], self.model_b[idx], self.prompt_id[idx], 1
      else:
          return self.model_b[idx], self.model_a[idx], self.prompt_id[idx], 0

    @staticmethod
    def get_dataloaders(dataset, batch_size=64, shuffle=False):
        return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

batch_size = 64
print("Performing retrievals on 'train_ds_enc' for the first five indices:")
print("Winner, Loser, Prompt ID, Label")
for i in range(5):
    print(BinaryClassificationDataset(train_ds_enc)[i])

print("\nPerforming retrievals on 'test_ds_enc' for the first five indices:")
print("Winner, Loser, Prompt ID, Label")
for i in range(5):
    print(BinaryClassificationDataset(test_ds_enc)[i])

## Generate static embeddings
We now generate initializations of the $Q$ matrix in the model.

In [None]:
"""
Try out different models from the Sentence Transformers library:
- Need to implement support for infgrad/stella_en_400M_v5 (1024-dimensional embeddings) (Good performance, as claimed by RouteLLM authors)
- all-mpnet-base-v2 (768-dimensional embeddings) # Okayish performance.
- all-MiniLM-L6-v2 (384-dimensional embeddings) # Worse. Not recommended.
"""

#model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", trust_remote_code=True, cache_folder=f"{proj_dir}/output/embeddings/embedding-model/", local_files_only=True)
embedding_generator = "sentence-transformers/all-mpnet-base-v2"
embedding_model = SentenceTransformer(embedding_generator, trust_remote_code=True)

In [None]:
def calculate_win_percentages(dataset):
    model_a_wins = 0
    model_b_wins = 0

    for winner in dataset.winner:
        if winner == "model_a":
            model_a_wins += 1
        elif winner == "model_b":
            model_b_wins += 1

    total_matches = model_a_wins + model_b_wins

    model_a_win_percentage = (model_a_wins / total_matches) * 100 if total_matches > 0 else 0
    model_b_win_percentage = (model_b_wins / total_matches) * 100 if total_matches > 0 else 0

    return model_a_win_percentage, model_b_win_percentage

print("\nTraining dataset:")
model_a_win_percentage, model_b_win_percentage = calculate_win_percentages(train_ds_enc)
print(f"Model A Win Percentage: {model_a_win_percentage:.2f}%")
print(f"Model B Win Percentage: {model_b_win_percentage:.2f}%")

print("\nTest dataset:")
model_a_win_percentage, model_b_win_percentage = calculate_win_percentages(test_ds_enc)
print(f"Model A Win Percentage: {model_a_win_percentage:.2f}%")
print(f"Model B Win Percentage: {model_b_win_percentage:.2f}%")


In [None]:
def load_npy_file(npy_location):
    embeddings = np.load(npy_location)
    return embeddings, embeddings.shape[0]

def match_shape(embeddings, split_size):
    return embeddings.shape[0] == split_size

def generate_embeddings(split, model):
    prompts = train_ds['prompt'].tolist() if split == 'train' else test_ds['prompt'].tolist()
    embeddings = embedding_model.encode(
        prompts, normalize_embeddings=True, convert_to_numpy=True, show_progress_bar=True, batch_size=1
    )
    return embeddings

embeddings_path = f"{proj_dir}/output/embeddings/"

if not os.path.exists(embeddings_path):
    os.makedirs(embeddings_path, exist_ok=True)
    print(f"'embeddings' directory created at {embeddings_path}.")

# Process embeddings for each split
for split in ["train", "test"]:
    split_size = len(train_ds_enc) if split == 'train' else len(test_ds_enc)
    file_path = f"{embeddings_path}/query_embeddings_{split}.npy"
    
    if os.path.exists(file_path):
        query_embeddings, embed_dim = load_npy_file(file_path)
        print(f"Precomputed embeddings for {split} split loaded successfully.")
        assert match_shape(query_embeddings, split_size), f"Embedding shape mismatch for {split} split."
        

    else:
        print(f"No precomputed embeddings found for {split} split. Generating embeddings.")
        query_embeddings = generate_embeddings(split, embedding_model)
        np.save(file_path, query_embeddings)
        print(f"Saved embeddings for {split} split at {file_path}.")


# Model

The logits for the model selection are computed using the following formula:
  $$z = \mathbf{W}_c \cdot \left( (\mathbf{P}(m_{\text{win}}) - \mathbf{P}(m_{\text{loss}})) \odot \mathbf{Q}(q) \right)$$
  Where:
  - $\mathbf{P}(m_{\text{win}})$ and $\mathbf{P}(m_{\text{loss}})$ represent learnable embeddings of the winning and losing models, respectively.
  - $\mathbf{Q}(q)$ is the frozen query embedding generated by a sentence transformer, used for comparison with model embeddings.
  - $\mathbf{W}_c$ is a linear classifier that maps the interaction between the model-query pair to a logit.
  - $\odot$ denotes the element-wise (Hadamard) product.

In [None]:
class MFGate(nn.Module):
      def __init__(
          self,
          num_models=2,
          model_dim=128,
          num_queries=len(train_ds['prompt']),
          text_dim=768,
          num_output_nodes=1,
          down_projection=True,
          embedding_path=None,
          ):
        
        super(MFGate, self).__init__()
        
        # P (Model embeddings)
        self.num_models = num_models
        self.P_dim = model_dim
        self.P = nn.Embedding(self.num_models, self.P_dim).requires_grad_(True)
        #self.P.weight.data.normal_(0, 1) # Initialize model embeddings with normal distribution (experimental)

        # Q (Prompt/Query embeddings)
        self.num_queries = num_queries
        self.Q_dim = text_dim
        self.Q = nn.Embedding(self.num_queries, self.Q_dim).requires_grad_(False) # Query embeddings are frozen. You dont't want to update them.

        if embedding_path is not None:
          self.checkpointfile = embedding_path
          try:
            embeddings = np.load(self.checkpointfile)
            self.Q.weight.data.copy_(torch.tensor(embeddings))
          except FileNotFoundError:
            print("Could not load prompt embeddings from file. Please check the path.")

        # W_proj (Q -> P) - Down-projection layer
        self.down_proj = down_projection
        if self.down_proj:
          self.downproj_QtoP = nn.Linear(self.Q_dim, self.P_dim, bias=False).requires_grad_(True)
          #self.downproj_QtoP.weight.data.normal_(0, 1) # Initialize down-projection layer with normal distribution (experimental)
          
        else:
          assert self.Q_dim == self.P_dim, "If not using projection: make sure Q_dim and P_dim match."
          self.downproj_QtoP = None

        # Linear classifier
        self.num_ways = num_output_nodes
        self.Classifier = nn.Linear(self.P_dim, self.num_ways, bias=False).requires_grad_(True)
        #self.Classifier.weight.data.normal_(0, 1) # Initialize classifier layer with normal distribution (experimental)

      def get_device(self):
        return next(self.parameters()).device

      def forward(self, model_win, model_loss, prompt, pe_noise=0):
        
        model_win = model_win.to(self.get_device())
        model_loss = model_loss.to(self.get_device())
        prompt = prompt.to(self.get_device())

        model_strong_embedding = self.P(model_win)
        model_strong_embedding = F.normalize(model_strong_embedding, p=2, dim=1)
        model_weak_embedding = self.P(model_loss)
        model_weak_embedding = F.normalize(model_weak_embedding, p=2, dim=1)
        prompt_embedding = self.Q(prompt)

        #if prompt_embedding.norm(p=2, dim=1).mean() > 1.001: # Uncomment if not already normalized.
          #prompt_embedding = F.normalize(prompt_embedding, p=2, dim=1) # Normalize the prompt embeddings if they are not already normalized.

        if pe_noise > 0:
          prompt_embedding += pe_noise * torch.randn_like(prompt_embedding)
        if self.down_proj:
          prompt_embedding = self.downproj_QtoP(prompt_embedding)
        return self.Classifier((model_strong_embedding - model_weak_embedding) * prompt_embedding).squeeze()

      @torch.no_grad()
      def predict(self, model_win, model_loss, prompt, pe_noise=0.05):
        logits = self.forward(model_win, model_loss, prompt, pe_noise)
        return logits > 0

# Training Framework

The model is trained using a **Binary Cross-Entropy with Logits Loss** function:
$$L = - \frac{1}{N} \sum_{i=1}^{N} \left[ y_i \cdot \log(\sigma(z_i)) + (1 - y_i) \cdot \log(1 - \sigma(z_i)) \right]$$
Which applies a sigmoid activation of the form $P(\text{M}_{strong} | q) = \sigma(z) = \frac{1}{1 + e^{-z}}$
to the logits before computing the error term. Where $y_i$ is the binary target label for each query (1 for the stronger model, 0 for the weaker model). The loss function essentially becomes:
$$L = - \frac{1}{N} \sum_{i=1}^{N} \left[ y_i \cdot \log(\sigma(z_i))\right]$$

The loss measures the error between the predicted logits and the target label of ones, which essentially converts the problem into likelihood maximization task, where the model is encouraged to assign a higher score to the stronger model (LLM) relative to the weaker model (SLM).

In [None]:
class MFTrainer:
    def __init__(
            self,
            device,
            model,
            train_loader,
            val_loader,
            noise=0.0,
            optimizer=None,
            scheduler=None,
            early_stopping_patience=3
    ):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.model = model.to(self.device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.train_criterion = nn.BCEWithLogitsLoss(reduction="mean")
        self.val_criterion = nn.BCEWithLogitsLoss(reduction="sum")
        self.noise = noise
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.best_val_loss = float('inf')
        self.early_stopping_patience = early_stopping_patience
        self.early_stopping_counter = 0
        self.train_losses = []
        self.val_losses = []
        self.weight_norms = []

    def _compute_weight_norm(self):
        weight_norm = 0.0
        for param in self.model.parameters():
            if param.requires_grad:
                weight_norm += param.data.norm().item() ** 2
        return weight_norm ** 0.5

    def train_epoch(self):
        """
        Run a single training pass.
        """
        self.model.train()
        running_loss = 0.0

        for models_a, models_b, prompts, labels in tqdm(self.train_loader, desc="Training...", leave=False):
            models_a, models_b, prompts, labels = self._move_to_device(models_a, models_b, prompts, labels)
            self.optimizer.zero_grad()         

            # Forward pass
            outputs = self.model(model_win=models_a, model_loss=models_b, prompt=prompts, pe_noise=0.0)
            # Compute loss
            loss = self.train_criterion(outputs, torch.ones_like(outputs)) # BCEWithLogitsLoss computes loss after passing applying a sigmoid transformation for normalization.
            # Backpropagate
            loss.backward()
            # Monitor gradient norms before clipping
            #torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            self.optimizer.step()

            running_loss += loss.item() * models_a.size(0)
        epoch_loss = running_loss / len(self.train_loader.dataset)
        weight_norm = self._compute_weight_norm()
        self.weight_norms.append(weight_norm)

        return epoch_loss

    def validate_epoch(self):
        """
        Run a single validation pass.
        """
        self.model.eval()
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0

        with torch.no_grad():
            for models_a, models_b, prompts, labels in tqdm(self.val_loader, desc="Validating...", leave=False):
                models_a, models_b, prompts, labels = self._move_to_device(models_a, models_b, prompts, labels)

                # Forward pass
                outputs = self.model(model_win=models_a, model_loss=models_b, prompt=prompts, pe_noise=0.05)
                # Compute loss
                loss = self.val_criterion(outputs, torch.ones_like(outputs))
                running_loss += loss.item()
                # Compute predictions and accuracy
                predictions = self.model.predict(models_a, models_b, prompts, pe_noise=0.1)
                correct_predictions += (predictions == torch.ones_like(predictions)).sum().item()
                total_samples += torch.ones_like(predictions).size(0)

        epoch_loss = running_loss / len(self.val_loader.dataset)
        accuracy = correct_predictions / total_samples

        return epoch_loss, accuracy

    def train(self, num_epochs, save_best=1):
        """
        Main training loop that runs over multiple epochs.
        """
        for epoch in range(num_epochs):
            epoch_str = f"Epoch {epoch + 1}/{num_epochs}"
            
            train_loss = self.train_epoch()
            val_loss, val_acc = self.validate_epoch()
            
            if save_best:
                is_best = val_loss < self.best_val_loss
            else:
                is_best = True
            
            self._save_best_model(val_loss, train_loss, epoch, num_epochs, best=is_best)

            if self.scheduler:
                self.scheduler.step(val_loss)

            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)

            if not self._early_stopping(val_loss):
                break

        self.plot_training_curve()
        self.plot_weight_norms()

    def plot_weight_norms(self):
        """
        Plot the weight norm over epochs to track vanishing/exploding gradients.
        """
        epochs = len(self.weight_norms)
        plt.figure(figsize=(8, 5))
        plt.plot(range(epochs), self.weight_norms, label="Weight Norm")
        plt.title("Weight Norm over Epochs")
        plt.xlabel("Epochs")
        plt.ylabel("Weight Norm (L2 Norm)")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()


    def _move_to_device(self, models_a, models_b, prompts, labels):
        models_a = models_a.to(self.device)
        models_b = models_b.to(self.device)
        prompts = prompts.to(self.device)
        labels = labels.to(self.device)
        return models_a, models_b, prompts, labels

    def _save_best_model(self, val_loss, train_loss, epoch, num_epochs, best=False):
        if best:
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                checkpoint = {
                    'model': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'scheduler': self.scheduler.state_dict() if self.scheduler else None,
                    'epoch': epoch,
                    'train_loss': train_loss,
                    'val_loss': val_loss
                }
                torch.save(checkpoint, f"{proj_dir}/output/best_train.pth")
                print(f"[Best model saved at epoch {epoch + 1} with Validation Loss: {val_loss:.4f}]")
                self.early_stopping_counter = 0
            else:
                self.early_stopping_counter += 1
        else:
            # Save model unconditionally
            checkpoint = {
                'model': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'scheduler': self.scheduler.state_dict() if self.scheduler else None,
                'epoch': epoch,
                'train_loss': train_loss,
                'val_loss': val_loss
            }
            torch.save(checkpoint, f"{proj_dir}/output/best_train.pth")
            print(f"[Model saved at epoch {epoch + 1}] | Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

    def _early_stopping(self, val_loss):
        """
        Helper function to handle early stopping logic.
        """
        if self.early_stopping_patience is None:
            return True  # Continue training

        if self.early_stopping_counter >= self.early_stopping_patience:
            print(f"\nNo improvement observed at patience level {self.early_stopping_patience}. Stopping early.")
            return False
        return True  # Continue training

    def plot_training_curve(self):
        """
        Plot the training and validation loss curves after training is complete.
        """
        epochs = len(self.train_losses)
        plt.figure(figsize=(8, 5))
        plt.plot(range(epochs), self.train_losses, label="Training Loss")
        plt.plot(range(epochs), self.val_losses, label="Validation Loss")
        plt.title("Loss over Epochs")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

# --------------------------------------------------------------------------------------

model_dim = 128
embedding_path = f"{proj_dir}/output/embeddings/query_embeddings_train.npy"
query_embeddings = np.load(embedding_path)
num_queries = len(train_ds['prompt'])
assert query_embeddings.shape[0] == num_queries, f"Number of queries ({num_queries}) does not match number of embeddings ({query_embeddings.shape[0]})"

model = MFGate(
    num_models=len(MODEL_IDS),
    model_dim=model_dim,
    num_queries=num_queries,
    text_dim=768,
    num_output_nodes=1,
    down_projection=True,
    embedding_path=embedding_path
).cuda()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_dataset = BinaryClassificationDataset(train_ds_enc)
test_dataset = BinaryClassificationDataset(test_ds_enc)

batch_size = 64
train_spl_loader = BinaryClassificationDataset.get_dataloaders(train_dataset, batch_size=batch_size, shuffle=True)
val_spl_loader = BinaryClassificationDataset.get_dataloaders(test_dataset, batch_size=1024, shuffle=True)

lr = 3e-4
weight_decay = 1e-5
noise = 0.05 # Setting to 1 destroys query embedding information.
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = None
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')
patience = None  # Set early stopping patience to avoid overfitting, set to None to disable

print(f"{model}\n\nNumber of parameters in the model: {sum(p.numel() for p in model.parameters())}")

trainer = MFTrainer(
    device=device,
    model=model,
    train_loader=train_spl_loader,
    val_loader=val_spl_loader,
    noise=noise,
    optimizer=optimizer,
    scheduler=scheduler,
    early_stopping_patience=patience
)

trainer.train(num_epochs=10)

In [None]:
model_path = f"{proj_dir}/output/best_train.pth"
checkpoint = torch.load(model_path)

model = MFGate(
    num_models=len(MODEL_IDS),
    model_dim=128,
    num_queries=len(train_ds['prompt']),
    text_dim=768,
    num_output_nodes=1,
    down_projection=True,
    embedding_path=f"{proj_dir}/output/embeddings/query_embeddings_train.npy"
).cuda()

model.load_state_dict(checkpoint['model'])
model.eval()
print("Model loaded successfully from best_train.pth")

# Decision Boundary Estimation

Additionally, the proportion of calls can be adjusted via specifying a parameter $\text{PCT}_{strong}$, based on which a cost threshold $\alpha$ is computed that controls the proportion of calls made to the LLM, which is calculated as the $\alpha^{th}$ percentile over the distribution of router predictions (win-rates) as $$\alpha = Q(1-\text{PCT}_{strong})$$

In [None]:
def calculate_and_plot_threshold(win_rates, strong_model_pct):
    threshold = win_rates.quantile(1 - strong_model_pct)
    return threshold

test_dataset = BinaryClassificationDataset(test_ds_enc)
test_loader = BinaryClassificationDataset.get_dataloaders(test_dataset, batch_size=batch_size, shuffle=True)

win_rates = []
count = 0
with torch.no_grad():
    for models_a, models_b, prompts, labels in tqdm(test_loader, desc="Calculating Win Rates...", leave=False):
        models_a, models_b, prompts, labels = models_a.cuda(), models_b.cuda(), prompts.cuda(), labels.cuda()

        outputs = model(models_a, models_b, prompts, pe_noise=0)
        outputs = torch.sigmoid(outputs).cuda()
        count += 1
        win_rates.extend(outputs.cpu().numpy())

tdf = pd.DataFrame(win_rates, columns=['mf_wr'])

strong_model_pcts = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
router_model = 'mf_wr'

threshold_results = pd.DataFrame({
    'strong_model_pct': strong_model_pcts,
    'threshold': [calculate_and_plot_threshold(tdf[router_model], pct) for pct in strong_model_pcts]
})

plt.figure(figsize=(8, 5))
plt.hist(tdf[router_model], bins=50, color='blue', alpha=0.7)
plt.title("Distribution of Win Rates")
plt.xlabel("Win Rate")
plt.ylabel("Frequency")
plt.grid(True)
plt.tight_layout()
plt.show()

display(threshold_results)
print(count*batch_size)

# Save model for inference

In [None]:
inference_model = {
    'P.weight': model.P.weight,
    'downproj_QtoP.weight': model.downproj_QtoP.weight,
    'Classifier.weight': model.Classifier.weight
}

torch.save(inference_model, f"{proj_dir}/output/BUD-MFR-Binary-1.0.pth")
print("Inference model saved at ~/output/BUD-MFR-Binary-1.0.pth'")

# Model (Inference pass)

In [None]:
class MFRB(nn.Module):
    def __init__(
            self,
            num_models=len(MODEL_IDS),
            model_dim=128,
            text_dim=768,
            num_output_nodes=1,
            down_projection=True
        ):
        super(MFRB, self).__init__()

        # P (Model embeddings)
        self.num_models = num_models
        self.P_dim = model_dim
        self.P = nn.Embedding(self.num_models, self.P_dim)

        # W_proj (Q -> P)
        self.down_proj = down_projection
        if self.down_proj:
            self.downproj_QtoP = nn.Linear(text_dim, model_dim, bias=False)
        else:
            assert text_dim == model_dim, "Q_dim must be equal to P_dim."
            self.downproj_QtoP = None

        # Linear classifier
        self.num_ways = num_output_nodes
        self.Classifier = nn.Linear(model_dim, self.num_ways, bias=False)

    def get_device(self):
        return next(self.parameters()).device

    def forward(self, model_ids, prompt_embed, noise):
        model_ids = torch.tensor(model_ids, dtype=torch.long).to(self.get_device())
        model_embed = self.P(model_ids)
        model_embed = F.normalize(model_embed, p=2, dim=1)

        if self.down_proj:
            prompt_embed = self.downproj_QtoP(prompt_embed)

        prompt_embed += torch.randn_like(prompt_embed) * noise
        #prompt_embed = F.normalize(prompt_embed, p=2, dim=1) - No need. Prompt embeddings are already normalized across rows.

        return self.Classifier(model_embed * prompt_embed).squeeze()

    @torch.no_grad()
    def predict_win_rate(self, model_win_idx, model_loss_idx, prompt_text_embed, noise):
        prompt = prompt_text_embed
        #prompt = embedding_model.encode(prompt_text, convert_to_tensor=True, normalize_embeddings=True).to(self.get_device())
        logits = self.forward(model_ids=[model_win_idx, model_loss_idx], prompt_embed=prompt, noise=noise)
        winrate = torch.sigmoid(logits[0] - logits[1]).item()
        return winrate # Returns confidence score in between [0,1]

In [None]:
model_dir = "./output/"

checkpoint_path=f"{model_dir}/BUD-MFR-Binary-1.0.pth"
inf_model = MFRB().to("cuda")

try:
    inf_model.load_state_dict(torch.load(checkpoint_path, weights_only=True))
    print(f"✅ BUD MFR Binary checkpoints successfully loaded from {checkpoint_path}")

except Exception as e:
    print(f"❌ Error loading BUD MFR Binary checkpoints from {checkpoint_path}: {str(e)}")
    raise RuntimeError(f"Failed to load the model from {checkpoint_path}.") from e

finally:
    inf_model.eval()
    print(inf_model)

# Compute metrics

In [None]:
def route(model_a, model_b, prompt, threshold):
    win_rate = inf_model.predict_win_rate(model_a, model_b, prompt, noise=0.05) # Jitter for improved generalization, set at 0.05
    if win_rate >= threshold:
        return "model_a"
    else:
        return "model_b"

def batch_process(dataset, threshold):
    true_positives = 0
    true_negatives = 0
    false_positives = 0
    false_negatives = 0
    model_a_calls = 0
    model_b_calls = 0
    win_rates = []
    true_labels = []

    for idx in tqdm(range(len(dataset)), desc="Processing Batches..."):
        model_a, model_b, prompt, label = dataset[idx]
        prompt = test_query_embeddings[prompt].cuda()
        win_rate = inf_model.predict_win_rate(model_a, model_b, prompt, noise=0.05)
        win_rates.append(win_rate)
        true_labels.append(label)
        
        model_prediction = "model_a" if win_rate >= threshold else "model_b"

        if model_prediction == "model_a":
            model_a_calls += 1
            if label == 1:
                true_positives += 1
            else:
                false_positives += 1
        else:
            model_b_calls += 1
            if label == 0:
                true_negatives += 1
            else:
                false_negatives += 1

    total_samples = true_positives + true_negatives + false_positives + false_negatives
    return (true_positives, true_negatives, false_positives, false_negatives, 
            total_samples, model_a_calls, model_b_calls, 
            np.array(win_rates), np.array(true_labels))

test_query_embeddings = torch.tensor(np.load(f"./output/embeddings/query_embeddings_test.npy"))
dataset = test_dataset

results_df = pd.DataFrame()
for spct, threshold in zip(threshold_results['strong_model_pct'], threshold_results['threshold']):
    metrics = batch_process(dataset, threshold)
    (true_positives, true_negatives, false_positives, false_negatives, 
     total_samples, model_a_calls, model_b_calls, win_rates, true_labels) = metrics
    
    # Accuracy
    accuracy = (true_positives + true_negatives) / total_samples
    
    # Positive class metrics
    precision_positive = np.nan if (true_positives + false_positives) == 0 else \
                        true_positives / (true_positives + false_positives)
    recall_positive = np.nan if (true_positives + false_negatives) == 0 else \
                     true_positives / (true_positives + false_negatives)
    
    # Negative class metrics
    precision_negative = np.nan if (true_negatives + false_positives) == 0 else \
                        true_negatives / (true_negatives + false_positives)
    recall_negative = np.nan if (true_negatives + false_negatives) == 0 else \
                     true_negatives / (true_negatives + false_negatives)
    
    # F1 scores
    f1_score_positive = np.nan if (precision_positive + recall_positive) == 0 else \
                       2 * (precision_positive * recall_positive) / (precision_positive + recall_positive)
    f1_score_negative = np.nan if (precision_negative + recall_negative) == 0 else \
                       2 * (precision_negative * recall_negative) / (precision_negative + recall_negative)
    
    ratio_calls_a = model_a_calls / (model_a_calls + model_b_calls)
    
    confidence_mean = np.mean(win_rates)
    confidence_std = np.std(win_rates)

    result = pd.DataFrame([{
        'Strong Model Pct': spct,
        'Threshold': threshold,
        'Calls Model A (%)': ratio_calls_a,
        'Accuracy': accuracy,
        'Precision Positive': precision_positive,
        'Recall Positive': recall_positive,
        'F1 Score Positive': f1_score_positive,
        'Precision Negative': precision_negative,
        'Recall Negative': recall_negative,
        'F1 Score Negative': f1_score_negative,
        'Confidence Mean': confidence_mean,
        'Confidence Std': confidence_std,
    }])

    results_df = pd.concat([results_df, result], ignore_index=True)

display(results_df)
results_df.to_csv(f"{proj_dir}/output/threshold_results.csv", index=False)

In [None]:
results_df = pd.read_csv(f"{proj_dir}/output/threshold_results.csv")

plt.figure(figsize=(12, 8))

# Plot F1 scores
plt.subplot(2, 2, 1)
plt.plot(results_df['Strong Model Pct'], results_df['F1 Score Positive'], label='F1 Score Positive', marker='o')
plt.plot(results_df['Strong Model Pct'], results_df['F1 Score Negative'], label='F1 Score Negative', marker='o')
plt.xlabel('Strong Model PCT')
plt.ylabel('F1 Score')
plt.title('F1 Score vs Strong Model PCT')
plt.legend()
plt.grid(True)

# Plot Precision
plt.subplot(2, 2, 2)
plt.plot(results_df['Strong Model Pct'], results_df['Precision Positive'], label='Precision Positive', marker='o')
plt.plot(results_df['Strong Model Pct'], results_df['Precision Negative'], label='Precision Negative', marker='o')
plt.xlabel('Strong Model PCT')
plt.ylabel('Precision')
plt.title('Precision vs Strong Model PCT')
plt.legend()
plt.grid(True)

# Plot Recall
plt.subplot(2, 2, 3)
plt.plot(results_df['Strong Model Pct'], results_df['Recall Positive'], label='Recall Positive', marker='o')
plt.plot(results_df['Strong Model Pct'], results_df['Recall Negative'], label='Recall Negative', marker='o')
plt.xlabel('Strong Model PCT')
plt.ylabel('Recall')
plt.title('Recall vs Strong Model PCT')
plt.legend()
plt.grid(True)

# Plot Calls to Model A
plt.subplot(2, 2, 4)
plt.plot(results_df['Strong Model Pct'], results_df['Calls Model A (%)'], label='Calls Model A (%)', marker='o')
plt.xlabel('Strong Model PCT')
plt.ylabel('Calls Model A (%)')
plt.title('Calls to Model A vs Strong Model PCT')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


# **Benchmark**

In [None]:
def load_and_concatenate_csvs(directory, prefix):
    csv_files = [f for f in os.listdir(directory) if f.startswith(prefix) and f.endswith('.csv')]
    dfs = [pd.read_csv(os.path.join(directory, f)) for f in csv_files]
    concatenated_df = pd.concat(dfs, ignore_index=True)
    concatenated_df.rename(columns={"mistralai/Mixtral-8x7B-Instruct-v0.1": "mixtral-8x7b-instruct-v0.1"}, inplace=True)
    return concatenated_df

# Define benchmark datasets here. Format must adhere to that of the below csvs.
mmlu_df = load_and_concatenate_csvs('./evals/gemma2-llama3.2/', 'samples_mmlu')
gsm8k_df = load_and_concatenate_csvs('./evals/gemma2-llama3.2/', 'samples_gsm8k')

class EvalDataset(Dataset):
    def __init__(self, df):
        self.prompts = df['prompt'].tolist()

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        return self.prompts[idx]

def get_dataloader(df, batch_size=64):
    dataset = EvalDataset(df)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader

mmlu_loader = get_dataloader(mmlu_df)
gsm8k_loader = get_dataloader(gsm8k_df)

strong_model_idx = MODEL_IDS[strong_model]
weak_model_idx = MODEL_IDS[weak_model]

def generate_win_rates(dataloader, cache_path):
    if os.path.exists(cache_path):
        print(f"Prediction cache found at {cache_path}. Load existing cache? (y/n)")
        choice = input("Use precomputed cache? (y/n) : ").lower()
        if choice == 'y':
            print(f"Loading caches from {cache_path}...")
            return np.load(cache_path).tolist()
    
    win_rates = []
    with torch.no_grad():
        for prompts in tqdm(dataloader, desc="Generating Win Rates..."):
            prompt_embeddings = embedding_model.encode(prompts, convert_to_tensor=True, normalize_embeddings=True).cuda()
            for prompt_embedding in prompt_embeddings:
                win_rate = inf_model.predict_win_rate(strong_model_idx, weak_model_idx, prompt_embedding, noise=0.05)
                win_rates.append(win_rate)
    
    print(f"Caching win rates to {cache_path}...")
    np.save(cache_path, win_rates)
    return win_rates

mmlu_cache_path = f"{proj_dir}/output/mmlu_win_rates.npy"
gsm8k_cache_path = f"{proj_dir}/output/gsm8k_win_rates.npy"

mmlu_win_rates = generate_win_rates(mmlu_loader, mmlu_cache_path)
gsm8k_win_rates = generate_win_rates(gsm8k_loader, gsm8k_cache_path)
np.save(mmlu_cache_path, mmlu_win_rates)
np.save(gsm8k_cache_path, gsm8k_win_rates)

def calculate_thresholds(win_rates, strong_model_pcts):
    thresholds = [calculate_and_plot_threshold(pd.Series(win_rates), pct) for pct in strong_model_pcts]
    return thresholds

strong_model_pcts = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
mmlu_thresholds = calculate_thresholds(mmlu_win_rates, strong_model_pcts)
gsm8k_thresholds = calculate_thresholds(gsm8k_win_rates, strong_model_pcts)

# Make routing decisions based on the calculated win rates and the selected threshold
def make_routing_decisions(win_rates, thresholds, strong_model_pct):
    threshold = thresholds[strong_model_pcts.index(strong_model_pct)]
    routing_decisions = [strong_model if win_rate >= threshold else weak_model for win_rate in win_rates]
    return routing_decisions

# Compare the routing decisions against the ground truth labels
def compare_routing_decisions(df, routing_decisions):
    correct_decisions = 0
    for i, decision in enumerate(routing_decisions):
        if decision == strong_model:
            correct_decisions += df.iloc[i][strong_model]
        else:
            correct_decisions += df.iloc[i][weak_model]
    return correct_decisions / len(routing_decisions)

results = []
for spct, threshold in zip(strong_model_pcts, mmlu_thresholds):
    mmlu_routing_decisions = make_routing_decisions(mmlu_win_rates, mmlu_thresholds, spct)
    gsm8k_routing_decisions = make_routing_decisions(gsm8k_win_rates, gsm8k_thresholds, spct)

    mmlu_strong_calls_pct = mmlu_routing_decisions.count(strong_model) / len(mmlu_routing_decisions) * 100
    gsm8k_strong_calls_pct = gsm8k_routing_decisions.count(strong_model) / len(gsm8k_routing_decisions) * 100
    
    mmlu_correct_decisions = compare_routing_decisions(mmlu_df, mmlu_routing_decisions) * 100
    gsm8k_correct_decisions = compare_routing_decisions(gsm8k_df, gsm8k_routing_decisions) * 100

    results.append({
        'Strong Model Pct': spct,
        'MMLU': mmlu_correct_decisions,
        'MMLU Strong Calls (%)': mmlu_strong_calls_pct,
        'GSM8K': gsm8k_correct_decisions,
        'GSM8K Strong Calls (%)': gsm8k_strong_calls_pct,
    })

results_df = pd.DataFrame(results)
display(results_df)
results_df.to_csv(f"{proj_dir}/output/benchmark_results.csv", index=False)

plt.figure(figsize=(10, 6))
plt.plot(results_df['MMLU Strong Calls (%)']/100, results_df['MMLU'], label='Scores-MMLU', marker='o')
plt.plot(results_df['GSM8K Strong Calls (%)']/100, results_df['GSM8K'], label='Scores-GSM8K', marker='o')
plt.xlabel('Strong Model PCT')
plt.ylabel('Scores')
plt.title('Benchmark Scores vs Strong Model PCT')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

- Benchmark scores at Strong Model PCT 0 correspond to the scores of the Weaker Model
- Benchmark scores at Strong Model PCT 1 correspond to the scores of the Stronger Model

# Cost-Performance Analysis

In [None]:
def auc_metric(results_df, benchmark_name): # AUC calculation
    auc = np.trapezoid(results_df[f"{benchmark_name}"]/100, results_df[f"{benchmark_name} Strong Calls (%)"]/100)
    return auc

def pgr_metric(score, weak_score, strong_score): # PGR calculation
    return (score - weak_score) / (strong_score - weak_score)

def apgr_metric(pgr_df, results_df, benchmark_name): # APGR calculation
    apgr = np.trapezoid(pgr_df[f"{benchmark_name} PGR"], results_df[f"{benchmark_name} Strong Calls (%)"]/100)
    return apgr

def pct_metric(pgr_df, benchmark_name, pgr): # CPT calculation @ different PGR values
    weak_model_accuracy = results_df.loc[results_df['Strong Model Pct'] == 0.0, benchmark_name].values[0]
    strong_model_accuracy = results_df.loc[results_df['Strong Model Pct'] == 1.0, benchmark_name].values[0]
    target_performance = pgr * (strong_model_accuracy - weak_model_accuracy) + weak_model_accuracy
    cpt = np.interp(target_performance, results_df[benchmark_name], results_df[f"{benchmark_name} Strong Calls (%)"])
    return cpt

# PGR Calculation
mmlu_weak_score = results_df.loc[results_df['Strong Model Pct'] == 0.0, 'MMLU'].values[0]
mmlu_strong_score = results_df.loc[results_df['Strong Model Pct'] == 1.0, 'MMLU'].values[0]

gsm8k_weak_score = results_df.loc[results_df['Strong Model Pct'] == 0.0, 'GSM8K'].values[0]
gsm8k_strong_score = results_df.loc[results_df['Strong Model Pct'] == 1.0, 'GSM8K'].values[0]

results_df['MMLU PGR'] = results_df['MMLU'].apply(pgr_metric, args=(mmlu_weak_score, mmlu_strong_score))
results_df['GSM8K PGR'] = results_df['GSM8K'].apply(pgr_metric, args=(gsm8k_weak_score, gsm8k_strong_score))

pgr_df = results_df[['Strong Model Pct', 'MMLU PGR', 'GSM8K PGR']]
display(pgr_df)
pgr_df.to_csv(f"{proj_dir}/output/pgr_results.csv", index=False)

metrics_df = pd.DataFrame({
    'Benchmark': ['MMLU', 'GSM8K'],
    'AUC': [auc_metric(results_df, 'MMLU'), auc_metric(results_df, 'GSM8K')],
    'APGR': [apgr_metric(pgr_df, results_df, 'MMLU'), apgr_metric(pgr_df, results_df, 'GSM8K')],
    'CPT @ 20% PGain': [pct_metric(pgr_df, 'MMLU', 0.2), pct_metric(pgr_df, 'GSM8K', 0.2)],
    'CPT @ 50% PGain': [pct_metric(pgr_df, 'MMLU', 0.5), pct_metric(pgr_df, 'GSM8K', 0.5)],
    'CPT @ 80% PGain': [pct_metric(pgr_df, 'MMLU', 0.8), pct_metric(pgr_df, 'GSM8K', 0.8)],
})

display(metrics_df)
metrics_df.to_csv(f"{proj_dir}/output/metrics_results.csv", index=False)

plt.figure(figsize=(10, 6))
plt.plot(pgr_df['Strong Model Pct'], pgr_df['MMLU PGR'], label='PGR-MMLU', marker='o')
plt.plot(pgr_df['Strong Model Pct'], pgr_df['GSM8K PGR'], label='PGR-GSM8K', marker='o')
plt.xlabel('Strong Model PCT')
plt.ylabel('PGR')
plt.title('PGR vs Strong Model PCT')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

### APGR

APGR (Average Performance Gap Recovered) is a metric used to assess the effectiveness of LLM routers. It quantifies how well a router can bridge the performance difference between a weaker, more cost-effective LLM and a stronger, more expensive one. The performance gap is the difference in quality between the outputs of a strong LLM (like GPT-4) and a weak LLM/SLM (like Mixtral-8x7B). This gap is usually measured using benchmark datasets or human evaluations. Routers aim to "recover" this performance gap by intelligently deciding which LLM to use for a given query. Ideally, they'd route simple queries to the weak model and complex ones to the strong model, maximizing quality while minimizing cost. APGR measures the router's ability to recover this performance gap across a range of cost constraints. It's calculated as the area under the router's performance curve (plotting quality against the percentage of calls to the strong model). A higher APGR value indicates that the router is more effective at selecting the appropriate LLM for different types of queries. This means it can achieve a good balance between cost-efficiency and response quality.

### AUC

AUC is calculated as the Area Under the Call-Performance Curve.

- AUC at 1: Implies maximum routing performance.
- AUC close to 1:  Suggests the router effectively routes most queries to the appropriate model (strong or weak), achieving near-optimal performance.
- AUC close to 0.5:  Indicates the router performs similarly to a random routing strategy, showing limited effectiveness in differentiating between query types.
- AUC below 0.5: Implies the router might be making poor routing decisions, potentially performing worse than random chance.
- AUC at 0: Lowest tier performance across all samples.