In [None]:
#| default_exp core

# API Details

In [None]:
#| export
import json, dateutil

from fastcore.utils import *
from fastcore.xml import *

from types import UnionType, SimpleNamespace as ns
from typing import Optional, get_type_hints, get_args, get_origin, Union, Mapping
from datetime import datetime
from dataclasses import dataclass,fields,is_dataclass,MISSING,asdict
from collections import namedtuple
from inspect import isfunction,ismethod,signature,Parameter,get_annotations
from functools import wraps, partialmethod

from starlette.applications import Starlette
from starlette.routing import Route, Mount, Router
from starlette.responses import Response, HTMLResponse, FileResponse, JSONResponse
from starlette.requests import Request
from starlette.staticfiles import StaticFiles
from starlette.exceptions import HTTPException
from starlette._utils import is_async_callable
from starlette.convertors import Convertor, StringConvertor, register_url_convertor, CONVERTOR_TYPES

In [None]:
import time

from IPython import display
from enum import Enum
from pprint import pprint

from fastcore.test import *
from starlette.testclient import TestClient

In [None]:
#| export
empty = Parameter.empty

In [None]:
#| export
def is_namedtuple(cls):
    return issubclass(cls, tuple) and hasattr(cls, '_fields')

In [None]:
#| export
def date(s): return dateutil.parser.parse(s)

In [None]:
#| export
def snake2hyphens(s):
    s = snake2camel(s)
    return camel2words(s, '-')

In [None]:
#| export
htmx_hdrs = dict(
    boosted="HX-Boosted",
    current_url="HX-Current-URL",
    history_restore_request="HX-History-Restore-Request",
    prompt="HX-Prompt",
    request="HX-Request",
    target="HX-Target",
    trigger_name="HX-Trigger-Name",
    trigger="HX-Trigger")

@dataclass
class HtmxHeaders:
    boosted:str|None=None; current_url:str|None=None; history_restore_request:str|None=None; prompt:str|None=None
    request:str|None=None; target:str|None=None; trigger_name:str|None=None; trigger:str|None=None
    def __bool__(self): return any(hasattr(self,o) for o in htmx_hdrs)

def _get_htmx(req):
    res = {k:req.headers.get(v.lower(), None) for k,v in htmx_hdrs.items()}
    return HtmxHeaders(**res)

In [None]:
#| export
def str2int(s)->int:
    "Convert `s` to an `int`"
    s = s.lower()
    if s=='on': return 1
    if s=='none': return 0
    return 0 if not s else int(s)

In [None]:
#| export
def _fix_anno(t):
    origin = get_origin(t)
    if origin is Union or origin is UnionType:
        t = first(o for o in get_args(t) if o!=type(None))
    d = {bool: str2bool, int: str2int}
    return d.get(t, t)

In [None]:
#| export
def _form_arg(k, v, d):
    if v is None: return
    anno = d.get(k, None)
    if not anno: return v
    return _fix_anno(anno)(v)

In [None]:
#| export
def _is_body(anno):
    return issubclass(anno, (dict,ns)) or is_dataclass(anno) or is_namedtuple(anno) or get_annotations(anno)

def _anno2flds(anno):
    if is_dataclass(anno): return {o.name:o.type for o in fields(anno)}
    if is_namedtuple(anno): return {o:str for o in anno._fields}
    annoanno = get_annotations(anno)
    if annoanno: return annoanno
    return {}

async def _from_body(req, arg, p):
    body = await req.form()
    anno = p.annotation
    d = _anno2flds(anno)
    cargs = {k:_form_arg(k, v, d) for k,v in body.items()}
    return anno(**cargs)

In [None]:
#| export
async def _find_p(req, arg:str, p):
    anno = p.annotation
    if isinstance(anno, type):
        if issubclass(anno, Request): return req
        if issubclass(anno, HtmxHeaders): return _get_htmx(req)
        if issubclass(anno, Starlette): return req.scope['app']
        if _is_body(anno): return await _from_body(req, arg, p)
    if anno is empty:
        if 'request'.startswith(arg.lower()): return req
        if arg.lower()=='htmx': return _get_htmx(req)
        if arg.lower()=='app': return req.scope['app']
        return None
    res = req.path_params.get(arg, None)
    if not res: res = req.query_params.get(arg, None)
    if not res: res = req.cookies.get(arg, None)
    if not res: res = req.headers.get(snake2hyphens(arg), None)
    if not res: res = p.default
    if res is empty or res is None: return None
    if not isinstance(res, str) or anno is empty: return res
    return _fix_anno(anno)(res)

In [None]:
#| export
async def _wrap_req(req, params):
    return [await _find_p(req, arg, p) for arg,p in params.items()]

In [None]:
#| export
def _xt_resp(req, resp, hdrs, **bodykw):
    if resp and 'hx-request' not in req.headers and isinstance(resp,tuple) and resp[0][0]=='title':
        title,bdy = resp
        resp = Html(Header(title, *hdrs), Body(bdy, **bodykw))
    return HTMLResponse(to_xml(resp))

In [None]:
#| export
def _wrap_resp(req, resp, cls, hdrs, **bodykw):
    if isinstance(resp, Response): return resp
    if cls is not empty: return cls(resp)
    if isinstance(resp, (list,tuple)): return _xt_resp(req, resp, hdrs, **bodykw)
    if isinstance(resp, str): cls = HTMLResponse 
    elif isinstance(resp, Mapping): cls = JSONResponse 
    else:
        resp = str(resp)
        cls = HTMLResponse
    return cls(resp)

In [None]:
#| export
def _wrap_ep(f, hdrs, **bodykw):
    if not (isfunction(f) or ismethod(f)): return f
    sig = signature(f)
    params = sig.parameters
    cls = sig.return_annotation

    async def _f(req):
        wreq = await _wrap_req(req, params)
        resp = f(*wreq)
        if is_async_callable(f): resp = await resp
        return _wrap_resp(req, resp, cls, hdrs, **bodykw)
    return _f

In [None]:
#| export
class RouteX(Route):
    def __init__(self, path:str, endpoint, *, methods=None, name=None, include_in_schema=True, middleware=None,
                hdrs=None, **bodykw):
        super().__init__(path, _wrap_ep(endpoint, hdrs, **bodykw), methods=methods, name=name,
                         include_in_schema=include_in_schema, middleware=middleware)

In [None]:
#| export
class RouterX(Router):
    def __init__(self, routes=None, redirect_slashes=True, default=None, on_startup=None, on_shutdown=None,
                 lifespan=None, *, middleware=None, hdrs=None, **bodykw):
        super().__init__(routes, redirect_slashes, default, on_startup, on_shutdown,
                 lifespan=lifespan, middleware=middleware)
        self.hdrs,self.bodykw = hdrs or (),bodykw

    def add_route( self, path: str, endpoint: callable, methods=None, name=None, include_in_schema=True):
        route = RouteX(path, endpoint=endpoint, methods=methods, name=name, include_in_schema=include_in_schema,
                      hdrs=self.hdrs, **self.bodykw)
        self.routes = [o for o in self.routes if o.methods!=methods or o.path!=path]
        self.routes.append(route)

In [None]:
#| export
htmxscr = Script(
    src="https://unpkg.com/htmx.org@1.9.12", crossorigin="anonymous",
    integrity="sha384-ujb1lZYygJmzgSwoxRggbCHcjc0rB2XoQrxeTUQyRjrOnlCoYta87iKBWq3EsdM2")

In [None]:
#| export
all_meths = 'get post put delete patch head trace options'.split()

class FastHTML(Starlette):
    def __init__(self, debug=False, routes=None, middleware=None, exception_handlers=None,
                 on_startup=None, on_shutdown=None, lifespan=None, hdrs=None, **bodykw):
        super().__init__(debug, routes, middleware, exception_handlers, on_startup, on_shutdown, lifespan=lifespan)
        hdrs = list([] if hdrs is None else hdrs) + [htmxscr]
        self.router = RouterX(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan, hdrs=hdrs, **bodykw)

    def __getitem__(self, path):
        name = None
        if isinstance(path, tuple): path,name = path
        def f(func):
            meth = func.__name__
            if meth not in all_meths: meth='get'
            self.router.add_route(path, func, methods=[meth], name=name)
            return func
        return f

    def route(self, path:str, methods=None, name=None, include_in_schema=True):
        if isinstance(methods,str): methods=[methods]
        def f(func):
            self.router.add_route(path, func, methods=methods, name=name, include_in_schema=include_in_schema)
            return func
        return f

for o in all_meths:
    setattr(FastHTML, o, partialmethod(FastHTML.route, methods=o.capitalize()))

In [None]:
#| export
def reg_re_param(m, s):
    cls = get_class(f'{m}Conv', sup=StringConvertor, regex=s)
    register_url_convertor(m, cls())

# Starlette doesn't have the '?', so it chomps the whole remaining URL
reg_re_param("path", ".*?")
reg_re_param("static", "ico|gif|jpg|jpeg|webm|css|js|woff|png|svg|mp4|webp|ttf|otf|eot|woff2|txt|xml")

## Tests

In [None]:
from starlette.responses import Response
from datetime import datetime
from fastcore.utils import *
from dataclasses import dataclass, asdict

In [None]:
app = FastHTML()
cli = TestClient(app)

In [None]:
@app["/hi"]
def get(): return 'Hi there'

r = cli.get('/hi')
r.text

'Hi there'

In [None]:
@app["/hi"]
def post(): return 'Postal'

cli.post('/hi').text

'Postal'

In [None]:
@app.get("/")
def show_host(req): return req.headers['host']

cli.get('/').text

'testserver'

In [None]:
@app['/user/{nm}', 'gday']
def get(nm:str): return f"Good day to you, {nm}!"

cli.get('/user/Alexis').text

'Good day to you, Alexis!'

In [None]:
app.router.url_path_for('gday', nm='Jeremy')

'/user/Jeremy'

In [None]:
def test_r(cli, path, exp, meth='get', **kwargs): test_eq(getattr(cli, meth)(path, **kwargs).text, exp)

app.chk = 'foo'
ModelName = str_enum('ModelName', "alexnet", "resnet", "lenet")
fake_db = [{"name": "Foo"}, {"name": "Bar"}]

In [None]:
@app['/html/{idx}']
async def get(idx:int): return Body(H4(f'Next is {idx+1}.'))

reg_re_param("imgext", "ico|gif|jpg|jpeg|webm")

@app[r'/static/{path:path}{fn}.{ext:imgext}']
def get(fn:str, path:str, ext:str): return f"Getting {fn}.{ext} from /{path}"

@app["/models/{nm}"]
def get(nm:ModelName): return nm

@app["/files/{path}"]
async def get(path: Path): return path.with_suffix('.txt')

@app["/items/"]
def get(idx:int|None = 0): return fake_db[idx]

In [None]:
test_r(cli, '/html/1', '<body>\n  <h4>\nNext is 2.\n  </h4>\n</body>\n')
test_r(cli, '/static/foo/jph.ico', 'Getting jph.ico from /foo/')
test_r(cli, '/models/alexnet', 'alexnet')
test_r(cli, '/files/foo', 'foo.txt')
test_r(cli, '/items/?idx=1', '{"name":"Bar"}')
test_r(cli, '/items/', '{"name":"Foo"}')

In [None]:
@app.get("/booly/")
def _(coming:bool=True): return 'Coming' if coming else 'Not coming'

@app.get("/datie/")
def _(d:date): return d

@app.get("/ua")
async def _(user_agent:str): return user_agent

@app.get("/hxtest")
def _(htmx): return htmx.request

@app.get("/hxtest2")
def _(foo:HtmxHeaders, req): return foo.request

@app.get("/app")
def _(app): return app.chk

@app.get("/app2")
def _(foo:FastHTML): return foo.chk

In [None]:
test_r(cli, '/booly/?coming=true', 'Coming')
test_r(cli, '/booly/?coming=no', 'Not coming')
date_str = "17th of May, 2024, 2p"
test_r(cli, f'/datie/?d={date_str}', '2024-05-17 14:00:00')
test_r(cli, '/ua', 'FastHTML', headers={'User-Agent':'FastHTML'})
test_r(cli, '/hxtest' , '1', headers={'HX-Request':'1'})
test_r(cli, '/hxtest2', '1', headers={'HX-Request':'1'})
test_r(cli, '/app' , 'foo')
test_r(cli, '/app2', 'foo')

In [None]:
@dataclass
class Bodie: a:int;b:str

@app.post("/bodie/{nm}/")
async def bodie(nm:str, data:Bodie):
    res = asdict(data)
    res['nm'] = nm
    return res

@app.post("/bodied/")
async def bodied(nm:str, data:dict): return data

nt = namedtuple('Bodient', ['a','b'])

@app.post("/bodient/")
async def bodient(nm:str, data:nt): return data._asdict()

class Bodie2:
    a:int|None; b:str
    def __init__(self, a, b='foo'): store_attr()

@app.post("/bodie2/")
async def bodie(d:Bodie2): return f"a: {d.a}; b: {d.b}"

In [None]:
d = dict(a=1, b='foo')

test_r(cli, '/bodie/me', '{"a":1,"b":"foo","nm":"me"}', 'post', data=d)
test_r(cli, '/bodied/', '{"a":"1","b":"foo"}', 'post', data=d)
test_r(cli, '/bodie2/', 'a: 1; b: foo', 'post', data={'a':1})
test_r(cli, '/bodient/', '{"a":"1","b":"foo"}', 'post', data=d)

In [None]:
@app.get("/setcookie")
async def setc(req):
    now = datetime.now()
    res = Response(f'Set to {now}')
    res.set_cookie('now', str(now))
    return res

@app.get("/getcookie")
async def getc(now:date): return f'Cookie was set at time {now.time()}'

In [None]:
print(cli.get('/setcookie').text)
time.sleep(0.01)
cli.get('/getcookie').text

Set to 2024-05-31 01:36:06.189385


'Cookie was set at time 01:36:06.189385'

# Export -

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()