In [None]:
#| default_exp core

# Core

This is the source code to fasthtml. You won't need to read this unless you want to understand how things are built behind the scenes, or need full details of a particular API. The notebook is converted to the Python module [fasthtml/core.py](https://github.com/AnswerDotAI/fasthtml/blob/main/fasthtml/core.py) using [nbdev](https://nbdev.fast.ai/).

## Imports and utils

In [None]:
#| export
import json,uuid,inspect,types

from fastcore.utils import *
from fastcore.xml import *

from types import UnionType, SimpleNamespace as ns, GenericAlias
from typing import Optional, get_type_hints, get_args, get_origin, Union, Mapping, TypedDict, List, Any
from datetime import datetime
from dataclasses import dataclass,fields,is_dataclass,MISSING,asdict
from collections import namedtuple
from inspect import isfunction,ismethod,Parameter,get_annotations
from functools import wraps, partialmethod
from http import cookies
from urllib.parse import urlencode, parse_qs, quote, unquote
from copy import copy,deepcopy
from warnings import warn
from dateutil import parser as dtparse
from starlette.requests import HTTPConnection

from fasthtml.starlette import *

empty = Parameter.empty

In [None]:
import time

from IPython import display
from enum import Enum
from pprint import pprint

from fastcore.test import *
from starlette.testclient import TestClient
from starlette.requests import Headers
from starlette.datastructures import UploadFile
from uuid import UUID

We write source code _first_, and then tests come _after_. The tests serve as both a means to confirm that the code works and also serves as working examples. The first declared function, `is_typeddict`, is an example of this pattern.

In [None]:
#| export
def _sig(f): return signature_ex(f, True)

In [None]:
#| export
def is_typeddict(cls:type)->bool:
    "Check if `cls` is a `TypedDict`"
    attrs = 'annotations', 'required_keys', 'optional_keys'
    return isinstance(cls, type) and all(hasattr(cls, f'__{attr}__') for attr in attrs)

In [None]:
class MyDict(TypedDict): name:str

assert is_typeddict(MyDict)
assert not is_typeddict({'a':1})

In [None]:
#| export
def is_namedtuple(cls):
    "`True` if `cls` is a namedtuple type"
    return issubclass(cls, tuple) and hasattr(cls, '_fields')

In [None]:
assert is_namedtuple(namedtuple('tst', ['a']))
assert not is_namedtuple(tuple)

In [None]:
#| export
def date(s:str):
    "Convert `s` to a datetime"
    return dtparse.parse(s)

In [None]:
date('2pm')

datetime.datetime(2024, 8, 13, 14, 0)

In [None]:
#| export
def snake2hyphens(s:str):
    "Convert `s` from snake case to hyphenated and capitalised"
    s = snake2camel(s)
    return camel2words(s, '-')

In [None]:
snake2hyphens("snake_case")

'Snake-Case'

In [None]:
#| export
htmx_hdrs = dict(
    boosted="HX-Boosted",
    current_url="HX-Current-URL",
    history_restore_request="HX-History-Restore-Request",
    prompt="HX-Prompt",
    request="HX-Request",
    target="HX-Target",
    trigger_name="HX-Trigger-Name",
    trigger="HX-Trigger")

@dataclass
class HtmxHeaders:
    boosted:str|None=None; current_url:str|None=None; history_restore_request:str|None=None; prompt:str|None=None
    request:str|None=None; target:str|None=None; trigger_name:str|None=None; trigger:str|None=None
    def __bool__(self): return any(hasattr(self,o) for o in htmx_hdrs)

def _get_htmx(h):
    res = {k:h.get(v.lower(), None) for k,v in htmx_hdrs.items()}
    return HtmxHeaders(**res)

In [None]:
def test_request(url: str='/', headers: dict={}, method: str='get') -> Request:
    scope = {
        'type': 'http',
        'method': method,
        'path': url,
        'headers': Headers(headers).raw,
        'query_string': b'',
        'scheme': 'http',
        'client': ('127.0.0.1', 8000),
        'server': ('127.0.0.1', 8000),
    }
    receive = lambda: {"body": b"", "more_body": False}
    return Request(scope, receive)

In [None]:
h = test_request(headers=Headers({'HX-Request':'1'}))
_get_htmx(h.headers)

HtmxHeaders(boosted=None, current_url=None, history_restore_request=None, prompt=None, request='1', target=None, trigger_name=None, trigger=None)

In [None]:
#| export
def str2int(s)->int:
    "Convert `s` to an `int`"
    s = s.lower()
    if s=='on': return 1
    if s=='none': return 0
    return 0 if not s else int(s)

In [None]:
str2int('1'),str2int('none')

(1, 0)

In [None]:
#| export
def _mk_list(t, v): return [t(o) for o in v]

## Request and response

In [None]:
#| export
fh_cfg = AttrDict(indent=True)

In [None]:
#| export
def _fix_anno(t):
    "Create appropriate callable type for casting a `str` to type `t` (or first type in `t` if union)"
    origin = get_origin(t)
    if origin is Union or origin is UnionType or origin in (list,List):
        t = first(o for o in get_args(t) if o!=type(None))
    d = {bool: str2bool, int: str2int}
    res = d.get(t, t)
    if origin in (list,List): res = partial(_mk_list, res)
    return res

In [None]:
test_eq(_fix_anno(Union[str,None]), str)
test_eq(_fix_anno(float), float)
test_eq(_fix_anno(int)('1'), 1)
test_eq(_fix_anno(list[int])(['1','2']), [1,2])
test_eq(_fix_anno(list[int])('1'), [1])

In [None]:
#| export
def _form_arg(k, v, d):
    "Get type by accessing key `k` from `d`, and use to cast `v`"
    if v is None: return
    if not isinstance(v, (str,list,tuple)): return v
    # This is the type we want to cast `v` to
    anno = d.get(k, None)
    if not anno: return v
    return _fix_anno(anno)(v)

In [None]:
d = dict(k=int, l=List[int])
test_eq(_form_arg('k', "1", d), 1)
test_eq(_form_arg('l', "1", d), [1])
test_eq(_form_arg('l', ["1","2"], d), [1,2])

In [None]:
#| export
@dataclass
class HttpHeader: k:str;v:str

In [None]:
#| export
def _annotations(anno):
    "Same as `get_annotations`, but also works on namedtuples"
    if is_namedtuple(anno): return {o:str for o in anno._fields}
    return get_annotations(anno)

In [None]:
#| export
def _is_body(anno): return issubclass(anno, (dict,ns)) or _annotations(anno)

In [None]:
#| export
def _formitem(form, k):
    "Return single item `k` from `form` if len 1, otherwise return list"
    o = form.getlist(k)
    return o[0] if len(o) == 1 else o if o else None

In [None]:
#| export
def form2dict(form: FormData) -> dict:
    "Convert starlette form data to a dict"
    return {k: _formitem(form, k) for k in form}

In [None]:
d = [('a',1),('a',2),('b',0)]
fd = FormData(d)
res = form2dict(fd)
test_eq(res['a'], [1,2])
test_eq(res['b'], 0)

In [None]:
#| export
async def _from_body(req, p):
    anno = p.annotation
    # Get the fields and types of type `anno`, if available
    d = _annotations(anno)
    if req.headers.get('content-type', None)=='application/json': data = await req.json()
    else: data = form2dict(await req.form())
    cargs = {k: _form_arg(k, v, d) for k, v in data.items() if not d or k in d}
    return anno(**cargs)

In [None]:
async def f(req):
    def _f(p:HttpHeader): ...
    p = first(_sig(_f).parameters.values())
    result = await _from_body(req, p)
    return JSONResponse(result.__dict__)

app = Starlette(routes=[Route('/', f, methods=['POST'])])
client = TestClient(app)

d = dict(k='value1',v=['value2','value3'])
response = client.post('/', data=d)
print(response.json())

{'k': 'value1', 'v': "['value2', 'value3']"}


In [None]:
#| export
async def _find_p(req, arg:str, p:Parameter):
    "In `req` find param named `arg` of type in `p` (`arg` is ignored for body types)"
    anno = p.annotation
    # If there's an annotation of special types, return object of that type
    # GenericAlias is a type of typing for iterators like list[int] that is not a class
    if isinstance(anno, type) and not isinstance(anno, GenericAlias):
        if issubclass(anno, Request): return req
        if issubclass(anno, HtmxHeaders): return _get_htmx(req.headers)
        if issubclass(anno, Starlette): return req.scope['app']
        if _is_body(anno): return await _from_body(req, p)
    # If there's no annotation, check for special names
    if anno is empty:
        if 'request'.startswith(arg.lower()): return req
        if 'session'.startswith(arg.lower()): return req.scope.get('session', {})
        if arg.lower()=='auth': return req.scope.get('auth', None)
        if arg.lower()=='htmx': return _get_htmx(req.headers)
        if arg.lower()=='app': return req.scope['app']
        if arg.lower() in ('hdrs','ftrs','bodykw','htmlkw'): return getattr(req, arg.lower())
        warn(f"`{arg} has no type annotation and is not a recognised special name, so is ignored.")
        return None
    # Look through path, cookies, headers, query, and body in that order
    res = req.path_params.get(arg, None)
    if res in (empty,None): res = req.cookies.get(arg, None)
    if res in (empty,None): res = req.headers.get(snake2hyphens(arg), None)
    if res in (empty,None): res = req.query_params.get(arg, None)
    if res in (empty,None):
        frm = await req.form()
        res = _formitem(frm, arg)
    # Raise 400 error if the param does not include a default
    if (res in (empty,None)) and p.default is empty: raise HTTPException(400, f"Missing required field: {arg}")
    # If we have a default, return that if we have no value
    if res in (empty,None): res = p.default
    # We can cast str and list[str] to types; otherwise just return what we have
    if not isinstance(res, (list,str)) or anno is empty: return res
    anno = _fix_anno(anno)
    try: return anno(res)
    except ValueError: raise HTTPException(404, req.url.path) from None

async def _wrap_req(req, params):
    return [await _find_p(req, arg, p) for arg,p in params.items()]

In [None]:
def g(req, this:Starlette, a:str, b:HttpHeader): ...

async def f(req):
    a = await _wrap_req(req, _sig(g).parameters)
    return Response(str(a))

app = Starlette(routes=[Route('/', f, methods=['POST'])])
client = TestClient(app)

response = client.post('/?a=1', data=d)
print(response.text)

[<starlette.requests.Request object>, <starlette.applications.Starlette object>, '1', HttpHeader(k='value1', v="['value2', 'value3']")]


In [None]:
#| export
def flat_xt(lst):
    "Flatten lists, except for `FT`s"
    result = []
    if isinstance(lst,(FT,str)): lst=[lst]
    for item in lst:
        if isinstance(item, (list,tuple)) and not isinstance(item, FT): result.extend(item)
        else: result.append(item)
    return result

In [None]:
x = ft('a',1)
test_eq(flat_xt([x, x, [x,x]]), [x]*4)
test_eq(flat_xt(x), [x])

In [None]:
#| export
class Beforeware:
    def __init__(self, f, skip=None): self.f,self.skip = f,skip or []

In [None]:
#| export
async def _handle(f, args, **kwargs):
    return (await f(*args, **kwargs)) if is_async_callable(f) else await run_in_threadpool(f, *args, **kwargs)

## Websockets

In [None]:
#| export
def _find_wsp(ws, data, hdrs, arg:str, p:Parameter):
    "In `data` find param named `arg` of type in `p` (`arg` is ignored for body types)"
    anno = p.annotation
    if isinstance(anno, type):
        if issubclass(anno, HtmxHeaders): return _get_htmx(hdrs)
        if issubclass(anno, Starlette): return ws.scope['app']
    if anno is empty:
        if arg.lower()=='ws': return ws
        if arg.lower()=='data': return data
        if arg.lower()=='htmx': return _get_htmx(hdrs)
        if arg.lower()=='app': return ws.scope['app']
        if arg.lower()=='send': return partial(_send_ws, ws)
        return None
    res = data.get(arg, None)
    if res is empty or res is None: res = hdrs.get(snake2hyphens(arg), None)
    if res is empty or res is None: res = p.default
    # We can cast str and list[str] to types; otherwise just return what we have
    if not isinstance(res, (list,str)) or anno is empty: return res
    anno = _fix_anno(anno)
    return [anno(o) for o in res] if isinstance(res,list) else anno(res)

def _wrap_ws(ws, data, params):
    hdrs = data.pop('HEADERS', {})
    return [_find_wsp(ws, data, hdrs, arg, p) for arg,p in params.items()]

In [None]:
#| export
async def _send_ws(ws, resp):
    if not resp: return
    res = to_xml(resp, indent=fh_cfg.indent) if isinstance(resp, (list,tuple)) or hasattr(resp, '__ft__') else resp
    await ws.send_text(res)

def _ws_endp(recv, conn=None, disconn=None, hdrs=None, before=None):
    cls = type('WS_Endp', (WebSocketEndpoint,), {"encoding":"text"})
    
    async def _generic_handler(handler, ws, data=None):
        wd = _wrap_ws(ws, loads(data) if data else {}, _sig(handler).parameters)
        resp = await _handle(handler, wd)
        if resp: await _send_ws(ws, resp)

    async def _connect(self, ws):
        await ws.accept()
        await _generic_handler(conn, ws)

    async def _disconnect(self, ws, close_code): await _generic_handler(disconn, ws)
    async def _recv(self, ws, data): await _generic_handler(recv, ws, data)

    if    conn: cls.on_connect    = _connect
    if disconn: cls.on_disconnect = _disconnect
    cls.on_receive = _recv
    return cls

In [None]:
def on_receive(self, msg:str): return f"Message text was: {msg}"
c = _ws_endp(on_receive)
app = Starlette(routes=[WebSocketRoute('/', _ws_endp(on_receive))])

cli = TestClient(app)
with cli.websocket_connect('/') as ws:
    ws.send_text('{"msg":"Hi!"}')
    data = ws.receive_text()
    assert data == 'Message text was: Hi!'

## Routing and application

In [None]:
#| export
class WS_RouteX(WebSocketRoute):
    def __init__(self, path:str, recv, conn:callable=None, disconn:callable=None, *,
                 name=None, middleware=None, hdrs=None, before=None):
        super().__init__(path, _ws_endp(recv, conn, disconn, hdrs, before), name=name, middleware=middleware)

In [None]:
#| export
def uri(_arg, **kwargs):
    return f"{quote(_arg)}/{urlencode(kwargs, doseq=True)}"

In [None]:
#| export
def decode_uri(s): 
    arg,_,kw = s.partition('/')
    return unquote(arg), {k:v[0] for k,v in parse_qs(kw).items()}

In [None]:
#| export
from starlette.convertors import StringConvertor

In [None]:
#| export
StringConvertor.regex = "[^/]*"  # `+` replaced with `*`

@patch
def to_string(self:StringConvertor, value: str) -> str:
    value = str(value)
    assert "/" not in value, "May not contain path separators"
    # assert value, "Must not be empty"  # line removed due to errors
    return value

In [None]:
#| export
@patch
def url_path_for(self:HTTPConnection, name: str, **path_params):
    router: Router = self.scope["router"]
    return router.url_path_for(name, **path_params)

In [None]:
#| export
_verbs = dict(get='hx-get', post='hx-post', put='hx-post', delete='hx-delete', patch='hx-patch', link='href')

def _url_for(req, t):
    if callable(t): t = t.__routename__
    kw = {}
    if t.find('/')>-1 and (t.find('?')<0 or t.find('/')<t.find('?')): t,kw = decode_uri(t)
    t,m,q = t.partition('?')    
    return f"{req.url_path_for(t, **kw)}{m}{q}"

def _find_targets(req, resp):
    if isinstance(resp, tuple):
        for o in resp: _find_targets(req, o)
    if isinstance(resp, FT):
        for o in resp.children: _find_targets(req, o)
        for k,v in _verbs.items():
            t = resp.attrs.pop(k, None)
            if t: resp.attrs[v] = _url_for(req, t)

def _to_xml(req, resp, indent):
    _find_targets(req, resp)
    return to_xml(resp, indent)

In [None]:
#| export
def _xt_resp(req, resp):
    if not isinstance(resp, tuple): resp = (resp,)
    resp = resp + tuple(getattr(req, 'injects', ()))
    http_hdrs,resp = partition(resp, risinstance(HttpHeader))
    http_hdrs = {o.k:str(o.v) for o in http_hdrs}
    hdr_tags = 'title','meta','link','style','base'
    titles,bdy = partition(resp, lambda o: getattr(o, 'tag', '') in hdr_tags)
    if resp and 'hx-request' not in req.headers and not any(getattr(o, 'tag', '')=='html' for o in resp):
        if not titles: titles = [Title('FastHTML page')]
        resp = Html(Head(*titles, *flat_xt(req.hdrs)), Body(bdy, *flat_xt(req.ftrs), **req.bodykw), **req.htmlkw)
    return HTMLResponse(_to_xml(req, resp, indent=fh_cfg.indent), headers=http_hdrs)

In [None]:
#| export
def _resp(req, resp, cls=empty):
    if not resp: resp=()
    if cls in (Any,FT): cls=empty
    if isinstance(resp, FileResponse) and not os.path.exists(resp.path): raise HTTPException(404, resp.path)
    if isinstance(resp, Response): return resp
    if cls is not empty: return cls(resp)
    if isinstance(resp, (list,tuple,HttpHeader)) or hasattr(resp, '__ft__'): return _xt_resp(req, resp)
    if isinstance(resp, str): cls = HTMLResponse
    elif isinstance(resp, Mapping): cls = JSONResponse
    else:
        resp = str(resp)
        cls = HTMLResponse
    return cls(resp)

In [None]:
#| export
async def _wrap_call(f, req, params):
    wreq = await _wrap_req(req, params)
    return await _handle(f, wreq)

In [None]:
#| export
class RouteX(Route):
    def __init__(self, path:str, endpoint, *, methods=None, name=None, include_in_schema=True, middleware=None,
                hdrs=None, ftrs=None, before=None, after=None, htmlkw=None, **bodykw):
        self.sig = _sig(endpoint)
        self.f,self.hdrs,self.ftrs,self.before,self.after,self.htmlkw,self.bodykw = endpoint,hdrs,ftrs,before,after,htmlkw,bodykw
        super().__init__(path, self._endp, methods=methods, name=name, include_in_schema=include_in_schema, middleware=middleware)

    async def _endp(self, req):
        resp = None
        req.injects = []
        req.hdrs,req.ftrs,req.htmlkw,req.bodykw = map(deepcopy, (self.hdrs,self.ftrs,self.htmlkw,self.bodykw))
        req.hdrs,req.ftrs = list(req.hdrs),list(req.ftrs)
        for b in self.before:
            if not resp:
                if isinstance(b, Beforeware): bf,skip = b.f,b.skip
                else: bf,skip = b,[]
                if not any(re.fullmatch(r, req.url.path) for r in skip):
                    resp = await _wrap_call(bf, req, _sig(bf).parameters)
        if not resp: resp = await _wrap_call(self.f, req, self.sig.parameters)
        for a in self.after:
            _,*wreq = await _wrap_req(req, _sig(a).parameters)
            nr = a(resp, *wreq)
            if nr: resp = nr
        return _resp(req, resp, self.sig.return_annotation)

In [None]:
#| export
class RouterX(Router):
    def __init__(self, routes=None, redirect_slashes=True, default=None, on_startup=None, on_shutdown=None,
                 lifespan=None, *, middleware=None, hdrs=None, ftrs=None, before=None, after=None, htmlkw=None, **bodykw):
        super().__init__(routes, redirect_slashes, default, on_startup, on_shutdown,
                 lifespan=lifespan, middleware=middleware)
        self.hdrs,self.ftrs,self.bodykw,self.htmlkw,self.before,self.after = hdrs,ftrs,bodykw,htmlkw or {},before,after

    def add_route( self, path: str, endpoint: callable, methods=None, name=None, include_in_schema=True):
        route = RouteX(path, endpoint=endpoint, methods=methods, name=name, include_in_schema=include_in_schema,
                       hdrs=self.hdrs, ftrs=self.ftrs, before=self.before, after=self.after, htmlkw=self.htmlkw, **self.bodykw)
        self.routes.append(route)

    def add_ws( self, path: str, recv: callable, conn:callable=None, disconn:callable=None, name=None):
        route = WS_RouteX(path, recv=recv, conn=conn, disconn=disconn, name=name, hdrs=self.hdrs, before=self.before)
        self.routes.append(route)

In [None]:
#| export
htmxscr   = Script(src="https://unpkg.com/htmx.org@next/dist/htmx.min.js")
htmxwsscr = Script(src="https://unpkg.com/htmx-ext-ws/ws.js")
surrsrc   = Script(src="https://cdn.jsdelivr.net/gh/answerdotai/surreal@main/surreal.js")
scopesrc  = Script(src="https://cdn.jsdelivr.net/gh/gnat/css-scope-inline@main/script.js")
viewport  = Meta(name="viewport", content="width=device-width, initial-scale=1, viewport-fit=cover")
charset   = Meta(charset="utf-8")

In [None]:
#| export
def get_key(key=None, fname='.sesskey'):
    if key: return key
    fname = Path(fname)
    if fname.exists(): return fname.read_text()
    key = str(uuid.uuid4())
    fname.write_text(key)
    return key

In [None]:
get_key()

'a604e4a2-08e8-462d-aff9-15468891fe09'

In [None]:
#| export
def _list(o): return [] if not o else list(o) if isinstance(o, (tuple,list)) else [o]

In [None]:
#| export
def _wrap_ex(f, hdrs, ftrs, htmlkw, bodykw):
    async def _f(req, exc):
        req.hdrs,req.ftrs,req.htmlkw,req.bodykw = map(deepcopy, (hdrs, ftrs, htmlkw, bodykw))
        res = await _handle(f, (req, exc))
        return _resp(req, res)
    return _f

In [None]:
#| export
class _SessionMiddleware(SessionMiddleware):
    "Same as Starlette's `SessionMiddleware`, but wraps `session` in an AttrDict"
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"] not in ("http", "websocket"):
            await self.app(scope, receive, send)
            return

        async def receive_wrapper():
            message = await receive()
            if "session" in scope and not isinstance(scope["session"], AttrDict):
                scope["session"] = AttrDict(scope["session"])
            return message

        await super().__call__(scope, receive_wrapper, send)

In [None]:
#| export
class FastHTML(Starlette):
    def __init__(self, debug=False, routes=None, middleware=None, exception_handlers=None,
                 on_startup=None, on_shutdown=None, lifespan=None, hdrs=None, ftrs=None,
                 before=None, after=None, ws_hdr=False,
                 surreal=True, htmx=True, default_hdrs=True,
                 secret_key=None, session_cookie='session_', max_age=365*24*3600, sess_path='/',
                 same_site='lax', sess_https_only=False, sess_domain=None, key_fname='.sesskey',
                 htmlkw=None, **bodykw):
        middleware,before,after = map(_list, (middleware,before,after))
        secret_key = get_key(secret_key, key_fname)
        sess = Middleware(_SessionMiddleware, secret_key=secret_key,session_cookie=session_cookie,
                          max_age=max_age, path=sess_path, same_site=same_site,
                          https_only=sess_https_only, domain=sess_domain)
        middleware.append(sess)
        hdrs,ftrs = listify(hdrs),listify(ftrs)
        htmlkw = htmlkw or {}
        if default_hdrs:
            if surreal: hdrs = [surrsrc,scopesrc] + hdrs
            if ws_hdr: hdrs = [htmxwsscr] + hdrs
            if htmx: hdrs = [htmxscr] + hdrs
            hdrs = [charset, viewport] + hdrs
        excs = {k:_wrap_ex(v, hdrs, ftrs, htmlkw, bodykw) for k,v in (exception_handlers or {}).items()}
        super().__init__(debug, routes, middleware, excs, on_startup, on_shutdown, lifespan=lifespan)
        self.router = RouterX(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan,
                              hdrs=hdrs, ftrs=ftrs, before=before, after=after, htmlkw=htmlkw, **bodykw)

    def route(self, path:str=None, methods=None, name=None, include_in_schema=True):
        "Add a route at `path`; the function name is the default method"
        pathstr = None if callable(path) else path
        def f(func):
            n,fn,p = name,func.__name__,pathstr
            if methods: m = [methods] if isinstance(methods,str) else methods
            else: m = [fn] if fn in _verbs else ['get'] if fn=='index' else ['post']
            if not n: n = fn
            if not p: p = '/'+('' if fn=='index' else fn)
            self.router.add_route(p, func, methods=m, name=n, include_in_schema=include_in_schema)
            func.__routename__ = n
            return func
        return f(path) if callable(path) else f

    def ws(self, path:str, conn=None, disconn=None, name=None):
        def f(func):
            self.router.add_ws(path, func, conn=conn, disconn=disconn, name=name)
            return func
        return f

all_meths = 'get post put delete patch head trace options'.split()
for o in all_meths: setattr(FastHTML, o, partialmethod(FastHTML.route, methods=o))

## Extras

In [None]:
#| export
def cookie(key: str, value="", max_age=None, expires=None, path="/", domain=None, secure=False, httponly=False, samesite="lax",):
    "Create a 'set-cookie' `HttpHeader`"
    cookie = cookies.SimpleCookie()
    cookie[key] = value
    if max_age is not None: cookie[key]["max-age"] = max_age
    if expires is not None:
        cookie[key]["expires"] = format_datetime(expires, usegmt=True) if isinstance(expires, datetime) else expires
    if path is not None: cookie[key]["path"] = path
    if domain is not None: cookie[key]["domain"] = domain
    if secure: cookie[key]["secure"] = True
    if httponly: cookie[key]["httponly"] = True
    if samesite is not None:
        assert samesite.lower() in [ "strict", "lax", "none", ], "must be 'strict', 'lax' or 'none'"
        cookie[key]["samesite"] = samesite
    cookie_val = cookie.output(header="").strip()
    return HttpHeader("set-cookie", cookie_val)

In [None]:
#| export
def reg_re_param(m, s):
    cls = get_class(f'{m}Conv', sup=StringConvertor, regex=s)
    register_url_convertor(m, cls())

In [None]:
#| export
# Starlette doesn't have the '?', so it chomps the whole remaining URL
reg_re_param("path", ".*?")
reg_re_param("static", "ico|gif|jpg|jpeg|webm|css|js|woff|png|svg|mp4|webp|ttf|otf|eot|woff2|txt|html")

In [None]:
#| export
class MiddlewareBase:
    async def __call__(self, scope, receive, send) -> None:
        if scope["type"] not in ["http", "websocket"]:
            await self.app(scope, receive, send)
            return
        return HTTPConnection(scope)

## Tests

In [None]:
def get_cli(app): return app,TestClient(app),app.route

In [None]:
app,cli,rt = get_cli(FastHTML(secret_key='soopersecret'))

In [None]:
@rt("/hi")
def get(): return 'Hi there'

r = cli.get('/hi')
r.text

'Hi there'

In [None]:
@rt("/hi")
def post(): return 'Postal'

cli.post('/hi').text

'Postal'

In [None]:
@app.get("/hostie")
def show_host(req): return req.headers['host']

cli.get('/hostie').text

'testserver'

In [None]:
@rt
def yoyo(): return 'a yoyo'

cli.post('/yoyo').text

'a yoyo'

In [None]:
@app.get
def autopost(): return Html(Div('Text.', post=yoyo))
print(cli.get('/autopost').text)

<!doctype html>

<html>
  <div hx-post="/yoyo">Text.</div>
</html>



In [None]:
@app.get
def autopost2(): return Html(Body(Div('Text.', cls='px-2', post='show_host?a=b')))
print(cli.get('/autopost2').text)

<!doctype html>

<html>
  <body>
    <div class="px-2" hx-post="/hostie?a=b">Text.</div>
  </body>
</html>



In [None]:
@app.get
def autoget2(): return Html(Div('Text.', get=show_host))
print(cli.get('/autoget2').text)

<!doctype html>

<html>
  <div hx-get="/hostie">Text.</div>
</html>



In [None]:
@rt('/user/{nm}', name='gday')
def get(nm:str=''): return f"Good day to you, {nm}!"
cli.get('/user/Alexis').text

'Good day to you, Alexis!'

In [None]:
@app.get
def autolink(): return Html(Div('Text.', link=uri('gday', nm='Alexis')))
print(cli.get('/autolink').text)

<!doctype html>

<html>
  <div href="/user/Alexis">Text.</div>
</html>



In [None]:
@rt('/link')
def get(req): return f"{req.url_for('gday', nm='Alexis')}; {req.url_for('show_host')}"

cli.get('/link').text

'http://testserver/user/Alexis; http://testserver/hostie'

In [None]:
test_eq(app.router.url_path_for('gday', nm='Jeremy'), '/user/Jeremy')

In [None]:
hxhdr = {'headers':{'hx-request':"1"}}

@rt('/ft')
def get(): return Title('Foo'),H1('bar')

txt = cli.get('/ft').text
assert '<title>Foo</title>' in txt and '<h1>bar</h1>' in txt and '<html>' in txt

@rt('/xt2')
def get(): return H1('bar')

txt = cli.get('/xt2').text
assert '<title>FastHTML page</title>' in txt and '<h1>bar</h1>' in txt and '<html>' in txt

assert cli.get('/xt2', **hxhdr).text.strip() == '<h1>bar</h1>'

@rt('/xt3')
def get(): return Html(Head(Title('hi')), Body(P('there')))

txt = cli.get('/xt3').text
assert '<title>FastHTML page</title>' not in txt and '<title>hi</title>' in txt and '<p>there</p>' in txt

In [None]:
@rt('/oops')
def get(nope): return nope
test_warns(lambda: cli.get('/oops?nope=1'))

In [None]:
def test_r(cli, path, exp, meth='get', hx=False, **kwargs):
    if hx: kwargs['headers'] = {'hx-request':"1"}
    test_eq(getattr(cli, meth)(path, **kwargs).text, exp)

app.chk = 'foo'
ModelName = str_enum('ModelName', "alexnet", "resnet", "lenet")
fake_db = [{"name": "Foo"}, {"name": "Bar"}]

In [None]:
@rt('/html/{idx}')
async def get(idx:int): return Body(H4(f'Next is {idx+1}.'))

reg_re_param("imgext", "ico|gif|jpg|jpeg|webm")

@rt(r'/static/{path:path}{fn}.{ext:imgext}')
def get(fn:str, path:str, ext:str): return f"Getting {fn}.{ext} from /{path}"

@rt("/models/{nm}")
def get(nm:ModelName): return nm

@rt("/files/{path}")
async def get(path: Path): return path.with_suffix('.txt')

@rt("/items/")
def get(idx:int|None = 0): return fake_db[idx]

In [None]:
test_r(cli, '/html/1', '<body>\n  <h4>Next is 2.</h4>\n</body>\n', hx=True)
test_r(cli, '/static/foo/jph.ico', 'Getting jph.ico from /foo/')
test_r(cli, '/models/alexnet', 'alexnet')
test_r(cli, '/files/foo', 'foo.txt')
test_r(cli, '/items/?idx=1', '{"name":"Bar"}')
test_r(cli, '/items/', '{"name":"Foo"}')
assert cli.get('/items/?idx=g').status_code==404

In [None]:
@app.get("/booly/")
def _(coming:bool=True): return 'Coming' if coming else 'Not coming'

@app.get("/datie/")
def _(d:date): return d

@app.get("/ua")
async def _(user_agent:str): return user_agent

@app.get("/hxtest")
def _(htmx): return htmx.request

@app.get("/hxtest2")
def _(foo:HtmxHeaders, req): return foo.request

@app.get("/app")
def _(app): return app.chk

@app.get("/app2")
def _(foo:FastHTML): return foo.chk,HttpHeader("mykey", "myval")

In [None]:
test_r(cli, '/booly/?coming=true', 'Coming')
test_r(cli, '/booly/?coming=no', 'Not coming')
date_str = "17th of May, 2024, 2p"
test_r(cli, f'/datie/?d={date_str}', '2024-05-17 14:00:00')
test_r(cli, '/ua', 'FastHTML', headers={'User-Agent':'FastHTML'})
test_r(cli, '/hxtest' , '1', headers={'HX-Request':'1'})
test_r(cli, '/hxtest2', '1', headers={'HX-Request':'1'})
test_r(cli, '/app' , 'foo')

In [None]:
r = cli.get('/app2', **hxhdr)
test_eq(r.text, 'foo\n')
test_eq(r.headers['mykey'], 'myval')

In [None]:
@app.post('/profile/me')
def profile_update(username: str): return username

test_r(cli, '/profile/me', 'Alexis', 'post', data={'username' : 'Alexis'})
test_r(cli, '/profile/me', 'Missing required field: username', 'post', data={})

In [None]:
# Example post request with parameter that has a default value
@app.post('/pet/dog')
def pet_dog(dogname: str = None): return dogname

# Working post request with optional parameter
test_r(cli, '/pet/dog', '', 'post', data={})

In [None]:
@dataclass
class Bodie: a:int;b:str

@rt("/bodie/{nm}")
def post(nm:str, data:Bodie):
    res = asdict(data)
    res['nm'] = nm
    return res

@app.post("/bodied/")
def bodied(data:dict): return data

nt = namedtuple('Bodient', ['a','b'])

@app.post("/bodient/")
def bodient(data:nt): return data._asdict()

class BodieTD(TypedDict): a:int;b:str='foo'

@app.post("/bodietd/")
def bodient(data:BodieTD): return data

class Bodie2:
    a:int|None; b:str
    def __init__(self, a, b='foo'): store_attr()

@app.post("/bodie2/")
def bodie(d:Bodie2): return f"a: {d.a}; b: {d.b}"

In [None]:
from fasthtml.xtend import Titled

In [None]:
# Testing POST with Content-Type: application/json
@app.post("/")
def index(it: Bodie): return Titled("It worked!", P(f"{it.a}, {it.b}"))

s = json.dumps({"b": "Lorem", "a": 15})
response = cli.post('/', headers={"Content-Type": "application/json"}, data=s).text
assert "<title>It worked!</title>" in response and "<p>15, Lorem</p>" in response

In [None]:
d = dict(a=1, b='foo')

test_r(cli, '/bodie/me', '{"a":1,"b":"foo","nm":"me"}', 'post', data=dict(a=1, b='foo', nm='me'))
test_r(cli, '/bodied/', '{"a":"1","b":"foo"}', 'post', data=d)
test_r(cli, '/bodie2/', 'a: 1; b: foo', 'post', data={'a':1})
test_r(cli, '/bodient/', '{"a":"1","b":"foo"}', 'post', data=d)
test_r(cli, '/bodietd/', '{"a":1,"b":"foo"}', 'post', data=d)

In [None]:
@rt("/setcookie")
def get(req): return cookie('now', datetime.now())

@rt("/getcookie")
def get(now:date): return f'Cookie was set at time {now.time()}'

In [None]:
print(cli.get('/setcookie').text)
time.sleep(0.01)
cli.get('/getcookie').text




'Cookie was set at time 05:27:23.743249'

In [None]:
@rt("/setsess")
def get(sess):
    now = datetime.now()
    sess['noo'] = str(now)
    return f'Set to {now}'

@rt("/getsess")
def get(noo:date): return f'Session time: {noo.time()}'

In [None]:
print(cli.get('/setsess').text)
time.sleep(0.01)
cli.get('/getsess').text

Set to 2024-08-13 05:27:23.793683


'Missing required field: noo'

In [None]:
@rt("/upload")
async def post(uploadfile:UploadFile): return (await uploadfile.read()).decode()

fn = '../../CHANGELOG.md'
data = {'message': 'Hello, world!'}
with open(fn, 'rb') as f:
    print(cli.post('/upload', files={'uploadfile': f}, data=data).text[:80])

# Release notes

<!-- do not remove -->


## 0.3.4

### New Features

- Experime


In [None]:
@rt("/upload")
async def post(uploadfile:UploadFile): return (await uploadfile.read()).decode()

fn = '../../CHANGELOG.md'
data = {'message': 'Hello, world!'}
with open(fn, 'rb') as f:
    print(cli.post('/upload', files={'uploadfile': f}, data=data).text[:80])

# Release notes

<!-- do not remove -->


## 0.3.4

### New Features

- Experime


In [None]:
@rt("/{fname:path}.{ext:static}")
async def get(fname:str, ext:str): return FileResponse(f'{fname}.{ext}')

In [None]:
assert 'These are the source notebooks for FastHTML' in cli.get('/README.txt').text

In [None]:
from fasthtml.authmw import user_pwd_auth

In [None]:
def _not_found(req, exc): return Div('nope')

In [None]:
app,cli,rt = get_cli(FastHTML(exception_handlers={404:_not_found}))

In [None]:
txt = cli.get('/').text
assert '<div>nope</div>' in txt
assert '<!doctype html>' in txt

In [None]:
auth = user_pwd_auth(testuser='spycraft')
app,cli,rt = get_cli(FastHTML(middleware=[auth]))

@rt("/locked")
def get(auth): return 'Hello, ' + auth

test_eq(cli.get('/locked').text, 'not authenticated')
test_eq(cli.get('/locked', auth=("testuser","spycraft")).text, 'Hello, testuser')

In [None]:
auth = user_pwd_auth(testuser='spycraft')
app,cli,rt = get_cli(FastHTML(middleware=[auth]))

@rt("/locked")
def get(auth): return 'Hello, ' + auth

test_eq(cli.get('/locked').text, 'not authenticated')
test_eq(cli.get('/locked', auth=("testuser","spycraft")).text, 'Hello, testuser')

In [None]:
hdrs, routes = app.router.hdrs, app.routes

In [None]:
from fasthtml.live_reload import FastHTMLWithLiveReload

In [None]:
app,cli,rt = get_cli(FastHTMLWithLiveReload())

@rt("/hi")
def get(): return 'Hi there'

test_eq(cli.get('/hi').text, "Hi there")

lr_hdrs, lr_routes = app.router.hdrs, app.routes
test_eq(len(lr_hdrs), len(hdrs)+1)
assert app.LIVE_RELOAD_HEADER in lr_hdrs
test_eq(len(lr_routes), len(routes)+1)
assert app.LIVE_RELOAD_ROUTE in lr_routes

In [None]:
from fasthtml.sessionmw import session_normalize
app,cli,rt = get_cli(FastHTML())
normalizer = session_normalize()
app.user_middleware.append(normalizer)

@app.get("/")
def index(session, item_id: UUID):
    session["item_id"] = item_id
    return "OK"

response = cli.get("/?item_id=36621c53-55c3-11ef-b14b-c45ab1ddc9ad").text
assert response == "OK"

## Export -

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()