# Core

> Lisette Core

In [None]:
#| default_exp core

In [None]:
#| hide
from cachy import enable_cachy,disable_cachy

In [None]:
#| hide
enable_cachy()

In [None]:
#| export
import asyncio, base64, json, litellm, mimetypes, random, string
from typing import Optional,Callable
from html import escape
from litellm import (acompletion, completion, stream_chunk_builder, Message,
                     ModelResponse, ModelResponseStream, get_model_info, register_model, Usage)
from litellm.utils import function_to_dict, StreamingChoices, Delta, ChatCompletionMessageToolCall, Function, Choices
from toolslm.funccall import mk_ns, call_func, call_func_async, get_schema
from fastcore.utils import *
from fastcore.meta import delegates
from fastcore import imghdr
from dataclasses import dataclass
from litellm.exceptions import ContextWindowExceededError

In [None]:
#| hide
from fastcore.test import *
from IPython.display import Markdown, Image, Audio, Video
import httpx

# LiteLLM

## Deterministic outputs

LiteLLM `ModelResponse(Stream)` objects have `id` and `created_at` fields that are generated dynamically. Even when we use [`cachy`](https://github.com/answerdotai/cachy) to cache the LLM response these dynamic fields create diffs which makes code review more challenging. The patches below ensure that `id` and `created_at` fields are fixed and won't generate diffs.

In [None]:
#| export
def patch_litellm(seed=0):
    "Patch litellm.ModelResponseBase such that `id` and `created` are fixed."
    from litellm.types.utils import ModelResponseBase, ChatCompletionMessageToolCall
    from uuid import UUID
    from base64 import b64encode
    if seed is not None: random.seed(seed) # ensures random ids like tool call ids are deterministic
    
    @patch
    def __init__(self: ModelResponseBase, id=None, created=None, *args, **kwargs): 
        self._orig___init__(id='chatcmpl-xxx', created=1000000000, *args, **kwargs)

    @patch
    def __setattr__(self: ModelResponseBase, name, value):
        if name == 'id': value = 'chatcmpl-xxx'
        elif name == 'created': value = 1000000000
        self._orig___setattr__(name, value)

    def _unqid():
        res = b64encode(UUID(int=random.getrandbits(128), version=4).bytes)
        return '_' + res.decode().rstrip('=').translate(str.maketrans('+/', '_-'))

    @patch
    def __init__(self: ChatCompletionMessageToolCall, function=None, id=None, type="function", **kwargs):
        # we keep the tool call prefix if it exists, this is needed for example to handle srvtoolu_ correctly.
        id = id.split('_')[0]+_unqid() if id and '_' in id else id
        self._orig___init__(function=function, id=id, type=type, **kwargs)


In [None]:
patch_litellm()

## Completion

LiteLLM provides an convenient unified interface for most big LLM providers. Because it's so useful to be able to switch LLM providers with just one argument. We want to make it even easier to by adding some more convenience functions and classes. 

This is very similar to our other wrapper libraries for popular AI providers: [claudette](https://claudette.answer.ai/) (Anthropic), [gaspard](https://github.com/AnswerDotAI/gaspard) (Gemini), [cosette](https://answerdotai.github.io/cosette/) (OpenAI).

In [None]:
#| export
@patch
def _repr_markdown_(self: litellm.ModelResponse):
    message = self.choices[0].message
    content = ''
    if mc:=message.content: content += mc[0]['text'] if isinstance(mc,list) else mc
    if message.tool_calls:
        tool_calls = [f"\n\nðŸ”§ {nested_idx(tc,'function','name')}({nested_idx(tc,'function','arguments')})\n" for tc in message.tool_calls]
        content += "\n".join(tool_calls)
    for img in getattr(message, 'images', []): content += f"\n\n![generated image]({nested_idx(img, 'image_url', 'url')})"
    if not content: content = str(message)
    details = [
        f"id: `{self.id}`",
        f"model: `{self.model}`",
        f"finish_reason: `{self.choices[0].finish_reason}`"
    ]
    if hasattr(self, 'usage') and self.usage: details.append(f"usage: `{self.usage}`")
    det_str = '\n- '.join(details)
    
    return f"""{content}

<details>

- {det_str}

</details>"""

In [None]:
#| export
register_model({
    "claude-opus-4-5": {
        "litellm_provider": "anthropic", "mode": "chat",
        "max_tokens": 64000, "max_input_tokens": 200000, "max_output_tokens": 64000,
        "input_cost_per_token": 0.000005, "output_cost_per_token": 0.000025,
        "cache_creation_input_token_cost": 0.000005*1.25, "cache_read_input_token_cost": 0.000005*0.1,
        "supports_function_calling": True, "supports_parallel_function_calling": True,
        "supports_vision": True, "supports_prompt_caching": True, "supports_response_schema": True,
        "supports_system_messages": True, "supports_reasoning": True, "supports_assistant_prefill": True,
        "supports_tool_choice": True, "supports_computer_use": True, "supports_web_search": True
    }
});
sonn45 = "claude-sonnet-4-5"
opus45 = "claude-opus-4-5"

In [None]:
# litellm._turn_on_debug()

In [None]:
ms = ["gemini/gemini-3-pro-preview", "gemini/gemini-2.5-pro", "gemini/gemini-2.5-flash", "claude-sonnet-4-5", "openai/gpt-4.1"]
msg = [{'role':'user','content':'Hey there!', 'cache_control': {'type': 'ephemeral'}}]
for m in ms:
    display(Markdown(f'**{m}:**'))
    display(completion(m,msg))

Generated images are also displayed (not shown here to conserve filesize):

In [None]:
# completion(model='gemini/gemini-2.5-flash-image', messages=[{'role':'user','content':'Draw a simple sketch of a cat'}])

## Messages formatting

Let's start with making it easier to pass messages into litellm's `completion` function (including images, and pdf files).

In [None]:
#| export
def _bytes2content(data):
    "Convert bytes to litellm content dict (image, pdf, audio, video)"
    mtype = detect_mime(data)
    if not mtype: raise ValueError(f'Data must be a supported file type, got {data[:10]}')
    encoded = base64.b64encode(data).decode("utf-8")    
    if mtype.startswith('image/'): return {'type': 'image_url', 'image_url': f'data:{mtype};base64,{encoded}'}
    return {'type': 'file', 'file': {'file_data': f'data:{mtype};base64,{encoded}'}}

In [None]:
#| export
def _add_cache_control(msg,          # LiteLLM formatted msg
                       ttl=None):    # Cache TTL: '5m' (default) or '1h'
    "cache `msg` with default time-to-live (ttl) of 5minutes ('5m'), but can be set to '1h'."
    if isinstance(msg["content"], str): 
        msg["content"] = [{"type": "text", "text": msg["content"]}]
    cache_control = {"type": "ephemeral"}
    if ttl is not None: cache_control["ttl"] = ttl
    if isinstance(msg["content"], list) and msg["content"]:
        msg["content"][-1]["cache_control"] = cache_control
    return msg

def _has_cache(msg):
    return msg["content"] and isinstance(msg["content"], list) and ('cache_control' in msg["content"][-1])

def remove_cache_ckpts(msg):
    "remove cache checkpoints and return msg."
    if _has_cache(msg): msg["content"][-1].pop('cache_control', None)
    return msg

def _mk_content(o):
    if isinstance(o, str): return {'type':'text','text':o.strip() or '.'}
    elif isinstance(o,bytes): return _bytes2content(o)
    return o

def contents(r):
    "Get message object from response `r`."
    return r.choices[0].message

In [None]:
#| export
def mk_msg(
    content,      # Content: str, bytes (image), list of mixed content, or dict w 'role' and 'content' fields
    role="user",  # Message role if content isn't already a dict/Message
    cache=False,  # Enable Anthropic caching
    ttl=None      # Cache TTL: '5m' (default) or '1h'
):
    "Create a LiteLLM compatible message."
    if isinstance(content, dict) or isinstance(content, Message): return content
    if isinstance(content, ModelResponse): return contents(content)
    if isinstance(content, list) and len(content) == 1 and isinstance(content[0], str): c = content[0]
    elif isinstance(content, list): c = [_mk_content(o) for o in content]
    else: c = content
    msg = {"role": role, "content": c}
    return _add_cache_control(msg, ttl=ttl) if cache else msg

Now we can use mk_msg to create different types of messages.

Simple text:

In [None]:
msg = mk_msg("hey")
msg

Which can be passed to litellm's `completion` function like this:

In [None]:
model = ms[1] # use 2.5-pro, 3-pro is very slow even to run tests as of making

In [None]:
res = completion(model, [msg])
res

We'll add a little shortcut to make examples and testing easier here:

In [None]:
def c(msgs, m=model, **kw):
    msgs = [msgs] if isinstance(msgs,dict) else listify(msgs)
    return completion(m, msgs, **kw)

In [None]:
c(msg)

Lists w just one string element are flattened for conciseness:

In [None]:
test_eq(mk_msg("hey"), mk_msg(["hey"]))

(LiteLLM ignores these fields when sent to other providers)

Text and images:

In [None]:
img_fn = Path('samples/puppy.jpg')
Image(filename=img_fn, width=200)

In [None]:
msg = mk_msg(['hey what in this image?',img_fn.read_bytes()])
print(json.dumps(msg,indent=1)[:200]+"...")

In [None]:
c(msg)

Let's also demonstrate this for PDFs

In [None]:
pdf_fn = Path('samples/solveit.pdf')
msg = mk_msg(['Who is the author of this pdf?', pdf_fn.read_bytes()])
c(msg)

Some models like Gemini support audio and video:

In [None]:
wav_data = httpx.get("https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav").content
# Audio(wav_data)  # uncomment to preview

In [None]:
msg = mk_msg(['What is this audio saying?', wav_data])
completion(ms[1], [msg])

In [None]:
vid_data = httpx.get("https://storage.googleapis.com/github-repo/img/gemini/multimodality_usecases_overview/pixel8.mp4").content

In [None]:
msg = mk_msg(['Concisely, what is happening in this video?', vid_data])
completion(ms[1], [msg])

### Caching

Some providers such as Anthropic require manually opting into caching. Let's try it:

In [None]:
def cpr(i): return f'{i} '*1024 + 'This is a caching test. Report back only what number you see repeated above.'

In [None]:
#| eval: false
disable_cachy()

In [None]:
# msg = mk_msg(cpr(1), cache=True)
# res = c(msg, ms[2])
# res

Anthropic has a maximum of 4 cache checkpoints, so we remove previous ones as we go:

In [None]:
# res = c([remove_cache_ckpts(msg), mk_msg(res), mk_msg(cpr(2), cache=True)], ms[2])
# res

We see that the first message was cached, and this extra message has been written to cache:

In [None]:
# res.usage.prompt_tokens_details

We can add a bunch of large messages in a loop to see how the number of cached tokens used grows.

We do this for 25 times to ensure it still works for more than >20 content blocks, [which is a known anthropic issue](https://docs.claude.com/en/docs/build-with-claude/prompt-caching).

The code below is commented by default, because it's slow. Please uncomment when working on caching.

In [None]:
# h = []
# msg = mk_msg(cpr(1), cache=True)

# for o in range(2,25):
#     h += [remove_cache_ckpts(msg), mk_msg(res)]
#     msg = mk_msg(cpr(o), cache=True)
#     res = c(h+[msg])
#     detls = res.usage.prompt_tokens_details
#     print(o, detls.cached_tokens, detls.cache_creation_tokens, end='; ')

In [None]:
enable_cachy()

### Reconstructing formatted outputs

Lisette can call multiple tools in a loop. Further down this notebook, we'll provide convenience functions for formatting such a sequence of toolcalls and responses into one formatted output string.

For now, we'll show an example and show how to transform such a formatted output string back into a valid LiteLLM history.

In [None]:
fmt_outp = '''
I'll solve this step-by-step, using parallel calls where possible.

<details class='tool-usage-details'>

```json
{
  "id": "toolu_01KjnQH2Nsz2viQ7XYpLW3Ta",
  "call": { "function": "simple_add", "arguments": { "a": 10, "b": 5 } },
  "result": "15"
}
```

</details>

<details class='tool-usage-details'>

```json
{
  "id": "toolu_01Koi2EZrGZsBbnQ13wuuvzY",
  "call": { "function": "simple_add", "arguments": { "a": 2, "b": 1 } },
  "result": "3"
}
```

</details>

Now I need to multiply 15 * 3 before I can do the final division:

<details class='tool-usage-details'>

```json
{
  "id": "toolu_0141NRaWUjmGtwxZjWkyiq6C",
  "call": { "function": "multiply", "arguments": { "a": 15, "b": 3 } },
  "result": "45"
}
```

</details>
'''

In [None]:
#| export
detls_tag = "<details class='tool-usage-details'>"
re_tools = re.compile(  fr"^({detls_tag}\n*(?:<summary>.*?</summary>\n*)?\n*```json\n+(.*?)\n+```\n+</details>)",
                        flags=re.DOTALL|re.MULTILINE)

We can split into chunks of (text,toolstr,json):

In [None]:
sp = re_tools.split(fmt_outp)
for o in list(chunked(sp, 3, pad=True)): print('- ', o)

In [None]:
#| export
def _extract_tool(text:str)->tuple[dict,dict]:
    "Extract tool call and results from <details> block"
    try: d = json.loads(text.strip())
    except: return
    call = d['call']
    func = call['function']
    tc = ChatCompletionMessageToolCall(Function(dumps(call['arguments']),func), d['id'])
    tr = {'role': 'tool','tool_call_id': d['id'],'name': func, 'content': d['result']}
    return tc,tr

def fmt2hist(outp:str)->list:
    "Transform a formatted output into a LiteLLM compatible history"
    lm,hist = Message(),[]
    spt = re_tools.split(outp)
    for txt,_,tooljson in chunked(spt, 3, pad=True):
        txt = txt.strip() if tooljson or txt.strip() else '.'
        hist.append(lm:=Message(txt))
        if tooljson:
            if tcr := _extract_tool(tooljson):
                if not hist: hist.append(lm) # if LLM calls a tool without talking
                lm.tool_calls = lm.tool_calls+[tcr[0]] if lm.tool_calls else [tcr[0]] 
                hist.append(tcr[1])
    return hist

See how we can turn that one formatted output string back into a list of Messages:

In [None]:
from pprint import pprint

In [None]:
h = fmt2hist(fmt_outp)
pprint(h)

### `mk_msgs`

We will skip tool use blocks and tool results during caching

In [None]:
#| export
def _apply_cache_idxs(msgs, cache_idxs=[-1], ttl=None):
    'Add cache control to idxs after filtering tools'
    ms = L(msgs).filter(lambda m: not (m.get('tool_calls', []) or m['role'] == 'tool'))
    for i in cache_idxs:
        try: _add_cache_control(ms[i], ttl)
        except IndexError: continue

Now lets make it easy to provide entire conversations:

In [None]:
#| export
def mk_msgs(
    msgs,                   # List of messages (each: str, bytes, list, or dict w 'role' and 'content' fields)
    cache=False,            # Enable Anthropic caching
    cache_idxs=[-1],        # Cache breakpoint idxs
    ttl=None,               # Cache TTL: '5m' (default) or '1h'
):
    "Create a list of LiteLLM compatible messages."
    if not msgs: return []
    if not isinstance(msgs, list): msgs = [msgs]
    res,role = [],'user'
    msgs = L(msgs).map(lambda m: fmt2hist(m) if detls_tag in m else [m]).concat()
    for m in msgs:
        res.append(msg:=remove_cache_ckpts(mk_msg(m, role=role)))
        role = 'assistant' if msg['role'] in ('user','function', 'tool') else 'user'
    if cache: _apply_cache_idxs(res, cache_idxs, ttl)
    return res

With `mk_msgs` you can easily provide a whole conversation:

In [None]:
msgs = mk_msgs(['Hey!',"Hi there!","How are you?","I'm doing fine and you?"])
msgs

By defualt the last message will be cached when `cache=True`:

In [None]:
msgs = mk_msgs(['Hey!',"Hi there!","How are you?","I'm doing fine and you?"], cache=True)
msgs

In [None]:
test_eq('cache_control' in msgs[-1]['content'][0], True)

Alternatively, users can provide custom `cache_idxs`. Tool call blocks and results are skipped during caching:

In [None]:
msgs = mk_msgs(['Hello!','Hi! How can I help you?','Call some functions!',fmt_outp], cache=True, cache_idxs=[0,-2,-1])
msgs

In [None]:
test_eq('cache_control' in msgs[0]['content'][0], True)
test_eq('cache_control' in msgs[2]['content'][0], True) # shifted idxs to skip tools
test_eq('cache_control' in msgs[-1]['content'][0], True)

Who's speaking at when is automatically inferred.
Even when there are multiple tools being called in parallel (which LiteLLM supports!).

In [None]:
msgs = mk_msgs(['Tell me the weather in Paris and Rome',
                'Assistant calls weather tool two times',
                {'role':'tool','content':'Weather in Paris is ...'},
                {'role':'tool','content':'Weather in Rome is ...'},
                'Assistant returns weather',
                'Thanks!'])
msgs

In [None]:
#| hide
test_eq([m['role'] for m in msgs],['user','assistant','tool','tool','assistant','user'])

For ease of use, if `msgs` is not already in a `list`, it will automatically be wrapped inside one. This way you can pass a single prompt into `mk_msgs` and get back a LiteLLM compatible msg history.

In [None]:
msgs = mk_msgs("Hey")
msgs

In [None]:
#| hide
msgs = mk_msgs({'role':'tool','content':'fake tool result'})
msgs

In [None]:
msgs = mk_msgs(['Hey!',"Hi there!","How are you?","I'm fine, you?"])
msgs

However, beware that if you use `mk_msgs` for a single message, consisting of multiple parts.
Then you should be explicit, and make sure to wrap those multiple messages in two lists:

1. One list to show that they belong together in one message (the inner list).
2. Another, because mk_msgs expects a list of multiple messages (the outer list).

This is common when working with images for example:

In [None]:
msgs = mk_msgs([['Whats in this img?',img_fn.read_bytes()]])
print(json.dumps(msgs,indent=1)[:200]+"...")

## Streaming

LiteLLM supports streaming responses. That's really useful if you want to show intermediate results, instead of having to wait until the whole response is finished.

We create this helper function that returns the entire response at the end of the stream. This is useful when you want to store the whole response somewhere after having displayed the intermediate results.

In [None]:
#| export
def stream_with_complete(gen, postproc=noop):
    "Extend streaming response chunks with the complete response"
    chunks = []
    for chunk in gen:
        chunks.append(chunk)
        yield chunk
    postproc(chunks)
    return stream_chunk_builder(chunks)

In [None]:
r = c(mk_msgs("Hey!"), stream=True)
r2 = SaveReturn(stream_with_complete(r))

In [None]:
for o in r2:
    cts = o.choices[0].delta.content
    if cts: print(cts, end='')

In [None]:
r2.value

## Tools

In [None]:
#| export
def lite_mk_func(f):
    if isinstance(f, dict): return f
    return {'type':'function', 'function':get_schema(f, pname='parameters')}

In [None]:
def simple_add(
    a: int,   # first operand
    b: int=0  # second operand
) -> int:
    "Add two numbers together"
    return a + b

In [None]:
toolsc = lite_mk_func(simple_add)
toolsc

In [None]:
tmsg = mk_msg("What is 5478954793+547982745? How about 5479749754+9875438979? Always use tools for calculations, and describe what you'll do before using a tool. Where multiple tool calls are required, do them in a single response where possible. ")
r = c(tmsg, tools=[toolsc])

In [None]:
display(r)

A tool response can be a string or a list of tool blocks (e.g., an image url block). To allow users to specify if a response should not be immediately stringified, we provide the ToolResponse datatype users can wrap their return statement in.

In [None]:
#| export
@dataclass
class ToolResponse:
    content: list[str,str]

In [None]:
#| export
def _lite_call_func(tc, tool_schemas, ns, raise_on_err=True):
    fn, valid = tc.function.name, {nested_idx(o,'function','name') for o in tool_schemas or []}
    if fn not in valid: res = f"Tool not defined in tool_schemas: {fn}"
    else:
        try: res = call_func(fn, json.loads(tc.function.arguments), ns=ns)
        except json.JSONDecodeError: res = f"Failed to parse function arguments: {tc.function.arguments}"
        else: res = res.content if isinstance(res, ToolResponse) else str(res)
    return {"tool_call_id": tc.id, "role": "tool", "name": fn, "content": res}

In [None]:
tcs = [_lite_call_func(o, [toolsc], ns=globals()) for o in r.choices[0].message.tool_calls]
tcs

Test tool calls that were not in tool_schemas are caught:

In [None]:
fake_tc = ChatCompletionMessageToolCall(index=0, function=Function(name='hallucinated_tool'),id='_', type='function')
test_eq(_lite_call_func(fake_tc, ns=globals(), tool_schemas=[toolsc])['content'],"Tool not defined in tool_schemas: hallucinated_tool")
test_fail(_lite_call_func(fake_tc, ns=globals(), tool_schemas=None)['content'],"Tool not defined in tool_schemas: hallucinated_tool")

Test tool calls that were not in tool_choice are caught:

In [None]:
def delta_text(msg):
    "Extract printable content from streaming delta, return None if nothing to print"
    c = msg.choices[0]
    if not c: return c
    if not hasattr(c,'delta'): return None #f'{c}'
    delta = c.delta
    if delta.content: return delta.content
    if delta.tool_calls:
        res = ''.join(f"ðŸ”§ {tc.function.name}" for tc in delta.tool_calls if tc.id and tc.function.name)
        if res: return f'\n{res}\n'
    if hasattr(delta,'reasoning_content'): return 'ðŸ§ ' if delta.reasoning_content else '\n\n'
    return None

In [None]:
r = c(tmsg, stream=True, tools=[toolsc])
r2 = SaveReturn(stream_with_complete(r))
for o in r2: print(delta_text(o) or '', end='')

In [None]:
r2.value

In [None]:
msg = mk_msg("Solve this complex math problem: What is the derivative of x^3 + 2x^2 - 5x + 1?")
r = c(msg, stream=True, reasoning_effort="low")
r2 = SaveReturn(stream_with_complete(r))
for o in r2: print(delta_text(o) or '', end='')

In [None]:
r2.value

## Structured Outputs

In [None]:
#| export
@delegates(completion)
def structured(
    m:str,          # LiteLLM model string
    msgs:list,      # List of messages 
    tool:Callable,  # Tool to be used for creating the structured output (class, dataclass or Pydantic, function, etc)
    **kwargs):
    "Return the value of the tool call (generally used for structured outputs)"
    t = lite_mk_func(tool)
    r = completion(m, msgs, tools=[t], tool_choice=t, **kwargs)
    args = json.loads(r.choices[0].message.tool_calls[0].function.arguments)
    return tool(**args)

In [None]:
class President:
    "Information about a president of the United States"
    def __init__(
        self, 
        first:str, # first name
        last:str, # last name
        spouse:str, # name of spouse
        years_in_office:str, # format: "{start_year}-{end_year}"
        birthplace:str, # name of city
        birth_year:int # year of birth, `0` if unknown
    ):
        assert re.match(r'\d{4}-\d{4}', years_in_office), "Invalid format: `years_in_office`"
        store_attr()

    __repr__ = basic_repr('first, last, spouse, years_in_office, birthplace, birth_year')

In [None]:
for m in ms[1:]: 
    r = structured(m, [mk_msg("Tell me something about the third president of the USA.")], President)
    test_eq(r.first, 'Thomas'); test_eq(r.last, 'Jefferson')

## Search

LiteLLM provides search, not via tools, but via the special `web_search_options` param.

**Note:** Not all models support web search. LiteLLM's `supports_web_search` field should indicate this, but it's unreliable for some models like `claude-sonnet-4-20250514`. Checking both `supports_web_search` and `search_context_cost_per_query` provides more accurate detection.

In [None]:
#| export
def _has_search(m):
    i = get_model_info(m)
    return bool(i.get('search_context_cost_per_query') or i.get('supports_web_search'))

In [None]:
for m in ms: print(m, _has_search(m))

When search is supported it can be used like this:

In [None]:
smsg = mk_msg("Search the web and tell me very briefly about otters")
r = c(smsg, web_search_options={"search_context_size": "low"})  # or 'medium' / 'high'
r

## Citations

Next, lets handle Anthropic's search citations.

When not using streaming, all citations are placed in a separate key in the response:

In [None]:
r['vertex_ai_grounding_metadata'][0].keys()

In [None]:
r['vertex_ai_grounding_metadata'][0]['webSearchQueries']

Web search results:

In [None]:
r['vertex_ai_grounding_metadata'][0]['groundingChunks'][:3]

Citations in gemini: 

In [None]:
r['vertex_ai_grounding_metadata'][0]['groundingSupports'][:3]

In [None]:
# r.choices[0].message.provider_specific_fields['citations'][0]

However, when streaming the results are not captured this way.
Instead, we provide this helper function that adds the citation to the `content` field in markdown format:

In [None]:
#| export
def cite_footnote(msg):
    if not (delta:=nested_idx(msg, 'choices', 0, 'delta')): return
    if citation:= nested_idx(delta, 'provider_specific_fields', 'citation'):
        title = citation['title'].replace('"', '\\"')
        delta.content = f'[*]({citation["url"]} "{title}") '
        
def cite_footnotes(stream_list):
    "Add markdown footnote citations to stream deltas"
    for msg in stream_list: cite_footnote(msg)

In [None]:
r = list(c(smsg, ms[2], stream=True, web_search_options={"search_context_size": "low"}))
cite_footnotes(r)
stream_chunk_builder(r)

# Chat

LiteLLM is pretty bare bones. It doesnt keep track of conversation history or what tools have been added in the conversation so far.

So lets make a Claudette style wrapper so we can do streaming, toolcalling, and toolloops without problems.

In [None]:
#| export
effort = AttrDict({o[0]:o for o in ('low','medium','high')})

In [None]:
#| export
def _mk_prefill(pf): return ModelResponseStream([StreamingChoices(delta=Delta(content=pf,role='assistant'))])

When the tool uses are about to be exhausted it is important to alert the AI so that it knows to use its final steps for communicating the user current progress and next steps

In [None]:
#| export
def _trunc_str(s, mx=2000, replace="<TRUNCATED>"):
    "Truncate `s` to `mx` chars max, adding `replace` if truncated"
    s = str(s).strip()
    if len(s)<=mx: return s
    s = s[:mx]
    ss = s.split(' ')
    if len(ss[-1])>50: ss[-1] = ss[-1][:5]
    s = ' '.join(ss)
    return s+replace

In [None]:
#| export
_final_prompt = dict(role="user", content="You have used all your tool calls for this turn. Please summarize your findings. If you did not complete your goal, tell the user what further work is needed. You may use tools again on the next user message.")

_cwe_msg = "ContextWindowExceededError: Do no more tool calls and complete your response now. Inform user that you ran out of context and explain what the cause was. This is the response to this tool call, truncated if needed: "

In [None]:
#| export
class Chat:
    def __init__(
        self,
        model:str,                # LiteLLM compatible model name 
        sp='',                    # System prompt
        temp=0,                   # Temperature
        search=False,             # Search (l,m,h), if model supports it
        tools:list=None,          # Add tools
        hist:list=None,           # Chat history
        ns:Optional[dict]=None,   # Custom namespace for tool calling 
        cache=False,              # Anthropic prompt caching
        cache_idxs:list=[-1],     # Anthropic cache breakpoint idxs, use `0` for sys prompt if provided
        ttl=None,                 # Anthropic prompt caching ttl
        api_base=None,            # API base URL for custom providers
        api_key=None,             # API key for custom providers
    ):
        "LiteLLM chat client."
        self.model = model
        hist,tools = mk_msgs(hist,cache,cache_idxs,ttl),listify(tools)
        if ns is None and tools: ns = mk_ns(tools)
        elif ns is None: ns = globals()
        self.tool_schemas = [lite_mk_func(t) for t in tools] if tools else None
        store_attr()
    
    def _prep_msg(self, msg=None, prefill=None):
        "Prepare the messages list for the API call"
        sp = [{"role": "system", "content": self.sp}] if self.sp else []
        if sp:
            if 0 in self.cache_idxs: sp[0] = _add_cache_control(sp[0])
            cache_idxs = L(self.cache_idxs).filter().map(lambda o: o-1 if o>0 else o)
        else:
            cache_idxs = self.cache_idxs
        if msg: self.hist = mk_msgs(self.hist+[msg], self.cache and 'claude' in self.model, cache_idxs, self.ttl)
        pf = [{"role":"assistant","content":prefill}] if prefill else []
        return sp + self.hist + pf

`web_search` is now included in `tool_calls` the internal LLM translation is correctly handled thanks to the fix [here](https://github.com/BerriAI/litellm/pull/17746) but the server side tools still need to be filtered out from `tool_calls` in our own toolloop.

In [None]:
#| export
def _filter_srvtools(tcs): return L(tcs).filter(lambda o: not o.id.startswith('srvtoolu_')) if tcs else None

In [None]:
#| export
@patch
def _call(self:Chat, msg=None, prefill=None, temp=None, think=None, search=None, stream=False, max_steps=2, step=1, final_prompt=None, tool_choice=None, **kwargs):
    "Internal method that always yields responses"
    if step>max_steps: return
    try:
        model_info = get_model_info(self.model)
    except Exception:
        register_model({self.model: {}})
        model_info = get_model_info(self.model)
    if not model_info.get("supports_assistant_prefill"): prefill=None
    if _has_search(self.model) and (s:=ifnone(search,self.search)): kwargs['web_search_options'] = {"search_context_size": effort[s]}
    else: _=kwargs.pop('web_search_options',None)
    if self.api_base: kwargs['api_base'] = self.api_base
    if self.api_key: kwargs['api_key'] = self.api_key
    res = completion(
        model=self.model, messages=self._prep_msg(msg, prefill), stream=stream, 
        tools=self.tool_schemas, reasoning_effort = effort.get(think), tool_choice=tool_choice,
        # temperature is not supported when reasoning
        temperature=None if think else ifnone(temp,self.temp),
        caching=self.cache and 'claude' not in self.model,
        **kwargs)
    if stream:
        if prefill: yield _mk_prefill(prefill)
        res = yield from stream_with_complete(res,postproc=cite_footnotes)
    m = contents(res)
    if prefill: m.content = prefill + m.content
    self.hist.append(m)
    yield res

    if tcs := _filter_srvtools(m.tool_calls):
        tool_results=[_lite_call_func(tc, self.tool_schemas, self.ns) for tc in tcs]
        self.hist+=tool_results
        for r in tool_results: yield r
        if step>=max_steps-1: prompt,tool_choice,search = final_prompt,'none',False
        else: prompt = None
        try: yield from self._call(
            prompt, prefill, temp, think, search, stream, max_steps, step+1,
            final_prompt, tool_choice, **kwargs)
        except ContextWindowExceededError:
            for t in tool_results:
                if len(t['content'])>1000: t['content'] = _cwe_msg + _trunc_str(t['content'], mx=1000)
            yield from self._call(None, prefill, temp, think, search, stream, max_steps, max_steps, final_prompt, 'none', **kwargs)

In [None]:
#| export
@patch
@delegates(Chat._call)
def __call__(self:Chat,
             msg=None,          # Message str, or list of multiple message parts
             prefill=None,      # Prefill AI response if model supports it
             temp=None,         # Override temp set on chat initialization
             think=None,        # Thinking (l,m,h)
             search=None,       # Override search set on chat initialization (l,m,h)
             stream=False,      # Stream results
             max_steps=2, # Maximum number of tool calls
             final_prompt=_final_prompt, # Final prompt when tool calls have ran out 
             return_all=False,  # Returns all intermediate ModelResponses if not streaming and has tool calls
             **kwargs):
    "Main call method - handles streaming vs non-streaming"
    result_gen = self._call(msg, prefill, temp, think, search, stream, max_steps, 1, final_prompt, **kwargs)     
    if stream: return result_gen              # streaming
    elif return_all: return list(result_gen)  # toolloop behavior
    else: return last(result_gen)             # normal chat behavior

In [None]:
@patch(as_prop=True)
def cost(self: Chat):
    "Total cost of all responses in conversation history"
    return sum(getattr(r, '_hidden_params', {}).get('response_cost')  or 0
               for r in self.h if hasattr(r, 'choices'))

In [None]:
#| export
@patch
def print_hist(self:Chat):
    "Print each message on a different line"
    for r in self.hist: print(r, end='\n\n')

## Examples

### History tracking

In [None]:
for m in ms[1:]:
    chat = Chat(m)
    chat("Hey my name is Rens")
    r = chat("Whats my name")
    test_eq('Rens' in contents(r).content, True)
r

See now we keep track of history!

History is stored in the `hist` attribute:

In [None]:
chat.hist

In [None]:
chat.print_hist()

You can also pass an old chat history into new Chat objects:

In [None]:
for m in ms[1:]:
    chat2 = Chat(m, hist=chat.hist)
    r = chat2("What was my name again?")
    test_eq('Rens' in contents(r).content, True)
r

You can prefix an [OpenAI compatible model](https://docs.litellm.ai/docs/providers/openai_compatible) with 'openai/' and use an `api_base` and `api_key` argument to use models not registered with litellm.

```python
import os, litellm
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
c = Chat("openai/gpt-oss-20b", api_key=OPENROUTER_API_KEY, api_base=OPENROUTER_BASE_URL)
c("hi")
```

### Synthetic History Creation

Lets build chat history step by step. That way we can tweak anything we need to during testing.

In [None]:
pr = "What is 5 + 7? Use the tool to calculate it."
for m in ms[1:]:
    c = Chat(m, tools=[simple_add])
    res = c(pr)
    test_eq('12' in contents(res).content, True)
    test_eq(nested_idx(c.hist,1,'tool_calls',0,'function','name'), 'simple_add')

Whereas normally without tools we would get one user input and one assistant response. Here we get two extra messages in between.
- An assistant message requesting the tools with arguments.
- A tool response with the result to the tool call.

In [None]:
c.print_hist()

Lets try to build this up manually so we have full control over the inputs.

In [None]:
#| export
def random_tool_id():
    "Generate a random tool ID with 'toolu_' prefix"
    random_part = ''.join(random.choices(string.ascii_letters + string.digits, k=25))
    return f'toolu_{random_part}'

In [None]:
random_tool_id()

A tool call request can contain one more or more tool calls. Lets make one.

In [None]:
#| export
def mk_tc(func, args, tcid=None, idx=1):
    if not tcid: tcid = random_tool_id()
    return {'index': idx, 'function': {'arguments': args, 'name': func}, 'id': tcid, 'type': 'function'}

In [None]:
tc = mk_tc(simple_add.__name__, json.dumps(dict(a=5, b=7)))
tc

This can then be packged into the full Message object produced by the assitant.

In [None]:
def mk_tc_req(content, tcs): return Message(content=content, role='assistant', tool_calls=tcs, function_call=None)

In [None]:
tc_cts = "I'll use the simple_add tool to calculate 5 + 7 for you."
tcq = mk_tc_req(tc_cts, [tc])
tcq

Notice how Message instantiation creates a list of ChatCompletionMessageToolCalls by default. When the tools are executed this is converted back
to a dictionary, for consistency we want to keep these as dictionaries from the beginning.

In [None]:
#| export
def mk_tc_req(content, tcs):
    msg = Message(content=content, role='assistant', tool_calls=tcs, function_call=None)
    msg.tool_calls = [{**dict(tc), 'function': dict(tc['function'])} for tc in msg.tool_calls]
    return msg

In [None]:
tcq = mk_tc_req(tc_cts, [tc])
tcq

In [None]:
c = Chat(model, tools=[simple_add], hist=[pr, tcq])

In [None]:
c.print_hist()

Looks good so far! Now we will want to provide the actual result!

In [None]:
#| export
def mk_tc_result(tc, result): return {'tool_call_id': tc['id'], 'role': 'tool', 'name': tc['function']['name'], 'content': result}

Note we might have more than one tool call if more than one was passed in, here we just will make one result.

In [None]:
tcq.tool_calls[0]

In [None]:
mk_tc_result(tcq.tool_calls[0], '12')

In [None]:
#| export
def mk_tc_results(tcq, results): return [mk_tc_result(a,b) for a,b in zip(tcq.tool_calls, results)]

Same for here tcq.tool_calls will match the number of results passed in the results list.

In [None]:
tcq

In [None]:
tcr = mk_tc_results(tcq, ['12'])
tcr

Now we can call it with this synthetic data to see what the response is!

In [None]:
c(tcr[0])

In [None]:
c.print_hist()

Lets try this again, but lets give it something that is clearly wrong for fun.

In [None]:
c = Chat(model, tools=[simple_add], hist=[pr, tcq])

In [None]:
tcr = mk_tc_results(tcq, ['13'])
tcr

In [None]:
c(tcr[0])

Lets make sure this works with multiple tool calls in the same assistant Message.

In [None]:
tcs = [
    mk_tc(simple_add.__name__, json.dumps({"a": 5, "b": 7})), 
    mk_tc(simple_add.__name__, json.dumps({"a": 6, "b": 7})), 
]

In [None]:
tcq = mk_tc_req("I will calculate these for you!", tcs)
tcq

In [None]:
tcr = mk_tc_results(tcq, ['12', '13'])

In [None]:
c = Chat(model, tools=[simple_add], hist=[pr, tcq, tcr[0]])

In [None]:
c(tcr[1])

In [None]:
c.print_hist()

In [None]:
chat = Chat(ms[1], tools=[simple_add])
res = chat("What's 5 + 3? Use the `simple_add` tool.")
res

In [None]:
res = chat("Now, tell me a joke based on that result.")
res

In [None]:
chat.hist

### Images

In [None]:
for m in ms[1:]:
    chat = Chat(m)
    r = chat(['Whats in this img?',img_fn.read_bytes()])
    test_eq('puppy' in contents(r).content, True)
r

### Prefill

Prefill works as expected:

In [None]:
for m in ms[1:]:
    if not get_model_info(m)['supports_assistant_prefill']: continue
    chat = Chat(m)
    chat('Hi this is Rens!')
    r = chat("Spell my name",prefill="Your name is R E")
    test_eq(contents(r).content.startswith('Your name is R E N S'), True)

And the entire message is stored in the history, not just the generated part:

In [None]:
# chat.hist[-1]

### Streaming

In [None]:
from time import sleep

In [None]:
for m in ms[1:]:
    chat = Chat(m)
    stream_gen = chat("Count to 5", stream=True)
    for chunk in stream_gen:
        if isinstance(chunk, ModelResponse): display(chunk)
        else: print(delta_text(chunk) or '',end='')

Lets try prefill with streaming too:

In [None]:
# stream_gen = chat("Continue counting to 10","Okay! 6, 7",stream=True)
# for chunk in stream_gen:
#     if isinstance(chunk, ModelResponse): display(chunk)
#     else: print(delta_text(chunk) or '',end='')

### Tool use

Ok now lets test tool use

In [None]:
for m in ms[1:]:
    display(Markdown(f'**{m}:**'))
    chat = Chat(m, tools=[simple_add])
    res = chat("What's 5 + 3? Use  the `simple_add` tool. Explain.")
    display(res)

### Thinking w tool use

In [None]:
for m in ms[1:]:
    _sparams = litellm.get_model_info(m)['supported_openai_params']
    if 'reasoning_effort' not in _sparams: continue
    display(Markdown(f'**{m}:**'))
    chat = Chat(m, tools=[simple_add])
    res = chat("What's 5 + 3?",think='l',return_all=True)
    display(*res)

### Search

In [None]:
for m in ms[1:]:
    display(Markdown(f'**{m}:**'))
    chat = Chat(m)
    res = chat("Search the web and tell me very briefly about otters", search='l', stream=True)
    for o in res:
        if isinstance(o, ModelResponse): sleep(0.01); display(o)
        else: pass

### Multi tool calling

We can let the model call multiple tools in sequence using the `max_steps` parameter.

In [None]:
for m in ms:
    display(Markdown(f'**{m}:**'))
    chat = Chat(m, tools=[simple_add])
    res = chat("What's ((5 + 3)+7)+11? Work step by step", return_all=True, max_steps=5)
    for r in res: display(r)

Some models support parallel tool calling. I.e. sending multiple tool call requests in one conversation step.

In [None]:
def multiply(a: int, b: int) -> int:
    "Multiply two numbers"
    return a * b

for m in ms[1:]:
    _sparams = litellm.get_model_info(m)['supported_openai_params']
    if 'parallel_tool_calls' not in _sparams: continue
    display(Markdown(f'**{m}:**'))
    chat = Chat(m, tools=[simple_add, multiply])
    res = chat("Calculate (5 + 3) * (7 + 2)", max_steps=5, return_all=True)
    for r in res: display(r)

See how the additions are calculated in one go!

We don't want the model to keep running tools indefinitely. Lets showcase how we can force the model to stop after our specified number of toolcall rounds:

In [None]:
def divide(a: int, b: int) -> float:
    "Divide two numbers"
    return a / b

chat = Chat(model, tools=[simple_add, multiply, divide])
res = chat("Calculate ((10 + 5) * 3) / (2 + 1) step by step.", 
           max_steps=3, return_all=True,
           final_prompt="Please wrap-up for now and summarize how far we got.")
for r in res: display(r)

In [None]:
#| hide
test_eq(len([o for o in res if isinstance(o,ModelResponse)]),3)

### Tool call exhaustion

In [None]:
pr = "What is 1+2, and then the result of adding +2, and then +3 to it? Use tools to make the calculations!"
c = Chat(model, tools=[simple_add])

In [None]:
res = c(pr, max_steps=2)
res

In [None]:
assert c.hist[-2] == _final_prompt

## Async

### AsyncChat

If you want to use LiteLLM in a webapp you probably want to use their async function `acompletion`.
To make that easier we will implement our version of `AsyncChat` to complement it. It follows the same implementation as Chat as much as possible:

In [None]:
#| export
async def _alite_call_func(tc, tool_schemas, ns, raise_on_err=True):
    fn, valid = tc.function.name, {nested_idx(o,'function','name') for o in tool_schemas or []}
    if fn not in valid: res = f"Tool not defined in tool_schemas: {fn}"
    else:
        try: fargs = json.loads(tc.function.arguments)
        except json.JSONDecodeError: res = f"Failed to parse function arguments: {tc.function.arguments}"
        else:
            res = await call_func_async(fn, fargs, ns=ns)
            res = res.content if isinstance(res, ToolResponse) else str(res)
    return {"tool_call_id": tc.id, "role": "tool", "name": fn, "content": res}

Testing the scenarios where the tool call was not in schemas, or schemas was missing:

In [None]:
result = await _alite_call_func(fake_tc, [toolsc], globals())
test_eq(result['content'], "Tool not defined in tool_schemas: hallucinated_tool")

In [None]:
result = await _alite_call_func(fake_tc, None, globals())
test_eq(result['content'], "Tool not defined in tool_schemas: hallucinated_tool")

In [None]:
#| export
@asave_iter
async def astream_with_complete(self, agen, postproc=noop):
    chunks = []
    async for chunk in agen:
        chunks.append(chunk)
        postproc(chunk)
        yield chunk
    self.value = stream_chunk_builder(chunks)

In [None]:
#| export
class AsyncChat(Chat):
    async def _call(self, msg=None, prefill=None, temp=None, think=None, search=None, stream=False, max_steps=2, step=1, final_prompt=None, tool_choice=None, **kwargs):
        if step>max_steps+1: return
        if not get_model_info(self.model).get("supports_assistant_prefill"): prefill=None
        if _has_search(self.model) and (s:=ifnone(search,self.search)): kwargs['web_search_options'] = {"search_context_size": effort[s]}
        else: _=kwargs.pop('web_search_options',None)
        res = await acompletion(model=self.model, messages=self._prep_msg(msg, prefill), stream=stream,
                         tools=self.tool_schemas, reasoning_effort=effort.get(think), tool_choice=tool_choice,
                         # temperature is not supported when reasoning
                         temperature=None if think else ifnone(temp,self.temp), 
                         caching=self.cache and 'claude' not in self.model,
                         **kwargs)
        if stream:
            if prefill: yield _mk_prefill(prefill)
            res = astream_with_complete(res,postproc=cite_footnote)
            async for chunk in res: yield chunk
            res = res.value
        m=contents(res)
        if prefill: m.content = prefill + m.content
        yield res
        self.hist.append(m)

        if tcs := _filter_srvtools(m.tool_calls):
            tool_results = []
            for tc in tcs:
                result = await _alite_call_func(tc, self.tool_schemas, self.ns)
                tool_results.append(result)
                yield result
            self.hist+=tool_results
            if step>=max_steps-1: prompt,tool_choice,search = final_prompt,'none',False
            else: prompt = None
            try:
                async for result in self._call(
                    prompt, prefill, temp, think, search, stream, max_steps, step+1,
                    final_prompt, tool_choice=tool_choice, **kwargs): yield result
            except ContextWindowExceededError:
                for t in tool_results:
                    if len(t['content'])>1000: t['content'] = _cwe_msg + _trunc_str(t['content'], mx=1000)
                async for result in self._call(
                    prompt, prefill, temp, think, search, stream, max_steps, step+1,
                    final_prompt, tool_choice='none', **kwargs): yield result

In [None]:
#| export
@patch
@delegates(Chat._call)
async def __call__(
    self:AsyncChat,
    msg=None,          # Message str, or list of multiple message parts
    prefill=None,      # Prefill AI response if model supports it
    temp=None,         # Override temp set on chat initialization
    think=None,        # Thinking (l,m,h)
    search=None,       # Override search set on chat initialization (l,m,h)
    stream=False,      # Stream results
    max_steps=2, # Maximum number of tool calls
    final_prompt=_final_prompt, # Final prompt when tool calls have ran out 
    return_all=False,  # Returns all intermediate ModelResponses if not streaming and has tool calls
    **kwargs
):
    result_gen = self._call(msg, prefill, temp, think, search, stream, max_steps, 1, final_prompt, **kwargs)
    if stream or return_all: return result_gen
    async for res in result_gen: pass
    return res # normal chat behavior only return last msg

### Examples

Basic example

In [None]:
for m in ms[1:]:
    chat = AsyncChat(m)
    test_eq('4' in contents(await chat("What is 2+2?")).content, True)

With tool calls

In [None]:
async def async_add(a: int, b: int) -> int:
    "Add two numbers asynchronously"
    await asyncio.sleep(0.1)
    return a + b

In [None]:
for m in ms[1:]:
    chat = AsyncChat(m, tools=[async_add])
    r = await chat("What is 5 + 7? Use the tool to calculate it.")
    test_eq('12' in contents(r).content, True)
    test_eq(nested_idx(chat.hist, 1, 'tool_calls', 0, 'function', 'name'), 'async_add')

## Async Streaming Display

This is what our outputs look like with streaming results:

In [None]:
chat_with_tools = AsyncChat(model, tools=[async_add])
res = await chat_with_tools("What is 5 + 7? Use the tool to calculate it.", stream=True)
async for o in res:
    if isinstance(o,ModelResponseStream): print(delta_text(o) or '',end='')
    elif isinstance(o,dict): print(o)

Here's a complete `ModelResponse` taken from the response stream:

In [None]:
resp = ModelResponse(id='chatcmpl-xxx', created=1000000000, model='claude-sonnet-4-5', object='chat.completion', system_fingerprint=None, choices=[Choices(finish_reason='tool_calls', index=0, message=Message(content="I'll calculate ((10 + 5) * 3) / (2 + 1) step by step:", role='assistant', tool_calls=[ChatCompletionMessageToolCall(function=Function(arguments='{"a": 10, "b": 5}', name='simple_add'), id='toolu_018BGyenjiRkDQFU1jWP6qRo', type='function'), ChatCompletionMessageToolCall(function=Function(arguments='{"a": 2, "b": 1}', name='simple_add'), id='toolu_01CWqrNQvoRjf1Q1GLpTUgQR', type='function')], function_call=None, provider_specific_fields=None))], usage=Usage(completion_tokens=228, prompt_tokens=794, total_tokens=1022, prompt_tokens_details=None))
print(repr(resp))

In [None]:
tc=resp.choices[0].message.tool_calls[0]
tc

In [None]:
tr={'tool_call_id': 'toolu_018BGyenjiRkDQFU1jWP6qRo', 'role': 'tool','name': 'simple_add',
    'content': '15 is the answer! ' +'.'*2000}

In [None]:
#| export
def _trunc_param(v, mx=50):
    "Truncate and escape param value for display"
    return _trunc_str(str(v).replace('`', r'\`'), mx=mx, replace='â€¦')

def mk_tr_details(tr, tc, mx=2000):
    "Create <details> block for tool call as JSON"
    args = {k:_trunc_str(v, mx=mx) for k,v in json.loads(tc.function.arguments).items()}
    res = {'id':tr['tool_call_id'], 
           'call':{'function': tc.function.name, 'arguments': args},
           'result':_trunc_str(tr.get('content'), mx=mx),}
    params = ', '.join(f"{k}={_trunc_param(v)}" for k,v in args.items())
    summ = f"<summary>{tc.function.name}({params})</summary>"
    return f"\n\n{detls_tag}\n{summ}\n\n```json\n{dumps(res, indent=2)}\n```\n\n</details>\n\n"

In [None]:
mk_tr_details(tr,tc,mx=300)

In [None]:
#| export
class StreamFormatter:
    def __init__(self, include_usage=False, mx=2000, debug=False):
        self.outp,self.tcs,self.include_usage,self.mx,self.debug = '',{},include_usage,mx,debug
    
    def format_item(self, o):
        "Format a single item from the response stream."
        res = ''
        if self.debug: print(o)
        if isinstance(o, ModelResponseStream):
            d = o.choices[0].delta
            if nested_idx(d, 'reasoning_content') and d['reasoning_content']!='{"text": ""}':
                res+= 'ðŸ§ ' if not self.outp or self.outp[-1]=='ðŸ§ ' else '\n\nðŸ§ ' # gemini can interleave reasoning
            elif self.outp and self.outp[-1] == 'ðŸ§ ': res+= '\n\n'
            if c:=d.content: # gemini has text content in last reasoning chunk
                res+=f"\n\n{c}" if res and res[-1] == 'ðŸ§ ' else c
            for img in getattr(d, 'images', []): res += f"\n\n![generated image]({nested_idx(img, 'image_url', 'url')})\n\n"
        elif isinstance(o, ModelResponse):
            if self.include_usage: res += f"\nUsage: {o.usage}"
            if c:=getattr(contents(o),'tool_calls',None):
                self.tcs = {tc.id:tc for tc in c}
        elif isinstance(o, dict) and 'tool_call_id' in o:
            res += mk_tr_details(o, self.tcs.pop(o['tool_call_id']), mx=self.mx)
        self.outp+=res
        return res
    
    def format_stream(self, rs):
        "Format the response stream for markdown display."
        for o in rs: yield self.format_item(o)

In [None]:
stream_msg = ModelResponseStream([StreamingChoices(delta=Delta(content="Hello world!"))])
StreamFormatter().format_item(stream_msg)

In [None]:
reasoning_msg = ModelResponseStream([StreamingChoices(delta=Delta(reasoning_content="thinking..."))])
StreamFormatter().format_item(reasoning_msg)

In [None]:
#| export
class AsyncStreamFormatter(StreamFormatter):
    async def format_stream(self, rs):
        "Format the response stream for markdown display."
        async for o in rs: yield self.format_item(o)

In [None]:
mock_tool_call = ChatCompletionMessageToolCall(
    id="toolu_123abc456def", type="function", 
    function=Function( name="simple_add", arguments='{"a": 5, "b": 3}' )
)

mock_response = ModelResponse()
mock_response.choices = [type('Choice', (), {
    'message': type('Message', (), {
        'tool_calls': [mock_tool_call]
    })()
})()]

mock_tool_result = {
    'tool_call_id': mock_tool_call.id, 'role': 'tool', 
    'name': 'simple_add', 'content': '8'
}

In [None]:
fmt = AsyncStreamFormatter()
print(fmt.format_item(mock_response))
print('---')
print(fmt.format_item(mock_tool_result))

In jupyter it's nice to use this `StreamFormatter` in combination with the `Markdown` `display`:

In [None]:
#| export
def display_stream(rs):
    "Use IPython.display to markdown display the response stream."
    try: from IPython.display import display, Markdown
    except ModuleNotFoundError: raise ModuleNotFoundError("This function requires ipython. Please run `pip install ipython` to use.")
    fmt = StreamFormatter()
    md = ''
    for o in fmt.format_stream(rs): 
        md+=o
        display(Markdown(md),clear=True)
    return fmt

Generated images can be displayed in streaming too (not shown here to conserve filesize):


In [None]:
# rs = completion(model='gemini/gemini-2.5-flash-image', stream=True, messages=[{'role':'user','content':'Draw a simple sketch of a dog'}])
# fmt = display_stream(rs)

In [None]:
#| export
async def adisplay_stream(rs):
    "Use IPython.display to markdown display the response stream."
    try: from IPython.display import display, Markdown
    except ModuleNotFoundError: raise ModuleNotFoundError("This function requires ipython. Please run `pip install ipython` to use.")
    fmt = AsyncStreamFormatter()
    md = ''
    async for o in fmt.format_stream(rs): 
        md+=o
        display(Markdown(md),clear=True)
    return fmt

## Streaming examples

Now we can demonstrate `AsyncChat` with `stream=True`!

### Tool call

In [None]:
chat = Chat(model, tools=[simple_add])
res = chat("What is 5 + 7? Use the tool to calculate it.", stream=True)
fmt = display_stream(res)

In [None]:
chat = AsyncChat(model, tools=[async_add])
res = await chat("What is 5 + 7? Use the tool to calculate it.", stream=True)
fmt = await adisplay_stream(res)

In [None]:
chat = AsyncChat(model, tools=[async_add])
res = await chat("What is 5 + 3? Use the tool to calculate it.", stream=True)
fmt = await adisplay_stream(res)

### Thinking tool call

In [None]:
chat = AsyncChat(model)
res = await chat("Briefly, what's the most efficient way to sort a list of 1000 random integers?", think='l',stream=True)
_ = await adisplay_stream(res)

### Multiple tool calls

In [None]:
#| hide
chat = AsyncChat(model, tools=[simple_add, multiply, divide])
res = await chat("Calculate ((10 + 5) * 3) / (2 + 1) Use parallel tool calls, but explain where we are after each batch.", 
           max_steps=3, stream=True,
           final_prompt="Please wrap-up for now and summarize how far we got.")
fmt = await adisplay_stream(res)

In [None]:
chat.hist[1]

In [None]:
chat.hist[2]

In [None]:
chat.hist[3]

In [None]:
chat.hist[4]

In [None]:
chat.hist[5]

Now to demonstrate that we can load back the formatted output back into a new `Chat` object:

In [None]:
chat5 = Chat(model,hist=fmt2hist(fmt.outp),tools=[simple_add, multiply, divide])
chat5('what did we just do?')

### Search

In [None]:
chat_stream_tools = AsyncChat(model, search='l')
res = await chat_stream_tools("Search the weather in NYC", stream=True)
_=await adisplay_stream(res)

### Caching

#### Anthropic

We use explicit caching via cache control checkpoints. Anthropic requires exact match with cached tokens and even a small change results in cache invalidation.

In [None]:
disable_cachy()

In [None]:
#| notest
a,b = random.randint(0,100), random.randint(0,100)
hist = [[f"What is {a}+{b}?\n" * 250], f"It's {a+b}", ['hi'], "Hello"]

In this first api call we will see cache creation until the last user msg:

In [None]:
#| notest
sleep(5)
chat = AsyncChat(ms[3], cache=True, hist=hist)
rs = await chat('hi again', stream=True, stream_options={"include_usage": True})
async for o in rs: 
    if isinstance(o, ModelResponse): print(o.usage)

In [None]:
#| notest
test_eq(o.usage.cache_creation_input_tokens > 1000, True)
test_eq(o.usage.cache_read_input_tokens, 0)

In [None]:
#| notest
hist.extend([['hi again'], 'how may i help you?'])
chat = AsyncChat(ms[3], cache=True, hist=hist)
rs = await chat('bye!', stream=True, stream_options={"include_usage": True})
async for o in rs:
    if isinstance(o, ModelResponse): print(o.usage)

In [None]:
#| notest
test_eq(o.usage.cache_read_input_tokens > 1000, True)

The subsequent call should re-use the existing cache:

#### Gemini

Gemini implicit caching supports partial token matches. The usage metadata only shows cache hits with the `cached_tokens` field. So, to view them we need to run completions at least twice.

Testing with `gemini-2.5-flash` until `gemini-3-pro-preview` is more reliable

In [None]:
#| notest
chat = AsyncChat(ms[2], cache=True, hist=hist)
rs = await chat('hi again', stream=True, stream_options={"include_usage": True})
async for o in rs: 
    if isinstance(o, ModelResponse): print(o.usage)

Running the same completion again:

In [None]:
#| notest
sleep(5) # it takes a while for cached tokens to be avail.
chat = AsyncChat(ms[2], cache=True, hist=hist)
rs = await chat('hi again', stream=True, stream_options={"include_usage": True})
async for o in rs: 
    if isinstance(o, ModelResponse): print(o.usage)

In [None]:
#| notest
test_eq(o.usage.prompt_tokens_details.cached_tokens > 1800, True)

In [None]:
#| notest
hist.extend([['hi again'], 'how may i help you?'])
chat = AsyncChat(ms[2], cache=True, hist=hist)
rs = await chat('bye!', stream=True, stream_options={"include_usage": True})
async for o in rs:
    if isinstance(o, ModelResponse): print(o.usage)

In [None]:
#| notest
test_eq(o.usage.prompt_tokens_details.cached_tokens > 1800, True)

Let's modify the cached content and see that partial matching works:

In [None]:
#| notest
c = hist[0][0]
hist[0][0] = c[:int(len(c)*0.75)] + " Some extra text"
hist.extend([['hi again'], 'how may i help you?'])
chat = AsyncChat(ms[2], cache=True, hist=hist)
rs = await chat('bye!', stream=True, stream_options={"include_usage": True})
async for o in rs:
    if isinstance(o, ModelResponse): print(o.usage)

In [None]:
#| notest
test_eq(o.usage.prompt_tokens_details.cached_tokens > 900, True)

# Export -

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()