In [18]:
import os, sys
parent_dir = os.path.abspath('..')
# the parent_dir could already be there if the kernel was not restarted,
# and we run this cell again
if parent_dir not in sys.path:
    sys.path.append(parent_dir)


from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, TimestampType
from model_data_collector_preprocessor.trace_aggregator import (
    process_spans_into_aggregated_traces,
    _get_aggregated_trace_log_spark_df_schema
)
import pytest

ModuleNotFoundError: No module named 'span_tree_utils'

In [None]:
def init_spark() -> SparkSession:
    return SparkSession.builder.appName("test").getOrCreate()

_trace_log_schema = _get_aggregated_trace_log_spark_df_schema()

_preprocessed_log_schema = StructType([
    StructField('attributes', StringType(), False),
    StructField('end_time', TimestampType(), False),
    StructField('events', StringType(), False),
    StructField('framework', StringType(), False),
    StructField('input', StringType(), False),
    StructField('links', StringType(), False),
    StructField('name', StringType(), False),
    StructField('output', StringType(), False),
    StructField('parent_id', StringType(), True),
    StructField('span_id', StringType(), False),
    StructField('span_type', StringType(), False),
    StructField('start_time', TimestampType(), False),
    StructField('status', StringType(), False),
    StructField('trace_id', StringType(), False),
    # TODO: this field might not be in v1. Double check later
    # StructField('session_id', StringType(), True),
    # StructField('user_id', StringType(), True),
])

_span_log_data = [
    ["{}", datetime(2024, 2, 5, 0, 8, 0), "[]", "FLOW", "in", "[]", "name",  "out", None] +
    ["1", "llm", datetime(2024, 2, 5, 0, 1, 0), "OK", "01"],
    ["{}", datetime(2024, 2, 5, 0, 5, 0), "[]", "RAG", "in", "[]", "name",  "out", "1"] +
    ["2", "llm", datetime(2024, 2, 5, 0, 2, 0), "OK", "01"],
    ["{}", datetime(2024, 2, 5, 0, 4, 0), "[]", "INTERNAL", "in", "[]", "name",  "out", "2"] +
    ["3", "llm", datetime(2024, 2, 5, 0, 3, 0), "OK", "01"],
    ["{}", datetime(2024, 2, 5, 0, 7, 0), "[]", "LLM", "in", "[]", "name",  "out", "1"] +
    ["4", "llm", datetime(2024, 2, 5, 0, 6, 0), "OK", "01"]
]

_trace_log_data = [
        [datetime(2024, 2, 5, 0, 8, 0), "in", "out", "{\"TBD\": \"TBD\"}", None] +
        [datetime(2024, 2, 5, 0, 1, 0), "01", None],
]

In [None]:
spark = init_spark()
processed_spans_df = spark.createDataFrame(_span_log_data, _preprocessed_log_schema)
expected_trace_df = spark.createDataFrame(_trace_log_data, _trace_log_schema)

print("processed logs:")
processed_spans_df.show()
processed_spans_df.printSchema()

print("expected trace logs:")
expected_trace_df.show()
expected_trace_df.printSchema()

actual_trace_df = process_spans_into_aggregated_traces(processed_spans_df)

print("actual trace logs:")
actual_trace_df.show()
actual_trace_df.printSchema()