In [None]:
#| default_exp data.datablob

In [None]:
from airt.testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.
[INFO] numexpr.utils: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.
[INFO] airt.keras.helpers: Using a single GPU #0 with memory_limit 1024 MB


In [None]:
#| export

import json
import shlex
from enum import Enum
from time import sleep
from typing import *
import uuid as uuid_pkg

import numpy as np
import boto3
from botocore.client import Config
from pydantic import BaseModel
from fastapi import APIRouter, Depends, HTTPException, status, Query, BackgroundTasks
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm.exc import StaleDataError
from sqlmodel import Session, select

from airt.executor.subcommand import SimpleCLICommand
from airt.helpers import get_s3_bucket_name_and_folder_from_uri
from airt.logger import get_logger
from airt.patching import patch

import airt_service
import airt_service.sanitizer
from airt_service.airflow.executor import AirflowExecutor
from airt_service.auth import get_current_active_user
from airt_service.aws.utils import (
    create_s3_datablob_path,
    get_s3_bucket_and_path_from_uri,
    verify_aws_region,
)
from airt_service.azure.utils import verify_azure_region
from airt_service.batch_job import create_batch_job
from airt_service.data.clickhouse import create_db_uri_for_clickhouse_datablob
from airt_service.data.datasource import DataSource
from airt_service.data.utils import (
    create_db_uri_for_azure_blob_storage_datablob,
    create_db_uri_for_s3_datablob,
    create_db_uri_for_db_datablob,
    create_db_uri_for_local_datablob,
    delete_data_object_files_in_cloud,
)
from airt_service.db.models import (
    User,
    DataBlob,
    DataBlobRead,
    DataSourceRead,
    TagCreate,
    get_session,
    Tag,
)
from airt_service.errors import HTTPError, ERRORS
from airt_service.helpers import commit_or_rollback

[INFO] airt.executor.subcommand: Module loaded.


In [None]:
import os
import random
import string

import dask.dataframe as dd
import pytest
import requests
from azure.identity import DefaultAzureCredential
from azure.mgmt.storage import StorageManagementClient

from airt_service.aws.utils import get_queue_definition_arns, upload_to_s3_with_retry
from airt_service.background_task import execute_cli
from airt_service.db.models import (
    create_user_for_testing,
    get_db_params_from_env_vars,
    get_session_with_context,
)
from airt_service.data.s3 import s3_pull
from airt_service.helpers import set_env_variable_context
from airt.remote_path import RemotePath

In [None]:
test_username = create_user_for_testing()
display(test_username)

'gwynhfyfgx'

In [None]:
#| exporti

logger = get_logger(__name__)

In [None]:
logger.info("log a random string")

[INFO] __main__: log a random string


In [None]:
#| export

# Default router for all data sources
datablob_router = APIRouter(
    prefix="/datablob",
    tags=["datablob"],
    #     dependencies=[Depends(get_current_active_user)],
    responses={
        404: {"description": "Not found"},
        500: {
            "model": HTTPError,
            "description": ERRORS["INTERNAL_SERVER_ERROR"],
        },
    },
)

In [None]:
#| exporti


@patch
def remove_tag_from_previous_datablobs(self: DataBlob, tag_name: str, session: Session):
    """Remove tag_name associated with other/previous datablobs

    Args:
        tag_name: Tag name to remove from other datablobs
        session: Sqlmodel session
    """
    tag_to_remove = Tag.get_by_name(name=tag_name, session=session)  # type: ignore
    try:
        datablobs = session.exec(
            select(DataBlob).where(
                DataBlob.type == self.type,
                DataBlob.uri == self.uri,
                DataBlob.user == self.user,
            )
        ).all()
    except NoResultFound:
        return

    for datablob in datablobs:
        if tag_to_remove in datablob.tags:
            try:
                datablob.tags.remove(tag_to_remove)
                session.add(datablob)
                session.commit()
            except StaleDataError:
                session.rollback()

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    uri = "s3://bucket"

    test_tag = Tag.get_by_name(name="test", session=session)
    db_uri = create_db_uri_for_s3_datablob(
        uri=uri,
        access_key="access",
        secret_key="secret",
    )
    datablob_with_tag = DataBlob(
        type="s3",
        uri=db_uri,
        source=uri,
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
        tags=[test_tag],
    )
    session.add(datablob_with_tag)
    session.commit()
    display(datablob_with_tag)

    with commit_or_rollback(session):
        new_datablob_with_test_tag = DataBlob(
            type="s3",
            uri=db_uri,
            source=uri,
            cloud_provider="aws",
            region="eu-west-1",
            total_steps=1,
            user=user,
            #     tags=[test_tag],
        )
        new_datablob_with_test_tag.remove_tag_from_previous_datablobs(
            tag_name=test_tag.name, session=session
        )

        datablob_without_test_tag = session.exec(
            select(DataBlob).where(DataBlob.uuid == datablob_with_tag.uuid)
        ).one()
        display(datablob_without_test_tag)
        session.add(datablob_without_test_tag)
        assert test_tag not in datablob_without_test_tag
        assert not datablob_without_test_tag.tags

        new_datablob_with_test_tag.tags.append(test_tag)
        session.add(new_datablob_with_test_tag)
        display(new_datablob_with_test_tag)

DataBlob(id=3, uuid=UUID('5f83b4d0-4f0c-46a6-84f6-a79523e77d55'), type='s3', uri='s3://****************************************@bucket', source='s3://bucket', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 5), user_id=6, pulled_on=None, tags=[Tag(id=3, created=datetime.datetime(2022, 11, 7, 9, 10, 5), name='test', uuid=UUID('e1e3ca56-290b-41be-b877-6d6bcd1b4078'))])

DataBlob(id=3, uuid=UUID('5f83b4d0-4f0c-46a6-84f6-a79523e77d55'), type='s3', uri='s3://****************************************@bucket', source='s3://bucket', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 5), user_id=6, pulled_on=None, tags=[])

DataBlob(id=4, uuid=UUID('71904d1d-8a0b-4f01-b93d-ae6281c09dc1'), type='s3', uri='s3://****************************************@bucket', source='s3://bucket', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 6), user_id=6, pulled_on=None, tags=[Tag()])

In [None]:
#| exporti

create_datablob_responses = {
    400: {"model": HTTPError, "description": "DataBlob error"},
}


@patch(cls_method=True)
def _create(
    cls: DataBlob,
    *,
    type: str,
    uri: str,
    source: str,
    cloud_provider: str,
    region: str,
    total_steps: int,
    user_tag: Optional[str] = None,
    user: User,
    session: Session
) -> DataBlob:
    """Function to create new datablob based on given params

    Args:
        type: Datablob type
        uri: DB uri of datablob
        source: Datablob uri
        cloud_provider: Cloud provider to store datablob files
        region: Region of cloud provider
        total_steps: Total steps
        user_tag: Tag created by user to add to new datablob
        user: User object
        session: Sqlmodel session

    Returns:
        The created datablob object
    Raises:
        HTTPException if request has bad parameters
    """
    try:
        datablob = DataBlob(
            type=type,
            uri=uri,
            source=source,
            cloud_provider=cloud_provider,
            region=region,
            total_steps=total_steps,
            user=user,
        )

        for tag_name in [user_tag, "latest"] if user_tag is not None else ["latest"]:
            datablob.remove_tag_from_previous_datablobs(  # type: ignore
                tag_name=tag_name, session=session
            )
            datablob.tags.append(Tag.get_by_name(name=tag_name, session=session))  # type: ignore

        session.add(datablob)
        session.commit()
        return datablob
    except Exception as e:
        logger.exception(e)
        error_message = (
            e._message() if callable(getattr(e, "_message", None)) else str(e)  # type: ignore
        )
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=error_message,
        )

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    uri = "s3://bucket"
    with commit_or_rollback(session):
        actual = DataBlob._create(
            type="s3",
            uri=create_db_uri_for_s3_datablob(
                uri=uri,
                access_key="access",
                secret_key="secret",
            ),
            source=uri,
            cloud_provider="aws",
            region="eu-west-1",
            total_steps=1,
            user_tag="my_s3_datablob_tag",
            user=user,
            session=session,
        )
        session.add(actual)
    display(actual)
    assert actual.id is not None

    assert actual.source == uri

    datablob = session.exec(select(DataBlob).where(DataBlob.id == actual.id)).one()
    assert datablob.type == "s3"
    assert (Tag.get_by_name("my_s3_datablob_tag", session)) in datablob.tags

DataBlob(id=5, uuid=UUID('702ef2e8-fafd-452f-b5db-d6305b32a973'), type='s3', uri='s3://****************************************@bucket', source='s3://bucket', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 6), user_id=6, pulled_on=None, tags=[Tag(id=2, created=datetime.datetime(2022, 11, 7, 9, 9, 59), name='latest', uuid=UUID('487ec313-1f51-498e-ad17-c29a9a10aa25')), Tag(id=4, created=datetime.datetime(2022, 11, 7, 9, 10, 6), name='my_s3_datablob_tag', uuid=UUID('93076773-196f-4144-aa16-66d18889460e'))])

In [None]:
#| export


class S3Request(BaseModel):
    """Base Request object for from_s3 and to_s3 routes"""

    uri: str
    access_key: str
    secret_key: str

In [None]:
#| export


class CloudProvider(str, Enum):
    aws = "aws"
    azure = "azure"

In [None]:
#| export


class FromS3Request(S3Request):
    """Request object for the /data/s3 route

    Args:
        uri: S3 uri of the folder where parquet files are stored
        access_key: Access key for the s3 bucket
        secret_key: Secret key for the s3 bucket
        cloud_provider: Cloud provider to save files
        region: Region of the cloud provider
        tag: Tag to add to the datablob
    """

    cloud_provider: CloudProvider = "aws"  # type: ignore
    region: Optional[str] = None
    tag: Optional[str] = None

    class Config:
        use_enum_values = True

In [None]:
#| exporti


@patch(cls_method=True)
def from_s3(
    cls: DataBlob,
    *,
    from_s3_request: FromS3Request,
    user: User,
    session: Session,
    background_tasks: BackgroundTasks,
    no_retries: int = 3,
) -> DataBlob:
    """Create a datablob from an S3 bucket

    Args:
        from_s3_request: The from_s3 request
        user: User object
        session: Sqlmodel session
        background_tasks: BackgroundTasks object
        no_retries: Number of times to retry before raising an exception

    Returns:
        A new datablob created from a S3
    """

    uri = create_db_uri_for_s3_datablob(
        uri=from_s3_request.uri,
        access_key=from_s3_request.access_key,
        secret_key=from_s3_request.secret_key,
    )

    cloud_provider = from_s3_request.cloud_provider
    region = from_s3_request.region
    if region is None:
        s3_client = boto3.client(
            "s3",
            aws_access_key_id=from_s3_request.access_key,
            aws_secret_access_key=from_s3_request.secret_key,
            config=Config(signature_version="s3v4"),
        )
        bucket_name, folder = get_s3_bucket_name_and_folder_from_uri(
            from_s3_request.uri
        )
        region = s3_client.get_bucket_location(Bucket=bucket_name)["LocationConstraint"]

    verify_aws_region(region) if cloud_provider == "aws" else verify_azure_region(
        region
    )
    source = from_s3_request.uri

    for i in range(no_retries):
        e: Optional[Exception] = None
        try:
            datablob = DataBlob._create(  # type: ignore
                type="s3",
                uri=uri,
                source=source,
                cloud_provider=cloud_provider,
                region=region,
                total_steps=1,
                user_tag=from_s3_request.tag,
                user=user,
                session=session,
            )
            break
        except Exception as _e:
            e = _e
            sleep(np.random.uniform(1, 5))

    if e:
        logger.exception(f"DataBlob.from_s3() failed", exc_info=e)
        raise HTTPException(status_code=500, detail=f"Unexpected exception: {e}")

    #     command = f"s3_pull {datablob.id}"

    #     create_batch_job(
    #         command=command, task="csv_processing", region=region, background_tasks=background_tasks
    #     )

    steps = [SimpleCLICommand(command="s3_pull {datablob_id}")]
    executor = AirflowExecutor.create_executor(
        steps, cloud_provider=datablob.cloud_provider, region=datablob.region
    )
    dag_file_path, run_id = executor.execute(
        description="s3 pull",
        tags="s3_pull",
        datablob_id=datablob.id,
    )

    return datablob

In [None]:
#| export


@datablob_router.post(
    "/from_s3", response_model=DataBlobRead, responses=create_datablob_responses  # type: ignore
)
def from_s3_route(
    *,
    from_s3_request: FromS3Request,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
    background_tasks: BackgroundTasks,
) -> DataBlob:
    """Create a datablob from csv/parquet files in s3 bucket"""
    user = session.merge(user)
    return DataBlob.from_s3(  # type: ignore
        from_s3_request=from_s3_request,
        user=user,
        session=session,
        background_tasks=background_tasks,
    )

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    uri = f"s3://test-airt-service/account_312571_events"
    from_s3_request = FromS3Request(
        uri=uri,
        cloud_provider="aws",
        region=None,
        access_key=os.environ["AWS_ACCESS_KEY_ID"],
        secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
        tag="my_s3_datablob_tag",
    )
    b = BackgroundTasks()

    # Test using FastAPIBatchJobContext with set_env_variable_context
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        actual = from_s3_route(
            from_s3_request=from_s3_request,
            user=user,
            session=session,
            background_tasks=b,
        )
    display(actual)

    assert actual.type == "s3"
    assert actual.source == uri, actual.source

    datablob = session.exec(select(DataBlob).where(DataBlob.uuid == actual.uuid)).one()

    # bg_task = b.tasks[-1]
    # display(f"{bg_task.func=}", f"{bg_task.args=}", f"{bg_task.kwargs=}")
    # assert bg_task.func == execute_cli
    # assert bg_task.kwargs["command"] == f"s3_pull {actual.id}"

[{'dag_id': 's3_pull-6', 'run_id': 'airt-service__2022-11-07T09:10:22.391682', 'state': 'running', 'execution_date': '2022-11-07T09:10:23+00:00', 'start_date': '2022-11-07T09:10:24.079718+00:00', 'end_date': ''}]


DataBlob(id=6, uuid=UUID('02d0ec7f-2f3e-4d69-9021-5ed335a046a9'), type='s3', uri='s3://****************************************@test-airt-service/account_312571_events', source='s3://test-airt-service/account_312571_events', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-3', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 6), user_id=6, pulled_on=None, tags=[Tag(id=2, created=datetime.datetime(2022, 11, 7, 9, 9, 59), name='latest', uuid=UUID('487ec313-1f51-498e-ad17-c29a9a10aa25')), Tag(id=4, created=datetime.datetime(2022, 11, 7, 9, 10, 6), name='my_s3_datablob_tag', uuid=UUID('93076773-196f-4144-aa16-66d18889460e'))])

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    uri = f"s3://test-airt-service/account_312571_events"
    from_s3_request = FromS3Request(
        uri=uri,
        access_key=os.environ["AWS_ACCESS_KEY_ID"],
        secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
        cloud_provider="aws",
        region="region-doesnt-exists",
        tag="my_s3_datablob_tag",
    )
    b = BackgroundTasks()

    # Test using FastAPIBatchJobContext with set_env_variable_context
    with pytest.raises(HTTPException) as e:
        with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
            actual = from_s3_route(
                from_s3_request=from_s3_request,
                user=user,
                session=session,
                background_tasks=b,
            )
    display(e)

<ExceptionInfo HTTPException(status_code=400, detail='Unknown region - region-doesnt-exists; Available regions are ap-northeast-1, ap...l-1, eu-central-1, eu-north-1, eu-west-1, eu-west-2, eu-west-3, sa-east-1, us-east-1, us-east-2, us-west-1, us-west-2') tblen=4>

In [None]:
#| export


class AzureBlobStorageRequest(BaseModel):
    """Base Request object for from_azure_blob_storage and to_azure_blob_storage routes"""

    uri: str
    credential: str

In [None]:
#| export


class FromAzureBlobStorageRequest(AzureBlobStorageRequest):
    """Request object for the from_azure_blob_storage route

    Args:
        uri: Azure blob storage uri of the folder where parquet files are stored
        credential: Credential for the blob storage container
        cloud_provider: Cloud provider to save files
        region: Region of the cloud provider
        tag: Tag to add to the datablob
    """

    cloud_provider: CloudProvider = "azure"  # type: ignore
    #     region: Optional[str] = None
    region: str
    tag: Optional[str] = None

    class Config:
        use_enum_values = True

In [None]:
#| exporti


@patch(cls_method=True)
def from_azure_blob_storage(
    cls: DataBlob,
    *,
    from_azure_blob_storage_request: FromAzureBlobStorageRequest,
    user: User,
    session: Session,
    background_tasks: BackgroundTasks,
    no_retries: int = 3,
) -> DataBlob:
    """Create a datablob from given azure blob storage

    Args:
        from_azure_blob_storage_request: The from_azure_blob_storage request
        user: User object
        session: Sqlmodel session
        background_tasks: BackgroundTasks object
        no_retries: Number of times to retry before raising an exception

    Returns:
        A new datablob created from azure blob storage
    """

    uri = create_db_uri_for_azure_blob_storage_datablob(
        uri=from_azure_blob_storage_request.uri,
        credential=from_azure_blob_storage_request.credential,
    )

    cloud_provider = from_azure_blob_storage_request.cloud_provider
    region = from_azure_blob_storage_request.region
    # ToDo: get region
    #     if region is None:
    #         s3_client = boto3.client(
    #             "s3",
    #             aws_access_key_id=from_s3_request.access_key,
    #             aws_secret_access_key=from_s3_request.secret_key,
    #         )
    #         bucket_name, folder = get_s3_bucket_name_and_folder_from_uri(from_s3_request.uri)
    #         region = s3_client.get_bucket_location(Bucket=bucket_name)["LocationConstraint"]

    verify_aws_region(region) if cloud_provider == "aws" else verify_azure_region(
        region
    )
    source = from_azure_blob_storage_request.uri

    datablob = DataBlob._create(  # type: ignore
        type="azure_blob_storage",
        uri=uri,
        source=source,
        cloud_provider=cloud_provider,
        region=region,
        total_steps=1,
        user_tag=from_azure_blob_storage_request.tag,
        user=user,
        session=session,
    )

    command = f"azure_blob_storage_pull {datablob.id}"
    create_batch_job(
        command=command,
        task="csv_processing",
        cloud_provider=cloud_provider,
        region=region,
        background_tasks=background_tasks,
    )

    return datablob

In [None]:
#| export


@datablob_router.post(
    "/from_azure_blob_storage", response_model=DataBlobRead, responses=create_datablob_responses  # type: ignore
)
def from_azure_blob_storage_route(
    *,
    from_azure_blob_storage_request: FromAzureBlobStorageRequest,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
    background_tasks: BackgroundTasks,
) -> DataBlob:
    """Create a datablob from csv/parquet files in s3 bucket"""
    user = session.merge(user)
    return DataBlob.from_azure_blob_storage(  # type: ignore
        from_azure_blob_storage_request=from_azure_blob_storage_request,
        user=user,
        session=session,
        background_tasks=background_tasks,
    )

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    uri = "https://testairtservice.blob.core.windows.net/test-container/account_312571_events"

    storage_client = StorageManagementClient(
        DefaultAzureCredential(), os.environ["AZURE_SUBSCRIPTION_ID"]
    )
    keys = storage_client.storage_accounts.list_keys(
        "test-airt-service", "testairtservice"
    )
    credential = keys.keys[0].value

    from_azure_blob_storage_request = FromAzureBlobStorageRequest(
        uri=uri,
        credential=credential,
        cloud_provider="azure",
        region="westeurope",
        tag="my_azure_blob_storage_datablob_tag",
    )
    b = BackgroundTasks()

    # Test using FastAPIBatchJobContext with set_env_variable_context
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        actual = from_azure_blob_storage_route(
            from_azure_blob_storage_request=from_azure_blob_storage_request,
            user=user,
            session=session,
            background_tasks=b,
        )
    display(actual)

    assert actual.type == "azure_blob_storage"
    assert actual.source == uri, actual.source

    datablob = session.exec(select(DataBlob).where(DataBlob.uuid == actual.uuid)).one()

    bg_task = b.tasks[-1]
    display(f"{bg_task.func=}", f"{bg_task.args=}", f"{bg_task.kwargs=}")
    assert bg_task.func == execute_cli
    assert bg_task.kwargs["command"] == f"azure_blob_storage_pull {actual.id}"

[INFO] azure.identity._credentials.environment: Environment is configured for ClientSecretCredential
[INFO] azure.identity._credentials.managed_identity: ManagedIdentityCredential will use IMDS
[INFO] azure.identity._credentials.chained: DefaultAzureCredential acquired a token from EnvironmentCredential
[INFO] airt_service.batch_job: create_batch_job(): command='azure_blob_storage_pull 7', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='azure_blob_storage_pull 7', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************

DataBlob(id=7, uuid=UUID('e10db2f9-c20d-4bd6-8704-76fc0c848218'), type='azure_blob_storage', uri='https://****************************************@testairtservice.blob.core.windows.net/test-container/account_312571_events', source='https://testairtservice.blob.core.windows.net/test-container/account_312571_events', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.azure: 'azure'>, region='westeurope', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 27), user_id=6, pulled_on=None, tags=[Tag(id=2, created=datetime.datetime(2022, 11, 7, 9, 9, 59), name='latest', uuid=UUID('487ec313-1f51-498e-ad17-c29a9a10aa25')), Tag(id=5, created=datetime.datetime(2022, 11, 7, 9, 10, 27), name='my_azure_blob_storage_datablob_tag', uuid=UUID('ab876872-4eaf-455e-95ea-db7027f82b0b'))])

'bg_task.func=<function execute_cli at 0x7fbd812eef70>'

'bg_task.args=()'

"bg_task.kwargs={'command': 'azure_blob_storage_pull 7'}"

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    uri = "https://testairtservice.blob.core.windows.net/test-container/account_312571_events"

    storage_client = StorageManagementClient(
        DefaultAzureCredential(), os.environ["AZURE_SUBSCRIPTION_ID"]
    )
    keys = storage_client.storage_accounts.list_keys(
        "test-airt-service", "testairtservice"
    )
    credential = keys.keys[0].value

    from_azure_blob_storage_request = FromAzureBlobStorageRequest(
        uri=uri,
        credential=credential,
        cloud_provider="azure",
        region="region-does-not-exists",
        tag="my_azure_blob_storage_datablob_tag",
    )
    b = BackgroundTasks()

    # Test using FastAPIBatchJobContext with set_env_variable_context
    with pytest.raises(HTTPException) as e:
        with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
            actual = from_azure_blob_storage_route(
                from_azure_blob_storage_request=from_azure_blob_storage_request,
                user=user,
                session=session,
                background_tasks=b,
            )
    display(e)

[INFO] azure.identity._credentials.environment: Environment is configured for ClientSecretCredential
[INFO] azure.identity._credentials.managed_identity: ManagedIdentityCredential will use IMDS
[INFO] azure.identity._credentials.chained: DefaultAzureCredential acquired a token from EnvironmentCredential


<ExceptionInfo HTTPException(status_code=400, detail='Unknown region - region-does-not-exists; Available regions are australiacentral...dnorth, switzerlandwest, uaecentral, uaenorth, uksouth, ukwest, westcentralus, westeurope, westindia, westus, westus2') tblen=4>

In [None]:
#| export


class DBRequest(BaseModel):
    """Base request object for from_db and to_db routes"""

    host: str
    port: int
    username: str
    password: str
    database: str
    table: str

In [None]:
#| export


class FromDBRequest(DBRequest):
    """Request object for the /datablob/db route

    Args:
        host: Remote database host name
        port: DB port
        username: Username to access the db
        password: Password to access the db
        database: Database to use
        table: Table to import data from
        cloud_provider: Cloud provider to save files
        region: Region of the cloud provider
        tag: Tag to add to the datablob
    """

    cloud_provider: CloudProvider = "aws"  # type: ignore
    region: str = "eu-west-1"
    tag: Optional[str] = None

    class Config:
        use_enum_values = True

In [None]:
#| exporti


@patch(cls_method=True)
def from_rdbms(
    cls: DataBlob,
    *,
    from_db_request: FromDBRequest,
    database_server: str,
    user: User,
    session: Session,
    background_tasks: BackgroundTasks,
) -> DataBlob:
    """Create a datablob from a RDBMS

    Args:
        from_db_request: The from_db request
        database_server: Database engine name
        user: User object
        session: Sqlmodel session
        background_tasks: BackgroundTasks object

    Returns:
        A new datablob created from a RDBMS
    """
    host = from_db_request.host
    port = from_db_request.port
    table = from_db_request.table
    database = from_db_request.database

    uri = create_db_uri_for_db_datablob(
        username=from_db_request.username,
        password=from_db_request.password,
        host=host,
        port=port,
        table=table,
        database=database,
        database_server=database_server,
    )

    source = f"{database_server}://{host}:{port}/{database}/{table}"
    verify_aws_region(
        from_db_request.region
    ) if from_db_request.cloud_provider == "aws" else verify_azure_region(
        from_db_request.region
    )

    with commit_or_rollback(session):
        datablob = DataBlob._create(  # type: ignore
            type="db",
            uri=uri,
            source=source,
            cloud_provider=from_db_request.cloud_provider,
            region=from_db_request.region,
            total_steps=1,
            user_tag=from_db_request.tag,
            user=user,
            session=session,
        )

    if database_server in ["mysql", "postgresql"]:
        command = f"db_pull {datablob.id}"
    else:
        raise HTTPException(
            status_code=status.HTTP_501_NOT_IMPLEMENTED,
            detail=f"{ERRORS['PULL_NOT_AVAILABLE']} for database server {database_server}",
        )

    create_batch_job(
        command=command,
        task="csv_processing",
        cloud_provider=from_db_request.cloud_provider,
        region=from_db_request.region,
        background_tasks=background_tasks,
    )
    return datablob

In [None]:
#| export


@datablob_router.post(
    "/from_mysql", response_model=DataBlobRead, responses=create_datablob_responses  # type: ignore
)
def from_mysql_route(
    *,
    from_db_request: FromDBRequest,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
    background_tasks: BackgroundTasks,
) -> DataBlob:
    """Create a datablob from a database"""
    user = session.merge(user)
    return DataBlob.from_rdbms(  # type: ignore
        from_db_request=from_db_request,
        database_server="mysql",
        user=user,
        session=session,
        background_tasks=background_tasks,
    )

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    database_server = "mysql"
    host = "db.example.com"
    port = 3306
    database = "database_to_import"
    table = "events"

    from_db_request = FromDBRequest(
        host=host,
        port=port,
        username="username",
        password="password",
        database=database,
        table=table,
        tag="my_db_datablob_tag",
    )
    b = BackgroundTasks()

    # Test using FastAPIBatchJobContext with set_env_variable_context
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        actual = from_mysql_route(
            from_db_request=from_db_request,
            user=user,
            session=session,
            background_tasks=b,
        )
    display(actual)

    assert actual.type == "db"

    display(f"{actual.source=}")
    assert (
        actual.source == f"{database_server}://{host}:{port}/{database}/{table}"
    ), actual.source

    datablob = session.exec(select(DataBlob).where(DataBlob.uuid == actual.uuid)).one()

    bg_task = b.tasks[-1]
    display(f"{bg_task.func=}", f"{bg_task.args=}", f"{bg_task.kwargs=}")
    assert bg_task.func == execute_cli
    assert bg_task.kwargs["command"] == f"db_pull {actual.id}"

[INFO] airt_service.batch_job: create_batch_job(): command='db_pull 8', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='db_pull 8', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************', 'AZURE_STORAGE_ACCOUNT_PREFIX': 'kumsairtsdev', 'AZURE_RESOURCE_GROUP': 'kumaran-airt-service-dev', 'STORAGE_BUCKET_PREFIX': 'kumaran-airt-service', 

DataBlob(id=8, uuid=UUID('67a32ba1-37aa-425b-b887-b0b02ab9e432'), type='db', uri='mysql://****************************************@db.example.com:3306/database_to_import/events', source='mysql://db.example.com:3306/database_to_import/events', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 29), user_id=6, pulled_on=None, tags=[Tag(id=2, created=datetime.datetime(2022, 11, 7, 9, 9, 59), name='latest', uuid=UUID('487ec313-1f51-498e-ad17-c29a9a10aa25')), Tag(id=6, created=datetime.datetime(2022, 11, 7, 9, 10, 29), name='my_db_datablob_tag', uuid=UUID('8995cbb2-c7e7-4e03-9dd9-2d6756aaf8ca'))])

"actual.source='mysql://db.example.com:3306/database_to_import/events'"

'bg_task.func=<function execute_cli at 0x7fbd812eef70>'

'bg_task.args=()'

"bg_task.kwargs={'command': 'db_pull 8'}"

In [None]:
# Test failure scenario where lengthy string is a db password

with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    database_server = "mysql"
    host = "db.example.com"
    port = 3306
    database = "database_to_import"
    table = "events"

    from_db_request = FromDBRequest(
        host=host,
        port=port,
        username="username",
        password="ZAP %1!s%2!s%3!s%4!s%5!s%6!s%7!s%8!s%9!s%10!s%11!s%12!s%13!s%14!s%15!s%16!s%17!s%18!s%19!s%20!s%21!n%22!n%23!n%24!n%25!n%26!n%27!n%28!n%29!n%30!n%31!n%32!n%33!n%34!n%35!n%36!n%37!n%38!n%39!n%40!n",
        database=database,
        table=table,
        tag="my_db_datablob_tag",
    )
    b = BackgroundTasks()

    # Test using FastAPIBatchJobContext with set_env_variable_context
    with pytest.raises(HTTPException) as e:
        with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
            actual = from_mysql_route(
                from_db_request=from_db_request,
                user=user,
                session=session,
                background_tasks=b,
            )
    display(e)

[ERROR] __main__: DataError('(MySQLdb.DataError) (1406, "Data too long for column \'uri\' at row 1")')
Traceback (most recent call last):
  File "/root/.local/lib/python3.8/site-packages/sqlalchemy/engine/base.py", line 1900, in _execute_context
    self.dialect.do_execute(
  File "/root/.local/lib/python3.8/site-packages/sqlalchemy/engine/default.py", line 736, in do_execute
    cursor.execute(statement, parameters)
  File "/root/.local/lib/python3.8/site-packages/MySQLdb/cursors.py", line 206, in execute
    res = self._query(query)
  File "/root/.local/lib/python3.8/site-packages/MySQLdb/cursors.py", line 319, in _query
    db.query(q)
  File "/root/.local/lib/python3.8/site-packages/MySQLdb/connections.py", line 254, in query
    _mysql.connection.query(self, query)
MySQLdb.DataError: (1406, "Data too long for column 'uri' at row 1")

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<ipython-input-11-53ff729c4413>", li

<ExceptionInfo HTTPException(status_code=400, detail='(MySQLdb.DataError) (1406, "Data too long for column \'uri\' at row 1")') tblen=4>

In [None]:
#| export


class ClickHouseRequest(DBRequest):
    """Base request object for from_clickhouse and to_clickhouse routes"""

    protocol: str

In [None]:
#| export


class FromClickHouseRequest(ClickHouseRequest):
    """Request object for the /datablob/from_clickhouse route

    Args:
        host: Hostname where db is hosted
        port: DB port
        username: Username to access the db
        password: Password to access the db
        database: Database to use
        table: Table to import/export data
        protocol: Protocol to use (native/http)
        index_column: Column to use to partition rows and to use as index
        timestamp_column: Timestamp column
        filters: Additional column filters as a dictionary
        cloud_provider: Cloud provider to save files
        region: Region of the cloud provider
        tag: Tag to add to the datablob
    """

    index_column: str
    timestamp_column: str
    filters: Optional[Dict[str, Any]] = None
    cloud_provider: CloudProvider = "aws"  # type: ignore
    region: str = "eu-west-1"
    tag: Optional[str] = None

    class Config:
        use_enum_values = True

In [None]:
#| exporti


@patch(cls_method=True)
def from_clickhouse(
    cls: DataBlob,
    *,
    from_clickhouse_request: FromClickHouseRequest,
    user: User,
    session: Session,
    background_tasks: BackgroundTasks,
) -> DataBlob:
    """Create a datablob from a clickhouse database

    Args:
        from_clickhouse_request: The from_clickhouse request
        user: User object
        session: Sqlmodel session
        background_tasks: BackgroundTasks object

    Returns:
        A new datablob created from a clickhouse database
    """

    host = from_clickhouse_request.host
    port = from_clickhouse_request.port
    table = from_clickhouse_request.table
    database = from_clickhouse_request.database
    protocol = from_clickhouse_request.protocol

    uri = create_db_uri_for_clickhouse_datablob(
        username=from_clickhouse_request.username,
        password=from_clickhouse_request.password,
        host=host,
        port=port,
        table=table,
        database=database,
        protocol=protocol,
    )

    source = f"clickhouse+{protocol}://{host}:{port}/{database}/{table}"
    verify_aws_region(
        from_clickhouse_request.region
    ) if from_clickhouse_request.cloud_provider == "aws" else verify_azure_region(
        from_clickhouse_request.region
    )

    with commit_or_rollback(session):
        datablob = DataBlob._create(  # type: ignore
            type="db",
            uri=uri,
            source=source,
            cloud_provider=from_clickhouse_request.cloud_provider,
            region=from_clickhouse_request.region,
            total_steps=1,
            user_tag=from_clickhouse_request.tag,
            user=user,
            session=session,
        )

    command = f"clickhouse_pull {datablob.id} {from_clickhouse_request.index_column} {from_clickhouse_request.timestamp_column}"
    if from_clickhouse_request.filters:
        command = (
            command
            + f" --filters_json {shlex.quote(json.dumps(from_clickhouse_request.filters))}"
        )

    create_batch_job(
        command=command,
        task="csv_processing",
        cloud_provider=from_clickhouse_request.cloud_provider,
        region=from_clickhouse_request.region,
        background_tasks=background_tasks,
    )
    return datablob

In [None]:
#| export


@datablob_router.post(
    "/from_clickhouse", response_model=DataBlobRead, responses=create_datablob_responses  # type: ignore
)
def from_clickhouse_route(
    *,
    from_clickhouse_request: FromClickHouseRequest,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
    background_tasks: BackgroundTasks,
) -> DataBlob:
    """Create a datablob from a database"""
    user = session.merge(user)
    return DataBlob.from_clickhouse(  # type: ignore
        from_clickhouse_request=from_clickhouse_request,
        user=user,
        session=session,
        background_tasks=background_tasks,
    )

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    filters = {"AccountId": 312571}

    host = "db.example.com"
    port = 3306
    database = "database_to_import"
    table = "events"
    protocol = "native"

    from_clickhouse_request = FromClickHouseRequest(
        host=host,
        port=port,
        username="username",
        password="password",
        database=database,
        table=table,
        protocol=protocol,
        index_column="PersonId",
        timestamp_column="OccurredTimeTicks",
        filters=filters,
        tag="my_clickhouse_datablob_tag",
    )
    b = BackgroundTasks()

    # Test using FastAPIBatchJobContext with set_env_variable_context
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        actual = from_clickhouse_route(
            from_clickhouse_request=from_clickhouse_request,
            user=user,
            session=session,
            background_tasks=b,
        )
    display(actual)

    assert actual.type == "db"

    display(f"{actual.source=}")
    assert (
        actual.source == f"clickhouse+{protocol}://{host}:{port}/{database}/{table}"
    ), actual.source

    datablob = session.exec(select(DataBlob).where(DataBlob.uuid == actual.uuid)).one()

    bg_task = b.tasks[-1]
    display(f"{bg_task.func=}", f"{bg_task.args=}", f"{bg_task.kwargs=}")
    assert bg_task.func == execute_cli
    assert (
        bg_task.kwargs["command"]
        == f"clickhouse_pull {actual.id} PersonId OccurredTimeTicks --filters_json '{json.dumps(filters)}'"
    )

[INFO] airt_service.batch_job: create_batch_job(): command='clickhouse_pull 9 PersonId OccurredTimeTicks --filters_json \'{"AccountId": 312571}\'', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='clickhouse_pull 9 PersonId OccurredTimeTicks --filters_json \'{"AccountId": 312571}\'', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '************************************', 'AZURE_CLIENT_ID': '************************************', 'AZURE_CLIENT_SECRET': '****************************************

DataBlob(id=9, uuid=UUID('6548dfad-9588-43ed-a789-264d02171de0'), type='db', uri='clickhouse+native://****************************************@db.example.com:3306/database_to_import/events', source='clickhouse+native://db.example.com:3306/database_to_import/events', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 29), user_id=6, pulled_on=None, tags=[Tag(id=2, created=datetime.datetime(2022, 11, 7, 9, 9, 59), name='latest', uuid=UUID('487ec313-1f51-498e-ad17-c29a9a10aa25')), Tag(id=7, created=datetime.datetime(2022, 11, 7, 9, 10, 29), name='my_clickhouse_datablob_tag', uuid=UUID('c362c5dd-f437-456c-8135-2563b49bfa70'))])

"actual.source='clickhouse+native://db.example.com:3306/database_to_import/events'"

'bg_task.func=<function execute_cli at 0x7fbd812eef70>'

'bg_task.args=()'

'bg_task.kwargs={\'command\': \'clickhouse_pull 9 PersonId OccurredTimeTicks --filters_json \\\'{"AccountId": 312571}\\\'\'}'

In [None]:
#| export


class FromLocalRequest(BaseModel):
    """Request object for from_local route

    Args:
        path: Local path of the datablob
        cloud_provider: Cloud provider to save files
        region: Region of the cloud provider
        tag: Tag to add to the datablob
    """

    path: str
    cloud_provider: CloudProvider = "aws"  # type: ignore
    region: str = "eu-west-1"
    tag: Optional[str] = None

    class Config:
        use_enum_values = True


class FromLocalResponse(BaseModel):
    """Response object for the /datablob/from_local route

    Args:
        uuid: Datablob uuid
        type: Type of the datablob
        presigned: Presigned s3 url(valid for 24 hours) and other params to upload CSV file
    """

    uuid: uuid_pkg.UUID
    type: str
    presigned: Dict[str, Any]

In [None]:
#| exporti


@patch(cls_method=True)
def from_local(
    cls: DataBlob,
    *,
    path: str,
    cloud_provider: str,
    region: str,
    user_tag: Optional[str] = None,
    user: User,
    session: Session,
) -> FromLocalResponse:
    """Create a datablob from local file(s)

    Args:
        path: The relative or absolute path to a local file or to a directory containing the files.
        user_tag: A string to tag the datablob
        user: User object
        session: Sqlmodel session

    Returns:
        A new datablob created from local file(s)
    """
    if cloud_provider == "azure":
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["AZURE_NOT_SUPPORTED"],
        )
    verify_aws_region(region)

    with commit_or_rollback(session):
        datablob = DataBlob._create(  # type: ignore
            type="local",
            uri=None,
            source=path,
            cloud_provider=cloud_provider,
            region=region,
            total_steps=1,
            user_tag=user_tag,
            user=user,
            session=session,
        )

    destination_bucket, s3_path = create_s3_datablob_path(
        user_id=user.id, datablob_id=datablob.id, region=region  # type: ignore
    )
    uri = create_db_uri_for_local_datablob(bucket=destination_bucket, s3_path=s3_path)
    with commit_or_rollback(session):
        datablob.uri = uri
        session.add(datablob)

    presigned = boto3.client(
        "s3", region_name=region, config=Config(signature_version="s3v4")
    ).generate_presigned_post(
        Bucket=destination_bucket.name,
        Key=s3_path + "/" + "${filename}",
        Fields=None,
        Conditions=[["starts-with", "$key", s3_path]],
        ExpiresIn=60 * 60 * 24,
    )
    from_local_response = FromLocalResponse(
        uuid=datablob.uuid, type=datablob.type, presigned=presigned
    )
    logger.info(f"DataBlob.from_local(): {from_local_response.__repr__()}")
    return from_local_response

In [None]:
#| export


@datablob_router.post(
    "/from_local/start",
    response_model=FromLocalResponse,
    responses=create_datablob_responses,  # type: ignore
)
def from_local_start_route(
    from_local_request: FromLocalRequest,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> FromLocalResponse:
    """Get presigned s3 url to upload local CSV/Parquet files and create datablob from it"""
    user = session.merge(user)
    return DataBlob.from_local(  # type: ignore
        path=from_local_request.path,
        cloud_provider=from_local_request.cloud_provider,
        region=from_local_request.region,
        user_tag=from_local_request.tag,
        user=user,
        session=session,
    )

In [None]:
test_path = "tmp/test-folder/"
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    from_local_request = FromLocalRequest(path=test_path, tag="my_csv_datablob_tag")
    actual = from_local_start_route(
        from_local_request=from_local_request,
        user=user,
        session=session,
    )
    display(actual)
    assert actual.uuid
    assert actual.type == "local"
    assert isinstance(actual.presigned, dict)

    try:
        datablob_csv = session.exec(
            select(DataBlob).where(DataBlob.uuid == actual.uuid)
        ).one()
        datablob_csv.source == test_path
    except NoResultFound:
        assert False

    with RemotePath.from_url(
        remote_url=f"s3://test-airt-service/account_312571_events",
        pull_on_enter=True,
        push_on_exit=False,
        exist_ok=True,
        parents=False,
        access_key=os.environ["AWS_ACCESS_KEY_ID"],
        secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
    ) as test_s3_path:
        ddf = dd.read_parquet(test_s3_path.as_path())
        ddf.to_csv(test_s3_path.as_path() / "csv" / "file-*.csv", index=False)
        display(list((test_s3_path.as_path() / "csv").glob("*")))
        sleep(10)

        for csv_to_upload in sorted((test_s3_path.as_path() / "csv").glob("*.csv")):
            display(f"Uploading {csv_to_upload}")
            upload_to_s3_with_retry(
                csv_to_upload, actual.presigned["url"], actual.presigned["fields"]
            )

[INFO] botocore.credentials: Found credentials in environment variables.
[INFO] __main__: DataBlob.from_local(): FromLocalResponse(uuid=UUID('a17e6c65-4c91-4e24-a89f-dd1fc23a7e23'), type='local', presigned={'url': 'https://kumaran-airt-service-eu-west-1.s3.amazonaws.com/', 'fields': {'key': '****************************************', 'x-amz-algorithm': 'AWS4-HMAC-SHA256', 'x-amz-credential': '********************/20221107/eu-west-1/s3/aws4_request', 'x-amz-date': '20221107T091029Z', 'policy': '************************************************************************************************************************************************************************************************************************************************************', 'x-amz-signature': '****************************'}})


FromLocalResponse(uuid=UUID('a17e6c65-4c91-4e24-a89f-dd1fc23a7e23'), type='local', presigned={'url': 'https://kumaran-airt-service-eu-west-1.s3.amazonaws.com/', 'fields': {'key': '****************************************', 'x-amz-algorithm': 'AWS4-HMAC-SHA256', 'x-amz-credential': '********************/20221107/eu-west-1/s3/aws4_request', 'x-amz-date': '20221107T091029Z', 'policy': '************************************************************************************************************************************************************************************************************************************************************', 'x-amz-signature': '****************************'}})

[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://test-airt-service/account_312571_events
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://test-airt-service/account_312571_events locally in /tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2
[INFO] airt.remote_path: S3Path.__enter__(): pulling data from s3://test-airt-service/account_312571_events to /tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2


[Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2/csv/file-4.csv'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2/csv/file-3.csv'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2/csv/file-2.csv'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2/csv/file-0.csv'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2/csv/file-1.csv')]

'Uploading /tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2/csv/file-0.csv'

'Uploading /tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2/csv/file-1.csv'

'Uploading /tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2/csv/file-2.csv'

'Uploading /tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2/csv/file-3.csv'

'Uploading /tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2/csv/file-4.csv'

[INFO] airt.remote_path: S3Path._clean_up(): removing local cache path /tmp/s3test-airt-serviceaccount_312571_events_cached_b8j5dnb2


In [None]:
#| exporti

get_datablob_responses = {
    400: {"model": HTTPError, "description": ERRORS["INCORRECT_DATABLOB_ID"]},
    422: {"model": HTTPError, "description": "DataBlob error"},
}


@patch(cls_method=True)
def get(cls: DataBlob, uuid: str, user: User, session: Session) -> DataBlob:
    """Get datablob based on uuid

    Args:
        uuid: Datablob uuid
        user: User object
        session: Sqlmodel session

    Returns:
        The datablob object for given datablob uuid

    Raises:
        HTTPException: if datablob id is incorrect or if datablob is deleted
    """
    try:
        datablob = session.exec(
            select(DataBlob).where(DataBlob.uuid == uuid, DataBlob.user == user)
        ).one()
    except NoResultFound:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["INCORRECT_DATABLOB_ID"],
        )

    if datablob.disabled:
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail=ERRORS["DATABLOB_IS_DELETED"],
        )

    if datablob.error is not None:
        raise HTTPException(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=datablob.error
        )

    return datablob

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    expected = user.datablobs[0]

    actual = DataBlob.get(uuid=expected.uuid, user=user, session=session)
    display(actual)
    assert actual == expected

    with pytest.raises(HTTPException) as e:
        DataBlob.get(
            uuid="00000000-0000-0000-0000-000000000000", user=user, session=session
        )
    display(e)

    user_kumaran = session.exec(select(User).where(User.username == "kumaran")).one()
    with pytest.raises(HTTPException) as e:
        DataBlob.get(uuid=expected.uuid, user=user_kumaran, session=session)
    display(e)

DataBlob(id=3, uuid=UUID('5f83b4d0-4f0c-46a6-84f6-a79523e77d55'), type='s3', uri='s3://****************************************@bucket', source='s3://bucket', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 5), user_id=6, pulled_on=None, tags=[])

<ExceptionInfo HTTPException(status_code=400, detail='The datablob uuid is incorrect. Please try again.') tblen=2>

<ExceptionInfo HTTPException(status_code=400, detail='The datablob uuid is incorrect. Please try again.') tblen=2>

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    datablob_disabled = DataBlob(
        type="s3",
        uri=create_db_uri_for_s3_datablob(
            uri="s3://", access_key="access", secret_key="secret"
        ),
        source="s3://",
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
        disabled=True,
    )
    session.add(datablob_disabled)
    session.commit()
    session.refresh(datablob_disabled)

    with pytest.raises(HTTPException) as e:
        DataBlob.get(uuid=datablob_disabled.uuid, user=user, session=session)
    display(e)

<ExceptionInfo HTTPException(status_code=400, detail='The datablob has already been deleted.') tblen=2>

In [None]:
#| exporti


@patch
def is_ready(self: DataBlob):
    """Check if the datablob's completed steps equal to total steps, else raise HTTPException"""
    if self.completed_steps != self.total_steps:
        if self.path:
            bucket, s3_path = get_s3_bucket_and_path_from_uri(self.path)  # type: ignore
        else:
            bucket, s3_path = create_s3_datablob_path(user_id=self.user.id, datablob_id=self.id, region=self.region)  # type: ignore

        if len(list(bucket.objects.filter(Prefix=s3_path + "/"))) == 0:
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail=ERRORS["DATABLOB_CSV_FILES_NOT_AVAILABLE"],
            )

In [None]:
#| export


class FileType(str, Enum):
    csv = "csv"
    parquet = "parquet"


class ToDataSourceRequest(BaseModel):
    """Request object for the /datablob/{datablob_id}/to_datasource route

    Args:

        file_type: type of files in datablob; currently csv and parquet are supported
        deduplicate_data: If set to True (default value False), then duplicate rows are removed while processing
        index_column: Name of the column used to index and partition the data into partitions
        sort_by: Name of the column or list of columns  used to sort data within the same index value
        blocksize: Size of partition
        kwargs: Keyword arguments which are passed to the **dask.dataframe.read_csv()** function,
            typically params for underlining **pd.read_csv()** from Pandas.
    """

    file_type: FileType
    deduplicate_data: bool = False
    index_column: str
    sort_by: Union[str, List[str]]
    blocksize: str = "256MB"
    kwargs: Optional[Dict[str, Any]] = None

    class Config:
        use_enum_values = True

In [None]:
#| exporti


@patch
def to_datasource(
    self: DataBlob,
    to_datasource_request: ToDataSourceRequest,
    user: User,
    session: Session,
    background_tasks: BackgroundTasks,
) -> DataSource:
    """Process the CSV/Parquet datablob files and return a datasource object

    Args:
        to_datasource_request: The to_datasource_request object
        user: User object
        session: Sqlmodel session
        background_tasks: BackgroundTasks object
    """

    self.is_ready()  # type: ignore
    datasource = DataSource._create(datablob=self, user=user, session=session)  # type: ignore

    if to_datasource_request.file_type == "csv":
        process_command = "process_csv"
    elif to_datasource_request.file_type == "parquet":
        process_command = "process_parquet"

    sort_by = (
        [to_datasource_request.sort_by]
        if isinstance(to_datasource_request.sort_by, str)
        else to_datasource_request.sort_by
    )

    command = f"{process_command} {self.id} {datasource.id} {shlex.quote(to_datasource_request.index_column)} {shlex.quote(json.dumps(sort_by))} --blocksize {to_datasource_request.blocksize}"
    if to_datasource_request.kwargs is not None:
        command = (
            command
            + f" --kwargs_json {shlex.quote(json.dumps(to_datasource_request.kwargs))}"
        )
    if to_datasource_request.deduplicate_data:
        command = command + " --deduplicate_data"

    create_batch_job(
        command=command,
        task="csv_processing",
        cloud_provider=self.cloud_provider,
        region=self.region,
        background_tasks=background_tasks,
    )

    return datasource

In [None]:
#| export


@datablob_router.post(
    "/{datablob_uuid}/to_datasource",
    status_code=status.HTTP_202_ACCEPTED,
    response_model=DataSourceRead,
    responses=get_datablob_responses,  # type: ignore
)
def to_datasource_route(
    *,
    datablob_uuid: str,
    to_datasource_request: ToDataSourceRequest,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
    background_tasks: BackgroundTasks,
) -> DataSource:
    """Pull uploaded CSV/Parquet datablob, process it and store in s3 client storage bucket as parquet"""
    user = session.merge(user)
    datablob = DataBlob.get(uuid=datablob_uuid, user=user, session=session)  # type: ignore

    return datablob.to_datasource(
        to_datasource_request, user, session, background_tasks
    )

In [None]:
# Test using FastAPIBatchJobContext with set_env_variable_context

with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
    file_type = "csv"
    deduplicate_data = True
    index_column = "PersonId"
    sort_by = "OccurredTime"
    blocksize = "256MB"
    kwargs = dict(
        usecols=[0, 1, 2, 3, 4],
        parse_dates=["OccurredTime"],
    )

    with get_session_with_context() as session:
        user = session.exec(select(User).where(User.username == test_username)).one()
        to_datasource_request = ToDataSourceRequest(
            file_type=file_type,
            deduplicate_data=deduplicate_data,
            index_column=index_column,
            sort_by=sort_by,
            blocksize=blocksize,
            kwargs=kwargs,
        )
        b = BackgroundTasks()
        actual = to_datasource_route(
            datablob_uuid=datablob_csv.uuid,
            to_datasource_request=to_datasource_request,
            user=user,
            session=session,
            background_tasks=b,
        )
        display(actual)
        assert isinstance(actual, DataSource)
        bg_task = b.tasks[-1]
        display(f"{bg_task.func=}", f"{bg_task.args=}", f"{bg_task.kwargs=}")
        assert bg_task.func == execute_cli
        assert (
            bg_task.kwargs["command"]
            == f"process_csv {datablob_csv.id} {actual.id} {index_column} '{json.dumps([sort_by])}' --blocksize {blocksize} --kwargs_json '{json.dumps(kwargs)}' --deduplicate_data"
        ), bg_task.kwargs["command"]
#     assert actual.id == datablob_csv.id

[INFO] airt_service.batch_job: create_batch_job(): command='process_csv 10 5 PersonId \'["OccurredTime"]\' --blocksize 256MB --kwargs_json \'{"usecols": [0, 1, 2, 3, 4], "parse_dates": ["OccurredTime"]}\' --deduplicate_data', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='process_csv 10 5 PersonId \'["OccurredTime"]\' --blocksize 256MB --kwargs_json \'{"usecols": [0, 1, 2, 3, 4], "parse_dates": ["OccurredTime"]}\' --deduplicate_data', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************************', 'AWS_DEFAULT_REGION': 'eu-west-1', 'AZURE_SUBSCRIPTION_ID': '************************************', 'AZURE_TENANT_ID': '******

DataSource(id=5, uuid=UUID('4c06889f-a615-4bc5-affa-a5f91ca29c8e'), hash=None, total_steps=1, completed_steps=0, folder_size=None, no_of_rows=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 59), user_id=6, pulled_on=None, tags=[Tag(id=2, created=datetime.datetime(2022, 11, 7, 9, 9, 59), name='latest', uuid=UUID('487ec313-1f51-498e-ad17-c29a9a10aa25')), Tag(id=8, created=datetime.datetime(2022, 11, 7, 9, 10, 29), name='my_csv_datablob_tag', uuid=UUID('c6f8bc58-8142-429e-a29f-25b6120fff1d'))])

'bg_task.func=<function execute_cli at 0x7fbd812eef70>'

'bg_task.args=()'

'bg_task.kwargs={\'command\': \'process_csv 10 5 PersonId \\\'["OccurredTime"]\\\' --blocksize 256MB --kwargs_json \\\'{"usecols": [0, 1, 2, 3, 4], "parse_dates": ["OccurredTime"]}\\\' --deduplicate_data\'}'

In [None]:
test_path = "tmp/test-folder/"
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    from_local_request = FromLocalRequest(path=test_path, tag="my_parquet_datablob_tag")
    actual = from_local_start_route(
        from_local_request=from_local_request,
        user=user,
        session=session,
    )

    try:
        datablob_parquet = session.exec(
            select(DataBlob).where(DataBlob.uuid == actual.uuid)
        ).one()
        datablob_parquet.source == test_path
    except NoResultFound:
        assert False

    with RemotePath.from_url(
        remote_url=f"s3://test-airt-service/account_312571_events",
        pull_on_enter=True,
        push_on_exit=False,
        exist_ok=True,
        parents=False,
        access_key=os.environ["AWS_ACCESS_KEY_ID"],
        secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
    ) as test_s3_path:
        display(list(test_s3_path.as_path().glob("*")))
        sleep(10)

        for parquet_to_upload in sorted(test_s3_path.as_path().glob("*")):
            display(f"Uploading {parquet_to_upload}")
            upload_to_s3_with_retry(
                parquet_to_upload, actual.presigned["url"], actual.presigned["fields"]
            )

    # Test using FastAPIBatchJobContext with set_env_variable_context
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        file_type = "parquet"
        deduplicate_data = True
        index_column = "PersonId"
        sort_by = "OccurredTime"
        blocksize = "256MB"
        kwargs = dict(
            usecols=[0, 1, 2, 3, 4],
            parse_dates=["OccurredTime"],
        )
        to_datasource_request = ToDataSourceRequest(
            file_type=file_type,
            deduplicate_data=deduplicate_data,
            index_column=index_column,
            sort_by=sort_by,
            blocksize=blocksize,
            kwargs=kwargs,
        )
        b = BackgroundTasks()
        actual = to_datasource_route(
            datablob_uuid=datablob_parquet.uuid,
            to_datasource_request=to_datasource_request,
            user=user,
            session=session,
            background_tasks=b,
        )
        display(actual)
        assert isinstance(actual, DataSource)
        bg_task = b.tasks[-1]
        display(f"{bg_task.func=}", f"{bg_task.args=}", f"{bg_task.kwargs=}")
        assert bg_task.func == execute_cli
        assert (
            bg_task.kwargs["command"]
            == f"process_parquet {datablob_parquet.id} {actual.id} {index_column} '{json.dumps([sort_by])}' --blocksize {blocksize} --kwargs_json '{json.dumps(kwargs)}' --deduplicate_data"
        ), bg_task.kwargs["command"]

[INFO] __main__: DataBlob.from_local(): FromLocalResponse(uuid=UUID('5e6fb117-e9ad-450a-b1dd-5bf6e272f789'), type='local', presigned={'url': 'https://kumaran-airt-service-eu-west-1.s3.amazonaws.com/', 'fields': {'key': '****************************************', 'x-amz-algorithm': 'AWS4-HMAC-SHA256', 'x-amz-credential': '********************/20221107/eu-west-1/s3/aws4_request', 'x-amz-date': '20221107T091100Z', 'policy': '************************************************************************************************************************************************************************************************************************************************************', 'x-amz-signature': '****************************'}})
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://test-airt-service/account_312571_events
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3test-airt-serviceaccount_312571_events_cached_3o

[Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/_metadata'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/_common_metadata'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/part.3.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/part.0.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/part.1.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/part.4.parquet'),
 Path('/tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/part.2.parquet')]

'Uploading /tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/_common_metadata'

'Uploading /tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/_metadata'

'Uploading /tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/part.0.parquet'

'Uploading /tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/part.1.parquet'

'Uploading /tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/part.2.parquet'

'Uploading /tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/part.3.parquet'

'Uploading /tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0/part.4.parquet'

[INFO] airt.remote_path: S3Path._clean_up(): removing local cache path /tmp/s3test-airt-serviceaccount_312571_events_cached_3otoccd0
[INFO] airt_service.batch_job: create_batch_job(): command='process_parquet 12 7 PersonId \'["OccurredTime"]\' --blocksize 256MB --kwargs_json \'{"usecols": [0, 1, 2, 3, 4], "parse_dates": ["OccurredTime"]}\' --deduplicate_data', task='csv_processing'
[INFO] airt_service.batch_job_components.base: Entering FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job: batch_ctx=FastAPIBatchJobContext(task=csv_processing)
[INFO] airt_service.batch_job_components.fastapi: FastAPIBatchJobContext.create_job(self=FastAPIBatchJobContext(task=csv_processing), command='process_parquet 12 7 PersonId \'["OccurredTime"]\' --blocksize 256MB --kwargs_json \'{"usecols": [0, 1, 2, 3, 4], "parse_dates": ["OccurredTime"]}\' --deduplicate_data', environment_vars={'AWS_ACCESS_KEY_ID': '********************', 'AWS_SECRET_ACCESS_KEY': '****************************

DataSource(id=7, uuid=UUID('bb1928aa-9213-4af9-86c1-3a8f64bebb58'), hash=None, total_steps=1, completed_steps=0, folder_size=None, no_of_rows=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 11, 28), user_id=6, pulled_on=None, tags=[Tag(id=2, created=datetime.datetime(2022, 11, 7, 9, 9, 59), name='latest', uuid=UUID('487ec313-1f51-498e-ad17-c29a9a10aa25')), Tag(id=9, created=datetime.datetime(2022, 11, 7, 9, 10, 59), name='my_parquet_datablob_tag', uuid=UUID('2d04b349-5f76-4988-bc98-47241abc0022'))])

'bg_task.func=<function execute_cli at 0x7fbd812eef70>'

'bg_task.args=()'

'bg_task.kwargs={\'command\': \'process_parquet 12 7 PersonId \\\'["OccurredTime"]\\\' --blocksize 256MB --kwargs_json \\\'{"usecols": [0, 1, 2, 3, 4], "parse_dates": ["OccurredTime"]}\\\' --deduplicate_data\'}'

In [None]:
#| export


@datablob_router.get(
    "/{datablob_uuid}", response_model=DataBlobRead, responses=get_datablob_responses  # type: ignore
)
def get_details_of_datablob(
    datablob_uuid: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> DataBlob:
    """Get details of the datablob"""
    user = session.merge(user)
    datablob = DataBlob.get(uuid=datablob_uuid, user=user, session=session)  # type: ignore
    return datablob

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    expected = user.datablobs[0]

    actual = get_details_of_datablob(
        datablob_uuid=expected.uuid, user=user, session=session
    )
    display(actual)
    assert actual == expected

DataBlob(id=3, uuid=UUID('5f83b4d0-4f0c-46a6-84f6-a79523e77d55'), type='s3', uri='s3://****************************************@bucket', source='s3://bucket', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 5), user_id=6, pulled_on=None, tags=[])

In [None]:
# DataBlob errored out while importing
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    datablob_errored = DataBlob(
        type="s3",
        uri=create_db_uri_for_s3_datablob(
            uri="wrong_uri",
            access_key="wrong_access_key",
            secret_key="wrong_secret_key",
        ),
        source="wrong_uri",
        cloud_provider="aws",
        region="eu-west-1",
        total_steps=1,
        user=user,
        error="test error",
    )
    session.add(datablob_errored)
    session.commit()
    session.refresh(datablob_errored)

    with pytest.raises(HTTPException) as e:
        get_details_of_datablob(
            datablob_uuid=datablob_errored.uuid, user=user, session=session
        )
    display(e)

<ExceptionInfo HTTPException(status_code=422, detail='test error') tblen=3>

In [None]:
#| exporti


@patch
def delete(self: DataBlob, user: User, session: Session):
    """Delete a datablob

    Args:
        user: User object
        session: Sqlmodel session
    """
    delete_data_object_files_in_cloud(data_object=self)

    self.disabled = True

    with commit_or_rollback(session):
        session.add(self)

    return self

In [None]:
#| export


@datablob_router.delete(
    "/{datablob_uuid}",
    response_model=DataBlobRead,
    responses=get_datablob_responses,  # type: ignore
)
def delete_datablob(
    datablob_uuid: str,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> DataBlob:
    """Delete datablob"""
    user = session.merge(user)
    datablob = DataBlob.get(uuid=datablob_uuid, user=user, session=session)  # type: ignore

    return datablob.delete(user, session)

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    from_s3_request = FromS3Request(
        uri="s3://test-airt-service/account_312571_events",
        access_key=os.environ["AWS_ACCESS_KEY_ID"],
        secret_key=os.environ["AWS_SECRET_ACCESS_KEY"],
    )
    with set_env_variable_context(variable="JOB_EXECUTOR", value="fastapi"):
        datablob = from_s3_route(
            from_s3_request=from_s3_request,
            user=user,
            session=session,
            background_tasks=BackgroundTasks(),
        )

    s3_pull(datablob_id=datablob.id)

with get_session_with_context() as session:
    datablob = session.exec(select(DataBlob).where(DataBlob.id == datablob.id)).one()
    display(datablob)
    actual = delete_datablob(datablob_uuid=datablob.uuid, user=user, session=session)
    display(actual)
    assert actual.disabled == True

[{'dag_id': 's3_pull-14', 'run_id': 'airt-service__2022-11-07T09:11:37.727383', 'state': 'running', 'execution_date': '2022-11-07T09:11:38+00:00', 'start_date': '2022-11-07T09:11:39.079337+00:00', 'end_date': ''}]
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://kumaran-airt-service-eu-west-3/6/datablob/14
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3kumaran-airt-service-eu-west-36datablob14_cached_m2ourj22
[INFO] airt.remote_path: S3Path.__init__(): created object for accessing s3://kumaran-airt-service-eu-west-3/6/datablob/14 locally in /tmp/s3kumaran-airt-service-eu-west-36datablob14_cached_m2ourj22
[INFO] airt.remote_path: RemotePath.from_url(): creating remote path with the following url s3://test-airt-service/account_312571_events
[INFO] airt.remote_path: S3Path._create_cache_path(): created cache path: /tmp/s3test-airt-serviceaccount_312571_events_cached_a27kyen9
[INFO] airt.remote_path: S3Path._

DataBlob(id=14, uuid=UUID('2ae92bb7-ae8d-4198-960f-58f0dbde4899'), type='s3', uri='s3://****************************************@test-airt-service/account_312571_events', source='s3://test-airt-service/account_312571_events', total_steps=1, completed_steps=1, folder_size=11219613, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-3', error=None, disabled=False, path='s3://kumaran-airt-service-eu-west-3/6/datablob/14', created=datetime.datetime(2022, 11, 7, 9, 11, 29), user_id=6, pulled_on=datetime.datetime(2022, 11, 7, 9, 11, 47), tags=[Tag(id=2, created=datetime.datetime(2022, 11, 7, 9, 9, 59), name='latest', uuid=UUID('487ec313-1f51-498e-ad17-c29a9a10aa25'))])

DataBlob(id=14, uuid=UUID('2ae92bb7-ae8d-4198-960f-58f0dbde4899'), type='s3', uri='s3://****************************************@test-airt-service/account_312571_events', source='s3://test-airt-service/account_312571_events', total_steps=1, completed_steps=1, folder_size=11219613, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-3', error=None, disabled=True, path='s3://kumaran-airt-service-eu-west-3/6/datablob/14', created=datetime.datetime(2022, 11, 7, 9, 11, 29), user_id=6, pulled_on=datetime.datetime(2022, 11, 7, 9, 11, 47), tags=[Tag(id=2, created=datetime.datetime(2022, 11, 7, 9, 9, 59), name='latest', uuid=UUID('487ec313-1f51-498e-ad17-c29a9a10aa25'))])

In [None]:
#| exporti


@patch(cls_method=True)
def get_all(
    cls: DataBlob,
    disabled: bool,
    completed: bool,
    offset: int,
    limit: int,
    user: User,
    session: Session,
) -> List[DataBlob]:
    """Get all datablobs created by the user

    Args:
        disabled: Whether to get disabled datablobs
        completed: Whether to include only datablobs which are successfully pulled from its source
        offset: Offset results by given integer
        limit: Limit results by given integer
        user: User object
        session: Sqlmodel session

    Returns:
        A list of datablob objects
    """
    statement = select(DataBlob).where(DataBlob.user == user)
    statement = statement.where(DataBlob.disabled == disabled)
    if completed:
        statement = statement.where(DataBlob.completed_steps == DataBlob.total_steps)
    # get all data sources from db
    return session.exec(statement.offset(offset).limit(limit)).all()

In [None]:
#| export


@datablob_router.get("/", response_model=List[DataBlobRead])
def get_all_datablobs(
    disabled: bool = False,
    completed: bool = False,
    offset: int = 0,
    limit: int = Query(default=100, lte=100),
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> List[DataBlob]:
    """
    Get all datablobs created by user
    """
    user = session.merge(user)
    return DataBlob.get_all(  # type: ignore
        disabled=disabled,
        completed=completed,
        offset=offset,
        limit=limit,
        user=user,
        session=session,
    )

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    actual = get_all_datablobs(
        disabled=False, completed=False, offset=0, limit=1, user=user, session=session
    )
    display(actual)

    assert len(actual) == 1
    assert isinstance(actual[0], DataBlob)
    assert actual[0] == user.datablobs[0], f"{actual[0]} != {user.datablobs[0]}"

[DataBlob(id=3, uuid=UUID('5f83b4d0-4f0c-46a6-84f6-a79523e77d55'), type='s3', uri='s3://****************************************@bucket', source='s3://bucket', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 5), user_id=6, pulled_on=None, tags=[])]

In [None]:
actual = get_all_datablobs(
    disabled=False, completed=False, offset=0, limit=10, user=user, session=session
)
display(f"{len(actual)=}")
for datablob in actual:
    assert not datablob.disabled
assert actual[0] == user.datablobs[0]
actual = get_all_datablobs(
    disabled=True, completed=False, offset=0, limit=10, user=user, session=session
)
display(f"{len(actual)=}")
for datablob in actual:
    assert datablob.disabled

actual = get_all_datablobs(
    disabled=False, completed=True, offset=0, limit=10, user=user, session=session
)
display(f"{len(actual)=}")
for datablob in actual:
    assert datablob.completed_steps == datablob.total_steps

'len(actual)=10'

'len(actual)=2'

'len(actual)=0'

In [None]:
#| exporti


@patch
def tag(self: DataBlob, tag_name: str, session: Session):
    """Tag an existing datablob

    Args:
        tag_name: A string to tag the datablob
        session: Sqlmodel session
    """
    user_tag = Tag.get_by_name(name=tag_name, session=session)  # type: ignore

    self.remove_tag_from_previous_datablobs(tag_name=user_tag.name, session=session)  # type: ignore
    self.tags.append(user_tag)

    with commit_or_rollback(session):
        session.add(self)

    return self

In [None]:
#| export


@datablob_router.post("/{datablob_uuid}/tag", response_model=DataBlobRead)
def tag_datablob(
    datablob_uuid: str,
    tag_to_create: TagCreate,
    user: User = Depends(get_current_active_user),
    session: Session = Depends(get_session),
) -> DataBlob:
    """Add tag to datablob"""
    user = session.merge(user)
    datablob = DataBlob.get(uuid=datablob_uuid, user=user, session=session)  # type: ignore

    return datablob.tag(tag_name=tag_to_create.name, session=session)

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()
    datablob = user.datablobs[0]

    tag_name = "new_tag"
    tag_to_create = TagCreate(name=tag_name)
    actual = tag_datablob(
        datablob_uuid=datablob.uuid,
        tag_to_create=tag_to_create,
        user=user,
        session=session,
    )
    display(actual)
    tag_found = False
    for tag in actual.tags:
        if tag.name == tag_name:
            tag_found = True
            break
    assert tag_found

DataBlob(id=3, uuid=UUID('5f83b4d0-4f0c-46a6-84f6-a79523e77d55'), type='s3', uri='s3://****************************************@bucket', source='s3://bucket', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 11, 7, 9, 10, 5), user_id=6, pulled_on=None, tags=[Tag(id=10, created=datetime.datetime(2022, 11, 7, 9, 12), name='new_tag', uuid=UUID('0c1b52f8-166d-4df2-a954-237d2b5eb067'))])