In [1]:
import numpy as np
import pandas as pd
from TabuLLM.embed import TextColumnTransformer
from TabuLLM.cluster import SphericalKMeans
df = pd.read_csv('../../data/raw.csv')
embeddings = TextColumnTransformer(
    model_type = 'st'
).fit_transform(df.loc[:, ['diagnoses']])
n_clusters = 10
cluster_labels = SphericalKMeans(n_clusters=n_clusters).fit_predict(embeddings)
assert np.array_equal(np.unique(cluster_labels), np.arange(n_clusters))

  from tqdm.autonotebook import tqdm, trange


In [2]:
from TabuLLM.explain import generate_prompt, generate_response
prompt_instruction, prompt_body = generate_prompt(
    text_list = list(df['diagnoses'])
    , cluster_labels = cluster_labels
    , prompt_observations = 'pediatric cardiopulmonary bypass surgeries'
    , prompt_texts = 'planned procedures'
)

In [3]:
from pydantic import BaseModel

class GroupLabel(BaseModel):
    number: int
    description_short: str
    description_long: str

class MultipleGroupLabels(BaseModel):
    groups: list[GroupLabel]

    # method to convert the response to a DataFrame
    def to_df(self):
        return pd.DataFrame([group.model_dump() for group in self.groups]).sort_values('number').reset_index(drop=True)

from openai import OpenAI
import os
from dotenv import load_dotenv
load_dotenv()
openai_api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI(api_key=openai_api_key)

google_project_id = os.getenv('VERTEXAI_PROJECT')
google_location = os.getenv('VERTEXAI_LOCATION')
import vertexai
from vertexai.generative_models import GenerativeModel, GenerationConfig
vertexai.init(project=google_project_id, location=google_location)

In [4]:
groups = generate_response(
    prompt_instructions = prompt_instruction
    , prompt_body = prompt_body
    , openai_client = client
    , model_type = 'openai'
    #, model_type = 'google'
    , google_location = google_location
    , google_project_id = google_project_id
)
#groups = groups[['number', 'description_short', 'description_long']]
groups

Unnamed: 0,group_number,description_short,description_long
0,1,Tetralogy of Fallot Related Surgeries,This group encompasses various cardiac conditi...
1,2,Aortic and Valvular Disorders,This group consists of surgeries dealing mainl...
2,3,Complex Congenital Heart Anomalies,This group includes various intricate congenit...
3,4,Atrial Septal Defects,Focusing predominantly on atrial septal defect...
4,5,Pulmonary Atresia and Vessel Anomalies,This group discusses surgeries related to pulm...
5,6,Ventricular Septal Defects,Enclosing mainly surgical interventions for ve...
6,7,Cardiomyopathies and Heart Failures,"Focusing on various types of cardiomyopathies,..."
7,8,Transposition of the Great Arteries,This group focuses on surgeries associated wit...
8,9,Pulmonary Venous Anomalies,Comprising surgical interventions related to p...
9,10,Aortic Stenosis and Related Conditions,This group contains surgeries related to aorti...


In [5]:
from TabuLLM.explain import one_vs_rest
fisher = one_vs_rest(
    pd.DataFrame({
        'cluster': cluster_labels
        #, 'outcome': df['aki_severity']
        , 'outcome': df['cr_ratio_log']
    })
)
fisher

Unnamed: 0,T-Statistic,P-value
0,2.386364,0.01906543
1,-0.933135,0.3523181
2,-0.479211,0.6326005
3,-8.231543,4.110858e-13
4,0.102038,0.9190231
5,1.896191,0.05912939
6,3.404894,0.001496727
7,-2.290184,0.02502068
8,-1.205041,0.2326634
9,-0.126009,0.8999003


In [9]:
fisher.merge(groups, left_index=True, right_index=True)[['group_number', 'description_short', 'description_long', 'T-Statistic', 'P-value']]

Unnamed: 0,group_number,description_short,description_long,T-Statistic,P-value
0,1,Tetralogy of Fallot Related Surgeries,This group encompasses various cardiac conditi...,2.386364,0.01906543
1,2,Aortic and Valvular Disorders,This group consists of surgeries dealing mainl...,-0.933135,0.3523181
2,3,Complex Congenital Heart Anomalies,This group includes various intricate congenit...,-0.479211,0.6326005
3,4,Atrial Septal Defects,Focusing predominantly on atrial septal defect...,-8.231543,4.110858e-13
4,5,Pulmonary Atresia and Vessel Anomalies,This group discusses surgeries related to pulm...,0.102038,0.9190231
5,6,Ventricular Septal Defects,Enclosing mainly surgical interventions for ve...,1.896191,0.05912939
6,7,Cardiomyopathies and Heart Failures,"Focusing on various types of cardiomyopathies,...",3.404894,0.001496727
7,8,Transposition of the Great Arteries,This group focuses on surgeries associated wit...,-2.290184,0.02502068
8,9,Pulmonary Venous Anomalies,Comprising surgical interventions related to p...,-1.205041,0.2326634
9,10,Aortic Stenosis and Related Conditions,This group contains surgeries related to aorti...,-0.126009,0.8999003


In [None]:
df.columns