In [1]:
import pandas as pd
df = pd.read_csv('../../data/raw.csv')
print(f"Example of diagnoses:\n{df['diagnoses'][2]}\n\n")
print(f"Example of operations:\n{df['operations'][2]}\n")

Example of diagnoses:
155516. Cardiac conduit failure;090101. Common arterial trunk;110021. Cardiac arrest


Example of operations:
123610. Replacement of cardiac conduit;123452. Pacemaker system placement: biventricular



In [2]:
import os
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
openai_api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI(api_key=openai_api_key)

In [3]:
from TabuLLM.embed import TextColumnTransformer
obj = TextColumnTransformer(
  model_type = 'openai',
  openai_args = {
    'client': client, 
    'model': 'text-embedding-3-small'
  }
)
#X = obj.fit_transform(df.loc[:5, ['diagnoses']])
#print(X.shape)

  from tqdm.autonotebook import tqdm, trange


In [4]:
google_project_id = os.getenv('VERTEXAI_PROJECT')
google_location = os.getenv('VERTEXAI_LOCATION')
print(f"Google project id: {google_project_id}, location: {google_location}")

Google project id: moonlit-helper-426810-a2, location: us-central1


In [5]:
obj = TextColumnTransformer(
    model_type = 'google'
    , google_args = {
        'project_id': google_project_id
        , 'location': google_location
        , 'model': 'text-embedding-004'
        , 'task': 'SEMANTIC_SIMILARITY'
        , 'batch_size': 250
    }
)
X = obj.fit_transform(df.loc[:5, ['diagnoses']])
print(X.shape)

(6, 768)


In [6]:
obj = TextColumnTransformer(
    model_type = 'st'
    , st_args = {
        'model': 'sentence-transformers/all-MiniLM-L6-v2'
    }
)
X = obj.fit_transform(df.loc[:, ['diagnoses']])
print(X.shape)



(830, 384)


In [7]:
from TabuLLM.cluster import SphericalKMeans
cluster = SphericalKMeans(n_clusters=10, n_init=5)
cluster.fit(X)
print(cluster.predict(X[:5]))

[6 9 5 8 5]


In [8]:
distances = cluster.transform(X)
print(distances.shape)

(830, 10)


In [10]:
from TabuLLM.explain import generate_prompt
prompt, payload = generate_prompt(
    text_list = obj.prep_X(df[['diagnoses']]),
    cluster_labels = cluster.predict(X),
    prompt_observations = 'CPB procedures',
    prompt_texts = 'diagnoses'
)

In [12]:
print(prompt)

The following is a list of 830 CPB procedures. Text lines represent diagnoses. Cpb procedures have been grouped into 10 groups, according to their diagnoses. Please suggest group labels that are representative of their members, and also distinct from each other. Follow the provided template to return - for each group - the group number, a short desciption / group label, and a long description.


In [19]:
print('\n'.join(payload.splitlines()[:5]))

Group 1:

diagnoses: 091600. Supravalvar aortic stenosis;070901. LV outflow tract obstruction
diagnoses: 155516. Cardiac conduit failure;091501. Aortic valvar stenosis - congenital;091591. Aortic regurgitation
diagnoses: 091591. Aortic regurgitation; 070901. LV outflow tract obstruction; 091501. Aortic valvar stenosis - congenital; 010501. Discordant VA connections (TGA); 040100. Superior caval vein (SVC) abnormality


In [20]:
from TabuLLM.explain import generate_response
if False:
    generate_response(
        prompt_instructions = prompt
        , prompt_body = payload
        , model_type = 'openai'
        , openai_client = client
        , openai_model = 'gpt-4o-mini'
    )

In [21]:
explanations = pd.read_csv('../../data/explanations.csv')
explanations

Unnamed: 0,group_id,description_short,description_long
0,0,Closure of Ventricular Septal Defects (VSDs),Patients primarily undergoing surgical closure...
1,1,Pulmonary and Tricuspid Valve Surgeries,This group includes patients requiring repairs...
2,2,Tetralogy of Fallot (ToF) Repairs,Patients with Tetralogy of Fallot and related ...
3,3,Cardiac Conduit Replacements,Patients with complications related to cardiac...
4,4,Aortic Valve Surgeries,Patients undergoing procedures related to aort...
5,5,Atrioventricular Septal Defects (AVSDs),Patients primarily undergoing repairs of Atrio...
6,6,Transposition and Related Surgery,This group involves patients with transpositio...
7,7,Univentricular Heart and Cavopulmonary Connect...,Patients requiring complex surgeries for unive...
8,8,Atrial Septal Defect (ASD) Repairs,Focuses on surgical repairs of Atrial Septal D...
9,9,Heart Transplants and Assistance Devices,Patients requiring heart transplantation or me...


In [24]:
from TabuLLM.explain import one_vs_rest
ovr = one_vs_rest(
    pd.DataFrame({
        'cluster': cluster.predict(X)
        , 'outcome': df['aki_severity']
    })
)
ovr

Unnamed: 0,Odds Ratio,P-value
0,1.16086,0.545812
1,2.311592,0.000725
2,1.415865,0.41933
3,0.371843,0.110473
4,0.870385,0.682033
5,0.539372,0.023937
6,1.252945,0.379913
7,0.114689,1e-06
8,1.344184,0.227912
9,1.137889,0.540602


In [28]:
pd.concat([explanations[['description_short']].rename(columns = {'description_short': 'group'}), ovr], axis=1)

Unnamed: 0,group,Odds Ratio,P-value
0,Closure of Ventricular Septal Defects (VSDs),1.16086,0.545812
1,Pulmonary and Tricuspid Valve Surgeries,2.311592,0.000725
2,Tetralogy of Fallot (ToF) Repairs,1.415865,0.41933
3,Cardiac Conduit Replacements,0.371843,0.110473
4,Aortic Valve Surgeries,0.870385,0.682033
5,Atrioventricular Septal Defects (AVSDs),0.539372,0.023937
6,Transposition and Related Surgery,1.252945,0.379913
7,Univentricular Heart and Cavopulmonary Connect...,0.114689,1e-06
8,Atrial Septal Defect (ASD) Repairs,1.344184,0.227912
9,Heart Transplants and Assistance Devices,1.137889,0.540602
