Skip to content

Commit

Permalink
Extend low code OAuthAuthenticator with token refresh capabilities (#…
Browse files Browse the repository at this point in the history
…26966)

* wip

* Automated Commit - Formatting Changes

* add documentation

* tests and fixes

* fix tests

* more documentation

* revert

* changes as discussed

* fix case

* add docstring

* add details to schema

* format

* fix bug

---------

Co-authored-by: flash1293 <flash1293@users.noreply.github.com>
  • Loading branch information
Joe Reuter and flash1293 committed Jun 7, 2023
1 parent 6defead commit b34fb00
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 70 deletions.
11 changes: 11 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sources/declarative/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
from airbyte_cdk.sources.declarative.interpolation.interpolated_string import InterpolatedString
from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_oauth import AbstractOauth2Authenticator
from airbyte_cdk.sources.streams.http.requests_native_auth.oauth import SingleUseRefreshTokenOauth2Authenticator


@dataclass
Expand Down Expand Up @@ -133,3 +134,13 @@ def access_token(self) -> str:
@access_token.setter
def access_token(self, value: str):
self._access_token = value


@dataclass
class DeclarativeSingleUseRefreshTokenOauth2Authenticator(SingleUseRefreshTokenOauth2Authenticator, DeclarativeAuthenticator):
"""
Declarative version of SingleUseRefreshTokenOauth2Authenticator which can be used in declarative connectors.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,46 @@ definitions:
type: string
examples:
- "%Y-%m-%d %H:%M:%S.%f+00:00"
refresh_token_updater:
title: Token Updater
description: When the token updater is defined, new refresh tokens, access tokens and the access token expiry date are written back from the authentication response to the config object. This is important if the refresh token can only used once.
properties:
refresh_token_name:
title: Refresh Token Property Name
description: The name of the property which contains the updated refresh token in the response from the token refresh endpoint.
type: string
default: "refresh_token"
examples:
- "refresh_token"
access_token_config_path:
title: Config Path To Access Token
description: Config path to the access token. Make sure the field actually exists in the config.
type: array
items:
type: string
default: ["credentials", "access_token"]
examples:
- ["credentials", "access_token"]
- ["access_token"]
refresh_token_config_path:
title: Config Path To Refresh Token
description: Config path to the access token. Make sure the field actually exists in the config.
type: array
items:
type: string
default: ["credentials", "refresh_token"]
examples:
- ["credentials", "refresh_token"]
- ["refresh_token"]
token_expiry_date_config_path:
title: Config Path To Expiry Date
description: Config path to the expiry date. Make sure actually exists in the config.
type: array
items:
type: string
default: ["credentials", "token_expiry_date"]
examples:
- ["credentials", "token_expiry_date"]
$parameters:
type: object
additionalProperties: true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,33 @@ class Config:
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")


class RefreshTokenUpdater(BaseModel):
refresh_token_name: Optional[str] = Field(
"refresh_token",
description="The name of the property which contains the updated refresh token in the response from the token refresh endpoint.",
examples=["refresh_token"],
title="Refresh Token Property Name",
)
access_token_config_path: Optional[List[str]] = Field(
["credentials", "access_token"],
description="Config path to the access token. Make sure the field actually exists in the config.",
examples=[["credentials", "access_token"], ["access_token"]],
title="Config Path To Access Token",
)
refresh_token_config_path: Optional[List[str]] = Field(
["credentials", "refresh_token"],
description="Config path to the access token. Make sure the field actually exists in the config.",
examples=[["credentials", "refresh_token"], ["refresh_token"]],
title="Config Path To Refresh Token",
)
token_expiry_date_config_path: Optional[List[str]] = Field(
["credentials", "token_expiry_date"],
description="Config path to the expiry date. Make sure actually exists in the config.",
examples=[["credentials", "token_expiry_date"]],
title="Config Path To Expiry Date",
)


class OAuthAuthenticator(BaseModel):
type: Literal["OAuthAuthenticator"]
client_id: str = Field(
Expand Down Expand Up @@ -340,6 +367,11 @@ class OAuthAuthenticator(BaseModel):
examples=["%Y-%m-%d %H:%M:%S.%f+00:00"],
title="Token Expiry Date Format",
)
refresh_token_updater: Optional[RefreshTokenUpdater] = Field(
None,
description="When the token updater is defined, new refresh tokens, access tokens and the access token expiry date are written back from the authentication response to the config object. This is important if the refresh token can only used once.",
title="Token Updater",
)
parameters: Optional[Dict[str, Any]] = Field(None, alias="$parameters")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
import re
from typing import Any, Callable, List, Literal, Mapping, Optional, Type, Union, get_args, get_origin, get_type_hints

import dpath
from airbyte_cdk.sources.declarative.auth import DeclarativeOauth2Authenticator
from airbyte_cdk.sources.declarative.auth.declarative_authenticator import NoAuth
from airbyte_cdk.sources.declarative.auth.oauth import DeclarativeSingleUseRefreshTokenOauth2Authenticator
from airbyte_cdk.sources.declarative.auth.token import (
ApiKeyAuthenticator,
BasicHttpAuthenticator,
Expand All @@ -24,6 +26,7 @@
from airbyte_cdk.sources.declarative.extractors import DpathExtractor, RecordFilter, RecordSelector
from airbyte_cdk.sources.declarative.incremental import DatetimeBasedCursor
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
from airbyte_cdk.sources.declarative.models.declarative_component_schema import AddedFieldDefinition as AddedFieldDefinitionModel
from airbyte_cdk.sources.declarative.models.declarative_component_schema import AddFields as AddFieldsModel
from airbyte_cdk.sources.declarative.models.declarative_component_schema import ApiKeyAuthenticator as ApiKeyAuthenticatorModel
Expand Down Expand Up @@ -659,6 +662,23 @@ def create_no_pagination(model: NoPaginationModel, config: Config, **kwargs) ->

@staticmethod
def create_oauth_authenticator(model: OAuthAuthenticatorModel, config: Config, **kwargs) -> DeclarativeOauth2Authenticator:
if model.refresh_token_updater:
return DeclarativeSingleUseRefreshTokenOauth2Authenticator(
config,
InterpolatedString.create(model.token_refresh_endpoint, parameters=model.parameters).eval(config),
access_token_name=InterpolatedString.create(model.access_token_name, parameters=model.parameters).eval(config),
refresh_token_name=model.refresh_token_updater.refresh_token_name,
expires_in_name=InterpolatedString.create(model.expires_in_name, parameters=model.parameters).eval(config),
client_id=InterpolatedString.create(model.client_id, parameters=model.parameters).eval(config),
client_secret=InterpolatedString.create(model.client_secret, parameters=model.parameters).eval(config),
access_token_config_path=model.refresh_token_updater.access_token_config_path,
refresh_token_config_path=model.refresh_token_updater.refresh_token_config_path,
token_expiry_date_config_path=model.refresh_token_updater.token_expiry_date_config_path,
grant_type=InterpolatedString.create(model.grant_type, parameters=model.parameters).eval(config),
refresh_request_body=InterpolatedMapping(model.refresh_request_body or {}, parameters=model.parameters).eval(config),
scopes=model.scopes,
token_expiry_date_format=model.token_expiry_date_format,
)
return DeclarativeOauth2Authenticator(
access_token_name=model.access_token_name,
client_id=model.client_id,
Expand All @@ -685,8 +705,8 @@ def create_single_use_refresh_token_oauth_authenticator(
access_token_name=model.access_token_name,
refresh_token_name=model.refresh_token_name,
expires_in_name=model.expires_in_name,
client_id_config_path=model.client_id_config_path,
client_secret_config_path=model.client_secret_config_path,
client_id=dpath.util.get(config, model.client_id_config_path),
client_secret=dpath.util.get(config, model.client_secret_config_path),
access_token_config_path=model.access_token_config_path,
refresh_token_config_path=model.refresh_token_config_path,
token_expiry_date_config_path=model.token_expiry_date_config_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from typing import Any, List, Mapping, Sequence, Tuple, Union
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union

import dpath
import pendulum
Expand Down Expand Up @@ -109,11 +109,12 @@ def __init__(
refresh_token_name: str = "refresh_token",
refresh_request_body: Mapping[str, Any] = None,
grant_type: str = "refresh_token",
client_id_config_path: Sequence[str] = ("credentials", "client_id"),
client_secret_config_path: Sequence[str] = ("credentials", "client_secret"),
client_id: Optional[str] = None,
client_secret: Optional[str] = None,
access_token_config_path: Sequence[str] = ("credentials", "access_token"),
refresh_token_config_path: Sequence[str] = ("credentials", "refresh_token"),
token_expiry_date_config_path: Sequence[str] = ("credentials", "token_expiry_date"),
token_expiry_date_format: Optional[str] = None,
):
"""
Expand All @@ -126,20 +127,23 @@ def __init__(
refresh_token_name (str, optional): Name of the name of the refresh token field, used to parse the refresh token response. Defaults to "refresh_token".
refresh_request_body (Mapping[str, Any], optional): Custom key value pair that will be added to the refresh token request body. Defaults to None.
grant_type (str, optional): OAuth grant type. Defaults to "refresh_token".
client_id_config_path (Sequence[str]): Dpath to the client_id field in the connector configuration. Defaults to ("credentials", "client_id").
client_secret_config_path (Sequence[str]): Dpath to the client_secret field in the connector configuration. Defaults to ("credentials", "client_secret").
client_id (Optional[str]): The client id to authenticate. If not specified, defaults to credentials.client_id in the config object.
client_secret (Optional[str]): The client secret to authenticate. If not specified, defaults to credentials.client_secret in the config object.
access_token_config_path (Sequence[str]): Dpath to the access_token field in the connector configuration. Defaults to ("credentials", "access_token").
refresh_token_config_path (Sequence[str]): Dpath to the refresh_token field in the connector configuration. Defaults to ("credentials", "refresh_token").
token_expiry_date_config_path (Sequence[str]): Dpath to the token_expiry_date field in the connector configuration. Defaults to ("credentials", "token_expiry_date").
token_expiry_date_format (Optional[str]): Date format of the token expiry date field (set by expires_in_name). If not specified the token expiry date is interpreted as number of seconds until expiration.
"""
self._client_id_config_path = client_id_config_path
self._client_secret_config_path = client_secret_config_path
self._client_id = client_id if client_id is not None else dpath.util.get(connector_config, ("credentials", "client_id"))
self._client_secret = (
client_secret if client_secret is not None else dpath.util.get(connector_config, ("credentials", "client_secret"))
)
self._access_token_config_path = access_token_config_path
self._refresh_token_config_path = refresh_token_config_path
self._token_expiry_date_config_path = token_expiry_date_config_path
self._token_expiry_date_format = token_expiry_date_format
self._refresh_token_name = refresh_token_name
self._connector_config = connector_config
self._validate_connector_config()
super().__init__(
token_refresh_endpoint,
self.get_client_id(),
Expand All @@ -151,69 +155,49 @@ def __init__(
expires_in_name=expires_in_name,
refresh_request_body=refresh_request_body,
grant_type=grant_type,
token_expiry_date_format=token_expiry_date_format,
)

def _validate_connector_config(self):
"""Validates the defined getters for configuration values are returning values.
Raises:
ValueError: Raised if the defined getters are not returning a value.
"""
try:
assert self.access_token
except KeyError:
raise ValueError(
f"This authenticator expects a value under the {self._access_token_config_path} field path. Please check your configuration structure or change the access_token_config_path value at initialization of this authenticator."
)
for field_path, getter, parameter_name in [
(self._client_id_config_path, self.get_client_id, "client_id_config_path"),
(self._client_secret_config_path, self.get_client_secret, "client_secret_config_path"),
(self._refresh_token_config_path, self.get_refresh_token, "refresh_token_config_path"),
(self._token_expiry_date_config_path, self.get_token_expiry_date, "token_expiry_date_config_path"),
]:
try:
assert getter()
except KeyError:
raise ValueError(
f"This authenticator expects a value under the {field_path} field path. Please check your configuration structure or change the {parameter_name} value at initialization of this authenticator."
)

def get_refresh_token_name(self) -> str:
return self._refresh_token_name

def get_client_id(self) -> str:
return dpath.util.get(self._connector_config, self._client_id_config_path)
return self._client_id

def get_client_secret(self) -> str:
return dpath.util.get(self._connector_config, self._client_secret_config_path)
return self._client_secret

@property
def access_token(self) -> str:
return dpath.util.get(self._connector_config, self._access_token_config_path)
return dpath.util.get(self._connector_config, self._access_token_config_path, default="")

@access_token.setter
def access_token(self, new_access_token: str):
dpath.util.set(self._connector_config, self._access_token_config_path, new_access_token)
dpath.util.new(self._connector_config, self._access_token_config_path, new_access_token)

def get_refresh_token(self) -> str:
return dpath.util.get(self._connector_config, self._refresh_token_config_path)
return dpath.util.get(self._connector_config, self._refresh_token_config_path, default="")

def set_refresh_token(self, new_refresh_token: str):
dpath.util.set(self._connector_config, self._refresh_token_config_path, new_refresh_token)
dpath.util.new(self._connector_config, self._refresh_token_config_path, new_refresh_token)

def get_token_expiry_date(self) -> pendulum.DateTime:
return pendulum.parse(dpath.util.get(self._connector_config, self._token_expiry_date_config_path))
expiry_date = dpath.util.get(self._connector_config, self._token_expiry_date_config_path, default="")
return pendulum.now().subtract(days=1) if expiry_date == "" else pendulum.parse(expiry_date)

def set_token_expiry_date(self, new_token_expiry_date):
dpath.util.set(self._connector_config, self._token_expiry_date_config_path, str(new_token_expiry_date))
dpath.util.new(self._connector_config, self._token_expiry_date_config_path, str(new_token_expiry_date))

def token_has_expired(self) -> bool:
"""Returns True if the token is expired"""
return pendulum.now("UTC") > self.get_token_expiry_date()

@staticmethod
def get_new_token_expiry_date(access_token_expires_in: int):
return pendulum.now("UTC").add(seconds=access_token_expires_in)
def get_new_token_expiry_date(access_token_expires_in: str, token_expiry_date_format: str = None) -> pendulum.DateTime:
if token_expiry_date_format:
return pendulum.from_format(access_token_expires_in, token_expiry_date_format)
else:
return pendulum.now("UTC").add(seconds=int(access_token_expires_in))

def get_access_token(self) -> str:
"""Retrieve new access and refresh token if the access token has expired.
Expand All @@ -223,17 +207,17 @@ def get_access_token(self) -> str:
"""
if self.token_has_expired():
new_access_token, access_token_expires_in, new_refresh_token = self.refresh_access_token()
new_token_expiry_date = self.get_new_token_expiry_date(access_token_expires_in)
new_token_expiry_date = self.get_new_token_expiry_date(access_token_expires_in, self._token_expiry_date_format)
self.access_token = new_access_token
self.set_refresh_token(new_refresh_token)
self.set_token_expiry_date(new_token_expiry_date)
emit_configuration_as_airbyte_control_message(self._connector_config)
return self.access_token

def refresh_access_token(self) -> Tuple[str, int, str]:
def refresh_access_token(self) -> Tuple[str, str, str]:
response_json = self._get_refresh_access_token_response()
return (
response_json[self.get_access_token_name()],
int(response_json[self.get_expires_in_name()]),
response_json[self.get_expires_in_name()],
response_json[self.get_refresh_token_name()],
)

0 comments on commit b34fb00

Please sign in to comment.