In [None]:
import logging
from typing import Optional, Sequence, Dict, Union

from dotenv import load_dotenv
from pydantic import BaseModel
from trulens_eval.feedback import Feedback
from trulens_eval.schema.feedback import FeedbackResult
from trulens_eval import LiteLLM
from trulens_eval.feedback.provider.base import LLMProvider
from trulens_eval import Select

load_dotenv()

In [None]:
logging.basicConfig()
logging.getLogger("proxy_model").setLevel(logging.INFO)
logger = logging.getLogger('proxy_model')
logger.info('logging')

In [None]:
class IterationFeedbackResults(BaseModel):
    passed: bool
    feedback_scores: dict[str, FeedbackResult]
    source_data: dict[str, Optional[Union[Sequence[Dict], Sequence[str], str]]]

class ProxyModelResponse(BaseModel):
    response: str
    all_responses: list[IterationFeedbackResults]
    passed: bool

    def print_conversation(self, n_iter=-1):
        iteration = self.all_responses[n_iter]
        for message in iteration.source_data['messages']:
            print(f"\x1b[31m{message['role']}\x1b[0m: {message['content']}")
        response = iteration.source_data['response']
        if response:
            print(f"\x1b[31m{'assistant'}\x1b[0m: {response}")
    

class ModelConfig(BaseModel):
    provider: LLMProvider
    n_retries: int = 1


In [None]:
class ProxyModel:
    def __init__(self, model_configs: list[ModelConfig], feedbacks: list[Feedback], feedback_thresholds: Optional[Dict[str, float]] = None):
        self.model_configs = model_configs
        self.feedbacks = feedbacks
        if not feedback_thresholds:
            feedback_thresholds = {feedback.name: .5 for feedback in self.feedbacks}
        self.feedback_thresholds = feedback_thresholds
        self._validate_args()

    def _validate_args(self):
        assert len(self.feedback_thresholds) == len(self.feedbacks)
        for model_config in self.model_configs:
            assert isinstance(model_config, ModelConfig)
        
        for feedback in self.feedbacks:
            assert isinstance(feedback, Feedback)
            assert feedback.name in self.feedback_thresholds
            assert 0 < self.feedback_thresholds[feedback.name] < 1

    def _feedback_results_as_user_response(self, feedback_results: IterationFeedbackResults, prompt: str):
        response_buffer = f"Feedback: {'good response' if feedback_results.passed else "bad response"}"
        for name, result in feedback_results.feedback_scores.items():
            response_buffer += f"\n{name} score: {result.result}/1.0"
            reasons = []
            for call in result.calls:
                if "reason" in call.meta:
                    reasons.append(call.meta['reason'])
            reasons_str = "\n - " + "\n - ".join(reasons)
            response_buffer += f"\nReasoning:\n{reasons_str}"
        response_buffer += f"\nGiven this feedback, answer the prompt again. {prompt}"
        return response_buffer
    
    def _format_prompt(self, prompt: str, contexts: Sequence[str]):
        context_str = "\n - " + "\n - ".join(contexts)
        return f"Use the context to answer this prompt.\nCONTEXT: {context_str}\nPROMPT: {prompt}"
    
    def rag_chat(self, prompt: str, contexts: Optional[Sequence[str]] = None, messages: Optional[Sequence[Dict]] = None, **kwargs):
        if contexts:
            prompt = self._format_prompt(prompt, contexts)
        return self._create_chat_completion(prompt=prompt, messages=messages, contexts=contexts, **kwargs)

    def _create_chat_completion(
        self,
        prompt: Optional[str] = None,
        messages: Optional[Sequence[Dict]] = None,
        **kwargs
    ) -> ProxyModelResponse:
        all_responses = []

        messages = messages or []
        if prompt is not None:
            messages.append({
                "role": "user",
                "content": prompt
            })
        
        for model_config in self.model_configs:
            for n_iter in range(1, model_config.n_retries + 1):
                response = model_config.provider._create_chat_completion(messages=messages)
                source_data = {
                    "prompt": prompt,
                    "messages": messages,
                    "response": response,
                } | kwargs
                feedback_results = self._score_feedback(source_data)
                all_responses.append(feedback_results)
                messages.append({
                    "role": "assistant",
                    "content": response
                })
                
                if feedback_results.passed:
                    logger.info(f"({n_iter}/{model_config.n_retries}) {model_config.provider.model_engine} passed.")
                    break
                
                logger.info(f"({n_iter}/{model_config.n_retries}) {model_config.provider.model_engine} did not pass feedback thresholds. Escalating")
                messages.append({
                    "role": "user",
                    "content": self._feedback_results_as_user_response(feedback_results=feedback_results, prompt=prompt)
                })

            if feedback_results.passed:
                break
            else:
                logger.info(f"{model_config.provider.model_engine} failed {model_config.n_retries} times. Escalating")

        return ProxyModelResponse(
            response=response, 
            all_responses=all_responses, 
            passed=feedback_results.passed
        )
    
    def _score_feedback(
        self, 
        source_data: dict[str, str]
    ) -> IterationFeedbackResults:
        feedback_scores = {feedback.name: feedback.run(source_data=source_data) for feedback in self.feedbacks}
        passed = True
        for name, threshold in self.feedback_thresholds.items():
            if feedback_scores[name].result is not None and feedback_scores[name].result < threshold:
                passed = False
                break
        return IterationFeedbackResults(
            source_data=source_data, 
            passed=passed, 
            feedback_scores=feedback_scores
        )


In [None]:
model_configs = [
    ModelConfig(provider=LiteLLM("replicate/mistralai/mistral-7b-instruct-v0.2"), n_retries=2),
    ModelConfig(provider=LiteLLM("replicate/meta/llama-2-70b-chat"), n_retries=2),
    ModelConfig(provider=LiteLLM("azure/sfc-ml-sweden-gpt4-managed", completion_kwargs={"api_base": "https://sfc-apim-sweden.azure-api.net"}), n_retries=2),
]

In [None]:
model_configs[-1].provider._create_chat_completion("What is 2+2?")

In [None]:
feedback_provider = LiteLLM("azure/sfc-ml-sweden-gpt4-managed", completion_kwargs={"api_base": "https://sfc-apim-sweden.azure-api.net"})
feedbacks = [
    Feedback(feedback_provider.relevance_with_cot_reasons, name="answer_relevance").on(Select.Tru.prompt).on(Select.Tru.response),
    # Feedback(feedback_provider.context_relevance_with_cot_reasons, name="context relevance").on(Select.Tru.prompt).on(Select.Tru.contexts[:]),
    # Feedback(feedback_provider.groundedness_measure_with_cot_reasons, name="groundedness").on(Select.Tru.contexts[:]).on(Select.Tru.response)
]

In [None]:
model = ProxyModel(
    model_configs=model_configs,
    feedbacks=feedbacks,
    feedback_thresholds={feedback.name: .9 for feedback in feedbacks}
)

In [None]:
resp = model.rag_chat(
    prompt="Why was Franklin born?", 
    contexts=[
        "Benjamin Franklin FRS FRSA FRSE (January 17, 1706 [O.S. January 6, 1705][Note 1] – April 17, 1790) was an American polymath: a leading writer, scientist, inventor, statesman, diplomat, printer, publisher, and political philosopher.[1] Among the most influential intellectuals of his time, Franklin was one of the Founding Fathers of the United States; a drafter and signer of the Declaration of Independence; and the first postmaster general.[2]",
        "Franklin became a successful newspaper editor and printer in Philadelphia, the leading city in the colonies, publishing the Pennsylvania Gazette at age 23.[3] He became wealthy publishing this and Poor Richard's Almanack, which he wrote under the pseudonym 'Richard Saunders'.[4] After 1767, he was associated with the Pennsylvania Chronicle, a newspaper known for its revolutionary sentiments and criticisms of the policies of the British Parliament and the Crown.[5]",
        "Benjamin Franklin's father, Josiah Franklin, was a tallow chandler, soaper, and candlemaker. Josiah Franklin was born at Ecton, Northamptonshire, England, on December 23, 1657, the son of Thomas Franklin, a blacksmith and farmer, and his wife, Jane White. Benjamin's father and all four of his grandparents were born in England."
    ]
)
resp.response

In [None]:
resp.all_responses[0].feedback_scores['answer_relevance'].result

In [None]:
resp.all_responses[0]

In [None]:
resp.print_conversation()