In [1]:
import pandas as pd
df = pd.read_csv('../../data/raw.csv')
print(f"Example of diagnoses:\n{df['diagnoses'][2]}\n\n")
print(f"Example of operations:\n{df['operations'][2]}\n")

Example of diagnoses:
155516. Cardiac conduit failure;090101. Common arterial trunk;110021. Cardiac arrest


Example of operations:
123610. Replacement of cardiac conduit;123452. Pacemaker system placement: biventricular



In [2]:
import os
from openai import OpenAI
from dotenv import load_dotenv
load_dotenv()
openai_api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI(api_key=openai_api_key)

In [3]:
from TabuLLM.embed import TextColumnTransformer
obj = TextColumnTransformer(
  model_type = 'openai',
  openai_args = {
    'client': client, 
    'model': 'text-embedding-3-small'
  }
)
#X = obj.fit_transform(df.loc[:5, ['diagnoses']])
#print(X.shape)

  from tqdm.autonotebook import tqdm, trange


In [4]:
google_project_id = os.getenv('VERTEXAI_PROJECT')
google_location = os.getenv('VERTEXAI_LOCATION')
print(f"Google project id: {google_project_id}, location: {google_location}")

Google project id: moonlit-helper-426810-a2, location: us-central1


In [5]:
obj = TextColumnTransformer(
    model_type = 'google'
    , google_args = {
        'project_id': google_project_id
        , 'location': google_location
        , 'model': 'text-embedding-004'
        , 'task': 'SEMANTIC_SIMILARITY'
        , 'batch_size': 250
    }
)
#X = obj.fit_transform(df.loc[:5, ['diagnoses']])
#print(X.shape)

In [6]:
obj = TextColumnTransformer(
    model_type = 'st'
    , st_args = {
        'model': 'sentence-transformers/all-MiniLM-L6-v2'
    }
)
X = obj.fit_transform(df.loc[:, ['diagnoses']])
print(X.shape)



(830, 384)


In [7]:
from TabuLLM.cluster import SphericalKMeans
cluster = SphericalKMeans(n_clusters=10, n_init=5)
cluster.fit(X)
print(cluster.predict(X[:5]))

[7 0 6 6 8]


In [8]:
distances = cluster.transform(X)
print(distances.shape)

(830, 10)


In [9]:
from TabuLLM.explain import generate_prompt
prompt, payload = generate_prompt(
    text_list = obj.prep_X(df[['diagnoses']]),
    cluster_labels = cluster.predict(X),
    prompt_observations = 'CPB procedures',
    prompt_texts = 'diagnoses'
)

In [10]:
print(prompt)

The following is a list of 830 CPB procedures. Text lines represent diagnoses. Cpb procedures have been grouped into 10 groups, according to their diagnoses. Please suggest group labels that are representative of their members, and also distinct from each other. Follow the provided template to return - for each group - the group number, a short desciption / group label, and a long description.


In [11]:
print('\n'.join(payload.splitlines()[:5]))

Group 1:

diagnoses: 091591. Aortic regurgitation;091519. Congenital anomaly of aortic valve;071205. Doubly committed juxta-arterial ventricular septal defect (VSD) with anteriorly malaligned fibrous outlet septum and perimembranous extension
diagnoses: 071001. Perimembranous central ventricular septal defect (VSD);070501. RV outflow tract obstruction
diagnoses: 071001. Perimembranous central ventricular septal defect (VSD);071402. Communication between LV-RA (Gerbode defect);060191. Tricuspid regurgitation


In [12]:
from TabuLLM.explain import generate_response
if False:
    generate_response(
        prompt_instructions = prompt
        , prompt_body = payload
        , model_type = 'openai'
        , openai_client = client
        , openai_model = 'gpt-4o-mini'
    )

In [13]:
explanations = pd.read_csv('../../data/explanations.csv')
explanations

Unnamed: 0,group_id,description_short,description_long
0,0,Closure of Ventricular Septal Defects (VSDs),Patients primarily undergoing surgical closure...
1,1,Pulmonary and Tricuspid Valve Surgeries,This group includes patients requiring repairs...
2,2,Tetralogy of Fallot (ToF) Repairs,Patients with Tetralogy of Fallot and related ...
3,3,Cardiac Conduit Replacements,Patients with complications related to cardiac...
4,4,Aortic Valve Surgeries,Patients undergoing procedures related to aort...
5,5,Atrioventricular Septal Defects (AVSDs),Patients primarily undergoing repairs of Atrio...
6,6,Transposition and Related Surgery,This group involves patients with transpositio...
7,7,Univentricular Heart and Cavopulmonary Connect...,Patients requiring complex surgeries for unive...
8,8,Atrial Septal Defect (ASD) Repairs,Focuses on surgical repairs of Atrial Septal D...
9,9,Heart Transplants and Assistance Devices,Patients requiring heart transplantation or me...


In [14]:
from TabuLLM.explain import one_vs_rest
ovr = one_vs_rest(
    pd.DataFrame({
        'cluster': cluster.predict(X)
        , 'outcome': df['aki_severity']
    })
)
ovr

Unnamed: 0,Odds Ratio,P-value
0,1.268311,0.2403796
1,1.006618,1.0
2,0.703439,0.2470421
3,1.482042,0.3262249
4,5.5875,1.296613e-07
5,1.380923,0.2160364
6,0.610717,0.0790038
7,1.099935,0.7681655
8,0.765392,0.4282053
9,0.114689,1.282851e-06


In [15]:
pd.concat([explanations[['description_short']].rename(columns = {'description_short': 'group'}), ovr], axis=1)

Unnamed: 0,group,Odds Ratio,P-value
0,Closure of Ventricular Septal Defects (VSDs),1.268311,0.2403796
1,Pulmonary and Tricuspid Valve Surgeries,1.006618,1.0
2,Tetralogy of Fallot (ToF) Repairs,0.703439,0.2470421
3,Cardiac Conduit Replacements,1.482042,0.3262249
4,Aortic Valve Surgeries,5.5875,1.296613e-07
5,Atrioventricular Septal Defects (AVSDs),1.380923,0.2160364
6,Transposition and Related Surgery,0.610717,0.0790038
7,Univentricular Heart and Cavopulmonary Connect...,1.099935,0.7681655
8,Atrial Septal Defect (ASD) Repairs,0.765392,0.4282053
9,Heart Transplants and Assistance Devices,0.114689,1.282851e-06


In [16]:
features_baseline = ['is_female', 'age', 'height', 'weight', 'optime']
features_text = ['diagnoses', 'operations']
X = df[features_baseline + features_text]
y = df['aki_severity']

In [17]:
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression

ct_baseline = ColumnTransformer([
    ('baseline', 'passthrough', features_baseline)
], remainder = 'drop')
pipeline_baseline = Pipeline([
    ('coltrans', ct_baseline)
    , ('logit', LogisticRegression(penalty = None))
])

In [18]:
from sklearn.model_selection import KFold, cross_val_score

kf = KFold(n_splits = 5, shuffle = True, random_state = 1234)

auc_baseline = cross_val_score(
    pipeline_baseline
    , X, y, cv = kf
    , scoring = 'roc_auc'
)
auc_baseline.mean()

0.6809980657254358

In [19]:
trans_embed = TextColumnTransformer(
    model_type = 'st'
)
trans_cluster = SphericalKMeans(n_clusters=10, n_init=5)
ct_text = Pipeline([
    ('embed', trans_embed)
    , ('cluster', trans_cluster)
])
ct_tabullm = ColumnTransformer([
    ('text', ct_text, features_text)
], remainder = 'passthrough')
pipeline_tabullm = Pipeline([
    ('coltrans', ct_tabullm)
    , ('logit', LogisticRegression(penalty = None))
])

auc_tabullm = cross_val_score(
    pipeline_tabullm
    , X, y, cv = kf
    , scoring = 'roc_auc'
)
auc_tabullm.mean()

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

0.6733075114896849

In [20]:
ct_text.fit_transform(X[features_text], y)



array([[0.42094144, 0.5204149 , 0.49500299, ..., 0.53031623, 0.764132  ,
        0.5529034 ],
       [0.5890519 , 0.51804733, 0.4859802 , ..., 0.75263476, 0.67278415,
        0.54293156],
       [0.50199026, 0.5701773 , 0.49924445, ..., 0.43944168, 0.67252946,
        0.5246769 ],
       ...,
       [0.88317037, 0.44940776, 0.4038371 , ..., 0.39154717, 0.46148455,
        0.48783517],
       [0.5378833 , 0.5223813 , 0.56443805, ..., 0.9075183 , 0.68682873,
        0.6138319 ],
       [0.6352277 , 0.5095983 , 0.6694881 , ..., 0.6113167 , 0.7752981 ,
        0.7639202 ]], dtype=float32)