Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 63 additions & 24 deletions src/main/python/mlsearch/api_requester.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
import html
import random
import collections
import math

# import scholarly

ErrorType = collections.namedtuple("ErrorType", "reason status")


class APIRequest:
"""For handling the different Valid API requests."""

Expand Down Expand Up @@ -77,7 +79,9 @@ def youtube_developer_key(self, developer_key):
if isinstance(developer_key, list):
self._config.YOUTUBE_DEVELOPER_KEY = developer_key
elif isinstance(developer_key, str) and "," in developer_key:
self._config.YOUTUBE_DEVELOPER_KEY = developer_key.strip().split(",")
self._config.YOUTUBE_DEVELOPER_KEY = developer_key.strip().split(
","
)
elif developer_key and isinstance(developer_key, str):
self._config.YOUTUBE_DEVELOPER_KEY.append(developer_key)

Expand Down Expand Up @@ -108,7 +112,9 @@ def _validate_params(self):
"""Validate user input data."""

for item, typ in self.params_model.items():
if item in self.params.keys() and not typ == type(self.params[item]):
if item in self.params.keys() and not typ == type(
self.params[item]
):
raise TypeError(
f"Invalid type for {item}. {typ} is expected but "
f"{type(self.params[item])} is given."
Expand Down Expand Up @@ -140,21 +146,36 @@ def _unescape(self, text):

def _fetch_github(self) -> [Protocol]:
"""Fetch Github Repository"""
item_per_page = self._config.GITHUB_PER_PAGE
github = Github(self._config.GITHUB_ACC_TOKEN, per_page=item_per_page)

github = Github(self._config.GITHUB_ACC_TOKEN)
skip_page = math.floor(self.params["init_idx"] / item_per_page)
total_page = math.ceil(
(self.params["init_idx"] + self.params["count"]) / item_per_page
)
query = "+".join([self.params["query"], self._config.GITHUB_URL])
responses = github.search_repositories(query, "stars", "desc")
results = []

if not self._is_valid_pagination(responses.totalCount):
return

for response in responses[
self.params["init_idx"] : min(
self.params["init_idx"] + self.params["count"], responses.totalCount
)
]:
paginated_responses = list()
for i in range(skip_page + 1, total_page + 1):
paginated_responses.extend(responses.get_page(i))

first_slot_items = item_per_page - (
self.params["init_idx"] % item_per_page
)
end_slot_items = item_per_page - (
(total_page * item_per_page)
- (self.params["count"] + self.params["init_idx"])
)

start_idx = item_per_page - first_slot_items
end_idx = (len(paginated_responses) - item_per_page) + end_slot_items

for response in paginated_responses[start_idx:end_idx]:
data = {
"repository_url": self._unescape(
response.clone_url.replace(".git", "")
Expand Down Expand Up @@ -184,7 +205,9 @@ def _fetch_paperwithcode(self) -> [Protocol]:
url = f"{self._config.PWC_URL}{self.params['query']}"
query_result = requests.get(
url,
auth=HTTPBasicAuth(self._config.PWC_USER_NAME, self._config.PWC_PASSWORD),
auth=HTTPBasicAuth(
self._config.PWC_USER_NAME, self._config.PWC_PASSWORD
),
)

if query_result.status_code == 200:
Expand All @@ -202,7 +225,9 @@ def _fetch_paperwithcode(self) -> [Protocol]:
for item in content:
data = {
"title": self._unescape(item.get("paper_title", None)),
"description": self._unescape(item.get("paper_abstract", None)),
"description": self._unescape(
item.get("paper_abstract", None)
),
"paper_url": self._unescape(item.get("paper_url", None)),
"num_of_implementations": self._unescape(
item.get("number_of_implementations", None)
Expand All @@ -211,7 +236,9 @@ def _fetch_paperwithcode(self) -> [Protocol]:
"paper_conference": self._unescape(
item.get("paper_conference", None)
),
"repository_url": self._unescape(item.get("repository_url", None)),
"repository_url": self._unescape(
item.get("repository_url", None)
),
"repository_name": self._unescape(
item.get("repository_name", None)
),
Expand Down Expand Up @@ -248,14 +275,15 @@ def _fetch_youtube(self, y_next_page_token=None) -> [Protocol]:
user_query = input_query + self._config.YOUTUBE_QUERY_FILTER

sampled_dev_key = None
if len(self._config.YOUTUBE_DEVELOPER_KEY) > 0:
sampled_dev_key = random.choice(self._config.YOUTUBE_DEVELOPER_KEY)

if not sampled_dev_key:
if not len(self._config.YOUTUBE_DEVELOPER_KEY) > 0:
auth_error = ErrorType(
reason="Empty YouTube Developer Key.", status="400"
)
raise HttpError(auth_error, str.encode("YouTube Developer Key Required."))
raise HttpError(
auth_error, str.encode("YouTube Developer Key Required.")
)

sampled_dev_key = random.choice(self._config.YOUTUBE_DEVELOPER_KEY)

youtube = googleapiclient.discovery.build(
self._config.YOUTUBE_SERVICE_NAME,
Expand All @@ -277,15 +305,21 @@ def _fetch_youtube(self, y_next_page_token=None) -> [Protocol]:
if "items" in response and len(response["items"]) > 0:
for item in response["items"]:
# Skip if the video id is null
if not item.get("id", dict({"videoId": None})).get("videoId", None):
if not item.get("id", dict({"videoId": None})).get(
"videoId", None
):
continue

data = {
"video_id": self._unescape(
item.get("id", dict({"videoId": None})).get("videoId", None)
item.get("id", dict({"videoId": None})).get(
"videoId", None
)
),
"title": self._unescape(
item.get("snippet", dict({"title": None})).get("title", None)
item.get("snippet", dict({"title": None})).get(
"title", None
)
),
"description": self._unescape(
item.get("snippet", dict({"description": None})).get(
Expand All @@ -303,9 +337,9 @@ def _fetch_youtube(self, y_next_page_token=None) -> [Protocol]:
)
),
"live_broadcast_content": self._unescape(
item.get("snippet", dict({"liveBroadcastContent": None})).get(
"liveBroadcastContent", None
)
item.get(
"snippet", dict({"liveBroadcastContent": None})
).get("liveBroadcastContent", None)
),
"published_datetime": self._unescape(
item.get("snippet", dict({"publishedAt": None})).get(
Expand Down Expand Up @@ -350,12 +384,17 @@ def fetch_data(self) -> json:
self.data["content"] = "Access rate limitation reached."

if self.params.get("source", "") == "youtube":
if not self._config.YOUTUBE_ORDER in self._config.VALID_YOUTUBE_ORDER:
if (
not self._config.YOUTUBE_ORDER
in self._config.VALID_YOUTUBE_ORDER
):
self.data["response_code"] = 400
self.data["content"] = "Invalid Youtube Query Order."
return self.data
try:
self._fetch_youtube(self.params.get("y_next_page_token", None))
self._fetch_youtube(
self.params.get("y_next_page_token", None)
)
except HttpError as ex:
print(str(ex))
self.data["response_code"] = 400
Expand Down
6 changes: 3 additions & 3 deletions src/main/python/mlsearch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ class Config(object):
PWC_USER_NAME = os.environ.get("PWC_USER_NAME") or ""
PWC_PASSWORD = os.environ.get("PWC_PASSWORD") or ""
PWC_URL = (
os.environ.get("PWC_URL") or "https://paperswithcode.com/api/v0/search/?q="
os.environ.get("PWC_URL")
or "https://paperswithcode.com/api/v0/search/?q="
)

# Github configuration
GITHUB_ACC_TOKEN = os.environ.get("GITHUB_ACC_TOKEN") or None
GITHUB_URL = os.environ.get("GITHUB_URL") or "in:readme+in:description"

GITHUB_PER_PAGE = os.environ.get("ITEM_PER_PAGE") or 10
# AIP Source
VALID_API_SOURCE = ["paperwithcode", "github", "coursera", "youtube"]

Expand Down Expand Up @@ -43,4 +44,3 @@ class Config(object):
# "videoCount", # This is for channel only
"viewCount",
]