<center><img src="https://storage.googleapis.com/arize-assets/arize-logo-white.jpg" width="200"/></center>

# Batch Ingestion for Multiclass Classification (Classification and AUC Metrics)

In this tutorial, we'll outline how to send predictions (scores + labels) and actuals from multiclass models to Arize in batch in order to calculate classification metrics and AUC for our model. Multiclass classification models are defined as a classification model with more than two classes. Each example can only be labeled as one class. For more information on multiclass ingestion, please see our documentation <a href="https://docs.arize.com/arize/model-types/multiclass-classification">here</a>. For a list of all model types, please see our documentation <a href="https://docs.arize.com/arize/">here</a>.

## Install and Import Dependencies

In [None]:
!pip install -q arize
from arize.pandas.logger import Client, Schema
from arize.utils.types import ModelTypes, Environments

import pandas as pd
import datetime
import numpy as np

## Download and Display Data
For this tutorial, we will use a sample Parquet file containing 100 predictions. 

In [None]:
file_url = "https://storage.googleapis.com/arize-assets/documentation-sample-data/data-ingestion/multiclass-classification-assets/multiclass-sample-data.parquet"
df = pd.read_parquet(file_url)
df.head()

## Add Timestamps for Predictions
Generate sample timestamps for each prediction. More information on timestamps in Arize can be found <a href="https://docs.arize.com/arize/sending-data/model-schema-reference#6.-timestamp">here</a>.

In [None]:
current_time = datetime.datetime.now().timestamp()

earlier_time = (
    datetime.datetime.now() - datetime.timedelta(days=30)
).timestamp()

optional_prediction_timestamps = np.linspace(
    earlier_time, current_time, num=df.shape[0]
)

df["prediction_ts"] = pd.Series(optional_prediction_timestamps.astype(int))
df[["prediction_ts"]].head()

## Restructure DataFrame
In order to send the probability/propensity for each class label in the prediction (i.e., the prediction scores), we need to fan out the single inference into a prediction for each class value. Thus, we need to restructure our DataFrame to reflect this. We'll use <a href="https://docs.arize.com/arize/sending-data/model-schema-reference#9.-tags">tags</a> to identify which class value is associated with each prediction score, which we can then filter on in the Arize platform. The prediction label will remain the same across all predictions and will represent what the model actually predicted for that specific record. The actual label will also remain the same across all predictions and will be the record's true actual label. The example below shows how 1 record will be fanned out into three predictions - one for each class value. 

#### Example prediction

**Inference**

| prediction_id | prediction_classes | prediction_scores | class_pred | actual_class |
| --- | ----------- | ------| ----- | --------- | 
| pred_123 | [first_class, business_class, economy_class] | [0.75, 0.15, 0.10] | first_class | first_class |

**Predictions Sent to Arize**

| prediction_id | prediction_label | prediction_score | tag | actual_label |
| --- | ----------- | ------| ----- | --------- | 
| pred_123_first_class | first_class | 0.75 | first_class | first_class | 
| pred_123_business_class | first_class | 0.15 | business_class | first_class | 
| pred_123_economy_class | first_class | 0.10 | economy_class | first_class | 


In [None]:
def restructure_df(df: pd.DataFrame) -> pd.DataFrame:
    """Return the restructured DataFrame with predictions fanned-out per class value."""

    # Explode each class value in list into separate rows
    df_restructured = df.explode(
        column=["prediction_classes", "prediction_scores"]
    )

    # Rename prediction_classes column and prediction_scores column for clarity
    df_restructured.rename(
        columns={
            "prediction_classes": "tag",
            "prediction_scores": "class_score"
        },
        inplace=True
    )

    # Set new prediction_id with combination of existing ID and the tag value
    df_restructured["prediction_id"] = (
        df_restructured["prediction_id"] + "_" + df_restructured["tag"]
    )

    return df_restructured

In [None]:
# Restructure DataFrame and show sample
df = restructure_df(df)
df.head(6)

## Create Arize Client
Sign up/login to your Arize account <a href="https://app.arize.com/auth/login">here</a>. Find your <a href="https://docs.arize.com/arize/sending-data/sdk-reference/python-sdk/arize.init#retrieving-space-and-api-keys">Space and API keys</a>. Copy/paste into the cell below. 

In [None]:
SPACE_KEY = "SPACE_KEY"  # update value here with your Space Key
API_KEY = "API_KEY"  # update value here with your API key

arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)

if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
    raise ValueError("❌ NEED TO CHANGE SPACE AND/OR API_KEY")
else:
    print(
        "✅ Import and Setup Arize Client Done! Now we can start using Arize!"
    )

## Define Schema
Create your <a href="https://docs.arize.com/arize/sending-data-to-arize/model-schema-reference">model schema</a>.

In [None]:
schema = Schema(
    prediction_id_column_name="prediction_id",
    timestamp_column_name="prediction_ts",
    prediction_label_column_name="class_pred",
    prediction_score_column_name="class_score",
    feature_column_names=["feature1", "feature2", "feature3", "feature4"],
    tag_column_names=["tag"],
    actual_label_column_name="actual_class"
)

## Log Data to Arize
Log the DataFrame using the <a href="https://docs.arize.com/arize/sending-data-to-arize/data-ingestion-methods/sdk-reference/python-sdk/arize.pandas">pandas API</a>. 

In [None]:
response = arize_client.log(
    dataframe=df,
    model_id="multiclass-classification-and-auc-metrics-batch-ingestion-tutorial",
    model_version="1.0",
    model_type=ModelTypes.SCORE_CATEGORICAL,
    environment=Environments.PRODUCTION,
    schema=schema
)

if response.status_code == 200:
    print(f"✅ You have successfully logged production dataset to Arize")
else:
    print(
        f"Logging failed with response code {response.status_code}, {response.text}"
    )