In [None]:
#| default_exp oauth

# OAuth
> Basic scaffolding for handling OAuth

- eval: false
- skip_exec: true

This is not yet thoroughly tested. See the [docs page](https://docs.fastht.ml/explains/oauth.html) for an explanation of how to use this.

In [None]:
#| export
from fasthtml.common import *
from oauthlib.oauth2 import WebApplicationClient
from urllib.parse import urlparse, urlencode, parse_qs, quote, unquote
from httpx import get, post
import secrets

In [None]:
from IPython.display import Markdown

In [None]:
#| export
class _AppClient(WebApplicationClient):
    def __init__(self, client_id, client_secret, code=None, scope=None, **kwargs):
        super().__init__(client_id, code=code, scope=scope, **kwargs)
        self.client_secret = client_secret

In [None]:
#| export
class GoogleAppClient(_AppClient):
    "A `WebApplicationClient` for Google oauth2"
    base_url = "https://accounts.google.com/o/oauth2/v2/auth"
    token_url = "https://www.googleapis.com/oauth2/v4/token"
    info_url = "https://www.googleapis.com/oauth2/v3/userinfo"
    id_key = 'sub'
    
    def __init__(self, client_id, client_secret, code=None, scope=None, **kwargs):
        scope_pre = "https://www.googleapis.com/auth/userinfo"
        if not scope: scope=["openid", f"{scope_pre}.email", f"{scope_pre}.profile"]
        super().__init__(client_id, client_secret, code=code, scope=scope, **kwargs)
    
    @classmethod
    def from_file(cls, fname, code=None, scope=None, **kwargs):
        cred = Path(fname).read_json()['web']
        return cls(cred['client_id'], client_secret=cred['client_secret'], code=code, scope=scope, **kwargs)

In [None]:
#| export
class GitHubAppClient(_AppClient):
    "A `WebApplicationClient` for GitHub oauth2"
    base_url = "https://github.com/login/oauth/authorize"
    token_url = "https://github.com/login/oauth/access_token"
    info_url = "https://api.github.com/user"
    id_key = 'id'

    def __init__(self, client_id, client_secret, code=None, scope=None, **kwargs):
        if not scope: scope="user"
        super().__init__(client_id, client_secret, code=code, scope=scope, **kwargs)

In [None]:
#| export
class HuggingFaceClient(_AppClient):
    "A `WebApplicationClient` for HuggingFace oauth2"

    base_url = "https://huggingface.co/oauth/authorize"
    token_url = "https://huggingface.co/oauth/token"
    info_url = "https://huggingface.co/oauth/userinfo"
    id_key = 'sub'
    
    def __init__(self, client_id, client_secret, code=None, scope=None, state=None, **kwargs):
        if not scope: scope=["openid","profile"]
        if not state: state=secrets.token_urlsafe(16)
        super().__init__(client_id, client_secret, code=code, scope=scope, state=state, **kwargs)

In [None]:
#| export
class DiscordAppClient(_AppClient):
    "A `WebApplicationClient` for Discord oauth2"
    base_url = "https://discord.com/oauth2/authorize"
    token_url = "https://discord.com/api/oauth2/token"
    revoke_url = "https://discord.com/api/oauth2/token/revoke"
    id_key = 'id'

    def __init__(self, client_id, client_secret, is_user=False, perms=0, scope=None, **kwargs):
        if not scope: scope="applications.commands applications.commands.permissions.update identify"
        self.integration_type = 1 if is_user else 0
        self.perms = perms
        super().__init__(client_id, client_secret, scope=scope, **kwargs)

    def login_link(self):
        d = dict(response_type='code', client_id=self.client_id,
                 integration_type=self.integration_type, scope=self.scope) #, permissions=self.perms, prompt='consent')
        return f'{self.base_url}?' + urlencode(d)

    def parse_response(self, code):
        headers = {'Content-Type': 'application/x-www-form-urlencoded'}
        data = dict(grant_type='authorization_code', code=code)#, redirect_uri=self.redirect_uri)
        r = post(self.token_url, data=data, headers=headers, auth=(self.client_id, self.client_secret))
        r.raise_for_status()
        self.parse_request_body_response(r.text)

In [None]:
cli = GoogleAppClient.from_file('/Users/jhoward/git/nbs/oauth-test/client_secret.json')

In [None]:
#| export
@patch
def login_link(self:WebApplicationClient, redirect_uri, scope=None, state=None):
    "Get a login link for this client"
    if not scope: scope=self.scope
    if not state: state=getattr(self, 'state', None)
    return self.prepare_request_uri(self.base_url, redirect_uri, scope, state=state)

Generating a login link that sends the user to the OAuth provider is done with `client.login_link()`.

It can sometimes be useful to pass state to the OAuth provider, so that when the user returns you can pick up where they left off. This can be done by passing the `state` parameter.

In [None]:
redir='http://localhost:8000/redirect'
cli.login_link(redir)

'https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=457681028261-5i71skrhb7ko4l8mlug5i0230q980do7.apps.googleusercontent.com&redirect_uri=http%3A%2F%2Flocalhost%3A8000%2Fredirect&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.profile'

In [None]:
def login_md(cli, redirect_uri, scope=None, state=None):
    "Display login link in notebook (for testing)"
    return Markdown(f'[login]({cli.login_link(redirect_uri, scope, state=state)})')

In [None]:
login_md(cli, redir, state='test_state')

[login](https://accounts.google.com/o/oauth2/v2/auth?response_type=code&client_id=457681028261-5i71skrhb7ko4l8mlug5i0230q980do7.apps.googleusercontent.com&redirect_uri=http%3A%2F%2Flocalhost%3A8000%2Fredirect&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.profile&state=test_state)

In [None]:
#| export
@patch
def parse_response(self:_AppClient, code, redirect_uri):
    "Get the token from the oauth2 server response"
    payload = dict(code=code, redirect_uri=redirect_uri, client_id=self.client_id,
                   client_secret=self.client_secret, grant_type='authorization_code')
    r = post(self.token_url, json=payload)
    r.raise_for_status()
    self.parse_request_body_response(r.text)

In [None]:
#| export
def decode(code_url):
    parsed_url = urlparse(code_url)
    query_params = parse_qs(parsed_url.query)
    return query_params.get('code', [''])[0], query_params.get('state', [''])[0], code_url.split('?')[0]

In [None]:
code_url = 'http://localhost:8000/redirect?state=test_state&code=4%2F0AQlEd8xCOSfc7yjmmylO6BTVgWtAmji4GkfITsWecq0CXlm-8wBRgwNmkDmXQEdOqw0REQ&scope=email+profile+openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.profile+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email&authuser=0&hd=answer.ai&prompt=consent'

code,state,redir = decode(code_url)

In [None]:
cli.parse_response(code, redir)
print(state)

test_state


In [None]:
#| export
@patch
def get_info(self:_AppClient, token=None):
    "Get the info for authenticated user"
    if not token: token = self.token["access_token"]
    headers = {'Authorization': f'Bearer {token}'}
    return get(self.info_url, headers=headers).json()

In [None]:
info

{'sub': '100000802623412015452',
 'name': 'Jeremy Howard',
 'given_name': 'Jeremy',
 'family_name': 'Howard',
 'picture': 'https://lh3.googleusercontent.com/a/ACg8ocID3bYiwh1wJNVjvlSUy0dGxvXbNjDt1hdhypQDinDf28DfEA=s96-c',
 'email': 'j@answer.ai',
 'email_verified': True,
 'hd': 'answer.ai'}

In [None]:
#| export
@patch
def retr_info(self:_AppClient, code, redirect_uri):
    "Combines `parse_response` and `get_info`"
    self.parse_response(code, redirect_uri)
    return self.get_info()

In [None]:
#| export
@patch
def retr_id(self:_AppClient, code, redirect_uri):
    "Call `retr_info` and then return id/subscriber value"
    return self.retr_info(code, redirect_uri)[self.id_key]

After logging in via the provider, the user will be redirected back to the supplied redirect URL. The request to this URL will contain a `code` parameter, which is used to get an access token and fetch the user's profile information. See [the explanation here](https://docs.fastht.ml/explains/oauth.html) for a worked example. You can either:

- Use client.retr_info(code) to get all the profile information, or
- Use client.retr_id(code) to get just the user's ID.

After either of these calls, you can also access the access token (used to revoke access, for example) with `client.token["access_token"]`.

In [None]:
#| export
class OAuth:
    def __init__(self, app, cli, skip=None, redir_path='/redirect', logout_path='/logout', login_path='/login'):
        if not skip: skip = [redir_path,login_path]
        self.app,self.cli,self.skip,self.redir_path,self.logout_path,self.login_path = app,cli,skip,redir_path,logout_path,login_path

        def before(req, session):
            auth = req.scope['auth'] = session.get('auth')
            if not auth: return RedirectResponse(self.login_path, status_code=303)
            info = AttrDictDefault(cli.get_info(auth))
            if not self._chk_auth(info, session): return RedirectResponse(self.login_path, status_code=303)
        app.before.append(Beforeware(before, skip=skip))

        @app.get(redir_path)
        def redirect(code:str, req, session, state:str=None):
            if not code: return "No code provided!"
            base_url = f"{req.url.scheme}://{req.url.netloc}"
            print(base_url)
            info = AttrDictDefault(cli.retr_info(code, base_url+redir_path))
            if not self._chk_auth(info, session): return RedirectResponse(self.login_path, status_code=303)
            session['auth'] = cli.token['access_token']
            return self.login(info, state)

        @app.get(logout_path)
        def logout(session):
            session.pop('auth', None)
            return self.logout(session)

    def redir_url(self, req): return f"{req.url.scheme}://{req.url.netloc}{self.redir_path}"
    def login_link(self, req, scope=None, state=None): return self.cli.login_link(self.redir_url(req), scope=scope, state=state)

    def login(self, info, state): raise NotImplementedError()
    def logout(self, session): return RedirectResponse(self.login_path, status_code=303)
    def chk_auth(self, info, ident, session): raise NotImplementedError()
    def _chk_auth(self, info, session):
        ident = info.get(self.cli.id_key)
        return ident and self.chk_auth(info, ident, session)

# Export -

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()