# Distributed Model inference using ThirdAI

This notebook demonstrates how to do distributed model inference using ThirdAI with ThirdAI's UDT and amazon_polarity dataset.

The guide consists of the following sections:
* Import all necessary libraries.
* Prepare a trained ThirdAI UDT model for inference.
* Run model inference using Pandas UDF

**Note:**
* This notebook should be running on a Databricks cluster.
* Make sure to add thirdai's license to `/dbfs/mnt/`. For getting license contact: https://www.thirdai.com/try-bolt/

## Import all necessary libraries

In [None]:
!python3 -m pip install --upgrade pip
!python3 -m pip install thirdai
!python3 -m pip install datasets
!python3 -m pip install utils

Collecting pip
  Downloading pip-22.3.1-py3-none-any.whl (2.1 MB)
[?25l[K     |▏                               | 10 kB 25.0 MB/s eta 0:00:01[K     |▎                               | 20 kB 14.1 MB/s eta 0:00:01[K     |▌                               | 30 kB 18.5 MB/s eta 0:00:01[K     |▋                               | 40 kB 10.0 MB/s eta 0:00:01[K     |▉                               | 51 kB 10.7 MB/s eta 0:00:01[K     |█                               | 61 kB 12.3 MB/s eta 0:00:01[K     |█▏                              | 71 kB 12.9 MB/s eta 0:00:01[K     |█▎                              | 81 kB 14.2 MB/s eta 0:00:01[K     |█▍                              | 92 kB 15.5 MB/s eta 0:00:01[K     |█▋                              | 102 kB 16.0 MB/s eta 0:00:01[K     |█▊                              | 112 kB 16.0 MB/s eta 0:00:01[K     |██                              | 122 kB 16.0 MB/s eta 0:00:01[K     |██                              | 133 kB 16.0 MB/s eta 0:00:0

In [None]:
import pandas as pd
from datasets import load_dataset
import numpy as np

import thirdai
from thirdai import bolt

from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.types import StringType, ArrayType, FloatType
from pyspark.sql.functions import col, pandas_udf, PandasUDFType




## Prepare a trained ThirdAI UDT model for inference

Setting ThirdAI's license path here. For getting license contact: https://www.thirdai.com/try-bolt/

In [None]:
thirdai.set_thirdai_license_path("/dbfs/mnt/license.serialized")

Loading Amazon Polarity Dataset

In [None]:
def load_data(output_filename, split, return_inference_batch=False):
    data = load_dataset('amazon_polarity')
    
    df = pd.DataFrame(data[split])
    df = df[['title', 'label']]    
    df.to_csv(output_filename, index=False, sep='\t')
    

train_filename = "/dbfs/mnt/amazon_polarity_train.csv"
test_filename = "/dbfs/mnt/amazon_polarity_test.csv"

load_data(train_filename, split='train')
load_data(test_filename, split='test', return_inference_batch=False)

Downloading builder script:   0%|          | 0.00/4.11k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/6.66k [00:00<?, ?B/s]

Downloading and preparing dataset amazon_polarity/amazon_polarity to /root/.cache/huggingface/datasets/amazon_polarity/amazon_polarity/3.0.0/a27b32b7e7b88eb274a8fa8ba0f654f1fe998a87c22547557317793b5d2772dc...


Downloading data:   0%|          | 0.00/688M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3600000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/400000 [00:00<?, ? examples/s]

Dataset amazon_polarity downloaded and prepared to /root/.cache/huggingface/datasets/amazon_polarity/amazon_polarity/3.0.0/a27b32b7e7b88eb274a8fa8ba0f654f1fe998a87c22547557317793b5d2772dc. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]



  0%|          | 0/2 [00:00<?, ?it/s]

Training, Evaluating and Saving a Bolt's UDT Model.

In [None]:
model = bolt.UniversalDeepTransformer(
    data_types={
        "title": bolt.types.text(),
        "label": bolt.types.categorical()
    },
    n_target_classes=2,
    target="label",
    delimiter='\t',
)

In [None]:
train_config = (bolt.TrainConfig(epochs=1, learning_rate=0.01)
                    .with_metrics(["categorical_accuracy"]))

model.train(train_filename, train_config)

Loading vectors from '/dbfs/mnt/amazon_polarity_train.csv'
Loaded 3600000 vectors from '/dbfs/mnt/amazon_polarity_train.csv' in 6 seconds.
train epoch 0:


train | epoch 0 | updates 1758 | {categorical_accuracy: 0.84207} | batches 1758 | time 205s | complete



In [None]:
eval_config = (bolt.EvalConfig()
                   .with_metrics(["categorical_accuracy"]))

model.evaluate(test_filename, eval_config);

Loading vectors from '/dbfs/mnt/amazon_polarity_test.csv'
Loaded 400000 vectors from '/dbfs/mnt/amazon_polarity_test.csv' in 0 seconds.
test:


predict | epoch 1 | updates 1758 | {categorical_accuracy: 0.855015} | batches 196 | time 7624ms



In [None]:
save_location = "/dbfs/mnt/sentiment_analysis.model"
model.save(save_location)

## Run model inference using Pandas UDF

In [None]:
#Enable Arrow support
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "64")

Load the data into Spark Dataframes.

In [None]:
test_datasets = "dbfs:/mnt/amazon_polarity_test.csv"
df = spark.read.option("header",True).option("sep", "\t").csv(test_datasets).repartition(6)

Define the function for model inference.

In [None]:
@pandas_udf("long")
def predict_batch_udf(batch_iter: pd.Series) -> pd.Series:
    thirdai.set_thirdai_license_path("/dbfs/mnt/license.serialized")
    save_location = "/dbfs/mnt/sentiment_analysis.model"
    model = bolt.UniversalDeepTransformer.load(save_location, "classifier")
    preds = []
    for i in batch_iter:
        pred = model.predict({"title":i})
        class_name = model.class_name(pred.argmax())
        preds.append(int(class_name))
    return pd.Series(preds)
    

Run the model inference and get the predictions

In [None]:
predictions_df = df.select(predict_batch_udf("title"))
predictions_df.head(10)

Out[26]: [Row(predict_batch_udf(title)=0),
 Row(predict_batch_udf(title)=0),
 Row(predict_batch_udf(title)=1),
 Row(predict_batch_udf(title)=0),
 Row(predict_batch_udf(title)=1),
 Row(predict_batch_udf(title)=1),
 Row(predict_batch_udf(title)=0),
 Row(predict_batch_udf(title)=0),
 Row(predict_batch_udf(title)=1),
 Row(predict_batch_udf(title)=0)]