In [None]:
#| default_exp helpers

In [None]:
from airt.testing import activate_by_import

[INFO] airt.testing.activate_by_import: Testing environment activated.
[INFO] numexpr.utils: Note: NumExpr detected 16 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
[INFO] numexpr.utils: NumExpr defaulting to 8 threads.
[INFO] airt.keras.helpers: Using a single GPU #0 with memory_limit 1024 MB


In [None]:
#| export

import os
import random
import string
import re
from contextlib import contextmanager
from os import environ
from pathlib import Path
from typing import *

import pandas as pd
import requests
from fastcore.utils import *
from passlib.context import CryptContext
from sqlmodel import Session

from airt.logger import get_logger

In [None]:
import tempfile
from time import sleep

import pytest
import numpy as np
import dask.dataframe as dd
from sqlmodel import select
from _pytest.monkeypatch import MonkeyPatch

from airt.remote_path import RemotePath
from airt_service.data.datablob import FromLocalRequest, from_local_start_route
from airt_service.data.utils import create_db_uri_for_s3_datablob
from airt_service.db.models import (
    create_user_for_testing,
    get_session,
    get_session_with_context,
    User,
    DataBlob,
)

[INFO] airt.executor.subcommand: Module loaded.


In [None]:
old_setattr = MonkeyPatch.setattr


@patch
def setattr(self: MonkeyPatch, *args, **kwargs):
    global logger
    old_setattr(self, *args, **kwargs)
    logger = get_logger(__name__)

In [None]:
test_username = create_user_for_testing()
display(test_username)

'evfogffpoi'

In [None]:
#| exporti

logger = get_logger(__name__)

In [None]:
#| export

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

In [None]:
#| export


def get_password_hash(password: str) -> str:
    """Get the hash for a given password

    Args:
        password: Password to be hashed as a string

    Returns:
        The hashed password as a string
    """
    return pwd_context.hash(password)


def verify_password(plain_password: str, hashed_password: str) -> bool:
    """Validate if the hashed password is derived from the plain password

    Args:
        plain_password: Plain password as a string
        hashed_password: Hashed password stored in the database as a string

    Returns:
        True, if the hashed password is derived from the plain password else False
    """
    return pwd_context.verify(plain_password, hashed_password)

In [None]:
test_password = "Welcome123"

actual_password_hash = get_password_hash(test_password)
display(actual_password_hash)

assert verify_password(test_password, actual_password_hash)

'$2b$12$pxZRvHW0Znt08CFtNKaequw7NdcRgmsj2nX6Tq..6OAJqjitLIrOK'

In [None]:
#| export


def get_storage_path() -> Path:
    """Get the root storage_path to store datasource, models, predictions in local

    Returns:
        The root storage path
    """
    storage_path = Path(os.environ.get("STORAGE_PATH", "./storage"))
    storage_path = storage_path.absolute()
    storage_path.mkdir(parents=True, exist_ok=True)
    return storage_path

In [None]:
actual = get_storage_path()
display(actual)
assert (actual == Path("/tf/airt-service/storage")) or (
    actual == Path("./storage").absolute()
)
assert actual.exists()

Path('/tf/airt-service/storage')

In [None]:
#| export


def get_datasource_path() -> Path:
    """Get a local path to store the datasources

    Returns:
        A path to store the datasources
    """
    storage_path = get_storage_path()
    datasource_path = storage_path / "datasource"
    datasource_path.mkdir(parents=False, exist_ok=True)
    return datasource_path

In [None]:
expected = get_storage_path() / "datasource"
actual = get_datasource_path()
display(actual)
assert actual == expected
assert actual.exists()

Path('/tf/airt-service/storage/datasource')

In [None]:
#| export


def get_model_path() -> Path:
    """Get a local path to store the models

    Returns:
        A path to store the models
    """
    storage_path = get_storage_path()
    model_path = storage_path / "model"
    model_path.mkdir(parents=False, exist_ok=True)
    return model_path

In [None]:
expected = get_storage_path() / "model"
actual = get_model_path()
display(actual)
assert actual == expected
assert actual.exists()

Path('/tf/airt-service/storage/model')

In [None]:
#| export


def get_prediction_path() -> Path:
    """Get a local path to store the predictions

    Returns:
        A path to store the predictions
    """
    storage_path = get_storage_path()
    prediction_path = storage_path / "prediction"
    prediction_path.mkdir(parents=False, exist_ok=True)
    return prediction_path

In [None]:
expected = get_storage_path() / "prediction"
actual = get_prediction_path()
display(actual)
assert actual == expected
assert actual.exists()

Path('/tf/airt-service/storage/prediction')

In [None]:
#| export


def generate_random_string(length: int = 6) -> str:
    """Generate a random string of the given length

    Args:
        length: Length of the random string. If not set, then the default value 6 will be used.

    Returns:
        A random string of the given length
    """
    return "".join(
        random.choice(string.ascii_uppercase + string.digits)  # nosec B311
        for _ in range(length)
    )

In [None]:
actual = generate_random_string(length=10)
display(actual)
assert len(actual) == 10

'XNC354M2SX'

In [None]:
#| export


@contextmanager
def set_env_variable_context(variable: str, value: str):
    old_value = environ[variable] if variable in environ else None
    environ[variable] = value
    yield
    if old_value is None:
        del environ[variable]
    else:
        environ[variable] = old_value

In [None]:
# test JOB_EXECUTOR is not set already scenario
assert "SET_AND_TEST_ENV_VARIABLE" not in environ
with set_env_variable_context(variable="SET_AND_TEST_ENV_VARIABLE", value="something"):
    assert environ["SET_AND_TEST_ENV_VARIABLE"] == "something"
assert "SET_AND_TEST_ENV_VARIABLE" not in environ

# test JOB_EXECUTOR is set already using first with statement scenario
assert "SET_AND_TEST_ENV_VARIABLE" not in environ
with set_env_variable_context(variable="SET_AND_TEST_ENV_VARIABLE", value="something"):
    assert environ["SET_AND_TEST_ENV_VARIABLE"] == "something"

    with set_env_variable_context(
        variable="SET_AND_TEST_ENV_VARIABLE", value="different_value"
    ):
        assert environ["SET_AND_TEST_ENV_VARIABLE"] == "different_value"

    assert environ["SET_AND_TEST_ENV_VARIABLE"] == "something"
assert "SET_AND_TEST_ENV_VARIABLE" not in environ

In [None]:
#| export


@contextmanager
def commit_or_rollback(session: Session):
    """A context manager to commit the changes to the database. In the case of an exception,
    the database will be rollback to the previous state.

    Args:
        session: Current session object
    """
    try:
        yield
        session.commit()
    except Exception as e:
        session.rollback()
        raise e

In [None]:
with get_session_with_context() as session:
    user = session.exec(select(User).where(User.username == test_username)).one()

    uri = "s3://bucket"
    # test_tag = Tag.get_by_name(name="test", session=session)
    db_uri = create_db_uri_for_s3_datablob(
        uri=uri,
        access_key="access",
        secret_key="secret",
    )

    def test_commit_or_rollback(raise_exception: bool):
        with commit_or_rollback(session):
            datablob = DataBlob(
                type="s3",
                uri=db_uri,
                source=uri,
                cloud_provider="aws",
                region="eu-west-1",
                total_steps=1,
                user=user,
                #         tags=[test_tag],
            )
            session.add(datablob)
            assert datablob.id is None
            if raise_exception:
                raise ValueError("I had one job and I failed")

        return datablob

    # positive case
    datablob = test_commit_or_rollback(False)
    display(datablob)
    assert datablob.id

    # negative case
    with pytest.raises(ValueError):
        test_commit_or_rollback(True)

    print("ok")

DataBlob(id=181, uuid=UUID('eb85cf42-4214-4890-aefe-d7ef35e44776'), type='s3', uri='s3://access:secret@bucket', source='s3://bucket', total_steps=1, completed_steps=0, folder_size=None, cloud_provider=<CloudProvider.aws: 'aws'>, region='eu-west-1', error=None, disabled=False, path=None, created=datetime.datetime(2022, 9, 13, 11, 23, 6), user_id=141, pulled_on=None, tags=[])

ok


In [None]:
#| export


def truncate(s: str, length: int = 255) -> str:
    """Truncate the string to a given length

    Args:
        s: String to truncate
        length: Length to truncate the string

    Returns:
        The truncated string
    """
    return s[:length]

In [None]:
test_cases = [
    {
        "s": "Error; something went wrong",
        "length": 255,
        "expected": "Error; something went wrong",
    },
    {
        "s": "Error; something went wrong",
        "length": 5,
        "expected": "Error",
    },
    {
        "s": "-" * 2000,
        "length": 255,
        "expected": "-" * 255,
    },
]

for case in test_cases:
    actual = truncate(s=case["s"], length=case["length"])
    display(f"{actual=}")
    assert actual == case["expected"]

"actual='Error; something went wrong'"

"actual='Error'"

"actual='---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------'"

In [None]:
#| export


def df_to_dict(df: pd.DataFrame) -> Dict[str, Any]:
    """Convert pandas dataframe to dict

    Args:
        df: Input dataframe

    Returns:
        A dict the data and dtypes
    """
    d = {
        "data": df.to_dict("tight"),
        "dtypes": df.dtypes.apply(lambda x: str(x)).to_dict(),
    }
    return d


def dict_to_df(d: Dict[str, Any]) -> pd.DataFrame:
    """Convert the dict into a pandas dataframe

    Args:
        d: Dict containing the data and dtypes

    Returns:
        The pandas dataframe constructed from the dict
    """
    data = d["data"]
    dtypes = d["dtypes"]
    df = pd.DataFrame.from_dict(data, orient="tight")
    for k, v in dtypes.items():
        df[k] = df[k].astype(v)
    return df

In [None]:
df = pd.util.testing.makeDataFrame().set_index("A")
for c in df.columns:
    df[f"{c}_float32"] = df[c].astype("float32")
    df[f"{c}_int32"] = df[c].astype("int32")
    df[f"{c}_bool"] = df[c].astype("bool")
df["ts"] = np.datetime64("now")

ddf = dd.from_pandas(df, npartitions=4)
ddf_head = ddf.head(10)
s = df_to_dict(ddf_head)
actual = dict_to_df(s)

pd.testing.assert_frame_equal(ddf_head, actual)
display(actual)

  import pandas.util.testing


Unnamed: 0_level_0,B,C,D,B_float32,B_int32,B_bool,C_float32,C_int32,C_bool,D_float32,D_int32,D_bool,ts
A,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
-1.970345,0.79658,1.302249,-0.731424,0.79658,0,True,1.302249,1,True,-0.731424,0,True,2022-09-13 11:23:05
-1.855409,-1.538287,-0.411187,-0.074388,-1.538287,-1,True,-0.411187,0,True,-0.074388,0,True,2022-09-13 11:23:05
-1.490789,-2.690547,-0.313748,-0.956371,-2.690547,-2,True,-0.313748,0,True,-0.956371,0,True,2022-09-13 11:23:05
-1.452324,1.344901,0.639226,-1.147019,1.344901,1,True,0.639226,0,True,-1.147019,-1,True,2022-09-13 11:23:05
-1.342805,1.502447,-0.139105,-0.639073,1.502447,1,True,-0.139105,0,True,-0.639073,0,True,2022-09-13 11:23:05
-1.072208,-0.018484,1.092937,-0.416845,-0.018484,0,True,1.092937,1,True,-0.416845,0,True,2022-09-13 11:23:05
-1.036864,-0.088844,-0.151247,-0.223246,-0.088844,0,True,-0.151247,0,True,-0.223246,0,True,2022-09-13 11:23:05
-1.015528,-1.484464,0.259131,-1.018076,-1.484464,-1,True,0.259131,0,True,-1.018076,-1,True,2022-09-13 11:23:05


In [None]:
#| export


def _detect_sql_code_injection(s: str) -> bool:
    """Check if the given string contains SQL code injection

    Args:
        s: String to validate

    Returns:
        True, if the given string contains SQL code injection
    """
    # https://larrysteinle.com/2011/02/20/use-regular-expressions-to-detect-sql-code-injection/
    regex_text = "('(''|[^'])*')|(;)|(\b(ALTER|CREATE|DELETE|DROP|EXEC(UTE){0,1}|INSERT( +INTO){0,1}|MERGE|SELECT|UPDATE|UNION( +ALL){0,1})\b)"
    return bool(re.search(regex_text, s))

In [None]:
# unsafe inputs
unsafe_inputs = ["'; SELECT * FROM test ;--", "10; DROP TABLE test /*"]

# safe inputs
safe_inputs = [
    "index_col",
    "index-col",
    "indexCol",
    "revenue_10",
    "string with space",
    "str!@$",
    "$str",
    "%%test%%",
    "SELECT a FROM test",
]

for i in unsafe_inputs:
    assert _detect_sql_code_injection(i), i

for i in safe_inputs:
    assert not _detect_sql_code_injection(i), i

In [None]:
#| export


def validate_user_inputs(xs: List[str]):
    """Validate the user input for SQL code injection

    Args:
        xs: List of strings

    Raises:
        ValueError: If the list contains SQL code injection
    """
    for i in xs:
        if _detect_sql_code_injection(i):
            raise ValueError(f"The input {i} is invalid. SQL code injection detected.")

In [None]:
valid_inputs = ["index_col", "index-col", "indexCol", "revenue_10"]
validate_user_inputs(valid_inputs)

In [None]:
invalid_inputs = ["index_col", "index-col", "indexCol", "'; SELECT * FROM test ;--"]

with pytest.raises(ValueError) as e:
    validate_user_inputs(invalid_inputs)

assert (
    "The input '; SELECT * FROM test ;-- is invalid. SQL code injection detected."
    == str(e.value)
)