In [None]:
#| default_exp core

# FastHTML

In [None]:
#| export
import json, dateutil

from fastcore.utils import *
from fastcore.xml import *

from types import UnionType
from typing import Optional, get_type_hints, get_args, get_origin, Union, Mapping
from datetime import datetime
from dataclasses import dataclass,fields,is_dataclass,MISSING,asdict
from inspect import isfunction,ismethod,signature,Parameter
from functools import wraps, partialmethod

from starlette.applications import Starlette
from starlette.routing import Route, Mount
from starlette.responses import Response, HTMLResponse, FileResponse, JSONResponse
from starlette.requests import Request
from starlette.staticfiles import StaticFiles
from starlette.exceptions import HTTPException
from starlette._utils import is_async_callable

In [None]:
from IPython import display
from enum import Enum
from pprint import pprint

from starlette.testclient import TestClient

In [None]:
# if 'HX-Request' not in request.headers:
#     resp = wrap_root(resp, self.headtags)

In [None]:
#| export
empty = Parameter.empty

In [None]:
#| export
def _wrap_resp(resp, cls):
    if isinstance(resp, Response): return resp
    if cls is not empty: return cls(resp)
    if isinstance(resp, list): return HTMLResponse(to_xml(resp))
    if isinstance(resp, str): cls = HTMLResponse 
    elif isinstance(resp, Mapping): cls = JSONResponse 
    else:
        resp = str(resp)
        cls = HTMLResponse
    return cls(resp)

In [None]:
#| export
def _fix_anno(t):
    origin = get_origin(t)
    if origin is Union or origin is UnionType:
        t = first(o for o in get_args(t) if o!=type(None))
    if t==bool: return str2bool
    return t

In [None]:
#| export
def date(s): return dateutil.parser.parse(s)

In [None]:
#| export
def _form_arg(fld, body):
    res = body.get(fld.name, None)
    if not res: res = fld.default
    assert res is not MISSING
    anno = _fix_anno(fld.type)
    if res is not None: res = anno(res)
    return res

In [None]:
#| export
async def _from_body(req, arg, p):
    body = await req.form()
    cargs = {o.name:_form_arg(o, body) for o in fields(p.annotation)}
    return p.annotation(**cargs)

In [None]:
#| export
def snake2hyphens(s):
    s = snake2camel(s)
    return camel2words(s, '-')

In [None]:
#| export
async def _find_p(req, arg:str, p):
    if is_dataclass(p.annotation): return await _from_body(req, arg, p)
    res = req.path_params.get(arg, None)
    if not res: res = req.query_params.get(arg, None)
    if not res: res = req.cookies.get(arg, None)
    if not res: res = req.headers.get(snake2hyphens(arg), None)
    if not res: res = p.default
    if res is empty: return None
    anno = _fix_anno(p.annotation)
    if res is not None and anno is not empty: res = anno(res)
    return res

In [None]:
#| export
async def _wrap_req(req, params):
    items = [(k,v) for k,v in params.items()
             if v.annotation is not empty or v.default is not empty]
    if len(params)==1 and not items: return [req]
    return [await _find_p(req, arg, p) for arg,p in items]

In [None]:
#| export
def _wrap_ep(f):
    if not (isfunction(f) or ismethod(f)): return f
    sig = signature(f)
    params = sig.parameters
    cls = sig.return_annotation

    async def _f(req):
        req = await _wrap_req(req, params)
        resp = f(*req)
        if is_async_callable(f): resp = await resp
        return _wrap_resp(resp, cls)
    return _f

In [None]:
#| export
class RouteX(Route):
    def __init__(self, path, endpoint, *args, **kw):
        ep = _wrap_ep(endpoint)
        super().__init__(path, ep, *args, **kw)

In [None]:
#| export
class FastHTML:
    def __init__(self): self.rd = {}

    async def __call__(self, scope, recv, send):
        routes = list(self.rd.values())
        app = Starlette(debug=True, routes=routes)
        return await app(scope, recv, send)

    def add_route(self, route):
        meth = first(route.methods)
        self.rd[(route.path,meth)] = route
        
    def route(self, path, meth='GET'):
        def _inner(f):
            self.add_route(RouteX(path, f, methods=[meth]))
            return f
        return _inner

for o in 'get post put delete patch head trace options'.split():
    setattr(FastHTML, o, partialmethod(FastHTML.route, meth=o.capitalize()))

## Demo

In [None]:
def todict(req): return {k:str(v) for k,v in req.items()}

In [None]:
app = FastHTML()

@app.get("/")
def root(req): return todict(req.scope)

@app.get('/user/{nm}')
def get_nm(nm:str): return f"Good day to you, {nm}!"

In [None]:
client = TestClient(app)
r = client.get('/')
print(r.text)

{"type":"http","http_version":"1.1","method":"GET","path":"/","raw_path":"b'/'","root_path":"","scheme":"http","query_string":"b''","headers":"[(b'host', b'testserver'), (b'accept', b'*/*'), (b'accept-encoding', b'gzip, deflate, br'), (b'connection', b'keep-alive'), (b'user-agent', b'testclient')]","client":"['testclient', 50000]","server":"['testserver', 80]","extensions":"{'http.response.debug': {}}","state":"{}","app":"<starlette.applications.Starlette object>","starlette.exception_handlers":"({<class 'starlette.exceptions.HTTPException'>: <bound method ExceptionMiddleware.http_exception of <starlette.middleware.exceptions.ExceptionMiddleware object>>, <class 'starlette.exceptions.WebSocketException'>: <bound method ExceptionMiddleware.websocket_exception of <starlette.middleware.exceptions.ExceptionMiddleware object>>}, {})","router":"<starlette.routing.Router object>","endpoint":"<function _wrap_ep.<locals>._f>","path_params":"{}"}


In [None]:
client.get('/user/jph').text

'Good day to you, jph!'

In [None]:
@app.get('/html/{idx}')
async def get_html(idx:int):
    return Body(
        H4("Wow look here"),
        P(f'It looks like you are visitor {idx}! Next is {idx+1}.')
    )

In [None]:
display.HTML(client.get('/html/1').text)

In [None]:
ModelName = str_enum('ModelName', "alexnet", "resnet", "lenet")

app = FastHTML()
@app.get("/models/{nm}")
def model(nm:ModelName): return nm

@app.get("/files/{path}")
async def txt(path: Path): return path.with_suffix('.txt')

In [None]:
print(TestClient(app).get('/models/alexnet').text)

alexnet


In [None]:
print(TestClient(app).get('/files/foo').text)

foo.txt


In [None]:
fake_db = [{"name": "Foo"}, {"name": "Bar"}]

@app.get("/items/")
def read_item(idx:int|None = 0): return fake_db[idx]

In [None]:
print(TestClient(app).get('/items/?idx=1').text)

{"name":"Bar"}


In [None]:
print(TestClient(app).get('/items/').text)

{"name":"Foo"}


In [None]:
@app.get("/booly/")
def booly(coming:bool=True): return 'Coming' if coming else 'Not coming'

In [None]:
cli = TestClient(app)
print(cli.get('/booly/?coming=true').text)

Coming


In [None]:
print(cli.get('/booly/?coming=no').text)

Not coming


In [None]:
@app.get("/datie/")
def datie(d:date): return d

In [None]:
cli = TestClient(app)
date_str = "17th of May, 2024, 2p"
print(cli.get(f'/datie/?d={date_str}').text)

2024-05-17 14:00:00


In [None]:
@dataclass
class Bodie:
    a:int;b:str

In [None]:
@app.post("/bodie/{nm}/")
async def bodie(nm:str, data:Bodie):
    res = asdict(data)
    res['nm'] = nm
    return res

In [None]:
cli.post('/bodie/me', data=dict(a=1, b='foo')).text

'{"a":1,"b":"foo","nm":"me"}'

In [None]:
@app.get("/setcookie")
async def setc(req):
    now = datetime.now()
    res = Response(f'Set to {now}')
    res.set_cookie('now', str(now))
    return res

In [None]:
cli.get('/setcookie').text

'Set to 2024-05-17 15:23:30.099808'

In [None]:
@app.get("/getcookie")
async def getc(now:date): return f'Cookie was set at time {now.time()}'

In [None]:
cli.get('/getcookie').text

'Cookie was set at time 15:23:30.099808'

In [None]:
@app.get("/ua")
async def ua(user_agent:str): return user_agent

In [None]:
cli.get('/ua', headers={'User-Agent':'FastHTML'}).text

'FastHTML'

## fin -