大型语言模型已经改变了业务，在从自然语言理解到复杂推理等各种任务中都表现出令人印象深刻的性能。然而，部署这些模型往往需要在性能和成本之间取得平衡。像 GPT-4o 这样的高级模型具有很高的准确性，但计算成本和财务成本也很高。这给对成本敏感的应用带来了挑战，因为在这些应用中，既要保证质量，又要控制成本是至关重要的。

## 什么是LLM Router

LLM 路由器是一个根据任务的复杂性动态地将查询定向到最合适的大型语言模型的系统。它将更简单的查询发送到更小、更具成本效益的模型，同时为更强大的模型保留复杂的任务，平衡性能和成本。

## 什么时候需要LLM Router

在需要平衡性能质量与成本限制的应用程序中部署 LLM 时，您需要 LLM 路由器。这在查询复杂性差异很大的场景中尤其重要，例如在聊天机器人、客户服务系统和其他交互式人工智能解决方案中。

如果所有查询都发送到 GPT-4o 等高性能模型，成本很快就会变得令人望而却步。当您想要保持高质量的响应而又不想为每次交互使用强大的模型而承担全部费用时，LLM 路由器非常有用。通过将查询路由到最合适的模型，系统可以降低成本，同时保持可接受的性能水平，使其成为仍需要准确和及时响应的成本敏感型应用程序的理想选择。

## LLM Router如何工作

LLM 路由器的工作原理是了解哪些类型的查询在由较弱的模型处理时更有可能产生有利的结果。在训练期间，路由器会接触到查询示例以及路由到强模型或弱模型时相应的性能结果。通过分析这些模式，路由器学会识别通常需要更强大的模型才能获得高质量结果的查询特征。

当新查询到达时，路由器使用学到的知识来预测每个模型成功的可能性。如果查询类似于之前通过强模型获得更好结果的查询，则路由器会将其定向到那里。相反，如果较弱的模型可能足以处理查询，则会相应地路由它。这种动态决策过程可以优化性能，同时控制成本，确保每个查询都由最适合的模型处理，从而根据过去的经验提供有利的结果。

## 训练一个LLM Router

现在，我们将专注于开发大型语言模型的路由系统，通过训练分类器来决定查询是否应该由强大的模型（例如 GPT-4o）或较弱的、具有成本效益的模型（例如 Mixtral-8x7B）处理。为了优化路由决策，我们利用包含 GPT-4o 和 Mixtral 响应的数据集，根据 Mixtral 的答案与 GPT-4o 响应的匹配程度，按 1 到 5 的等级进行评分。

听起来可能有点奇怪，我们使用 GPT-4o 来评估 GPT-4o 和 Mixtral 生成的响应的质量（GPT-4o 将其自己的答案与另一个模型进行比较）。然而，由于我们要求模型比较 Mixtral 响应与 GPT-4o 响应的匹配程度，我相信模型偏向其自身响应的风险较小，因为它只是比较 Mixtral 响应的匹配程度GPT-4o 响应。

评分为 4 或更高的响应被认为足以满足较弱的模型，而评分低于 4 的响应则表明需要更强的模型。此代码训练一个学习这些路由模式的二元分类器。使用 Torch 和 Sentence Transformers 库，对模型进行训练，以根据其对齐分数来预测查询是否应路由到较弱或较强的模型，旨在在不牺牲性能的情况下最大限度地降低成本。

In [1]:
import torch
import os
from torch import nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split
from datasets import load_dataset
import wandb  # Import W&B
from tqdm import tqdm

# Initialize W&B
wandb.init(project="router")  # Set your project name

# Load the dataset from Hugging Face
dataset = load_dataset("./gpt4_dataset") # wangrongsheng/gpt4_dataset

# Convert the training data to pandas DataFrame for easier manipulation
train_df = dataset["train"].to_pandas()

# Define the scoring threshold for routing labels
train_df["routing_label"] = train_df["mixtral_score"].apply(lambda x: 0 if x >= 4 else 1)  # Binary classification labels

# Extract prompts and labels for training
sentences = train_df["prompt"].tolist()
labels = train_df["routing_label"].tolist()

# Split the data into training and validation sets
sentences_train, sentences_val, labels_train, labels_val = train_test_split(sentences, labels, test_size=0.2, random_state=42)

# Create a custom PyTorch dataset
class TrainingDataset(Dataset):
    def __init__(self, sentences, labels):
        self.sentences = sentences
        self.labels = labels

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        label = self.labels[idx]
        return sentence, torch.tensor(label, dtype=torch.float)  # Use float for BCEWithLogitsLoss

# Create DataLoaders
train_data = TrainingDataset(sentences_train, labels_train)
val_data = TrainingDataset(sentences_val, labels_val)

train_loader = DataLoader(train_data, batch_size=4096)
val_loader = DataLoader(val_data, batch_size=4096, shuffle=True)  # Validation loader remains unchanged

# Define the classifier model with trainable transformer backbone
class Classifier(nn.Module):
    def __init__(self, transformer_model_name):
        super(Classifier, self).__init__()
        self.transformer = SentenceTransformer(transformer_model_name)
        self.fc1 = nn.Linear(self.transformer.get_sentence_embedding_dimension(), 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 1)  # Single output neuron for binary classification
        self.relu = nn.ReLU()

    def forward(self, sentences):
        embeddings = self.transformer.encode(sentences, convert_to_tensor=True)  # Generate embeddings in the forward pass
        x = self.relu(self.fc1(embeddings))
        x = self.relu(self.fc2(x))
        logits = self.fc3(x)  # Output single logit for binary classification
        return logits

# Initialize the classifier
model = Classifier(transformer_model_name='./all-distilroberta-v1')

# Use GPU if it's available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()  # Use BCEWithLogitsLoss for binary classification with one output neuron
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Number of epochs
n_epochs = 10

# Directory to save the best model
runs_dir = "runs"
os.makedirs(runs_dir, exist_ok=True)

# Initialize best validation loss with infinity
best_valid_loss = float('inf')

# Log hyperparameters to W&B
wandb.config = {
    "learning_rate": 0.001,
    "epochs": n_epochs,
    "batch_size": 4096,
}

def validate(model, val_loader, criterion, device):
    """Perform validation and return the loss, accuracy, and percentage of predictions for each class."""
    model.eval()
    valid_loss = 0.0
    valid_correct = 0
    total_predictions = []
    with torch.no_grad():
        for sentences, labels in tqdm(val_loader):
            sentences = list(sentences)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(sentences).squeeze(1)
            
            # Compute loss
            loss = criterion(outputs, labels)
            valid_loss += loss.item()
            predictions = torch.round(torch.sigmoid(outputs))
            valid_correct += (predictions == labels).sum().item()
            total_predictions.extend(predictions.cpu().numpy())

    valid_loss /= len(val_loader)
    valid_accuracy = valid_correct / len(val_loader.dataset)
    
    return valid_loss, valid_accuracy

# Initial validation of the untrained model
initial_valid_loss, initial_valid_accuracy = validate(model, val_loader, criterion, device)
print(f'Initial Validation Loss: {initial_valid_loss:.4f}, Initial Validation Accuracy: {initial_valid_accuracy:.4f}')

  from tqdm.autonotebook import tqdm, trange
[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mwangrongsheng[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 6/6 [00:20<00:00,  3.48s/it]

Initial Validation Loss: 0.6845, Initial Validation Accuracy: 0.8638





In [2]:
wandb.log({
    "epoch": 0,
    "valid_loss": initial_valid_loss,
    "valid_accuracy": initial_valid_accuracy,
})

for epoch in range(n_epochs):
    # Training
    model.train()
    train_loss = 0.0
    train_correct = 0
    for sentences, labels in tqdm(train_loader):
        sentences = list(sentences)  # Convert tensor of strings back to list for transformer
        labels = labels.to(device)
        
        # Clear the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(sentences).squeeze(1)  # Squeeze output to match shape [batch_size]
        
        # Compute loss
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        predictions = torch.round(torch.sigmoid(outputs))  # Convert logits to probabilities and then round to 0 or 1
        train_correct += (predictions == labels).sum().item()

    train_loss /= len(train_loader)
    train_accuracy = train_correct / len(train_loader.dataset)

    # Validation after each epoch
    valid_loss, valid_accuracy = validate(model, val_loader, criterion, device)

    # Log metrics to W&B
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "train_accuracy": train_accuracy,
        "valid_loss": valid_loss,
        "valid_accuracy": valid_accuracy,
    })

    print(f'Epoch {epoch+1}/{n_epochs}, Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, Validation Loss: {valid_loss:.4f}, Validation Accuracy: {valid_accuracy:.4f}')

    # Save the model if it's the best so far
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), os.path.join(runs_dir, 'best_model.pt'))

print('Training complete.')
wandb.finish()  # Finish the W&B run

100%|██████████| 22/22 [01:21<00:00,  3.71s/it]
100%|██████████| 6/6 [00:20<00:00,  3.40s/it]


Epoch 1/10, Training Loss: 0.5026, Training Accuracy: 0.8650, Validation Loss: 0.3888, Validation Accuracy: 0.8638


100%|██████████| 22/22 [01:22<00:00,  3.75s/it]
100%|██████████| 6/6 [00:20<00:00,  3.42s/it]


Epoch 2/10, Training Loss: 0.3741, Training Accuracy: 0.8650, Validation Loss: 0.3659, Validation Accuracy: 0.8638


100%|██████████| 22/22 [01:23<00:00,  3.81s/it]
100%|██████████| 6/6 [00:21<00:00,  3.57s/it]


Epoch 3/10, Training Loss: 0.3567, Training Accuracy: 0.8650, Validation Loss: 0.3554, Validation Accuracy: 0.8638


100%|██████████| 22/22 [01:23<00:00,  3.79s/it]
100%|██████████| 6/6 [00:20<00:00,  3.46s/it]


Epoch 4/10, Training Loss: 0.3489, Training Accuracy: 0.8650, Validation Loss: 0.3535, Validation Accuracy: 0.8638


100%|██████████| 22/22 [01:23<00:00,  3.80s/it]
100%|██████████| 6/6 [00:20<00:00,  3.44s/it]


Epoch 5/10, Training Loss: 0.3440, Training Accuracy: 0.8650, Validation Loss: 0.3535, Validation Accuracy: 0.8638


100%|██████████| 22/22 [01:22<00:00,  3.77s/it]
100%|██████████| 6/6 [00:20<00:00,  3.49s/it]


Epoch 6/10, Training Loss: 0.3394, Training Accuracy: 0.8650, Validation Loss: 0.3499, Validation Accuracy: 0.8638


100%|██████████| 22/22 [01:22<00:00,  3.76s/it]
100%|██████████| 6/6 [00:20<00:00,  3.43s/it]


Epoch 7/10, Training Loss: 0.3346, Training Accuracy: 0.8650, Validation Loss: 0.3486, Validation Accuracy: 0.8638


  5%|▍         | 1/22 [00:06<02:11,  6.26s/it]


KeyboardInterrupt: 

[1;34mwandb[0m: 🚀 View run [33mglamorous-snow-17[0m at: [34mhttps://wandb.ai/wangrongsheng/router/runs/yvc5r9cg[0m
[1;34mwandb[0m: Find logs at: [1;35mwandb/run-20241007_224043-yvc5r9cg/logs[0m


微调 LLM 路由器涉及训练一个分类器，该分类器根据查询的复杂性确定是否应将查询路由到强模型或弱模型。训练过程首先从 HuggingFace 加载标记数据集，其中包含来自强模型和弱模型的查询及其性能分数。该数据集被转换为 pandas DataFrame 以简化操作。标签是基于性能阈值创建的：在弱模型中得分足够高的查询（例如，得分为 4 或以上）被标记为适合该模型，而较低的得分表明需要强模型。

然后使用 train_test_split 将数据分成训练集和验证集，这确保模型在一部分数据上进行训练并在另一部分上进行验证，从而可以评估其在未见过的数据上的性能。为了有效地处理数据，定义了一个自定义的 PyTorch 数据集类，将查询及其标签构建为批次，可以在训练期间使用 DataLoader 实用程序对这些批次进行洗牌和处理。

分类器模型是使用 Sentence Transformers 库中的可训练变压器主干构建的，它为输入句子生成嵌入。这些嵌入通过一系列具有 ReLU 激活的完全连接层传递，最终形成为二元分类提供 logit 的单个输出神经元。使用的损失函数是 BCEWithLogitsLoss，它非常适合路由决策等二元分类任务。

在每个epoch，模型都会在训练集上进行训练，以最大限度地减少分类损失并提高准确性。训练后，模型的性能在验证集上进行评估，从而可以监控其对新数据的泛化。使用权重和偏差在整个过程中记录训练和验证损失和准确性等性能指标，从而能够实时跟踪和分析模型的进度。

当模型训练时，只要与之前的迭代相比，它实现了较低的验证损失，就会保存其状态。此检查点可确保保留模型的最佳版本，有助于避免使用过度拟合训练数据的模型。当所有 epoch 完成时，训练结束，权重和偏差的运行也最终确定，从而巩固了实验结果。然后，经过训练的模型就可以部署在 LLM 路由系统中，它将使用其学到的知识动态地确定每个查询的最佳模型，平衡性能与成本考虑。

## 评估LLM Router的性能

我们使用两个关键指标——恢复性能差距 (PGR) 和呼叫性能阈值 (CPT)——来评估路由有效性。 PGR 衡量路由系统可以恢复强模型和弱模型之间的性能差距有多大。例如，如果 GPT-4o 达到 100% 的准确率，Mixtral-8x7B 达到 86%，那么达到 93% 的路由模型就可以弥补一半的差距。该系统允许通过调整阈值来调整路由模型，这些阈值定义何时基于查询复杂性和置信度路由到强模型。

另一方面，CPT 量化必须路由到强模型以实现所需 PGR 级别的查询的最小百分比。例如，CPT(50%)表示通过一定比例的强模型调用可以弥补一半的性能差距。较低的 CPT 值表明更高效的路由模型，可以通过减少对更昂贵模型的调用来保持高性能。性能/成本权衡图说明了这种平衡，显示了准确性如何响应对强模型的不同程度的依赖。决策者可以使用此图表来确定最佳的成本节约策略，而无需牺牲太多性能。

下面是一些计算 CPT(50%) 和 CPT(80%) 分数的代码，以及性能/成本权衡图表，用于显示随着对强模型的更多调用，性能如何提高。

In [2]:
import torch
import matplotlib.pyplot as plt
from torch import nn
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import wandb

# Initialize WandB project
wandb.init(project="router_eval", name="CPT_Evaluation")

# Define the trained model class and load the model
class Classifier(nn.Module):
    def __init__(self, transformer_model_name):
        super(Classifier, self).__init__()
        self.transformer = SentenceTransformer(transformer_model_name)
        self.fc1 = nn.Linear(self.transformer.get_sentence_embedding_dimension(), 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 1)
        self.relu = nn.ReLU()


    def forward(self, sentences):
        embeddings = self.transformer.encode(sentences, convert_to_tensor=True)
        x = self.relu(self.fc1(embeddings))
        x = self.relu(self.fc2(x))
        return self.fc3(x)

model = Classifier('./all-distilroberta-v1')
model.load_state_dict(torch.load('runs/best_model.pt'))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()

# Load evaluation data
dataset = load_dataset("./gpt4_dataset")
eval_df = dataset["validation"].to_pandas()
sentences_eval = eval_df["prompt"].tolist()
labels_eval = eval_df["mixtral_score"].tolist()

def calculate_accuracy(predictions, labels):
    correct = 0
    for pred, label in zip(predictions, labels):
        if pred == 1:  # Routed to the strong model
            correct += 1  # Always considered correct
        elif pred == 0 and label >= 4:  # Routed to the weak model and label indicates correct
            correct += 1  # Correct if the label meets the threshold
    return correct / len(predictions) if predictions else 0

# Generate logits using the model
def generate_logits(model, sentences, labels):
    logit_buffer = []
    for sentence, label in zip(sentences, labels):
        with torch.no_grad():
            output = model([sentence]).squeeze(1)
            prob_strong = torch.sigmoid(output).item()
            logit_buffer.append((prob_strong, 1 - prob_strong, label))
    return logit_buffer

# Evaluate the model across bins
def evaluate_model_across_bins(logit_buffer, num_bins):
    bin_accuracies = []
    for pct in range(1, num_bins + 1):
        max_calls = int((pct / num_bins) * len(logit_buffer))
        sorted_buffer = sorted(logit_buffer, key=lambda x: x[0], reverse=True)
        predictions = [1 if i < max_calls else 0 for i in range(len(sorted_buffer))]
        true_labels = [lbl for _, _, lbl in sorted_buffer]
        accuracy = calculate_accuracy(predictions, true_labels)
        bin_accuracies.append((pct * 100 / num_bins, accuracy))
    return bin_accuracies

# Plot and log accuracies with matplotlib for 1000-bin charts
# def plot_and_log_accuracies(bin_accuracies, title, log_name, target_accuracy=None, cpt=None):
#     percentages, accuracies, cpt_values = zip(*bin_accuracies)
#     plt.figure()
#     plt.plot(percentages, accuracies, marker='o')
#     plt.xlabel('% Calls to Strong Model')
#     plt.ylabel('Accuracy')
#     plt.title(title)
#     plt.grid(True)
    
#     # Add dashed lines for target accuracy and CPT, if provided
#     if target_accuracy is not None:
#         plt.axhline(y=target_accuracy, color='r', linestyle='--', label='Target Accuracy')
#     if cpt is not None:
#         plt.axvline(x=cpt, color='g', linestyle='--', label=f'CPT Value ({cpt:.2f})')
#         # Annotate the actual CPT value
#         plt.text(cpt, target_accuracy, f'{cpt:.4f}', color='g', fontsize=9, ha='right', va='bottom')
        
#     plt.legend()
#     plt.savefig(f"{log_name}.png")
#     wandb.log({log_name: wandb.Image(f"{log_name}.png")})
#     plt.close()

def plot_and_log_accuracies(bin_accuracies, title, log_name, target_accuracy=None, cpt=None):
    percentages, accuracies = zip(*bin_accuracies)  # 修改这里，去掉cpt_values
    plt.figure()
    plt.plot(percentages, accuracies, marker='o')
    plt.xlabel('% Calls to Strong Model')
    plt.ylabel('Accuracy')
    plt.title(title)
    plt.grid(True)
    
    # 添加目标准确率和CPT值的虚线，如果提供了的话
    if target_accuracy is not None:
        plt.axhline(y=target_accuracy, color='r', linestyle='--', label='Target Accuracy')
    if cpt is not None:
        plt.axvline(x=cpt, color='g', linestyle='--', label=f'CPT Value ({cpt:.2f})')
        # 标注实际的CPT值
        plt.text(cpt, target_accuracy, f'{cpt:.4f}', color='g', fontsize=9, ha='right', va='bottom')
        
    plt.legend()
    plt.savefig(f"{log_name}.png")
    wandb.log({log_name: wandb.Image(f"{log_name}.png")})
    plt.close()

logit_buffer = generate_logits(model, sentences_eval, labels_eval)

bin_accuracies_1000 = evaluate_model_across_bins(logit_buffer, 1000)

# Find weak and strong model accuracies
weak_accuracy = calculate_accuracy([0] * len(labels_eval), labels_eval)
strong_accuracy = calculate_accuracy([1] * len(labels_eval), labels_eval)

# Calculate CPT values for 50% and 80% PGR
target_accuracy_50 = (strong_accuracy - weak_accuracy) * 0.5 + weak_accuracy
target_accuracy_80 = (strong_accuracy - weak_accuracy) * 0.8 + weak_accuracy

cpt_50 = min(bin_accuracies_1000, key=lambda x: abs(x[1] - target_accuracy_50))[0]
cpt_80 = min(bin_accuracies_1000, key=lambda x: abs(x[1] - target_accuracy_80))[0]

# Log CPT values
wandb.log({"CPT_50": cpt_50, "CPT_80": cpt_80})

# Plot and log the 1000-bin accuracy charts
plot_and_log_accuracies(bin_accuracies_1000, 'CPT 50 Evaluation (1000 Bins)', 'CPT 50 Chart', target_accuracy_50, cpt_50)
plot_and_log_accuracies(bin_accuracies_1000, 'CPT 80 Evaluation (1000 Bins)', 'CPT 80 Chart', target_accuracy_80, cpt_80)

wandb.finish()

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
CPT_50,▁
CPT_80,▁

0,1
CPT_50,21.4
CPT_80,49.0


VBox(children=(Label(value='0.076 MB of 0.076 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
CPT_50,▁
CPT_80,▁

0,1
CPT_50,21.4
CPT_80,49.0


生成的图表显示了 CPT(50%) 和 CPT(80%) 评估的性能/成本权衡。 CPT(80%) 图表表明，需要对强模型进行近 49% 的调用才能达到实现强弱模型之间性能差距恢复 80% 的目标准确度。 CPT(50%) 图表显示，需要约 21.4% 的强模型调用才能达到 50% 性能差距恢复目标。这些结果显示了使用强模型和实现所需性能水平之间的权衡，表明无需将所有查询路由到强模型即可实现显着的性能提升。

## 使用 Weave 评估响应质量

为了更深入地了解我们的模型在使用路由器时如何响应，我们对数据集进行了 Weave 评估。 Weave 是一个用于简化评估的强大工具，提供了一种快速直观的方式来可视化模型如何响应各种查询。

虽然性能指标通常是主要关注点，但 Weave 更进一步，将个人响应直接记录到交互式仪表板。此设置可以轻松地并排比较响应，从而可以轻松识别不同模型如何处理相同的查询。对具体响应的详细检查不仅突出了每个模型的优点和缺点，而且还提供了可以改进的清晰视图，为机器学习从业者提供了有效改进模型所需的信息。

In [2]:
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
import pandas as pd
import weave
from weave import Evaluation
import asyncio
from datasets import load_dataset
from weave import Dataset
import nest_asyncio  # 导入 nest_asyncio 模块

# 应用 nest_asyncio.patch() 来修补当前环境，允许在已有事件循环中运行 asyncio.run()
nest_asyncio.apply()

# Define the classifier model with a trainable transformer backbone
class Classifier(nn.Module):
    def __init__(self, transformer_model_name):
        super(Classifier, self).__init__()
        self.transformer = SentenceTransformer(transformer_model_name)
        self.transformer.train()
        self.fc1 = nn.Linear(self.transformer.get_sentence_embedding_dimension(), 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 1)  # Single output neuron for binary classification
        self.relu = nn.ReLU()

    def forward(self, sentences):
        embeddings = self.transformer.encode(sentences, convert_to_tensor=True)  # Generate embeddings
        x = self.relu(self.fc1(embeddings))
        x = self.relu(self.fc2(x))
        logits = self.fc3(x)  # Output single logit for binary classification
        return logits

# Sample alpha threshold for routing
alpha = 0.23591  # Adjust this value based on your routing needs

# Initialize the classifier model with the desired transformer
transformer_model_name = 'all-distilroberta-v1'  # 注意这里的路径是否正确
model = Classifier(transformer_model_name=transformer_model_name)
model.load_state_dict(torch.load('runs/best_model.pt'))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Load dataset from Hugging Face and convert to pandas dataframe
dataset = load_dataset("gpt4_dataset")  # 注意这里是否需要本地路径或远程名称
val_df = dataset["validation"].to_pandas()

# Initialize Weave
weave.init('router-example')

# Define a scoring function that checks if the chosen response matches the expected one
@weave.op()
def match_score(expected: str, model_output: dict) -> dict:
    # Check if the chosen response matches the expected response
    return {'match': expected == model_output['generated_text']}

# Create evaluation examples directly from the dataframe for speed
examples = [
    {
        "prompt": row['prompt'],
        "expected": row['mixtral_response'] if row['mixtral_score'] >= 4 else row['gpt4_response'],
        "gpt4_response": row['gpt4_response'],
        "mixtral_response": row['mixtral_response'],
    }
    for _, row in val_df.head(100).iterrows()  # just evaluate 100 samples 
]

# Create a Dataset object with examples
dataset_obj = Dataset(name='gpt4_dataset_example', rows=examples)

@weave.op()
def run_inference(prompt: str, gpt4_response: str, mixtral_response: str) -> dict:
    model.eval()
    with torch.no_grad():
        # Forward pass through classifier to get routing score
        logits = model([prompt]).squeeze()
        score = torch.sigmoid(logits).item()  # Convert logit to probability score between 0 and 1

        # Decision logic based on score and alpha
        chosen_response = gpt4_response if score > alpha else mixtral_response

    # Return the chosen response
    return {
        'generated_text': chosen_response,
    }

# Create an evaluation object with examples and the scoring function
evaluation = Evaluation(dataset=dataset_obj, scorers=[match_score])

# Run the evaluation asynchronously on the function
# 注意这里的实现方式，以适应已经在运行的事件循环
coroutine = evaluation.evaluate(run_inference)

try:
    # 使用 await 直接等待异步函数的结果
    result = asyncio.run(coroutine)
except RuntimeError:
    # 如果仍然遇到事件循环已经存在的错误，可以尝试以下方法创建一个新的任务并等待它完成
    loop = asyncio.get_event_loop()
    task = loop.create_task(coroutine)
    result = loop.run_until_complete(task)

print('Evaluation complete.')

🍩 https://wandb.ai/wangrongsheng/router-example/r/call/01926787-8764-72e3-9104-d95eee2e19df
Evaluation complete.


- https://wandb.ai/byyoung3/ML_NEWS3/reports/How-to-train-and-evaluate-an-LLM-router--Vmlldzo5MjU0MTA1
- https://github.com/lm-sys/RouteLLM