In [1]:
import pandas as pd
df = pd.read_csv('../../data/raw.csv')
print(f"Example of diagnoses:\n{df['diagnoses'][2]}\n\n")
print(f"Example of operations:\n{df['operations'][2]}\n")

Example of diagnoses:
155516. Cardiac conduit failure;090101. Common arterial trunk;110021. Cardiac arrest


Example of operations:
123610. Replacement of cardiac conduit;123452. Pacemaker system placement: biventricular



In [2]:
import os
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
openai_api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI(api_key=openai_api_key)

In [3]:
from TabuLLM.embed import TextColumnTransformer
obj = TextColumnTransformer(
  model_type = 'openai',
  openai_args = {
    'client': client, 
    'model': 'text-embedding-3-small'
  }
)
#X = obj.fit_transform(df.loc[:5, ['diagnoses']])
#print(X.shape)

  from tqdm.autonotebook import tqdm, trange


In [4]:
google_project_id = os.getenv('VERTEXAI_PROJECT')
google_location = os.getenv('VERTEXAI_LOCATION')
print(f"Google project id: {google_project_id}, location: {google_location}")

Google project id: moonlit-helper-426810-a2, location: us-central1


In [5]:
obj = TextColumnTransformer(
    model_type = 'google'
    , google_args = {
        'project_id': google_project_id
        , 'location': google_location
        , 'model': 'text-embedding-004'
        , 'task': 'SEMANTIC_SIMILARITY'
        , 'batch_size': 250
    }
)
#X = obj.fit_transform(df.loc[:5, ['diagnoses']])
#print(X.shape)

In [6]:
obj = TextColumnTransformer(
    model_type = 'st'
    , st_args = {
        'model': 'sentence-transformers/all-MiniLM-L6-v2'
    }
)
X = obj.fit_transform(df.loc[:, ['diagnoses']])
print(X.shape)



(830, 384)


In [7]:
from TabuLLM.cluster import SphericalKMeans
cluster = SphericalKMeans(n_clusters=10, n_init=5)
cluster.fit(X)
print(cluster.predict(X[:5]))

[1 3 5 4 0]


In [8]:
distances = cluster.transform(X)
print(distances.shape)

(830, 10)


In [9]:
from TabuLLM.explain import generate_prompt
prompt, payload = generate_prompt(
    text_list = obj.prep_X(df[['diagnoses']]),
    cluster_labels = cluster.predict(X),
    prompt_observations = 'CPB procedures',
    prompt_texts = 'diagnoses'
)

In [10]:
print(prompt)

The following is a list of 830 CPB procedures. Text lines represent diagnoses. Cpb procedures have been grouped into 10 groups, according to their diagnoses. Please suggest group labels that are representative of their members, and also distinct from each other. Follow the provided template to return - for each group - the group number, a short desciption / group label, and a long description.


In [11]:
print('\n'.join(payload.splitlines()[:5]))

Group 1:

diagnoses: 155516. Cardiac conduit failure;010133. Left heart obstruction at multiple sites (including Shone syndrome);093002. Aberrant origin R subclavian artery
diagnoses: 091591. Aortic regurgitation;110100. Supraventricular tachycardia
diagnoses: 155516. Cardiac conduit failure;111100. Pacemaker dysfunction / complication necessitating replacement;010117. Double outlet right ventricle with subaortic or doubly committed ventricular septal defect and pulmonary stenosis, Fallot type;070501. RV outflow tract obstruction;070901. LV outflow tract obstruction;110610. Acquired complete AV block


In [12]:
from TabuLLM.explain import generate_response
if False:
    generate_response(
        prompt_instructions = prompt
        , prompt_body = payload
        , model_type = 'openai'
        , openai_client = client
        , openai_model = 'gpt-4o-mini'
    )

In [13]:
explanations = pd.read_csv('../../data/explanations.csv')
explanations

Unnamed: 0,group_id,description_short,description_long
0,0,Closure of Ventricular Septal Defects (VSDs),Patients primarily undergoing surgical closure...
1,1,Pulmonary and Tricuspid Valve Surgeries,This group includes patients requiring repairs...
2,2,Tetralogy of Fallot (ToF) Repairs,Patients with Tetralogy of Fallot and related ...
3,3,Cardiac Conduit Replacements,Patients with complications related to cardiac...
4,4,Aortic Valve Surgeries,Patients undergoing procedures related to aort...
5,5,Atrioventricular Septal Defects (AVSDs),Patients primarily undergoing repairs of Atrio...
6,6,Transposition and Related Surgery,This group involves patients with transpositio...
7,7,Univentricular Heart and Cavopulmonary Connect...,Patients requiring complex surgeries for unive...
8,8,Atrial Septal Defect (ASD) Repairs,Focuses on surgical repairs of Atrial Septal D...
9,9,Heart Transplants and Assistance Devices,Patients requiring heart transplantation or me...


In [14]:
from TabuLLM.explain import one_vs_rest
ovr = one_vs_rest(
    pd.DataFrame({
        'cluster': cluster.predict(X)
        , 'outcome': df['aki_severity']
    })
)
ovr

Unnamed: 0,Odds Ratio,P-value
0,0.808065,0.3493238
1,1.10678,0.6872657
2,0.452492,0.1171836
3,1.295413,0.2013507
4,0.820426,0.7358009
5,0.704426,0.3149755
6,1.272727,0.3901943
7,5.5875,1.296613e-07
8,1.063765,0.818482
9,0.159012,1.106404e-05


In [15]:
pd.concat([explanations[['description_short']].rename(columns = {'description_short': 'group'}), ovr], axis=1)

Unnamed: 0,group,Odds Ratio,P-value
0,Closure of Ventricular Septal Defects (VSDs),0.808065,0.3493238
1,Pulmonary and Tricuspid Valve Surgeries,1.10678,0.6872657
2,Tetralogy of Fallot (ToF) Repairs,0.452492,0.1171836
3,Cardiac Conduit Replacements,1.295413,0.2013507
4,Aortic Valve Surgeries,0.820426,0.7358009
5,Atrioventricular Septal Defects (AVSDs),0.704426,0.3149755
6,Transposition and Related Surgery,1.272727,0.3901943
7,Univentricular Heart and Cavopulmonary Connect...,5.5875,1.296613e-07
8,Atrial Septal Defect (ASD) Repairs,1.063765,0.818482
9,Heart Transplants and Assistance Devices,0.159012,1.106404e-05


In [16]:
features_baseline = ['is_female', 'age', 'height', 'weight', 'optime']
features_text = ['diagnoses', 'operations']
X = df[features_baseline + features_text]
y = df['aki_severity']

In [17]:
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression

ct_baseline = ColumnTransformer([
    ('baseline', 'passthrough', features_baseline)
], remainder = 'drop')
pipeline_baseline = Pipeline([
    ('coltrans', ct_baseline)
    , ('logit', LogisticRegression(penalty = None))
])

In [18]:
from sklearn.model_selection import KFold, cross_val_score

kf = KFold(n_splits = 50, shuffle = True, random_state = 1234)

auc_baseline = cross_val_score(
    pipeline_baseline
    , X, y, cv = kf
    , scoring = 'roc_auc'
)
auc_baseline.mean()

0.6755453574203575

In [None]:
trans_embed = TextColumnTransformer(
    model_type = 'st'
)
trans_cluster = SphericalKMeans(n_clusters=10, n_init=5)
ct_text = Pipeline([
    ('embed', trans_embed)
    , ('cluster', trans_cluster)
])
ct_tabullm = ColumnTransformer([
    ('text', ct_text, features_text)
], remainder = 'passthrough')
pipeline_tabullm = Pipeline([
    ('coltrans', ct_tabullm)
    , ('logit', LogisticRegression(penalty = None))
])

auc_tabullm = cross_val_score(
    pipeline_tabullm
    , X, y, cv = kf
    , scoring = 'roc_auc'
)
auc_tabullm.mean()

In [20]:
X_embedding = trans_embed.fit_transform(df[features_text])
X_2 = pd.concat([X_embedding, df[features_baseline]], axis=1)



In [22]:
from sklearn.preprocessing import TargetEncoder, StandardScaler

ct_tabullm_2 = ColumnTransformer([
    ('cluster_text', SphericalKMeans(n_clusters=10, n_init=5), X_embedding.columns)
    , ('baseline', StandardScaler(), features_baseline)
], remainder = 'passthrough')

pipeline_tabullm_2 = Pipeline([
    ('coltrans', ct_tabullm_2)
    , ('logit', LogisticRegression())
])

auc_tabullm_2 = cross_val_score(
    pipeline_tabullm_2
    , X_2, y, cv = kf
    , scoring = 'roc_auc'
)
auc_tabullm_2.mean()

0.6814838911088912

In [23]:

pipeline_te = Pipeline([
    ('cluster', SphericalKMeans(n_clusters=10, n_init=5, return_hard_labels=True)),
    ('te', TargetEncoder(smooth = 'auto'))
])
ct_te = ColumnTransformer([
    ('baseline', StandardScaler(), features_baseline),
    ('text', pipeline_te, X_2.columns)
], remainder = 'drop')
pipeline_te = Pipeline([
    ('preprocess', ct_te)
    , ('logit', LogisticRegression())
])
auc_tabullm_3 = cross_val_score(
    pipeline_te
    , X_2, y, cv = kf
    , scoring = 'roc_auc'
)
auc_tabullm_3.mean()

0.6823733904983905