In [None]:
!!pip install transformers





In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
from typing import List

class LLMClient:
    def __init__(self, model_name: str):
        self.model_name = model_name
        print(f"Loading model: {model_name}")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
        self.pipeline = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer, device=0)

    def generate(self, prompt: str, max_tokens: int = 256) -> str:
        response = self.pipeline(prompt, max_new_tokens=max_tokens, temperature=0.7, top_k=50, do_sample=True)
        return response[0]['generated_text']


class MOA:
    def __init__(self, proposer_clients: List[LLMClient], num_layers: int):
        self.proposer_clients = proposer_clients
        self.num_layers = num_layers

    def generate_responses(self, prompt: str) -> List[str]:
        responses = []
        for client in self.proposer_clients:
            try:
                response = client.generate(prompt)
                responses.append(response)
            except Exception as e:
                print(f"Error generating response from {client.model_name}: {e}")
        return responses

    def aggregate_responses(self, responses: List[str], aggregator: LLMClient) -> str:
        prompt = "You are tasked to synthesize these responses into a single, high-quality answer:\n\n"
        for i, response in enumerate(responses):
            prompt += f"Response {i + 1}: {response}\n\n"
        prompt += "Provide a refined and comprehensive response based on the above."
        return aggregator.generate(prompt, max_tokens=512)

    def run(self, initial_prompt: str, aggregator: LLMClient) -> str:
        current_prompt = initial_prompt
        for layer in range(self.num_layers):
            print(f"Processing Layer {layer + 1}/{self.num_layers}...")

            proposer_responses = self.generate_responses(current_prompt)

            current_prompt = self.aggregate_responses(proposer_responses, aggregator)

        return current_prompt


In [None]:
proposer_models = [
        LLMClient("Qwen/Qwen1.5-110B-Chat"),
        LLMClient("Qwen/Qwen1.5-72B-Chat"),
        LLMClient("alpindale/WizardLM-2-8x22B"),
        LLMClient("mistralai/Mixtral-8x22B-v0.1"),
        LLMClient("databricks/dbrx-instruct"),
]

In [None]:
aggregator = LLMClient("Qwen/Qwen1.5-110B-Chat")

In [None]:
moa = MOA(proposer_clients=proposer_models, num_layers=3)

In [None]:
initial_prompt = "What are the advantages and disadvantages of array?"
response = moa.run(initial_prompt, aggregator)

In [None]:
print(response)