# AI Explaining AI: The `Explain` Module in `TabuLLM`

In previous tutorials, we introduced the `embed` and `cluster` modules of `TabuLLM` which allowed us to apply a diverse collection of text embedding algorithms to text columns in our tabular data, and create clusters based on the embedding vectors. The resulting clusters can be included as a categorical feature in a predictive model, but they can also be used to interpret or explain the embeddings.

More specifically, text-generating LLMs can be used to provide desriptive labels for the clusters, which are themselves based on the output of an embedding LLM. We can refer to this as, 'AI Explaining AI'. Furthermore, we can apply statistical tests such as Fisher's exact test or ANOVA (?) to determine which clusters have a significantly different distribution of outcome compared to the rest of the population. The combination of these two (labeling the clusters and associating them with outcome) provides a solid explainability path.

To support the above, the `explain` module of `TabuLLM` offers three functions:
1. `generate_prompt` to assemble the full text of the prompt, which solicits cluster labels from a text-completion LLM.
1. `submit_prompt` which is a thin wrapper around various commercial and open-source LLMs.
1. `one_vs_test`, a wrapper for testing the mean outcome within each cluster agains the rest.

Below, we discuss each using the AKI dataset introduced in a previous tutorial. Before proceeding, let's load the AKI data, use a small LLM to embed the text column, and perform spherical K-means to split the data into 10 clusters:

In [12]:
import numpy as np
import pandas as pd
from TabuLLM.embed import TextColumnTransformer
from TabuLLM.cluster import SphericalKMeans
df = pd.read_csv('../data/raw.csv')
embeddings = TextColumnTransformer(
    type = 'st'
    , embedding_model_st = 'sentence-transformers/all-MiniLM-L6-v2'
).fit_transform(df.loc[:, ['diagnoses']])
n_clusters = 10
cluster_labels = SphericalKMeans(n_clusters=n_clusters).fit_predict(embeddings)
#assert np.array_equal(np.unique(cluster_labels), np.arange(0, n_clusters + 0))



## Generating the Prompt

The prompt consiste of two parts. First is the preamble, which provides the data context and the request to the LLM. Second is the data, in which observations are grouped by their cluster labels, and for each observation, the value of the text field that was used to generate the embeddings and then to produce clusters is printed. There are two ways to generate the preamble: 1) provide the phrases to describe the text field and the observation unit, and let the function automatically generate the preamble, 2) directly provide the preamble. Let's make this all more clear by continuing with our running example:

In [14]:
from TabuLLM.explain import generate_prompt

prompt = generate_prompt(
    text_list = list(df['diagnoses'])
    , cluster_labels = cluster_labels
    , prompt_observations = 'pediatric cardiopulmonary bypass surgeries'
    , prompt_texts = 'planned procedures'
)
print(prompt)

The following is a list of 830 pediatric cardiopulmonary bypass surgeries. Text lines represent planned procedures. Pediatric cardiopulmonary bypass surgeries have been grouped into 10 groups, according to their planned procedures. Please suggest group labels that are representative of their members, and also distinct from each other:

=====

Group 1:

101025. Dilated cardiomyopathy;110550. Non-sustained ventricular tachycardia
101025. Dilated cardiomyopathy
100611. Infective endocarditis of aortic valve;100611. Infective endocarditis of aortic valve
101025. Dilated cardiomyopathy;060291. Mitral regurgitation;102303. Family history of disorder with cardiac involvement
101025. Dilated cardiomyopathy;091591. Aortic regurgitation
101025. Dilated cardiomyopathy;110407. AV junctional (nodal) tachycardia
101025. Dilated cardiomyopathy
010109. Hypoplastic left heart syndrome
101020. Hypertrophic cardiomyopathy
159500. Complication after heart or lung transplant;101025. Dilated cardiomyopathy


We can now examine the prompt preamble, edit it as needed, and regenerate the full prompt by supplying our modified preamble:

In [16]:
preamble = '''
The following is a list of 830 pediatric cardiopulmonary bypass (CPB) surgeries. Text lines represent procedures performed on each patient. 
These CPB surgeries have been grouped into 10 groups, according to their planned procedures. 
Please suggest group labels that are representative of their members, and also distinct from each other:
'''
prompt2 = generate_prompt(
    text_list = list(df['diagnoses'])
    , cluster_labels = cluster_labels
    , preamble = preamble
)
print(prompt2)


The following is a list of 830 pediatric cardiopulmonary bypass (CPB) surgeries. Text lines represent procedures performed on each patient. 
These CPB surgeries have been grouped into 10 groups, according to their planned procedures. 
Please suggest group labels that are representative of their members, and also distinct from each other:


=====

Group 1:

101025. Dilated cardiomyopathy;110550. Non-sustained ventricular tachycardia
101025. Dilated cardiomyopathy
100611. Infective endocarditis of aortic valve;100611. Infective endocarditis of aortic valve
101025. Dilated cardiomyopathy;060291. Mitral regurgitation;102303. Family history of disorder with cardiac involvement
101025. Dilated cardiomyopathy;091591. Aortic regurgitation
101025. Dilated cardiomyopathy;110407. AV junctional (nodal) tachycardia
101025. Dilated cardiomyopathy
010109. Hypoplastic left heart syndrome
101020. Hypertrophic cardiomyopathy
159500. Complication after heart or lung transplant;101025. Dilated cardiomyop

## Submitting the Prompt to LLM

Next, we will submit our prompt to an LLM to generate cluster labels. Before doing so, it's important to make sure the total size of the prompt does not exceed the specifications of our target LLM. Here are the links to model specifications for OpenAI and Google:
- Google: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-1.5-pro
- OpenAI: https://platform.openai.com/docs/models

We see that OpenAI specifies its *context window* in tokens, while Google's *maximum input tokens* is defined in characters. Let's count the number of characters and tokens in our prompt. While the former is straightforward and easy to calculate, the latter's exact value depends on the tokenizer used, though we are more interestes in an approximate estimate and not an exact number.

In [20]:
n_characters = len(prompt2)
import tiktoken
encoder = tiktoken.encoding_for_model('gpt-4-turbo')
n_tokens = len(encoder.encode(prompt2))
print(f'Number of characters: {n_characters}, number of tokens: {n_tokens}')

Number of characters: 89697, number of tokens: 25217
