In [None]:
#| default_exp core

# API Details

In [None]:
#| export
import json, dateutil

from fastcore.utils import *
from fastcore.xml import *

from types import UnionType, SimpleNamespace as ns
from typing import Optional, get_type_hints, get_args, get_origin, Union, Mapping
from datetime import datetime
from dataclasses import dataclass,fields,is_dataclass,MISSING,asdict
from inspect import isfunction,ismethod,signature,Parameter
from functools import wraps, partialmethod

from starlette.applications import Starlette
from starlette.routing import Route, Mount, Router
from starlette.responses import Response, HTMLResponse, FileResponse, JSONResponse
from starlette.requests import Request
from starlette.staticfiles import StaticFiles
from starlette.exceptions import HTTPException
from starlette._utils import is_async_callable
from starlette.convertors import Convertor, StringConvertor, register_url_convertor, CONVERTOR_TYPES

In [None]:
from IPython import display
from enum import Enum
from pprint import pprint

from starlette.testclient import TestClient

In [None]:
#| export
empty = Parameter.empty

In [None]:
#| export
def _fix_anno(t):
    origin = get_origin(t)
    if origin is Union or origin is UnionType:
        t = first(o for o in get_args(t) if o!=type(None))
    if t==bool: return str2bool
    return t

In [None]:
#| export
def date(s): return dateutil.parser.parse(s)

In [None]:
#| export
def _form_arg(fld, body):
    res = body.get(fld.name, None)
    if not res: res = fld.default
    assert res is not MISSING
    anno = _fix_anno(fld.type)
    if res is not None: res = anno(res)
    return res

In [None]:
#| export
async def _from_body(req, arg, p):
    body = await req.form()
    cargs = {o.name:_form_arg(o, body) for o in fields(p.annotation)}
    return p.annotation(**cargs)

In [None]:
#| export
def snake2hyphens(s):
    s = snake2camel(s)
    return camel2words(s, '-')

In [None]:
#| export
htmx_hdrs = dict(
    boosted="HX-Boosted",
    current_url="HX-Current-URL",
    history_restore_request="HX-History-Restore-Request",
    prompt="HX-Prompt",
    request="HX-Request",
    target="HX-Target",
    trigger_name="HX-Trigger-Name",
    trigger="HX-Trigger")

def _get_htmx(req):
    res = {k:req.headers.get(v.lower(), None) for k,v in htmx_hdrs.items()}
    return ns(**res) if res else None

In [None]:
#| export
async def _find_p(req, arg:str, p):
    if p.annotation is empty:
        if 'request'.startswith(arg.lower()): return req
        if arg.lower()=='htmx': return _get_htmx(req)
        return None
    if is_dataclass(p.annotation): return await _from_body(req, arg, p)
    res = req.path_params.get(arg, None)
    if not res: res = req.query_params.get(arg, None)
    if not res: res = req.cookies.get(arg, None)
    if not res: res = req.headers.get(snake2hyphens(arg), None)
    if not res: res = p.default
    if res is empty: return None
    anno = _fix_anno(p.annotation)
    if res is not None and anno is not empty: res = anno(res)
    return res

In [None]:
#| export
async def _wrap_req(req, params):
    return [await _find_p(req, arg, p) for arg,p in params.items()]

In [None]:
#| export
def _xt_resp(req, resp, hdrs, **bodykw):
    if resp and 'hx-request' not in req.headers and isinstance(resp,tuple) and resp[0][0] not in ('!doctype','html'):
        title,bdy = resp
        if isinstance(title,str): title=Title(title)
        resp = Html(Header(title, *hdrs), Body(bdy, **bodykw))
    return HTMLResponse(to_xml(resp))

In [None]:
#| export
def _wrap_resp(req, resp, cls, hdrs, **bodykw):
    if isinstance(resp, Response): return resp
    if cls is not empty: return cls(resp)
    if isinstance(resp, (list,tuple)): return _xt_resp(req, resp, hdrs, **bodykw)
    if isinstance(resp, str): cls = HTMLResponse 
    elif isinstance(resp, Mapping): cls = JSONResponse 
    else:
        resp = str(resp)
        cls = HTMLResponse
    return cls(resp)

In [None]:
#| export
def _wrap_ep(f, hdrs, **bodykw):
    if not (isfunction(f) or ismethod(f)): return f
    sig = signature(f)
    params = sig.parameters
    cls = sig.return_annotation

    async def _f(req):
        wreq = await _wrap_req(req, params)
        resp = f(*wreq)
        if is_async_callable(f): resp = await resp
        return _wrap_resp(req, resp, cls, hdrs, **bodykw)
    return _f

In [None]:
#| export
class RouteX(Route):
    def __init__(self, path:str, endpoint, *, methods=None, name=None, include_in_schema=True, middleware=None,
                hdrs=None, **bodykw):
        super().__init__(path, _wrap_ep(endpoint, hdrs, **bodykw), methods=methods, name=name,
                         include_in_schema=include_in_schema, middleware=middleware)

In [None]:
#| export
class RouterX(Router):
    def __init__(self, routes=None, redirect_slashes=True, default=None, on_startup=None, on_shutdown=None,
                 lifespan=None, *, middleware=None, hdrs=None, **bodykw):
        super().__init__(routes, redirect_slashes, default, on_startup, on_shutdown,
                 lifespan=lifespan, middleware=middleware)
        self.hdrs,self.bodykw = hdrs or (),bodykw

    def add_route( self, path: str, endpoint: callable, methods=None, name=None, include_in_schema=True):
        route = RouteX(path, endpoint=endpoint, methods=methods, name=name, include_in_schema=include_in_schema,
                      hdrs=self.hdrs, **self.bodykw)
        self.routes.append(route)

In [None]:
#| export
class FastHTML(Starlette):
    def __init__(self, debug=False, routes=None, middleware=None, exception_handlers=None,
                 on_startup=None, on_shutdown=None, lifespan=None, hdrs=None, **bodykw):
        super().__init__(debug, routes, middleware, exception_handlers, on_startup, on_shutdown, lifespan=lifespan)
        self.router = RouterX(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan, hdrs=hdrs, **bodykw)

    def route(self, path:str, methods=None, name=None, include_in_schema=True):
        if isinstance(methods,str): methods=[methods]
        def f(func):
            self.router.add_route(path, func, methods=methods, name=name, include_in_schema=include_in_schema)
            return func
        return f

for o in 'get post put delete patch head trace options'.split():
    setattr(FastHTML, o, partialmethod(FastHTML.route, methods=o.capitalize()))

In [None]:
#| export
def reg_re_param(m, s):
    cls = get_class(f'{m}Conv', sup=StringConvertor, regex=s)
    register_url_convertor(m, cls())

# Starlette doesn't have the '?', so it chomps the whole remaining URL
reg_re_param("path", ".*?")

# Export -

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()