Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 Source GitHub: Fix MultipleToken rotation logic #34503

Merged
merged 20 commits into from Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions airbyte-integrations/connectors/source-github/metadata.yaml
Expand Up @@ -6,11 +6,11 @@ data:
hosts:
- ${api_url}
connectorBuildOptions:
baseImage: docker.io/airbyte/python-connector-base:1.1.0@sha256:bd98f6505c6764b1b5f99d3aedc23dfc9e9af631a62533f60eb32b1d3dbab20c
baseImage: docker.io/airbyte/python-connector-base:1.2.0@sha256:c22a9d97464b69d6ef01898edf3f8612dc11614f05a84984451dde195f337db9
connectorSubtype: api
connectorType: source
definitionId: ef69ef6e-aa7f-4af1-a01d-ef775033524e
dockerImageTag: 1.5.5
dockerImageTag: 1.5.6
dockerRepository: airbyte/source-github
documentationUrl: https://docs.airbyte.com/integrations/sources/github
githubIssueLabel: source-github
Expand Down
Expand Up @@ -123,14 +123,7 @@ def get_access_token(config: Mapping[str, Any]):
def _get_authenticator(self, config: Mapping[str, Any]):
_, token = self.get_access_token(config)
tokens = [t.strip() for t in token.split(constants.TOKEN_SEPARATOR)]
requests_per_hour = config.get("requests_per_hour")
if requests_per_hour:
return MultipleTokenAuthenticatorWithRateLimiter(
tokens=tokens,
auth_method="token",
requests_per_hour=requests_per_hour,
)
return MultipleTokenAuthenticator(tokens=tokens, auth_method="token")
return MultipleTokenAuthenticatorWithRateLimiter(tokens=tokens)

def _validate_and_transform_config(self, config: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
config = self._ensure_default_values(config)
Expand Down
Expand Up @@ -130,13 +130,6 @@
"description": "List of GitHub repository branches to pull commits for, e.g. `airbytehq/airbyte/master`. If no branches are specified for a repository, the default branch will be pulled.",
"order": 4,
"pattern_descriptor": "org/repo/branch1 org/repo/branch2"
},
"requests_per_hour": {
"type": "integer",
"title": "Max requests per hour",
"description": "The GitHub API allows for a maximum of 5000 requests per hour (15000 for Github Enterprise). You can specify a lower value to limit your use of the API quota.",
"minimum": 1,
"order": 5
}
}
},
Expand Down
Expand Up @@ -25,7 +25,7 @@
get_query_pull_requests,
get_query_reviews,
)
from .utils import getter
from .utils import GitHubAPILimitException, getter


class GithubStreamABC(HttpStream, ABC):
Expand All @@ -38,6 +38,8 @@ class GithubStreamABC(HttpStream, ABC):
stream_base_params = {}

def __init__(self, api_url: str = "https://api.github.com", access_token_type: str = "", **kwargs):
if kwargs.get("authenticator"):
kwargs["authenticator"].max_time = self.max_time
super().__init__(**kwargs)

self.access_token_type = access_token_type
Expand Down Expand Up @@ -126,16 +128,25 @@ def backoff_time(self, response: requests.Response) -> Optional[float]:
# we again could have 5000 per another hour.

min_backoff_time = 60.0

retry_after = response.headers.get("Retry-After")
if retry_after is not None:
return max(float(retry_after), min_backoff_time)
backoff_time_in_seconds = max(float(retry_after), min_backoff_time)
return self.get_waiting_time(backoff_time_in_seconds)

reset_time = response.headers.get("X-RateLimit-Reset")
if reset_time:
return max(float(reset_time) - time.time(), min_backoff_time)
backoff_time_in_seconds = max(float(reset_time) - time.time(), min_backoff_time)
return self.get_waiting_time(backoff_time_in_seconds)

def get_waiting_time(self, backoff_time_in_seconds):
if backoff_time_in_seconds < self.max_time:
return backoff_time_in_seconds
else:
self._session.auth.update_token() # New token will be used in next request
return 1

def check_graphql_rate_limited(self, response_json) -> bool:
@staticmethod
def check_graphql_rate_limited(response_json: dict) -> bool:
errors = response_json.get("errors")
if errors:
for error in errors:
Expand Down Expand Up @@ -203,6 +214,8 @@ def read_records(self, stream_slice: Mapping[str, Any] = None, **kwargs) -> Iter
raise e

self.logger.warning(error_msg)
except GitHubAPILimitException:
self.logger.warning("Limits for all provided tokens are reached, please try again later")


class GithubStream(GithubStreamABC):
Expand Down
132 changes: 98 additions & 34 deletions airbyte-integrations/connectors/source-github/source_github/utils.py
Expand Up @@ -2,14 +2,16 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import logging
import time
from dataclasses import dataclass
from itertools import cycle
from types import SimpleNamespace
from typing import List
from typing import Any, List, Mapping

import pendulum
import requests
from airbyte_cdk.models import SyncMode
from airbyte_cdk.sources.streams import Stream
from airbyte_cdk.sources.streams.http.requests_native_auth import TokenAuthenticator
from airbyte_cdk.sources.streams.http.requests_native_auth.abstract_token import AbstractHeaderAuthenticator


Expand All @@ -32,6 +34,18 @@ def read_full_refresh(stream_instance: Stream):
yield record


class GitHubAPILimitException(Exception):
"""General class for Rate Limits errors"""


@dataclass
class Token:
count_rest: int = 5000
count_graphql: int = 5000
reset_at_rest: pendulum.DateTime = pendulum.now()
reset_at_graphql: pendulum.DateTime = pendulum.now()


class MultipleTokenAuthenticatorWithRateLimiter(AbstractHeaderAuthenticator):
"""
Each token in the cycle is checked against the rate limiter.
Expand All @@ -40,49 +54,99 @@ class MultipleTokenAuthenticatorWithRateLimiter(AbstractHeaderAuthenticator):
the first token becomes available again.
"""

DURATION = 3600 # seconds
DURATION = pendulum.duration(seconds=3600) # Duration at which the current rate limit window resets

def __init__(self, tokens: List[str], requests_per_hour: int, auth_method: str = "Bearer", auth_header: str = "Authorization"):
def __init__(self, tokens: List[str], auth_method: str = "token", auth_header: str = "Authorization"):
self._auth_method = auth_method
self._auth_header = auth_header
now = time.time()
self._requests_per_hour = requests_per_hour
self._tokens = {t: SimpleNamespace(count=self._requests_per_hour, update_at=now) for t in tokens}
self._tokens = {t: Token() for t in tokens}
self.check_all_tokens()
self._tokens_iter = cycle(self._tokens)
self._active_token = next(self._tokens_iter)
self._max_time = 60 * 10 # 10 minutes as default

@property
def auth_header(self) -> str:
return self._auth_header

def get_auth_header(self) -> Mapping[str, Any]:
"""The header to set on outgoing HTTP requests"""
if self.auth_header:
return {self.auth_header: self.token}
return {}

def __call__(self, request):
"""Attach the HTTP headers required to authenticate on the HTTP request"""
while True:
current_token = self._tokens[self.current_active_token]
if "graphql" in request.path_url:
if self.process_token(current_token, "count_graphql", "reset_at_graphql"):
break
else:
if self.process_token(current_token, "count_rest", "reset_at_rest"):
break

request.headers.update(self.get_auth_header())

return request

@property
def current_active_token(self) -> str:
return self._active_token

def update_token(self) -> None:
self._active_token = next(self._tokens_iter)

@property
def token(self) -> str:
while True:
token = next(self._tokens_iter)
if self._check_token(token):
return f"{self._auth_method} {token}"

def _check_token(self, token: str):
token = self.current_active_token
return f"{self._auth_method} {token}"

@property
def max_time(self) -> int:
return self._max_time

@max_time.setter
def max_time(self, value: int) -> None:
self._max_time = value

def _check_token_limits(self, token: str):
"""check that token is not limited"""
self._refill()
if self._sleep():
self._refill()
if self._tokens[token].count > 0:
self._tokens[token].count -= 1
return True
headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
rate_limit_info = (
requests.get(
"https://api.github.com/rate_limit", headers=headers, auth=TokenAuthenticator(token, auth_method=self._auth_method)
)
.json()
.get("resources")
)
token_info = self._tokens[token]
remaining_info_core = rate_limit_info.get("core")
token_info.count_rest, token_info.reset_at_rest = remaining_info_core.get("remaining"), pendulum.from_timestamp(
remaining_info_core.get("reset")
)

remaining_info_graphql = rate_limit_info.get("graphql")
token_info.count_graphql, token_info.reset_at_graphql = remaining_info_graphql.get("remaining"), pendulum.from_timestamp(
remaining_info_graphql.get("reset")
)

def _refill(self):
"""refill all needed tokens"""
now = time.time()
for token, ns in self._tokens.items():
if now - ns.update_at >= self.DURATION:
ns.update_at = now
ns.count = self._requests_per_hour

def _sleep(self):
"""sleep only if all tokens is exhausted"""
now = time.time()
if sum([ns.count for ns in self._tokens.values()]) == 0:
sleep_time = self.DURATION - (now - min([ns.update_at for ns in self._tokens.values()]))
logging.warning("Sleeping for %.1f seconds to enforce the limit of %d requests per hour.", sleep_time, self._requests_per_hour)
time.sleep(sleep_time)
def check_all_tokens(self):
for token in self._tokens:
self._check_token_limits(token)

def process_token(self, current_token, count_attr, reset_attr):
if getattr(current_token, count_attr) > 0:
setattr(current_token, count_attr, getattr(current_token, count_attr) - 1)
return True
elif all(getattr(x, count_attr) == 0 for x in self._tokens.values()):
min_time_to_wait = min((getattr(x, reset_attr) - pendulum.now()).in_seconds() for x in self._tokens.values())
if min_time_to_wait < self.max_time:
time.sleep(min_time_to_wait)
self.check_all_tokens()
else:
raise GitHubAPILimitException(f"Rate limits for all tokens ({count_attr}) were reached")
else:
self.update_token()
return False
Expand Up @@ -2,4 +2,28 @@

import os

import pytest
import responses

os.environ["REQUEST_CACHE_PATH"] = "REQUEST_CACHE_PATH"


@pytest.fixture(name="rate_limit_mock_response")
def rate_limit_mock_response():
rate_limit_response = {
"resources": {
"core": {
"limit": 5000,
"used": 0,
"remaining": 5000,
"reset": 4070908800
},
"graphql": {
"limit": 5000,
"used": 0,
"remaining": 5000,
"reset": 4070908800
}
}
}
responses.add(responses.GET, "https://api.github.com/rate_limit", json=rate_limit_response)