In [None]:
import pandas as pd
import altair as alt
import matplotlib.pyplot as plt

plt.rcParams['figure.dpi'] = 2000


In [None]:
mistral_df_without_db = pd.read_csv('combined_augmented_mistral_nodb.csv')
mistral_df_with_db = pd.read_csv('combined_augmented_mistral.csv')

In [None]:
for c in mistral_df_without_db.columns:
    if "Augmented" in c:
        print(c)

In [None]:
'Augmented_Age (years)',
'Augmented_Body Temperature (°C)'
'Augmented_Breathing Rate (breaths/min)',

df_with_db = pd.read_csv("./llama3_augmented_results/combined_augmented_all_llama3_1.csv")
df_without_db = pd.read_csv("./llama3_augmented_results/combined_augmented_all_nodb_llama3_1.csv")

In [None]:
maxbins = 20

In [None]:
# gt_age_chart = (
#     alt.Chart(df_with_db)
#     .mark_bar()
#     .encode(
#         alt.X("age:Q"),
#         y="count()",
#     )
#     .properties(
#         width=400,
#         height=400,
#         title={
#             "text": "Ground Truth",
#         },
#     )
# )

# # Save the chart with higher resolution

# gt_age_chart

In [None]:
# Create a histogram for the 'age' feature
gt_age_chart = (
    alt.Chart(df_with_db)
    .mark_bar()
    .encode(
        alt.X("age:Q", ),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "Ground Truth",
        },
    )
)

db_age_chart = (
    alt.Chart(df_with_db)
    .mark_bar()
    .encode(
        alt.X("Augmented_Age (years):Q", ),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "LLama3 With DB",
        },
    )
)

# Create a faceted histogram
no_db_age_chart = (
    alt.Chart(df_without_db)
    .mark_bar()
    .encode(
        alt.X("Augmented_Age (years):Q", ),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "LLama3 Without DB",
        },
    )
)

# mistral with db
mistral_db_age_chart = (
    alt.Chart(mistral_df_with_db)
    .mark_bar()
    .encode(
        alt.X("Augmented_Age:Q", ),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "Mistral With GraphRAG",
        },
    )
)

# mistral no db
mistral_no_db_age_chart = (
    alt.Chart(mistral_df_without_db)
    .mark_bar()
    .encode(
        alt.X("Augmented_Age:Q", ),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "Mistral Without GraphRAG",
        },
    )
)


# Combine the original chart with the new age chart, placing the age chart below the original chart and sharing the x-axis
age_combined_chart = alt.vconcat(
    gt_age_chart,
    db_age_chart,
    no_db_age_chart,
    # mistral_db_age_chart,
    # mistral_no_db_age_chart,
).resolve_scale(x="shared")
age_combined_chart = age_combined_chart.properties(
    title={
        "text": "Age",
    }
)
age_combined_chart

In [None]:
# Just ground-truth and mistral in horizontal
mistral_age_combined_chart = alt.hconcat(
    gt_age_chart,
    mistral_db_age_chart,
    mistral_no_db_age_chart,
).resolve_scale(x="shared")

mistral_age_combined_chart.save('mistarl_overall_combined_chart.png', scale_factor=8.0)

In [None]:
mistral_age_combined_chart

In [None]:
# transform the temperature from Fahrenheit to Celsius
df_with_db["body_temperature"] = df_with_db["temperature"].apply(
    lambda x: (x - 32) * 5.0 / 9.0
)

In [None]:
maxbins = 100
# Create a histogram for the 'body_temperature' feature
gt_temp_chart = (
    alt.Chart(df_with_db)
    .mark_bar()
    .encode(
        alt.X("body_temperature:Q"),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "Ground Truth",
        },
    )
)

db_temp_chart = (
    alt.Chart(df_with_db)
    .mark_bar()
    .encode(
        alt.X("Augmented_Body Temperature (°C):Q"),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "LLama3 With DB",
        },
    )
)

# Create a faceted histogram
no_db_temp_chart = (
    alt.Chart(df_without_db)
    .mark_bar()
    .encode(
        alt.X("Augmented_Body Temperature (degrees Celsius):Q"),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "LLama3 Without DB",
        },
    )
)

# Combine the original chart with the new temperature chart, placing the temperature chart below the original chart and sharing the x-axis
temp_combined_chart = alt.vconcat(
    gt_temp_chart,
    db_temp_chart,
    no_db_temp_chart,
).resolve_scale(x="shared")
temp_combined_chart = temp_combined_chart.properties(
    title={
        "text": "Body Temperature",
    }
)

temp_combined_chart

In [None]:
with_db_temperature_outliers_count = df_with_db["Augmented_Body Temperature (°C)"].apply(lambda x: x > 45).sum()
without_db_temperature_outliers_count = df_without_db["Augmented_Body Temperature (degrees Celsius)"].apply(lambda x: x > 45).sum()
gt_outliers_count = df_with_db["temperature"].apply(lambda x: x < 45).sum()

print(gt_outliers_count, with_db_temperature_outliers_count, without_db_temperature_outliers_count)

In [None]:
df_with_db["temperature"].min(), df_with_db["temperature"].max()

In [None]:
df_with_db["Augmented_Body Temperature (°C)"] = df_with_db[
    "Augmented_Body Temperature (°C)"
].apply(lambda x: (x - 32) * 5.0 / 9.0 if x > 45 else x)

# 40 before calibration, but not after

In [None]:
# Create a histogram for the 'body_temperature' feature
maxbins = 20
gt_temp_chart = (
    alt.Chart(df_with_db)
    .mark_bar()
    .encode(
        alt.X("body_temperature:Q"),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "Ground Truth",
        },
    )
)

db_temp_chart = (
    alt.Chart(df_with_db)
    .mark_bar()
    .encode(
        alt.X("Augmented_Body Temperature (°C):Q"),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "LLama3 With GraphRAG",
        },
    )
)

# Create a faceted histogram
no_db_temp_chart = (
    alt.Chart(df_without_db)
    .mark_bar()
    .encode(
        alt.X("Augmented_Body Temperature (degrees Celsius):Q"),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "LLama3 Without GraphRAG",
        },
    )
)

# Combine the original chart with the new temperature chart, placing the temperature chart below the original chart and sharing the x-axis
fixed_temp_combined_chart = alt.vconcat(
    gt_temp_chart,
    db_temp_chart,
    no_db_temp_chart,
).resolve_scale(x="shared")
fixed_temp_combined_chart = fixed_temp_combined_chart.properties(
    title={
        "text": "Body Temperature",
    }
)
fixed_temp_combined_chart

In [None]:
# Create a histogram for the 'resprate' feature
gt_resprate_chart = (
    alt.Chart(df_with_db)
    .mark_bar()
    .encode(
        alt.X("resprate:Q"),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "Ground Truth",
        },
    )
)

db_resprate_chart = (
    alt.Chart(df_with_db)
    .mark_bar()
    .encode(
        alt.X("Augmented_Breathing Rate (breaths/min):Q"),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "LLama3 With GraphRAG",
        },
    )
)

# Create a faceted histogram
no_db_resprate_chart = (
    alt.Chart(df_without_db)
    .mark_bar()
    .encode(
        alt.X("Augmented_Breathing Rate (breaths/min):Q"),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "LLama3 Without GraphRAG",
        },
    )
)

# Combine the original chart with the new resprate chart, placing the resprate chart below the original chart and sharing the x-axis
resprate_combined_chart = alt.vconcat(
    gt_resprate_chart,
    db_resprate_chart,
    no_db_resprate_chart,
).resolve_scale(x="shared")
resprate_combined_chart = resprate_combined_chart.properties(
    title={
        "text": "Respiratory Rate",
    }
)
resprate_combined_chart

In [None]:
def to_oxygen_saturation(P_O2, P50=26.6, n=2.8):
    """
    Transform partial pressure of oxygen (P_O2 in mmHg) into oxygen saturation (S_O2 in %).
    
    Parameters:
    - P_O2: Partial pressure of oxygen in mmHg
    - P50: The partial pressure at which hemoglobin is 50% saturated (default: 26.6 mmHg)
    - n: Hill coefficient, describing the steepness of the curve (default: 2.8)
    
    Returns:
    - S_O2: Oxygen saturation as a percentage (%)
    """
    S_O2 = (P_O2**n / (P_O2**n + P50**n)) * 100
    return S_O2

In [None]:
df_with_db["Augmented_o2Sat"] = df_with_db["Augmented_Oxygen levels (mmHg)"].apply(lambda x: to_oxygen_saturation(x))
df_without_db["Augmented_o2Sat"] = df_without_db["Augmented_Oxygen levels (mmHg)"].apply(lambda x: to_oxygen_saturation(x))

In [None]:
# Create a histogram for the 'o2sat' feature
gt_o2sat_chart = (
    alt.Chart(df_with_db)
    .mark_bar()
    .encode(
        alt.X("o2sat:Q"),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "Ground Truth",
        },
    )
)

db_o2sat_chart = (
    alt.Chart(df_with_db)
    .mark_bar()
    .encode(
        alt.X("Augmented_o2Sat:Q"),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "LLama3 With GraphRAG",
        },
    )
)

# Create a faceted histogram
no_db_o2sat_chart = (
    alt.Chart(df_without_db)
    .mark_bar()
    .encode(
        alt.X("Augmented_o2Sat:Q"),
        y="count()",
    )
    .properties(
        width=200,
        height=100,
        title={
            "text": "LLama3 Without GraphRAG",
        },
    )
)

# Combine the original chart with the new o2sat chart, placing the o2sat chart below the original chart and sharing the x-axis
o2sat_combined_chart = alt.vconcat(
    gt_o2sat_chart,
    db_o2sat_chart,
    no_db_o2sat_chart,
).resolve_scale(x="shared")
o2sat_combined_chart = o2sat_combined_chart.properties(
    title={
        "text": "Oxygen Saturation",
    }
)
o2sat_combined_chart


In [None]:

# Set the DPI for higher resolution
overall_combined_chart = alt.hconcat(
    age_combined_chart,
    fixed_temp_combined_chart,
    o2sat_combined_chart,
    resprate_combined_chart,
)

overall_combined_chart.save('overall_combined_chart.png', scale_factor=8.0)


In [None]:
overall_combined_chart

In [None]:
from sklearn.metrics import root_mean_squared_error
from scipy.stats import ks_2samp
from scipy.spatial.distance import jensenshannon
import pandas as pd


def calculate_jsd(p, q):
    return jensenshannon(p, q, base=2)


def get_rmse(
    df_with_db,
    df_without_db,
    gt_col,
    db_col,
    without_db_col,
    mistral_df_with_db=None,
    mistral_df_without_db=None,
    mistral_db_col=None,
    mistral_without_db_col=None,
):
    with_db_rmse = root_mean_squared_error(df_with_db[gt_col], df_with_db[db_col])
    without_db_rmse = root_mean_squared_error(
        df_with_db[gt_col], df_without_db[without_db_col]
    )

    ks_statistic_with_db, p_value_with_db = ks_2samp(
        df_with_db[gt_col], df_with_db[db_col]
    )
    ks_statistic_without_db, p_value_without_db = ks_2samp(
        df_with_db[gt_col], df_without_db[without_db_col]
    )

    with_db_jsd = calculate_jsd(df_with_db[gt_col], df_with_db[db_col])
    without_db_jsd = calculate_jsd(df_with_db[gt_col], df_without_db[without_db_col])

    print(f"LLama3.1 RMSE with DB: {with_db_rmse}")
    print(f"LLama3.1 RMSE without DB: {without_db_rmse}")
    print(
        f"LLama3.1 KS Statistic with DB: {ks_statistic_with_db}, P-value with DB: {p_value_with_db}"
    )
    print(
        f"LLama3.1 KS Statistic without DB: {ks_statistic_without_db}, P-value without DB: {p_value_without_db}"
    )
    print(f"LLama3.1 Jensen-Shannon Divergence with DB: {with_db_jsd}")
    print(f"LLama3.1 Jensen-Shannon Divergence without DB: {without_db_jsd}")

    output = {
        "LLama3.1 RMSE with DB": with_db_rmse,
        "LLama3.1 RMSE without DB": without_db_rmse,
        "LLama3.1 KS Statistic with DB": ks_statistic_with_db,
        "LLama3.1 KS Statistic without DB": ks_statistic_without_db,
        "LLama3.1 P-value with DB": p_value_with_db,
        "LLama3.1 P-value without DB": p_value_without_db,
        "LLama3.1 Jensen-Shannon Divergence with DB": with_db_jsd,
        "LLama3.1 Jensen-Shannon Divergence without DB": without_db_jsd,
    }

    if mistral_df_with_db is not None and mistral_df_without_db is not None:

        mistral_with_db_rmse = root_mean_squared_error(
            mistral_df_with_db[gt_col], mistral_df_with_db[mistral_db_col]
        )
        mistral_without_db_rmse = root_mean_squared_error(
            mistral_df_without_db[gt_col], mistral_df_without_db[mistral_without_db_col]
        )
        mistral_ks_statistic_with_db, mistral_p_value_with_db = ks_2samp(
            mistral_df_with_db[gt_col], mistral_df_with_db[mistral_db_col]
        )
        mistral_without_db_ks_statistic, mistral_p_value_without_db = ks_2samp(
            mistral_df_without_db[gt_col], mistral_df_without_db[mistral_without_db_col]
        )

        mistral_with_db_jsd = calculate_jsd(
            mistral_df_with_db[gt_col], mistral_df_with_db[mistral_db_col]
        )
        mistral_without_db_jsd = calculate_jsd(
            mistral_df_without_db[gt_col], mistral_df_without_db[mistral_without_db_col]
        )

        print(f"Mistral RMSE with DB: {mistral_with_db_rmse}")
        print(f"Mistral RMSE without DB: {mistral_without_db_rmse}")
        print(
            f"Mistral KS Statistic with DB: {mistral_ks_statistic_with_db}, P-value with DB: {mistral_p_value_with_db}"
        )
        print(
            f"Mistral KS Statistic without DB: {mistral_without_db_ks_statistic}, P-value without DB: {mistral_p_value_without_db}"
        )
        print(f"Mistral Jensen-Shannon Divergence with DB: {mistral_with_db_jsd}")
        print(f"Mistral Jensen-Shannon Divergence without DB: {mistral_without_db_jsd}")

        output.update(
            {
                "Mistral RMSE with GraphRAG": mistral_with_db_rmse,
                "Mistral RMSE without GraphRAG": mistral_without_db_rmse,
                "Mistral KS Statistic with GraphRAG": mistral_ks_statistic_with_db,
                "Mistral KS Statistic without GraphRAG": mistral_without_db_ks_statistic,
                "Mistral P-value with GraphRAG": mistral_p_value_with_db,
                "Mistral P-value without GraphRAG": mistral_p_value_without_db,
                "Mistral Jensen-Shannon Divergence with GraphRAG": mistral_with_db_jsd,
                "Mistral Jensen-Shannon Divergence without GraphRAG": mistral_without_db_jsd,
            }
        )

    return output


results = []
print("=" * 40)
print("Running feature: Age")
results.append(
    {
        "Feature": "Age",
        **get_rmse(
            df_with_db,
            df_without_db,
            "age",
            "Augmented_Age (years)",
            "Augmented_Age (years)",
            mistral_df_with_db=mistral_df_with_db,
            mistral_df_without_db=mistral_df_without_db,
            mistral_db_col="Augmented_Age",
            mistral_without_db_col="Augmented_Age",
        ),
    }
)
print("=" * 40)
print("Running feature: Body Temperature")
results.append(
    {
        "Feature": "Body Temperature",
        **get_rmse(
            df_with_db,
            df_without_db,
            "body_temperature",
            "Augmented_Body Temperature (°C)",
            "Augmented_Body Temperature (degrees Celsius)",
        ),
    }
)

print("=" * 40)
print("Running feature: Respiratory Rate")
results.append(
    {
        "Feature": "Respiratory Rate",
        **get_rmse(
            df_with_db,
            df_without_db,
            "resprate",
            "Augmented_Breathing Rate (breaths/min)",
            "Augmented_Breathing Rate (breaths/min)",
        ),
    }
)
print("=" * 40)
print("Running feature: Oxygen Saturation")
results.append(
    {
        "Feature": "Oxygen Saturation",
        **get_rmse(
            df_with_db,
            df_without_db,
            "o2sat",
            "Augmented_Oxygen levels (mmHg)",
            "Augmented_Oxygen levels (mmHg)",
        ),
    }
)

# Convert results to DataFrame and save to CSV
results_df = pd.DataFrame(results, index=None).transpose()
results_df.to_csv("augmentation_eval.csv")

In [None]:
results_df

In [None]:
mistral_df_with_db

In [None]:
mistral_df_with_db['age']

In [None]:
mistral_df_with_db['Augmented_Age']

In [None]:
# from scipy.stats import ks_2samp

# # Perform the Kolmogorov-Smirnov test
# ks_statistic_with_db, p_value_with_db = ks_2samp(df_with_db['age'], df_with_db["Augmented_Age (years)"])
# ks_statistic_without_db, p_value_without_db = ks_2samp(df_with_db['age'], df_without_db["Augmented_Age (years)"])


# print(ks_statistic_with_db, p_value_with_db)
# print(ks_statistic_without_db, p_value_without_db)

# from sklearn.metrics import root_mean_squared_error

# def get_rmse(
#     df_with_db,
#     df_without_db,
#     gt_col,
#     db_col,
#     without_db_col,
# ):
#     with_db_mse = root_mean_squared_error(df_with_db[gt_col], df_with_db[db_col])
#     without_db_mse = root_mean_squared_error(
#         df_with_db[gt_col], df_without_db[without_db_col]
#     )

#     ks_statistic_with_db, p_value_with_db = ks_2samp(df_with_db[gt_col], df_with_db[db_col])
#     ks_statistic_without_db, p_value_without_db = ks_2samp(df_with_db[gt_col], df_without_db[without_db_col])

#     print("RMSE with DB:", with_db_mse)
#     print("RMSE without DB:", without_db_mse)
#     print(f"KS Statistic with DB: [{ks_statistic_with_db}], P-value: [{p_value_with_db}]")
#     print(f"KS Statistic without DB: [{ks_statistic_without_db}], P-value: [{p_value_without_db}]")

# print("="*50)
# print("Age")
# get_rmse(
#     df_with_db,
#     df_without_db,
#     "age",
#     "Augmented_Age (years)",
#     "Augmented_Age (years)",
# )
# print("="*50)
# print("Body Temperature")
# get_rmse(
#     df_with_db,
#     df_without_db,
#     "temperature",
#     "Augmented_Body Temperature (°C)",
#     "Augmented_Body Temperature (degrees Celsius)",
# )
# print("="*50)
# print("Body Temperature Fixed")
# get_rmse(
#     df_with_db,
#     df_without_db,
#     "temperature",
#     "Augmented_Body Temperature (°C) fixed",
#     "Augmented_Body Temperature (degrees Celsius)",
# )
# print("="*50)
# print("Respiratory Rate")
# get_rmse(
#     df_with_db,
#     df_without_db,
#     "resprate",
#     "Augmented_Breathing Rate (breaths/min)",
#     "Augmented_Breathing Rate (breaths/min)",
# )
# print("="*50)
# print("Oxygen Saturation")
# get_rmse(
#     df_with_db,
#     df_without_db,
#     "o2sat",
#     "Augmented_Oxygen levels (mmHg)",
#     "Augmented_Oxygen levels (mmHg)",
# )