In [None]:
import sys

sys.version

In [None]:
import asyncio
import json
from collections.abc import Iterator
from pathlib import Path
from httpx import ReadTimeout, TimeoutException
import os
import httpx
import logging

from auth import get_token
from params import (
    SridParams,
    TridParams,
    CategoryParams,
    lookup_factory,
    upload_config,
    get_remote_parameter,
    get_local_parameter,
)
from upload import do_upload, AttemptState, set_scopes, store_mdids
from data_cleaning import JsonDataset

In [None]:
class UploadError(Exception):
    pass


class HeartbeatError(Exception):
    pass

In [None]:
try:
    for line in Path("../.env").read_text().splitlines():
        foo, _, bar = line.partition("=")
        os.environ[foo] = bar
except FileNotFoundError:
    # No env file
    pass

log = logging.getLogger("upload_notebook")

In [None]:
# Parameters
shared_data_dir = Path(
    "../tests/"
)  # The path to the shared_data_dir for this notebook (see FlowpyterOperator notes)
dagrun_data_dir = Path("../tests/")
static_dir = Path(
    "../tests"
)  # The path to the static_dir for this notebook (see FlowpyterOperator notes)
JSON_DATA_SUBDIR = Path("jsons")
INDICATORS_TO_UPLOAD = [
    "residents.residents",
    "residents.residents_perKm2",
    "residents.arrived",
    "residents.departed",
    "residents.delta_arrived",
    "residents.residents_diffwithref",
    "residents.abnormality",
    "residents.residents_pctchangewithref",
    "relocations.relocations",
    "relocations.relocations_diffwithref",
    "relocations.abnormality",
    "relocations.relocations_pctchangewithref",
    "presence.presence",
    "presence.presence_perKm2",
    "presence.trips_in",
    "presence.trips_out",
    "presence.abnormality",
    "presence.presence_diffwithref",
    "presence.presence_pctchangewithref",
    "movements.travellers",
    "movements.abnormality",
    "movements.travellers_diffwithref",
    "movements.travellers_pctchangewithref",
]  # indicators to upload in the form 'category.indicator'
CONFIG_STATIC_PATH = "config.json"  # The path to config.json within static_dir
CHUNK_SIZE = 20  # Number of parallel uploads to attempt at once
RETRY_COUNT = 3  # Number of times a chunk of parallel uploads will retry before failing
BASE_URL = "https://api.dev.haiti.mobility-dashboard.org/v1"  # The base URL for the backend api
MDIDS_DATA_PATH = Path("mdids")

In [None]:
# Temporary hack required until we update flowpyter-task to allow list params
if isinstance(INDICATORS_TO_UPLOAD, str):
    INDICATORS_TO_UPLOAD = json.loads(INDICATORS_TO_UPLOAD)

In [None]:
# Airflow variables injected via env vars
AUTH0_CLIENT_ID_ADMIN = os.getenv("ADMIN_CLIENT")  # Admin client id from Auth0
AUTH0_CLIENT_ID_UPDATER = os.getenv("UPDATER_CLIENT")  # Updator client id from Auth0
AUTH0_CLIENT_SECRET_ADMIN = os.getenv("ADMIN_SECRET")  # Admin secret from Auth0
AUTH0_CLIENT_SECRET_UPDATER = os.getenv("UPDATER_SECRET")  # Updator secret from Auth0
AUTH0_DOMAIN = os.getenv(
    "AUTH0_DOMAIN", "flowminder-dev.eu.auth0.com"
)  # Auth0 domain to request tokens from
AUDIENCE = os.getenv(
    "AUDIENCE", "https://flowkit-ui-backend.flowminder.org"
)  # Domain to request tokens for

In [None]:
# Postprocessing of parameters + af vars
shared_data_dir = Path(shared_data_dir)
dagrun_data_dir = Path(dagrun_data_dir)
static_dir = Path(static_dir)
JSON_FOLDER = dagrun_data_dir / JSON_DATA_SUBDIR
CONFIG_PATH = static_dir / CONFIG_STATIC_PATH
MDIDS_PATH = dagrun_data_dir / MDIDS_DATA_PATH
CACHE_FOLDER = dagrun_data_dir / "token_cache"

In [None]:
def load_all_datasets_generator(
    json_folder: Path, filename_pattern: str = "*.json"
) -> Iterator[JsonDataset]:
    for fp in json_folder.glob(filename_pattern):
        data = json.loads(fp.read_text())
        try:
            yield JsonDataset(**data)
        except TypeError as err:
            if all(id not in data.keys() for id in ("category_id", "indicator_id")):
                # Folder may contain json files that are not datasets - ignore them
                log.debug(f"{fp} is not a dataset, skipping")
            else:
                # If json object has "category_id" and "indicator_id" keys then it _is_ a dataset,
                # so if it's not a valid JsonDataset then it must be malformed
                log.error(f"{fp} contains a malformed dataset")
                raise (err)


def filter_datasets_generator(
    datasets: Iterator[JsonDataset], indicators_to_upload: list[str]
) -> Iterator[JsonDataset]:
    indicators_to_upload = set(indicators_to_upload)
    present_indicators = set()
    for ds in datasets:
        # Yield only the indicators that are specified in indicators_to_upload
        if ds.indicator_id in indicators_to_upload:
            present_indicators.add(ds.indicator_id)
            yield ds
    # Warn if any of the specified indicators were not found
    missing_indicators = indicators_to_upload - present_indicators
    if missing_indicators:
        log.warning(f"No data found for the following indicators: {missing_indicators}")
    # Raise error if we did not yield any datasets at all
    if not present_indicators:
        raise UploadError("No indicators present")


async def main():
    try:
        response = httpx.get(f"{BASE_URL}/heartbeat", follow_redirects=True)
    except TimeoutException:
        # If we get a ReadTimeout, the last request caused the server to spin up. Try again now it's awake
        await asyncio.sleep(
            1
        )  # Hey, if we're already async might as well be preemptible.
        response = httpx.get(f"{BASE_URL}/heartbeat", follow_redirects=True)
    if response.status_code >= 300:
        raise HeartbeatError("Heartbeat not found")

    CACHE_FOLDER.mkdir(exist_ok=True)
    admin_token = get_token(
        AUTH0_DOMAIN,
        AUTH0_CLIENT_ID_ADMIN,
        AUTH0_CLIENT_SECRET_ADMIN,
        AUDIENCE,
        CACHE_FOLDER,
    )
    updater_token = get_token(
        AUTH0_DOMAIN,
        AUTH0_CLIENT_ID_UPDATER,
        AUTH0_CLIENT_SECRET_UPDATER,
        AUDIENCE,
        CACHE_FOLDER,
    )

    responses = list(
        get_remote_parameter(p.endpoint, admin_token, BASE_URL)
        for p in [SridParams, TridParams, CategoryParams]
    )
    print([r for r in responses])
    if any(r == [] for r in responses):
        log.warning(
            f"Config not found: loading modifiers from {CONFIG_PATH} and uploading config"
        )
        upload_config(CONFIG_PATH, admin_token, BASE_URL)

    payloads = filter_datasets_generator(
        load_all_datasets_generator(JSON_FOLDER), INDICATORS_TO_UPLOAD
    )

    # TODO maybe: Confirm or refetch srids/trids from server?
    attempts = await do_upload(
        payloads,
        base_url=BASE_URL,
        admin_token=updater_token,
        chunk_size=CHUNK_SIZE,
        retry_count=RETRY_COUNT,
    )

    mdids = (int(a.response.text) for a in attempts)
    log.info("Setting 'read:preview_data' for uploads")
    responses = await set_scopes(mdids, "read:preview_data", BASE_URL, admin_token)
    if any(r.status_code >= 300 for r in responses):
        print([r.json() for r in responses])
        raise UploadError("Scope setting failed")

    MDIDS_PATH.mkdir(exist_ok=True)
    store_mdids(attempts, MDIDS_PATH)


await main()