In [1]:
!pip install lm-format-enforcer

Collecting lm-format-enforcer
  Downloading lm_format_enforcer-0.10.10-py3-none-any.whl.metadata (17 kB)
Collecting interegular>=0.3.2 (from lm-format-enforcer)
  Downloading interegular-0.3.3-py37-none-any.whl.metadata (3.0 kB)
Downloading lm_format_enforcer-0.10.10-py3-none-any.whl (44 kB)
Downloading interegular-0.3.3-py37-none-any.whl (23 kB)
Installing collected packages: interegular, lm-format-enforcer
Successfully installed interegular-0.3.3 lm-format-enforcer-0.10.10


In [None]:
import json
import transformers
from pydantic import BaseModel
from lmformatenforcer import JsonSchemaParser
from lmformatenforcer.integrations.transformers import (
    build_transformers_prefix_allowed_tokens_fn,
)

from deepeval.models import DeepEvalBaseLLM


class CustomLLM(DeepEvalBaseLLM):
    ...

    def generate(self, prompt: str, schema: BaseModel) -> BaseModel:
        # Same as the previous example above
        model = self.load_model()
        pipeline = transformers.pipeline(
            "text-generation",
            model=model,
            tokenizer=self.tokenizer,
            use_cache=True,
            device_map="auto",
            max_length=2500,
            do_sample=True,
            top_k=5,
            num_return_sequences=1,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.eos_token_id,
        )

        # Create parser required for JSON confinement using lmformatenforcer
        parser = JsonSchemaParser(schema.model_json_schema())
        prefix_function = build_transformers_prefix_allowed_tokens_fn(
            pipeline.tokenizer, parser
        )

        # Output and load valid JSON
        output_dict = pipeline(prompt, prefix_allowed_tokens_fn=prefix_function)
        output = output_dict[0]["generated_text"][len(prompt) :]
        json_result = json.loads(output)

        # Return valid JSON object according to the schema DeepEval supplied
        return schema(**json_result)

    async def a_generate(self, prompt: str, schema: BaseModel) -> BaseModel:
        return self.generate(prompt, schema)

# Docs

[Creating a custom LLM](https://github.com/confident-ai/deepeval/blob/main/docs/guides/guides-using-custom-llms.mdx)