# ThirdAI integration with Databricks

This notebook demonstrates how to do distributed inference using ThirdAI with ThirdAI's UDT over amazon_polarity dataset on Databricks.

The guide consists of the following sections:
* Import all necessary libraries.
* Prepare a trained ThirdAI UDT model for inference.
* Run model inference using Pandas UDF

**Note:**
* This notebook should be running on a Databricks cluster.
* Make sure to add thirdai's license to `/dbfs/mnt/`. For getting license contact: https://www.thirdai.com/try-bolt/

## Import all necessary libraries

In [None]:
!python3 -m pip install --upgrade pip
!python3 -m pip install thirdai==0.5.4
!python3 -m pip install datasets
!python3 -m pip install utils

In [None]:
import pandas as pd
from datasets import load_dataset
import numpy as np

import thirdai
from thirdai import bolt

from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.types import StringType, ArrayType, FloatType
from pyspark.sql.functions import col, pandas_udf, PandasUDFType


## Prepare a trained ThirdAI UDT model for inference

Setting ThirdAI's license path here. For getting license contact: https://www.thirdai.com/try-bolt/

In [None]:
thirdai.set_thirdai_license_path("/dbfs/mnt/license.serialized")

Loading Amazon Polarity Dataset

In [None]:
def load_data(output_filename, split, return_inference_batch=False):
    data = load_dataset('amazon_polarity')
    
    df = pd.DataFrame(data[split])
    df = df[['title', 'label']]    
    df.to_csv(output_filename, index=False, sep='\t')
    

train_filename = "/dbfs/mnt/amazon_polarity_train.csv"
test_filename = "/dbfs/mnt/amazon_polarity_test.csv"

load_data(train_filename, split='train')
load_data(test_filename, split='test', return_inference_batch=False)

Training, Evaluating and Saving a Bolt's UDT Model.

In [None]:
model = bolt.UniversalDeepTransformer(
    data_types={
        "title": bolt.types.text(),
        "label": bolt.types.categorical()
    },
    n_target_classes=2,
    target="label",
    delimiter='\t',
)

In [None]:
train_config = (bolt.TrainConfig(epochs=1, learning_rate=0.01)
                    .with_metrics(["categorical_accuracy"]))

model.train(train_filename, train_config)

In [None]:
eval_config = (bolt.EvalConfig()
                   .with_metrics(["categorical_accuracy"]))

model.evaluate(test_filename, eval_config);

In [None]:
save_location = "/dbfs/mnt/sentiment_analysis.model"
model.save(save_location)

## Run model inference using Pandas UDF

In [None]:
#Enable Arrow support
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "64")

Load the data into Spark Dataframes.

In [None]:
test_datasets = "dbfs:/mnt/amazon_polarity_test.csv"
df = spark.read.option("header",True).option("sep", "\t").csv(test_datasets).repartition(6)

Define the function for model inference.

In [None]:
@pandas_udf("long")
def predict_batch_udf(batch_iter: pd.Series) -> pd.Series:
    thirdai.set_thirdai_license_path("/dbfs/mnt/license.serialized")
    save_location = "/dbfs/mnt/sentiment_analysis.model"
    model = bolt.UniversalDeepTransformer.load(save_location, "classifier")
    preds = []
    for i in batch_iter:
        pred = model.predict({"title":i})
        class_name = model.class_name(pred.argmax())
        preds.append(int(class_name))
    return pd.Series(preds)
    

Run the model inference and get the predictions

In [None]:
predictions_df = df.select(predict_batch_udf("title"))
predictions_df.head(10)