In [22]:
class QueryPromptGenerator:
    """Intelligent Query Prompt Generator"""
    
    OPERATOR_MAP = {
        '>': 'greater than', '<': 'less than', '>=': 'greater than or equal to', '<=': 'less than or equal to',
        '==': 'equal to', '!=': 'not equal to', 'between': 'between', 'in': 'included in'
    }

    def __init__(self):
        self.params = {
            'patient_id': None,         # Patient ID (None means no restriction)
            'date_range': None,         # Date range (tuple or relative description)
            'conditions': [],           # List of query conditions
            'output_format': 'table',   # Output format
            'sort_rules': []            # Sorting rules (field, order)
        }
    
    def set_patient(self, patient_id):
        """Set the patient ID"""
        self.params['patient_id'] = patient_id
    
    def set_date_range(self, start=None, end=None, relative=None):
        """
        Set the date range
        :param start: Absolute start date (YYYY-MM-DD)
        :param end: Absolute end date
        :param relative: Relative time description (e.g., "last three months")
        """
        import datetime
        if relative:
            self.params['date_range'] = ('relative', f"{relative}")
        else:
            if start and type(start) != str:
                start = start.strftime("%Y-%m-%d")
            if end and type(end) != str:
                end = end.strftime("%Y-%m-%d")
            self.params['date_range'] = ('absolute', (start, end))
            
    
    def add_condition(self, field, operator, value):
        """
        Add a query condition
        :param field: Field name
        :param operator: Comparison operator (supports >, <, between, etc.)
        :param value: Comparison value (tuple required for 'between' operator)
        """
        self.params['conditions'].append({
            'field': field,
            'operator': operator,
            'value': value
        })
    
    def set_sorting(self, field, ascending=True):
        """Set sorting rules"""
        self.params['sort_rules'].append((
            field,
            'ascending' if ascending else 'descending'
        ))
    
    def _build_date_description(self):
        """Construct the date range description"""
        if not self.params['date_range']:
            return "no time restrictions"
        
        range_type, value = self.params['date_range']
        if range_type == 'relative':
            return f"within {value}"
        
        start, end = value
        if start and end:
            if start == end:
                return f"on {start}"
            return f"from {start} to {end}"
        if start:
            return f"after {start}"
        if end:
            return f"before {end}"
        return "no time restrictions"
    
    def _build_condition_description(self):
        """Construct the condition description"""
        if not self.params['conditions']:
            return "no specific filtering conditions"
        
        desc = []
        for cond in self.params['conditions']:
            op = self.OPERATOR_MAP.get(cond['operator'], cond['operator'])
            field = f"`{cond['field']}`"
            
            if cond['operator'] == 'between':
                val = f"{cond['value'][0]:.2f} and {cond['value'][1]:.2f}" 
            elif isinstance(cond['value'], list):
                val = ", ".join([str(v) for v in cond['value']])
            else:
                val = str(cond['value']) if type(cond['value']) != float else f"{cond['value']:.2f}"
            
            desc.append(f"{field} {op} {val}")
        return "and meet the following conditions: " + ", ".join(desc)
    
    def generate_prompt(self):
        """Generate a natural language query prompt"""
        # Patient description
        patient_desc = (
            f"patient `{self.params['patient_id']}`" 
            if self.params['patient_id'] else "all patients"
        )
        
        # Sorting description
        sort_desc = ""
        if self.params['sort_rules']:
            sort_rules = [
                f"sorted by `{field} {order}`" 
                for field, order in self.params['sort_rules']
            ]
            sort_desc = ", " + ", ".join(sort_rules)
        
        # Assemble the full prompt
        components = [
            f"Please retrieve data for {patient_desc}",
            self._build_date_description(),
            self._build_condition_description(),
            f"results should be displayed in `{self.params['output_format']}`{sort_desc}"
        ]
        
        # Filter out empty values and join the components
        return "; ".join([c for c in components if c != ""]) + "."

# Example Usage
if __name__ == "__main__":
    # Example 1: Complex query
    q1 = QueryPromptGenerator()
    q1.set_patient("P2024-001")
    q1.set_date_range(start="2023-01-01", end="2023-12-31")
    q1.add_condition("Systolic Blood Pressure", ">", "140 mmHg")
    q1.add_condition("Column", "in", ["Hypertension", "Diabetes"])
    # q1.set_sorting("Visit Date", ascending=False)
    q1.params['output_format'] = "statistical summary include mean, median"
    print("Example 1:\n" + q1.generate_prompt())

    # Example 2: Simple query
    q2 = QueryPromptGenerator()
    q2.set_date_range(relative="2023-10-01")
    q2.add_condition("Age", "between", (30, 50))
    print("\nExample 2:\n" + q2.generate_prompt())

Example 1:
Please retrieve data for patient `P2024-001`; from 2023-01-01 to 2023-12-31; and meet the following conditions: `Systolic Blood Pressure` greater than 140 mmHg, `Column` included in Hypertension, Diabetes; results should be displayed in `statistical summary include mean, median`.

Example 2:
Please retrieve data for all patients; within 2023-10-01; and meet the following conditions: `Age` between 30.00 and 50.00; results should be displayed in `table`.


In [23]:
import pandas as pd
import random
def filter_by_col(df, col_list):
    return df[col_list]

def filter_by_row(df, label, type, value):
    if type == '==':
        return df[df[label] == value]
    elif type == '!=':
        return df[df[label] != value]
    elif type == '>':
        return df[df[label] > value]
    elif type == '<':
        return df[df[label] < value]
    elif type == '>=':
        return df[df[label] >= value]
    elif type == '<=':
        return df[df[label] <= value]
    elif type == 'between':
        return df[(df[label] >= value[0]) & (df[label] <= value[1])]

def calculate_stats(df, type):
    res = None
    if type == 'mean':
        res = df.mean()
    elif type == 'median':
        res = df.median()
    elif type == 'std':
        res = df.std()
    elif type == 'min':
        res = df.min()
    elif type == 'max':
        res = df.max()
    elif type == 'count':
        res = df.count()
    else:
        print('Unknown type:', type)
    
    # change column name to '{label}_{type}'
    res = res.to_frame().T
    res.columns = [col + '_' + type for col in res.columns]
    return res



def write_to_csv(df, filename):
    df.to_csv(filename, index=False, float_format='%.2f')
    
def read_csv(filename, date_cols=[], str_col=[], id_col=[], target_colomn=[]) -> pd.DataFrame:
    df = pd.read_csv(filename)
    df = df[target_colomn]
    for col in date_cols:
        df[col] = pd.to_datetime(df[col])
    for col in str_col + id_col:
        df[col] = df[col].astype(str)
    return df



In [24]:
def filter_data(df, date_col, str_col, id_col, filter_by_id=True, filter_by_date=True, row_filter_num=1, column_num=1, staticmethod_num=1):
    prompt_generator = QueryPromptGenerator()

    def select_upper_and_lower_bound(nums):
        max_num = max(nums)
        min_num = min(nums)
        upper_bound = random.uniform(max_num, min_num)
        lower_bound = random.uniform(min_num, upper_bound)
        return upper_bound, lower_bound
    def select_upper_and_lower_bound_for_date(nums):

        unique_dates = list(set(nums))
        upper_bound = random.choice(unique_dates)
        unique_dates = [date for date in unique_dates if date < upper_bound]
        lower_bound = random.choice(unique_dates)
        
        return upper_bound, lower_bound
    def select_value(nums):
        uniqs = list(set(nums))
        return random.choice(uniqs)
    
    visit_label = list(set(df.columns) - set(id_col) - set(date_col))
    print(visit_label)
    
    if filter_by_id:
        id_value = select_value(df[id_col[0]])
        df = filter_by_row(df, id_col[0], '==', id_value)
        prompt_generator.set_patient(id_value)

    if filter_by_date:
        operator = random.choice(['>=', '<=', 'between'])        
        upper_bound, lower_bound = select_upper_and_lower_bound_for_date(df[date_col[0]])
        value = [lower_bound, upper_bound] if operator == 'between' else upper_bound
        filter_by_row(df, date_col, operator, value)
        prompt_generator.set_date_range(start=lower_bound, end=upper_bound)

    if row_filter_num > 0:
        columns = random.sample(visit_label, row_filter_num)
        for column in columns:
            if column in str_col or column in id_col:
                value = select_value(df[column])
                operator = random.choice(['==', '!='])
            else:
                operator = random.choice(['>', '<', '>=', '<=', 'between'])
                upper_bound, lower_bound = select_upper_and_lower_bound(df[column])
                value = [lower_bound, upper_bound] if operator == 'between' else upper_bound
            filter_by_row(df, column, operator, value)
            prompt_generator.add_condition(column, operator, value)

    if column_num > 0:
        keep_columns = random.sample(visit_label, column_num)
        keep_columns = id_col + date_col + keep_columns
        df = filter_by_col(df, keep_columns)
        prompt_generator.add_condition('Column', 'in', keep_columns)
    
    if staticmethod_num > 0:
        # drop columns id_col, date_col, str_col if exist
        df = df.drop(columns=id_col + date_col + str_col, errors='ignore')
        
        res = []
        all_stats_type = ['mean', 'median', 'std', 'min', 'max', 'count']
        stats_type = random.sample(all_stats_type, staticmethod_num)
        for type in stats_type:
            res.append(calculate_stats(df, type))
        df = pd.concat(res, axis=1)
        prompt_generator.params['output_format'] = "statistical summary include " + ', '.join(stats_type)
    return df, prompt_generator.generate_prompt(), [filter_by_id, filter_by_date, row_filter_num, column_num, staticmethod_num]

In [28]:
import pandas as pd
import os
file_path = "/home/wyy/workspace/EHRAgent/workspace/esrd/data"
file_name = "esrd_656_all.csv"
date_col = ['RecordTime']
str_col = ['Diab', 'Gender']
target_column = ['PatientID', 'RecordTime', 'Cl', 'CO2CP', 'WBC', 'Hb', 'Urea', 'Ca', 'K', 'Na', 'Scr', 'P', 'Albumin', 'HSCRP', 'Glucose', 'Appetite', 'Weight', 'SBP', 'DBP', 'Age', 'Gender', 'Diab', 'Height']
id_col = ['PatientID']
df = read_csv(filename=file_path + '/' + file_name, date_cols=date_col, str_col=str_col, id_col=id_col, target_colomn=target_column)

dataset_path = './generated_data/'
os.makedirs(dataset_path, exist_ok=True)
global_index = 1

df_filtered, prompt,metainfo = filter_data(df, date_col=date_col, str_col=str_col, id_col=['PatientID'], filter_by_id=True, filter_by_date=True, row_filter_num=2, column_num=2, staticmethod_num=0)
print(prompt)
print(metainfo)
write_to_csv(df_filtered, dataset_path + 'generated_data_' + str(global_index) + '.csv')


['Ca', 'Gender', 'SBP', 'Hb', 'Albumin', 'K', 'Scr', 'Appetite', 'Urea', 'Cl', 'Diab', 'DBP', 'HSCRP', 'CO2CP', 'Height', 'P', 'Weight', 'Na', 'WBC', 'Age', 'Glucose']
Please retrieve data for patient `371`; from 2008-12-02 to 2009-04-01; and meet the following conditions: `Appetite` between 1822.43 and 2648.12, `Urea` greater than 19.19, `Column` included in PatientID, RecordTime, Albumin, HSCRP; results should be displayed in `table`.
[True, True, 2, 2, 0]
