# this should give structured output

In [8]:
ollama_api_endpoint = "http://localhost:11434/api/chat"

In [9]:
import requests
from pydantic import BaseModel
from typing import List, Type, TypeVar

T = TypeVar("T", bound=BaseModel)


class CanadaInfo(BaseModel):
    name: str
    capital: str
    languages: List[str]


def ask_ollama(endpoint, model, prompt: str, model_class: Type[T]) -> T:
    """
    Sends a prompt to Ollama API and parses the result into the given Pydantic model.

    Args:
        prompt (str): The user's input prompt.
        model_class (Type[T]): The Pydantic model class to parse the response.

    Returns:
        An instance of the provided Pydantic model class.
    """
    # Convert Pydantic model schema to JSON Schema
    schema = model_class.model_json_schema()

    payload = {
        "model": model,
        "messages": [{"role": "user", "content": prompt}],
        "stream": False,
        "format": schema,  # Ollama expects the format in JSON schema form
    }

    resp = requests.post(
        endpoint,
        headers={"Content-Type": "application/json"},
        json=payload,
    )
    resp.raise_for_status()

    # The API returns a dict with 'message' -> 'content' containing the JSON output
    raw_json = resp.json()["message"]["content"]

    # Parse into the Pydantic model
    return model_class.model_validate_json(raw_json)


if __name__ == "__main__":
    result = ask_ollama(
        ollama_api_endpoint, "llama3", "Tell me about Canada.", CanadaInfo
    )
    print(result)
    print(result.name, result.capital, result.languages)

name='Canada' capital='Ottawa' languages=['English', 'French', 'Indigenous languages']
Canada Ottawa ['English', 'French', 'Indigenous languages']
