In [1]:
from modal import Cls
from textwrap import dedent

In [2]:
schema = '''{
    "title": "Character",
    "type": "object",
    "properties": {
        "name": {
            "title": "Name",
            "maxLength": 20,
            "type": "string"
        },
        "age": {
            "title": "Age",
            "type": "integer"
        },
        "armor": {"$ref": "#/definitions/Armor"},
        "weapon": {"$ref": "#/definitions/Weapon"},
        "strength": {
            "title": "Strength",
            "type": "integer"
        }
    },
    "required": ["name", "age", "armor", "weapon", "strength"],
    "definitions": {
        "Armor": {
            "title": "Armor",
            "description": "An enumeration.",
            "enum": ["leather", "chainmail", "plate", "cape", "poncho"],
            "type": "string"
        },
        "Weapon": {
            "title": "Weapon",
            "description": "An enumeration.",
            "enum": ["sword", "axe", "mace", "spear", "bow", "crossbow", "wand", "charango"],
            "type": "string"
        }
    }
}'''

prompt = "Give me a character description"

## Mistral

In [7]:
prompt_template = dedent(
                            """\
                        [INST]
                        A user is gonna ask you a question.
                        You must answer the user's question by replying VALID JSON that matches the schema below:
                        
                        ```json
                        {schema}
                        ```
                        
                        ---
                        
                        The user's question below
                        
                        ```text
                        {question}
                        ```
                        
                        [/INST]
                        """)
Model = Cls.lookup("outlines-app", "Model")
m_mistral = Model(model_name="mistralai/Mistral-7B-Instruct-v0.2")
result = m_mistral.generate.remote(schema.strip(), prompt_template.format(schema=schema.strip(), question=prompt))

In [8]:
result

{'name': 'Thorgrim Ironfist',
 'age': 42,
 'armor': 'plate',
 'weapon': 'mace',
 'strength': 18}

In [9]:
prompts = ["Give me a funny character description",
           "Give me a chilean character description",
           "Someone from Harry Potter universe",
           "Who you gonna call?"]

In [10]:
for p in prompts:
    print(m_mistral.generate.remote(schema.strip(), prompt_template.format(schema=schema.strip(), question=p)))

{'name': ' crazy Steve', 'age': 35, 'armor': 'poncho', 'weapon': 'charango', 'strength': 8}
{'name': 'Pedro Vi variance', 'age': 35, 'armor': 'poncho', 'weapon': 'charango', 'strength': 45}
{'name': 'Harry Potter', 'age': 17, 'armor': 'cape', 'weapon': 'wand', 'strength': 100}
{'name': 'Ghostbusters', 'age': 30, 'armor': 'cape', 'weapon': 'wand', 'strength': 50}


## Gemma

In [12]:
gemma_prompt_template = dedent(
                        """
                    <bos><start_of_turn>user\n
                    A user is gonna ask you a question, you need to extract the arguments to be passed to the function that can answer the question.
                    You must answer the user's question by replying VALID JSON that matches the schema below:\n
                    ```json\n
                    {schema}\n
                    ```\n
                    The user's question below:\n
                    ```text\n
                    {question}\n
                    ```\n
                    <end_of_turn>\n
                    <start_of_turn>model\n
                    """)
Model = Cls.lookup("outlines-app", "Model")
m_gemma = Model(model_name="google/codegemma-7b-it")
result = m_gemma.generate.remote(schema.strip(),
                                 gemma_prompt_template.format(schema=schema.strip(), question=prompt))

In [13]:
result

{'name': "User's name",
 'age': 0,
 'armor': 'leather',
 'weapon': 'sword',
 'strength': 0}

In [14]:
for p in prompts:
    print(m_gemma.generate.remote(schema.strip(),
                                  gemma_prompt_template.format(schema=schema.strip(), question=p)))

{'name': 'Zork the Suggestive', 'age': 250, 'armor': 'cape', 'weapon': 'wand', 'strength': 150}
{'name': 'Chilean Character', 'age': 25, 'armor': 'leather', 'weapon': 'sword', 'strength': 15}
{'name': 'Harry Potter', 'age': 11, 'armor': 'cape', 'weapon': 'wand', 'strength': 10}
{'name': 'main-character', 'age': 18, 'armor': 'plate', 'weapon': 'sword', 'strength': 10}
