In [1]:
import contextlib
import httpx
import json

import pyspark.sql.functions as F
import pyspark.sql.types as T

from requests.structures import CaseInsensitiveDict
from tqdm import tqdm
from urllib.parse import parse_qs, urlparse, urlunparse

from pyspark.sql import Row

from authlib.integrations.httpx_client import OAuth2Client
from authlib.oauth2.rfc7523 import ClientSecretJWT


In [2]:
def _build_oauth_client(auth: dict = None) -> httpx.Client:
    if auth is None:
        msg = "Please provide a valid auth specification"
        raise ValueError(msg)
    if not (token_url := auth.get("token-url")):
        msg = "Please provide a token-url in the auth specification"
        raise ValueError(msg)
    if not (client_id := auth.get("client-id")):
        msg = "Please provide a client-id in the auth specification"
        raise ValueError(msg)
    if not (client_secret := auth.get("client-secret")):
        msg = "Please provide a client-secret in the auth specification"
        raise ValueError(msg)

    client = OAuth2Client(
        client_id=client_id,
        client_secret=client_secret,
        token_endpoint_auth_method="client_secret_jwt",
    )
    client.register_client_auth_method(ClientSecretJWT(token_endpoint=token_url))
    client.fetch_token(token_url)

    return client


In [3]:
def _build_basic_auth_client(auth: dict = None) -> httpx.Client:
    if auth is None:
        msg = "Please provide a valid auth specification"
        raise ValueError(msg)
    if not (client_id := auth.get("client-id")):
        msg = "Please provide a client-id in the auth specification"
        raise ValueError(msg)
    if not (client_secret := auth.get("client-secret")):
        msg = "Please provide a client-secret in the auth specification"
        raise ValueError(msg)

    basic_auth = httpx.BasicAuth(username=client_id, password=client_secret)
    transport = httpx.HTTPTransport(retries=5)
    client = httpx.Client(transport=transport, auth=basic_auth)

    return client


In [4]:
def _build_no_auth_client() -> httpx.Client:
    transport = httpx.HTTPTransport(retries=5)
    return httpx.Client(transport=transport)


In [5]:
def _build_client(auth: dict = None) -> httpx.Client:
    if auth is None:
        msg = "Please specify an authentication type"
        raise ValueError(msg)

    __NO_AUTH__ = "NO_AUTH"
    __BASIC_AUTH__ = "BASIC"
    __OAUTH2__ = "OAUTH2"

    auth_type = str(auth.get("type")).upper()

    if auth_type == __NO_AUTH__:
        return _build_no_auth_client()
    if auth_type == __BASIC_AUTH__:
        return _build_basic_auth_client(auth)
    if auth_type == __OAUTH2__:
        return _build_oauth_client(auth)

    msg = f"Expected auth['type'] in ['NO_AUTH', 'BASIC', 'OAUTH2']. Got: {auth_type}"
    raise ValueError(msg)


In [6]:
def _split_params_from_url(url: str) -> tuple[str, dict]:
    parsed_url = urlparse(url)
    stripped_params = parse_qs(parsed_url.query)
    stripped_url = urlunparse(
        (
            parsed_url.scheme,
            parsed_url.netloc,
            parsed_url.path,
            "",
            "",
            parsed_url.fragment,
        )
    )

    return stripped_url, stripped_params

In [7]:
def _update_url_and_params(url: str, params: dict = None) -> tuple[str, dict]:
    stripped_url, stripped_params = _split_params_from_url(url)
    stripped_params.update(params)

    return stripped_url, stripped_params

In [8]:
def _add_data_for_get_request(data: str | dict) -> dict:
    kwargs = {}

    if data is None:
        return kwargs

    if isinstance(data, dict):
        kwargs["data"] = json.dumps(data)
        return kwargs

    elif isinstance(data, str):
        kwargs["data"] = data
        return kwargs

    type_error_msg = f"Expected data to be either str or dict. Got {type(data)}"
    raise TypeError(type_error_msg)


In [9]:
def _add_data_for_post_request(data: str | dict) -> dict:
    kwargs = {}

    if data is None:
        return kwargs

    if isinstance(data, dict):
        kwargs["json"] = data
        return kwargs

    if isinstance(data, str):
        kwargs["data"] = data
        return kwargs

    type_error_msg = f"Expected data to be either str or dict. Got {type(data)}"
    raise TypeError(type_error_msg)


In [10]:
def _prepare_request(
    url: str,
    headers: dict,
    params: dict,
    options: dict,
    data: str | dict,
    method: str,
) -> dict:
    kwargs = {"url": url, "headers": headers, "params": params, "method": method}
    kwargs.update(options)

    if method.upper() == "GET":
        kwargs.update(_add_data_for_get_request(data))

    if method.upper() == "POST":
        kwargs.update(_add_data_for_post_request(data))

    return kwargs


In [11]:
def _batch_request(
    client: httpx.Client,
    url: str,
    headers: dict,
    params: dict,
    options: dict,
    data: str | dict,
    method: str,
) -> httpx.Response:
    kwargs = _prepare_request(
        url=url,
        headers=headers,
        params=params,
        options=options,
        data=data,
        method=method,
    )

    return client.request(**kwargs)


In [12]:
def _parse_api_response(res: httpx.Response) -> list[dict]:
    res_json = None

    with contextlib.suppress(json.JSONDecodeError):
        res_json = res.json()

    if not res_json:
        res_json = {"payload": res.text}

    if not isinstance(res_json, list):
        res_json = [res_json]

    return res_json


In [13]:
def batch_request(
    url: str,
    headers: dict = None,
    auth: dict = None,
    params: dict = None,
    data: str | dict = None,
    method: str = "GET",
    options: dict = None,
) -> list[dict]:
    headers = headers or {}
    params = params or {}
    # For options, see: https://www.python-httpx.org/api/
    options = options or {}
    auth = auth or {"type": "NO_AUTH"}

    stripped_url, stripped_params = _update_url_and_params(url=url, params=params)
    client = _build_client(auth)

    res = _batch_request(
        client=client,
        url=stripped_url,
        headers=headers,
        params=stripped_params,
        options=options,
        data=data,
        method=method,
    )

    return _parse_api_response(res)


In [14]:
def _streaming_request(
    client: httpx.Client,
    url: str,
    headers: dict,
    params: dict,
    options: dict,
    data: str | dict,
    method: str,
    target_path: str,
) -> httpx.Response:
    kwargs = _prepare_request(
        url=url,
        headers=headers,
        params=params,
        options=options,
        data=data,
        method=method,
    )
    with open(target_path, "+wb") as download_file:
        with client.stream(**kwargs) as res:
            res_headers = CaseInsensitiveDict(res.headers)
            total = int(res_headers.get("content-length", "0"))

            with tqdm(
                total=total, unit_scale=True, unit_divisor=1024, unit="B"
            ) as progress:
                num_bytes_downloaded = res.num_bytes_downloaded
                for chunk in res.iter_bytes():
                    download_file.write(chunk)
                    progress.update(res.num_bytes_downloaded - num_bytes_downloaded)
                    num_bytes_downloaded = res.num_bytes_downloaded

    return target_path

100%|██████████| 100M/100M [00:14<00:00, 7.20MB/s] 


In [15]:
def streaming_request(
    url: str,
    headers: dict = None,
    auth: dict = None,
    params: dict = None,
    data: str | dict = None,
    method: str = "GET",
    options: dict = None,
    target_path: str = None,
) -> list[dict]:
    headers = headers or {}
    params = params or {}
    # For options, see: https://www.python-httpx.org/api/
    options = options or {}
    auth = auth or {"type": "NO_AUTH"}

    stripped_url, stripped_params = _update_url_and_params(url=url, params=params)
    client = _build_client(auth)

    return _streaming_request(
        client=client,
        url=stripped_url,
        headers=headers,
        params=stripped_params,
        options=options,
        data=data,
        method=method,
        target_path=target_path,
    )

In [21]:
r = batch_request(url="https://reqres.in/api/users", params={"page": 2}, method="GET")
r

[{'page': 2,
  'per_page': 6,
  'total': 12,
  'total_pages': 2,
  'data': [{'id': 7,
    'email': 'michael.lawson@reqres.in',
    'first_name': 'Michael',
    'last_name': 'Lawson',
    'avatar': 'https://reqres.in/img/faces/7-image.jpg'},
   {'id': 8,
    'email': 'lindsay.ferguson@reqres.in',
    'first_name': 'Lindsay',
    'last_name': 'Ferguson',
    'avatar': 'https://reqres.in/img/faces/8-image.jpg'},
   {'id': 9,
    'email': 'tobias.funke@reqres.in',
    'first_name': 'Tobias',
    'last_name': 'Funke',
    'avatar': 'https://reqres.in/img/faces/9-image.jpg'},
   {'id': 10,
    'email': 'byron.fields@reqres.in',
    'first_name': 'Byron',
    'last_name': 'Fields',
    'avatar': 'https://reqres.in/img/faces/10-image.jpg'},
   {'id': 11,
    'email': 'george.edwards@reqres.in',
    'first_name': 'George',
    'last_name': 'Edwards',
    'avatar': 'https://reqres.in/img/faces/11-image.jpg'},
   {'id': 12,
    'email': 'rachel.howell@reqres.in',
    'first_name': 'Rachel',
    '

In [17]:
r = batch_request(
    url="https://reqres.in/api/users",
    data={"name": "Gigi Par", "job": "smasher"},
    method="POST",
)
r


[{'name': 'Gigi Par',
  'job': 'smasher',
  'id': '378',
  'createdAt': '2023-07-19T15:15:36.604Z'}]

In [18]:
r = batch_request(
    url="https://reqbin.com/echo",
    data={"name": "Gigi Par", "job": "smasher"},
    method="POST",
)
r


[{'payload': '<!DOCTYPE html>\n<!--[if lt IE 7]> <html class="no-js ie6 oldie" lang="en-US"> <![endif]-->\n<!--[if IE 7]>    <html class="no-js ie7 oldie" lang="en-US"> <![endif]-->\n<!--[if IE 8]>    <html class="no-js ie8 oldie" lang="en-US"> <![endif]-->\n<!--[if gt IE 8]><!--> <html class="no-js" lang="en-US"> <!--<![endif]-->\n<head>\n<title>Attention Required! | Cloudflare</title>\n<meta charset="UTF-8" />\n<meta http-equiv="Content-Type" content="text/html; charset=UTF-8" />\n<meta http-equiv="X-UA-Compatible" content="IE=Edge" />\n<meta name="robots" content="noindex, nofollow" />\n<meta name="viewport" content="width=device-width,initial-scale=1" />\n<link rel="stylesheet" id="cf_styles-css" href="/cdn-cgi/styles/cf.errors.css" />\n<!--[if lt IE 9]><link rel="stylesheet" id=\'cf_styles-ie-css\' href="/cdn-cgi/styles/cf.errors.ie.css" /><![endif]-->\n<style>body{margin:0;padding:0}</style>\n\n\n<!--[if gte IE 10]><!-->\n<script>\n  if (!navigator.cookieEnabled) {\n    window.ad

In [19]:
r = streaming_request(
    url="https://reqres.in/api/users",
    params={"page": 2},
    method="GET",
    target_path="./response.json",
)
r

379B [00:00, 1.04MB/s]


'./response.json'

In [20]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

23/07/19 17:36:29 WARN Utils: Your hostname, laptop resolves to a loopback address: 127.0.1.1; using 192.168.2.5 instead (on interface wlan0)
23/07/19 17:36:29 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/07/19 17:36:30 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [25]:
df = spark.createDataFrame(r).toPandas()
df


Unnamed: 0,data,page,per_page,support,total,total_pages
0,"[{'last_name': None, 'avatar': None, 'id': 7, ...",2,6,"{'url': 'https://reqres.in/#support-heading', ...",12,2
