In [None]:
# type: ignore

# Dynamic In-Context Learning

This recipe allows TensorZero users to set up a dynamic in-context learning variant for any function.
Since TensorZero automatically logs all inferences and feedback, it is straightforward to query a set of good examples and retrieve the most relevant ones to put them into context for future inferences.
Since TensorZero allows users to add demonstrations for any inference it is also easy to include them in the set of examples as well.
This recipe will show use the OpenAI embeddings API only, but we are working towards support for all embedding providers over time as well.


In [1]:
# Load environment variables from .env file
from dotenv import load_dotenv

# Load from the root .env file
load_dotenv('/home/alchen/claude/tensorzero/.env')

print("Environment variables loaded successfully!")

Environment variables loaded successfully!


To get started:

- Set the `TENSORZERO_CLICKHOUSE_URL` environment variable. For example: `TENSORZERO_CLICKHOUSE_URL="http://chuser:chpassword@localhost:8123/tensorzero"`
- Set the `OPENAI_API_KEY` environment variable.
- Update the following parameters:


In [13]:
from typing import Optional

CONFIG_PATH = "../../examples/data-extraction-ner/config/tensorzero.toml"

FUNCTION_NAME = "extract_entities"

# Can also set this to None if you do not want to use a metric and only want to use demonstrations
METRIC_NAME: Optional[str] = None

# The name of the DICL variant you will want to use. Set this to a meaningful name that does not conflict
# with other variants for the function selected above.
DICL_VARIANT_NAME = "gpt_4o_mini_dicl"

# The model to use for the DICL variant.
DICL_EMBEDDING_MODEL = "text-embedding-3-small"

# The model to use for generation in the DICL variant.
DICL_GENERATION_MODEL = "gpt-4.1-mini-2025-04-14"

# The number of examples to retrieve for the DICL variant.
DICL_K = 10

# If the metric is a float metric, you can set the threshold to filter the data
FLOAT_METRIC_THRESHOLD = 0.5

# Whether to use demonstrations for DICL examples
USE_DEMONSTRATIONS = True

In [14]:
import os
from asyncio import Semaphore
from pathlib import Path

import pandas as pd
import toml
from clickhouse_connect import get_client
from openai import AsyncOpenAI
from tensorzero.util import uuid7
from tqdm.asyncio import tqdm_asyncio

Load the TensorZero configuration file.


In [15]:
config_path = Path(CONFIG_PATH)

assert config_path.exists(), f"{CONFIG_PATH} does not exist"
assert config_path.is_file(), f"{CONFIG_PATH} is not a file"

with config_path.open("r") as f:
    config = toml.load(f)

Retrieve the configuration for the function we are interested in.


In [16]:
assert "functions" in config, "No `[functions]` section found in config"
assert FUNCTION_NAME in config["functions"], (
    f"No function named `{FUNCTION_NAME}` found in config"
)

function_config = config["functions"][FUNCTION_NAME]
function_type = function_config["type"]
print(function_config)
print(function_type)

{'type': 'json', 'output_schema': 'functions/extract_entities/output_schema.json', 'variants': {'gpt_4o': {'type': 'chat_completion', 'model': 'openai::gpt-4o-2024-08-06', 'system_template': 'functions/extract_entities/initial_prompt/system_template.minijinja', 'json_mode': 'strict'}, 'gpt_4o_mini': {'type': 'chat_completion', 'model': 'openai::gpt-4o-mini-2024-07-18', 'system_template': 'functions/extract_entities/initial_prompt/system_template.minijinja', 'json_mode': 'strict'}, 'gpt_4o_mini_fine_tuned': {'type': 'chat_completion', 'model': 'openai::ft:gpt-4.1-mini-2025-04-14:al-test::Bz8DBr4C', 'system_template': 'functions/extract_entities/initial_prompt/system_template.minijinja', 'json_mode': 'strict'}}}
json


Retrieve the metric configuration.


In [17]:
if METRIC_NAME is None:
    metric = None
else:
    assert "metrics" in config, "No `[metrics]` section found in config"
    assert METRIC_NAME in config["metrics"], (
        f"No metric named `{METRIC_NAME}` found in config"
    )
    metric = config["metrics"][METRIC_NAME]

metric
print(metric)

None


Initialize the ClickHouse client.


In [18]:
assert "TENSORZERO_CLICKHOUSE_URL" in os.environ, (
    "TENSORZERO_CLICKHOUSE_URL environment variable not set"
)

clickhouse_client = get_client(dsn=os.environ["TENSORZERO_CLICKHOUSE_URL"])

Determine the ClickHouse table name for the function.


In [19]:
inference_table_name = {"chat": "ChatInference", "json": "JsonInference"}.get(
    function_type
)

if inference_table_name is None:
    raise ValueError(f"Unsupported function type: {function_type}")

Determine the ClickHouse table name for the metric.


In [20]:
feedback_table_name = (
    {
        "float": "FloatMetricFeedback",
        "boolean": "BooleanMetricFeedback",
    }.get(metric["type"])
    if metric is not None
    else None
)

if feedback_table_name is None and metric is not None:
    raise ValueError(f"Unsupported metric type: {metric['type']}")

Determine the correct join key to use for the metric on the inference table.


In [21]:
inference_join_key = (
    {
        "episode": "episode_id",
        "inference": "id",
    }.get(metric["level"])
    if metric is not None
    else None
)

if inference_join_key is None and metric is not None:
    raise ValueError(f"Unsupported metric level: {metric['level']}")

In [22]:
if metric is not None:
    assert "optimize" in metric, "Metric is missing the `optimize` field"

    threshold = FLOAT_METRIC_THRESHOLD if metric["type"] == "float" else 0.5
    comparison_operator = ">=" if metric["optimize"] == "max" else "<="

    query = f"""
    SELECT
        i.input,
        i.output,
    FROM
        {inference_table_name} i
    JOIN
        (SELECT
            target_id,
            value,
            ROW_NUMBER() OVER (PARTITION BY target_id ORDER BY timestamp DESC) as rn
        FROM
            {feedback_table_name}
        WHERE
            metric_name = %(metric_name)s
            AND value {comparison_operator} %(threshold)s
        ) f ON i.{inference_join_key} = f.target_id and f.rn = 1
    WHERE
        i.function_name = %(function_name)s
    """

    params = {
        "function_name": FUNCTION_NAME,
        "metric_name": METRIC_NAME,
        "comparison_operator": comparison_operator,
        "threshold": threshold,
    }

    metric_df = clickhouse_client.query_df(query, params)

    metric_df.head()
else:
    metric_df = None

In [32]:
query = f"""
SELECT
    i.input,
    f.value AS output
FROM
    {inference_table_name} i
JOIN
    (SELECT
        inference_id,
        value,
        ROW_NUMBER() OVER (PARTITION BY inference_id ORDER BY timestamp DESC) as rn
    FROM
        DemonstrationFeedback
    ) f ON i.id = f.inference_id AND f.rn = 1
WHERE
    i.function_name = %(function_name)s
"""

params = {
    "function_name": FUNCTION_NAME,
}

if USE_DEMONSTRATIONS:
    demonstration_df = clickhouse_client.query_df(query, params)

    demonstration_df.head()
else:
    demonstration_df = None

In [33]:
# Combine metric_df and demonstration_df into example_df
example_df = pd.concat(
    [df for df in [metric_df, demonstration_df] if df is not None], ignore_index=True
)

# Assert that at least one of the dataframes is not None
assert example_df is not None and not example_df.empty, (
    "Both metric_df and demonstration_df are None or empty"
)

# Display the first few rows of the combined dataframe
example_df.head()

Unnamed: 0,input,output
0,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[],\""l..."
1,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[\""C. Lewis\"",\""Wasim Akra..."
2,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[\""Glickman\""],\""organizat..."
3,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[],\""l..."
4,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[\""Said\""],\""organization\..."


In [34]:
openai_client = AsyncOpenAI()

In [35]:
async def get_embedding(
    text: str, semaphore: Semaphore, model: str = "text-embedding-3-small"
) -> Optional[list[float]]:
    try:
        async with semaphore:
            response = await openai_client.embeddings.create(input=text, model=model)
            return response.data[0].embedding
    except Exception as e:
        print(f"Error getting embedding: {e}")
        return None

In [36]:
MAX_CONCURRENT_EMBEDDING_REQUESTS = 50
semaphore = Semaphore(MAX_CONCURRENT_EMBEDDING_REQUESTS)

In [37]:
# Embed the 'input' column using the get_embedding function
tasks = [
    get_embedding(str(input_text), semaphore, DICL_EMBEDDING_MODEL)
    for input_text in example_df["input"]
]
embeddings = await tqdm_asyncio.gather(*tasks, desc="Embedding inputs")

Embedding inputs: 100%|██████████| 500/500 [00:05<00:00, 96.67it/s] 


In [38]:
# Add the embeddings as a new column to the dataframe
example_df["embedding"] = embeddings

# Display the first few rows to verify the new column
print(example_df[["input", "embedding"]].head())

                                               input  \
0  {"messages":[{"role":"user","content":[{"type"...   
1  {"messages":[{"role":"user","content":[{"type"...   
2  {"messages":[{"role":"user","content":[{"type"...   
3  {"messages":[{"role":"user","content":[{"type"...   
4  {"messages":[{"role":"user","content":[{"type"...   

                                           embedding  
0  [-0.03471248596906662, 0.0012018261477351189, ...  
1  [-0.008944535627961159, -0.0022327937185764313...  
2  [-0.025813400745391846, 0.009400255978107452, ...  
3  [-0.03811752423644066, -0.029765434563159943, ...  
4  [0.016340002417564392, -0.0020620147697627544,...  


Prepare the data for the DynamicInContextLearningExample table
The table schema is as follows:

```
CREATE TABLE tensorzero.DynamicInContextLearningExample
(
    `id` UUID,
    `function_name` LowCardinality(String),
    `variant_name` LowCardinality(String),
    `namespace` String,
    `input` String,
    `output` String,
    `embedding` Array(Float32),
    `timestamp` DateTime MATERIALIZED UUIDv7ToDateTime(id)
)
ENGINE = MergeTree
ORDER BY (function_name, variant_name, namespace)
```


In [39]:
# Add a new column 'function_name' with the value FUNCTION_NAME for every row
example_df["function_name"] = FUNCTION_NAME

# Overwrite the 'variant_name' column with the value DICL_VARIANT_NAME for every row
example_df["variant_name"] = DICL_VARIANT_NAME

# Add a new column 'id' with a UUID for every row
example_df["id"] = [uuid7() for _ in range(len(example_df))]

In [40]:
example_df.head()

Unnamed: 0,input,output,embedding,function_name,variant_name,id
0,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[],\""l...","[-0.03471248596906662, 0.0012018261477351189, ...",extract_entities,gpt_4o_mini_dicl,01986213-620a-7b21-a0c3-2a5f422ebf36
1,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[\""C. Lewis\"",\""Wasim Akra...","[-0.008944535627961159, -0.0022327937185764313...",extract_entities,gpt_4o_mini_dicl,01986213-620a-7b21-a0c3-2a686f1bb53f
2,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[\""Glickman\""],\""organizat...","[-0.025813400745391846, 0.009400255978107452, ...",extract_entities,gpt_4o_mini_dicl,01986213-620a-7b21-a0c3-2a731e8ffa39
3,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[],\""organization\"":[],\""l...","[-0.03811752423644066, -0.029765434563159943, ...",extract_entities,gpt_4o_mini_dicl,01986213-620a-7b21-a0c3-2a8231086fee
4,"{""messages"":[{""role"":""user"",""content"":[{""type""...","{""raw"":""{\""person\"":[\""Said\""],\""organization\...","[0.016340002417564392, -0.0020620147697627544,...",extract_entities,gpt_4o_mini_dicl,01986213-620a-7b21-a0c3-2a9752c187e6


In [41]:
# Insert the data into the DiclExample table
result = clickhouse_client.insert_df(
    "DynamicInContextLearningExample",
    example_df,
)
print(result)

<clickhouse_connect.driver.summary.QuerySummary object at 0x74c14c722b10>


Finally, add a new variant to your function configuration to try out the Dynamic In-Context Learning variant in practice!

If your embedding model name or generation model name in the config is different from the one you used above, you might have to update the config.
Be sure and also give the variant some weight and if you are using a JSON function set the json_mode field to "strict" if you want.

> **Tip:** DICL variants support additional parameters like system instructions or strict JSON mode. See [Configuration Reference](https://www.tensorzero.com/docs/gateway/configuration-reference).


In [42]:
variant_config = {
    "type": "experimental_dynamic_in_context_learning",
    "embedding_model": DICL_EMBEDDING_MODEL,
    "model": DICL_GENERATION_MODEL,
    "k": DICL_K,
}
full_variant_config = {
    "functions": {FUNCTION_NAME: {"variants": {DICL_VARIANT_NAME: variant_config}}}
}

print(toml.dumps(full_variant_config))

[functions.extract_entities.variants.gpt_4o_mini_dicl]
type = "experimental_dynamic_in_context_learning"
embedding_model = "text-embedding-3-small"
model = "gpt-4.1-mini-2025-04-14"
k = 10



If you haven't, also include the embedding model in the config.


In [43]:
embedding_model_config = {
    "embedding_models": {
        DICL_EMBEDDING_MODEL: {
            "routing": ["openai"],
            "providers": {
                "openai": {"type": "openai", "model_name": DICL_EMBEDDING_MODEL}
            },
        }
    }
}

print(toml.dumps(embedding_model_config))

[embedding_models.text-embedding-3-small]
routing = [ "openai",]

[embedding_models.text-embedding-3-small.providers.openai]
type = "openai"
model_name = "text-embedding-3-small"

