In [2]:
from pydantic import BaseModel
from openai import OpenAI
import os
from dotenv import load_dotenv
load_dotenv()
openai_api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI(api_key=openai_api_key)

In [18]:
class CalendarEvent(BaseModel):
    name: str
    date: str
    participants: list[str]

class MultipleCalendarEvents(BaseModel):
    events: list[CalendarEvent]

completion = client.beta.chat.completions.parse(
    model="gpt-4o-2024-08-06",
    messages=[
        {"role": "system", "content": "Extract the information for all events."},
        {"role": "user", "content": "Alice and Bob are going to a science fair on Friday. Bob and Alex are going to a concert on Saturday."},
    ],
    #response_format=CalendarEvent,
    response_format=MultipleCalendarEvents,
)

event = completion.choices[0].message.parsed

In [21]:
event
#completion.choices[0].message

MultipleCalendarEvents(events=[CalendarEvent(name='Science Fair', date='Friday', participants=['Alice', 'Bob']), CalendarEvent(name='Concert', date='Saturday', participants=['Bob', 'Alex'])])

In [8]:
import numpy as np
import pandas as pd
from TabuLLM.embed import TextColumnTransformer
from TabuLLM.cluster import SphericalKMeans
df = pd.read_csv('../../data/raw.csv')
embeddings = TextColumnTransformer(
    model_type = 'st'
).fit_transform(df.loc[:, ['diagnoses']])
n_clusters = 10
cluster_labels = SphericalKMeans(n_clusters=n_clusters).fit_predict(embeddings)



In [9]:
from TabuLLM.explain import generate_prompt

# a helper function to avoid printing the entire prompt
def print_first_n_lines(text, n):
    lines = text.split('\n')
    for line in lines[:n]:
        print(line)

prompt = generate_prompt(
    text_list = list(df['diagnoses'])
    , cluster_labels = cluster_labels
    , prompt_observations = 'pediatric cardiopulmonary bypass surgeries'
    , prompt_texts = 'planned procedures'
)
print_first_n_lines(prompt, 20)

The following is a list of 830 pediatric cardiopulmonary bypass surgeries. Text lines represent planned procedures. Pediatric cardiopulmonary bypass surgeries have been grouped into 10 groups, according to their planned procedures. Please suggest group labels that are representative of their members, and also distinct from each other:

=====

Group 1:

091591. Aortic regurgitation;091519. Congenital anomaly of aortic valve;071205. Doubly committed juxta-arterial ventricular septal defect (VSD) with anteriorly malaligned fibrous outlet septum and perimembranous extension
071001. Perimembranous central ventricular septal defect (VSD);070501. RV outflow tract obstruction
071001. Perimembranous central ventricular septal defect (VSD);071402. Communication between LV-RA (Gerbode defect);060191. Tricuspid regurgitation
071001. Perimembranous central ventricular septal defect (VSD)
071001. Perimembranous central ventricular septal defect (VSD)
071001. Perimembranous central ventricular septal

In [29]:
class GroupLabel(BaseModel):
    number: int
    description_short: str
    description_long: str

class MultipleGroupLabels(BaseModel):
    groups: list[GroupLabel]

    # method to convert the response to a DataFrame
    def to_df(self):
        return pd.DataFrame([group.model_dump() for group in self.groups])

In [30]:
completion = client.beta.chat.completions.parse(
    model="gpt-4o-2024-08-06",
    messages=[
        {"role": "user", "content": prompt},
    ],
    #response_format=CalendarEvent,
    response_format=MultipleGroupLabels,
)

event = completion.choices[0].message.parsed

In [32]:
event.to_df()

Unnamed: 0,number,description_short,description_long
0,1,Perimembranous VSDs and Associated Defects,This group primarily includes surgeries for pe...
1,2,Subaortic Stenosis and Complex Valve Defects,Encompasses procedures addressing subaortic st...
2,3,Cardiomyopathy and Cardiac Function Disorders,Covers surgical interventions related to vario...
3,4,Hypoplastic and Double Inlet/Outlet Syndromes,Primarily focuses on conditions like hypoplast...
4,5,Atrial Septal Defects and Simple Valve Conditions,Involves correcting atrial septal defects (ASD...
5,6,Complex Valvar Prolapse and Stenosis,"Targets intricate issues with valvar prolapse,..."
6,7,Atrioventricular Septal Defects and Related Ma...,Centers on surgeries for complete or partial a...
7,8,Transposition and Vascular Anomalies,Surgeries focusing on transposition of the gre...
8,9,Tetralogy of Fallot and Pulmonary Obstructions,Addressing Tetralogy of Fallot (TOF) and relat...
9,10,Truncus Arteriosus and Venous Anomalies,"Focused on truncus arteriosus, venous return a..."


In [48]:
from TabuLLM.explain import one_vs_rest
df = pd.read_csv('../../data/raw.csv')
fisher = one_vs_rest(
    pd.DataFrame({
        'cluster': cluster_labels
        , 'outcome': df['aki_severity']
    })
)
fisher

Unnamed: 0,Category,Test Type,Statistic,P-value
0,0,Odds Ratio,1.172762,0.431201
1,1,Odds Ratio,0.627154,0.1661018
2,2,Odds Ratio,7.019531,1.138133e-08
3,3,Odds Ratio,1.779492,0.03203989
4,4,Odds Ratio,0.118707,2.115097e-06
5,5,Odds Ratio,1.151392,0.6124628
6,6,Odds Ratio,0.793308,0.3691068
7,7,Odds Ratio,0.371843,0.1104734
8,8,Odds Ratio,1.10971,0.6973826
9,9,Odds Ratio,0.75798,0.3255197


In [47]:
df

Unnamed: 0,description_long,description_short,number
0,Perimembranous Ventricular Septal Defect (VSD)...,Perimembranous VSD,1
1,Left Ventricular Outflow Tract Obstruction and...,LV Outflow & Valve Issues,2
2,Cardiomyopathies,Cardiomyopathies,3
3,Hypoplastic Left Heart Syndrome and Aortic Arc...,HLHS & Aortic Arch,4
4,Atrial Septal Defects,ASD,5
5,Valve Dysfunctions and Stenosis,Valve Dysfunctions,6
6,Complex Congenital Heart Defects,Complex CHD,7
7,Transposition of the Great Arteries (TGA),TGA,8
8,Tetralogy of Fallot and Pulmonary Atresia,TOF & Pulmonary Atresia,9
9,Pulmonary Artery and Vein Anomalies,Pulmonary Artery & Vein,10


In [33]:
# JSON schema for the response
response_schema = {
    "type": "object",
    "properties": {
        "groups": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "number": {"type": "integer"},
                    "description_short": {"type": "string"},
                    "description_long": {"type": "string"},
                },
                "required": ["number", "description_short", "description_long"],
            },
        },
    },
    "required": ["groups"],
}


In [40]:
google_project_id = os.getenv('VERTEXAI_PROJECT')
google_location = os.getenv('VERTEXAI_LOCATION')
import vertexai
from vertexai.generative_models import GenerativeModel, GenerationConfig
vertexai.init(project=google_project_id, location=google_location)

google_model = 'gemini-1.5-pro-001'
model = GenerativeModel(google_model)
response = model.generate_content(
    prompt
    , generation_config = GenerationConfig(
        response_mime_type="application/json", response_schema=response_schema
    )
)

In [42]:
response.text

'{"groups": [{"description_long": "Perimembranous Ventricular Septal Defect (VSD) and Related Conditions", "description_short": "Perimembranous VSD", "number": 1}, {"description_long": "Left Ventricular Outflow Tract Obstruction and Valve Dysfunctions", "description_short": "LV Outflow & Valve Issues", "number": 2}, {"description_long": "Cardiomyopathies", "description_short": "Cardiomyopathies", "number": 3}, {"description_long": "Hypoplastic Left Heart Syndrome and Aortic Arch Abnormalities", "description_short": "HLHS & Aortic Arch", "number": 4}, {"description_long": "Atrial Septal Defects", "description_short": "ASD", "number": 5}, {"description_long": "Valve Dysfunctions and Stenosis", "description_short": "Valve Dysfunctions", "number": 6}, {"description_long": "Complex Congenital Heart Defects", "description_short": "Complex CHD", "number": 7}, {"description_long": "Transposition of the Great Arteries (TGA)", "description_short": "TGA", "number": 8}, {"description_long": "Tetra

In [45]:
# convert text to JSON
import json
response_json = json.loads(response.text)
response_json

# convert JSON to DataFrame
df = pd.DataFrame(response_json['groups'])
df[['number', 'description_short', 'description_long']]

Unnamed: 0,number,description_short,description_long
0,1,Perimembranous VSD,Perimembranous Ventricular Septal Defect (VSD)...
1,2,LV Outflow & Valve Issues,Left Ventricular Outflow Tract Obstruction and...
2,3,Cardiomyopathies,Cardiomyopathies
3,4,HLHS & Aortic Arch,Hypoplastic Left Heart Syndrome and Aortic Arc...
4,5,ASD,Atrial Septal Defects
5,6,Valve Dysfunctions,Valve Dysfunctions and Stenosis
6,7,Complex CHD,Complex Congenital Heart Defects
7,8,TGA,Transposition of the Great Arteries (TGA)
8,9,TOF & Pulmonary Atresia,Tetralogy of Fallot and Pulmonary Atresia
9,10,Pulmonary Artery & Vein,Pulmonary Artery and Vein Anomalies
