# Unit tests

This notebook demonstrates code focused unit test cases for [nb-city-safety-common.ipynb](../src/notebooks/nb-city-safety-common.ipynb).

## Formatting the Notebook - Run only when developing the Notebook

- Reference: https://learn.microsoft.com/en-us/fabric/data-engineering/author-notebook-format-code#extend-fabric-notebooks]
- **WARNING**: When using formatting using `jupyter_black` it will remove any *cell magic commands* present. You should add them back.

```python
import jupyter_black
jupyter_black.load()
```

## Load function definitions

### Public libraries needed for testing

In [None]:
from unittest.mock import MagicMock, patch, call
import pytest
import ipytest

# this makes the ipytest magic available and raise_on_error causes notebook failure incase of errors
ipytest.autoconfig(raise_on_error=True)

### Source the function definitions that need to be tested

- In this example we are using `%run` magic to source the function definitions we are going to test. We can also perform `import` if they are part of the environment.
- [nb-city-safety-common.ipynb](../src/notebooks/nb-city-safety-common.ipynb) must be in the same Fabric workspace as this notebook.
- The external `common_execution_mode` parameter controls which cells to run in the notebook [nb-city-safety-common.ipynb](../src/notebooks/nb-city-safety-common.ipynb).
- See [run a notebook](https://learn.microsoft.com/fabric/data-engineering/author-execute-notebook#spark-session-configuration-magic-command) for details about `%run`.

In [None]:
%run nb-city-safety-common { "common_execution_mode": "normal" }

## Create unit tests

### Define local resources and test fixtures


In [None]:
from opentelemetry import trace, metrics
from pyspark.sql.functions import (
    lit,
    to_utc_timestamp,
    current_timestamp,
    unix_timestamp,
    avg,
    max,
    min,
    sum,
    count,
)

# from typing import Optional
# from opentelemetry.trace.status import StatusCode
tracer = trace.get_tracer(__name__)
logger = logging.getLogger(__name__)
meter = metrics.get_meter_provider().get_meter(__name__)

# from pyspark.sql.functions import col, expr
# from pyspark.sql.types import TimestampType
# from typing import Optional
# from opentelemetry import trace
# from opentelemetry.trace.status import StatusCode
if spark_major_version >= 3.5:
    from pyspark.testing import assertDataFrameEqual, assertSchemaEqual

onelake_table_path = "dummy_path"
table_name = "test_table"
delta_table_path = f"{onelake_table_path}/{table_name}"
wasbs_path = "dumppy_wasbs_path"
cities = ["city1", "city2"]
onelake_name = "dummy_onelake"
workspace_name = "dummy_workspace"
lakehouse_name = "dummy_lakehouse"
job_exec_instance = "dummy_job_exec_id"
user_name = "dummy_user_name"

### Environment shakeout tests

In [None]:
def test_env_source_connection():
    # TO DO: write unit testcase
    assert 1 == 1


def test_env_target_connection():
    # TO DO: write unit testcase
    assert 1 == 1


def test_env_keyvault_connection():
    # TO DO: write unit testcase
    assert 1 == 1

### Code-based tests

In [None]:
# uncomment the two lines below to run the tests directly without using ipytest.run()
# %%capture captured_output
# %%ipytest


def test_query_app_insights():
    # TO DO: write unit testcase
    assert 1 == 1


def test_store_unit_test_results():
    # TO DO: write unit testcase
    assert 1 == 1


def test_make_fabric_api_call():
    # TO DO: Write unit testcase
    assert 1 == 1


@pytest.mark.parametrize(
    "table_exists, exp_load_mode", [(False, "overwrite"), (True, "append")]
)
@patch("__main__.DeltaTable")
def test_identify_table_load_mode(mock_delta, table_exists, exp_load_mode):
    mock_span_obj = MagicMock()
    mock_delta.isDeltaTable.return_value = table_exists
    load_mode = identify_table_load_mode(table_name=table_name, span_obj=mock_span_obj)
    mock_span_obj.set_attribute.assert_called_with("load_mode", exp_load_mode)


@pytest.mark.parametrize(
    "table_exists, expected_logs",
    [
        (False, ["The specified delta table doesn't exist. No need for deletion."]),
        (
            True,
            [
                f"Attempting to delete existing delta table with {delta_table_path = }....",
                "Deleted existing delta table: test_table.",
            ],
        ),
    ],
)
@patch("__main__.notebookutils")
def test_delete_delta_table(mock_mssparkutils, caplog, table_exists, expected_logs):
    caplog.clear()
    caplog.set_level(logging.INFO)

    mock_mssparkutils.fs.exists.return_value = table_exists
    result = delete_delta_table(table_name)

    if table_exists:
        mock_mssparkutils.fs.rm.assert_called_with(dir=delta_table_path, recurse=True)
    assert all(log in caplog.text for log in expected_logs)
    assert result is None
    assert len(caplog.records) == len(expected_logs)


@patch("__main__.notebookutils")
def test_delete_delta_table_exception(mock_mssparkutils, caplog):
    caplog.clear()
    caplog.set_level(logging.INFO)

    mock_mssparkutils.fs.exists.return_value = True

    # Mock an exception being raised during deletion
    mock_mssparkutils.fs.rm.side_effect = Exception("Deletion failed")

    with pytest.raises(Exception) as exc:
        delete_delta_table(table_name)

    assert exc.type == Exception
    assert str(exc.value) == "Deletion failed"

    assert (
        f"Attempting to delete existing delta table with {delta_table_path = }...."
        in caplog.text
    )
    assert f"Deletion failed with the error:\n====Deletion failed\n=====" in caplog.text
    assert len(caplog.records) == 2


@patch("__main__.transform_data")
@patch("__main__.identify_table_load_mode")
@patch("__main__.spark")
@patch("__main__.tracer")
def test_city_data_etl(
    mock_tracer, mock_spark, mock_load_mode, mock_transform_data, caplog
):
    caplog.clear()
    caplog.set_level(logging.INFO)

    mock_spark.read.parquet = MagicMock()
    mock_spark.read.parquet.return_value.count.return_value = 10
    mock_load_mode.return_value = "overwrite"

    mock_city_span = mock_tracer.start_as_current_span.return_value = MagicMock()
    # inputs for current invocation
    table_name = "mydummytable"
    cities = ["city1", "city2"]
    mock_span = MagicMock()
    mock_span.add_event = MagicMock()

    exp_log_output = []
    for city in cities:
        exp_log_output += [
            f"ETL started for {city = }.",
            f"\t Data Extraction in progress.",
            f"\t Read {mock_spark.read.parquet.return_value.count.return_value} records for {city = }.",
            f"\t Data loading in inprogress using {mock_load_mode.return_value} mode.",
            f"ETL completed for {city = }.",
        ]

    city_data_etl(table_name, cities, mock_span)

    call_count = len(cities)

    assert mock_load_mode.call_count == call_count
    assert mock_transform_data.call_count == call_count
    assert mock_load_mode.call_count == call_count
    assert mock_span.add_event.call_count == call_count * 2
    assert all(log in caplog.text for log in exp_log_output)

    # identifying the calls being made - helpful to write assert statements
    # print(mock_tracer.mock_calls)  # print(mock_city_span.mock_calls)
    assert (
        call.start_as_current_span().__enter__() in mock_tracer.mock_calls
    )  # <----trace (with) context call
    assert (
        mock_tracer.start_as_current_span.call_count == call_count
    )  # <----trace (with) context call
    assert (
        mock_tracer.start_as_current_span.return_value.__enter__.call_count
        == call_count
    )  # <----trace (with) context-start
    assert (
        mock_tracer.start_as_current_span.return_value.__enter__.return_value.set_attribute.call_count
        == call_count
    )
    assert (
        mock_tracer.start_as_current_span.return_value.__enter__.return_value.add_event.call_count
        == call_count * 3
    )
    assert (
        mock_tracer.start_as_current_span.return_value.__enter__.return_value.set_status.call_count
        == call_count
    )
    assert (
        mock_tracer.start_as_current_span.return_value.__exit__.call_count == call_count
    )  # <--- trace (with) context-end


@pytest.mark.parametrize(
    "cleanup_mode",
    [(False), (True)],
)
@patch("__main__.delete_delta_table")
@patch("__main__.city_data_etl")
@patch("__main__.trace")
def test_etl_steps(mock_trace, mock_city_etl, mock_del_table, caplog, cleanup_mode):
    caplog.clear()
    caplog.set_level(logging.INFO)
    mock_current_span = mock_trace.get_current_span.return_value = MagicMock()

    etl_steps(table_name, cities, cleanup=cleanup_mode)

    exp_log_output = [
        f"\n=====\nCity safety data is loaded into {table_name =} for {cities =}\n====="
    ]

    assert mock_current_span.set_attributes.call_count == 1
    assert mock_current_span.set_status.call_count == 1
    mock_city_etl.assert_called_once_with(table_name, cities, mock_current_span)
    if cleanup_mode:
        mock_del_table.assert_called_once()
        exp_log_output += [
            f"A new delta table '{table_name}' will be created with {delta_table_path = }"
        ]
    else:
        exp_log_output += ["No request for cleanup. Proceeding to ETL steps."]

    assert all(log in caplog.text for log in exp_log_output)


@patch("__main__.spark")
@patch("__main__.meter")
def test_gather_city_level_metrics(mock_meter, mock_spark, caplog):
    caplog.clear()
    caplog.set_level(logging.INFO)

    df = spark.createDataFrame(
        data=[("Boston", 1000), ("Chicago", 3000)], schema=["city", "count"]
    )
    mock_spark.read.format.return_value.load.return_value = df
    df_filter_and_agg = (
        mock_spark.createDataFrame.return_value.filter.return_value.groupBy.return_value.agg.return_value.agg.return_value
    )
    df_filter_and_agg.collect.return_value.__getitem__.return_value.__getitem__.return_value = (
        2
    )
    mock_counter = MagicMock()

    gather_city_level_metrics(table_name, mock_counter)

    mock_counter.add.assert_called_once_with(
        amount=1, attributes={"record_count_total": 2}
    )
    assert "total:2" in caplog.text


@patch("__main__.notebookutils")
@patch("__main__.trace.get_current_span")
def test_verify_onelake_connection(mock_span, mock_mssparkutils, caplog):
    caplog.clear()
    caplog.set_level(logging.INFO)
    mock_mssparkutils.fs.ls.return_value = str(["filepath1", "filepath2"])

    # File system exists
    mock_mssparkutils.fs.exists.return_value = True
    verify_onelake_connection()
    mock_span.assert_called_once()
    assert len(caplog.records) == 1
    assert (
        f"Target table path: {onelake_table_path} is valid and exists.\nListing source data contents to check connectivity\n['filepath1', 'filepath2']"
        in caplog.text
    )
    mock_span.return_value.set_status.assert_called_once()

    # Fiel system doesn't exist
    mock_mssparkutils.fs.exists.return_value = False
    verify_onelake_connection()
    print(mock_span.mock_calls)
    print(mock_mssparkutils.mock_calls)
    assert (
        "Error message: Encountered error while checking for Lakehouse table path specified."
        in caplog.text
    )
    mock_span.return_value.record_exception.assert_called_once()
    mock_span.return_value.set_status.assert_called()
    mock_mssparkutils.notebook.exit.assert_called_once_with(
        f"Specfied lakehouse table path {onelake_table_path} doesn't exist. Ensure onelake={onelake_name}, workspace={workspace_name} and lakehouse={lakehouse_name} exist."
    )

### Data-based tests

In [None]:
@pytest.mark.parametrize(
    "city, test_data",
    [
        (
            "Boston",
            [("event1", "2022-01-01T00:00:00"), ("event2", "2022-01-01T23:00:00")],
        ),
        (
            "Seattle",
            [("event1", "2022-01-01T00:00:00"), ("event2", "2022-01-01T23:00:00")],
        ),
        (
            "Chicago",
            [("event1", "2022-01-01T00:00:00"), ("event2", "2022-01-01T23:00:00")],
        ),
    ],
)
@patch("__main__.to_utc_timestamp")
def test_transform_data(mock_utcts, city, test_data):

    # This is dataframe col of timestamp type
    date_output_value = lit("2022-12-31T23:00:00").cast(TimestampType())
    mock_utcts.return_value = date_output_value
    test_df = spark.createDataFrame(test_data, ["event_name", "dateTime"])
    test_df = test_df.withColumn("dateTime", col("dateTime").cast(TimestampType()))

    act_df = test_df.withColumns(
        {
            "dateTimeUTC": date_output_value,
            "City": lit(city),
            "jobExecId": lit(job_exec_instance),
            "lastUpdateUTC": date_output_value,
            "lastUpdateUser": lit(user_name),
        }
    )

    exp_df = transform_data(city, test_df)

    # works in spark 3.5 +
    # - If using earlier realeases compare columns, and then data collect statements to compare records
    if spark_major_version >= 3.5:
        assertDataFrameEqual(act_df, exp_df)
    else:
        # Compare the DataFrames
        assert exp_df.exceptAll(act_df).count() == 0
        assert exp_df.schema == act_df.schema

## Run the testcases and capture the output

- As `ipytest.autoconfig(raise_on_error=True)` was used in the begining of this notebook, any errors from the testcases will not result in notebook failure.

In [None]:
%%capture captured_output
ipytest.run()

## Process the test rsults

- These can be stored somewhere or sent for further processing.

In [None]:
store_unit_test_results(captured_output)