In [9]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

model_name = "facebook/bart-large-cnn"  # LLama LLM model I am using for the aggregator
model_name2 = "openai-community/gpt2"

# Create a text-generation pipeline using the loaded model
summarizer_pipeline = pipeline(
    "summarization",
    model=model_name
)

tokenizer = AutoTokenizer.from_pretrained(model_name2)  
model = AutoModelForCausalLM.from_pretrained(model_name2)  

# Create a text-generation pipeline using the loaded model
aggregator_pipeline = pipeline(
    "text-generation",
    model="openai-community/gpt2",
    tokenizer=tokenizer,
)

def aggregate_response(user_prompt: str, expert_outputs: list[str], max_length: int = 150, temperature: float = 0.7) -> str:
    """
    Aggregates the responses from multiple expert models into a coherent answer using a Llama-based model.

    Args:
        user_prompt (str): The original user prompt.
        expert_outputs (list[str]): A list of responses from expert models.
        max_length (int): The maximum length of the generated output.
        temperature (float): Sampling temperature for generation.

    Returns:
        str: A coherent aggregated answer.
    """
    # Combine expert responses into a single context string.
    expert_responses = "\n".join(
        f"Expert {i+1}: {response}"
        for i, response in enumerate(expert_outputs)
    )
    # summarize prompts
    summarization_prompt = f"""
Below are responses from multiple experts. Create at most one paragraph, comprehensive summary, that combines their key points.

Expert Responses:
{expert_responses}

[SUMMARY_START]:

"""

    summarized_text_model = summarizer_pipeline(
        summarization_prompt,
        max_length=max_length,
        temperature=temperature,
        do_sample=True 
    )

    summarized_text = summarized_text_model[0]['generated_text']

    if "[SUMMARY_START]" in summarized_text:
        summarized_text = summarized_text.split("[SUMMARY_START]")[1].strip()
    else:
      summarized_text = summarized_text[len(summarization_prompt):].strip()
    # Construct the aggregator prompt.
    aggregator_prompt = f"""
You are designed to analyze the given user query, and the summarization of responses of experts to the same query. Using this,
make sure you produce a meaningful and helpful response. Synthesize the information to produce a single, clear aggregated answer.

User Query:
{user_prompt}

Summarization of Expert Responses:
{summarized_text}

[ANSWER_START]
"""
    print(summarized_text)
    # Generate the aggregated answer using the Llama-based aggregator model.
    generated_output = aggregator_pipeline(
        aggregator_prompt,
        max_length=max_length,
        temperature=temperature,
        do_sample=True  # Set to False for deterministic output
    )

    # The output is a list of dictionaries; extract the generated text.
    aggregated_answer = generated_output[0]['generated_text']
    if "[ANSWER_START]" in aggregated_answer:
        aggregated_answer = aggregated_answer.split("[ANSWER_START]")[1].strip()
    else:
        # Fallback: take only the new generated text
        aggregated_answer = aggregated_answer[len(aggregator_prompt):].strip()

    return aggregated_answer

# Example usage:
if __name__ == "__main__":
    # Original user prompt
    user_query = "How can I optimize the performance of my deep learning models in production?"

    # Sample responses from expert models. The router should input these responses into the aggregator at a later step
    expert_responses = [
        "Consider using model quantization and pruning techniques to reduce model size and inference latency.",
        "Implementing efficient serving architectures such as TensorRT or ONNX Runtime can boost performance.",
        "Optimizing data pipelines and leveraging hardware accelerators like GPUs or TPUs will further improve throughput."
    ]

    final_answer = aggregate_response(user_query, expert_responses)
    print("Aggregated Answer:\n", final_answer)


Device set to use cpu
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Aggregated Answer:
 New PROMPT

New PROMPT

New PROMPT

New PROMPT

New PROMPT

New PROMPT

New PROMPT

New PROMPT

New PROMPT

New PROMPT

New
