In [48]:
# imports
import json

import transformers
import torch

from transformers import BitsAndBytesConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.integrations.transformers import (
    build_transformers_prefix_allowed_tokens_fn,
)

from deepeval.test_case import LLMTestCase
from deepeval import evaluate
from deepeval.models import DeepEvalBaseLLM
from deepeval.metrics import (HallucinationMetric, 
                              FaithfulnessMetric, 
                              BiasMetric,
                              ToolCorrectnessMetric
                              )

from pydantic import BaseModel

In [63]:
# define custom llm class for deepeval
class CustomLlama3_8B(DeepEvalBaseLLM):
    def __init__(self, model_path: str = None):

        quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )

        model_4bit = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="auto",
            quantization_config=quantization_config,
        )
        tokenizer = AutoTokenizer.from_pretrained(
            model_path
        )

        self.model = model_4bit
        self.tokenizer = tokenizer
        self.model_name = model_path.split("/")[-1]

    def load_model(self):
        return self.model

    def generate(self, prompt: str, schema: BaseModel) -> BaseModel:
        model = self.load_model()

        pipeline = transformers.pipeline(
            "text-generation",
            model=model,
            tokenizer=self.tokenizer,
            use_cache=True,
            device_map="auto",
            max_length=2500,
            do_sample=True,
            top_k=5,
            num_return_sequences=1,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.eos_token_id,
        )

        # Create parser required for JSON confinement using lmformatenforcer
        parser = JsonSchemaParser(schema.model_json_schema())
        prefix_function = build_transformers_prefix_allowed_tokens_fn(
            pipeline.tokenizer, parser
        )

        # Output and load valid JSON
        output_dict = pipeline(prompt, prefix_allowed_tokens_fn=prefix_function)
        output = output_dict[0]["generated_text"][len(prompt):]
        json_result = json.loads(output)

        return schema(**json_result)

    async def a_generate(self, prompt: str, schema: BaseModel) -> BaseModel:
        return self.generate(prompt, schema)

    def get_model_name(self):
        name = self.model_name
        return ' '.join(name.split('-'))

class Schema(BaseModel):
    answer: str

In [64]:
model_path = "meta-llama/Meta-Llama-3-8B-Instruct"
llama = CustomLlama3_8B(model_path=model_path)
schema = Schema

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [66]:
input='Tell me about TAF13.'
actual_output = llama.generate(input, schema)
print(actual_output)

answer='TAF13 is a transcriptional coactivator that plays a crucial role in regulating gene expression during development and tissue homeostasis. It is a subunit of the transcriptional coactivator complex TFIID, which is essential for the transcription of most protein-coding genes. TAF13 has been implicated in various biological processes, including cell proliferation, differentiation, and survival. Mutations in the TAF13 gene have been linked to various human diseases, including cancer and developmental disorders. Research on TAF13 has shed light on its importance in maintaining cellular homeostasis and its potential as a therapeutic target for treating diseases.'


In [67]:
# define context
context = ["TAF13, or TATA-Box Binding Protein Associated Factor 13, is a protein that is encoded by the TAF13 gene in humans.",
           "It is a subunit of the transcription initiation factor TFIID",
           "TAF13 is involved in RNA polymerase II transcription initiation and promoter clearance: TAF13 is part of the TFIID complex,which plays a major role in the initiation of transcription that is dependent on RNA polymerase II.",
           "TAF13 is involved in gene expression.",
           "TAF13 is involved in DNA-binding transcription factor activity."]

# define the metric
test_case = LLMTestCase(
    input=input,
    actual_output=actual_output,
    context=context
)

metric = HallucinationMetric(threshold=0.5, model=llama)

metric.measure(test_case)
print(metric.score)
print(metric.reason)

# or evaluate test cases in bulk
evaluate([test_case], [metric])

Output()

0.0
The score is 0.00 because the actual output is consistent with the provided context, indicating a high degree of factual accuracy and no hallucinations detected. The actual output agrees with multiple statements in the context, including TAF13's involvement in RNA polymerase II transcription initiation, being part of the TFIID complex, and being involved in gene expression and DNA-binding transcription factor activity. The consistency between the actual output and context suggests that the model has accurately extracted the relevant information and has not introduced any hallucinations. Therefore, the hallucination score is 0.00, indicating a perfect score for factual accuracy.


Event loop is already running. Applying nest_asyncio patch to allow async execution...


Evaluating 1 test case(s) in parallel: |          |  0% (0/1) [Time Taken: 00:00, ?test case/s]


ValidationError: 1 validation error for LLMApiTestCase
actualOutput
  Input should be a valid string [type=string_type, input_value=Schema(answer='TAF13 is a...for treating diseases.'), input_type=Schema]
    For further information visit https://errors.pydantic.dev/2.8/v/string_type