In [1]:
from utils.query_handler import QueryHandler

In [2]:
handler = QueryHandler()

In [3]:
from utils.dataset_creator import QueryTemplateGenerator, MimicSchema

In [4]:
schema = MimicSchema()
schema.generate_sample_values()
schema.check_status()


        Current Status:
            max_workers set: 35
            sample_size set: 150
            sample values stored: False
        
Default Sample File found at /home/aaryan/Documents/Ashoka/Sem_8/Capstone_Thesis/NL2SQL_MIMIC/data/custom_dataset/sample_data/default.json

        Current Status:
            max_workers set: 35
            sample_size set: 150
            sample values stored: True
        


In [5]:
import random
template = QueryTemplateGenerator(schema)

In [7]:
def _get_entity_type(table: str) -> str:
    """Get a descriptive entity name for a table"""
    entity_mappings = {
        'patients': 'patients',
        'admissions': 'hospital admissions',
        'diagnoses_icd': 'patient diagnoses',
        'procedures_icd': 'patient procedures',
        'prescriptions': 'medications',
        'labevents': 'lab results',
        'chartevents': 'chart entries',
        'icustays': 'ICU stays',
        'transfers': 'patient transfers',
        'services': 'hospital services',
        'microbiologyevents': 'microbiology results',
        'outputevents': 'patient outputs',
        'inputevents': 'patient inputs',
        'emar': 'medication administration records',
        'pharmacy': 'pharmacy orders',
        'poe': 'provider orders',
        'd_icd_diagnoses': 'diagnosis codes',
        'd_icd_procedures': 'procedure codes',
        'd_labitems': 'laboratory tests',
        'd_items': 'charted items'
    }
    
    return entity_mappings.get(table, table.replace('_', ' '))

In [8]:
count = 1
results = []

In [17]:

# for _ in range(count):
# Choose random table
table = random.choice(schema.tables)

# Choose random columns to select
columns = template.random_columns(table, min_cols=1, max_cols=4)

table, columns

('microbiologyevents', ['ab_name', 'comments', 'storedate', 'subject_id'])

In [19]:
use_where = random.random() > 0.3
where_clause = ""
nl_filter = ""

use_where

True

In [36]:
if use_where:
    # Choose how many conditions to include
    num_conditions = random.choices([1, 2, 3], weights=[0.1, 0.3, 0.6])[0]
    conditions = []
    nl_conditions = []

    for i in range(num_conditions):
        # Choose a random column for the condition
        filter_col = random.choice(template.schema.columns[table])
        
        filter_sql, filter_nl = template.random_filter(table, filter_col)
        if filter_sql and filter_nl:  # Ensure we got valid filters
            conditions.append(filter_sql)
            nl_conditions.append(filter_nl)
    
    if len(conditions) > 1:
        logical_ops = [random.choice(["AND", "OR"]) for _ in range(len(conditions)-1)]
        where_clause = "WHERE " + conditions[0]
        nl_filter = "where " + nl_conditions[0]
        
        for i in range(1, len(conditions)):
            op = logical_ops[i-1]
            where_clause += f" {op} {conditions[i]}"
            nl_filter += f" {op.lower()} {nl_conditions[i]}"
    elif len(conditions) == 1:
        where_clause = f"WHERE {conditions[0]}"
        nl_filter = f"where {nl_conditions[0]}"

where_clause, nl_filter

("WHERE spec_type_desc IN ('PERIPHERAL BLOOD LYMPHOCYTES', 'BRONCHIAL WASHINGS', 'FOOT CULTURE') AND test_seq NOT IN (2, 10, 9, 1, 13, 4, 11, 18, 16) AND quantity <> '10,000-100,000 CFU/mL'",
 'where spec type desc among PERIPHERAL BLOOD LYMPHOCYTES, BRONCHIAL WASHINGS, FOOT CULTURE and test seq not among 2, 10, 9, 1, 13, 4, 11, 18, 16 and quantity excluding 10,000-100,000 CFU/mL')

In [38]:
use_order = random.random() > 0.7
order_clause = ""
nl_order = ""

use_order

True

In [49]:
if use_order:
    order_col = random.choice(columns)
    order_dir = random.choice(template.sort_dirs)
    order_clause = f"ORDER BY {order_col} {order_dir}"
    
    nl_dir = random.choice(template.sort_phrases[order_dir])
    nl_order = random.choice(template.order_phrases).format(order_col=order_col.replace('_', ' '), nl_dir=nl_dir)

order_clause, nl_order

('ORDER BY ab_name DESC', 'ordered by ab name greatest first')

In [50]:

# Decide if we'll include LIMIT
use_limit = random.random() > 0.5
limit_clause = ""
nl_limit = ""

use_limit


True

In [65]:

if use_limit:
    limit = random.choice(template.limit_values)
    limit_clause = f"LIMIT {limit}"
    
    # Use more natural language for limit
    nl_limit = random.choice(template.limit_phrases).format(limit=limit)

limit_clause, nl_limit

('LIMIT 20', 'return a maximum of 20 items')

In [66]:

# Build the complete SQL query
sql_parts = [
    f"SELECT {', '.join(columns)} FROM {table}",
    where_clause,
    order_clause,
    limit_clause
]
sql_query = " ".join([part for part in sql_parts if part])

sql_query

"SELECT ab_name, comments, storedate, subject_id FROM microbiologyevents WHERE spec_type_desc IN ('PERIPHERAL BLOOD LYMPHOCYTES', 'BRONCHIAL WASHINGS', 'FOOT CULTURE') AND test_seq NOT IN (2, 10, 9, 1, 13, 4, 11, 18, 16) AND quantity <> '10,000-100,000 CFU/mL' ORDER BY ab_name DESC LIMIT 20"

In [67]:


# Build the NL question with more natural language
entity_type = _get_entity_type(table)

# Generate natural column descriptions
column_descriptions = []
for col in columns:
    # Make column names more readable
    readable_col = col.replace('_', ' ')
    column_descriptions.append(readable_col)

# Choose a natural query template based on context
query_templates = [
    f"Show me {', '.join(column_descriptions)} for {entity_type}",
    f"What are the {', '.join(column_descriptions)} of {entity_type}",
    f"List {', '.join(column_descriptions)} from {entity_type}",
    f"Get {', '.join(column_descriptions)} for {entity_type}"
]

# Choose a template but bias toward the first one for simplicity
template_weights = [0.4, 0.2, 0.2, 0.2]
base_question = random.choices(query_templates, weights=template_weights)[0]

# Build the complete NL question
nl_parts = [
    base_question,
    nl_filter,
    nl_order,
    nl_limit
]
nl_question = " ".join([part for part in nl_parts if part])

# Add a period at the end if it doesn't have one
if not nl_question.endswith('.'):
    nl_question += '.'

# Capitalize the first letter
nl_question = nl_question[0].upper() + nl_question[1:]

nl_question

'Show me ab name, comments, storedate, subject id for microbiology results where spec type desc among PERIPHERAL BLOOD LYMPHOCYTES, BRONCHIAL WASHINGS, FOOT CULTURE and test seq not among 2, 10, 9, 1, 13, 4, 11, 18, 16 and quantity excluding 10,000-100,000 CFU/mL ordered by ab name greatest first return a maximum of 20 items.'

In [None]:

results.append({"question": nl_question, "query": sql_query})
