In [None]:
# | default_exp data.datasource

In [None]:
from airt.testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.
[INFO] numexpr.utils: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.
[INFO] airt.keras.helpers: Using a single GPU #0 with memory_limit 1024 MB


In [None]:
# | export

from pathlib import Path
from typing import *

import dask.dataframe as dd
import pandas as pd
from airt.logger import get_logger
from airt.patching import patch
from airt.remote_path import RemotePath
from checksumdir import dirhash
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.exc import NoResultFound
from sqlmodel import Session, select

import airt_service.sanitizer

# import airt_service
from airt_service.auth import get_current_active_user
from airt_service.constants import DS_HEAD_FILE_NAME, METADATA_FOLDER_PATH
from airt_service.data.utils import delete_data_object_files_in_cloud
from airt_service.db.models import (
    DataBlob,
    DataSource,
    DataSourceRead,
    Tag,
    TagCreate,
    User,
    get_session,
)
from airt_service.errors import ERRORS, HTTPError
from airt_service.helpers import commit_or_rollback, df_to_dict

In [None]:
import json
import shutil
from os import environ
from time import sleep

import pytest
import requests

from airt_service.aws.utils import create_s3_datasource_path, upload_to_s3_with_retry
from airt_service.data.csv import process_csv
from airt_service.data.datablob import FromLocalRequest, from_local_start_route
from airt_service.data.utils import create_db_uri_for_s3_datablob
from airt_service.db.models import create_user_for_testing, get_session_with_context
from airt_service.helpers import dict_to_df
from airt_service.users import (
    ActivateMFARequest,
    activate_mfa,
    disable_mfa,
    generate_mfa_url,
)

[INFO] airt.data.importers: Module loaded:
[INFO] airt.data.importers:  - using pandas     : 1.4.4
[INFO] airt.data.importers:  - using dask       : 2022.9.0
[INFO] airt.executor.subcommand: Module loaded.


In [None]:
test_username = create_user_for_testing()
display(test_username)

'mauodnfciq'

In [None]:
# | exporti

logger = get_logger(__name__)

In [None]:
logger.info("log a random string")

[INFO] __main__: log a random string


In [None]:
# | export

# Default router for all data sources
datasource_router = APIRouter(
    prefix="/datasource",
    tags=["datasource"],
    #     dependencies=[Depends(get_current_active_user)],
    responses={
        404: {"description": "Not found"},
        500: {
            "model": HTTPError,
            "description": ERRORS["INTERNAL_SERVER_ERROR"],
        },
    },
)

In [None]:
# | exporti


@patch
def calculate_properties(self: DataSource, cache_path: Path):
    """Calculate properties of datasource like no of rows, dtypes, head, hash from parquet files

    Args:
        cache_path: Cache folder path containing the synced parquet files
    """
    self.hash = dirhash(cache_path, hashfunc="md5")

    ddf = dd.read_parquet(cache_path)
    self.no_of_rows = ddf.shape[0].compute()

    metadata_folder_path = cache_path / METADATA_FOLDER_PATH
    metadata_folder_path.mkdir(exist_ok=True)

    head = ddf.head(10)
    head.to_parquet(metadata_folder_path / DS_HEAD_FILE_NAME)

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    datasource = DataSource(
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
    )
    session.add(datasource)
    session.commit()
    session.refresh(datasource)

    assert not datasource.no_of_rows
    assert not datasource.path
    assert not datasource.hash

    with RemotePath.from_url(
        remote_url=f"s3://test-airt-service/account_312571_events",
        pull_on_enter=True,
        push_on_exit=False,
        exist_ok=True,
        parents=False,
        access_key=environ["AWS_ACCESS_KEY_ID"],
        secret_key=environ["AWS_SECRET_ACCESS_KEY"],
    ) as test_s3_path:
        processed_destination_path = test_s3_path.as_path()
        ddf = dd.read_parquet(processed_destination_path)
        expected_head = ddf.head(n=10)

        datasource.calculate_properties(cache_path=processed_destination_path)

        head_path = (
            processed_destination_path / METADATA_FOLDER_PATH / DS_HEAD_FILE_NAME
        )
        actual_head = pd.read_parquet(head_path)

    pd.testing.assert_frame_equal(actual_head, expected_head)
    # assert datasource.hash == "2f96b39df0f1f71a05d3ff5509c160e7", datasource.hash
    assert len(datasource.hash) == 32, len(datasource.hash)
    assert datasource.no_of_rows == 498961

[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://test-airt-service/account_312571_events
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3test-airt-serviceaccount_312571_events_cached_7c4qpy_w
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://test-airt-service/account_312571_events locally in /tmp/s3test-airt-serviceaccount_312571_events_cached_7c4qpy_w
[INFO] airt.remote_path: S3Path.__enter__(): pulling data from s3://test-airt-service/account_312571_events to /tmp/s3test-airt-serviceaccount_312571_events_cached_7c4qpy_w
[INFO] airt.remote_path: S3Path._clean_up(): removing local cache path /tmp/s3test-airt-serviceaccount_312571_events_cached_7c4qpy_w


In [None]:
# | exporti


@patch
def remove_tag_from_previous_datasources(
    self: DataSource, tag_name: str, session: Session
):
    """Remove the tag_name associated with other/previous datasources

    Args:
        tag_name: Tag name to remove from other datasources
        session: Sqlmodel session
    """
    tag_to_remove = Tag.get_by_name(name=tag_name, session=session)
    try:
        datasources = session.exec(
            select(DataSource).where(
                DataSource.datablob == self.datablob,
                DataSource.user == self.user,
            )
        ).all()
    except NoResultFound:
        return

    for datasource in datasources:
        if tag_to_remove in datasource.tags:
            datasource.tags.remove(tag_to_remove)
            session.add(datasource)
            session.commit()

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    test_tag = Tag.get_by_name(name="test", session=session)

    uri = "s3://bucket"
    db_uri = create_db_uri_for_s3_datablob(
        uri=uri,
        access_key="access",
        secret_key="secret",
    )
    datablob_with_tag = DataBlob(
        type="s3",
        uri=db_uri,
        source=uri,
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
        tags=[test_tag],
    )
    datasource_with_tag = DataSource(
        datablob=datablob_with_tag,
        cloud_provider=datablob_with_tag.cloud_provider,
        region=datablob_with_tag.region,
        total_steps=1,
        user=user,
        tags=datablob_with_tag.tags,
    )
    session.add(datasource_with_tag)
    session.commit()
    session.refresh(datasource_with_tag)
    display(datasource_with_tag)

    with commit_or_rollback(session):
        new_ds_with_test_tag = DataSource(
            datablob=datablob_with_tag,
            cloud_provider=datablob_with_tag.cloud_provider,
            region=datablob_with_tag.region,
            total_steps=1,
            user=user,
        )
        new_ds_with_test_tag.remove_tag_from_previous_datasources(
            tag_name=test_tag.name, session=session
        )

        datasource_without_test_tag = session.exec(
            select(DataSource).where(DataSource.uuid == datasource_with_tag.uuid)
        ).one()
        display(datasource_without_test_tag)
        session.add(datasource_without_test_tag)
        assert test_tag not in datasource_without_test_tag
        assert not datasource_without_test_tag.tags

        new_ds_with_test_tag.tags.append(test_tag)
        session.add(new_ds_with_test_tag)
        display(new_ds_with_test_tag)

DataSource(id=67, uuid=UUID('fe8af992-4c27-469a-9204-6a48aa5c3821'), hash=None, total_steps=1, completed_steps=0, folder_size=None, no_of_rows=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 10, 19, 10, 38, 16), user_id=92, pulled_on=None, tags=[Tag(name='test', id=1, created=datetime.datetime(2022, 10, 18, 12, 28, 54), uuid=UUID('e9a39788-684b-455d-b15d-32b485206855'))])

DataSource(id=67, uuid=UUID('fe8af992-4c27-469a-9204-6a48aa5c3821'), hash=None, total_steps=1, completed_steps=0, folder_size=None, no_of_rows=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 10, 19, 10, 38, 16), user_id=92, pulled_on=None, tags=[])

DataSource(id=68, uuid=UUID('06281d8d-22da-4d80-a9c8-995fa97da773'), hash=None, total_steps=1, completed_steps=0, folder_size=None, no_of_rows=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 10, 19, 10, 38, 16), user_id=92, pulled_on=None, tags=[Tag()])

In [None]:
# | exporti

get_datasource_responses = {
    400: {"model": HTTPError, "description": ERRORS["INCORRECT_DATASOURCE_ID"]},
    422: {"model": HTTPError, "description": "DataSource error"},
}


@patch(cls_method=True)
def get(cls: DataSource, uuid: str, user: User, session: Session) -> DataSource:
    """Get datasource based on uuid

    Args:
        uuid: Datasource uuid
        user: User object
        session: Sqlmodel session

    Returns:
        Datasource object for given datasource uuid

    Raises:
        HTTPException: if datasource id is incorrect or if datasource is deleted
    """
    try:
        datasource = session.exec(
            select(DataSource).where(DataSource.uuid == uuid, DataSource.user == user)
        ).one()
    except NoResultFound:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["INCORRECT_DATASOURCE_ID"],
        )

    if datasource.disabled:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["DATASOURCE_IS_DELETED"],
        )

    if datasource.error is not None:
        raise HTTPException(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=datasource.error
        )

    return datasource

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    expected = user.datasources[0]

    actual = DataSource.get(uuid=expected.uuid, user=user, session=session)
    display(actual)
    assert actual == expected

    with pytest.raises(HTTPException) as e:
        DataSource.get(
            uuid="00000000-0000-0000-0000-000000000000", user=user, session=session
        )
    display(e)

    user_kumaran = session.exec(select(User).where(User.username == "kumaran")).one()
    with pytest.raises(HTTPException) as e:
        DataSource.get(uuid=expected.uuid, user=user_kumaran, session=session)
    display(e)

DataSource(id=65, uuid=UUID('6aacca90-cf55-433a-b1b6-d850e1fbce8d'), hash=None, total_steps=1, completed_steps=0, folder_size=None, no_of_rows=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 10, 19, 10, 38, 11), user_id=92, pulled_on=None, tags=[])

<ExceptionInfo HTTPException(status_code=400, detail='Incorrect datasource id') tblen=2>

<ExceptionInfo HTTPException(status_code=400, detail='Incorrect datasource id') tblen=2>

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    uri = "s3://"
    datablob = DataBlob(
        type="s3",
        uri=create_db_uri_for_s3_datablob(
            uri=uri, access_key="access", secret_key="secret"
        ),
        source=uri,
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
    )
    datasource_disabled = DataSource(
        datablob=datablob,
        cloud_provider=datablob.cloud_provider,
        region=datablob.region,
        total_steps=1,
        user=user,
        disabled=True,
    )
    session.add(datasource_disabled)
    session.commit()

    with pytest.raises(HTTPException) as e:
        DataSource.get(uuid=datasource_disabled.uuid, user=user, session=session)
    display(e)

<ExceptionInfo HTTPException(status_code=400, detail='Datasource is deleted') tblen=2>

In [None]:
# | export


@datasource_router.get(
    "/{datasource_uuid}",
    response_model=DataSourceRead,
    responses=get_datasource_responses,  # type: ignore
)
def get_details_of_datasource(
    datasource_uuid: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> DataSource:
    """Get details of the datasource"""
    user = session.merge(user)
    datasource = DataSource.get(uuid=datasource_uuid, user=user, session=session)  # type: ignore

    return datasource

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    expected = user.datasources[0]

    actual = get_details_of_datasource(
        datasource_uuid=expected.uuid, user=user, session=session
    )
    display(actual)
    assert actual == expected

DataSource(id=65, uuid=UUID('6aacca90-cf55-433a-b1b6-d850e1fbce8d'), hash=None, total_steps=1, completed_steps=0, folder_size=None, no_of_rows=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 10, 19, 10, 38, 11), user_id=92, pulled_on=None, tags=[])

In [None]:
# Datasource errored out while importing
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    uri = "s3://"
    datasource_errored = DataSource(
        datablob=DataBlob(
            type="s3",
            uri=create_db_uri_for_s3_datablob(
                uri=uri, access_key="access", secret_key="secret"
            ),
            cloud_provider="aws",
            region="eu-west-1",
            total_steps=1,
            source=uri,
            user=user,
        ),
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
        error="test error",
    )
    session.add(datasource_errored)
    session.commit()

    with pytest.raises(HTTPException) as e:
        get_details_of_datasource(
            datasource_uuid=datasource_errored.uuid, user=user, session=session
        )
    display(e)

<ExceptionInfo HTTPException(status_code=422, detail='test error') tblen=3>

In [None]:
# | exporti


@patch
def delete(self: DataSource, user: User, session: Session):
    """Delete a datasource

    Args:
        user: User object
        session: Sqlmodel session
    """
    delete_data_object_files_in_cloud(data_object=self)
    self.disabled = True

    with commit_or_rollback(session):
        session.add(self)

    return self

In [None]:
# | export


@datasource_router.delete(
    "/{datasource_uuid}",
    response_model=DataSourceRead,
    responses=get_datasource_responses,  # type: ignore
)
def delete_datasource(
    datasource_uuid: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> DataSource:
    """Delete datasource"""
    user = session.merge(user)
    datasource = DataSource.get(uuid=datasource_uuid, user=user, session=session)  # type: ignore

    return datasource.delete(user, session)

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    uri = "s3://"
    datasource = DataSource(
        datablob=DataBlob(
            type="s3",
            uri=create_db_uri_for_s3_datablob(
                uri=uri, access_key="access", secret_key="secret"
            ),
            source=uri,
            cloud_provider="aws",
            region="eu-west-1",
            total_steps=1,
            user=user,
        ),
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
    )
    session.add(datasource)
    session.commit()

    with RemotePath.from_url(
        remote_url=f"s3://test-airt-service/account_312571_events",
        pull_on_enter=True,
        push_on_exit=False,
        exist_ok=True,
        parents=False,
    ) as source_s3_path:
        bucket, s3_path = create_s3_datasource_path(
            user_id=user.id, datasource_id=datasource.id, region=datasource.region
        )
        sleep(10)
        with RemotePath.from_url(
            remote_url=f"s3://{bucket.name}/{s3_path}",
            pull_on_enter=False,
            push_on_exit=True,
            exist_ok=True,
            parents=True,
        ) as destination_s3_path:
            shutil.copytree(
                source_s3_path.as_path(),
                destination_s3_path.as_path(),
                dirs_exist_ok=True,
            )
with get_session_with_context() as session:
    datasource = session.exec(
        select(DataSource).where(DataSource.id == datasource.id)
    ).one()
    display(datasource)
    actual = delete_datasource(
        datasource_uuid=datasource.uuid, user=user, session=session
    )
    display(actual)
    assert actual.disabled == True

[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://test-airt-service/account_312571_events
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3test-airt-serviceaccount_312571_events_cached_9701xh6v
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://test-airt-service/account_312571_events locally in /tmp/s3test-airt-serviceaccount_312571_events_cached_9701xh6v
[INFO] airt.remote_path: S3Path.__enter__(): pulling data from s3://test-airt-service/account_312571_events to /tmp/s3test-airt-serviceaccount_312571_events_cached_9701xh6v
[INFO] botocore.credentials: Found credentials in environment variables.
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/92/datasource/71
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-192datasource71_cached_wwqowm62
[IN

DataSource(id=71, uuid=UUID('d163e287-9d79-43a4-8104-d6f8938305f0'), hash=None, total_steps=1, completed_steps=0, folder_size=None, no_of_rows=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 10, 19, 10, 38, 17), user_id=92, pulled_on=None, tags=[])

DataSource(id=71, uuid=UUID('d163e287-9d79-43a4-8104-d6f8938305f0'), hash=None, total_steps=1, completed_steps=0, folder_size=None, no_of_rows=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=True, path=None, created=datetime.datetime(2022, 10, 19, 10, 38, 17), user_id=92, pulled_on=None, tags=[])

In [None]:
# Create and pull datasource to use in following tests
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    from_local_request = FromLocalRequest(
        path="tmp/test-folder/", tag="my_csv_datasource_tag"
    )
    from_local_response = from_local_start_route(
        from_local_request=from_local_request,
        user=user,
        session=session,
    )

    with RemotePath.from_url(
        remote_url=f"s3://test-airt-service/account_312571_events",
        pull_on_enter=True,
        push_on_exit=False,
        exist_ok=True,
        parents=False,
        access_key=environ["AWS_ACCESS_KEY_ID"],
        secret_key=environ["AWS_SECRET_ACCESS_KEY"],
    ) as test_s3_path:
        df = pd.read_parquet(test_s3_path.as_path())
        display(df.head())
        df.to_csv(test_s3_path.as_path() / "file.csv", index=False)
        display(list(test_s3_path.as_path().glob("*")))
        !head -n 10 {test_s3_path.as_path()/"file.csv"}

        upload_to_s3_with_retry(
            test_s3_path.as_path() / "file.csv",
            from_local_response.presigned["url"],
            from_local_response.presigned["fields"],
        )

    datablob_id = (
        session.exec(select(DataBlob).where(DataBlob.uuid == from_local_response.uuid))
        .one()
        .id
    )
    datasource = DataSource(
        datablob_id=datablob_id,
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
    )
    session.add(datasource)
    session.commit()

    datasource_id = (
        session.exec(select(DataSource).where(DataSource.uuid == datasource.uuid))
        .one()
        .id
    )
    process_csv(
        datablob_id=datablob_id,
        datasource_id=datasource_id,
        deduplicate_data=True,
        index_column="PersonId",
        sort_by="OccurredTime",
        blocksize="256MB",
        kwargs_json=json.dumps(
            dict(
                usecols=[0, 1, 2, 3, 4],
                parse_dates=["OccurredTime"],
            )
        ),
    )

with get_session_with_context() as session:
    datasource_synced = session.exec(
        select(DataSource).where(DataSource.id == datasource_id)
    ).one()
    display(datasource_synced)

[INFO] airt_service.data.datablob: DataBlob.from_local(): FromLocalResponse(uuid=UUID('e8cb121a-66a5-4a1f-8cc5-2bb80c35665f'), type='local', presigned={'url': 'https://kumaran-airt-service-eu-west-1.s3.amazonaws.com/', 'fields': {'key': '****************************************', 'AWSAccessKeyId': '********************', 'policy': '************************************************************************************************************************************************************************************************************************************************************', 'signature': '****************************'}})
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://test-airt-service/account_312571_events
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3test-airt-serviceaccount_312571_events_cached_rtn1ty_8
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://test-airt-ser

Unnamed: 0_level_0,AccountId,DefinitionId,OccurredTime,OccurredTimeTicks,PersonId
__null_dask_index__,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0,312571,loadTests2,2019-12-31 21:30:02,1577836802678,2
1,312571,loadTests3,2020-01-03 23:53:22,1578104602678,2
2,312571,loadTests1,2020-01-07 02:16:42,1578372402678,2
3,312571,loadTests2,2020-01-10 04:40:02,1578640202678,2
4,312571,loadTests3,2020-01-13 07:03:22,1578908002678,2


[Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_rtn1ty_8/_metadata'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_rtn1ty_8/_common_metadata'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_rtn1ty_8/file.csv'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_rtn1ty_8/part.3.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_rtn1ty_8/part.0.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_rtn1ty_8/part.1.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_rtn1ty_8/part.4.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_rtn1ty_8/part.2.parquet')]

AccountId,DefinitionId,OccurredTime,OccurredTimeTicks,PersonId
312571,loadTests2,2019-12-31 21:30:02,1577836802678,2
312571,loadTests3,2020-01-03 23:53:22,1578104602678,2
312571,loadTests1,2020-01-07 02:16:42,1578372402678,2
312571,loadTests2,2020-01-10 04:40:02,1578640202678,2
312571,loadTests3,2020-01-13 07:03:22,1578908002678,2
312571,loadTests1,2020-01-16 09:26:42,1579175802678,2
312571,loadTests2,2020-01-19 11:50:02,1579443602678,2
312571,loadTests3,2020-01-22 14:13:22,1579711402678,2
312571,loadTests1,2020-01-25 16:36:42,1579979202678,2
[INFO] airt.remote_path: S3Path._clean_up(): removing local cache path /tmp/s3test-airt-serviceaccount_312571_events_cached_rtn1ty_8
[INFO] airt_service.data.csv: process_csv(datablob_id=296, datasource_id=74): processing user uploaded csv file for datablob_id=296 and uploading parquet back to S3 for datasource_id=74
[INFO] airt_service.data.csv: process_csv(datablob_id=296, datasource_id=74): step 1/4: downloading user uploaded file from bucket s

Perhaps you already have a cluster running?
Hosting the HTTP server on port 45067 instead


[INFO] airt.dask_manager: Cluster started: <Client: 'tcp://127.0.0.1:45165' processes=8 threads=8, memory=22.89 GiB>
Cluster dashboard: http://127.0.0.1:45067/status
[INFO] airt.data.importers: import_csv(): step 1/5: importing data and storing it into partitioned Parquet files
[INFO] airt.data.importers:  - number of rows: 498,961
[INFO] airt.dask_manager: Starting stopping cluster...
[INFO] airt.dask_manager: Cluster stopped
[INFO] airt.dask_manager: Starting cluster...


Perhaps you already have a cluster running?
Hosting the HTTP server on port 40243 instead


[INFO] airt.dask_manager: Cluster started: <Client: 'tcp://127.0.0.1:33627' processes=4 threads=4, memory=22.89 GiB>
Cluster dashboard: http://127.0.0.1:40243/status
[INFO] airt.data.importers: import_csv(): step 2/5: indexing data by PersonId.
[INFO] airt.data.importers:  - number of rows: 498,961
[INFO] airt.dask_manager: Starting stopping cluster...
[INFO] airt.dask_manager: Cluster stopped
[INFO] airt.dask_manager: Starting cluster...


Perhaps you already have a cluster running?
Hosting the HTTP server on port 41435 instead


[INFO] airt.dask_manager: Cluster started: <Client: 'tcp://127.0.0.1:46535' processes=8 threads=8, memory=22.89 GiB>
Cluster dashboard: http://127.0.0.1:41435/status
[INFO] airt.data.importers: import_csv(): step 3/5: deduplicating and sorting data by PersonId and OccurredTime.
[INFO] airt.data.importers:  - number of rows: 498,961
[INFO] airt.data.importers: import_csv(): step 4/5: repartitioning data.
[INFO] airt.data.importers:  - number of rows: 498,961
[INFO] airt.data.importers: import_csv(): step 5/5: sorting data by PersonId and OccurredTime.
[INFO] airt.data.importers:  - number of rows: 498,961
[INFO] airt.data.importers: import_csv(): completed, the final data is stored in /tmp/s3kumaran-airt-service-eu-west-192datasource74_cached_jro6198u as Parquet files with:
[INFO] airt.data.importers:  - dtypes={'AccountId': dtype('int64'), 'DefinitionId': dtype('O'), 'OccurredTime': dtype('<M8[ns]'), 'OccurredTimeTicks': dtype('int64')}
[INFO] airt.data.importers:  - npartitions=1
[INF

DataSource(id=74, uuid=UUID('c3118331-0df9-4061-9efd-05922e545d5d'), hash='64ab63985d6651f495ddccd4d96d16cb', total_steps=1, completed_steps=1, folder_size=6619982, no_of_rows=498961, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path='s3://kumaran-airt-service-eu-west-1/92/datasource/74', created=datetime.datetime(2022, 10, 19, 10, 38, 53), user_id=92, pulled_on=datetime.datetime(2022, 10, 19, 10, 39, 9), tags=[])

In [None]:
# | export


def _get_ds_head_and_dtypes(datasource_s3_path: str) -> Dict[str, Any]:
    """Read the head metadata file and return its contents as a dict

    Args:
        datasource_s3_path: Input datasource S3 path

    Returns:
        The head along with its dtypes as a dict
    """
    s3_metadata_path = f"{datasource_s3_path}/{METADATA_FOLDER_PATH}"

    with RemotePath.from_url(
        remote_url=s3_metadata_path,
        push_on_exit=False,
        exist_ok=True,
        parents=False,
    ) as local_metadata_path:
        processed_local_metadata_path = local_metadata_path.as_path()

        df = pd.read_parquet(processed_local_metadata_path / DS_HEAD_FILE_NAME)
        return df_to_dict(df)

In [None]:
df_dict = _get_ds_head_and_dtypes(datasource_synced.path)

assert df_dict["data"]["columns"] == [
    "AccountId",
    "DefinitionId",
    "OccurredTime",
    "OccurredTimeTicks",
]
assert df_dict["data"]["index_names"] == ["PersonId"]
assert df_dict["dtypes"] == {
    "AccountId": "int64",
    "DefinitionId": "object",
    "OccurredTime": "datetime64[ns]",
    "OccurredTimeTicks": "int64",
}
assert len(df_dict["data"]["data"]) == 10
assert len(df_dict) == 2

[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/92/datasource/74/.metadata_by_airt
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-192datasource74metadata_by_airt_cached_o_ymhbfj
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-1/92/datasource/74/.metadata_by_airt locally in /tmp/s3kumaran-airt-service-eu-west-192datasource74metadata_by_airt_cached_o_ymhbfj
[INFO] airt.remote_path: S3Path.__enter__(): pulling data from s3://kumaran-airt-service-eu-west-1/92/datasource/74/.metadata_by_airt to /tmp/s3kumaran-airt-service-eu-west-192datasource74metadata_by_airt_cached_o_ymhbfj
[INFO] airt.remote_path: S3Path._clean_up(): removing local cache path /tmp/s3kumaran-airt-service-eu-west-192datasource74metadata_by_airt_cached_o_ymhbfj


In [None]:
# | exporti


@patch
def is_ready(self: DataSource):
    """Check if the datasource's completed steps equal to total steps, else raise HTTPException"""
    if self.completed_steps != self.total_steps:
        raise HTTPException(
            status_code=status.HTTP_412_PRECONDITION_FAILED,
            detail=ERRORS["DATASOURCE_IS_NOT_PULLED"],
        )

In [None]:
# | export


@datasource_router.get(
    "/{datasource_uuid}/head",
    responses={
        **get_datasource_responses,  # type: ignore
        412: {"model": HTTPError, "description": ERRORS["DATASOURCE_IS_NOT_PULLED"]},
    },
)
def datasource_head_route(
    datasource_uuid: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> Dict[str, List[Any]]:
    """Get head of the datasource"""
    user = session.merge(user)
    datasource = DataSource.get(uuid=datasource_uuid, user=user, session=session)  # type: ignore
    datasource.is_ready()

    df_dict = _get_ds_head_and_dtypes(datasource_s3_path=datasource.path)
    return df_dict

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    datasource_synced = session.exec(
        select(DataSource).where(DataSource.id == datasource_synced.id)
    ).one()

    actual = datasource_head_route(
        datasource_uuid=datasource_synced.uuid, user=user, session=session
    )
    assert isinstance(actual, dict)

    actual_df = dict_to_df(actual)
    assert actual_df.index.name == "PersonId"
    pd.testing.assert_series_equal(actual_df.dtypes, pd.Series(actual["dtypes"]))

    display(actual_df)

[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/92/datasource/74/.metadata_by_airt
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-192datasource74metadata_by_airt_cached_0baie8pi
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-1/92/datasource/74/.metadata_by_airt locally in /tmp/s3kumaran-airt-service-eu-west-192datasource74metadata_by_airt_cached_0baie8pi
[INFO] airt.remote_path: S3Path.__enter__(): pulling data from s3://kumaran-airt-service-eu-west-1/92/datasource/74/.metadata_by_airt to /tmp/s3kumaran-airt-service-eu-west-192datasource74metadata_by_airt_cached_0baie8pi
[INFO] airt.remote_path: S3Path._clean_up(): removing local cache path /tmp/s3kumaran-airt-service-eu-west-192datasource74metadata_by_airt_cached_0baie8pi


Unnamed: 0_level_0,AccountId,DefinitionId,OccurredTime,OccurredTimeTicks
PersonId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2,312571,loadTests2,2019-12-31 21:30:02,1577836802678
2,312571,loadTests3,2020-01-03 23:53:22,1578104602678
2,312571,loadTests1,2020-01-07 02:16:42,1578372402678
2,312571,loadTests2,2020-01-10 04:40:02,1578640202678
2,312571,loadTests3,2020-01-13 07:03:22,1578908002678
2,312571,loadTests1,2020-01-16 09:26:42,1579175802678
2,312571,loadTests2,2020-01-19 11:50:02,1579443602678
2,312571,loadTests3,2020-01-22 14:13:22,1579711402678
2,312571,loadTests1,2020-01-25 16:36:42,1579979202678
2,312571,loadTests2,2020-01-28 19:00:02,1580247002678


In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    uri = "s3://"

    datasource_without_pull = DataSource(
        datablob=DataBlob(
            type="s3",
            uri=create_db_uri_for_s3_datablob(
                uri=uri, access_key="access", secret_key="secret"
            ),
            source=uri,
            cloud_provider="aws",
            region="eu-west-1",
            total_steps=1,
            user=user,
        ),
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
    )
    session.add(datasource_without_pull)
    session.commit()
    with pytest.raises(HTTPException):
        datasource_head_route(
            datasource_uuid=datasource_without_pull.uuid, user=user, session=session
        )

In [None]:
# | export


@datasource_router.get(
    "/{datasource_uuid}/dtypes",
    responses={
        **get_datasource_responses,  # type: ignore
        412: {"model": HTTPError, "description": ERRORS["DATASOURCE_IS_NOT_PULLED"]},
    },
)
def datasource_dtypes_route(
    datasource_uuid: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> Dict[str, str]:
    """Get columns and its dtypes of the datasource"""
    user = session.merge(user)
    # get locally saved parquet file path, read it and return columns and its dtypes
    datasource = DataSource.get(uuid=datasource_uuid, user=user, session=session)  # type: ignore
    datasource.is_ready()

    df_dict = _get_ds_head_and_dtypes(datasource_s3_path=datasource.path)
    return df_dict["dtypes"]

In [None]:
with get_session_with_context() as session:
    session.commit()
    user = session.exec(select(User).where(User.username == test_username)).one()
    datasource_synced = session.exec(
        select(DataSource).where(DataSource.uuid == datasource_synced.uuid)
    ).one()
    # expected = user.datasources[-1]

    actual = datasource_dtypes_route(
        datasource_uuid=datasource_synced.uuid, user=user, session=session
    )
    assert actual == {
        "AccountId": "int64",
        "DefinitionId": "object",
        "OccurredTime": "datetime64[ns]",
        "OccurredTimeTicks": "int64",
    }

    display(actual)

[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-1/92/datasource/74/.metadata_by_airt
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-192datasource74metadata_by_airt_cached_yl7bu38p
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-1/92/datasource/74/.metadata_by_airt locally in /tmp/s3kumaran-airt-service-eu-west-192datasource74metadata_by_airt_cached_yl7bu38p
[INFO] airt.remote_path: S3Path.__enter__(): pulling data from s3://kumaran-airt-service-eu-west-1/92/datasource/74/.metadata_by_airt to /tmp/s3kumaran-airt-service-eu-west-192datasource74metadata_by_airt_cached_yl7bu38p
[INFO] airt.remote_path: S3Path._clean_up(): removing local cache path /tmp/s3kumaran-airt-service-eu-west-192datasource74metadata_by_airt_cached_yl7bu38p


{'AccountId': 'int64',
 'DefinitionId': 'object',
 'OccurredTime': 'datetime64[ns]',
 'OccurredTimeTicks': 'int64'}

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    datasource_without_pull = session.merge(datasource_without_pull)

    with pytest.raises(HTTPException):
        datasource_dtypes_route(
            datasource_uuid=datasource_without_pull.uuid, user=user, session=session
        )

In [None]:
# | exporti


@patch(cls_method=True)
def get_all(
    cls: DataSource,
    disabled: bool,
    completed: bool,
    offset: int,
    limit: int,
    user: User,
    session: Session,
) -> List[DataSource]:
    """Get all datasources created by the user

    Args:
        disabled: Whether to get disabled datasources
        completed: Whether to include only datasources which are pulled from its source
        offset: Offset results by given integer
        limit: Limit results by given integer
        user: User object
        session: Sqlmodel session

    Returns:
        A list of datasource objects
    """
    statement = select(DataSource).where(DataSource.user == user)
    statement = statement.where(DataSource.disabled == disabled)
    if completed:
        statement = statement.where(
            DataSource.completed_steps == DataSource.total_steps
        )
    # get all data sources from db
    return session.exec(statement.offset(offset).limit(limit)).all()

In [None]:
# | export


@datasource_router.get("/", response_model=List[DataSourceRead])
def get_all_datasources(
    disabled: bool = False,
    completed: bool = False,
    offset: int = 0,
    limit: int = Query(default=100, lte=100),
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> List[DataSource]:
    """Get all datasources created by user"""
    user = session.merge(user)
    return DataSource.get_all(
        disabled=disabled,
        completed=completed,
        offset=offset,
        limit=limit,
        user=user,
        session=session,
    )

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    actual = get_all_datasources(
        disabled=False, completed=False, offset=0, limit=1, user=user, session=session
    )
    display(actual)

    assert len(actual) == 1
    assert isinstance(actual[0], DataSource)
    assert actual[0] == user.datasources[0]

[DataSource(id=65, uuid=UUID('6aacca90-cf55-433a-b1b6-d850e1fbce8d'), hash=None, total_steps=1, completed_steps=0, folder_size=None, no_of_rows=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 10, 19, 10, 38, 11), user_id=92, pulled_on=None, tags=[])]

In [None]:
actual = get_all_datasources(
    disabled=False, completed=False, offset=0, limit=10, user=user, session=session
)
display(f"{len(actual)=}")
for datasource in actual:
    assert not datasource.disabled
assert actual[0] == user.datasources[0]
actual = get_all_datasources(
    disabled=True, completed=False, offset=0, limit=10, user=user, session=session
)
display(f"{len(actual)=}")
for datasource in actual:
    assert datasource.disabled

actual = get_all_datasources(
    disabled=False, completed=True, offset=0, limit=10, user=user, session=session
)
display(f"{len(actual)=}")
for datasource in actual:
    assert datasource.completed_steps == datasource.total_steps

'len(actual)=6'

'len(actual)=2'

'len(actual)=1'

In [None]:
# | exporti


@patch
def tag(self: DataSource, tag_name: str, session: Session):
    """Tag an existing datasource

    Args:
        tag_name: A string to tag the datasource
        session: Sqlmodel session
    """

    user_tag = Tag.get_by_name(name=tag_name, session=session)

    self.remove_tag_from_previous_datasources(tag_name=user_tag.name, session=session)  # type: ignore
    self.tags.append(user_tag)

    with commit_or_rollback(session):
        session.add(self)

    return self

In [None]:
# | export


@datasource_router.post("/{datasource_uuid}/tag", response_model=DataSourceRead)
def tag_datasource(
    datasource_uuid: str,
    tag_to_create: TagCreate,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> DataSource:
    """Add tag to datasource"""
    user = session.merge(user)
    datasource = DataSource.get(uuid=datasource_uuid, user=user, session=session)  # type: ignore

    return datasource.tag(tag_name=tag_to_create.name, session=session)

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    datasource = user.datasources[0]

    tag_name = "new_tag"
    tag_to_create = TagCreate(name=tag_name)
    actual = tag_datasource(
        datasource_uuid=datasource.uuid,
        tag_to_create=tag_to_create,
        user=user,
        session=session,
    )
    display(actual)
    tag_found = False
    for tag in actual.tags:
        if tag.name == tag_name:
            tag_found = True
            break
    assert tag_found

DataSource(id=65, uuid=UUID('6aacca90-cf55-433a-b1b6-d850e1fbce8d'), hash=None, total_steps=1, completed_steps=0, folder_size=None, no_of_rows=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 10, 19, 10, 38, 11), user_id=92, pulled_on=None, tags=[Tag(name='new_tag', id=9, created=datetime.datetime(2022, 10, 18, 12, 30, 43), uuid=UUID('1662ea89-5018-4728-a71a-ecc502a22fd7'))])

In [None]:
# | exporti


@patch(cls_method=True)
def _create(
    cls: DataSource,
    *,
    datablob: DataBlob,
    total_steps: int = 1,
    user: User,
    session: Session,
) -> DataSource:
    """Create new datasource based on given params

    Args:
        datablob: Datablob object
        total_steps: Total steps
        user: User object
        session: Sqlmodel session

    Returns:
        The created datasource object
    """
    with commit_or_rollback(session):
        datasource = DataSource(
            datablob=datablob,
            cloud_provider=datablob.cloud_provider,
            region=datablob.region,
            total_steps=total_steps,
            user=user,
        )

    for tag in datablob.tags:
        datasource.tag(tag_name=tag.name, session=session)  # type: ignore

    return datasource

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    uri = "s3://"
    datablob = DataBlob(
        type="s3",
        uri=create_db_uri_for_s3_datablob(
            uri=uri, access_key="access", secret_key="secret"
        ),
        source=uri,
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
    )
    actual = DataSource._create(
        datablob=datablob, total_steps=1, user=user, session=session
    )
    display(actual)
    assert actual.uuid is not None

    datasource = session.exec(
        select(DataSource).where(DataSource.uuid == actual.uuid)
    ).one()
    assert datasource.tags == datablob.tags

DataSource(id=82, uuid=UUID('4f51e904-3cea-4d42-8839-d1ecf300bce2'), hash=None, total_steps=1, completed_steps=0, folder_size=None, no_of_rows=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 10, 19, 10, 39, 38), user_id=92, pulled_on=None, tags=[])