# AI Explaining AI: The `Explain` Module in `TabuLLM`

In previous tutorials, we introduced the `embed` and `cluster` modules of `TabuLLM` which allowed us to apply a diverse collection of text embedding algorithms to text columns in our tabular data, and create clusters based on the embedding vectors. The resulting clusters can be included as a categorical feature in a predictive model, but they can also be used to interpret or explain the embeddings.

More specifically, text-generating LLMs can be used to provide desriptive labels for the clusters, which are themselves based on the output of an embedding LLM. We can refer to this as, 'AI Explaining AI'. Furthermore, we can apply statistical tests such as Fisher's exact test or ANOVA (?) to determine which clusters have a significantly different distribution of outcome compared to the rest of the population. The combination of these two (labeling the clusters and associating them with outcome) provides a solid explainability path.

To support the above, the `explain` module of `TabuLLM` offers three functions:
1. `generate_prompt` to assemble the full text of the prompt, which solicits cluster labels from a text-completion LLM.
1. `submit_prompt` which is a thin wrapper around various commercial and open-source LLMs.
1. `one_vs_test`, a wrapper for testing the mean outcome within each cluster agains the rest.

Below, we discuss each using the AKI dataset introduced in a previous tutorial. Before proceeding, let's load the AKI data, use a small LLM to embed the text column, and perform spherical K-means to split the data into 10 clusters:

In [1]:
import numpy as np
import pandas as pd
from TabuLLM.embed import TextColumnTransformer
from TabuLLM.cluster import SphericalKMeans
df = pd.read_csv('../data/raw.csv')
embeddings = TextColumnTransformer(
    type = 'st'
    , embedding_model_st = 'sentence-transformers/all-MiniLM-L6-v2'
).fit_transform(df.loc[:, ['diagnoses']])
n_clusters = 10
cluster_labels = SphericalKMeans(n_clusters=n_clusters).fit_predict(embeddings)
#assert np.array_equal(np.unique(cluster_labels), np.arange(0, n_clusters + 0))

  from tqdm.autonotebook import tqdm, trange


## Generating the Prompt

The prompt consiste of two parts. First is the preamble, which provides the data context and the request to the LLM. Second is the data, in which observations are grouped by their cluster labels, and for each observation, the value of the text field that was used to generate the embeddings and then to produce clusters is printed. There are two ways to generate the preamble: 1) provide the phrases to describe the text field and the observation unit, and let the function automatically generate the preamble, 2) directly provide the preamble. Let's make this all more clear by continuing with our running example:

In [2]:
from TabuLLM.explain import generate_prompt

# a helper function to avoid printing the entire prompt
def print_first_n_lines(text, n):
    lines = text.split('\n')
    for line in lines[:n]:
        print(line)

prompt = generate_prompt(
    text_list = list(df['diagnoses'])
    , cluster_labels = cluster_labels
    , prompt_observations = 'pediatric cardiopulmonary bypass surgeries'
    , prompt_texts = 'planned procedures'
)
print_first_n_lines(prompt, 20)

The following is a list of 830 pediatric cardiopulmonary bypass surgeries. Text lines represent planned procedures. Pediatric cardiopulmonary bypass surgeries have been grouped into 10 groups, according to their planned procedures. Please suggest group labels that are representative of their members, and also distinct from each other:

=====

Group 1:

155516. Cardiac conduit failure;010501. Discordant VA connections (TGA);091026. Left pulmonary arterial stenosis;070530. Subpulmonary stenosis
155516. Cardiac conduit failure;111100. Pacemaker dysfunction / complication necessitating replacement;010117. Double outlet right ventricle with subaortic or doubly committed ventricular septal defect and pulmonary stenosis, Fallot type;070501. RV outflow tract obstruction;070901. LV outflow tract obstruction;110610. Acquired complete AV block
070501. RV outflow tract obstruction;010117. Double outlet right ventricle with subaortic or doubly committed ventricular septal defect and pulmonary steno

We can now examine the prompt preamble, edit it as needed, and regenerate the full prompt by supplying our modified preamble:

In [3]:
preamble = '''
The following is a list of 830 pediatric cardiopulmonary bypass (CPB) surgeries. Text lines represent procedures performed on each patient. 
These CPB surgeries have been grouped into 10 groups, according to their planned procedures. 
Please suggest group labels that are representative of their members, and also distinct from each other:
'''
prompt2 = generate_prompt(
    text_list = list(df['diagnoses'])
    , cluster_labels = cluster_labels
    , preamble = preamble
)
print_first_n_lines(prompt2, 20)


The following is a list of 830 pediatric cardiopulmonary bypass (CPB) surgeries. Text lines represent procedures performed on each patient. 
These CPB surgeries have been grouped into 10 groups, according to their planned procedures. 
Please suggest group labels that are representative of their members, and also distinct from each other:


=====

Group 1:

155516. Cardiac conduit failure;010501. Discordant VA connections (TGA);091026. Left pulmonary arterial stenosis;070530. Subpulmonary stenosis
155516. Cardiac conduit failure;111100. Pacemaker dysfunction / complication necessitating replacement;010117. Double outlet right ventricle with subaortic or doubly committed ventricular septal defect and pulmonary stenosis, Fallot type;070501. RV outflow tract obstruction;070901. LV outflow tract obstruction;110610. Acquired complete AV block
070501. RV outflow tract obstruction;010117. Double outlet right ventricle with subaortic or doubly committed ventricular septal defect and pulmonary 

## Submitting the Prompt to LLM

### Context Window Limits

Next, we will submit our prompt to an LLM to generate cluster labels. Before doing so, it's important to make sure the total size of the prompt does not exceed the specifications of our target LLM. Here are the links to model specifications for OpenAI and Google:
- Google: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-1.5-pro
- OpenAI: https://platform.openai.com/docs/models

We see that OpenAI specifies its *context window* in tokens, while Google's *maximum input tokens* is defined in characters. Let's count the number of characters and tokens in our prompt. While the former is straightforward and easy to calculate, the latter's exact value depends on the tokenizer used, though we are more interestes in an approximate estimate and not an exact number.

In [4]:
n_characters = len(prompt2)
import tiktoken
encoder = tiktoken.encoding_for_model('gpt-4-turbo')
n_tokens = len(encoder.encode(prompt2))
print(f'Number of characters: {n_characters}, number of tokens: {n_tokens}')

Number of characters: 89697, number of tokens: 25217


If we compare the above numbers against OpenAI's limits, we see that while the maximum context length of 8192 tokens for the `gpt-4` family is insufficient to handle our prompt, the newer-generation of OpenAI models including `gpt-4-turbo` and `gpt-4o` have an adequate context length of 128k tokens. Likewise, upon examining the Gemini text-completion models form Google, we note that *Gemini 1.0 Pro* and newer models are capable of handling our prompt. In particular, we note that the *Gemini 1.5 Pro* model has an impressive *maximum input tokens* parameter of more than 2 million!

The broader point is that this prompt - since it includes the entirety of the text data itself - is likely to be lengthy in most applications, but with the rapid advances in LLMs and increase in the length of their context window, larger datasets can be handled. At the same time, it must be noted that modern LLMs which are capable of handling very long prompts are likely to be quite large in size, and thus would exceed the RAM and processing power of most users' local machines.

For this reason, we have currently limited the `explain` module of `TabuLLM` to commercial LLMs from OpenAI and Google. A potential feature on our roadmap is to include an option for *two-stage explanation* to circumvent the limits of some LLMs, especially the open-source ones.

### Cost Considerations

As with submitting embedding tasks to commercial LLMs, here we must also be aware of the costs. The following links contain pricing information from OpenAI and Google for their text-completion models:
- Google: https://cloud.google.com/vertex-ai/generative-ai/pricing
- OpenAI: https://openai.com/api/pricing/

For instance, we see that if we use the *Gemini 1.5 Flash* model from Google, we would incur \$0.00001875 per 1,000 characters. For the above prompt, this would amount to 89.697 x 0.00001875 or 0.17 cents, which is negligible. The more advanced *Gemini 1.5 Pro* model costs roughly two orders of magnitude more (\$0.00125 per 1,000 characters), or about 11 cents.

Similarly, for OpenAI's *gpt-4o* model, the price is \$5.0 per 1 million tokens, which amounts to 12 cents. Unsurprisingly, we see that OpenAI and Google models have very competitive prices for their most advanced models.

### AI Explaining AI

Having discussed the context window limits and cost aspects, let's finally proceed with submitting our prompt and examining the results.

In [5]:
from TabuLLM.explain import generate_response

import os
from dotenv import load_dotenv
load_dotenv()
google_project_id = os.getenv('VERTEXAI_PROJECT')
google_location = os.getenv('VERTEXAI_LOCATION')

response = generate_response(
    prompt2
    , type = "google"
    , google_project_id = google_project_id
    , google_location = google_location
    , google_model = 'gemini-1.5-flash-001'
)
print(response)

Here are some suggested group labels, aiming for both representativeness and distinctiveness:

**Group 1:  Double Outlet Right Ventricle & Complex Congenital Heart Defects**

* **Reasoning:** This group features a high prevalence of Double Outlet Right Ventricle (DORV) variants and other complex congenital heart defects.  It includes various types of DORV, as well as associated issues like pulmonary stenosis, subaortic stenosis, and TGA. 

**Group 2: Cardiomyopathies & Heart Failure**

* **Reasoning:**  This group primarily consists of patients with various cardiomyopathies, including dilated, hypertrophic, and restrictive forms.  There's a clear focus on cardiomyopathy-related heart failure.

**Group 3: Valvular Abnormalities & Congenital Heart Disease**

* **Reasoning:**  The focus is on valvular problems, both congenital (e.g., bicuspid aortic valve) and acquired (e.g., mitral regurgitation).  There's a significant presence of congenital heart disease like Tetralogy of Fallot and ot

While validating the medical sensibility of the above labels requires expert opinion, we can see that at least we have obtained a coherent answer from the LLM.

We can also write a small utility function to extract the group names from the LLM response:


In [10]:
import re

def extract_group_names(text):
    pattern = re.compile(r"\*\*Group \d+: (.*?)\*\*")
    
    # Find all matches in the text
    matches = pattern.findall(text)
    
    return matches

cluster_names = extract_group_names(response)
print(cluster_names)

[' Double Outlet Right Ventricle & Complex Congenital Heart Defects', 'Cardiomyopathies & Heart Failure', 'Valvular Abnormalities & Congenital Heart Disease', 'Atrioventricular Septal Defects', 'Transposition of the Great Arteries', 'Hypoplastic Left Heart Syndrome & Complex Congenital Defects', 'Ventricular Septal Defects', 'Atrial Septal Defects', 'Tetralogy of Fallot & Pulmonary Atresia Variants', 'Pulmonary Vascular Anomalies & Congenital Heart Disease']


## Association of Clusters with Outcome

Now that we have come up with descriptive labels for the clusters based on the embedding vectors, it will be insightful to see which clusters show a significant difference from the rest of the data in terms of prevalence of outcome. In our case, the outcome is a binary variable that reflects the severity of postoperative acute kidney injury. For such binary classification problems, we can use the Fisher's exact test to compare the odds of severe AKI in each cluster against the remaining clusters. This is done readily using the `one_vs_rest` function:

In [14]:
from TabuLLM.explain import one_vs_rest
fisher = one_vs_rest(
    pd.DataFrame({
        'cluster': cluster_labels
        , 'outcome': df['aki_severity']
    })
)

We can now combine the above with the cluster labels into a single dataframe:

In [16]:
fisher['Group Name'] = cluster_names
fisher = fisher[['Group Name', 'Statistic', 'P-value']]
fisher

Unnamed: 0,Group Name,Statistic,P-value
0,Double Outlet Right Ventricle & Complex Conge...,0.56851,0.103071
1,Cardiomyopathies & Heart Failure,6.46875,2.687169e-08
2,Valvular Abnormalities & Congenital Heart Disease,1.053571,0.8169446
3,Atrioventricular Septal Defects,1.009901,1.0
4,Transposition of the Great Arteries,0.371843,0.1104734
5,Hypoplastic Left Heart Syndrome & Complex Cong...,1.260627,0.2818915
6,Ventricular Septal Defects,1.197822,0.3735867
7,Atrial Septal Defects,0.153629,6.89021e-06
8,Tetralogy of Fallot & Pulmonary Atresia Variants,1.252945,0.3799128
9,Pulmonary Vascular Anomalies & Congenital Hear...,0.494785,0.03383647


To summarize the explanation, we have concluded that:
1. Pediatric CPBs to repair 'Cardiomyopathies and Heart Failure' have ~6.5x odds of being followed by severe AKI.
1. On the other hand, operations to repair 'Atrial Septal Defects' have a 6.5x smaller odds of leading to severe AKI, compared to the rest of operations.

Setting aside the accuracy and medical plausibility of the above explanations, they facilitate a practitioner's understanding of what embeddings are doing and allows them to decide how much to trust their output.