<header>
   <p  style='font-size:36px;font-family:Arial; color:#F0F0F0; background-color: #00233c; padding-left: 20pt; padding-top: 20pt;padding-bottom: 10pt; padding-right: 20pt;'>
       Supervised Text Classification using ONNXEmbeddings and BYOM
 <br>       
       <img id="teradata-logo" src="https://storage.googleapis.com/clearscape_analytics_demo_data/DEMO_Logo/teradata.svg" alt="Teradata" style="width: 150px; height: auto; margin-top: 20pt;">
  <br>
    </p>
</header>


<p style = 'font-size:20px;font-family:Arial;'><b>Introduction</b></p>
<p style = 'font-size:16px;font-family:Arial;'>
Text classification is a crucial task in natural language processing (NLP) with applications across various domains, including healthcare, customer support, and log analysis. When dealing with manually labeled text data, training a supervised classification model can significantly enhance automation and decision-making. However, creating and classifying text manually is often expensive and time-consuming.</p>

<p style = 'font-size:16px;font-family:Arial;'>
In this blog post, we explore a practical approach to text classification using ONNXEmbeddings and the Bring Your Own Model (BYOM) framework. This method allows for efficient embedding generation and classification model deployment directly within a database, enabling real-time predictions and decision support.</p>

<p style = 'font-size:16px;font-family:Arial;'>
Even though the specific example is more or less arbitrary, our focus will be on a medical use case: classifying patient conditions based on medical abstracts written by doctors. Given the high volume of abstracts reviewed daily in hospitals, an assistive technology that can accurately predict the condition category can be invaluable.</p>

<p style = 'font-size:16px;font-family:Arial;'>To achieve this, we will:   
<ul style = 'font-size:16px;font-family:Arial;'>
  <li>Generate text embeddings from medical abstracts</li>
  <li>Train a supervised classification model on these embeddings</li>
  <li>Deploy both the embedding generation and classification model using ONNX</li>
</ul>
<p style = 'font-size:16px;font-family:Arial;'>The approach is further illustrated in this diagram:
</p>

<img src=images/workflow.png style="border: 4px solid #404040; border-radius: 10px;"/>

<p style = 'font-size:16px;font-family:Arial;'>
This approach provides an alternative to simple zero-shot classification, where vector similarity to predefined target descriptions is used. Instead, by leveraging a dataset of historically labeled data—whether binary, multi-class, or multiple binary classifications—we can achieve superior accuracy and interpretability in classification tasks.</p>


<p style = 'font-size:20px;font-family:Arial;'><b>Dataset Overview</b></p>
<p style = 'font-size:16px;font-family:Arial;'>
We use the <b>Medical Abstracts Text Classification Dataset</b>, originally compiled by Schopf, Braun, and Matthes (2023) in their paper <i>"Evaluating Unsupervised Text Classification: Zero-Shot and Similarity-Based Approaches."</i> This dataset contains medical abstracts categorized into five condition types:</p>
<ul style = 'font-size:16px;font-family:Arial;'>
  <li>Digestive system diseases</li>
  <li>Cardiovascular diseases</li>
  <li>Neoplasms</li>
  <li>Nervous system diseases</li>
  <li>General pathological conditions</li>
</ul>

<p style = 'font-size:16px;font-family:Arial;'>The dataset consists of 11,550 training records and 2,888 test records. While the training set includes class labels, the goal is to predict the classes for the test dataset.</p>

<p style = 'font-size:16px;font-family:Arial;'>For those interested, the dataset is publicly available on  <a href="https://github.com/sebischair/Medical-Abstracts-TC-Corpus.git" target="_blank">GitHub Repository</a> and is licensed under the Creative Commons Attribution-ShareAlike 3.0 Unported License.</p>

<p style = 'font-size:16px;font-family:Arial;'>In the following sections, we will walk through the training and deployment workflows, showcasing how ONNX embeddings and a BYOM approach can streamline text classification for real-world applications.</p>


<hr style="height:2px;border:none;">
<p style = 'font-size:20px;font-family:Arial;'><b>1. Connect to Vantage, Import python packages and explore the dataset</b></p>


In [None]:
!pip install -r requirements.txt --quiet

In [None]:
%%capture
!pip install teradataml --upgrade
!pip install huggingface_hub --upgrade
!pip install teradataml-plus --upgrade

<div class="alert alert-block alert-info">
<p style = 'font-size:16px;font-family:Arial;'><b>Note: </b><i>The above libraries have to be installed. Restart the kernel after executing these cells to bring the installed libraries into memory. The simplest way to restart the Kernel is by typing <b> 0 0</b></i> (zero zero) and pressing <i>Enter</i>.</p>
</div>
<p style = 'font-size:16px;font-family:Arial;'>Here, we import the required libraries, set environment variables and environment paths (if required).</p>

In [None]:
import getpass
import tdmlplus
import pandas as pd
import numpy as np
import json
from collections import OrderedDict
import time

import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
import catboost
import onnx
from skl2onnx.common.data_types import FloatTensorType
import onnxruntime as rt
from teradataml import *
import plotly.express as px
import plotly.io as pio
import re
import json
import plotly.graph_objects as go
import ipywidgets as widgets
from IPython.display import display, Markdown

In [None]:
%run utils/tab_widget.py # imports a function `display_dataframes_in_tabs`
list_relevant_tables = [] # we will be adding names of relevant tables progressivley into this list to display 

<hr style="height:2px;border:none;">
<b style = 'font-size:18px;font-family:Arial;'> 1.1 Connect to Vantage</b>
<p style = 'font-size:16px;font-family:Arial;'>We will be prompted to provide the password. We will enter the password, press the Enter key, and then use the down arrow to go to the next cell.</p>

In [None]:
%run -i ../../UseCases/startup.ipynb
eng = create_context(host = 'host.docker.internal', username='demo_user', password = password)
print(eng)

In [None]:
%%capture
execute_sql('''SET query_band='DEMO=Telco_Customer_Churn_EFS.ipynb;' UPDATE FOR SESSION; ''')

In [None]:
%run utils/_dataload.ipynb # takes about 1 minute

<p style = 'font-size:16px;font-family:Arial;'>In addition, we want to check if our database has already got the required functionality to generate embeddings.</p>


In [None]:
DataFrame.from_query("select InfoKey, InfoData FROM DBC.DBCInfoV")

In [None]:
VCL = False # current system is VCE/VCore

In [None]:
if VCL:
    results = execute_sql("help database mldb").fetchall()
else:
    results = execute_sql("help user mldb").fetchall()

embeddings_functions = [x[0] for x in results if x[0].startswith("ONNXEmbeddings")]
if len(embeddings_functions) >0:#
    print("\n".join(embeddings_functions))
    print("---------------------\nONNXEmbeddings is installed")
else:
    print("ONNXEmbeddings is not installed. Please Upgrade to BYOM version 6")

<hr style="height:2px;border:none;">
<p style = 'font-size:20px;font-family:Arial;'><b>2. Data Exploration</b></p>


<p style = 'font-size:16px;font-family:Arial;'>Before training our model, we first inspect the dataset to understand its structure. We use the function <code>display_dataframes_in_tabs()</code>, which presents the three key tables as interactive tabs in an IPython widget.</p>
<ul style = 'font-size:16px;font-family:Arial;'>
<li>medical_train and medical_test: These datasets have the same schema with three columns:</li>
<ul>
<li>row_id: Unique identifier for each record.</li>
<li>medical_abstract: The textual description of a patient's condition.</li>
<li>condition_label: A numerical label (1-5) representing the diagnosed condition.</li>
</ul>
<li>medical_labels: A lookup table mapping condition labels to their respective names:</li>
</ul>

In [None]:
list_relevant_tables+=["medical_train","medical_test","medical_labels"]

In [None]:
display_dataframes_in_tabs(list_relevant_tables)

In [None]:
DF_train_raw = DataFrame("medical_train")
DF_test_raw = DataFrame("medical_test")
DF_labels = DataFrame("medical_labels")

In [None]:
df_labels = DF_labels.to_pandas().set_index("condition_label")
condition_dict = df_labels.to_dict()["condition_name"]
condition_dict

<p  style = 'font-size:16px;font-family:Arial;'>To check the distribution of the classes that we are going to predict, we need to perform aggregation. The best way to do this is in the database. It will reduce memory usage and speed up processing. By keeping the heavy computation in the database and only retrieving summarized results, we minimize data transfer and optimize performance.</p>
<p  style = 'font-size:16px;font-family:Arial;'>Although the image below shows us that the dataset is not perfectly balanced, each condition category still has a sufficient number of examples for training a reliable classification model.</p>

In [None]:
DF_agg = (DF_train_raw
     .groupby("condition_label")
     .agg([DF_train_raw['row_id'].count().alias("num_rows")]))
df_agg = DF_agg.to_pandas().sort_values("condition_label")
df_agg["condition"] = df_agg.condition_label.astype(str) +" - "+ df_agg.condition_label.map(condition_dict)
df_agg.plot(x = "condition", y = "num_rows", kind = "bar")

<hr style="height:2px;border:none;">
<p style = 'font-size:20px;font-family:Arial;'><b>3. Load HuggingFace Model</b></p>


<p style = 'font-size:16px;font-family:Arial;'>To generate embeddings, we need an ONNX model capable of transforming text into vector representations. We use a pretrained model from
<a href="https://huggingface.co/Teradata/gte-base-en-v1.5" target="_blank">Teradata's Hugging Face repository</a>    
such as gte-base-en-v1.5. The model and its tokenizer are downloaded and stored in Vantage tables as BLOBs using the save_byom function.</p>

In [None]:
from huggingface_hub import hf_hub_download

model_name = "gte-base-en-v1.5"
number_dimensions_output = 768
model_file_name = "model.onnx" 

In [None]:
# Step 1: Download Model from Teradata HuggingFace Page
hf_hub_download(repo_id=f"Teradata/{model_name}", filename=f"tokenizer.json", local_dir="./")

In [None]:
# using the command line syntax as it is more reliable then the python function

In [None]:
!hf download Teradata/{model_name} onnx/{model_file_name} --local-dir ./

In [None]:
try:
    db_drop_table("embeddings_models")
except:
    pass
try:
    db_drop_table("embeddings_tokenizers")
except:
    pass

In [None]:
# Step 2: Load Models into Vantage
# a) Embedding model
save_byom(model_id = model_name, # must be unique in the models table
               model_file = f"onnx/{model_file_name}",
               table_name = 'embeddings_models' )
# b) Tokenizer
save_byom(model_id = model_name, # must be unique in the models table
              model_file = 'tokenizer.json',
              table_name = 'embeddings_tokenizers') 

In [None]:
display_dataframes_in_tabs(["embeddings_models","embeddings_tokenizers"])

<hr style="height:2px;border:none;">
<p style = 'font-size:20px;font-family:Arial;'><b>4. Generate Embeddings with ONNXEmbeddings</b></p>


<p style = 'font-size:16px;font-family:Arial;'>Now it's time to generate the embeddings using <code>ONNXEmbeddings</code>.</p>
<p style = 'font-size:16px;font-family:Arial;'>We run the <code>ONNXEmbeddings</code> function to generate embeddings for a small subset of records. The model is <b>loaded into the cache memory on each node</b>, and Teradata's <b>Massively Parallel Processing (MPP)</b> architecture ensures that embeddings are computed in parallel using <b>ONNX Runtime</b> on each node.</p>
<p style = 'font-size:16px;font-family:Arial;'>Having said that, generating embeddings for the entire training set can be time-consuming, especially when working on a system with limited resources. In the <b>ClearScape Analytics experience</b>, only a <b>4 AMP system</b> with constrained RAM and CPU power is available. To ensure smooth execution, we test embedding generation on a small sample and use <b>pre-calculated embeddings</b> for the remainder of this blog post. In a real-life scenario you would typically encounter multiple hundred AMPs with much more compute power.</p>
<p style = 'font-size:16px;font-family:Arial;'>Also have a look at the most important input parameters of this <code>ONNXEmbeddings</code> function.</p>

<ul style = 'font-size:16px;font-family:Arial;'> 
<li><code><b>InputTable</b></code>: The source table containing the text to be embedded. Here, we use a subquery to rename <code>medical_abstract</code> to <code>txt</code> since <code>ONNXEmbeddings</code> expects the input column to be named <code>txt</code>.</li>
<li><code><b>ModelTable</b></code>: The table storing the ONNX model.</li>
<li><code><b>TokenizerTable</b></code>: The table storing the tokenizer JSON file.</li>
<li><code><b>Accumulate</b></code>: Specifies additional columns to retain in the output (<code>row_id</code>, <code>condition_label</code>, and <code>txt</code>).</li>
<li><code><b>OutputFormat</b></code>: Specifies the data format of the output embeddings (<code>FLOAT32(768)</code>, matching the model's output dimension).</li>
</ul>

<p style = 'font-size:16px;font-family:Arial;'>Since embedding generation is computationally expensive, we only process <b>10 records for testing</b> and rely on precomputed embeddings for further analysis.</p>

In [None]:
configure.byom_install_location = "mldb"

In [None]:
DF_sample10 = DataFrame("medical_train")
DF_sample10 = DF_sample10.assign(txt = DF_sample10.medical_abstract).top(10)

In [None]:
my_model = DataFrame.from_query(f"select * from {username}.embeddings_models where model_id = '{model_name}'")
my_tokenizer = DataFrame.from_query(f"select model as tokenizer from {username}.embeddings_tokenizers where model_id = '{model_name}'")

In [None]:
# Step 4: Test ONNXEmbeddings Function
# Note that ONNXEmbeddings expects the 'payload' column to be 'txt'. 
# If it has got a different name, just rename it in a subquery/CTE.
DF_embeddings10 = ONNXEmbeddings(
    newdata = DF_sample10,
    modeldata = my_model,
    tokenizerdata = my_tokenizer, 
    accumulate = ['row_id','condition_label', 'txt'],
    model_output_tensor = "sentence_embedding",
    output_format = f'FLOAT32({number_dimensions_output})',
    enable_memory_check = False    
).result

In [None]:
DF_embeddings10

In [None]:
list_relevant_tables.append("medical_train_embedding")
DF_train_embeddings = DataFrame("medical_train_embedding")
DF_train_embeddings.shape

<hr style="height:2px;border:none;">

<p style = 'font-size:20px;font-family:Arial;'><b>5. Visualizing Embeddings with t-SNE</b></p>
<p style = 'font-size:16px;font-family:Arial;'>Once we have generated embeddings for the texts, the next step is to visualize them using <a href="https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding" target="_blank">t-SNE (t-Distributed Stochastic Neighbor Embedding)</a>        
. t-SNE is a dimensionality reduction technique that projects the high-dimensional embeddings into two dimensions for easier visualization, while preserving the local structure of the data. By plotting the embeddings, we aim to gain insights into how well the model has captured the underlying structure of the data, specifically how the classes, labeled by human analysts, are distributed in the embedding space.</p>
<p style = 'font-size:16px;font-family:Arial;'>In this case, after creating the embeddings in Teradata, we are looking to uncover any patterns or clusters in the data that align with the analyst-assigned labels. The coloring of the plot represents the target classes, and what we expect to see is some degree of separation between these classes. If we observe distinct clusters for each class, it suggests that the embeddings effectively capture the differences between the labels.</p>

In [None]:
df_train_embeddings = DataFrame("medical_train_embedding").to_pandas()

df_train_embeddings = df_train_embeddings.sort_values("condition_label").reset_index()
df_train_embeddings["condition"] = df_train_embeddings.condition_label.astype(str) +" - "+ df_train_embeddings.condition_label.map(condition_dict)

X = df_train_embeddings[[f"emb_{i}" for i in range(number_dimensions_output)]].values
y = df_train_embeddings["condition"].values

In [None]:
# Extract embeddings and class labels
# Apply t-SNE
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
X_tsne = tsne.fit_transform(X)

df_tsne = pd.DataFrame(X_tsne, columns = ["tsne_1","tsne_2"])
df_tsne["condition"] = y

In [None]:
pio.renderers.default = "notebook_connected"

def plot_tsne_plotly(this_df):
    palette = sns.color_palette("tab10", n_colors=len(this_df["condition"].unique())).as_hex()
    fig = px.scatter(
        this_df, 
        x="tsne_1", 
        y="tsne_2", 
        color="condition", 
        title="t-SNE Visualization of Embeddings",
        labels={"tsne_1": "t-SNE Component 1", "tsne_2": "t-SNE Component 2"},
        opacity=0.3,
        color_discrete_sequence=palette,
        size_max=50 
    )
    
    fig.update_layout(
        legend_title_text="Condition",
        legend=dict(x=1.05, y=1, traceorder="normal"),
        width=1200,  # Set the width
        height=800,  # Set the height to be the same as the width
            plot_bgcolor='white',  # Set the background color to white
        xaxis=dict(showgrid=False,range=[-110, 110]),  # Remove the grid for the x-axis
        yaxis=dict(showgrid=False,range=[-110, 110])   # Remove the grid for the y-axis
    )
    
    fig.show()

In [None]:
f = plot_tsne_plotly(df_tsne)

<p style = 'font-size:16px;font-family:Arial;'>Take some time to understand the plot. If you have created the plotly chart, you can select/unselect different conditions by double-clicking on the entry in the legend.</p>
<p style = 'font-size:16px;font-family:Arial;'>Looking at the t-SNE scatterplot, we can see that most labels form clear groups in different areas of the plot, showing that the embeddings capture meaningful differences between the texts. For example, the yellow label (2 - digestive system dieseases) has at least four separate clusters spread out across the plot. This suggests that even though the texts share the same label, they have distinct differences. Simply using one topic vector for this category wouldn't work well using Vector Distance, so a more advanced supervised model is needed to handle these variations.</p>
<p style = 'font-size:16px;font-family:Arial;'>On the other hand, the "5 - general pathological conditions" category (category 5) doesn't form a clear cluster. This is expected, as this category is broad and covers a wide range of topics, which makes it harder to group the texts together. This shows why it's useful to train a supervised machine learning model — it can learn these differences and improve the classification by recognizing the complexity within each category.</p>


<img src=images/tsne_embeddings.gif style="border: 4px solid #404040; border-radius: 10px;"/>


<hr style="height:2px;border:none;">

<p style = 'font-size:20px;font-family:Arial;'><b>6. Train and Evaluate Catboost Model for Classification</b></p>
<p style = 'font-size:16px;font-family:Arial;'>Next, we train a supervised machine learning model to classify the texts based on the embeddings. While many models could work, we choose CatBoost for its strong performance with minimal tuning required.</p>

In [None]:
key = "row_id"
target = "condition"

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state=42)

model_catb = catboost.CatBoostClassifier(loss_function="MultiClass", iterations=100)

model_catb.fit(X_train, y_train)

In [None]:
# Predict the labels for the test set
y_pred = model_catb.predict(X_test)

# Generate the confusion matrix
cm = confusion_matrix(y_test, y_pred)

# Generate classification report
report = classification_report(y_test, y_pred)

print(report)

In [None]:
# Create a DataFrame for the confusion matrix
cm_df = pd.DataFrame(cm, index=model_catb.classes_, columns=model_catb.classes_)

# Plot the confusion matrix as a heatmap
plt.figure(figsize=(10, 7))
sns.heatmap(cm_df, annot=True, fmt='d', cmap='Blues')
plt.ylabel('Actual')
plt.xlabel('Predicted')
plt.title('Confusion Matrix Heatmap')
plt.show()

<p style = 'font-size:16px;font-family:Arial;'>The model shows decent performance, with an overall accuracy of 51%. However, it's important to note that the "general pathological conditions" category is the largest group and has the highest confusion. This makes it harder to distinguish, as texts in this category often overlap with other diseases. Ignoring this category, the scores for the other groups would be significantly higher. Some categories, like "digestive system diseases" and "nervous system diseases," have lower precision and recall, which could be due to the difficulty in distinguishing these diseases from others, or the presence of multiple conditions in the same text.</p>

<p style = 'font-size:16px;font-family:Arial;'>If you're not satisfied with the results and want to improve performance, there are several steps you can take:</p>
<ul style = 'font-size:16px;font-family:Arial;'>
<li>Increase the number of iterations during model training to allow CatBoost to learn more from the data.</li>
<li>Tune the hyperparameters of the model to find the optimal configuration for your dataset. You can find more details on parameter tuning <a href="https://catboost.ai/docs/en/concepts/parameter-tuning" target="_blank">here</a>  .</li>
<li>Consider adding structured data, such as patient information, disease history, or physiological measurements, to provide the model with more context and improve predictions.</li>
<li>Fine-tuning the embeddings model could also help capture more relevant features. Alternatively, you might want to try using an open-source model that has already been trained on medical data, as this could bring better results.</li>
</ul>

<p style = 'font-size:16px;font-family:Arial;'>The next step is to convert the model to the ONNX format, making it compatible for deployment within Vantage. We then test the model by running it in a local ONNX runtime to ensure it works as expected. Finally, we upload the model into a table for further use.</p>

In [None]:
onnxfile_catb = "catb-medical.onnx"
model_catb_id = "catb_medical"

model_catb.save_model(onnxfile_catb,
                 format="onnx",
    export_parameters={
        'onnx_domain': 'ai.catboost',
        'onnx_model_version': 1,
        'onnx_doc_string': 'model for MultiClassification of medical abstracts',
        'onnx_graph_name': 'CatBoostModel_for_MultiClassification'
    })

sess = rt.InferenceSession(onnxfile_catb)

label, probabilities = sess.run(['label', 'probabilities'],
                                {'features': X[:3].astype(np.float32)})
print(probabilities)

model_table_name = "medical_models"
try:
    db_drop_table(model_table_name)
except:
    pass

save_byom(model_id = model_catb_id, 
               model_file =  onnxfile_catb, 
               table_name =  model_table_name)

display_dataframes_in_tabs([model_table_name])

In [None]:
DF_test_embeddings = DataFrame("medical_test_embedding")

<p style = 'font-size:16px;font-family:Arial;'>Next, we will test the inference in-database using ONNXPredict. We utilize the ONNXPredict Python wrapper from <code>teradataml</code>, and by using the <code>show_query()</code> function, we can confirm that it's running as an actual in-database function.</p>

In [None]:
configure.byom_install_location = "mldb"

onnxpred_obj = ONNXPredict(
    newdata = DF_test_embeddings,
    modeldata = DataFrame(model_table_name),
    accumulate = ["row_id", "condition_label"],
    overwrite_cached_models = "*",
    model_input_fields_map = f"features=emb_0:emb_{number_dimensions_output-1}" 

)

print(onnxpred_obj.show_query())

In [None]:
onnxpred_obj.result

<hr style="height:2px;border:none;">

<p style = 'font-size:20px;font-family:Arial;'><b>7. Deployment</b></p>
<p style = 'font-size:16px;font-family:Arial;'>The optimal deployment scenario for an MPP system like Vantage is through batch processing, where new data is processed in regular intervals such as hourly, daily, weekly, or monthly. In this case, you could combine the SQL for ONNXEmbeddings and ONNXPredict into a single SQL query and deploy it as part of an ETL job. After processing, you can review the results, filter out uncertain predictions, and provide them to analysts for further review. With their feedback, you can retrain the models to improve performance.</p>
<p style = 'font-size:16px;font-family:Arial;'>Another deployment scenario involves ad-hoc probability predictions, where medical personnel enter a text and want to see the probabilities in real-time. For this, we can build an interactive widget that allows quick, on-the-spot inference, providing instant insights for medical diagnosis. This is what we do in the final chapter.</p>

<p style = 'font-size:20px;font-family:Arial;'><b>Widget</b></p>

<p  style = 'font-size:16px;font-family:Arial;'><b>Disclaimer</b>: This tool should not be used for actual diagnosis. It is solely for showcasing the capability of combining embeddings with labeled data and supervised machine learning.</p>
<p  style = 'font-size:16px;font-family:Arial;'>The widget is designed for diagnosis using pure SQL, making it easy to integrate into a dashboard. Below find a preview of what the widget will look like.</p>

<hr style="height:1px;border:none;">


<img src=images/textclassificationwidget.gif style="border: 4px solid #404040; border-radius: 10px;"/>


<hr style="height:1px;border:none;">
<p style = 'font-size:16px;font-family:Arial;'>Below you will find the code. We go trough it step by step:</p>
<p style = 'font-size:16px;font-family:Arial;'><code><b>get_query(text)</b></code></p>
<ul style = 'font-size:16px;font-family:Arial;'>
<li><b>Sanitize Input:</b>The function starts by sanitizing the user input <code>text</code> to prevent SQL injection attacks using a regular expression (<code>re.sub</code>). It removes harmful characters such as single quotes, double dashes, semicolons, and backslashes.</li>
<li><b>Generate SQL Query:</b> A SQL query is constructed using the sanitized input. The query runs the <code>ONNXEmbeddings</code> function to generate embeddings from the input text using a specified model. The result is stored in a CTE (Common Table Expression) called <code>embeddings_output</code>.</li>
<li><b>Run Prediction:</b> The <code>ONNXPredict</code> function is then applied to the embeddings output to generate predictions based on the trained model in the <code>medical_models</code> table. The prediction result is returned.</li>
</ul>

In [None]:
def get_query(text):

    # to avoid sql injection, we sanitize the input
    text = re.sub(r"(')|(--)|(;)|(\\)", lambda match: {"'": "", "--": " ", ";": " ", "\\": ""}[match.group()], text)
    
    complete_query = f"""
    WITH embeddings_output AS (
    SELECT 
            *
    from mldb.ONNXEmbeddings(
            on (SELECT 1 as row_id, CAST('{text}' AS VARCHAR(10000)) as txt ) as InputTable
            on (select * from embeddings_models where model_id = '{model_name}') as ModelTable DIMENSION
            on (select model as tokenizer from embeddings_tokenizers where model_id = '{model_name}') as TokenizerTable DIMENSION
            using
                ModelOutputTensor('sentence_embedding')
                EnableMemoryCheck('false')
                Accumulate('row_id')
                OutputFormat('FLOAT32({number_dimensions_output})')
                OverwriteCachedModel('false')
        ) a )
    
    SELECT * FROM "mldb".ONNXPredict(
        ON embeddings_output AS InputTable
        PARTITION BY ANY 
        ON "medical_models" AS ModelTable
        DIMENSION
        USING
        Accumulate('row_id')
        OverwriteCachedModel('n')
        ModelInputFieldsMap('features=emb_0:emb_{number_dimensions_output-1}')
    ) as sqlmr
    
    """
    return complete_query
    

<p style = 'font-size:16px;font-family:Arial;'><code><b>get_predictions(query)</b></code></p>
<ul style = 'font-size:16px;font-family:Arial;'>
<li><b>Execute SQL Query: </b>This function takes the generated query from <code>get_query()</code> and executes it using <code>execute_sql(query)</code>. It fetches the result of the query as a single row, without the overhead of creating a DataFrame, that would require collecting metadata, too.</li>
<li><b>Return Predictions: </b>The predictions are returned as a JSON string, which contains the predicted probabilities for different conditions.</li>    
</ul>


In [None]:
def get_predictions(query):
    input_array = execute_sql(query).fetchall()[0]
    return input_array
    

<p style = 'font-size:16px;font-family:Arial;'><code><b>create_bar_chart(input_array)</b></code>:</p>
<ul style = 'font-size:16px;font-family:Arial;'>
<li><b>Parse JSON: </b>The JSON string containing the predicted probabilities is parsed using <code>json.loads()</code>.</li>    
<li><b>Extract and Sort Data: </b>The probabilities are extracted from the parsed data and sorted by the condition labels.</li>    
<li><b>Generate Bar Chart: </b>A bar chart is created using <code>plotly.graph_objects</code> to visually display the probabilities of different conditions. The X-axis shows the diseases, and the Y-axis shows the associated probabilities.</li>    
<li><b>Return Chart: </b>The bar chart figure is returned for display.</li>    
</ul>

In [None]:
def create_bar_chart(input_array):
    # Parse the JSON string
    data = json.loads(input_array[1])
    
    # Extract the value sub dict
    value_dict = data['probabilities'][0]['value']
    sorted_value_dict = dict(sorted(value_dict.items(), key=lambda item: item[0]))
    
    # Create the bar chart
    fig = go.Figure(data=[go.Bar(x=list(sorted_value_dict.keys()), y=list(sorted_value_dict.values()))])
    
    # Set chart title and labels
    fig.update_layout(
        title="Probabilities of Different Conditions",
        xaxis_title="Disease",
        yaxis_title="Probability"
    )
    return fig

<p style = 'font-size:16px;font-family:Arial;'><code><b>create_probability_app()</b></code></p>
<ul style = 'font-size:16px;font-family:Arial;'>
<li><b>Create UI Elements: </b>The app UI is created using <code>ipywidgets</code>:</li>
<ul>
    <li>A header is displayed at the top.</li>
    <li>A text area is provided for inputting the medical abstract.</li>
    <li>A button is added that triggers the probability calculation.</li>
    <li>Tabs are created to display either the probability plot or the SQL query for transparency.</li>
</ul>
<li><b>Define Button Action: </b>When the "Get Probabilities" button is clicked, the text entered in the textarea is passed to the <code>get_query()</code> function to generate the SQL query, and then to <code>get_predictions()</code> to get the probabilities.</li>
<li><b>Display Results: </b>The SQL query and the bar chart with the predicted probabilities are displayed in the respective tabs. The SQL query is shown in a markdown format for transparency, and the bar chart visualizes the predicted probabilities.</li>
<li><b>Launch App: </b>Finally, the layout is displayed, showing the input area, button, and the two tabs for the results.</li>
</ul>

In [None]:
def create_probability_app():
    # Create the header
    header = widgets.HTML(value="<h1>Calculate Probabilities of Different Conditions Based on Medical Abstract</h1>")

    # Create the textarea for the medical abstract
    textarea = widgets.Textarea(
        value='',
        placeholder='Enter medical abstract here...',
        layout=widgets.Layout(width='95%', height='200px')
    )

    # Create the button to get probabilities
    button = widgets.Button(
        description='Get Probabilities',
        button_style='primary'
    )

    # Create the VBox for the left side
    left_vbox = widgets.VBox([textarea, button])

    # Create the tabs
    tab_contents = ['Probability Plot', 'Query']
    plot_output = widgets.Output()
    query_output = widgets.Output()
    tab_children = [plot_output, query_output]
    tabs = widgets.Tab()
    tabs.children = tab_children
    for i in range(len(tab_contents)):
        tabs.set_title(i, tab_contents[i])

    # Create the overall layout
    app_layout = widgets.VBox([
        header,
        widgets.AppLayout(
            header=None,
            left_sidebar=left_vbox,
            center=tabs,
            right_sidebar=None,
            footer=None
        )
    ])

    def on_get_probabilities(text_input):
        query_output.clear_output()
        plot_output.clear_output()
        if not text_input:
            return
        
        this_query = get_query(text_input)
        
        with query_output:
            query_md = Markdown(f"""```sql\n{this_query}\n```""")
            display(query_md)
        
        this_plot = create_bar_chart(get_predictions(this_query))
        
        with plot_output:
            display(this_plot)

    def button_clicked(b):
        on_get_probabilities(textarea.value)

    button.on_click(button_clicked)

    # Display the layout
    display(app_layout)

<p style = 'font-size:16px;font-family:Arial;'><b>Try it Out</b></p>
<p style = 'font-size:16px;font-family:Arial;'>You can try out the application with this example text; I have let Copilot create this medical abstract based on 5 examples of condition label 3. Let's see if that works:</p>
<p style = 'font-size:16px;font-family:Arial;'><code>Early initiation of physical therapy (PT) has been hypothesized to improve functional outcomes in stroke patients. This study aimed to evaluate the effects of early PT on recovery post-stroke. We conducted a retrospective analysis of 300 stroke patients who received PT within 48 hours of stroke onset (early PT group) and compared them to 300 patients who began PT after 48 hours (delayed PT group). Functional outcomes were assessed using the Modified Rankin Scale (mRS) at 3, 6, and 12 months post-stroke. The early PT group demonstrated significantly better mRS scores at all time points (p < 0.01). Additionally, the early PT group had a lower incidence of complications such as pneumonia and deep vein thrombosis (p < 0.05). These findings suggest that early initiation of PT is associated with improved functional outcomes and reduced complications in stroke patients. Further prospective studies are warranted to confirm these results and to explore the underlying mechanisms.</code></p>
<p style = 'font-size:16px;font-family:Arial;'>Also, check out the "Query" tab to view the SQL query generated for this input. This allows you to see the exact SQL query executed behind the scenes for full transparency!</p>

In [None]:
# Call the function to create and display the app
create_probability_app()

<hr style="height:2px;border:none;">
<p style = 'font-size:20px;font-family:Arial;'><b>Conclusion</b></p>

<p style = 'font-size:16px;font-family:Arial;'>In this blog post, we've seen how leveraging manually created labels on top of texts can be a great starting point for helping analysts focus only on the "difficult" cases, significantly improving their workflow. By generating embeddings with ONNXEmbedding on a powerful MPP CPU system, you can achieve performance similar to a small GPU cluster, making it accessible for larger-scale analysis.</p>
<p style = 'font-size:16px;font-family:Arial;'>We also explored the power of t-SNE to visualize how close examples of similar or different classes are, providing valuable insights into the structure of your data. Whether you're using an open-source supervised ML model via BYOM or running models directly in the database, this approach allows you to incorporate any model into the pipeline effectively.</p>
<p style = 'font-size:16px;font-family:Arial;'>Finally, by stacking ONNXEmbedding and ONNXPredict together, you can easily serve both batch inference and ad-hoc queries, offering flexibility for a variety of deployment scenarios. With these tools, you can significantly streamline medical diagnosis workflows and improve the efficiency of your data processing pipeline.</p>


<hr style="height:2px;border:none;">
<b style = 'font-size:20px;font-family:Arial;'>6. Cleanup</b>

In [None]:
%run utils/_dataremove.ipynb

In [None]:
remove_context()

<footer style="padding-bottom:35px; background:#f9f9f9; border-bottom:3px solid">
    <div style="float:left;margin-top:14px">ClearScape Analytics™</div>
    <div style="float:right;">
        <div style="float:left; margin-top:14px">
            Copyright © Teradata Corporation - 2025. All Rights Reserved
        </div>
    </div>
</footer>