### Paper 1 - eating behavior

#### Define the relevant directories used in this paper

Data is pulled from the standardized data folder; subsequently, it is stored and managed in the paper 1 folder. 

In [1]:
import os

# Define the source and output directories
source_directory = r"C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\DB2_standard"
paper1_directory = r"C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional"

# Ensure the output directory exists
os.makedirs(paper1_directory, exist_ok=True)

#### Create a research question-specific SQL database subset 

Check those medical records where any/3+/all emotional values are available, and filter the database to contain only the specified patients and medical records. Save the data to 3 new SQL files - one with any, one with some, one with all values available. For research purposes, the last one is most likely to be used. The first two may be relevant if trying to increase the sample size for one or a few specific emotional values. 

In [5]:
import sqlite3
import pandas as pd
import os

# Use the above defined directories
db_path = os.path.join(source_directory, "pnk_db2_colclean.sqlite")
conn = sqlite3.connect(db_path)

# List all tables in the database
query_tables = "SELECT name FROM sqlite_master WHERE type='table';"
tables = pd.read_sql_query(query_tables, conn)
table_names = tables['name'].tolist()

# Define criteria for filtering for any/3+/all emotional values available
def create_filtered_database(criteria, output_filename):
    # Set up the appropriate query based on the criteria for the three scenarios
    if criteria == "any":
        # to select records where at least one emotional variable is not null
        query = """
        SELECT medical_record_id, patient_id
        FROM medical_records_colclean
        WHERE hunger IS NOT NULL
           OR satiety IS NOT NULL
           OR emotional_eating IS NOT NULL
           OR emotional_eating_value IS NOT NULL
           OR quantity_control IS NOT NULL
           OR impulse_control IS NOT NULL;
        """
    elif criteria == "3plus":
        # to select records where at least three emotional variables are not null
        query = """
        SELECT medical_record_id, patient_id
        FROM medical_records_colclean
        WHERE (CASE WHEN hunger IS NOT NULL THEN 1 ELSE 0 END +
               CASE WHEN satiety IS NOT NULL THEN 1 ELSE 0 END +
               CASE WHEN emotional_eating IS NOT NULL THEN 1 ELSE 0 END +
               CASE WHEN emotional_eating_value IS NOT NULL THEN 1 ELSE 0 END +
               CASE WHEN quantity_control IS NOT NULL THEN 1 ELSE 0 END +
               CASE WHEN impulse_control IS NOT NULL THEN 1 ELSE 0 END) >= 3;
        """
    elif criteria == "all":
        # to select records where all emotional variables are not null 
        query = """
        SELECT medical_record_id, patient_id
        FROM medical_records_colclean
        WHERE hunger IS NOT NULL
          AND satiety IS NOT NULL
          AND emotional_eating IS NOT NULL
          AND emotional_eating_value IS NOT NULL
          AND quantity_control IS NOT NULL
          AND impulse_control IS NOT NULL;
        """
    
    # Get the relevant records, and extract their medical record and patient IDs
    relevant_records = pd.read_sql_query(query, conn)
    relevant_medical_record_ids = tuple(relevant_records['medical_record_id'])
    relevant_patient_ids = tuple(relevant_records['patient_id'])
    
    # Create a new database in the output directory
    output_db_path = os.path.join(paper1_directory, output_filename)
    filtered_conn = sqlite3.connect(output_db_path)
    
    # Filter each table in the source SQl to only contain the records that comply the criteria of the given scenario; 
    # ie. they have records where any/3+/all emotional variables are available
    for table_name in table_names:
        if table_name.startswith("sqlite_"):
            # Skip any SQLite system tables
            continue 
        # In case of tables that may contain several medical records from the same patient, 
        # filter by medical_record_id
        if table_name == "medical_records_colclean" or table_name == "prescriptions_colclean":
            query = f"""
            SELECT * 
            FROM {table_name}
            WHERE medical_record_id IN {relevant_medical_record_ids}
            """
        else:
            # For all other tables, filter by patient_id only, as medical record ID is not available in those
            query_check_column = f"PRAGMA table_info({table_name});"
            columns = pd.read_sql_query(query_check_column, conn)
            if 'patient_id' not in columns['name'].values:
                continue  # Skip tables without patient_id
            query = f"""
            SELECT *
            FROM {table_name}
            WHERE patient_id IN {relevant_patient_ids}
            """
        
        # Execute the given query and save the result in a new SQLite database
        filtered_data = pd.read_sql_query(query, conn)
        filtered_data.to_sql(table_name, filtered_conn, index=False, if_exists="replace")
    
    filtered_conn.close()
    return len(relevant_records), len(set(relevant_records['patient_id']))

# Create and save the three databases for the three scenarios - any/3+/all emotional variables available
any_count, any_patients = create_filtered_database("any", "emotional_any_notna.sqlite")
three_plus_count, three_plus_patients = create_filtered_database("3plus", "emotional_3plus_notna.sqlite")
all_count, all_patients = create_filtered_database("all", "emotional_all_notna.sqlite")

# Close the connection
conn.close()

# Print summary
print(f"Any emotional data points are available in {any_count} records from {any_patients} patients")
print(f"At least 3 emotional data points are available in {three_plus_count} records from {three_plus_patients} patients")
print(f"All emotional data points are available in {all_count} records from {all_patients} patients")

Any emotional data points are available in 2482 records from 2437 patients
At least 3 emotional data points are available in 2169 records from 2132 patients
All emotional data points are available in 1853 records from 1826 patients


#### Clean and link measurements to prescriptions and medical records

Based on the patient ID and the date of a given measurements, look for prescriptions with the same patient ID that cover the range of time in which the measurement was taken. This way, measurements can be linked to important metadata, such as the prescription and medical record they belong to, the step of the programme they were taken in, etc. 

In summary, this is a key step in the research, without which data on any measurement's identity would be insufficient, and measurements from different prescriptions of the same individual could be mixed, for example. In a previous attempt, I tried identifying blocks of measurements as those that are taken within two months of each other, but I consider this a much more solid approach. 

It is important to note that some patients may take repeated measurements on the same occasion. These duplicates need to be removed, as they inflate the dataset. 

After removing the duplicates, measurements and prescriptions are linked in a two-step process. 

First, measurements are linked to all possible prescriptions that can belong to them based on the shared patient ID (this scenario where every option is linked to every option is called a Cartesian product). 

After, these possible links are filtered by date: a measurement belongs to a prescription if it is within its validity period, or is 5 days within its start or end dates. In the latter case, a measurement may be assigned to multiple prescriptions; if this happens, it is assigned to the one it is closer to in time. 

If a measurement is not succesfully linked to any prescription, it is lost. 

In [6]:
"""
Remove duplicate measurements before doing any data frame merging. 
Any measurement from the same patient on the same day (ignoring time) with the same weight should be considered a duplicate.
"""

import pandas as pd
import sqlite3
import os

# Connect to the database, load the measurements table
conn = sqlite3.connect(os.path.join(paper1_directory, "emotional_all_notna.sqlite"))
measurements = pd.read_sql_query("SELECT * FROM measurements_colclean", conn)

# Convert measurement_date to datetime, if not already in that format. 
# Add a temporary column with the measurement date only; time is ignored, 
# as repeated measurements are at least a few seconds or minutes apart. 
measurements['measurement_date'] = pd.to_datetime(measurements['measurement_date'])
measurements['measurement_date_date'] = measurements['measurement_date'].dt.date
# After, remove duplicates based on patient id, date, and weight. 
# Drop the temporary column. 
measurements_rowclean = measurements.drop_duplicates(subset=['patient_id', 'measurement_date_date', 'weight_kg'])
measurements_rowclean = measurements_rowclean.drop(columns=['measurement_date_date'])
# Save the cleaned measurements back to the database with the _rowclean name code
measurements_rowclean.to_sql("measurements_rowclean", conn, if_exists="replace", index=False)

print(f"Duplicate measurements removed. There are {len(measurements_rowclean)} measurements from {measurements_rowclean['patient_id'].nunique()} patients.")

Duplicate measurements removed. There are 35709 measurements from 1826 patients.


In [7]:
"""
Link metadata from the prescriptions table to measurements. 

The two dataframes are merged based on patient_id, creating a Cartesian product of the two tables, 
where every measurement from one patient is linked to every possible prescription from that patient.

This Cartesian product is then filtered based on the dates of both the measurement and the prescription, 
in order to, preferably, only consider a prescription being linked to a given measurement
if the measurement date is between the prescription's start and end dates. 

If a measurement is not within any prescription's validity period, 
there is a permissivity of 5 days, meaning that a measurement can be linked to a prescription if
it is within 5 days from the start or end date of the prescription.
If this allows a measurement to be linked to multiple prescriptions,
it is linked to the one it is closest to in date. 

If a measurement is not linked to any valid prescription, 
it is excluded from the outuput. 
"""

# Connect to the paper-specific database, load the prescriptions table, and make sure its date values are in datetime format
conn = sqlite3.connect(os.path.join(paper1_directory, "emotional_all_notna.sqlite"))
prescriptions = pd.read_sql_query("SELECT * FROM prescriptions_colclean", conn)
prescriptions['prescription_creation_date'] = pd.to_datetime(prescriptions['prescription_creation_date'])
prescriptions['prescription_validity_end_date'] = pd.to_datetime(prescriptions['prescription_validity_end_date'])

# Merge the measurements and prescriptions data frames on patient ID,
# creating the Cartesian product that needs further date-based filtering
merged = pd.merge(measurements_rowclean, prescriptions, on="patient_id", how="left", suffixes=('_meas', '_presc'))

# To execute date-based filtering: 
# First, define those measurements that are within the range of a prescription. 
# If any measurement can be assigned to a prescription based on this criteria, it will be. 
merged['measurement_in_prescription_range'] = (
    (merged['measurement_date'] >= merged['prescription_creation_date']) &
    (merged['measurement_date'] <= merged['prescription_validity_end_date'])
)
# If after this, a measurement is still not linked to any prescription due to not being in the range of any, 
# it will be linked to the prescription it is closest to, within a 5-day permissivity range. 
# For these out-of-range measurements, first, the distance from the start/end dates of any prescription is calculated. 
merged['days_before_prescription_start'] = (merged['prescription_creation_date'] - merged['measurement_date']).dt.days
merged['days_after_prescription_end'] = (merged['measurement_date'] - merged['prescription_validity_end_date']).dt.days
# After, near-range measurements are defined, 
# as measurements that are NOT within the range of any prescription, 
# AND they are at within 5 days before the start/after the end of any prescription. 
merged['measurement_near_prescription_range'] = (
    (~merged['measurement_in_prescription_range']) &
    (
        ((merged['days_before_prescription_start'] <= 5) & (merged['days_before_prescription_start'] > 0)) |
        ((merged['days_after_prescription_end'] <= 5) & (merged['days_after_prescription_end'] > 0))
    )
)
# After, a distance metric calculation determines how far a given measurement is from a prescription. 
# In-range measurements get a distance metric of 0,
# while out-of-range measurements get the minimum distance to any boundary they are close to. 
merged['measurement_distance_from_prescription_range'] = merged.apply(
    lambda row: 0 if row['measurement_in_prescription_range'] else min(max(row['days_before_prescription_start'], 0), max(row['days_after_prescription_end'], 0)),
    axis=1
)
# After defining the in-range and near-range logics, the database (currently containing Cartesian products) 
# is filtered to keep only in-or near-range measurements. 
# Any measurements not assigned to a prescription is lost. 
measurements_with_metadata = merged[merged['measurement_in_prescription_range'] | merged['measurement_near_prescription_range']].copy()
# In edge cases where multiple prescriptions are linked to a single measurement, only the closest match is kept. 
# This is done by sorting the data frame by patient id, measurement date and distance from range, 
# and if multiple measurement-prescription pairs from the same patient on the same date are found, 
# duplicates are removed and only the row with the smallest distance from range is kept. 
measurements_with_metadata = measurements_with_metadata.sort_values(['patient_id', 'measurement_date', 'measurement_distance_from_prescription_range'])
measurements_with_metadata = measurements_with_metadata.drop_duplicates(['patient_id', 'measurement_date'])

# After filtering the data frame, columns are reordered, and any irrelevant ones, like prescribed supplements, are dropped. 
column_order = [
    'patient_id',
    'medical_record_id',
    'prescription_id',
    'measurement_date',
    'prescription_creation_date',
    'prescription_validity_end_date',
    'prescription_validity_days',
    'method',
    'step',
    'weight_kg',
    'bmi',
    'bmr_kcal',
    'fat_%',
    'vat_%',
    'muscle_%',
    'water_%',
    'measurement_in_prescription_range',
    'days_before_prescription_start',
    'days_after_prescription_end',
    'measurement_near_prescription_range',
    'measurement_distance_from_prescription_range'
]
measurements_with_metadata = measurements_with_metadata[column_order]

# The measurements_with_metadata data frame is saved within the SQL database, and some summary info is printed. 
measurements_with_metadata.to_sql("measurements_with_metadata", conn, if_exists="replace", index=False)

print(f"Measurements are linked to their corresponding prescriptions and medical records. \n"
    f"There are a total of {measurements_with_metadata.shape[0]} measurements "
    f"from {measurements_with_metadata['medical_record_id'].nunique()} medical records "
    f"of {measurements_with_metadata['patient_id'].nunique()} patients.")

Measurements are linked to their corresponding prescriptions and medical records. 
There are a total of 20976 measurements from 1678 medical records of 1664 patients.


#### Add sex, genomics ID and baseline/final weight data to medical records

In an effort to create data frames containing the most possible information in one place, the medical records data frame is completed with the sex (originally stored in Patients) as well as the baseline and final weight data (measurements linked to medical records stored in measurements_with_metadata) of patients. Genomics sample IDs are also fetched for patients that have it available. 

Besides executing these merge operations, the code checks the time passed between baseline and final measurements in each medical record, along with whether the measurements are close to the beginning/end date of the medical record they belong to or not. This helps checking whether the length of the actual followup is similar to that of the medical record or not. 

Any medical record that has no associated measurements is lost here. 

In [8]:
"""
Complete Medical records by adding sex and baseline/final weight data to it. 

Sex and genomics sample IDs are fetched from the Patients table, based on the patient_id.

Baseline and final weight measurements are obtained from the measurements_with_metadata table created in the previous step. 
The logic is the following: 
Measurements are grouped by patient and medical record ID, and the first and last measurements of each group are assigned
to the medical records table as baseline and final measurements, respectively.
Delta weight is calculated as the difference between final and baseline weights, to obtain negative results. 

Measurement dates are added and it is checked if they are within the medical record creation and closing dates.

If a medical record has no measurements linked to it, it is dropped. 

Additionally, the 'days_between_measurements' column is added to calculate the number of days between the baseline and final measurements.

Finally, the columns are reordered to match the desired order.
"""

import pandas as pd
import sqlite3

# Connect to the database, load relevant tables
conn = sqlite3.connect(os.path.join(paper1_directory, "emotional_all_notna.sqlite"))
medical_records = pd.read_sql_query("SELECT * FROM medical_records_colclean", conn)
patients = pd.read_sql_query("SELECT * FROM patients_colclean", conn)
measurements_with_metadata = pd.read_sql_query("SELECT * FROM measurements_with_metadata", conn)

# The following functions complete the original medical records data frame with research-relevant variables.
# First, add the sex variable to medical_records_complete by merging patients' sex into medical_records_complete based on patient_id
"""
Adding sex data and genomics sample IDs to medical records
"""
medical_records_complete = pd.merge(
    medical_records,
    patients[['patient_id', 'sex', 'genomics_sample_id']],
    on='patient_id',
    how='left'
)

# After, add baseline and final measurements to medical_records_complete
# Treat measurements coming from a given medical record as units
# by grouping measurements_with_metadata by patient_id and medical_record_id 
"""
Adding weight data to medical records
"""
grouped_measurements = measurements_with_metadata.groupby(['patient_id', 'medical_record_id'])
# Extract the first (baseline) and last (final) measurement for each group
baseline = grouped_measurements.first().reset_index()
final = grouped_measurements.last().reset_index()
# Insert baseline and final measurements into medical_records_complete
medical_records_complete = pd.merge(
    medical_records_complete,
    baseline[['patient_id', 'medical_record_id', 'measurement_date', 'weight_kg', 'bmi']],
    on=['patient_id', 'medical_record_id'],
    how='left'
)
medical_records_complete = pd.merge(
    medical_records_complete,
    final[['patient_id', 'medical_record_id', 'measurement_date', 'weight_kg', 'bmi']],
    on=['patient_id', 'medical_record_id'],
    how='left',
    suffixes=('_baseline', '_final')
)
# Make sure all dates are in datetime format for further operations, 
# and calculate delta weight and delta BMI values (final - baseline, so the resulting weight loss value is negative)
medical_records_complete['medical_record_creation_date'] = pd.to_datetime(medical_records_complete['medical_record_creation_date'])
medical_records_complete['medical_record_closing_date'] = pd.to_datetime(medical_records_complete['medical_record_closing_date'])
medical_records_complete['measurement_date_baseline'] = pd.to_datetime(medical_records_complete['measurement_date_baseline'])
medical_records_complete['measurement_date_final'] = pd.to_datetime(medical_records_complete['measurement_date_final'])
medical_records_complete['delta_weight_kg'] = medical_records_complete['weight_kg_final'] - medical_records_complete['weight_kg_baseline']
medical_records_complete['delta_bmi'] = medical_records_complete['bmi_final'] - medical_records_complete['bmi_baseline']
# Check if the baseline and final measurements are close to the starting/closing date of the medical record they belong to or not (within a 10-day window). 
# In some cases, the first measurement is recorded weeks after opening the medical record, or the last one is taken long before closing it. 
# In other cases, the medical record's closing date is absent, if this happens, the last measurement will be considered as out of range. 
# This is supposed to help identify cases where the followup has some imperfections. 
window_days = 10
medical_records_complete['baseline_measurement_inrange'] = (
    (medical_records_complete['measurement_date_baseline'] >= 
     medical_records_complete['medical_record_creation_date'] - pd.Timedelta(days=window_days)) &
    (medical_records_complete['measurement_date_baseline'] <=
     medical_records_complete['measurement_date_baseline'] + pd.Timedelta(days=window_days))
)
medical_records_complete['final_measurement_inrange'] = (
    (medical_records_complete['measurement_date_final'] >= 
     medical_records_complete['medical_record_closing_date'] - pd.Timedelta(days=window_days)) &
    (medical_records_complete['measurement_date_final'] <= 
     medical_records_complete['medical_record_closing_date'] + pd.Timedelta(days=window_days))
)
# Add a column that calculates the days passed between baseline and final measurements
# This also helps identify cases where the medical record's duration and the actual followup time are very different
medical_records_complete['days_between_measurements'] = (
    (medical_records_complete['measurement_date_final'] - medical_records_complete['measurement_date_baseline']).dt.days
)

"""
Removing medical records with no associated measurements
"""
# As for some reason (unidentified as of 16Apr25) many medical records have no available measurements associated to them, 
# any such instances are dropped from the data frame. 
medical_records_complete = medical_records_complete.dropna(subset=['weight_kg_baseline', 'weight_kg_final'])

"""
Presenting and saving the output
"""
# Rename and reorder columns for better clarity and interpretability
medical_records_complete = medical_records_complete.rename(columns={
    'measurement_date_baseline': 'baseline_measurement_date',
    'measurement_date_final': 'final_measurement_date',
    'weight_kg_baseline': 'baseline_weight_kg',
    'weight_kg_final': 'final_weight_kg', 
    'bmi_baseline': 'baseline_bmi',
    'bmi_final': 'final_bmi'
})
desired_column_order = [
    'patient_id',
    'medical_record_id',
    'genomics_sample_id',
    'medical_record_creation_date',
    'medical_record_closing_date',
    'intervention_duration_days',
    'baseline_measurement_date',
    'final_measurement_date',
    'days_between_measurements',
    'baseline_measurement_inrange',
    'final_measurement_inrange',
    'birth_date',
    'age',
    'age_when_creating_record',
    'sex',
    'height_m',
    'baseline_weight_kg',
    'final_weight_kg',
    'delta_weight_kg',
    'baseline_bmi',
    'final_bmi',
    'delta_bmi',
    'wc_cm_confirm_time',
    'pnk_method',
    'orders_in_medical_record',
    'dietitian_visits',
    'physical_activity',
    'physical_activity_frequency',
    'physical_inactivity_cause',
    'weight_gain_cause',
    'smoking',
    'medications',
    'hunger',
    'satiety',
    'emotional_eating',
    'emotional_eating_value',
    'quantity_control',
    'impulse_control'
]
medical_records_complete = medical_records_complete[desired_column_order]
# Save the complete medical records to the SQL database, and print a summary statement
medical_records_complete.to_sql("medical_records_complete", conn, if_exists="replace", index=False)
print(f"Medical records table completed with sex and baseline/final weight data. \n" 
      f"There are {len(medical_records_complete)} records available from {medical_records_complete['patient_id'].nunique()} patients.")

Medical records table completed with sex and baseline/final weight data. 
There are 1678 records available from 1664 patients.


#### Create the base input for survival analysis

Here, data frames specifically prepared for survival analysis are created. Time-to-event (days) of achieving 3 different weight loss targets (5-10-15%) in 3 different time frames (40-60-80 days) is analyzed. Relevant demographic, anthropometric and eating behavior variables are added to each analyzed medical record. 

In [15]:
import pandas as pd
import os
import sqlite3
from datetime import timedelta
# Removed: import logging

"""
CONFIGURATION
"""
# Define directories and database paths - paper1_directory should be defined 
# in the first cell of this notebook chapter
input_db_path = os.path.join(paper1_directory, 'emotional_all_notna.sqlite')
input_measurements = "measurements_with_metadata"
input_medical_records = "medical_records_complete"
output_db_path = os.path.join(paper1_directory, 'survival_analysis.sqlite')

# Define analysis parameters
weight_loss_targets = [5, 10, 15]     # Weight loss target percentages
time_windows = [40, 60, 80]       # Time windows (centers) in days
window_span = 10                   # Permissible span around windows (+/- days)

# Define the variables stored in medical_records_complete that are relevant for the analysis. 
# These include basic metadata like patient and record ID,
# basic factors such as age and sex, 
# as well as the emotional and eating behavior variables pivotal to the research question. 
# The list can be amended on demand - 
# for example, right now it does not include medical record creating and closing dates. 
relevant_medical_values = ['patient_id', 'medical_record_id', 'sex', 'age',
                             'height_m', 'baseline_bmi', 'hunger', 'satiety', 'emotional_eating',
                             'emotional_eating_value', 'quantity_control', 'impulse_control']

"""
DATA LOADING & PREPARATION
"""

def load_measurements(connection):
    """
    Load measurements from the measurement_with_metadata table; 
    make sure key values are in the correct format. 
    """
    query = f"SELECT * FROM {input_measurements}"
    measurements = pd.read_sql_query(query, connection)
    measurements['measurement_date'] = pd.to_datetime(measurements['measurement_date'], errors='coerce')
    measurements['weight_kg'] = pd.to_numeric(measurements['weight_kg'], errors='coerce')
    return measurements

def load_medical_records(connection):
    """
    Load medical records from the medical_records_complete table;
    make sure date values are in datetime format. 
    The exact columns to be used are defined in the prepare_patient_data function.
    """
    query = f"SELECT * FROM {input_medical_records}"
    medical_records = pd.read_sql_query(query, connection)
    medical_records['medical_record_creation_date'] = pd.to_datetime(medical_records['medical_record_creation_date'], errors='coerce')
    return medical_records

def prepare_patient_data(measurements, medical_records):
    """
    Filter measurements to only include those from the earliest medical record for each patient.
    Merge measurements with relevant medical record data, including the pivotal eating behavior scores. 
    """
    # Filter measurements to only include those from the first treatment record of each patient
    earliest_records_with_data = measurements.sort_values('measurement_date')\
        .groupby('patient_id')['medical_record_id']\
        .first()\
        .reset_index()
    filtered_measurements = pd.merge(
        measurements,
        earliest_records_with_data,
        on=['patient_id', 'medical_record_id'],
        how='inner'
    )
    # Identify the baseline measurement in each record
    baseline_data = filtered_measurements.sort_values('measurement_date')\
                                       .groupby(['patient_id', 'medical_record_id'])\
                                       .first()\
                                       .reset_index()

    cols_to_select = [col for col in relevant_medical_values if col in medical_records.columns]
    medical_record_data = medical_records[cols_to_select]
    # Merge baseline measurements with relevant medical record data
    prepared_data = pd.merge(
        baseline_data,
        medical_record_data,
        on=['patient_id', 'medical_record_id'],
        how='left' # Keep all baseline data
    )
    return prepared_data, filtered_measurements

"""
CALCULATE WEIGHT LOSS OUTCOMES
"""

def _get_patient_baseline(patient_data, patient_id, medical_record_id):
    """
    Get the baseline data for each patient's corresponding medical record.
    """
    patient_baseline = patient_data[
        (patient_data['patient_id'] == patient_id) &
        (patient_data['medical_record_id'] == medical_record_id)
    ]
    if len(patient_baseline) == 0:
        print(f"WARN: No baseline data found for patient {patient_id}, record {medical_record_id}. Skipping.")
        return None
    return patient_baseline.iloc[0]

def _check_target_achievement(measurements_within_window, baseline_weight, weight_loss_target):
    """
    Check if the weight loss target was achieved in some of the given measurements.
    """
    # Set default to False/None
    target_achieved = False
    first_success_measurement = None
    # Calculate weight loss percentage for each measurement in the window, 
    # and check if it meets the target
    for _, row in measurements_within_window.iterrows():
        current_weight = row['weight_kg']
        if baseline_weight is not None and baseline_weight > 0:
            current_weight_loss = ((baseline_weight - current_weight) / baseline_weight) * 100
            if round(current_weight_loss, 2) >= weight_loss_target:
                target_achieved = True
                first_success_measurement = row
                break # Stop at the first success; if that is not identified, target_achieved remains False as by default
    return target_achieved, first_success_measurement

def _determine_final_measurement(target_achieved, first_success_row, measurements_around_cutoff,
                                measurements_within_window, baseline_date, window_center):
    """
    Determine the final measurement based on success or censoring (ie. completion without success) rules.
    """
    # Set final measurement to None by default
    final_measurement = None
    # Set target final date based on the given time window
    target_date = baseline_date + timedelta(days=window_center)
    # If weight loss target was achieved at any point of the followup time window,
    # use the first success measurement as the final measurement.  
    if target_achieved:
        final_measurement = first_success_row
    # In case of no success, the date closest to the target date is used as the final measurement. 
    elif not measurements_around_cutoff.empty:
        measurements_around_cutoff = measurements_around_cutoff.copy()
        measurements_around_cutoff['distance_to_center'] = abs(
            (measurements_around_cutoff['measurement_date'] - target_date).dt.days
        )
        closest_measurement_idx = measurements_around_cutoff['distance_to_center'].idxmin()
        final_measurement = measurements_around_cutoff.loc[closest_measurement_idx]
    # In case of no success nor completion (delayed dropout), use the last available measurement as the final measurement
    elif not measurements_within_window.empty:
        final_measurement = measurements_within_window.sort_values('measurement_date').iloc[-1]
    # Else: Instant dropout, final_measurement remains None, 
    # and is set to the baseline measurementin the calculate_outcome_metrics function.
    return final_measurement

def _calculate_outcome_metrics(baseline_row, final_measurement_row):
    """
    Calculate follow-up lenght and weight loss percentage based on baseline and final measurement.
    """
    # Identify the baseline measurement
    baseline_date = baseline_row['measurement_date']
    baseline_weight = baseline_row['weight_kg']
    # In patients that have at least one followup measurement, identify the end date and final weight, 
    # to calculate followup length and weight loss in kg and %
    if final_measurement_row is not None:
        end_date = final_measurement_row['measurement_date']
        final_weight = final_measurement_row['weight_kg']
        followup_period = (end_date - baseline_date).days
        weight_loss_kg = baseline_weight - final_weight
        weight_loss_pct = ((baseline_weight - final_weight) / baseline_weight) * 100
    # In patients that have no followup measurement (instant dropouts), 
    # the end date and final weight are set to the baseline values, 
    # and followup length and weight loss are set to 0. 
    else: 
        end_date = baseline_date
        final_weight = baseline_weight
        followup_period = 0
        weight_loss_kg = 0
        weight_loss_pct = 0
    return {
        'end_date': end_date,
        'final_weight': final_weight,
        'followup_period': followup_period,
        'weight_loss_kg': weight_loss_kg,
        'weight_loss_pct': round(weight_loss_pct, 2)
    }

"""
CORE ANALYSIS FUNCTION
"""

def calculate_weight_loss_outcome(patient_data, filtered_measurements, weight_loss_target, window_center, window_span):
    """
    Calculate weight loss outcomes for each patient in a survival analysis-ready format. 
    """
    # Initialize an empty list to store results, and group measurements by patient and medical record ID
    results = []
    grouped_measurements = filtered_measurements.groupby(['patient_id', 'medical_record_id'])
    # Iterate through each group within measurements. 
    for (patient_id, medical_record_id), group in grouped_measurements:
        # 1. Identify baseline measurement date and weight
        baseline_row = _get_patient_baseline(patient_data, patient_id, medical_record_id)
        if baseline_row is None: continue
        baseline_date = baseline_row['measurement_date']
        baseline_weight = baseline_row['weight_kg']
        # 2. Define observation time windows and group measurements within the defined window
        # Calculations are done for both the complete observation period, 
        # as well as the period strictry around the cutoff date, within the defined permissivity window. 
        min_window_date = baseline_date + timedelta(days=(window_center - window_span))
        max_window_date = baseline_date + timedelta(days=(window_center + window_span))
        measurements_within_window = group[
            (group['measurement_date'] > baseline_date) &
            (group['measurement_date'] <= max_window_date)
        ].sort_values('measurement_date')
        measurements_around_cutoff = group[
            (group['measurement_date'] >= min_window_date) &
            (group['measurement_date'] <= max_window_date)
        ]
        # 3. Check whether target weight loss was achieved in the defined time window
        target_achieved, first_success_row = _check_target_achievement(
            measurements_within_window, baseline_weight, weight_loss_target
        )
        # 4. Identify the last measurement date within the time window,
        # whether based on target achievment or followup completion
        final_measurement_row = _determine_final_measurement(
            target_achieved, first_success_row, measurements_around_cutoff,
            measurements_within_window, baseline_date, window_center
        )
        # 5. Check for dropout status - instant dropouts are those who have no second measurement, 
        # while delayed dropouts are those who have not reached target, 
        # and their final measurement is before the cutoff window. 
        is_instant_dropout = final_measurement_row is None
        is_delayed_dropout = (not target_achieved and
                              final_measurement_row is not None and
                              final_measurement_row['measurement_date'] < min_window_date)
        dropout = is_instant_dropout or is_delayed_dropout
        success = target_achieved
        # 6. Calculate metrics like final date and weight, followup length and weight lost. 
        outcome_metrics = _calculate_outcome_metrics(baseline_row, final_measurement_row)

        """ARE WE GOING TO MODIFY AVG CALCS?"""

        # # 7. 
        # # --- NEW: Calculate metrics based *always* on the last measurement within the window ---
        # actual_last_measurement_row = None
        # if not measurements_within_window.empty:
        #     actual_last_measurement_row = measurements_within_window.iloc[-1]

        # # Use the same helper, but pass the actual last measurement row
        # actual_end_metrics = _calculate_outcome_metrics(baseline_row, actual_last_measurement_row)
        # actual_wl_pct_at_window_end = actual_end_metrics['weight_loss_pct']
        # # --- End NEW ---



        # 8. Assemble the result - this is where the output tables' columns are defined. 
        # If additional variables are inserted at an earlier part of the code, 
        # they need to be mentioned here as well. 
        result = {
            'patient_id': patient_id,
            'medical_record_id': medical_record_id,
            'baseline_date': baseline_date,
            'end_date': outcome_metrics['end_date'],
            'followup_period': outcome_metrics['followup_period'],
            'baseline_weight': baseline_weight,
            'final_weight': outcome_metrics['final_weight'],
            'weight_loss_kg': outcome_metrics['weight_loss_kg'],
            'weight_loss_pct': outcome_metrics['weight_loss_pct'],
            # 
            f'{weight_loss_target}pct_achieved': success,
            'dropout': dropout,
            # Add baseline characteristics safely using .get()
            'sex': baseline_row.get('sex'),
            'age': baseline_row.get('age'),
            'height_m': baseline_row.get('height_m'),
            'baseline_bmi': baseline_row.get('bmi'),
            'hunger': baseline_row.get('hunger'),
            'satiety': baseline_row.get('satiety'),
            'emotional_eating': baseline_row.get('emotional_eating'),
            'emotional_eating_value': baseline_row.get('emotional_eating_value'),
            'quantity_control': baseline_row.get('quantity_control'),
            'impulse_control': baseline_row.get('impulse_control')
        }
        results.append(result)
    return pd.DataFrame(results)


"""
MAIN ORCHESTRATION FUNCTION
"""

def generate_survival_analysis_datasets(input_connection, output_connection, weight_loss_targets, time_windows, window_span=10):
    """
    The main function to orchestrate the survival analysis process, calling all previously defined functions in an organized manner. 
    Generate survival analysis datasets for multiple weight loss targets and observation time windows.
    Targets and timeframes are defined in the configuration section at the beginning of the code module.
    Save data to a separate SQLite database. 
    """
    # 1. Load and prepare input data
    measurements = load_measurements(input_connection)
    medical_records = load_medical_records(input_connection)
    patient_data, filtered_measurements = prepare_patient_data(measurements, medical_records)
    if patient_data.empty:
        print("ERROR: Prepared patient data is empty. Cannot proceed.")
        return {}, pd.DataFrame()
    # 2. Calculate weight loss outcomes for each target-timeframe combination. 
    # Targets and timeframes are defined in the config section of the script. 
    # Initialize a results dictionary and a list for summary statistics. 
    results = {}
    summary_list = []
    for window in sorted(time_windows):
        for target in sorted(weight_loss_targets):
            # Name each instance accordingly, where sa stands for survival analysis, 
            # and the numbers indicate the time window and target percentage.
            name = f"sa_{window}d_{target}p"
            print(f"--- Processing: {name} ---") # Minimal progress indication
            result_df = calculate_weight_loss_outcome(
                patient_data,
                filtered_measurements,
                target,
                window,
                window_span # Defined in config - the permissivity window around the followup cutoff time
            )
            results[name] = result_df
            # Add the calculated instances to the summary statistics list. 
            if not result_df.empty:
                summary_row = {
                    'analysis_name': name,
                    'weight_loss_target': target,
                    'time_window': window,
                    'total_patients': len(result_df),
                    'achieved_target': int(result_df[f'{target}pct_achieved'].sum()),
                    'dropout_count': int(result_df['dropout'].sum()),
                    'avg_weight_loss_pct': result_df['weight_loss_pct'].mean() if not result_df['weight_loss_pct'].isnull().all() else 0
                }
                summary_list.append(summary_row)
            else:
                 print(f"WARN: No results generated for {name}. Skipping summary entry.")
    # Turn the summary statistics list into a data frame
    summary = pd.DataFrame(summary_list)

    # 3. Save the analysis results (9 tables by default) to the SQLite database defined in the config section
    print(f"--- Saving results to output database: {output_db_path} ---")
    # Save individual tables
    for name, df in results.items():
        print(f"Saving table: {name} ({len(df)} rows)")
        df.to_sql(name, output_connection, if_exists='replace', index=False)
    # Save the summary stats table in the database as well
    print(f"Saving summary table: survival_analysis_summary ({len(summary)} rows)")
    summary.to_sql('survival_analysis_summary', output_connection, if_exists='replace', index=False)
    output_connection.commit() # Ensure changes are saved
    print("--- All results saved successfully ---")
    return results, summary

"""
EXECUTION BLOCK
"""

"""
This part of the code calls all the functions and executes the code. 
Currently it has a lot of debug messages and error handling, which might be an overkill, 
but overall, it should not affect transparency of the code.
"""

if __name__ == "__main__":
    print("========== Starting Survival Analysis Script ==========")
    # By default, connections are set to None, and will be established in the try block.
    input_conn = None
    output_conn = None
    try:
        # Connect to in-and output databases
        print(f"Connecting to input database: {input_db_path}")
        if not os.path.exists(input_db_path):
             raise FileNotFoundError(f"Input database not found at {input_db_path}")
        input_conn = sqlite3.connect(input_db_path)
        print(f"Connecting to output database: {output_db_path}")
        output_conn = sqlite3.connect(output_db_path)
        # Run the main analysis function
        results, summary = generate_survival_analysis_datasets(
            input_conn,
            output_conn,
            weight_loss_targets,
            time_windows,
            window_span
        )
        # Display summary if successful
        if not summary.empty:
            print("\n--- Survival Analysis Summary ---")
            print(summary.to_string()) # Use print for console display
            print("--- End Summary ---")
        else:
            print("WARN: Analysis completed, but the summary table is empty.")
        print(f"Analysis data successfully generated and saved to {output_db_path}")

    # Minimal error handling for critical failures
    except FileNotFoundError as e:
        print(f"ERROR: Database file not found - {e}")
    except sqlite3.Error as e:
        print(f"ERROR: SQLite database error - {e}")
    except ValueError as e:
        print(f"ERROR: Data processing error - {e}")
    except Exception as e:
        print(f"ERROR: An unexpected error occurred - {e}")
        # Consider adding traceback for debugging complex errors:
        # import traceback
        # print(traceback.format_exc())
    finally:
        # Ensure connections are closed
        print("Closing database connections...")
        if input_conn:
            input_conn.close()
        if output_conn:
            output_conn.close()
        print("========== Survival Analysis Script Finished ==========")


Connecting to input database: C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\emotional_all_notna.sqlite
Connecting to output database: C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\survival_analysis.sqlite
--- Processing: sa_40d_5p ---
--- Processing: sa_40d_10p ---
--- Processing: sa_40d_15p ---
--- Processing: sa_60d_5p ---
--- Processing: sa_60d_10p ---
--- Processing: sa_60d_15p ---
--- Processing: sa_80d_5p ---
--- Processing: sa_80d_10p ---
--- Processing: sa_80d_15p ---
--- Saving results to output database: C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\survival_analysis.sqlite ---
Saving table: sa_40d_5p (1664 rows)
Saving table: sa_40d_10p (1664 rows)
Saving table: sa_40d_15p (1664 rows)
Saving table: sa_60d_5p (1664 rows)
Saving table: sa_60d_10p (1664 rows)
Saving table: sa_60d_15p (1664 rows)
Saving table: sa_80d_5p (1664 rows)
Saving table: sa_80d_10p (1664 rows)
Saving table: sa_80d_15p (1664 rows)
Saving summary table: su

#### Reconsidered input structure for SA 18-19 Apr v1 - this is still not perfect

In [None]:
import pandas as pd
import os
import sqlite3
from datetime import timedelta
import numpy as np # Import numpy for NaN

"""
CONFIGURATION
"""
# Define directories and database paths - paper1_directory should be defined
# in the first cell of this notebook chapter
input_db_path = os.path.join(paper1_directory, 'emotional_all_notna.sqlite')
input_measurements = "measurements_with_metadata"
input_medical_records = "medical_records_complete"
# Modified: Output database path remains the same, but will contain a different table structure
output_db_path = os.path.join(paper1_directory, 'shukishukishuuu.sqlite')
output_table_name = "shukishukishuuu" # Define the name for the single wide table

# Define analysis parameters
weight_loss_targets = [5, 10, 15]     # Weight loss target percentages
time_windows = [40, 60, 80]       # Time windows (centers) in days
window_span = 10                   # Permissible span around windows (+/- days)

# Define the variables stored in medical_records_complete that are relevant for the analysis.
# These include basic metadata like patient and record ID,
# basic factors such as age and sex,
# as well as the emotional and eating behavior variables pivotal to the research question.
# The list can be amended on demand -
# for example, right now it does not include medical record creating and closing dates.
relevant_medical_values = ['patient_id', 'medical_record_id', 'sex', 'age',
                             'height_m', 'baseline_bmi', 'hunger', 'satiety', 'emotional_eating',
                             'emotional_eating_value', 'quantity_control', 'impulse_control', 'genomics_sample_id']

"""
DATA LOADING & PREPARATION
"""

def load_measurements(connection):
    """
    Load measurements from the measurement_with_metadata table;
    make sure key values are in the correct format.
    """
    query = f"SELECT * FROM {input_measurements}"
    measurements = pd.read_sql_query(query, connection)
    measurements['measurement_date'] = pd.to_datetime(measurements['measurement_date'], errors='coerce')
    measurements['weight_kg'] = pd.to_numeric(measurements['weight_kg'], errors='coerce')
    # Ensure sorting for consistent 'first'/'last' operations later
    measurements = measurements.sort_values(['patient_id', 'medical_record_id', 'measurement_date'])
    return measurements

def load_medical_records(connection):
    """
    Load medical records from the medical_records_complete table;
    make sure date values are in datetime format.
    The exact columns to be used are defined in the prepare_patient_data function.
    """
    query = f"SELECT * FROM {input_medical_records}"
    medical_records = pd.read_sql_query(query, connection)
    # Ensure relevant date columns are datetime
    for col in ['medical_record_creation_date', 'baseline_measurement_date', 'final_measurement_date']:
         if col in medical_records.columns:
             medical_records[col] = pd.to_datetime(medical_records[col], errors='coerce')
    return medical_records

def prepare_patient_data(measurements, medical_records):
    """
    Filter measurements to only include those from the earliest medical record for each patient.
    Merge measurements with relevant medical record data, including the pivotal eating behavior scores.
    """
    # Filter measurements to only include those from the first treatment record of each patient
    # Ensure measurements are sorted before grouping
    measurements = measurements.sort_values(['patient_id', 'measurement_date'])
    earliest_records_with_data = measurements.groupby('patient_id')['medical_record_id'].first().reset_index()

    filtered_measurements = pd.merge(
        measurements,
        earliest_records_with_data,
        on=['patient_id', 'medical_record_id'],
        how='inner'
    )
    # Ensure filtered_measurements are sorted for baseline identification
    filtered_measurements = filtered_measurements.sort_values(['patient_id', 'medical_record_id', 'measurement_date'])

    # Identify the baseline measurement row for each record (the first measurement in the filtered set)
    baseline_data_rows = filtered_measurements.groupby(['patient_id', 'medical_record_id']).first().reset_index()

    # Select only relevant columns from medical_records to merge
    cols_to_select = [col for col in relevant_medical_values if col in medical_records.columns]
    medical_record_subset = medical_records[cols_to_select]

    # Merge baseline measurement info with the selected medical record data
    # Use baseline_data_rows which contains the actual first measurement details
    prepared_data = pd.merge(
        baseline_data_rows[['patient_id', 'medical_record_id', 'measurement_date', 'weight_kg']], # Get baseline date/weight from actual first measurement
        medical_record_subset,
        on=['patient_id', 'medical_record_id'],
        how='left' # Keep all baseline measurements
    )
    # Rename columns for clarity before returning
    prepared_data = prepared_data.rename(columns={'measurement_date': 'baseline_date', 'weight_kg': 'baseline_weight_kg'})

    # Add baseline_bmi from medical_records if available and not already present from measurement merge
    if 'baseline_bmi' in medical_record_subset.columns and 'baseline_bmi' not in prepared_data.columns:
         prepared_data = pd.merge(
              prepared_data,
              medical_record_subset[['patient_id', 'medical_record_id', 'baseline_bmi']],
              on=['patient_id', 'medical_record_id'],
              how='left'
         )

    return prepared_data, filtered_measurements


"""
MODIFIED: CALCULATE WEIGHT LOSS OUTCOMES FOR WIDE TABLE
"""

def _get_patient_baseline(patient_data, patient_id, medical_record_id):
    """
    Get the prepared baseline data for a patient's specific medical record.
    (Function remains largely the same, but operates on the prepared_data structure)
    """
    patient_baseline = patient_data[
        (patient_data['patient_id'] == patient_id) &
        (patient_data['medical_record_id'] == medical_record_id)
    ]
    if patient_baseline.empty:
        print(f"WARN: No baseline data found for patient {patient_id}, record {medical_record_id}. Skipping.")
        return None
    # Ensure we return a Series for consistent access
    return patient_baseline.iloc[0]

def _calculate_wl_metrics(baseline_weight, current_weight):
    """ Helper to calculate weight loss kg and % """
    if pd.isna(baseline_weight) or pd.isna(current_weight) or baseline_weight == 0:
        return np.nan, np.nan
    wl_kg = baseline_weight - current_weight
    wl_pct = (wl_kg / baseline_weight) * 100
    return wl_kg, round(wl_pct, 2)

# Removed _check_target_achievement, _determine_final_measurement, _calculate_outcome_metrics
# Their logic will be integrated into the main calculation function.

def calculate_wide_patient_outcomes(prepared_patient_data, filtered_measurements, weight_loss_targets, time_windows, window_span):
    """
    Calculate all required outcomes (fixed-time and time-to-event) for each patient
    and return a list of dictionaries, each representing a row in the wide table.
    """
    results_list = []
    # Group all measurements to only get the relevant (earliest) medical record per patient
    grouped_measurements = filtered_measurements.groupby(['patient_id', 'medical_record_id'])

    for (patient_id, medical_record_id), group in grouped_measurements:
        # 1. Get Baseline Info
        baseline_info = _get_patient_baseline(prepared_patient_data, patient_id, medical_record_id)
        if baseline_info is None:
            continue

        baseline_date = baseline_info['baseline_date']
        baseline_weight = baseline_info['baseline_weight_kg']

        # Initialize result dictionary with baseline info
        result = {
            'patient_ID': patient_id, 
            'medical_record_ID': medical_record_id,
            'baseline_date': baseline_date,
            'baseline_weight_kg': baseline_weight,
            # Add other baseline characteristics safely using .get() or direct access
            'sex': baseline_info.get('sex'),
            'age': baseline_info.get('age'),
            'height_m': baseline_info.get('height_m'),
            'baseline_bmi': baseline_info.get('baseline_bmi'), # Get BMI from prepared data
            'hunger': baseline_info.get('hunger'),
            'satiety': baseline_info.get('satiety'),
            'emotional_eating': baseline_info.get('emotional_eating'),
            'emotional_eating_value': baseline_info.get('emotional_eating_value'),
            'quantity_control': baseline_info.get('quantity_control'),
            'impulse_control': baseline_info.get('impulse_control')
        }

        # Get all measurements *after* baseline for this group
        followup_measurements = group[group['measurement_date'] > baseline_date].sort_values('measurement_date')

        # 2. Calculate Overall Follow-up Metrics
        if not followup_measurements.empty:
            last_measurement = followup_measurements.iloc[-1]
            result['last_aval_date'] = last_measurement['measurement_date']
            result['total_followup_days'] = (last_measurement['measurement_date'] - baseline_date).days
            result['last_aval_weight_kg'] = last_measurement['weight_kg']
            wl_kg, wl_pct = _calculate_wl_metrics(baseline_weight, last_measurement['weight_kg'])
            result['total_wl_kg'] = wl_kg
            result['total_wl_%'] = wl_pct
        else:
            # Handle instant dropouts (only baseline measurement exists)
            result['last_aval_date'] = baseline_date
            result['total_followup_days'] = 0
            result['last_aval_weight_kg'] = baseline_weight
            result['total_wl_kg'] = 0.0
            result['total_wl_%'] = 0.0

        # 3. Calculate Fixed-Timepoint Metrics (for each time window)
        for window_center in time_windows:
            min_window_date = baseline_date + timedelta(days=(window_center - window_span))
            max_window_date = baseline_date + timedelta(days=(window_center + window_span))
            target_date = baseline_date + timedelta(days=window_center)

            # Find measurements around the cutoff window
            measurements_around_cutoff = followup_measurements[
                (followup_measurements['measurement_date'] >= min_window_date) &
                (followup_measurements['measurement_date'] <= max_window_date)
            ]

            measurement_at_window = None
            is_dropout_at_window = True # Assume dropout unless proven otherwise

            if not measurements_around_cutoff.empty:
                # Find measurement closest to the window center
                measurements_around_cutoff = measurements_around_cutoff.copy()
                measurements_around_cutoff['distance_to_center'] = abs(
                    (measurements_around_cutoff['measurement_date'] - target_date).dt.days
                )
                closest_measurement_idx = measurements_around_cutoff['distance_to_center'].idxmin()
                measurement_at_window = measurements_around_cutoff.loc[closest_measurement_idx]
                is_dropout_at_window = False # Measurement found within/around window
            elif not followup_measurements.empty:
                 # No measurement in cutoff window, check if *any* followup exists before the window
                 last_followup_before_window = followup_measurements[followup_measurements['measurement_date'] < min_window_date]
                 if not last_followup_before_window.empty:
                      # Use the latest measurement before the window started
                      measurement_at_window = last_followup_before_window.iloc[-1]
                      # Still considered dropout *for this window* as they didn't reach it
                      is_dropout_at_window = True
                 else:
                      """!!!CHECK this logic, it might get tricky!!!"""
                      # Followup exists, but only *after* the window (unlikely but possible)
                      # Treat as dropout for this window, no relevant measurement
                      measurement_at_window = None
                      is_dropout_at_window = True
            else:
                 # Instant dropout (no followup measurements at all)
                 measurement_at_window = None
                 is_dropout_at_window = True


            # Populate results for this time window
            prefix = f"{window_center}d"
            if measurement_at_window is not None:
                result[f'{prefix}_weight_kg'] = measurement_at_window['weight_kg']
                wl_kg, wl_pct = _calculate_wl_metrics(baseline_weight, measurement_at_window['weight_kg'])
                result[f'wl_{prefix}_kg'] = wl_kg
                result[f'wl_{prefix}_%'] = wl_pct
                result[f'{prefix}_date'] = measurement_at_window['measurement_date']
                result[f'days_to_{prefix}_measurement'] = (measurement_at_window['measurement_date'] - baseline_date).days
            else:
                # No relevant measurement found for this window
                result[f'{prefix}_weight_kg'] = np.nan
                result[f'wl_{prefix}_kg'] = np.nan
                result[f'wl_{prefix}_%'] = np.nan
                result[f'{prefix}_date'] = pd.NaT
                result[f'days_to_{prefix}_measurement'] = np.nan

            result[f'{prefix}_dropout'] = is_dropout_at_window


        # 4. Calculate Time-to-Event Metrics (for each weight loss target)
        for target in weight_loss_targets:
            target_achieved = False
            first_success_measurement = None
            actual_wl_at_success = np.nan

            # Check all followup measurements for the first success
            for _, row in followup_measurements.iterrows():
                current_weight = row['weight_kg']
                if baseline_weight is not None and baseline_weight > 0:
                    current_weight_loss_pct = ((baseline_weight - current_weight) / baseline_weight) * 100
                    if round(current_weight_loss_pct, 2) >= target:
                        target_achieved = True
                        first_success_measurement = row
                        actual_wl_at_success = round(current_weight_loss_pct, 2)
                        break # Stop at the first success

            # Populate results for this target
            prefix = f"{target}%_wl"
            result[f'{prefix}_achieved'] = target_achieved
            if target_achieved and first_success_measurement is not None:
                result[f'{prefix}_%'] = actual_wl_at_success
                result[f'{prefix}_date'] = first_success_measurement['measurement_date']
                result[f'days_to_{prefix}'] = (first_success_measurement['measurement_date'] - baseline_date).days
            else:
                result[f'{prefix}_%'] = np.nan
                result[f'{prefix}_date'] = pd.NaT
                result[f'days_to_{prefix}'] = np.nan # Or perhaps total_followup_days if censored? Check analysis plan needs. NaN is safer.

        results_list.append(result)

    return pd.DataFrame(results_list)


"""
MODIFIED: MAIN ORCHESTRATION FUNCTION FOR WIDE TABLE
"""

def generate_wide_analysis_dataset(input_connection, output_connection, weight_loss_targets, time_windows, window_span=10):
    """
    Orchestrates the process to generate the single wide survival analysis dataset.
    Loads data, prepares patient baseline info, calculates all outcomes per patient,
    and saves the resulting wide DataFrame to the output database.
    """
    # 1. Load and prepare input data
    print("Loading measurements...")
    measurements = load_measurements(input_connection)
    print("Loading medical records...")
    medical_records = load_medical_records(input_connection)
    print("Preparing patient data...")
    prepared_data, filtered_measurements = prepare_patient_data(measurements, medical_records)

    if prepared_data.empty:
        print("ERROR: Prepared patient data is empty. Cannot proceed.")
        return pd.DataFrame() # Return empty DataFrame

    # 2. Calculate wide outcomes for all patients
    print("Calculating wide outcomes for all patients...")
    wide_results_df = calculate_wide_patient_outcomes(
        prepared_data,
        filtered_measurements,
        weight_loss_targets,
        time_windows,
        window_span
    )

    # 3. Save the single wide table
    if not wide_results_df.empty:
        print(f"--- Saving results to output database: {output_db_path} ---")
        print(f"Saving table: {output_table_name} ({len(wide_results_df)} rows)")
        wide_results_df.to_sql(output_table_name, output_connection, if_exists='replace', index=False)
        output_connection.commit() # Ensure changes are saved
        print("--- Wide table saved successfully ---")
    else:
        print("WARN: No results generated. Output table will be empty or not created.")

    # Removed the old summary logic based on multiple tables
    # A new summary could be generated from wide_results_df if needed

    return wide_results_df # Return the generated DataFrame

"""
EXECUTION BLOCK (Modified to call the new main function)
"""

if __name__ == "__main__":
    print("========== Generating Survival Analysis Input Dataset ==========")
    input_conn = None
    output_conn = None
    try:
        # Connect to in-and output databases
        print(f"Connecting to input database: {input_db_path}")
        if not os.path.exists(input_db_path):
             raise FileNotFoundError(f"Input database not found at {input_db_path}")
        input_conn = sqlite3.connect(input_db_path)

        print(f"Connecting to output database: {output_db_path}")
        output_conn = sqlite3.connect(output_db_path)

        # Run the new main analysis function
        wide_df = generate_wide_analysis_dataset(
            input_conn,
            output_conn,
            weight_loss_targets,
            time_windows,
            window_span
        )

        # Display basic info if successful
        if not wide_df.empty:
            print("\n--- Survival Analysis Input Table Generation Summary ---")
            print(f"Generated table '{output_table_name}' with {len(wide_df)} rows and {len(wide_df.columns)} columns.")
            # print(wide_df.head().to_string()) # Optionally print head
            print("--- End Summary ---")
        else:
            print("WARN: Analysis completed, but the resulting DataFrame is empty.")

        print(f"Analysis data saved to {output_db_path}")

    # Error handling remains the same
    except FileNotFoundError as e:
        print(f"ERROR: Database file not found - {e}")
    except sqlite3.Error as e:
        print(f"ERROR: SQLite database error - {e}")
    except ValueError as e:
        print(f"ERROR: Data processing error - {e}")
    except Exception as e:
        print(f"ERROR: An unexpected error occurred - {e}")
        # Consider adding traceback for debugging complex errors:
        # import traceback
        # print(traceback.format_exc())
    finally:
        # Ensure connections are closed
        print("Closing database connections...")
        if input_conn:
            input_conn.close()
        if output_conn:
            output_conn.close()
        print("========== Survival Analysis Input Data GenerationFinished ==========")

Connecting to input database: C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\emotional_all_notna.sqlite
Connecting to output database: C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\shukishukishuuu.sqlite
Loading measurements...
Loading medical records...
Preparing patient data...
Calculating wide outcomes for all patients...
--- Saving results to output database: C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\shukishukishuuu.sqlite ---
Saving table: shukishukishuuu (1664 rows)
--- Wide table saved successfully ---

--- Wide Analysis Table Generation Summary ---
Generated table 'shukishukishuuu' with 1664 rows and 49 columns.
--- End Summary ---
Analysis data successfully generated and saved to C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\shukishukishuuu.sqlite
Closing database connections...


#### Reconsidered input structure for SA 18-19 Apr v2 - revise this well, I think this is the one!

In [None]:
# This code should be placed in a new cell in your Jupyter Notebook.
# It incorporates the requested changes into the previous wide-format script.

import pandas as pd
import os
import sqlite3
from datetime import timedelta
import numpy as np # Import numpy for NaN

"""
CONFIGURATION
"""
# Define directories and database paths - paper1_directory should be defined
# in the first cell of this notebook chapter
input_db_path = os.path.join(paper1_directory, 'emotional_all_notna.sqlite')
input_measurements = "measurements_with_metadata"
input_medical_records = "medical_records_complete"
# Modified: Output database path remains the same, but will contain a different table structure
output_db_path = os.path.join(paper1_directory, 'survival_analysis_wide_v2.sqlite') # Changed filename for new version
output_table_name = "survival_analysis_wide_v2" # Changed table name for new version

# Define analysis parameters
weight_loss_targets = [5, 10, 15]     # Weight loss target percentages
time_windows = [40, 60, 80]       # Time windows (centers) in days
window_span = 10                   # Permissible span around windows (+/- days)

# Define the variables stored in medical_records_complete that are relevant for the analysis.
# These include basic metadata like patient and record ID,
# basic factors such as age and sex,
# as well as the emotional and eating behavior variables pivotal to the research question.
# The list can be amended on demand -
# for example, right now it does not include medical record creating and closing dates.
# --- MODIFIED: Added 'dietitian_visits' ---
relevant_medical_values = ['patient_id', 'medical_record_id', 'dietitian_visits', 'sex', 'age',
                             'height_m', 'baseline_bmi', 'hunger', 'satiety', 'emotional_eating',
                             'emotional_eating_value', 'quantity_control', 'impulse_control', 'weight_gain_cause', 'genomics_sample_id']

# --- NEW: Define the desired final column order ---
# This list determines the order of columns in the final output table.
# It includes baseline info, overall followup, adherence proxies,
# fixed-time metrics, time-to-event metrics, and finally confounders/predictors.
FINAL_COLUMN_ORDER = [
    # IDs
    'patient_ID', 'medical_record_ID',
    # Followup period info and adherence proxies
    'baseline_date', 'last_aval_date', 'total_followup_days', 'nr_visits', 'nr_total_measurements', 'avg_days_between_measurements',
    # Total weight change
    'baseline_weight_kg', 'last_aval_weight_kg', 'total_wl_kg', 'total_wl_%', 'baseline_bmi', 'final_bmi', 'bmi_reduction',
    # Fixed-Timepoint Analysis (Dynamically generated columns will be inserted here by logic below)
    # Time-to-Event Analysis (Dynamically generated columns will be inserted here by logic below)
    # Confounders / Predictors (from medical records)
    'sex', 'age', 'height_m', 'hunger', 'satiety',
    'emotional_eating', 'emotional_eating_value', 'quantity_control', 'impulse_control', 'weight_gain_cause', 'genomics_sample_id'
]

# --- NEW: Dynamically insert fixed-time and time-to-event columns into FINAL_COLUMN_ORDER ---
fixed_time_cols = []
for window in time_windows:
    prefix = f"{window}d"
    fixed_time_cols.extend([
        f'{prefix}_weight_kg', f'wl_{prefix}_kg', f'wl_{prefix}_%',
        f'{prefix}_date', f'days_to_{prefix}_measurement', f'{prefix}_dropout'
    ])

time_to_event_cols = []
for target in weight_loss_targets:
    prefix = f"{target}%_wl"
    time_to_event_cols.extend([
        f'{prefix}_achieved', f'{prefix}_%', f'{prefix}_date', f'days_to_{prefix}'
    ])

# Find the insertion point (after adherence proxies)
insert_point = FINAL_COLUMN_ORDER.index('total_wl_%') + 1
# Insert the dynamic columns
FINAL_COLUMN_ORDER[insert_point:insert_point] = fixed_time_cols + time_to_event_cols


"""
DATA LOADING & PREPARATION (No changes needed here, kept for context)
"""

def load_measurements(connection):
    """
    Load measurements from the measurement_with_metadata table;
    make sure key values are in the correct format.
    """
    query = f"SELECT * FROM {input_measurements}"
    measurements = pd.read_sql_query(query, connection)
    measurements['measurement_date'] = pd.to_datetime(measurements['measurement_date'], errors='coerce')
    measurements['weight_kg'] = pd.to_numeric(measurements['weight_kg'], errors='coerce')
    # Ensure sorting for consistent 'first'/'last' operations later
    measurements = measurements.sort_values(['patient_id', 'medical_record_id', 'measurement_date'])
    return measurements

def load_medical_records(connection):
    """
    Load medical records from the medical_records_complete table;
    make sure date values are in datetime format.
    The exact columns to be used are defined in the prepare_patient_data function.
    """
    query = f"SELECT * FROM {input_medical_records}"
    medical_records = pd.read_sql_query(query, connection)
    # Ensure relevant date columns are datetime
    for col in ['medical_record_creation_date', 'baseline_measurement_date', 'final_measurement_date']:
         if col in medical_records.columns:
             medical_records[col] = pd.to_datetime(medical_records[col], errors='coerce')
    # --- NEW: Ensure dietitian_visits is numeric ---
    if 'dietitian_visits' in medical_records.columns:
        medical_records['dietitian_visits'] = pd.to_numeric(medical_records['dietitian_visits'], errors='coerce')
    return medical_records

def prepare_patient_data(measurements, medical_records):
    """
    Filter measurements to only include those from the earliest medical record for each patient.
    Merge measurements with relevant medical record data, including the pivotal eating behavior scores.
    """
    # Filter measurements to only include those from the first treatment record of each patient
    # Ensure measurements are sorted before grouping
    measurements = measurements.sort_values(['patient_id', 'measurement_date'])
    earliest_records_with_data = measurements.groupby('patient_id')['medical_record_id'].first().reset_index()

    filtered_measurements = pd.merge(
        measurements,
        earliest_records_with_data,
        on=['patient_id', 'medical_record_id'],
        how='inner'
    )
    # Ensure filtered_measurements are sorted for baseline identification
    filtered_measurements = filtered_measurements.sort_values(['patient_id', 'medical_record_id', 'measurement_date'])

    # Identify the baseline measurement row for each record (the first measurement in the filtered set)
    baseline_data_rows = filtered_measurements.groupby(['patient_id', 'medical_record_id']).first().reset_index()

    # Select only relevant columns from medical_records to merge
    # --- MODIFIED: Now includes 'dietitian_visits' if added to relevant_medical_values ---
    cols_to_select = [col for col in relevant_medical_values if col in medical_records.columns]
    medical_record_subset = medical_records[cols_to_select]

    # Merge baseline measurement info with the selected medical record data
    # Use baseline_data_rows which contains the actual first measurement details
    prepared_data = pd.merge(
        baseline_data_rows[['patient_id', 'medical_record_id', 'measurement_date', 'weight_kg']], # Get baseline date/weight from actual first measurement
        medical_record_subset,
        on=['patient_id', 'medical_record_id'],
        how='left' # Keep all baseline measurements
    )
    # Rename columns for clarity before returning
    prepared_data = prepared_data.rename(columns={'measurement_date': 'baseline_date', 'weight_kg': 'baseline_weight_kg'})

    # Add baseline_bmi from medical_records if available and not already present from measurement merge
    if 'baseline_bmi' in medical_record_subset.columns and 'baseline_bmi' not in prepared_data.columns:
         prepared_data = pd.merge(
              prepared_data,
              medical_record_subset[['patient_id', 'medical_record_id', 'baseline_bmi']],
              on=['patient_id', 'medical_record_id'],
              how='left'
         )

    return prepared_data, filtered_measurements


"""
MODIFIED: CALCULATE WEIGHT LOSS OUTCOMES FOR WIDE TABLE
"""

def _get_patient_baseline(patient_data, patient_id, medical_record_id):
    """
    Get the prepared baseline data for a patient's specific medical record.
    (Function remains largely the same, but operates on the prepared_data structure)
    """
    patient_baseline = patient_data[
        (patient_data['patient_id'] == patient_id) &
        (patient_data['medical_record_id'] == medical_record_id)
    ]
    if patient_baseline.empty:
        print(f"WARN: No baseline data found for patient {patient_id}, record {medical_record_id}. Skipping.")
        return None
    # Ensure we return a Series for consistent access
    return patient_baseline.iloc[0]

def _calculate_wl_metrics(baseline_weight, current_weight):
    """ Helper to calculate weight loss kg and % """
    if pd.isna(baseline_weight) or pd.isna(current_weight) or baseline_weight == 0:
        return np.nan, np.nan
    wl_kg = baseline_weight - current_weight
    wl_pct = (wl_kg / baseline_weight) * 100
    return wl_kg, round(wl_pct, 2)

# Removed _check_target_achievement, _determine_final_measurement, _calculate_outcome_metrics
# Their logic will be integrated into the main calculation function.

def calculate_wide_patient_outcomes(prepared_patient_data, filtered_measurements, weight_loss_targets, time_windows, window_span):
    """
    Calculate all required outcomes (fixed-time and time-to-event) for each patient
    and return a list of dictionaries, each representing a row in the wide table.
    """
    results_list = []
    # Group all measurements for the relevant (earliest) medical record per patient
    grouped_measurements = filtered_measurements.groupby(['patient_id', 'medical_record_id'])

    for (patient_id, medical_record_id), group in grouped_measurements:
        # --- NEW: Calculate total number of measurements for this group (record) ---
        num_measurements = len(group)

        # 1. Get Baseline Info
        baseline_info = _get_patient_baseline(prepared_patient_data, patient_id, medical_record_id)
        if baseline_info is None:
            continue

        baseline_date = baseline_info['baseline_date']
        baseline_weight = baseline_info['baseline_weight_kg']

        # Initialize result dictionary with baseline info AND adherence proxies
        result = {
            # IDs
            'patient_ID': patient_id, # Match Excel header
            'medical_record_ID': medical_record_id, # Match Excel header
            # Baseline
            'baseline_date': baseline_date,
            'baseline_weight_kg': baseline_weight,
            # Adherence Proxies (Initialize early)
            'nr_visits': baseline_info.get('dietitian_visits'), # Get from merged baseline data
            'nr_total_measurements': num_measurements, # Use calculated value
            'avg_days_between_measurements': np.nan, # Initialize, calculated later
            # Confounders / Predictors (Initialize early)
            'sex': baseline_info.get('sex'),
            'age': baseline_info.get('age'),
            'height_m': baseline_info.get('height_m'),
            'baseline_bmi': baseline_info.get('baseline_bmi'), # Get BMI from prepared data
            'hunger': baseline_info.get('hunger'),
            'satiety': baseline_info.get('satiety'),
            'emotional_eating': baseline_info.get('emotional_eating'),
            'emotional_eating_value': baseline_info.get('emotional_eating_value'),
            'quantity_control': baseline_info.get('quantity_control'),
            'impulse_control': baseline_info.get('impulse_control'), 
            'weight_gain_cause': baseline_info.get('weight_gain_cause'),
            'genomics_sample_id': baseline_info.get('genomics_sample_id')
        }

        # Get all measurements *after* baseline for this group
        followup_measurements = group[group['measurement_date'] > baseline_date].sort_values('measurement_date')

        # 2. Calculate Overall Follow-up Metrics
        if not followup_measurements.empty:
            last_measurement = followup_measurements.iloc[-1]
            result['last_aval_date'] = last_measurement['measurement_date']
            # --- MODIFIED: Calculate inclusive total_followup_days (+1) ---
            result['total_followup_days'] = (last_measurement['measurement_date'] - baseline_date).days + 1
            result['last_aval_weight_kg'] = last_measurement['weight_kg']
            wl_kg, wl_pct = _calculate_wl_metrics(baseline_weight, last_measurement['weight_kg'])
            result['total_wl_kg'] = wl_kg
            result['total_wl_%'] = wl_pct
        else:
            # Handle instant dropouts (only baseline measurement exists)
            result['last_aval_date'] = baseline_date
            # --- MODIFIED: Set total_followup_days to 1 for instant dropouts ---
            result['total_followup_days'] = 1
            result['last_aval_weight_kg'] = baseline_weight
            result['total_wl_kg'] = 0.0
            result['total_wl_%'] = 0.0

        # --- NEW: Calculate avg_days_between_measurements ---
        # Requires total_followup_days and nr_total_measurements
        if result['nr_total_measurements'] is not None and result['nr_total_measurements'] > 1:
             # Use the calculated inclusive total_followup_days
             total_days = result['total_followup_days']
             # Calculate average based on number of intervals (N measurements = N-1 intervals)
             # Ensure total_days is treated as the span covering N points (so N-1 intervals)
             # If total_days is 1 (instant dropout), num_measurements is 1, this condition isn't met.
             # If total_days > 1, num_measurements must be >= 2.
             result['avg_days_between_measurements'] = round( (total_days -1) / (result['nr_total_measurements'] - 1) , 2) if (result['nr_total_measurements'] - 1) > 0 else np.nan
        else:
             # If 0 or 1 measurements, average days between is undefined
             result['avg_days_between_measurements'] = np.nan


        # 3. Calculate Fixed-Timepoint Metrics (for each time window)
        for window_center in time_windows:
            min_window_date = baseline_date + timedelta(days=(window_center - window_span))
            max_window_date = baseline_date + timedelta(days=(window_center + window_span))
            target_date = baseline_date + timedelta(days=window_center)

            # Find measurements strictly *within* the cutoff window span
            measurements_around_cutoff = followup_measurements[
                (followup_measurements['measurement_date'] >= min_window_date) &
                (followup_measurements['measurement_date'] <= max_window_date)
            ]

            measurement_for_window = None # The measurement to use for this window's stats
            is_dropout_at_window = True # Assume dropout unless a measurement is found *in* the window

            if not measurements_around_cutoff.empty:
                # Measurement exists within the window span. Find the closest one.
                is_dropout_at_window = False # Found measurement in window
                measurements_around_cutoff = measurements_around_cutoff.copy()
                measurements_around_cutoff['distance_to_center'] = abs(
                    (measurements_around_cutoff['measurement_date'] - target_date).dt.days
                )
                closest_measurement_idx = measurements_around_cutoff['distance_to_center'].idxmin()
                measurement_for_window = measurements_around_cutoff.loc[closest_measurement_idx]
            # else: If measurements_around_cutoff is empty, is_dropout_at_window remains True.
            # No need to check for measurements *before* the window for populating data,
            # as per the requirement to leave fields blank for dropouts.

            # --- MODIFIED: Populate results based *strictly* on dropout status for the window ---
            prefix = f"{window_center}d"
            result[f'{prefix}_dropout'] = is_dropout_at_window # Set dropout status first

            if is_dropout_at_window:
                # If dropout for this window, set all related metrics to NaN/NaT
                result[f'{prefix}_weight_kg'] = np.nan
                result[f'wl_{prefix}_kg'] = np.nan
                result[f'wl_{prefix}_%'] = np.nan
                result[f'{prefix}_date'] = pd.NaT
                result[f'days_to_{prefix}_measurement'] = np.nan
            else:
                # If NOT dropout, populate metrics using the found measurement_for_window
                # (This block only runs if measurement_for_window is not None)
                result[f'{prefix}_weight_kg'] = measurement_for_window['weight_kg']
                wl_kg, wl_pct = _calculate_wl_metrics(baseline_weight, measurement_for_window['weight_kg'])
                result[f'wl_{prefix}_kg'] = wl_kg
                result[f'wl_{prefix}_%'] = wl_pct
                result[f'{prefix}_date'] = measurement_for_window['measurement_date']
                # Calculate days from baseline to this specific measurement
                result[f'days_to_{prefix}_measurement'] = (measurement_for_window['measurement_date'] - baseline_date).days + 1 # Inclusive days


        # 4. Calculate Time-to-Event Metrics (for each weight loss target)
        # (No changes needed in this section based on discussion)
        for target in weight_loss_targets:
            target_achieved = False
            first_success_measurement = None
            actual_wl_at_success = np.nan

            # Check all followup measurements for the first success
            for _, row in followup_measurements.iterrows():
                current_weight = row['weight_kg']
                if baseline_weight is not None and baseline_weight > 0:
                    current_weight_loss_pct = ((baseline_weight - current_weight) / baseline_weight) * 100
                    if round(current_weight_loss_pct, 2) >= target:
                        target_achieved = True
                        first_success_measurement = row
                        actual_wl_at_success = round(current_weight_loss_pct, 2)
                        break # Stop at the first success

            # Populate results for this target
            prefix = f"{target}%_wl"
            result[f'{prefix}_achieved'] = target_achieved
            if target_achieved and first_success_measurement is not None:
                result[f'{prefix}_%'] = actual_wl_at_success
                result[f'{prefix}_date'] = first_success_measurement['measurement_date']
                # Calculate inclusive days to achieve target
                result[f'days_to_{prefix}'] = (first_success_measurement['measurement_date'] - baseline_date).days + 1
            else:
                result[f'{prefix}_%'] = np.nan
                result[f'{prefix}_date'] = pd.NaT
                result[f'days_to_{prefix}'] = np.nan

        results_list.append(result)

    return pd.DataFrame(results_list)


"""
MODIFIED: MAIN ORCHESTRATION FUNCTION FOR WIDE TABLE
"""

def generate_wide_analysis_dataset(input_connection, output_connection, weight_loss_targets, time_windows, window_span=10):
    """
    Orchestrates the process to generate the single wide survival analysis dataset.
    Loads data, prepares patient baseline info, calculates all outcomes per patient,
    reorders columns, and saves the resulting wide DataFrame to the output database.
    """
    # 1. Load and prepare input data
    print("Loading measurements...")
    measurements = load_measurements(input_connection)
    print("Loading medical records...")
    medical_records = load_medical_records(input_connection)
    print("Preparing patient data...")
    prepared_data, filtered_measurements = prepare_patient_data(measurements, medical_records)

    if prepared_data.empty:
        print("ERROR: Prepared patient data is empty. Cannot proceed.")
        return pd.DataFrame() # Return empty DataFrame

    # 2. Calculate wide outcomes for all patients
    print("Calculating wide outcomes for all patients...")
    wide_results_df = calculate_wide_patient_outcomes(
        prepared_data,
        filtered_measurements,
        weight_loss_targets,
        time_windows,
        window_span
    )

    # --- NEW: Reorder columns according to FINAL_COLUMN_ORDER defined in config ---
    if not wide_results_df.empty:
        print("Reordering columns...")
        # Ensure all columns in FINAL_COLUMN_ORDER exist in the DataFrame, handle potential missing ones
        final_columns_present = [col for col in FINAL_COLUMN_ORDER if col in wide_results_df.columns]
        missing_cols = [col for col in FINAL_COLUMN_ORDER if col not in wide_results_df.columns]
        if missing_cols:
             print(f"WARN: The following columns defined in FINAL_COLUMN_ORDER were not found in the generated data and will be skipped: {missing_cols}")
        # Add any columns present in DataFrame but not in FINAL_COLUMN_ORDER to the end, just in case
        extra_cols = [col for col in wide_results_df.columns if col not in final_columns_present]
        if extra_cols:
             print(f"WARN: The following columns were generated but not included in FINAL_COLUMN_ORDER; they will be added to the end: {extra_cols}")

        wide_results_df = wide_results_df[final_columns_present + extra_cols]


    # 3. Save the single wide table
    if not wide_results_df.empty:
        print(f"--- Saving results to output database: {output_db_path} ---")
        print(f"Saving table: {output_table_name} ({len(wide_results_df)} rows)")
        wide_results_df.to_sql(output_table_name, output_connection, if_exists='replace', index=False)
        output_connection.commit() # Ensure changes are saved
        print("--- Wide table saved successfully ---")
    else:
        print("WARN: No results generated. Output table will be empty or not created.")

    # Removed the old summary logic based on multiple tables
    # A new summary could be generated from wide_results_df if needed

    return wide_results_df # Return the generated DataFrame

"""
EXECUTION BLOCK (Modified to call the new main function and use new output names)
"""

if __name__ == "__main__":
    print("========== Generating Survival Analysis Input Dataset (Wide Format v2) ==========") # Updated title
    input_conn = None
    output_conn = None
    try:
        # Connect to in-and output databases
        print(f"Connecting to input database: {input_db_path}")
        if not os.path.exists(input_db_path):
             raise FileNotFoundError(f"Input database not found at {input_db_path}")
        input_conn = sqlite3.connect(input_db_path)

        print(f"Connecting to output database: {output_db_path}")
        output_conn = sqlite3.connect(output_db_path)

        # Run the new main analysis function
        wide_df = generate_wide_analysis_dataset(
            input_conn,
            output_conn,
            weight_loss_targets,
            time_windows,
            window_span
        )

        # Display basic info if successful
        if not wide_df.empty:
            print("\n--- Survival Analysis Input Table Generation Summary ---")
            print(f"Generated table '{output_table_name}' with {len(wide_df)} rows and {len(wide_df.columns)} columns.")
            # print(wide_df.head().to_string()) # Optionally print head
            print("--- End Summary ---")
        else:
            print("WARN: Analysis completed, but the resulting DataFrame is empty.")

        print(f"Analysis data saved to {output_db_path}")

    # Error handling remains the same
    except FileNotFoundError as e:
        print(f"ERROR: Database file not found - {e}")
    except sqlite3.Error as e:
        print(f"ERROR: SQLite database error - {e}")
    except ValueError as e:
        print(f"ERROR: Data processing error - {e}")
    except Exception as e:
        print(f"ERROR: An unexpected error occurred - {e}")
        # Consider adding traceback for debugging complex errors:
        # import traceback
        # print(traceback.format_exc())
    finally:
        # Ensure connections are closed
        print("Closing database connections...")
        if input_conn:
            input_conn.close()
        if output_conn:
            output_conn.close()
        print("========== Survival Analysis Input Data Generation Finished (Wide Format v2) ==========") # Updated title


Connecting to input database: C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\emotional_all_notna.sqlite
Connecting to output database: C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\survival_analysis_wide_v3.sqlite
Loading measurements...
Loading medical records...
Preparing patient data...
Calculating wide outcomes for all patients...
Reordering columns...
--- Saving results to output database: C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\survival_analysis_wide_v3.sqlite ---
Saving table: survival_analysis_wide_v3 (1664 rows)
--- Wide table saved successfully ---

--- Survival Analysis Input Table Generation Summary ---
Generated table 'survival_analysis_wide_v3' with 1664 rows and 56 columns.
--- End Summary ---
Analysis data saved to C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\survival_analysis_wide_v3.sqlite
Closing database connections...


#### Actually, this may be the one, Apr 20 - added final BMI too, and fetching BMI from measurements instead of medical records

In [16]:
####
# filepath: c:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\paper1_script.ipynb
# This code should be placed in a new cell in your Jupyter Notebook.
# It incorporates the requested changes for BMI handling into the wide-format script.

import pandas as pd
import os
import sqlite3
from datetime import timedelta
import numpy as np # Import numpy for NaN

"""
CONFIGURATION
"""
# Define directories and database paths - paper1_directory should be defined
# in the first cell of this notebook chapter
input_db_path = os.path.join(paper1_directory, 'emotional_all_notna.sqlite')
input_measurements = "measurements_with_metadata"
input_medical_records = "medical_records_complete"
# Output database path and table name
output_db_path = os.path.join(paper1_directory, 'survival_analysis.sqlite') # Changed filename for new version
output_table_name = "sa_input_table" # Changed table name for new version

# Define analysis parameters
weight_loss_targets = [5, 10, 15]     # Weight loss target percentages
time_windows = [40, 60, 80]       # Time windows (centers) in days
window_span = 10                   # Permissible span around windows (+/- days)

# Define the variables stored in medical_records_complete that are relevant for the analysis.
# --- MODIFIED: Removed 'baseline_bmi' as it will be sourced from measurements ---
relevant_medical_values = ['patient_id', 'medical_record_id', 'dietitian_visits', 'sex', 'age',
                             'height_m', 'hunger', 'satiety', 'emotional_eating', # Removed 'baseline_bmi'
                             'emotional_eating_value', 'quantity_control', 'impulse_control', 'weight_gain_cause', 'genomics_sample_id']

# --- MODIFIED: Renamed to output_column_order and updated for BMI ---
# This list determines the order of columns in the final output table.
output_column_order = [
    # IDs
    'patient_ID', 'medical_record_ID',
    # Followup period info and adherence proxies
    'baseline_date', 'last_aval_date', 'total_followup_days', 'nr_visits', 'nr_total_measurements', 'avg_days_between_measurements',
    # Total weight change & BMI change
    'baseline_weight_kg', 'last_aval_weight_kg', 'total_wl_kg', 'total_wl_%',
    'baseline_bmi', 'final_bmi', 'bmi_reduction', # Added BMI columns here
    # Fixed-Timepoint Analysis (Dynamically generated columns will be inserted here by logic below)
    # Time-to-Event Analysis (Dynamically generated columns will be inserted here by logic below)
    # Confounders / Predictors (from medical records)
    'sex', 'age', 'height_m', 'hunger', 'satiety', # Removed 'baseline_bmi' from here
    'emotional_eating', 'emotional_eating_value', 'quantity_control', 'impulse_control', 'weight_gain_cause', 'genomics_sample_id'
]

# --- MODIFIED: Dynamically insert fixed-time and time-to-event columns into output_column_order ---
fixed_time_cols = []
for window in time_windows:
    prefix = f"{window}d"
    # --- MODIFIED: Excluded BMI calculation for fixed timepoints ---
    fixed_time_cols.extend([
        f'{prefix}_weight_kg', f'wl_{prefix}_kg', f'wl_{prefix}_%',
        f'{prefix}_date', f'days_to_{prefix}_measurement', f'{prefix}_dropout'
        # Removed: f'{prefix}_bmi'
    ])

time_to_event_cols = []
for target in weight_loss_targets:
    prefix = f"{target}%_wl"
    time_to_event_cols.extend([
        f'{prefix}_achieved', f'{prefix}_%', f'{prefix}_date', f'days_to_{prefix}'
    ])

# Find the insertion point (after BMI reduction)
# --- MODIFIED: Insertion point updated ---
insert_point = output_column_order.index('bmi_reduction') + 1
# Insert the dynamic columns
output_column_order[insert_point:insert_point] = fixed_time_cols + time_to_event_cols


"""
DATA LOADING & PREPARATION
"""

def load_measurements(connection):
    """
    Load measurements from the measurement_with_metadata table;
    make sure key values are in the correct format.
    --- MODIFIED: Ensure BMI is numeric ---
    """
    query = f"SELECT * FROM {input_measurements}"
    measurements = pd.read_sql_query(query, connection)
    measurements['measurement_date'] = pd.to_datetime(measurements['measurement_date'], errors='coerce')
    measurements['weight_kg'] = pd.to_numeric(measurements['weight_kg'], errors='coerce')
    # --- NEW: Ensure BMI is numeric ---
    measurements['bmi'] = pd.to_numeric(measurements['bmi'], errors='coerce')
    # Ensure sorting for consistent 'first'/'last' operations later
    measurements = measurements.sort_values(['patient_id', 'medical_record_id', 'measurement_date'])
    return measurements

def load_medical_records(connection):
    """
    Load medical records from the medical_records_complete table;
    make sure date values are in datetime format.
    The exact columns to be used are defined in the prepare_patient_data function.
    """
    query = f"SELECT * FROM {input_medical_records}"
    medical_records = pd.read_sql_query(query, connection)
    # Ensure relevant date columns are datetime
    for col in ['medical_record_creation_date', 'baseline_measurement_date', 'final_measurement_date']:
         if col in medical_records.columns:
             medical_records[col] = pd.to_datetime(medical_records[col], errors='coerce')
    # Ensure dietitian_visits is numeric
    if 'dietitian_visits' in medical_records.columns:
        medical_records['dietitian_visits'] = pd.to_numeric(medical_records['dietitian_visits'], errors='coerce')
    return medical_records

def prepare_patient_data(measurements, medical_records):
    """
    Filter measurements to only include those from the earliest medical record for each patient.
    Merge measurements with relevant medical record data, including the pivotal eating behavior scores.
    --- MODIFIED: Fetch baseline BMI from measurements ---
    """
    # Filter measurements to only include those from the first treatment record of each patient
    measurements = measurements.sort_values(['patient_id', 'measurement_date'])
    earliest_records_with_data = measurements.groupby('patient_id')['medical_record_id'].first().reset_index()

    filtered_measurements = pd.merge(
        measurements,
        earliest_records_with_data,
        on=['patient_id', 'medical_record_id'],
        how='inner'
    )
    filtered_measurements = filtered_measurements.sort_values(['patient_id', 'medical_record_id', 'measurement_date'])

    # Identify the baseline measurement row for each record (the first measurement in the filtered set)
    # --- MODIFIED: Include 'bmi' ---
    baseline_data_rows = filtered_measurements.groupby(['patient_id', 'medical_record_id']).first().reset_index()

    # Select only relevant columns from medical_records to merge
    cols_to_select = [col for col in relevant_medical_values if col in medical_records.columns]
    medical_record_subset = medical_records[cols_to_select]

    # Merge baseline measurement info with the selected medical record data
    # --- MODIFIED: Include 'bmi' from baseline_data_rows ---
    prepared_data = pd.merge(
        baseline_data_rows[['patient_id', 'medical_record_id', 'measurement_date', 'weight_kg', 'bmi']], # Added 'bmi'
        medical_record_subset,
        on=['patient_id', 'medical_record_id'],
        how='left' # Keep all baseline measurements
    )
    # Rename columns for clarity before returning
    # --- MODIFIED: Rename 'bmi' to 'baseline_bmi' ---
    prepared_data = prepared_data.rename(columns={
        'measurement_date': 'baseline_date',
        'weight_kg': 'baseline_weight_kg',
        'bmi': 'baseline_bmi' # Rename BMI from measurement
    })

    # --- REMOVED: Separate merge for baseline_bmi from medical_records is no longer needed ---
    # if 'baseline_bmi' in medical_record_subset.columns and 'baseline_bmi' not in prepared_data.columns:
    #      ... (old merge logic removed) ...

    return prepared_data, filtered_measurements


"""
MODIFIED: CALCULATE WEIGHT LOSS OUTCOMES FOR WIDE TABLE
"""

def _get_patient_baseline(patient_data, patient_id, medical_record_id):
    """
    Get the prepared baseline data for a patient's specific medical record.
    """
    patient_baseline = patient_data[
        (patient_data['patient_id'] == patient_id) &
        (patient_data['medical_record_id'] == medical_record_id)
    ]
    if patient_baseline.empty:
        print(f"WARN: No baseline data found for patient {patient_id}, record {medical_record_id}. Skipping.")
        return None
    return patient_baseline.iloc[0]

def _calculate_wl_metrics(baseline_weight, current_weight):
    """ Helper to calculate weight loss kg and % """
    if pd.isna(baseline_weight) or pd.isna(current_weight) or baseline_weight == 0:
        return np.nan, np.nan
    wl_kg = baseline_weight - current_weight
    wl_pct = (wl_kg / baseline_weight) * 100
    return wl_kg, round(wl_pct, 2)

# --- NEW: Helper to calculate BMI reduction ---
def _calculate_bmi_reduction(baseline_bmi, final_bmi):
    """ Helper to calculate BMI reduction (final - baseline) """
    if pd.isna(baseline_bmi) or pd.isna(final_bmi):
        return np.nan
    return round(final_bmi - baseline_bmi, 2)


def calculate_wide_patient_outcomes(prepared_patient_data, filtered_measurements, weight_loss_targets, time_windows, window_span):
    """
    Calculate all required outcomes (fixed-time and time-to-event) for each patient
    and return a list of dictionaries, each representing a row in the wide table.
    --- MODIFIED: Incorporates BMI from measurements and BMI reduction ---
    """
    results_list = []
    grouped_measurements = filtered_measurements.groupby(['patient_id', 'medical_record_id'])

    for (patient_id, medical_record_id), group in grouped_measurements:
        num_measurements = len(group)

        # 1. Get Baseline Info
        baseline_info = _get_patient_baseline(prepared_patient_data, patient_id, medical_record_id)
        if baseline_info is None:
            continue

        baseline_date = baseline_info['baseline_date']
        baseline_weight = baseline_info['baseline_weight_kg']
        # --- MODIFIED: Get baseline_bmi from prepared_data (sourced from measurement) ---
        baseline_bmi = baseline_info.get('baseline_bmi') # Use .get() for safety

        # Initialize result dictionary
        result = {
            # IDs
            'patient_ID': patient_id,
            'medical_record_ID': medical_record_id,
            # Baseline
            'baseline_date': baseline_date,
            'baseline_weight_kg': baseline_weight,
            'baseline_bmi': baseline_bmi, # Add baseline BMI here
            # Adherence Proxies
            'nr_visits': baseline_info.get('dietitian_visits'),
            'nr_total_measurements': num_measurements,
            'avg_days_between_measurements': np.nan,
            # Confounders / Predictors
            'sex': baseline_info.get('sex'),
            'age': baseline_info.get('age'),
            'height_m': baseline_info.get('height_m'),
            'hunger': baseline_info.get('hunger'),
            'satiety': baseline_info.get('satiety'),
            'emotional_eating': baseline_info.get('emotional_eating'),
            'emotional_eating_value': baseline_info.get('emotional_eating_value'),
            'quantity_control': baseline_info.get('quantity_control'),
            'impulse_control': baseline_info.get('impulse_control'),
            'weight_gain_cause': baseline_info.get('weight_gain_cause'),
            'genomics_sample_id': baseline_info.get('genomics_sample_id')
        }

        followup_measurements = group[group['measurement_date'] > baseline_date].sort_values('measurement_date')

        # 2. Calculate Overall Follow-up Metrics
        if not followup_measurements.empty:
            last_measurement = followup_measurements.iloc[-1]
            result['last_aval_date'] = last_measurement['measurement_date']
            result['total_followup_days'] = (last_measurement['measurement_date'] - baseline_date).days + 1
            result['last_aval_weight_kg'] = last_measurement['weight_kg']
            wl_kg, wl_pct = _calculate_wl_metrics(baseline_weight, last_measurement['weight_kg'])
            result['total_wl_kg'] = wl_kg
            result['total_wl_%'] = wl_pct
            # --- NEW: Get final BMI and calculate reduction ---
            result['final_bmi'] = last_measurement.get('bmi') # Get BMI from last measurement
            result['bmi_reduction'] = _calculate_bmi_reduction(result['baseline_bmi'], result['final_bmi'])
        else:
            # Handle instant dropouts
            result['last_aval_date'] = baseline_date
            result['total_followup_days'] = 1
            result['last_aval_weight_kg'] = baseline_weight
            result['total_wl_kg'] = 0.0
            result['total_wl_%'] = 0.0
            # --- NEW: Set final BMI and reduction for instant dropouts ---
            result['final_bmi'] = result['baseline_bmi'] # Final BMI is baseline BMI
            result['bmi_reduction'] = 0.0 # No change

        # Calculate avg_days_between_measurements
        if result['nr_total_measurements'] is not None and result['nr_total_measurements'] > 1:
             total_days = result['total_followup_days']
             result['avg_days_between_measurements'] = round( (total_days -1) / (result['nr_total_measurements'] - 1) , 2) if (result['nr_total_measurements'] - 1) > 0 else np.nan
        else:
             result['avg_days_between_measurements'] = np.nan


        # 3. Calculate Fixed-Timepoint Metrics (for each time window)
        # --- MODIFIED: Excluded BMI calculation ---
        for window_center in time_windows:
            min_window_date = baseline_date + timedelta(days=(window_center - window_span))
            max_window_date = baseline_date + timedelta(days=(window_center + window_span))
            target_date = baseline_date + timedelta(days=window_center)

            measurements_around_cutoff = followup_measurements[
                (followup_measurements['measurement_date'] >= min_window_date) &
                (followup_measurements['measurement_date'] <= max_window_date)
            ]

            measurement_for_window = None
            is_dropout_at_window = True

            if not measurements_around_cutoff.empty:
                is_dropout_at_window = False
                measurements_around_cutoff = measurements_around_cutoff.copy()
                measurements_around_cutoff['distance_to_center'] = abs(
                    (measurements_around_cutoff['measurement_date'] - target_date).dt.days
                )
                closest_measurement_idx = measurements_around_cutoff['distance_to_center'].idxmin()
                measurement_for_window = measurements_around_cutoff.loc[closest_measurement_idx]

            prefix = f"{window_center}d"
            result[f'{prefix}_dropout'] = is_dropout_at_window

            if is_dropout_at_window:
                result[f'{prefix}_weight_kg'] = np.nan
                result[f'wl_{prefix}_kg'] = np.nan
                result[f'wl_{prefix}_%'] = np.nan
                result[f'{prefix}_date'] = pd.NaT
                result[f'days_to_{prefix}_measurement'] = np.nan
                # Removed: result[f'{prefix}_bmi'] = np.nan
            else:
                result[f'{prefix}_weight_kg'] = measurement_for_window['weight_kg']
                wl_kg, wl_pct = _calculate_wl_metrics(baseline_weight, measurement_for_window['weight_kg'])
                result[f'wl_{prefix}_kg'] = wl_kg
                result[f'wl_{prefix}_%'] = wl_pct
                result[f'{prefix}_date'] = measurement_for_window['measurement_date']
                result[f'days_to_{prefix}_measurement'] = (measurement_for_window['measurement_date'] - baseline_date).days + 1
                # Removed: result[f'{prefix}_bmi'] = measurement_for_window.get('bmi')


        # 4. Calculate Time-to-Event Metrics (for each weight loss target)
        # (No changes needed in this section)
        for target in weight_loss_targets:
            target_achieved = False
            first_success_measurement = None
            actual_wl_at_success = np.nan

            for _, row in followup_measurements.iterrows():
                current_weight = row['weight_kg']
                if baseline_weight is not None and baseline_weight > 0:
                    current_weight_loss_pct = ((baseline_weight - current_weight) / baseline_weight) * 100
                    if round(current_weight_loss_pct, 2) >= target:
                        target_achieved = True
                        first_success_measurement = row
                        actual_wl_at_success = round(current_weight_loss_pct, 2)
                        break

            prefix = f"{target}%_wl"
            result[f'{prefix}_achieved'] = target_achieved
            if target_achieved and first_success_measurement is not None:
                result[f'{prefix}_%'] = actual_wl_at_success
                result[f'{prefix}_date'] = first_success_measurement['measurement_date']
                result[f'days_to_{prefix}'] = (first_success_measurement['measurement_date'] - baseline_date).days + 1
            else:
                result[f'{prefix}_%'] = np.nan
                result[f'{prefix}_date'] = pd.NaT
                result[f'days_to_{prefix}'] = np.nan

        results_list.append(result)

    return pd.DataFrame(results_list)


"""
MODIFIED: MAIN ORCHESTRATION FUNCTION FOR WIDE TABLE
"""

def generate_wide_analysis_dataset(input_connection, output_connection, weight_loss_targets, time_windows, window_span=10):
    """
    Orchestrates the process to generate the single wide survival analysis dataset.
    Loads data, prepares patient baseline info, calculates all outcomes per patient,
    reorders columns, and saves the resulting wide DataFrame to the output database.
    --- MODIFIED: Uses output_column_order ---
    """
    # 1. Load and prepare input data
    print("Loading measurements...")
    measurements = load_measurements(input_connection)
    print("Loading medical records...")
    medical_records = load_medical_records(input_connection)
    print("Preparing patient data...")
    prepared_data, filtered_measurements = prepare_patient_data(measurements, medical_records)

    if prepared_data.empty:
        print("ERROR: Prepared patient data is empty. Cannot proceed.")
        return pd.DataFrame()

    # 2. Calculate wide outcomes for all patients
    print("Calculating wide outcomes for all patients...")
    wide_results_df = calculate_wide_patient_outcomes(
        prepared_data,
        filtered_measurements,
        weight_loss_targets,
        time_windows,
        window_span
    )

    # --- MODIFIED: Reorder columns according to output_column_order ---
    if not wide_results_df.empty:
        print("Reordering columns...")
        # Ensure all columns in output_column_order exist in the DataFrame
        final_columns_present = [col for col in output_column_order if col in wide_results_df.columns]
        missing_cols = [col for col in output_column_order if col not in wide_results_df.columns]
        if missing_cols:
             print(f"WARN: The following columns defined in output_column_order were not found and will be skipped: {missing_cols}")
        # Add any extra columns not in the defined order to the end
        extra_cols = [col for col in wide_results_df.columns if col not in final_columns_present]
        if extra_cols:
             print(f"WARN: The following columns were generated but not in output_column_order; adding to the end: {extra_cols}")

        wide_results_df = wide_results_df[final_columns_present + extra_cols]


    # 3. Save the single wide table
    if not wide_results_df.empty:
        print(f"--- Saving results to output database: {output_db_path} ---")
        print(f"Saving table: {output_table_name} ({len(wide_results_df)} rows)")
        wide_results_df.to_sql(output_table_name, output_connection, if_exists='replace', index=False)
        output_connection.commit()
        print("--- Wide table saved successfully ---")
    else:
        print("WARN: No results generated. Output table will be empty or not created.")

    return wide_results_df


"""
EXECUTION BLOCK (Modified to use new output names)
"""

if __name__ == "__main__":
    print("========== Generating Survival Analysis Input Dataset (Wide Format v3) ==========") # Updated title
    input_conn = None
    output_conn = None
    try:
        print(f"Connecting to input database: {input_db_path}")
        if not os.path.exists(input_db_path):
             raise FileNotFoundError(f"Input database not found at {input_db_path}")
        input_conn = sqlite3.connect(input_db_path)

        print(f"Connecting to output database: {output_db_path}")
        output_conn = sqlite3.connect(output_db_path)

        wide_df = generate_wide_analysis_dataset(
            input_conn,
            output_conn,
            weight_loss_targets,
            time_windows,
            window_span
        )

        if not wide_df.empty:
            print("\n--- Survival Analysis Input Table Generation Summary ---")
            print(f"Generated table '{output_table_name}' with {len(wide_df)} rows and {len(wide_df.columns)} columns.")
            # print(wide_df.head().to_string())
            print("--- End Summary ---")
        else:
            print("WARN: Analysis completed, but the resulting DataFrame is empty.")

        print(f"Analysis data saved to {output_db_path}")

    except FileNotFoundError as e:
        print(f"ERROR: Database file not found - {e}")
    except sqlite3.Error as e:
        print(f"ERROR: SQLite database error - {e}")
    except ValueError as e:
        print(f"ERROR: Data processing error - {e}")
    except Exception as e:
        print(f"ERROR: An unexpected error occurred - {e}")
        # import traceback
        # print(traceback.format_exc())
    finally:
        print("Closing database connections...")
        if input_conn:
            input_conn.close()
        if output_conn:
            output_conn.close()
        print("========== Survival Analysis Input Data Generation Finished (Wide Format v3) ==========") # Updated title
####

Connecting to input database: C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\emotional_all_notna.sqlite
Connecting to output database: C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\survival_analysis.sqlite
Loading measurements...
Loading medical records...
Preparing patient data...
Calculating wide outcomes for all patients...
Reordering columns...
--- Saving results to output database: C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\survival_analysis.sqlite ---
Saving table: sa_input_table (1664 rows)
--- Wide table saved successfully ---

--- Survival Analysis Input Table Generation Summary ---
Generated table 'sa_input_table' with 1664 rows and 56 columns.
--- End Summary ---
Analysis data saved to C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\survival_analysis.sqlite
Closing database connections...


### Start studying the survival analysis dataset

#### Generation of summary stats tables in the database - the code is not revised, the output is draft-level, but it is already insightful and seems correct. 

In [8]:
# This code calculates detailed population and outcome summary statistics,
# ensures variables are columns and stats are rows in the final output,
# fixes categorical counts (both population and outcome),
# and saves them to the specified table names in the SQLite database.

import pandas as pd
import sqlite3
import os
import numpy as np

"""
CONFIGURATION
"""
# Define database path, input table name, and output table names
# Ensure 'paper1_directory' is defined in your notebook environment before running this cell
if 'paper1_directory' not in locals() and 'paper1_directory' not in globals():
     raise NameError("'paper1_directory' is not defined. Please define it in a previous cell.")

db_path = os.path.join(paper1_directory, 'survival_analysis.sqlite') # DB with the final wide table
input_table_name = "sa_input_table" # The wide table created previously
output_pop_summary_table = "population_summary"
output_outcome_summary_table = "outcome_summary"

# Define dynamic parameters (should match those used to create sa_input_table)
weight_loss_targets = [5, 10, 15]
time_windows = [40, 60, 80]

"""
HELPER FUNCTIONS
"""

def format_mean_sd(series, decimals=1):
    """Calculates mean and SD, returns formatted string 'mean ± SD'."""
    numeric_series = pd.to_numeric(series, errors='coerce')
    if numeric_series.isnull().all() or numeric_series.empty:
        return np.nan
    mean = numeric_series.mean()
    std = numeric_series.std()
    if pd.isna(mean) or pd.isna(std):
         return np.nan
    return f"{mean:.{decimals}f} ± {std:.{decimals}f}"

def format_n_percent(series, condition_value, total_n, decimals=1):
    """Calculates N and % matching a condition (primarily for strings),
       returns formatted string 'N (X.X%)'. Case-insensitive for strings.
    """
    if total_n == 0:
        return "0 (NaN%)"

    # Primarily designed for string comparison now
    condition_str = str(condition_value).lower()
    try:
        # Convert series to string, strip whitespace, convert to lower case, and compare
        condition_mask = series.astype(str).str.strip().str.lower().eq(condition_str)
        n = condition_mask.sum()
    except Exception as e: # Broad exception catch if string methods fail
        print(f"  Warning: Could not perform string comparison for series '{series.name}' with value '{condition_value}'. Error: {e}. Falling back to direct comparison.")
        # Fallback for non-string types or errors during string conversion
        try:
            condition_mask = (series == condition_value)
            n = condition_mask.sum()
        except TypeError:
             print(f"  Error: Type error during fallback comparison for series '{series.name}' with value '{condition_value}'. Setting N to 0.")
             n = 0
        except Exception as e_fallback:
             print(f"  Error: Unexpected error during fallback comparison for series '{series.name}' with value '{condition_value}'. Error: {e_fallback}. Setting N to 0.")
             n = 0

    percent = (n / total_n) * 100 if total_n > 0 else 0
    return f"{int(n)} ({percent:.{decimals}f}%)" # Ensure n is int for formatting


def get_describe_stats(df, columns):
    """Runs describe() and extracts key stats for specified columns."""
    actual_cols = [col for col in columns if col in df.columns]
    if not actual_cols:
        return pd.DataFrame()
    numeric_subset = df[actual_cols].select_dtypes(include=np.number)
    if numeric_subset.empty:
        # print(f"  Warning: No numeric columns found among {actual_cols} for describe.") # Less verbose
        return pd.DataFrame()
    described = numeric_subset.describe(percentiles=[.25, .5, .75]).transpose()
    stats_to_keep = ['count', 'mean', 'std', 'min', '25%', '50%', '75%', 'max']
    described = described[[col for col in stats_to_keep if col in described.columns]]
    return described

"""
SUMMARY TABLE GENERATION FUNCTIONS
"""

def generate_population_summary(df):
    """Generates the population summary data (Variables as index, Stats as columns)."""
    # NOTE: This function returns the data BEFORE transposition.
    if df.empty:
        print("Input DataFrame is empty. Cannot generate population summary.")
        return pd.DataFrame()

    total_n = len(df)
    summary_data_by_var = {}
    print(f"Generating population summary data (N={total_n})...")

    def add_stats(var_name, stats_dict):
        if var_name not in summary_data_by_var:
            summary_data_by_var[var_name] = {}
        summary_data_by_var[var_name].update(stats_dict)

    add_stats('Total Population', {'N': total_n})

    # --- Numerical Summaries ---
    numeric_cols_pop = [
        'age', 'height_m', 'baseline_weight_kg', 'last_aval_weight_kg',
        'baseline_bmi', 'final_bmi',
        'emotional_eating_value', 'quantity_control', 'impulse_control'
    ]
    pop_described = get_describe_stats(df, numeric_cols_pop)
    for col in pop_described.index:
        stats_for_col = {}
        stats_for_col['N'] = int(pop_described.loc[col, 'count']) if 'count' in pop_described.columns else np.nan
        stats_for_col['Mean (SD)'] = format_mean_sd(df[col])
        if '50%' in pop_described.columns and '25%' in pop_described.columns and '75%' in pop_described.columns:
             median, iqr_25, iqr_75 = pop_described.loc[col, ['50%', '25%', '75%']]
             if not pd.isna([median, iqr_25, iqr_75]).any():
                  stats_for_col['Median (IQR)'] = f"{median:.1f} ({iqr_25:.1f}-{iqr_75:.1f})"
             else: stats_for_col['Median (IQR)'] = np.nan
        else: stats_for_col['Median (IQR)'] = np.nan
        if 'min' in pop_described.columns and 'max' in pop_described.columns:
             min_val, max_val = pop_described.loc[col, ['min', 'max']]
             if not pd.isna([min_val, max_val]).any():
                  stats_for_col['Min - Max'] = f"{min_val:.1f} - {max_val:.1f}"
             else: stats_for_col['Min - Max'] = np.nan
        else: stats_for_col['Min - Max'] = np.nan
        add_stats(col, stats_for_col)

    # --- Categorical/Boolean Summaries ---
    # Sex ('Female') - Use format_n_percent for string check
    if 'sex' in df.columns:
        add_stats('Sex: Female', {'N (%)': format_n_percent(df['sex'], 'Female', total_n)})
    else: print("  Warning: 'sex' column not found.")

    # Yes/No Questions - Use format_n_percent for string check
    for col_name, display_name in [('hunger', 'Hunger: Yes'), ('satiety', 'Satiety: Yes'), ('emotional_eating', 'Emotional Eating: Yes')]:
        if col_name in df.columns:
             add_stats(display_name, {'N (%)': format_n_percent(df[col_name], 'Yes', total_n)})
        else: print(f"  Warning: '{col_name}' column not found.")

    # Availability Checks (Not NULL) - Direct calculation
    for col_name, display_name in [('weight_gain_cause', 'Weight Gain Cause Available'), ('genomics_sample_id', 'Genomics Sample ID Available')]:
        if col_name in df.columns:
            n_not_null = df[col_name].notna().sum()
            percent_not_null = (n_not_null / total_n) * 100 if total_n > 0 else 0
            add_stats(display_name, {'N (%)': f"{n_not_null} ({percent_not_null:.1f}%)"})
        else: print(f"  Warning: '{col_name}' column not found.")

    # Convert dictionary to DataFrame (Variables as index, Stats as columns)
    summary_df = pd.DataFrame.from_dict(summary_data_by_var, orient='index')
    summary_df.index.name = 'Variable'

    # Reorder STAT columns (optional, done before transpose)
    desired_col_order = ['N', 'N (%)', 'Mean (SD)', 'Median (IQR)', 'Min - Max']
    existing_cols_ordered = [col for col in desired_col_order if col in summary_df.columns]
    remaining_cols = [col for col in summary_df.columns if col not in existing_cols_ordered]
    final_col_order = existing_cols_ordered + remaining_cols
    summary_df = summary_df[final_col_order]

    return summary_df


def generate_outcome_summary(df, weight_targets, time_windows_list):
    """Generates the outcome summary data (Variables as index, Stats as columns)."""
    # NOTE: This function returns the data BEFORE transposition.
    if df.empty:
        print("Input DataFrame is empty. Cannot generate outcome summary.")
        return pd.DataFrame()

    total_n = len(df)
    summary_data_by_var = {}
    print(f"Generating outcome summary data (N={total_n})...")

    def add_stats(var_name, stats_dict):
        if var_name not in summary_data_by_var:
            summary_data_by_var[var_name] = {}
        for key, value in stats_dict.items():
            summary_data_by_var[var_name][key] = value

    add_stats('Total Population', {'N': total_n})

    # --- Overall Adherence & Outcome Summaries ---
    numeric_cols_outcome = [
        'total_followup_days', 'nr_visits', 'nr_total_measurements', 'avg_days_between_measurements',
        'total_wl_kg', 'total_wl_%', 'bmi_reduction'
    ]
    outcome_described = get_describe_stats(df, numeric_cols_outcome)
    for col in outcome_described.index:
        stats_for_col = {}
        stats_for_col['N'] = int(outcome_described.loc[col, 'count']) if 'count' in outcome_described.columns else np.nan
        stats_for_col['Mean (SD)'] = format_mean_sd(df[col])
        if '50%' in outcome_described.columns and '25%' in outcome_described.columns and '75%' in outcome_described.columns:
             median, iqr_25, iqr_75 = outcome_described.loc[col, ['50%', '25%', '75%']]
             if not pd.isna([median, iqr_25, iqr_75]).any():
                  stats_for_col['Median (IQR)'] = f"{median:.1f} ({iqr_25:.1f}-{iqr_75:.1f})"
             else: stats_for_col['Median (IQR)'] = np.nan
        else: stats_for_col['Median (IQR)'] = np.nan
        if 'min' in outcome_described.columns and 'max' in outcome_described.columns:
             min_val, max_val = outcome_described.loc[col, ['min', 'max']]
             if not pd.isna([min_val, max_val]).any():
                  stats_for_col['Min - Max'] = f"{min_val:.1f} - {max_val:.1f}"
             else: stats_for_col['Min - Max'] = np.nan
        else: stats_for_col['Min - Max'] = np.nan
        add_stats(col, stats_for_col)

    # --- Specific Outcome Metrics ---
    # Instant Dropouts (total_followup_days == 1) - Direct calculation assuming numeric/bool
    if 'total_followup_days' in df.columns:
        try:
            # Attempt direct comparison (works for numbers, might work for bools if 1 used)
            instant_dropout_mask = (df['total_followup_days'] == 1)
            n_instant_dropout = instant_dropout_mask.sum()
            percent_instant = (n_instant_dropout / total_n) * 100 if total_n > 0 else 0
            add_stats('Instant Dropouts', {'N (%)': f"{n_instant_dropout} ({percent_instant:.1f}%)"})
        except Exception as e:
            print(f"  Warning: Could not calculate Instant Dropouts directly. Error: {e}. Trying format_n_percent.")
            # Fallback to string comparison if direct fails
            add_stats('Instant Dropouts', {'N (%)': format_n_percent(df['total_followup_days'], 1, total_n)})
    else: print("  Warning: 'total_followup_days' column not found for Instant Dropout calculation.")

    # --- Dynamic Time Window Metrics ---
    for window in time_windows_list:
        dropout_col = f'{window}d_dropout'
        wl_kg_col = f'wl_{window}d_kg'
        wl_pct_col = f'wl_{window}d_%'
        completer_var_name = f'Completers at {window}d'
        wl_kg_var_name = f'WL (kg) at {window}d [Completers]'
        wl_pct_var_name = f'WL (%) at {window}d [Completers]'

        if dropout_col in df.columns:
            # N (%) Completers (Not Dropout) - Direct calculation assuming boolean False
            try:
                completers_mask = (df[dropout_col] == False) # Explicitly check for boolean False
                n_completers = completers_mask.sum()
                percent_completers = (n_completers / total_n) * 100 if total_n > 0 else 0
                add_stats(completer_var_name, {'N (%)': f"{n_completers} ({percent_completers:.1f}%)"})

                # Mean (SD) Weight Loss for Completers
                completers_df = df.loc[completers_mask].copy()
                if not completers_df.empty:
                    if wl_kg_col in completers_df.columns:
                        add_stats(wl_kg_var_name, {
                            'Mean (SD)': format_mean_sd(completers_df[wl_kg_col]),
                            'N': len(completers_df[wl_kg_col].dropna())
                        })
                    else:
                        print(f"  Warning: '{wl_kg_col}' column not found.")
                        add_stats(wl_kg_var_name, {'Mean (SD)': np.nan, 'N': 0})
                    if wl_pct_col in completers_df.columns:
                        add_stats(wl_pct_var_name, {
                            'Mean (SD)': format_mean_sd(completers_df[wl_pct_col]),
                            'N': len(completers_df[wl_pct_col].dropna())
                        })
                    else:
                        print(f"  Warning: '{wl_pct_col}' column not found.")
                        add_stats(wl_pct_var_name, {'Mean (SD)': np.nan, 'N': 0})
                else:
                    # print(f"  Note: No completers found for {window}d window to calculate WL stats.") # Less verbose
                    add_stats(wl_kg_var_name, {'Mean (SD)': np.nan, 'N': 0})
                    add_stats(wl_pct_var_name, {'Mean (SD)': np.nan, 'N': 0})

            except Exception as e:
                 print(f"  Error calculating completer stats for {window}d. Column '{dropout_col}' might not be boolean. Error: {e}")
                 add_stats(completer_var_name, {'N (%)': 'Error'})
                 add_stats(wl_kg_var_name, {'Mean (SD)': 'Error', 'N': 'Error'})
                 add_stats(wl_pct_var_name, {'Mean (SD)': 'Error', 'N': 'Error'})

        else:
            print(f"  Warning: Dropout column '{dropout_col}' not found for window {window}d.")
            add_stats(completer_var_name, {'N (%)': np.nan})
            add_stats(wl_kg_var_name, {'Mean (SD)': np.nan, 'N': np.nan})
            add_stats(wl_pct_var_name, {'Mean (SD)': np.nan, 'N': np.nan})


    # --- Dynamic Weight Loss Target Metrics ---
    for target in weight_targets:
        achieved_col = f'{target}%_wl_achieved'
        days_col = f'days_to_{target}%_wl'
        achiever_var_name = f'Achieved {target}% WL'
        days_var_name = f'Days to {target}% WL [Achievers]'

        if achieved_col in df.columns:
            # N (%) Achievers - Direct calculation assuming boolean True
            try:
                achievers_mask = (df[achieved_col] == True) # Explicitly check for boolean True
                n_achievers = achievers_mask.sum()
                percent_achievers = (n_achievers / total_n) * 100 if total_n > 0 else 0
                add_stats(achiever_var_name, {'N (%)': f"{n_achievers} ({percent_achievers:.1f}%)"})

                # Mean (SD) Days to Achievement (for Achievers)
                if days_col in df.columns:
                    achievers_df = df.loc[achievers_mask].copy()
                    if not achievers_df.empty:
                        add_stats(days_var_name, {
                            'Mean (SD)': format_mean_sd(achievers_df[days_col]),
                            'N': len(achievers_df[days_col].dropna())
                        })
                    else:
                        # print(f"  Note: No achievers found for {target}% target to calculate days.") # Less verbose
                        add_stats(days_var_name, {'Mean (SD)': np.nan, 'N': 0})
                else:
                    print(f"  Warning: Days column '{days_col}' not found for target {target}%.")
                    add_stats(days_var_name, {'Mean (SD)': np.nan, 'N': np.nan})

            except Exception as e:
                 print(f"  Error calculating achiever stats for {target}%. Column '{achieved_col}' might not be boolean. Error: {e}")
                 add_stats(achiever_var_name, {'N (%)': 'Error'})
                 add_stats(days_var_name, {'Mean (SD)': 'Error', 'N': 'Error'})

        else:
            print(f"  Warning: Achievement column '{achieved_col}' not found for target {target}%.")
            add_stats(achiever_var_name, {'N (%)': np.nan})
            add_stats(days_var_name, {'Mean (SD)': np.nan, 'N': np.nan})


    # Convert dictionary to DataFrame (Variables as index, Stats as columns)
    summary_df = pd.DataFrame.from_dict(summary_data_by_var, orient='index')
    summary_df.index.name = 'Variable'

    # Reorder STAT columns (optional, done before transpose)
    desired_col_order = ['N', 'N (%)', 'Mean (SD)', 'Median (IQR)', 'Min - Max']
    existing_cols_ordered = [col for col in desired_col_order if col in summary_df.columns]
    remaining_cols = [col for col in summary_df.columns if col not in existing_cols_ordered]
    final_col_order = existing_cols_ordered + remaining_cols
    summary_df = summary_df[final_col_order]

    return summary_df


"""
MAIN ORCHESTRATION FUNCTION
"""

def create_and_save_summary_tables(db_path, input_table, pop_table_out, outcome_table_out, weight_targets, time_windows_list):
    """Loads data, generates both summary tables, transposes for final output, and saves them."""
    conn = None
    try:
        print(f"\nConnecting to database: {db_path}")
        if not os.path.exists(db_path):
            print(f"ERROR: Database file not found at {db_path}")
            return

        conn = sqlite3.connect(db_path)
        print(f"Loading input table: {input_table}")
        df = pd.read_sql_query(f"SELECT * FROM {input_table}", conn)

        if df.empty:
            print(f"Input table '{input_table}' is empty. Cannot generate summaries.")
            return

        # --- Generate Population Summary ---
        pop_summary_raw = generate_population_summary(df.copy())
        if not pop_summary_raw.empty:
            print(f"\n--- Population Summary ({pop_table_out}) ---")
            # Transpose for final output (Stats as index/rows, Variables as columns)
            pop_summary_final = pop_summary_raw #.transpose() OR MAYBE DON'T, Not as nice!
            pop_summary_final.index.name = 'Statistic' # Index is now stats
            print(pop_summary_final.to_string())
            print(f"\nSaving population summary to table: {pop_table_out}")
            # Save the transposed DataFrame, index=True saves 'Statistic' column
            pop_summary_final.to_sql(pop_table_out, conn, if_exists='replace', index=True)
        else:
            print("Population summary generation failed or resulted in an empty table.")

        # --- Generate Outcome Summary ---
        outcome_summary_raw = generate_outcome_summary(df.copy(), weight_targets, time_windows_list)
        if not outcome_summary_raw.empty:
            print(f"\n--- Outcome Summary ({outcome_table_out}) ---")
            # Transpose for final output (Stats as index/rows, Variables as columns)

            outcome_summary_final = outcome_summary_raw # .transpose() OR MAYBE DON'T, it is not as nice
            outcome_summary_final.index.name = 'Statistic' # Index is now stats
            print(outcome_summary_final.to_string())
            print(f"\nSaving outcome summary to table: {outcome_table_out}")
            # Save the transposed DataFrame, index=True saves 'Statistic' column
            outcome_summary_final.to_sql(outcome_table_out, conn, if_exists='replace', index=True)
        else:
            print("Outcome summary generation failed or resulted in an empty table.")

        conn.commit()
        print("\nSummary tables saved successfully.")

    except sqlite3.Error as e:
        print(f"ERROR: SQLite error - {e}")
        if conn: conn.rollback()
    except pd.errors.DatabaseError as e:
         print(f"ERROR: Pandas/Database error during SQL operation - {e}")
         if conn: conn.rollback()
    except KeyError as e:
        print(f"ERROR: A required column name was not found in the DataFrame: {e}")
    except Exception as e:
        print(f"ERROR: An unexpected error occurred - {e}")
        import traceback
        print(traceback.format_exc())
        if conn: conn.rollback()
    finally:
        if conn:
            conn.close()
            print("Database connection closed.")

"""
EXECUTION
"""
# Make sure 'paper1_directory' is defined before this point!
if 'paper1_directory' in locals() or 'paper1_directory' in globals():
    create_and_save_summary_tables(
        db_path,
        input_table_name,
        output_pop_summary_table,
        output_outcome_summary_table,
        weight_loss_targets,
        time_windows
    )
else:
    print("ERROR: 'paper1_directory' variable is not defined. Please define it before running this cell.")



Connecting to database: C:\Users\Felhasználó\Desktop\Projects\PNK_DB2\paper1_emotional\survival_analysis.sqlite
Loading input table: sa_input_table
Generating population summary data (N=1664)...

--- Population Summary (population_summary) ---
                                   N         N (%)    Mean (SD)      Median (IQR)     Min - Max
Statistic                                                                                      
Total Population              1664.0           NaN          NaN               NaN           NaN
age                           1620.0           NaN  47.5 ± 10.2  48.0 (41.0-54.0)   18.0 - 84.0
height_m                      1664.0           NaN    1.7 ± 0.1     1.6 (1.6-1.7)     1.4 - 2.1
baseline_weight_kg            1664.0           NaN  82.8 ± 12.1  82.0 (74.0-89.1)  57.5 - 131.2
last_aval_weight_kg           1664.0           NaN  76.3 ± 10.8  74.7 (68.4-82.4)  53.7 - 130.5
baseline_bmi                  1664.0           NaN   30.2 ± 3.0  30.1 (27.7-32.5)  

#### Next analysis