# Inference Tables to MLflow
This notebook goes over pulling inference tables into an MLflow experiment for further evaluation.

In [0]:
%pip install mlflow --upgrade
%restart_python

In [0]:
import mlflow
from mlflow import MlflowClient
import pandas as pd
import json
from mlflow.entities import SpanType
from datetime import datetime
from typing import Dict, Any, Optional
from pandas._libs.tslibs.timestamps import Timestamp

In [0]:
EXPERIMENT_NAME = "/Workspace/Users/scott.mckean@databricks.com/experiments/inf_to_mlflow"
INFERENCE_TABLE = "shm.3w.well_agent_payload"

In [0]:
mlflow.set_experiment(EXPERIMENT_NAME)
exp_id = mlflow.get_experiment_by_name(EXPERIMENT_NAME).experiment_id

## Extract Inference Table
We use spark to pull the inference table. You can see that MLflow can display the trace directly when using display.

In [0]:
df = spark.table(INFERENCE_TABLE)
df_pd = df.toPandas()
display(df)

## Transfer to MLflow
We wrap a low-level API to log the trace into MLflow with as little modifications as possible

In [0]:
def log_inference_row_trace(row: pd.Series) -> Optional[str]:
    """
    Create a single MLflow trace from a row of an inference table
    """
    request_data = json.loads(row['request'])
    response_data = json.loads(row['response'])
    start_time = int(pd.to_datetime(row['request_time']).timestamp() * 1e9)
    end_time = int(start_time + row['execution_duration_ms'] * 1e6)
    total_tokens = json.loads(response_data['databricks_output']['trace']['info']['trace_metadata']['mlflow.trace.tokenUsage'])['total_tokens']

    span = client.start_trace(
      row['databricks_request_id'], 
      span_type='CHAIN',
      inputs=request_data,
      attributes={
          "execution_duration_ms": row['execution_duration_ms'],
          "status_code": row['status_code'],
          "request_time": row['request_time'],
          "sampling_fraction": row['sampling_fraction'],
          "llm.token_usage.input_tokens": 5,
          "llm.token_usage.output_tokens": 10,
          "llm.token_usage.total_tokens": 15
      },
      tags={
          "source": "inference_table",
          "databricks_request_id": row['databricks_request_id'],
          "client_request_id": row['client_request_id'],
          "served_entity_id": row['served_entity_id'],
          "requester": row['requester'],
      },
      start_time_ns=start_time
      )

    client.end_trace(
      trace_id=span.trace_id,
      outputs=response_data,
      end_time_ns=end_time,

      )
    
    return span.trace_id

In [0]:
spark.table(INFERENCE_TABLE).toPandas()
for idx, row in df_pd.iloc.iterrows():
    trace_id = log_inference_row_trace(row)