In [None]:
from typing import Iterable, cast

from datasets import load_dataset
from dotenv import load_dotenv
from pydantic import BaseModel

from intelligence_layer.connectors import (
    ArgillaEvaluation,
    DefaultArgillaClient,
    Field,
    LimitedConcurrencyClient,
    Question,
    RecordData,
)
from intelligence_layer.core import (
    CompleteOutput,
    Instruct,
    InstructInput,
    LuminousControlModel,
)
from intelligence_layer.evaluation import (
    AggregationLogic,
    ArgillaAggregator,
    ArgillaEvaluationLogic,
    ArgillaEvaluationRepository,
    ArgillaEvaluator,
    Example,
    InMemoryAggregationRepository,
    InMemoryDatasetRepository,
    InMemoryEvaluationRepository,
    InMemoryRunRepository,
    RecordDataSequence,
    Runner,
    SuccessfulExampleOutput,
)

load_dotenv()

client = LimitedConcurrencyClient.from_env()

# Human Evaluation using the Intelligence Layer

Although there are a variety of ways to automate the evaluation of LLM-based tasks, sometimes it is still necessary to get a human opinion.
To make this as painless as possible, we have integrated an [Argilla](https://argilla.io/)-Evaluator into the intelligence layer.
This notebook serves as a quick start guide.

## Environment setup
This notebook expects that you have added your Aleph Alpha token to your .env file.
Additionally you need to add the `ARGILLA_API_URL` and `ARGILLA_API_KEY` from env.sample to your .env file. 
Next, run

```bash
docker-compose up -d
``` 

from the intelligence layer base directory.

Once you go to `localhost:6900` and you are prompted to enter a username and password, use:
- username: `argilla`
- password: `1234`

## Dataset Repository definition
First we need to define our dataset. Here we use an [Instruction Dataset](https://huggingface.co/datasets/HuggingFaceH4/instruction-dataset?row=0) from Huggingface. Before we can use it for human eval, we need to make an intelligence layer dataset repository.

In [None]:
dataset = load_dataset("HuggingfaceH4/instruction-dataset")["test"]

Let us explore the dataset a bit. It consists of prompts, example completions and metadata for 327 examples. Since we are doing human eval, for now we only need the prompt and corresponding id.

In [None]:
print(dataset)
print(dataset["meta"][0].keys())

We could now build a single example like this:

In [None]:
example = Example(
    input=InstructInput(instruction=dataset["prompt"][0], input=None),
    expected_output=None,
    id=str(dataset["meta"][0]["id"]),
)

For our dataset repository, we can either use a FileDatasetRepository or an InMemoryDatasetRepository.

In [None]:
num_examples = 5
assert num_examples <= len(dataset)
dataset_repository = InMemoryDatasetRepository()
dataset_id = dataset_repository.create_dataset(
    examples=[
        Example(
            input=InstructInput(instruction=dataset["prompt"][i], input=None),
            expected_output=None,
            id=str(dataset["meta"][i]["id"]),
        )
        for i in range(num_examples)
    ],
    dataset_name="human-evaluation-dataset",
).id

## Task Setup

We use an Instruction task to run the examples in the Instruct dataset.
In addition, we define a `Runner` to generate the completions from the model for our dataset
and a `RunRepository` to save the results.

In [None]:
model = LuminousControlModel(name="luminous-base-control", client=client)
task = Instruct(model=model)
run_repository = InMemoryRunRepository()
runner = Runner(task, dataset_repository, run_repository, "Instruct")
run_overview = runner.run_dataset(dataset_id)

## Evaluator Definition


At the end of our evaluation we want a float score $$s \in [1,5]$$ describing the model performance.
We define this as `InstructAggregatedEvaluation`

In [None]:
class InstructAggregatedEvaluation(BaseModel):
    general_rating: float | None
    fluency: float | None
    evaluated_examples: int

![Argilla Interface](../../assets/argilla_interface.png)

In the Argilla UI, our model input (Instruction) and output (Model Completion) will be shown on the left hand side.
They are defined below using the `fields` list.
Note that the field names have to match the content keys from the `RecordData` which we will define later in our `InstructArgillaEvaluationLogic`.

On the right side of the UI, the rating interface will be shown.
It is used to serve a number of questions that can be rated by the user.
Currently, only integer scales are accepted.
The `name` property will later be used to access the human ratings in the aggregation step

In [None]:
questions = [
    Question(
        name="general_rating",
        title="Rating",
        description="Rate this completion on a scale from 1 to 5",
        options=range(1, 6),
    ),
    Question(
        name="fluency",
        title="Fluency",
        description="How fluent is the completion?",
        options=range(1, 6),
    ),
]

fields = [
    Field(name="input", title="Instruction"),
    Field(name="output", title="Model Completion"),
]

We can now define our `InstructArgillaEvaluationLogic` and `InstructArgillaAggregationLogic`.
They have to implement the two abstract methods `_to_record` and `aggregate` respectively.
Lets look at the documentation:

In [None]:
help(AggregationLogic.aggregate)
help(ArgillaEvaluationLogic._to_record)

In [None]:
class InstructArgillaAggregationLogic(
    AggregationLogic[ArgillaEvaluation, InstructAggregatedEvaluation]
):
    def aggregate(
        self,
        evaluations: Iterable[ArgillaEvaluation],
    ) -> InstructAggregatedEvaluation:
        evaluations = list(evaluations)

        if len(evaluations) == 0:  # if no evaluations were submitted, return
            return InstructAggregatedEvaluation(
                general_rating=None,
                fluency=None,
                evaluated_examples=0,
            )

        general_rating = sum(
            cast(float, evaluation.responses["general_rating"])
            for evaluation in evaluations
        ) / len(evaluations)

        fluency = sum(
            cast(float, evaluation.responses["fluency"]) for evaluation in evaluations
        ) / len(evaluations)

        return InstructAggregatedEvaluation(
            general_rating=general_rating,
            fluency=fluency,
            evaluated_examples=len(evaluations),
        )


class InstructArgillaEvaluationLogic(
    ArgillaEvaluationLogic[
        InstructInput,
        CompleteOutput,
        None,
    ]
):
    def _to_record(
        self,
        example: Example[InstructInput, None],
        *example_outputs: SuccessfulExampleOutput[CompleteOutput],
    ) -> RecordDataSequence:
        return RecordDataSequence(
            records=[
                RecordData(
                    content={
                        "input": example.input.instruction,
                        "output": example_outputs[0].output.completion,
                    },
                    example_id=example.id,
                )
            ]
        )


argilla_client = DefaultArgillaClient()
workspace_id = argilla_client.ensure_workspace_exists("test")
evaluation_repository = InMemoryEvaluationRepository()
aggregation_repository = InMemoryAggregationRepository()
eval_logic = InstructArgillaEvaluationLogic()
aggregation_logic = InstructArgillaAggregationLogic()

argilla_evaluation_repository = ArgillaEvaluationRepository(
    evaluation_repository, argilla_client, workspace_id, fields, questions
)

evaluator = ArgillaEvaluator(
    dataset_repository,
    run_repository,
    argilla_evaluation_repository,
    "instruct",
    eval_logic,
)

The `evaluate_runs` posts the records created from a run to the argilla instance.

In [None]:
try:
    eval_overview = evaluator.evaluate_runs(run_overview.id)
    print(eval_overview)

except Exception as e:
    eval_overview = None
    print(str(e))

Once we have evaluated the examples in the Argilla UI we can retrieve the evaluation results via the `ArgillaAggregator`.

In [None]:
aggregator = ArgillaAggregator(
    argilla_evaluation_repository,
    aggregation_repository,
    "instruct",
    aggregation_logic,
)

if eval_overview:
    output = aggregator.aggregate_evaluation(eval_overview.id)
    print(output.statistics)