In [10]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

# Load models and tokenizers
model_name_a = "Salesforce/codet5-small"  # Replace with small LLM model for LLM-A
model_name_b = "google-t5/t5-small"  # Replace with small LLM model for LLM-B

# Load models
device = "mps" if torch.backends.mps.is_available() else "cpu"
tokenizer_a = AutoTokenizer.from_pretrained(model_name_a)
model_a = AutoModelForSeq2SeqLM.from_pretrained(model_name_a).to(device)

tokenizer_b = AutoTokenizer.from_pretrained(model_name_b)
model_b = AutoModelForSeq2SeqLM.from_pretrained(model_name_b).to(device)

# Define interaction functions
def generate_response(model, tokenizer, prompt, max_length=200):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    outputs = model.generate(
        inputs["input_ids"], max_length=max_length, num_return_sequences=1
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


def iterative_refinement(problem, max_iterations=5):
    iteration = 0
    solution = problem

    while iteration < max_iterations:
        print(f"\n--- Iteration {iteration + 1} ---")

        # LLM-A generates a response
        solution = generate_response(
            model_a, tokenizer_a, f"Task: {problem}\nSolution:", max_length=150
        )
        print(f"LLM-A's Response:\n{solution}")

        # LLM-B critiques the response
        critique = generate_response(
            model_b, tokenizer_b, f"Solution: {solution}\nCritique:", max_length=150
        )
        print(f"LLM-B's Critique:\n{critique}")

        # Evaluate stopping condition (can refine further logic here)
        if "no issues" in critique.lower():
            print("Consensus reached.")
            break

        # Prepare input for next iteration
        problem = critique  # Alternatively, merge critique with original task
        iteration += 1

    return solution


# Example usage
problem = "Write a Python function to calculate Fibonacci numbers recursively."
final_solution = iterative_refinement(problem)
print(f"\nFinal Solution:\n{final_solution}")


--- Iteration 1 ---
LLM-A's Response:
 def
LLM-B's Critique:
Solution: def Critique:

--- Iteration 2 ---
LLM-A's Response:
 def Critique_
Solution_Task :
LLM-B's Critique:
Solution: def Critique_ Solution_ Solution_Task : Critique:

--- Iteration 3 ---
LLM-A's Response:
Task: Critique:
Solution:
LLM-B's Critique:
: Task: Critique: Solution: Critique:

--- Iteration 4 ---
LLM-A's Response:
: Task
LLM-B's Critique:
Solution::

--- Iteration 5 ---
LLM-A's Response:
 public class
SolutionTask :
LLM-B's Critique:
public class SolutionTask : Critique:

Final Solution:
 public class
SolutionTask :
