In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from src.database import Database
from src.config import DATA_DIR

In [None]:
db = Database(DATA_DIR / "lab_tracker.db")
conn = db.connect()
df = pd.read_sql_query(
    """                                
    SELECT collection_date, test_name, result, reference_min, reference_max, unit, out_of_range 
    FROM test_result
    LEFT JOIN report ON test_result.report_id = report.id
    WHERE (test_name, unit) IN (
        SELECT test_name, unit
        FROM test_result
        GROUP BY test_name, unit
        HAVING COUNT(*) > 1
    )
    ORDER BY test_name, collection_date;
""",
    conn,
)

In [None]:
df["result_numeric"] = pd.to_numeric(df["result"], errors="coerce")
df["test_label"] = df["test_name"] + " (" + df["unit"] + ")"
df["collection_date"] = pd.to_datetime(df["collection_date"])
# df_plot = df[
#     df["result_numeric"].notna()
# ] # remove CRP and the ones that have string value e.g. "<5"
df_plot = df

In [None]:
# Pre-compute reference values
ref_lookup = (
    df_plot.groupby("test_label")[["reference_min", "reference_max"]]
    .first()
    .to_dict("index")
)

# Get ALL unique dates from the entire dataset
all_dates = sorted(df_plot["collection_date"].unique())

# Plot with reference ranges
g = sns.FacetGrid(df_plot, col="test_label", col_wrap=5, sharey=False)
g.map(sns.lineplot, "collection_date", "result_numeric", marker="o")

# Add reference ranges
for ax in g.axes.flat:
    title = ax.get_title().replace("test_label = ", "")
    if not title:
        continue

    refs = ref_lookup.get(title, {})
    ref_min = refs.get("reference_min")
    ref_max = refs.get("reference_max")

    # Shade the reference range
    if pd.notna(ref_min) and pd.notna(ref_max):
        ax.axhspan(ref_min, ref_max, alpha=0.2, color="green")
    elif pd.notna(ref_max):
        # Only max: values should be BELOW this line
        ax.axhline(ref_max, color="green", linestyle="--")
        ax.text(
            ax.get_xlim()[1], ref_max, " ↓", va="center", fontsize=12, color="green"
        )
    elif pd.notna(ref_min):
        # Only min: values should be ABOVE this line
        ax.axhline(ref_min, color="green", linestyle="--")
        ax.text(
            ax.get_xlim()[1], ref_min, " ↑", va="center", fontsize=12, color="green"
        )

    # Expand y-axis limits
    ymin, ymax = ax.get_ylim()
    padding = (ymax - ymin) * 0.2
    ax.set_ylim(ymin - padding, ymax + padding)

    # Set x-ticks to ALL dates (not just ones with data)
    x_padding = pd.Timedelta(days=30)
    ax.set_xlim(all_dates[0] - x_padding, all_dates[-1] + x_padding)
    ax.set_xticks(all_dates)
    ax.set_xticklabels(
        [d.strftime("%Y-%m-%d") for d in all_dates], rotation=45, ha="right", fontsize=8
    )

plt.tight_layout()