In [None]:
#| default_exp core

# Core
> The `FastHTML` subclass of `Starlette`, along with the `RouterX` and `RouteX` classes it automatically uses.

This is the source code to fasthtml. You won't need to read this unless you want to understand how things are built behind the scenes, or need full details of a particular API. The notebook is converted to the Python module [fasthtml/core.py](https://github.com/AnswerDotAI/fasthtml/blob/main/fasthtml/core.py) using [nbdev](https://nbdev.fast.ai/).

## Imports and utils

In [None]:
#| export
import json,uuid,inspect,types,signal,asyncio,threading,inspect

from fastcore.utils import *
from fastcore.xml import *
from fastcore.meta import use_kwargs_dict

from types import UnionType, SimpleNamespace as ns, GenericAlias
from typing import Optional, get_type_hints, get_args, get_origin, Union, Mapping, TypedDict, List, Any
from datetime import datetime,date
from dataclasses import dataclass,fields
from collections import namedtuple
from inspect import isfunction,ismethod,Parameter,get_annotations
from functools import wraps, partialmethod, update_wrapper
from http import cookies
from urllib.parse import urlencode, parse_qs, quote, unquote
from copy import copy,deepcopy
from warnings import warn
from dateutil import parser as dtparse
from httpx import ASGITransport, AsyncClient
from anyio import from_thread
from uuid import uuid4
from base64 import b85encode,b64encode

from fasthtml.starlette import *

In [None]:
import time

from IPython import display
from enum import Enum
from pprint import pprint

from fastcore.test import *
from starlette.testclient import TestClient
from starlette.requests import Headers
from starlette.datastructures import UploadFile

In [None]:
#| export
def _params(f): return signature_ex(f, True).parameters

empty = Parameter.empty

We write source code _first_, and then tests come _after_. The tests serve as both a means to confirm that the code works and also serves as working examples. The first exported function, `parsed_date`, is an example of this pattern.

In [None]:
#| export
def parsed_date(s:str):
    "Convert `s` to a datetime"
    return dtparse.parse(s)

In [None]:
parsed_date('2pm')

datetime.datetime(2025, 1, 12, 14, 0)

In [None]:
isinstance(date.fromtimestamp(0), date)

True

In [None]:
#| export
def snake2hyphens(s:str):
    "Convert `s` from snake case to hyphenated and capitalised"
    s = snake2camel(s)
    return camel2words(s, '-')

In [None]:
snake2hyphens("snake_case")

'Snake-Case'

In [None]:
#| export
htmx_hdrs = dict(
    boosted="HX-Boosted",
    current_url="HX-Current-URL",
    history_restore_request="HX-History-Restore-Request",
    prompt="HX-Prompt",
    request="HX-Request",
    target="HX-Target",
    trigger_name="HX-Trigger-Name",
    trigger="HX-Trigger")

@dataclass
class HtmxHeaders:
    boosted:str|None=None; current_url:str|None=None; history_restore_request:str|None=None; prompt:str|None=None
    request:str|None=None; target:str|None=None; trigger_name:str|None=None; trigger:str|None=None
    def __bool__(self): return any(hasattr(self,o) for o in htmx_hdrs)

def _get_htmx(h):
    res = {k:h.get(v.lower(), None) for k,v in htmx_hdrs.items()}
    return HtmxHeaders(**res)

In [None]:
def test_request(url: str='/', headers: dict={}, method: str='get') -> Request:
    scope = {
        'type': 'http',
        'method': method,
        'path': url,
        'headers': Headers(headers).raw,
        'query_string': b'',
        'scheme': 'http',
        'client': ('127.0.0.1', 8000),
        'server': ('127.0.0.1', 8000),
    }
    receive = lambda: {"body": b"", "more_body": False}
    return Request(scope, receive)

In [None]:
h = test_request(headers=Headers({'HX-Request':'1'}))
_get_htmx(h.headers)

HtmxHeaders(boosted=None, current_url=None, history_restore_request=None, prompt=None, request='1', target=None, trigger_name=None, trigger=None)

In [None]:
#| export
def _mk_list(t, v): return [t(o) for o in listify(v)]

## Request and response

In [None]:
#| export
fh_cfg = AttrDict(indent=True)

In [None]:
#| export
def _fix_anno(t, o):
    "Create appropriate callable type for casting a `str` to type `t` (or first type in `t` if union)"
    origin = get_origin(t)
    if origin is Union or origin is UnionType or origin in (list,List):
        t = first(o for o in get_args(t) if o!=type(None))
    d = {bool: str2bool, int: str2int, date: str2date, UploadFile: noop}
    res = d.get(t, t)
    if origin in (list,List): return _mk_list(res, o)
    if not isinstance(o, (str,list,tuple)): return o
    return res(o[-1]) if isinstance(o,(list,tuple)) else res(o)

In [None]:
test_eq(_fix_anno(Union[str,None], 'a'), 'a')
test_eq(_fix_anno(float, 0.9), 0.9)
test_eq(_fix_anno(int, '1'), 1)
test_eq(_fix_anno(int, ['1','2']), 2)
test_eq(_fix_anno(list[int], ['1','2']), [1,2])
test_eq(_fix_anno(list[int], '1'), [1])

In [None]:
#| export
def _form_arg(k, v, d):
    "Get type by accessing key `k` from `d`, and use to cast `v`"
    if v is None: return
    if not isinstance(v, (str,list,tuple)): return v
    # This is the type we want to cast `v` to
    anno = d.get(k, None)
    if not anno: return v
    return _fix_anno(anno, v)

In [None]:
d = dict(k=int, l=List[int])
test_eq(_form_arg('k', "1", d), 1)
test_eq(_form_arg('l', "1", d), [1])
test_eq(_form_arg('l', ["1","2"], d), [1,2])

In [None]:
#| export
@dataclass
class HttpHeader: k:str;v:str

In [None]:
#| export
def _to_htmx_header(s):
    return 'HX-' + s.replace('_', '-').title()

htmx_resps = dict(location=None, push_url=None, redirect=None, refresh=None, replace_url=None,
                 reswap=None, retarget=None, reselect=None, trigger=None, trigger_after_settle=None, trigger_after_swap=None)

In [None]:
_to_htmx_header('trigger_after_settle')

'HX-Trigger-After-Settle'

In [None]:
#| export
@use_kwargs_dict(**htmx_resps)
def HtmxResponseHeaders(**kwargs):
    "HTMX response headers"
    res = tuple(HttpHeader(_to_htmx_header(k), v) for k,v in kwargs.items())
    return res[0] if len(res)==1 else res

In [None]:
HtmxResponseHeaders(trigger_after_settle='hi')

HttpHeader(k='HX-Trigger-After-Settle', v='hi')

In [None]:
#| export
def _annotations(anno):
    "Same as `get_annotations`, but also works on namedtuples"
    if is_namedtuple(anno): return {o:str for o in anno._fields}
    return get_annotations(anno)

In [None]:
#| export
def _is_body(anno): return issubclass(anno, (dict,ns)) or _annotations(anno)

In [None]:
#| export
def _formitem(form, k):
    "Return single item `k` from `form` if len 1, otherwise return list"
    if isinstance(form, dict): return form[k]
    o = form.getlist(k)
    return o[0] if len(o) == 1 else o if o else None

In [None]:
#| export
def form2dict(form: FormData) -> dict:
    "Convert starlette form data to a dict"
    if isinstance(form, dict): return form
    return {k: _formitem(form, k) for k in form}

In [None]:
d = [('a',1),('a',2),('b',0)]
fd = FormData(d)
res = form2dict(fd)
test_eq(res['a'], [1,2])
test_eq(res['b'], 0)

In [None]:
#| export
async def parse_form(req: Request) -> FormData:
    "Starlette errors on empty multipart forms, so this checks for that situation"
    ctype = req.headers.get("Content-Type", "")
    if ctype=='application/json': return await req.json()
    if not ctype.startswith("multipart/form-data"): return await req.form()
    try: boundary = ctype.split("boundary=")[1].strip()
    except IndexError: raise HTTPException(400, "Invalid form-data: no boundary")
    min_len = len(boundary) + 6
    clen = int(req.headers.get("Content-Length", "0"))
    if clen <= min_len: return FormData()
    return await req.form()

In [None]:
#| export
async def _from_body(req, p):
    anno = p.annotation
    # Get the fields and types of type `anno`, if available
    d = _annotations(anno)
    data = form2dict(await parse_form(req))
    if req.query_params: data = {**data, **dict(req.query_params)}
    cargs = {k: _form_arg(k, v, d) for k, v in data.items() if not d or k in d}
    return anno(**cargs)

In [None]:
async def f(req):
    def _f(p:HttpHeader): ...
    p = first(_params(_f).values())
    result = await _from_body(req, p)
    return JSONResponse(result.__dict__)

client = TestClient(Starlette(routes=[Route('/', f, methods=['POST'])]))

d = dict(k='value1',v=['value2','value3'])
response = client.post('/', data=d)
print(response.json())

{'k': 'value1', 'v': 'value3'}


In [None]:
async def f(req): return Response(str(req.query_params.getlist('x')))
client = TestClient(Starlette(routes=[Route('/', f, methods=['GET'])]))
client.get('/?x=1&x=2').text

"['1', '2']"

In [None]:
#| export
async def _find_p(req, arg:str, p:Parameter):
    "In `req` find param named `arg` of type in `p` (`arg` is ignored for body types)"
    anno = p.annotation
    # If there's an annotation of special types, return object of that type
    # GenericAlias is a type of typing for iterators like list[int] that is not a class
    if isinstance(anno, type) and not isinstance(anno, GenericAlias):
        if issubclass(anno, Request): return req
        if issubclass(anno, HtmxHeaders): return _get_htmx(req.headers)
        if issubclass(anno, Starlette): return req.scope['app']
        if _is_body(anno): return await _from_body(req, p)
    # If there's no annotation, check for special names
    if anno is empty:
        if 'request'.startswith(arg.lower()): return req
        if 'session'.startswith(arg.lower()): return req.scope.get('session', {})
        if arg.lower()=='scope': return dict2obj(req.scope)
        if arg.lower()=='auth': return req.scope.get('auth', None)
        if arg.lower()=='htmx': return _get_htmx(req.headers)
        if arg.lower()=='app': return req.scope['app']
        if arg.lower()=='body': return (await req.body()).decode()
        if arg.lower() in ('hdrs','ftrs','bodykw','htmlkw'): return getattr(req, arg.lower())
        if arg!='resp': warn(f"`{arg} has no type annotation and is not a recognised special name, so is ignored.")
        return None
    # Look through path, cookies, headers, query, and body in that order
    res = req.path_params.get(arg, None)
    if res in (empty,None): res = req.cookies.get(arg, None)
    if res in (empty,None): res = req.headers.get(snake2hyphens(arg), None)
    if res in (empty,None): res = req.query_params.getlist(arg)
    if res==[]: res = None
    if res in (empty,None): res = _formitem(await parse_form(req), arg)
    # Raise 400 error if the param does not include a default
    if (res in (empty,None)) and p.default is empty: raise HTTPException(400, f"Missing required field: {arg}")
    # If we have a default, return that if we have no value
    if res in (empty,None): res = p.default
    # We can cast str and list[str] to types; otherwise just return what we have
    if anno is empty: return res
    try: return _fix_anno(anno, res)
    except ValueError: raise HTTPException(404, req.url.path) from None

async def _wrap_req(req, params):
    return [await _find_p(req, arg, p) for arg,p in params.items()]

In [None]:
def g(req, this:Starlette, a:str, b:HttpHeader): ...

async def f(req):
    a = await _wrap_req(req, _params(g))
    return Response(str(a))

client = TestClient(Starlette(routes=[Route('/', f, methods=['POST'])]))
response = client.post('/?a=1', data=d)
print(response.text)

[<starlette.requests.Request object>, <starlette.applications.Starlette object>, '1', HttpHeader(k='value1', v='value3')]


In [None]:
def g(req, this:Starlette, a:str, b:HttpHeader): ...

async def f(req):
    a = await _wrap_req(req, _params(g))
    return Response(str(a))

client = TestClient(Starlette(routes=[Route('/', f, methods=['POST'])]))
response = client.post('/?a=1', data=d)
print(response.text)

[<starlette.requests.Request object>, <starlette.applications.Starlette object>, '1', HttpHeader(k='value1', v='value3')]


In [None]:
#| export
def flat_xt(lst):
    "Flatten lists"
    result = []
    if isinstance(lst,(FT,str)): lst=[lst]
    for item in lst:
        if isinstance(item, (list,tuple)): result.extend(item)
        else: result.append(item)
    return tuple(result)

In [None]:
x = ft('a',1)
test_eq(flat_xt([x, x, [x,x]]), (x,)*4)
test_eq(flat_xt(x), (x,))

In [None]:
#| export
class Beforeware:
    def __init__(self, f, skip=None): self.f,self.skip = f,skip or []

In [None]:
#| export
async def _handle(f, args, **kwargs):
    return (await f(*args, **kwargs)) if is_async_callable(f) else await run_in_threadpool(f, *args, **kwargs)

## Websockets / SSE

In [None]:
#| export
def _find_wsp(ws, data, hdrs, arg:str, p:Parameter):
    "In `data` find param named `arg` of type in `p` (`arg` is ignored for body types)"
    anno = p.annotation
    if isinstance(anno, type):
        if issubclass(anno, HtmxHeaders): return _get_htmx(hdrs)
        if issubclass(anno, Starlette): return ws.scope['app']
        if issubclass(anno, WebSocket): return ws
    if anno is empty:
        if arg.lower()=='ws': return ws
        if arg.lower()=='scope': return dict2obj(ws.scope)
        if arg.lower()=='data': return data
        if arg.lower()=='htmx': return _get_htmx(hdrs)
        if arg.lower()=='app': return ws.scope['app']
        if arg.lower()=='send': return partial(_send_ws, ws)
        if 'session'.startswith(arg.lower()): return ws.scope.get('session', {})
        return None
    res = data.get(arg, None)
    if res is empty or res is None: res = hdrs.get(arg, None)
    if res is empty or res is None: res = p.default
    # We can cast str and list[str] to types; otherwise just return what we have
    if not isinstance(res, (list,str)) or anno is empty: return res
    return [_fix_anno(anno, o) for o in res] if isinstance(res,list) else _fix_anno(anno, res)

def _wrap_ws(ws, data, params):
    hdrs = {k.lower().replace('-','_'):v for k,v in data.pop('HEADERS', {}).items()}
    return [_find_wsp(ws, data, hdrs, arg, p) for arg,p in params.items()]

In [None]:
#| export
async def _send_ws(ws, resp):
    if not resp: return
    res = to_xml(resp, indent=fh_cfg.indent) if isinstance(resp, (list,tuple,FT)) or hasattr(resp, '__ft__') else resp
    await ws.send_text(res)

def _ws_endp(recv, conn=None, disconn=None):
    cls = type('WS_Endp', (WebSocketEndpoint,), {"encoding":"text"})

    async def _generic_handler(handler, ws, data=None):
        wd = _wrap_ws(ws, loads(data) if data else {}, _params(handler))
        resp = await _handle(handler, wd)
        if resp: await _send_ws(ws, resp)

    async def _connect(self, ws):
        await ws.accept()
        await _generic_handler(conn, ws)

    async def _disconnect(self, ws, close_code): await _generic_handler(disconn, ws)
    async def _recv(self, ws, data): await _generic_handler(recv, ws, data)

    if    conn: cls.on_connect    = _connect
    if disconn: cls.on_disconnect = _disconnect
    cls.on_receive = _recv
    return cls

In [None]:
def on_receive(self, msg:str): return f"Message text was: {msg}"
c = _ws_endp(on_receive)
cli = TestClient(Starlette(routes=[WebSocketRoute('/', _ws_endp(on_receive))]))
with cli.websocket_connect('/') as ws:
    ws.send_text('{"msg":"Hi!"}')
    data = ws.receive_text()
    assert data == 'Message text was: Hi!'

In [None]:
#| export
def EventStream(s):
    "Create a text/event-stream response from `s`"
    return StreamingResponse(s, media_type="text/event-stream")

In [None]:
#| export
def signal_shutdown():
    event = asyncio.Event()
    def signal_handler(signum, frame):
        event.set()
        signal.signal(signum, signal.SIG_DFL)
        os.kill(os.getpid(), signum)

    for sig in (signal.SIGINT, signal.SIGTERM): signal.signal(sig, signal_handler)
    return event

## Routing and application

In [None]:
#| export
def uri(_arg, **kwargs):
    return f"{quote(_arg)}/{urlencode(kwargs, doseq=True)}"

In [None]:
#| export
def decode_uri(s):
    arg,_,kw = s.partition('/')
    return unquote(arg), {k:v[0] for k,v in parse_qs(kw).items()}

In [None]:
#| export
from starlette.convertors import StringConvertor

In [None]:
#| export
StringConvertor.regex = "[^/]*"  # `+` replaced with `*`

@patch
def to_string(self:StringConvertor, value: str) -> str:
    value = str(value)
    assert "/" not in value, "May not contain path separators"
    # assert value, "Must not be empty"  # line removed due to errors
    return value

In [None]:
#| export
@patch
def url_path_for(self:HTTPConnection, name: str, **path_params):
    lp = self.scope['app'].url_path_for(name, **path_params)
    return URLPath(f"{self.scope['root_path']}{lp}", lp.protocol, lp.host)

In [None]:
#| export
_verbs = dict(get='hx-get', post='hx-post', put='hx-post', delete='hx-delete', patch='hx-patch', link='href')

def _url_for(req, t):
    if callable(t): t = t.__routename__
    kw = {}
    if t.find('/')>-1 and (t.find('?')<0 or t.find('/')<t.find('?')): t,kw = decode_uri(t)
    t,m,q = t.partition('?')
    return f"{req.url_path_for(t, **kw)}{m}{q}"

def _find_targets(req, resp):
    if isinstance(resp, tuple):
        for o in resp: _find_targets(req, o)
    if isinstance(resp, FT):
        for o in resp.children: _find_targets(req, o)
        for k,v in _verbs.items():
            t = resp.attrs.pop(k, None)
            if t: resp.attrs[v] = _url_for(req, t)

def _apply_ft(o):
    if isinstance(o, tuple): o = tuple(_apply_ft(c) for c in o)
    if hasattr(o, '__ft__'): o = o.__ft__()
    if isinstance(o, FT): o.children = tuple(_apply_ft(c) for c in o.children)
    return o

def _to_xml(req, resp, indent):
    resp = _apply_ft(resp)
    _find_targets(req, resp)
    return to_xml(resp, indent)

In [None]:
#| export
_iter_typs = (tuple,list,map,filter,range,types.GeneratorType)

In [None]:
#| export
def flat_tuple(o):
    "Flatten lists"
    result = []
    if not isinstance(o,_iter_typs): o=[o]
    o = list(o)
    for item in o:
        if isinstance(item, _iter_typs): result.extend(list(item))
        else: result.append(item)
    return tuple(result)

In [None]:
#| export
def noop_body(c, req):
    "Default Body wrap function which just returns the content"
    return c

In [None]:
#| export
def respond(req, heads, bdy):
    "Default FT response creation function"
    body_wrap = getattr(req, 'body_wrap', noop_body)
    params = inspect.signature(body_wrap).parameters
    bw_args = (bdy, req) if len(params)>1 else (bdy,)
    body = Body(body_wrap(*bw_args), *flat_xt(req.ftrs), **req.bodykw)
    return Html(Head(*heads, *flat_xt(req.hdrs)), body, **req.htmlkw)

In [None]:
#| export
def _xt_cts(req, resp):
    resp = flat_tuple(resp)
    resp = resp + tuple(getattr(req, 'injects', ()))
    http_hdrs,resp = partition(resp, risinstance(HttpHeader))
    http_hdrs = {o.k:str(o.v) for o in http_hdrs}
    tasks,resp = partition(resp, risinstance(BackgroundTask))
    ts = BackgroundTasks()
    for t in tasks: ts.tasks.append(t)
    hdr_tags = 'title','meta','link','style','base'
    heads,bdy = partition(resp, lambda o: getattr(o, 'tag', '') in hdr_tags)
    if resp and 'hx-request' not in req.headers and not any(getattr(o, 'tag', '')=='html' for o in resp):
        title = [] if any(getattr(o, 'tag', '')=='title' for o in heads) else [Title(req.app.title)]
        resp = respond(req, [*heads, *title], bdy)
    return _to_xml(req, resp, indent=fh_cfg.indent), http_hdrs, ts

In [None]:
#| export
def _xt_resp(req, resp, status_code):
    cts,http_hdrs,tasks = _xt_cts(req, resp)
    return HTMLResponse(cts, status_code=status_code, headers=http_hdrs, background=tasks)

In [None]:
#| export
def _is_ft_resp(resp): return isinstance(resp, _iter_typs+(HttpHeader,FT)) or hasattr(resp, '__ft__')

In [None]:
#| export
def _resp(req, resp, cls=empty, status_code=200):
    if not resp: resp=()
    if hasattr(resp, '__response__'): resp = resp.__response__(req)
    if cls in (Any,FT): cls=empty
    if isinstance(resp, FileResponse) and not os.path.exists(resp.path): raise HTTPException(404, resp.path)
    if cls is not empty: return cls(resp, status_code=status_code)
    if isinstance(resp, Response): return resp # respect manually set status_code
    if _is_ft_resp(resp): return _xt_resp(req, resp, status_code)
    if isinstance(resp, str): cls = HTMLResponse
    elif isinstance(resp, Mapping): cls = JSONResponse
    else:
        resp = str(resp)
        cls = HTMLResponse
    return cls(resp, status_code=status_code)

In [None]:
#| export
class Redirect:
    "Use HTMX or Starlette RedirectResponse as required to redirect to `loc`"
    def __init__(self, loc): self.loc = loc
    def __response__(self, req):
        if 'hx-request' in req.headers: return HtmxResponseHeaders(redirect=self.loc)
        return RedirectResponse(self.loc, status_code=303)

In [None]:
#| export
async def _wrap_call(f, req, params):
    wreq = await _wrap_req(req, params)
    return await _handle(f, wreq)

In [None]:
#| export
htmx_exts = {
    "head-support": "https://unpkg.com/htmx-ext-head-support@2.0.3/head-support.js",
    "preload": "https://unpkg.com/htmx-ext-preload@2.1.0/preload.js",
    "class-tools": "https://unpkg.com/htmx-ext-class-tools@2.0.1/class-tools.js",
    "loading-states": "https://unpkg.com/htmx-ext-loading-states@2.0.0/loading-states.js",
    "multi-swap": "https://unpkg.com/htmx-ext-multi-swap@2.0.0/multi-swap.js",
    "path-deps": "https://unpkg.com/htmx-ext-path-deps@2.0.0/path-deps.js",
    "remove-me": "https://unpkg.com/htmx-ext-remove-me@2.0.0/remove-me.js",
    "ws": "https://unpkg.com/htmx-ext-ws@2.0.2/ws.js",
    "chunked-transfer": "https://unpkg.com/htmx-ext-transfer-encoding-chunked@0.4.0/transfer-encoding-chunked.js"
}

In [None]:
#| export
htmxsrc   = Script(src="https://unpkg.com/htmx.org@2.0.4/dist/htmx.min.js")
fhjsscr   = Script(src="https://cdn.jsdelivr.net/gh/answerdotai/fasthtml-js@1.0.12/fasthtml.js")
surrsrc   = Script(src="https://cdn.jsdelivr.net/gh/answerdotai/surreal@main/surreal.js")
scopesrc  = Script(src="https://cdn.jsdelivr.net/gh/gnat/css-scope-inline@main/script.js")
viewport  = Meta(name="viewport", content="width=device-width, initial-scale=1, viewport-fit=cover")
charset   = Meta(charset="utf-8")

In [None]:
#| export
def get_key(key=None, fname='.sesskey'):
    if key: return key
    fname = Path(fname)
    if fname.exists(): return fname.read_text()
    key = str(uuid.uuid4())
    fname.write_text(key)
    return key

In [None]:
get_key()

'5a5e5544-5ee8-46f2-836e-924976ce8b58'

In [None]:
#| export
def _list(o): return [] if not o else list(o) if isinstance(o, (tuple,list)) else [o]

In [None]:
#| export
def _wrap_ex(f, status_code, hdrs, ftrs, htmlkw, bodykw, body_wrap):
    async def _f(req, exc):
        req.hdrs,req.ftrs,req.htmlkw,req.bodykw = map(deepcopy, (hdrs, ftrs, htmlkw, bodykw))
        req.body_wrap = body_wrap
        res = await _handle(f, (req, exc))
        return _resp(req, res, status_code=status_code)
    return _f

In [None]:
#| export
def qp(p:str, **kw) -> str:
    "Add query parameters to path p"
    kw = {k:('' if v in (False,None) else v) for k,v in kw.items()}
    return p + ('?' + urlencode(kw,doseq=True) if kw else '')

In [None]:
qp('/foo', a=None, b=False, c=[1,2], d='bar')

'/foo?a=&b=&c=1&c=2&d=bar'

In [None]:
#| export
def def_hdrs(htmx=True, surreal=True):
    "Default headers for a FastHTML app"
    hdrs = []
    if surreal: hdrs = [surrsrc,scopesrc] + hdrs
    if htmx: hdrs = [htmxsrc,fhjsscr] + hdrs
    return [charset, viewport] + hdrs

In [None]:
#| export
cors_allow = Middleware(CORSMiddleware, allow_credentials=True,
                        allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])

iframe_scr = Script(NotStr("""
    function sendmsg() {
        window.parent.postMessage({height: document.documentElement.offsetHeight}, '*');
    }
    window.onload = function() {
        sendmsg();
        document.body.addEventListener('htmx:afterSettle',    sendmsg);
        document.body.addEventListener('htmx:wsAfterMessage', sendmsg);
    };"""))

In [None]:
#| export
class FastHTML(Starlette):
    def __init__(self, debug=False, routes=None, middleware=None, title: str = "FastHTML page", exception_handlers=None,
                 on_startup=None, on_shutdown=None, lifespan=None, hdrs=None, ftrs=None, exts=None,
                 before=None, after=None, surreal=True, htmx=True, default_hdrs=True, sess_cls=SessionMiddleware,
                 secret_key=None, session_cookie='session_', max_age=365*24*3600, sess_path='/',
                 same_site='lax', sess_https_only=False, sess_domain=None, key_fname='.sesskey',
                 body_wrap=noop_body, htmlkw=None, nb_hdrs=False, **bodykw):
        middleware,before,after = map(_list, (middleware,before,after))
        self.title = title
        hdrs,ftrs,exts = map(listify, (hdrs,ftrs,exts))
        exts = {k:htmx_exts[k] for k in exts}
        htmlkw = htmlkw or {}
        if default_hdrs: hdrs = def_hdrs(htmx, surreal=surreal) + hdrs
        hdrs += [Script(src=ext) for ext in exts.values()]
        if IN_NOTEBOOK:
            hdrs.append(iframe_scr)
            from IPython.display import display,HTML
            if nb_hdrs: display(HTML(to_xml(tuple(hdrs))))
            middleware.append(cors_allow)
        on_startup,on_shutdown = listify(on_startup) or None,listify(on_shutdown) or None
        self.lifespan,self.hdrs,self.ftrs = lifespan,hdrs,ftrs
        self.body_wrap,self.before,self.after,self.htmlkw,self.bodykw = body_wrap,before,after,htmlkw,bodykw
        secret_key = get_key(secret_key, key_fname)
        if sess_cls:
            sess = Middleware(sess_cls, secret_key=secret_key,session_cookie=session_cookie,
                              max_age=max_age, path=sess_path, same_site=same_site,
                              https_only=sess_https_only, domain=sess_domain)
            middleware.append(sess)
        exception_handlers = ifnone(exception_handlers, {})
        if 404 not in exception_handlers:
            def _not_found(req, exc): return  Response('404 Not Found', status_code=404)
            exception_handlers[404] = _not_found
        excs = {k:_wrap_ex(v, k, hdrs, ftrs, htmlkw, bodykw, body_wrap=body_wrap) for k,v in exception_handlers.items()}
        super().__init__(debug, routes, middleware=middleware, exception_handlers=excs, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)

    def add_route(self, route):
        route.methods = [m.upper() for m in listify(route.methods)]
        self.router.routes = [r for r in self.router.routes if not
                       (r.path==route.path and r.name == route.name and
                        ((route.methods is None) or (set(r.methods) == set(route.methods))))]
        self.router.routes.append(route)

In [None]:
#| export
all_meths = 'get post put delete patch head trace options'.split()

In [None]:
#| export
@patch
def _endp(self:FastHTML, f, body_wrap):
    sig = signature_ex(f, True)
    async def _f(req):
        resp = None
        req.injects = []
        req.hdrs,req.ftrs,req.htmlkw,req.bodykw = map(deepcopy, (self.hdrs,self.ftrs,self.htmlkw,self.bodykw))
        req.hdrs,req.ftrs = listify(req.hdrs),listify(req.ftrs)
        for b in self.before:
            if not resp:
                if isinstance(b, Beforeware): bf,skip = b.f,b.skip
                else: bf,skip = b,[]
                if not any(re.fullmatch(r, req.url.path) for r in skip):
                    resp = await _wrap_call(bf, req, _params(bf))
        req.body_wrap = body_wrap
        if not resp: resp = await _wrap_call(f, req, sig.parameters)
        for a in self.after:
            _,*wreq = await _wrap_req(req, _params(a))
            nr = a(resp, *wreq)
            if nr: resp = nr
        return _resp(req, resp, sig.return_annotation)
    return _f

In [None]:
#| export
@patch
def _add_ws(self:FastHTML, func, path, conn, disconn, name, middleware):
    endp = _ws_endp(func, conn, disconn)
    route = WebSocketRoute(path, endpoint=endp, name=name, middleware=middleware)
    route.methods = ['ws']
    self.add_route(route)
    return func

In [None]:
#| export
@patch
def ws(self:FastHTML, path:str, conn=None, disconn=None, name=None, middleware=None):
    "Add a websocket route at `path`"
    def f(func=noop): return self._add_ws(func, path, conn, disconn, name=name, middleware=middleware)
    return f

In [None]:
#| export
def _mk_locfunc(f,p):
    class _lf:
        def __init__(self): update_wrapper(self, f)
        def __call__(self, *args, **kw): return f(*args, **kw)
        def to(self, **kw): return qp(p, **kw)
        def __str__(self): return p
    return _lf()

In [None]:
#| export
def nested_name(f):
    "Get name of function `f` using '_' to join nested function names"
    return f.__qualname__.replace('.<locals>.', '_')

In [None]:
def f():
    def g(): ...
    return g

In [None]:
func = f()
nested_name(func)

'f_g'

In [None]:
#| export
@patch
def _add_route(self:FastHTML, func, path, methods, name, include_in_schema, body_wrap):
    n,fn,p = name,nested_name(func),None if callable(path) else path
    if methods: m = [methods] if isinstance(methods,str) else methods
    elif fn in all_meths and p is not None: m = [fn]
    else: m = ['get','post']
    if not n: n = fn
    if not p: p = '/'+('' if fn=='index' else fn)
    route = Route(p, endpoint=self._endp(func, body_wrap or self.body_wrap), methods=m, name=n, include_in_schema=include_in_schema)
    self.add_route(route)
    lf = _mk_locfunc(func, p)
    lf.__routename__ = n
    return lf

In [None]:
#| export
@patch
def route(self:FastHTML, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None):
    "Add a route at `path`"
    def f(func): return self._add_route(func, path, methods, name=name, include_in_schema=include_in_schema, body_wrap=body_wrap)
    return f(path) if callable(path) else f

for o in all_meths: setattr(FastHTML, o, partialmethod(FastHTML.route, methods=o))

In [None]:
app = FastHTML()
@app.get
def foo(a:str, b:list[int]): ...

print(app.routes)
foo.to(a='bar', b=[1,2])

[Route(path='/foo', name='foo', methods=['GET', 'HEAD'])]


'/foo?a=bar&b=1&b=2'

In [None]:
#| export
def serve(
        appname=None, # Name of the module
        app='app', # App instance to be served
        host='0.0.0.0', # If host is 0.0.0.0 will convert to localhost
        port=None, # If port is None it will default to 5001 or the PORT environment variable
        reload=True, # Default is to reload the app upon code changes
        reload_includes:list[str]|str|None=None, # Additional files to watch for changes
        reload_excludes:list[str]|str|None=None # Files to ignore for changes
        ):
    "Run the app in an async server, with live reload set as the default."
    bk = inspect.currentframe().f_back
    glb = bk.f_globals
    code = bk.f_code
    if not appname:
        if glb.get('__name__')=='__main__': appname = Path(glb.get('__file__', '')).stem
        elif code.co_name=='main' and bk.f_back.f_globals.get('__name__')=='__main__': appname = inspect.getmodule(bk).__name__
    import uvicorn
    if appname:
        if not port: port=int(os.getenv("PORT", default=5001))
        print(f'Link: http://{"localhost" if host=="0.0.0.0" else host}:{port}')
        uvicorn.run(f'{appname}:{app}', host=host, port=port, reload=reload, reload_includes=reload_includes, reload_excludes=reload_excludes)

In [None]:
#| export
class Client:
    "A simple httpx ASGI client that doesn't require `async`"
    def __init__(self, app, url="http://testserver"):
        self.cli = AsyncClient(transport=ASGITransport(app), base_url=url)

    def _sync(self, method, url, **kwargs):
        async def _request(): return await self.cli.request(method, url, **kwargs)
        with from_thread.start_blocking_portal() as portal: return portal.call(_request)

for o in ('get', 'post', 'delete', 'put', 'patch', 'options'): setattr(Client, o, partialmethod(Client._sync, o))

In [None]:
app = FastHTML(routes=[Route('/', lambda _: Response('test'))])
cli = Client(app)

cli.get('/').text

'test'

Note that you can also use Starlette's `TestClient` instead of FastHTML's `Client`. They should be largely interchangable.

## FastHTML Tests

In [None]:
def get_cli(app): return app,TestClient(app),app.route

In [None]:
app,cli,rt = get_cli(FastHTML(secret_key='soopersecret'))

In [None]:
app,cli,rt = get_cli(FastHTML(title="My Custom Title"))
@app.get
def foo(): return Div("Hello World")

print(app.routes)

response = cli.get('/foo')
assert '<title>My Custom Title</title>' in response.text

foo.to(param='value')

[Route(path='/foo', name='foo', methods=['GET', 'HEAD'])]


'/foo?param=value'

In [None]:
app,cli,rt = get_cli(FastHTML())

@rt('/xt2')
def get(): return H1('bar')

txt = cli.get('/xt2').text
assert '<title>FastHTML page</title>' in txt and '<h1>bar</h1>' in txt and '<html>' in txt

In [None]:
@rt("/hi")
def get(): return 'Hi there'

r = cli.get('/hi')
r.text

'Hi there'

In [None]:
@rt("/hi")
def post(): return 'Postal'

cli.post('/hi').text

'Postal'

In [None]:
@app.get("/hostie")
def show_host(req): return req.headers['host']

cli.get('/hostie').text

'testserver'

In [None]:
@app.get("/setsess")
def set_sess(session):
   session['foo'] = 'bar'
   return 'ok'

@app.ws("/ws")
def ws(self, msg:str, ws:WebSocket, session): return f"Message text was: {msg} with session {session.get('foo')}, from client: {ws.client}"

cli.get('/setsess')
with cli.websocket_connect('/ws') as ws:
    ws.send_text('{"msg":"Hi!"}')
    data = ws.receive_text()
assert 'Message text was: Hi! with session bar' in data
print(data)

Message text was: Hi! with session bar, from client: Address(host='testclient', port=50000)


In [None]:
@rt
def yoyo(): return 'a yoyo'

cli.post('/yoyo').text

'a yoyo'

In [None]:
@app.get
def autopost(): return Html(Div('Text.', hx_post=yoyo()))
print(cli.get('/autopost').text)

 <!doctype html>
 <html>
   <div hx-post="a yoyo">Text.</div>
 </html>



In [None]:
@app.get
def autopost2(): return Html(Body(Div('Text.', cls='px-2', hx_post=show_host.to(a='b'))))
print(cli.get('/autopost2').text)

 <!doctype html>
 <html>
   <body>
     <div class="px-2" hx-post="/hostie?a=b">Text.</div>
   </body>
 </html>



In [None]:
@app.get
def autoget2(): return Html(Div('Text.', hx_get=show_host))
print(cli.get('/autoget2').text)

 <!doctype html>
 <html>
   <div hx-get="/hostie">Text.</div>
 </html>



In [None]:
@rt('/user/{nm}', name='gday')
def get(nm:str=''): return f"Good day to you, {nm}!"
cli.get('/user/Alexis').text

'Good day to you, Alexis!'

In [None]:
@app.get
def autolink(): return Html(Div('Text.', link=uri('gday', nm='Alexis')))
print(cli.get('/autolink').text)

 <!doctype html>
 <html>
   <div href="/user/Alexis">Text.</div>
 </html>



In [None]:
@rt('/link')
def get(req): return f"{req.url_for('gday', nm='Alexis')}; {req.url_for('show_host')}"

cli.get('/link').text

'http://testserver/user/Alexis; http://testserver/hostie'

In [None]:
@app.get("/background")
async def background_task(request):
    async def long_running_task():
        await asyncio.sleep(0.1)
        print("Background task completed!")
    return P("Task started"), BackgroundTask(long_running_task)

response = cli.get("/background")

Background task completed!


In [None]:
test_eq(app.router.url_path_for('gday', nm='Jeremy'), '/user/Jeremy')

In [None]:
hxhdr = {'headers':{'hx-request':"1"}}

@rt('/ft')
def get(): return Title('Foo'),H1('bar')

txt = cli.get('/ft').text
assert '<title>Foo</title>' in txt and '<h1>bar</h1>' in txt and '<html>' in txt

@rt('/xt2')
def get(): return H1('bar')

txt = cli.get('/xt2').text
assert '<title>FastHTML page</title>' in txt and '<h1>bar</h1>' in txt and '<html>' in txt

assert cli.get('/xt2', **hxhdr).text.strip() == '<h1>bar</h1>'

@rt('/xt3')
def get(): return Html(Head(Title('hi')), Body(P('there')))

txt = cli.get('/xt3').text
assert '<title>FastHTML page</title>' not in txt and '<title>hi</title>' in txt and '<p>there</p>' in txt

In [None]:
@rt('/oops')
def get(nope): return nope
test_warns(lambda: cli.get('/oops?nope=1'))

In [None]:
def test_r(cli, path, exp, meth='get', hx=False, **kwargs):
    if hx: kwargs['headers'] = {'hx-request':"1"}
    test_eq(getattr(cli, meth)(path, **kwargs).text, exp)

ModelName = str_enum('ModelName', "alexnet", "resnet", "lenet")
fake_db = [{"name": "Foo"}, {"name": "Bar"}]

In [None]:
@rt('/html/{idx}')
async def get(idx:int): return Body(H4(f'Next is {idx+1}.'))

In [None]:
@rt("/models/{nm}")
def get(nm:ModelName): return nm

@rt("/files/{path}")
async def get(path: Path): return path.with_suffix('.txt')

@rt("/items/")
def get(idx:int|None = 0): return fake_db[idx]

@rt("/idxl/")
def get(idx:list[int]): return str(idx)

In [None]:
r = cli.get('/html/1', headers={'hx-request':"1"})
assert '<h4>Next is 2.</h4>' in r.text
test_r(cli, '/models/alexnet', 'alexnet')
test_r(cli, '/files/foo', 'foo.txt')
test_r(cli, '/items/?idx=1', '{"name":"Bar"}')
test_r(cli, '/items/', '{"name":"Foo"}')
assert cli.get('/items/?idx=g').text=='404 Not Found'
assert cli.get('/items/?idx=g').status_code == 404
test_r(cli, '/idxl/?idx=1&idx=2', '[1, 2]')
assert cli.get('/idxl/?idx=1&idx=g').status_code == 404

In [None]:
app = FastHTML()
rt = app.route
cli = TestClient(app)
@app.route(r'/static/{path:path}.jpg')
def index(path:str): return f'got {path}'
cli.get('/static/sub/a.b.jpg').text

'got sub/a.b'

In [None]:
app.chk = 'foo'

In [None]:
@app.get("/booly/")
def _(coming:bool=True): return 'Coming' if coming else 'Not coming'

@app.get("/datie/")
def _(d:parsed_date): return d

@app.get("/ua")
async def _(user_agent:str): return user_agent

@app.get("/hxtest")
def _(htmx): return htmx.request

@app.get("/hxtest2")
def _(foo:HtmxHeaders, req): return foo.request

@app.get("/app")
def _(app): return app.chk

@app.get("/app2")
def _(foo:FastHTML): return foo.chk,HttpHeader("mykey", "myval")

@app.get("/app3")
def _(foo:FastHTML): return HtmxResponseHeaders(location="http://example.org")

@app.get("/app4")
def _(foo:FastHTML): return Redirect("http://example.org")

In [None]:
test_r(cli, '/booly/?coming=true', 'Coming')
test_r(cli, '/booly/?coming=no', 'Not coming')
date_str = "17th of May, 2024, 2p"
test_r(cli, f'/datie/?d={date_str}', '2024-05-17 14:00:00')
test_r(cli, '/ua', 'FastHTML', headers={'User-Agent':'FastHTML'})
test_r(cli, '/hxtest' , '1', headers={'HX-Request':'1'})
test_r(cli, '/hxtest2', '1', headers={'HX-Request':'1'})
test_r(cli, '/app' , 'foo')

In [None]:
r = cli.get('/app2', **hxhdr)
test_eq(r.text, 'foo')
test_eq(r.headers['mykey'], 'myval')

In [None]:
r = cli.get('/app3')
test_eq(r.headers['HX-Location'], 'http://example.org')

In [None]:
r = cli.get('/app4', follow_redirects=False)
test_eq(r.status_code, 303)

In [None]:
r = cli.get('/app4', headers={'HX-Request':'1'})
test_eq(r.headers['HX-Redirect'], 'http://example.org')

In [None]:
@rt
def meta():
    return ((Title('hi'),H1('hi')),
        (Meta(property='image'), Meta(property='site_name'))
    )

t = cli.post('/meta').text
assert re.search(r'<body>\s*<h1>hi</h1>\s*</body>', t)
assert '<meta' in t

In [None]:
@app.post('/profile/me')
def profile_update(username: str): return username

test_r(cli, '/profile/me', 'Alexis', 'post', data={'username' : 'Alexis'})
test_r(cli, '/profile/me', 'Missing required field: username', 'post', data={})

In [None]:
# Example post request with parameter that has a default value
@app.post('/pet/dog')
def pet_dog(dogname: str = None): return dogname

# Working post request with optional parameter
test_r(cli, '/pet/dog', '', 'post', data={})

In [None]:
@dataclass
class Bodie: a:int;b:str

@rt("/bodie/{nm}")
def post(nm:str, data:Bodie):
    res = asdict(data)
    res['nm'] = nm
    return res

@app.post("/bodied/")
def bodied(data:dict): return data

nt = namedtuple('Bodient', ['a','b'])

@app.post("/bodient/")
def bodient(data:nt): return asdict(data)

class BodieTD(TypedDict): a:int;b:str='foo'

@app.post("/bodietd/")
def bodient(data:BodieTD): return data

class Bodie2:
    a:int|None; b:str
    def __init__(self, a, b='foo'): store_attr()

@rt("/bodie2/", methods=['get','post'])
def bodie(d:Bodie2): return f"a: {d.a}; b: {d.b}"

In [None]:
from fasthtml.xtend import Titled

In [None]:
d = dict(a=1, b='foo')

test_r(cli, '/bodie/me', '{"a":1,"b":"foo","nm":"me"}', 'post', data=dict(a=1, b='foo', nm='me'))
test_r(cli, '/bodied/', '{"a":"1","b":"foo"}', 'post', data=d)
test_r(cli, '/bodie2/', 'a: 1; b: foo', 'post', data={'a':1})
test_r(cli, '/bodie2/?a=1&b=foo&nm=me', 'a: 1; b: foo')
test_r(cli, '/bodient/', '{"a":"1","b":"foo"}', 'post', data=d)
test_r(cli, '/bodietd/', '{"a":1,"b":"foo"}', 'post', data=d)

In [None]:
# Testing POST with Content-Type: application/json
@app.post("/")
def index(it: Bodie): return Titled("It worked!", P(f"{it.a}, {it.b}"))

s = json.dumps({"b": "Lorem", "a": 15})
response = cli.post('/', headers={"Content-Type": "application/json"}, data=s).text
assert "<title>It worked!</title>" in response and "<p>15, Lorem</p>" in response

In [None]:
# Testing POST with Content-Type: application/json
@app.post("/bodytext")
def index(body): return body

response = cli.post('/bodytext', headers={"Content-Type": "application/json"}, data=s).text
test_eq(response, '{"b": "Lorem", "a": 15}')

In [None]:
files = [ ('files', ('file1.txt', b'content1')),
         ('files', ('file2.txt', b'content2')) ]

In [None]:
@rt("/uploads")
async def post(files:list[UploadFile]):
    return ','.join([(await file.read()).decode() for file in files])

res = cli.post('/uploads', files=files)
print(res.status_code)
print(res.text)

200
content1,content2


In [None]:
res = cli.post('/uploads', files=[files[0]])
print(res.status_code)
print(res.text)

200
content1


In [None]:
@rt("/setsess")
def get(sess, foo:str=''):
    now = datetime.now()
    sess['auth'] = str(now)
    return f'Set to {now}'

@rt("/getsess")
def get(sess): return f'Session time: {sess["auth"]}'

print(cli.get('/setsess').text)
time.sleep(0.01)

cli.get('/getsess').text

Set to 2025-01-12 14:12:46.576323


'Session time: 2025-01-12 14:12:46.576323'

In [None]:
@rt("/sess-first")
def post(sess, name: str):
    sess["name"] = name
    return str(sess)

cli.post('/sess-first', data={'name': 2})

@rt("/getsess-all")
def get(sess): return sess['name']

test_eq(cli.get('/getsess-all').text, '2')

In [None]:
@rt("/upload")
async def post(uf:UploadFile): return (await uf.read()).decode()

with open('../../CHANGELOG.md', 'rb') as f:
    print(cli.post('/upload', files={'uf':f}, data={'msg':'Hello'}).text[:15])

# Release notes


In [None]:
@rt("/form-submit/{list_id}")
def options(list_id: str):
    headers = {
        'Access-Control-Allow-Origin': '*',
        'Access-Control-Allow-Methods': 'POST',
        'Access-Control-Allow-Headers': '*',
    }
    return Response(status_code=200, headers=headers)

In [None]:
h = cli.options('/form-submit/2').headers
test_eq(h['Access-Control-Allow-Methods'], 'POST')

In [None]:
from fasthtml.authmw import user_pwd_auth

In [None]:
def _not_found(req, exc): return Div('nope')

app,cli,rt = get_cli(FastHTML(exception_handlers={404:_not_found}))

txt = cli.get('/').text
assert '<div>nope</div>' in txt
assert '<!doctype html>' in txt

In [None]:
app,cli,rt = get_cli(FastHTML())

@rt("/{name}/{age}")
def get(name: str, age: int):
    return Titled(f"Hello {name.title()}, age {age}")

assert '<title>Hello Uma, age 5</title>' in cli.get('/uma/5').text
assert '404 Not Found' in cli.get('/uma/five').text

In [None]:
auth = user_pwd_auth(testuser='spycraft')
app,cli,rt = get_cli(FastHTML(middleware=[auth]))

@rt("/locked")
def get(auth): return 'Hello, ' + auth

test_eq(cli.get('/locked').text, 'not authenticated')
test_eq(cli.get('/locked', auth=("testuser","spycraft")).text, 'Hello, testuser')

In [None]:
auth = user_pwd_auth(testuser='spycraft')
app,cli,rt = get_cli(FastHTML(middleware=[auth]))

@rt("/locked")
def get(auth): return 'Hello, ' + auth

test_eq(cli.get('/locked').text, 'not authenticated')
test_eq(cli.get('/locked', auth=("testuser","spycraft")).text, 'Hello, testuser')

## APIRouter

In [None]:
#| export
class RouteFuncs:
    def __init__(self): super().__setattr__('_funcs', {})
    def __setattr__(self, name, value): self._funcs[name] = value
    def __getattr__(self, name):
        if name in all_meths: raise AttributeError("Route functions with HTTP Names are not accessible here")
        try: return self._funcs[name]
        except KeyError: raise AttributeError(f"No route named {name} found in route functions")
    def __dir__(self): return list(self._funcs.keys())

In [None]:
#| export
class APIRouter:
    "Add routes to an app"
    def __init__(self, prefix:str|None=None, body_wrap=noop_body):
        self.routes,self.wss = [],[]
        self.rt_funcs = RouteFuncs()  # Store wrapped route function for discoverability
        self.prefix = prefix if prefix else ""
        self.body_wrap = body_wrap

    def _wrap_func(self, func, path=None):
        name = func.__name__
        wrapped = _mk_locfunc(func, path)
        wrapped.__routename__ = name
        # If you are using the def get or def post method names, this approach is not supported
        if name not in all_meths: setattr(self.rt_funcs, name, wrapped)
        return wrapped

    def __call__(self, path:str=None, methods=None, name=None, include_in_schema=True, body_wrap=None):
        "Add a route at `path`"
        def f(func):
            p = self.prefix + ("/" + ('' if path.__name__=='index' else func.__name__) if callable(path) else path)
            wrapped = self._wrap_func(func, p)
            self.routes.append((func, p, methods, name, include_in_schema, body_wrap or self.body_wrap))
            return wrapped
        return f(path) if callable(path) else f

    def __getattr__(self, name):
        try: return getattr(self.rt_funcs, name)
        except AttributeError: return super().__getattr__(self, name)

    def to_app(self, app):
        "Add routes to `app`"
        for args in self.routes: app._add_route(*args)
        for args in self.wss: app._add_ws(*args)

    def ws(self, path:str, conn=None, disconn=None, name=None, middleware=None):
        "Add a websocket route at `path`"
        def f(func=noop): return self.wss.append((func, f"{self.prefix}{path}", conn, disconn, name, middleware))
        return f

In [None]:
ar = APIRouter()

In [None]:
@ar("/hi")
def get(): return 'Hi there'
@ar("/hi")
def post(): return 'Postal'
@ar
def ho(): return 'Ho ho'
@ar("/hostie")
def show_host(req): return req.headers['host']
@ar
def yoyo(): return 'a yoyo'
@ar
def index(): return "home page"

@ar.ws("/ws")
def ws(self, msg:str): return f"Message text was: {msg}"

In [None]:
app,cli,_ = get_cli(FastHTML())
ar.to_app(app)

In [None]:
assert str(yoyo) == '/yoyo'
# ensure route functions are properly discoverable on `APIRouter` and `APIRouter.rt_funcs`
assert ar.prefix == ''
assert str(ar.rt_funcs.index) == '/'
assert str(ar.index) == '/'
with ExceptionExpected(): ar.blah()
with ExceptionExpected(): ar.rt_funcs.blah()
# ensure any route functions named using an HTTPMethod are not discoverable via `rt_funcs`
assert "get" not in ar.rt_funcs._funcs.keys()

In [None]:
test_eq(cli.get('/hi').text, 'Hi there')
test_eq(cli.post('/hi').text, 'Postal')
test_eq(cli.get('/hostie').text, 'testserver')
test_eq(cli.post('/yoyo').text, 'a yoyo')

test_eq(cli.get('/ho').text, 'Ho ho')
test_eq(cli.post('/ho').text, 'Ho ho')

In [None]:
with cli.websocket_connect('/ws') as ws:
    ws.send_text('{"msg":"Hi!"}')
    data = ws.receive_text()
    assert data == 'Message text was: Hi!'

In [None]:
ar2 = APIRouter("/products")

In [None]:
@ar2("/hi")
def get(): return 'Hi there'
@ar2("/hi")
def post(): return 'Postal'
@ar2
def ho(): return 'Ho ho'
@ar2("/hostie")
def show_host(req): return req.headers['host']
@ar2
def yoyo(): return 'a yoyo'
@ar2
def index(): return "home page"

@ar2.ws("/ws")
def ws(self, msg:str): return f"Message text was: {msg}"

In [None]:
app,cli,_ = get_cli(FastHTML())
ar2.to_app(app)

In [None]:
assert str(yoyo) == '/products/yoyo'
assert ar2.prefix == '/products'
assert str(ar2.rt_funcs.index) == '/products/'
assert str(ar2.index) == '/products/'
assert str(ar.index) == '/'
with ExceptionExpected(): ar2.blah()
with ExceptionExpected(): ar2.rt_funcs.blah()
assert "get" not in ar2.rt_funcs._funcs.keys()

In [None]:
test_eq(cli.get('/products/hi').text, 'Hi there')
test_eq(cli.post('/products/hi').text, 'Postal')
test_eq(cli.get('/products/hostie').text, 'testserver')
test_eq(cli.post('/products/yoyo').text, 'a yoyo')

test_eq(cli.get('/products/ho').text, 'Ho ho')
test_eq(cli.post('/products/ho').text, 'Ho ho')

In [None]:
with cli.websocket_connect('/products/ws') as ws:
    ws.send_text('{"msg":"Hi!"}')
    data = ws.receive_text()
    assert data == 'Message text was: Hi!'

In [None]:
#| export
for o in all_meths: setattr(APIRouter, o, partialmethod(APIRouter.__call__, methods=o))

In [None]:
@ar.get
def hi2(): return 'Hi there'
@ar.get("/hi3")
def _(): return 'Hi there'
@ar.post("/post2")
def _(): return 'Postal'

@ar2.get
def hi2(): return 'Hi there'
@ar2.get("/hi3")
def _(): return 'Hi there'
@ar2.post("/post2")
def _(): return 'Postal'

## Extras

In [None]:
app,cli,rt = get_cli(FastHTML(secret_key='soopersecret'))

In [None]:
#| export
def cookie(key: str, value="", max_age=None, expires=None, path="/", domain=None, secure=False, httponly=False, samesite="lax",):
    "Create a 'set-cookie' `HttpHeader`"
    cookie = cookies.SimpleCookie()
    cookie[key] = value
    if max_age is not None: cookie[key]["max-age"] = max_age
    if expires is not None:
        cookie[key]["expires"] = format_datetime(expires, usegmt=True) if isinstance(expires, datetime) else expires
    if path is not None: cookie[key]["path"] = path
    if domain is not None: cookie[key]["domain"] = domain
    if secure: cookie[key]["secure"] = True
    if httponly: cookie[key]["httponly"] = True
    if samesite is not None:
        assert samesite.lower() in [ "strict", "lax", "none", ], "must be 'strict', 'lax' or 'none'"
        cookie[key]["samesite"] = samesite
    cookie_val = cookie.output(header="").strip()
    return HttpHeader("set-cookie", cookie_val)

In [None]:
@rt("/setcookie")
def get(req): return cookie('now', datetime.now())

@rt("/getcookie")
def get(now:parsed_date): return f'Cookie was set at time {now.time()}'

print(cli.get('/setcookie').text)
time.sleep(0.01)
cli.get('/getcookie').text




'Cookie was set at time 14:12:47.159530'

In [None]:
#| export
def reg_re_param(m, s):
    cls = get_class(f'{m}Conv', sup=StringConvertor, regex=s)
    register_url_convertor(m, cls())

In [None]:
#| export
# Starlette doesn't have the '?', so it chomps the whole remaining URL
reg_re_param("path", ".*?")
reg_re_param("static", "ico|gif|jpg|jpeg|webm|css|js|woff|png|svg|mp4|webp|ttf|otf|eot|woff2|txt|html|map|pdf")

@patch
def static_route_exts(self:FastHTML, prefix='/', static_path='.', exts='static'):
    "Add a static route at URL path `prefix` with files from `static_path` and `exts` defined by `reg_re_param()`"
    @self.route(f"{prefix}{{fname:path}}.{{ext:{exts}}}")
    async def get(fname:str, ext:str): return FileResponse(f'{static_path}/{fname}.{ext}')

In [None]:
reg_re_param("imgext", "ico|gif|jpg|jpeg|webm|pdf")

@rt(r'/static/{path:path}{fn}.{ext:imgext}')
def get(fn:str, path:str, ext:str): return f"Getting {fn}.{ext} from /{path}"

test_r(cli, '/static/foo/jph.me.ico', 'Getting jph.me.ico from /foo/')

In [None]:
app.static_route_exts()
assert 'These are the source notebooks for FastHTML' in cli.get('/README.txt').text

In [None]:
#| export
@patch
def static_route(self:FastHTML, ext='', prefix='/', static_path='.'):
    "Add a static route at URL path `prefix` with files from `static_path` and single `ext` (including the '.')"
    @self.route(f"{prefix}{{fname:path}}{ext}")
    async def get(fname:str): return FileResponse(f'{static_path}/{fname}{ext}')

In [None]:
app.static_route('.md', static_path='../..')
assert 'THIS FILE WAS AUTOGENERATED' in cli.get('/README.md').text

In [None]:
#| export
class MiddlewareBase:
    async def __call__(self, scope, receive, send) -> None:
        if scope["type"] not in ["http", "websocket"]:
            await self._app(scope, receive, send)
            return
        return HTTPConnection(scope)

In [None]:
#| export
class FtResponse:
    "Wrap an FT response with any Starlette `Response`"
    def __init__(self, content, status_code:int=200, headers=None, cls=HTMLResponse, media_type:str|None=None):
        self.content,self.status_code,self.headers = content,status_code,headers
        self.cls,self.media_type = cls,media_type

    def __response__(self, req):
        cts,httphdrs,tasks = _xt_cts(req, self.content)
        headers = {**(self.headers or {}), **httphdrs}
        return self.cls(cts, status_code=self.status_code, headers=headers, media_type=self.media_type, background=tasks)

In [None]:
@rt('/ftr')
def get():
    cts = Title('Foo'),H1('bar')
    return FtResponse(cts, status_code=201, headers={'Location':'/foo/1'})

r = cli.get('/ftr')

test_eq(r.status_code, 201)
test_eq(r.headers['location'], '/foo/1')
txt = r.text
assert '<title>Foo</title>' in txt and '<h1>bar</h1>' in txt and '<html>' in txt

In [None]:
#| export
def unqid():
    res = b64encode(uuid4().bytes)
    return '_' + res.decode().rstrip('=').translate(str.maketrans('+/', '_-'))

In [None]:
#| export
def _add_ids(s):
    if not isinstance(s, FT): return
    if not getattr(s, 'id', None): s.id = unqid()
    for c in s.children: _add_ids(c)

In [None]:
#| export
def setup_ws(app, f=noop):
    conns = {}
    async def on_connect(scope, send): conns[scope.client] = send
    async def on_disconnect(scope): conns.pop(scope.client)
    app.ws('/ws', conn=on_connect, disconn=on_disconnect)(f)
    async def send(s):
        for o in conns.values(): await o(s)
    app._send = send
    return send

## Export -

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()