## firebase setup

In [1]:
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
import os

SERVICE_ACCOUNT_KEY_PATH = os.environ.get("FIREBASE_SERVICE_ACCOUNT_KEY_PATH", "./cstam2-1f2ec-firebase-adminsdk-fbsvc-2ab61a7ed6.json")

try:
    # Check if the default app is already initialized
    app = firebase_admin.get_app()
    print("Firebase Admin SDK already initialized. Reusing existing app instance.")
except ValueError:
    # If not initialized, proceed with initialization
    try:
        cred = credentials.Certificate(SERVICE_ACCOUNT_KEY_PATH)
        app = firebase_admin.initialize_app(cred)
        print("Firebase Admin SDK initialized successfully.")
    except Exception as e:
        print(f"Error during Firebase Admin SDK initialization: {e}")
        # It's crucial to handle this error, as your app can't write to Firestore without it.
        raise # Re-raise to stop the script if Firebase initialization fails

db = firestore.client(app=app) 


An error occurred: module 'importlib.metadata' has no attribute 'packages_distributions'
Firebase Admin SDK initialized successfully.




## setup spark

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, when, lit, to_timestamp, from_unixtime
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, IntegerType, DoubleType, BooleanType

import time
import logging

# Configure logger
logger = logging.getLogger("DataCleaning")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("[%(asctime)s] %(levelname)s - %(message)s")
handler.setFormatter(formatter)
if not logger.hasHandlers():
    logger.addHandler(handler)
# ---------- Spark session ----------
spark = (
    SparkSession.builder
    .appName("health-streams-to-firebase")
    .getOrCreate()
)
spark.sparkContext.setLogLevel("ERROR")
food_df = None

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/14 20:18:13 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## import and calculate daily metrics

### helper functions

In [3]:
import os
import time
import traceback

from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    DoubleType,
)
from pyspark.sql.functions import (
    col,
    avg,
    expr,
    to_timestamp,when, sum as _sum,
)


def get_user_ids(db, collection_name="users", debug=False):
    """
    Robustly enumerate top-level user document IDs.
    Tries, in order:
      1) collection_ref.list_documents()
      2) collection_ref.get()
      3) collection_group (search for nested 'users' if top-level empty)
    Returns list of user_ids (strings).
    """
    user_ids = []


    try:
        coll_ref = db.collection(collection_name)

        # 1) Try list_documents() (lightweight, returns DocumentReference objects)
        try:
            docs = list(coll_ref.list_documents())
            if docs:
                user_ids = [d.id for d in docs]
                if debug:
                    print(f"‚úÖ get_user_ids: using list_documents(), found {len(user_ids)} users.")
                return user_ids
            else:
                if debug:
                    print("‚ÑπÔ∏è list_documents() returned 0 refs; falling back to get().")
        except Exception as e:
            if debug:
                print("‚ö†Ô∏è list_documents() failed:", repr(e))

        # 2) Fallback: collection.get() (reads documents)
        try:
            docs = list(coll_ref.get())
            if docs:
                user_ids = [d.id for d in docs]
                if debug:
                    print(f"‚úÖ get_user_ids: using get(), found {len(user_ids)} users.")
                return user_ids
            else:
                if debug:
                    print("‚ÑπÔ∏è collection.get() returned 0 documents; falling back to collection_group().")
        except Exception as e:
            if debug:
                print("‚ö†Ô∏è collection.get() failed:", repr(e))

        # 3) Last resort: collection_group search (if 'users' lives nested under other paths)
        try:
            group_docs = list(db.collection_group(collection_name).limit(500).stream())
            if group_docs:
                # Use parent document id if the structure is like parent/{parent_id}/users/{user_id}
                user_ids = [d.reference.id for d in group_docs]
                if debug:
                    print(f"‚úÖ get_user_ids: using collection_group('{collection_name}'), found {len(user_ids)} documents.")
                return user_ids
            else:
                if debug:
                    print(f"‚ÑπÔ∏è collection_group('{collection_name}') returned 0 results.")
        except Exception as e:
            if debug:
                print("‚ö†Ô∏è collection_group() failed:", repr(e))

    except Exception as e:
        if debug:
            print("‚ùå get_user_ids: unexpected error:", repr(e))

    # final: empty result
    if debug:
        print("‚ùå get_user_ids: no users found by any method.")
    return []

# -------------------------
# Example: read_all_users using get_user_ids
# -------------------------
def read_all_users(db, collection_name="users", batch_size=500, debug=False):
    """
    Read user subcollections per user id using robust enumeration.
    Returns list of records (dicts) same as earlier fetch_user_data would produce.
    """
    try:
        user_ids = get_user_ids(db, collection_name=collection_name, debug=debug)
        if debug:
            print(f"üìã Found {len(user_ids)} user doc ids.")

        all_records = []
        for user_id in user_ids:
            if debug:
                print(f"Processing user: {user_id}")
            try:
                # Reuse your existing fetch_user_data(db, user_id, ...) here.
                # Example minimal call (replace with your function):
                user_records = fetch_user_data(db, user_id, page_size=batch_size)  # assume function exists
                if not user_records and debug:
                    print(f"‚ÑπÔ∏è No subcollection docs for user {user_id}.")
                all_records.extend(user_records)
                time.sleep(0.1)
            except Exception as e:
                if debug:
                    print(f"‚ö†Ô∏è Failed reading data for user {user_id}: {e}")
        return all_records

    except Exception as e:
        if debug:
            print("‚ùå read_all_users error:", repr(e))
        return []# --- Pagination helper ---
def paginate_collection(collection_ref, page_size=500):
    try:
        docs = collection_ref.limit(page_size).stream()
        last_doc = None
        while True:
            batch = list(docs)
            if not batch:
                break
            yield batch
            last_doc = batch[-1]
            docs = collection_ref.start_after(last_doc).limit(page_size).stream()
    except Exception as e:
        print("‚ö†Ô∏è Pagination failed:", e)

# --- Fetch a single user's subcollections with debugging ---
def fetch_user_data(db, user_id, page_size=500, debug=False):
    user_ref = db.collection("users").document(user_id)
    records = []
    for metric in ["calories", "heart_rate", "steps"]:
        try:
            sub_ref = user_ref.collection(metric)
            all_docs = []
            for batch in paginate_collection(sub_ref, page_size):
                for doc in batch:
                    all_docs.append(doc)
            if debug:
                print(f"üß© User {user_id} | {metric}: {len(all_docs)} docs")

            for doc in all_docs:
                data = doc.to_dict() or {}
                if debug and data:
                    print(f"   ‚Ü≥ sample doc: {data}")

                # detect possible timestamp and value fields
                timestamp_field = None
                for k in ["timestamp", "time", "created_at", "date"]:
                    if k in data:
                        timestamp_field = k
                        break

                value_field = None
                for k in ["value", "heart_rate", "calories", "steps", "count"]:
                    if k in data:
                        value_field = k
                        break

                if not timestamp_field or not value_field:
                    if debug:
                        print(f"   ‚ö†Ô∏è Missing timestamp/value in {metric} doc: {data}")
                    continue

                rec = {
                    "user_id": user_id,
                    "metric_type": metric,
                    "timestamp": data[timestamp_field],
                    "value": float(data[value_field]) if data[value_field] is not None else 0.0,
                }
                records.append(rec)
        except Exception as e:
            print(f"‚ö†Ô∏è Failed to read {metric} for user {user_id}: {e}")
    return records



# --- Build DataFrame safely ---
def build_dataframe(spark, raw_data):
    schema = StructType([
        StructField("user_id", StringType(), True),
        StructField("metric_type", StringType(), True),
        StructField("timestamp", StringType(), True),
        StructField("value", DoubleType(), True),
    ])
    try:
        df = spark.createDataFrame(raw_data or [], schema=schema)
        print(f"‚úÖ DataFrame created with {df.count()} rows.")
        return df
    except Exception as e:
        print("‚ùå Error creating DataFrame:", e)
        return spark.createDataFrame([], schema=schema)
    
    # --- Safely cast string timestamp -> proper timestamp ---
    try:
        df = df.withColumn(
            "timestamp",
            to_timestamp(col("timestamp"), "yyyy-MM-dd'T'HH:mm:ss")
        )
    except Exception as e:
        print("‚ö†Ô∏è Failed to cast timestamp column:", repr(e))
# --- Core logic ---
def get_last24h_and_averages(spark, db, batch_size=300, debug=False):
    raw_data = read_all_users(db, batch_size=batch_size, debug=debug)
    print(f"‚úÖ Total records fetched: {len(raw_data)}")

    df = build_dataframe(spark, raw_data)
    if df.isEmpty():
        print("‚ÑπÔ∏è No data available.")
        return df, df

    df_recent = df.filter(col("timestamp") >= expr("current_timestamp() - INTERVAL 24 HOURS"))
    print(f"üïí Records after 24h filter: {df_recent.count()}")

    df_avg = (
        df_recent.groupBy("user_id", "metric_type")
        .agg(
            # Conditional aggregation: use avg for heart_rate, sum for others
            avg(when(col("metric_type") == "heart_rate", col("value"))).alias("heartbeat_avg"),
            _sum(when(col("metric_type") != "heart_rate", col("value"))).alias("others_sum")
        )
        .withColumn(
            "daily_value",
            when(col("metric_type") == "heart_rate", col("heartbeat_avg"))
            .otherwise(col("others_sum"))
        )
        .select("user_id", "metric_type", "daily_value")
        .orderBy("user_id", "metric_type")
    )

    return df_recent, df_avg

### excution

In [4]:
df_recent, df_daily_activity = get_last24h_and_averages(spark, db, batch_size=200, debug=True)


‚úÖ get_user_ids: using list_documents(), found 34 users.
üìã Found 34 user doc ids.
Processing user: 1503960366
Processing user: 1644430081
Processing user: 1844505072
Processing user: 1927972279
Processing user: 2022484408
Processing user: 2026352035
Processing user: 2320127002
Processing user: 2347167796
Processing user: 2873212765
Processing user: 3372868164
Processing user: 3977333714
Processing user: 4020332650
Processing user: 4057192912
Processing user: 4319703577
Processing user: 4388161847
Processing user: 4445114986
Processing user: 4558609924
Processing user: 4702921684
Processing user: 5553957443
Processing user: 5577150313
Processing user: 6290855005
Processing user: 6775888955
Processing user: 6962181067
Processing user: 7007744171
Processing user: 7086361926
Processing user: 8053475328
Processing user: 8378563200
Processing user: 8583815059
Processing user: 8792009665
Processing user: 8877689391
Processing user: user_4020332650
‚ÑπÔ∏è No subcollection docs for user use

                                                                                

‚úÖ DataFrame created with 12108 rows.
üïí Records after 24h filter: 0


## load daily meals df

In [5]:
def load_daily_meals( db, debug=False):
    """
    Load monthly_plan documents for all users for target_date and return a Spark DataFrame.
    - spark: SparkSession
    - db: firestore.Client()
    - target_date: datetime.date | datetime.datetime | 'YYYY-MM-DD' string
    """

    user_ids = get_user_ids(db)
    if debug:
        print(f"Found {len(user_ids)} users")
    
    rows = []
    for uid in user_ids:

        docs = db.collection(f"users/{uid}/daily_meals").stream()
        docs = [doc.to_dict() for doc in docs]
        try:
            for doc in docs:
                rows.append(doc)
        
        except Exception as e:
            if debug:
                print(f"Query failed for {uid}: {e}")

    if not rows:
        if debug:
            print(f"‚ö†Ô∏è No meals found")
        empty_schema = StructType([
            StructField("user_id", StringType(), True),
        ])
        return spark.createDataFrame([], schema=empty_schema)

    # infer schema dynamically from first document
    sample = rows[0]
    fields = []
    for k, v in sample.items():
        if k == "user_id":
            fields.append(StructField(k, StringType(), True))
        elif isinstance(v, bool):
            fields.append(StructField(k, BooleanType(), True))
        elif isinstance(v, (int, float)):
            fields.append(StructField(k, DoubleType(), True))
        else:
            fields.append(StructField(k, StringType(), True))

    schema = StructType(fields)
    df = spark.createDataFrame(rows, schema=schema)

    return df
df_daily_meals = load_daily_meals(db)

In [6]:
df_daily_meals.show()

+------------------+---------------+---------+--------------+----------+------------------+------------------+----------+
|          height_m|sport_available|meal_type|     food_name|      date|         weight_kg|          water_ml|   user_id|
+------------------+---------------+---------+--------------+----------+------------------+------------------+----------+
|1.7999999523162842|           true|    Lunch|      Broccoli|2025-11-28|              76.0| 607.1673583984375|4020332650|
|1.8200000524520874|           true|Breakfast|     Chocolate|2025-11-04| 62.79999923706055|441.80242919921875|4020332650|
|1.8200000524520874|           true|    Snack|        Salmon|2025-11-04| 62.79999923706055| 182.9405975341797|4020332650|
|1.7999999523162842|           true|    Snack|      Broccoli|2025-11-28|              76.0| 298.7261047363281|4020332650|
|1.8200000524520874|          false|    Lunch|        Grapes|2025-11-21|62.400001525878906| 604.0103759765625|4020332650|
|1.7599999904632568|    

## daily planner

### join helper functions

In [10]:
from pyspark.sql import functions as F
from pyspark.sql import DataFrame
from pyspark.sql.types import StringType, DoubleType, BooleanType
import logging

logger = logging.getLogger("fatigue-detector")
logger.setLevel(logging.INFO)
if not logger.handlers:
    h = logging.StreamHandler()
    h.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s"))
    logger.addHandler(h)

# ------------------ Utilities ------------------

def _ensure_columns(df: DataFrame, cols_with_types: dict = None) -> DataFrame:
    """
    Ensure the DataFrame contains the listed columns. If a column is missing,
    create it with NULL and cast to the supplied type (or StringType by default).
    cols_with_types: { "col_name": pyspark.sql.types.DataType, ... }
    """
    if cols_with_types is None:
        cols_with_types = {}
    existing = set(df.columns)
    out = df
    for col_name, dtype in cols_with_types.items():
        if col_name not in existing:
            logger.info("Column missing, adding NULL column: %s", col_name)
            # add null column with cast to dtype if provided
            cast_type = dtype if dtype is not None else StringType()
            out = out.withColumn(col_name, F.lit(None).cast(cast_type))
    return out

# ------------------ Helpers ------------------

def _find_user_id_col(df: DataFrame):
    """Return the name of the user-id-like column, or None."""
    candidates = ["user_id", "userid", "UserID", "User_Id", "uid", "user"]
    cols_lower = {c.lower(): c for c in df.columns}
    for cand in candidates:
        if cand.lower() in cols_lower:
            return cols_lower[cand.lower()]
    return None

def _normalize_user_id_column(df: DataFrame) -> DataFrame:
    """Ensure there is a 'user_id' column that's trimmed and lowercased."""
    col_name = _find_user_id_col(df)
    if col_name is None:
        # If no column looks like a user id, add a NULL user_id (so joins won't crash)
        logger.warning("No user-id-like column found in DataFrame. Adding NULL 'user_id' column. Columns: %s", df.columns)
        return df.withColumn("user_id", F.lit(None).cast(StringType()))
    # create/replace normalized 'user_id' string column
    return df.withColumn("user_id", F.lower(F.trim(F.col(col_name).cast(StringType()))))

def _pivot_activity_if_needed(activity_df: DataFrame) -> DataFrame:
    """
    Pivot long-format activity df into wide; leave as-is if already wide.
    This version tolerates missing columns and adds NULL columns for expected metrics.
    """
    cols = set(activity_df.columns)

    # ensure at least user_id exists (or add it as NULL)
    activity_df = _normalize_user_id_column(activity_df)

    if "avg_heart_rate" in cols or "total_steps" in cols:
        logger.info("Activity DF already wide-format.")
        return activity_df

    # If it's long format, confirm metric_type/daily_value exist
    if "metric_type" in cols and "daily_value" in cols:
        logger.info("Pivoting long-format activity DF (metric_type/daily_value).")
        # pivot; use first value so missing metrics become NULL automatically
        pivoted = (
            activity_df
            .groupBy("user_id")
            .pivot("metric_type")
            .agg(F.first("daily_value"))
        )
        # rename common metrics if present
        rename_map = {}
        if "heart_rate" in pivoted.columns:
            rename_map["heart_rate"] = "avg_heart_rate"
        if "steps" in pivoted.columns:
            rename_map["steps"] = "total_steps"
        if "calories" in pivoted.columns:
            rename_map["calories"] = "total_calories_burnt"

        out = pivoted
        for old, new in rename_map.items():
            if old in out.columns:
                out = out.withColumnRenamed(old, new)

        # Ensure canonical columns exist so downstream code won't fail
        needed = {
            "avg_heart_rate": DoubleType(),
            "total_steps": DoubleType(),
            "total_calories_burnt": DoubleType()
        }
        out = _ensure_columns(out, needed)
        return out

    logger.warning("Activity DF does not have expected columns for pivot (avg_heart_rate/metric_type/daily_value). Columns: %s", activity_df.columns)
    # As a fallback, ensure canonical columns exist with NULLs
    fallback_needed = {
        "avg_heart_rate": DoubleType(),
        "total_steps": DoubleType(),
        "total_calories_burnt": DoubleType()
    }
    return _ensure_columns(activity_df, fallback_needed)

def _safe_rename_nutrition_cols(nutrition_df: DataFrame) -> DataFrame:
    """Rename nutrition columns with spaces/units into safe snake_case names.
       If a source column is missing, do nothing for it (no error).
    """
    rename_map = {
        "Calories (kcal)": "calories_kcal",
        "Protein (g)": "protein_g",
        "Carbohydrates (g)": "carbs_g",
        "Fat (g)": "fat_g",
        "Fiber (g)": "fiber_g",
        "Sugars (g)": "sugars_g",
        "Sodium (mg)": "sodium_mg",
        "Cholesterol (mg)": "cholesterol_mg",
        "Water_Intake (ml)": "water_ml"
    }
    df = nutrition_df
    for old, new in rename_map.items():
        if old in df.columns:
            df = df.withColumnRenamed(old, new)
        # also handle total_ prefixed variants safely
        total_old = f"total_{old}"
        total_new = f"total_{new}"
        if total_old in df.columns:
            df = df.withColumnRenamed(total_old, total_new)

    # After renaming, ensure canonical nutrition columns exist with sensible types
    ensure_map = {
        "calories_kcal": DoubleType(),
        "protein_g": DoubleType(),
        "carbs_g": DoubleType(),
        "fat_g": DoubleType(),
        "fiber_g": DoubleType(),
        "sugars_g": DoubleType(),
        "sodium_mg": DoubleType(),
        "cholesterol_mg": DoubleType(),
        "water_ml": DoubleType()
    }
    df = _ensure_columns(df, ensure_map)
    return df

def _detect_calorie_and_water_cols(nutrition_df: DataFrame):
    """
    Return (calories_col, water_col) names if present; otherwise None.
    This will not raise if columns are missing.
    """
    calories_candidates = ["total_calories_kcal", "calories_kcal", "total_Calories (kcal)", "Calories (kcal)"]
    water_candidates = ["total_water_ml", "water_ml", "total_Water_Intake (ml)", "Water_Intake (ml)"]

    calories_col = next((c for c in calories_candidates if c in nutrition_df.columns), None)
    water_col = next((c for c in water_candidates if c in nutrition_df.columns), None)

    if calories_col is None:
        logger.info("No calories column detected. Candidates tried: %s", calories_candidates)
    if water_col is None:
        logger.info("No water column detected. Candidates tried: %s", water_candidates)

    return calories_col, water_col
def _compute_physiologic_factors(df: DataFrame) -> DataFrame:
    df = df.withColumn("total_steps", F.coalesce(F.col("total_steps"), F.lit(0)))
    df = df.withColumn("avg_heart_rate", F.coalesce(F.col("avg_heart_rate"), F.lit(70)))
    df = df.withColumn("total_calories_burnt", F.coalesce(F.col("total_calories_burnt"), F.lit(2000)))
    df = df.withColumn("calories_intake", F.coalesce(F.col("calories_intake"), F.lit(0)))
    df = df.withColumn("water_ml", F.coalesce(F.col("water_ml"), F.lit(0)))
    df = df.withColumn("sleep_hours", F.coalesce(F.col("sleep_hours"), F.lit(7.0)))
    df = df.withColumn("weight", F.coalesce(F.col("weight"), F.lit(70.0)))

    df = df.withColumn(
        "sleep_factor",
        F.when(F.col("sleep_hours") < 5, 0.3)
         .when(F.col("sleep_hours") < 6, 0.5)
         .when(F.col("sleep_hours") < 7, 0.8)
         .when(F.col("sleep_hours") <= 8, 1.0)
         .otherwise(0.9)
    )

    df = df.withColumn(
        "activity_factor",
        F.when(F.col("total_steps") < 3000, 0.4)
         .when(F.col("total_steps") < 6000, 0.7)
         .when(F.col("total_steps") <= 10000, 1.0)
         .when(F.col("total_steps") <= 12000, 0.8)
         .otherwise(0.6)
    )

    df = df.withColumn(
        "heart_factor",
        F.when(F.col("avg_heart_rate") <= 60, 1.0)
         .when(F.col("avg_heart_rate") <= 75, 0.8)
         .when(F.col("avg_heart_rate") <= 85, 0.6)
         .otherwise(0.4)
    )

    df = df.withColumn(
        "energy_ratio",
        F.when(F.col("total_calories_burnt") > 0,
               F.col("calories_intake") / F.col("total_calories_burnt"))
         .otherwise(F.lit(0.0))
    )

    df = df.withColumn(
        "energy_factor",
        F.when(F.col("energy_ratio") < 0.7, 0.5)
         .when(F.col("energy_ratio") < 0.9, 0.8)
         .when(F.col("energy_ratio") <= 1.1, 1.0)
         .otherwise(0.9)
    )

    df = df.withColumn(
        "hydration_factor",
        F.least(F.col("water_ml") / (F.col("weight") * F.lit(35.0)), F.lit(1.0))
    )

    df = df.withColumn(
        "fatigue_score",
        0.35 * F.col("sleep_factor") +
        0.20 * F.col("activity_factor") +
        0.15 * F.col("heart_factor") +
        0.20 * F.col("energy_factor") +
        0.10 * F.col("hydration_factor")
    )

    df = df.withColumn(
        "fatigue_level",
        F.when(F.col("fatigue_score") < 0.4, "Exhausted")
         .when(F.col("fatigue_score") < 0.65, "Tired")
         .when(F.col("fatigue_score") < 0.85, "Normal")
         .otherwise("Energetic")
    )

    return df

### compute fatigue functions

In [11]:
def detect_daily_fatigue_with_diagnostics(raw_activity_df: DataFrame, raw_nutrition_df: DataFrame) -> DataFrame:
    """
    Detect fatigue per user, continuing even when data or columns are missing.
    Returns DataFrame with fatigue metrics and diagnostic info.
    """

    # Normalize IDs early
    activity_df = _normalize_user_id_column(raw_activity_df)
    nutrition_df = _normalize_user_id_column(raw_nutrition_df)

    logger.info("Activity columns: %s", activity_df.columns)
    logger.info("Nutrition columns: %s", nutrition_df.columns)

    # Safely count rows
    try:
        a_count, n_count = activity_df.count(), nutrition_df.count()
        logger.info("Counts: activity=%d, nutrition=%d", a_count, n_count)
    except Exception as e:
        logger.warning("Could not count DataFrames: %s", e)
        a_count = n_count = -1

    # Sample user IDs
    try:
        a_sample = [r["user_id"] for r in activity_df.select("user_id").distinct().limit(10).collect()] if "user_id" in activity_df.columns else []
        n_sample = [r["user_id"] for r in nutrition_df.select("user_id").distinct().limit(10).collect()] if "user_id" in nutrition_df.columns else []
        logger.info("Sample user_ids activity=%s", a_sample)
        logger.info("Sample user_ids nutrition=%s", n_sample)
    except Exception as e:
        logger.warning("Could not sample user_ids: %s", e)

    # Pivot and normalize activity safely
    activity_pivoted = _pivot_activity_if_needed(activity_df)
    activity_pivoted = _normalize_user_id_column(activity_pivoted)

    # Rename nutrition safely and ensure expected cols exist
    nutrition_safe = _safe_rename_nutrition_cols(nutrition_df)

    # Detect columns dynamically
    calories_col, water_col = _detect_calorie_and_water_cols(nutrition_safe)
    logger.info("Detected columns: calories=%s, water=%s", calories_col, water_col)

    # Guarantee expected nutrition columns exist (even if empty)
    expected_nutrition_cols = ["sleep_hours", "weight", "height"]
    for colname in expected_nutrition_cols:
        if colname not in nutrition_safe.columns:
            logger.info("Missing column in nutrition, adding NULL: %s", colname)
            nutrition_safe = nutrition_safe.withColumn(colname, F.lit(None).cast(DoubleType()))

    # Detect or fallback for date
    has_a_date, has_n_date = "date" in activity_pivoted.columns, "date" in nutrition_safe.columns
    if has_a_date and has_n_date:
        date_expr = F.coalesce(F.col("a.date"), F.col("n.date")).alias("date")
    elif has_a_date:
        date_expr = F.col("a.date").alias("date")
    elif has_n_date:
        date_expr = F.col("n.date").alias("date")
    else:
        logger.info("No date columns found ‚Äî using current_date as fallback")
        date_expr = F.current_date().alias("date")

    # Build select list safely
    sel = [
        F.col("a.user_id").alias("user_id"),
        date_expr,
        F.col("a.avg_heart_rate") if "avg_heart_rate" in activity_pivoted.columns else F.lit(None).alias("avg_heart_rate"),
        F.col("a.total_steps") if "total_steps" in activity_pivoted.columns else F.lit(None).alias("total_steps"),
        F.col("a.total_calories_burnt") if "total_calories_burnt" in activity_pivoted.columns else F.lit(None).alias("total_calories_burnt"),
        (F.col(f"n.{calories_col}") if calories_col else F.lit(None)).alias("calories_intake"),
        F.col("n.sleep_hours") if "sleep_hours" in nutrition_safe.columns else F.lit(None).alias("sleep_hours"),
        (F.col("n.weight") if "weight" in nutrition_safe.columns else F.lit(None)).alias("weight"),
        (F.col("n.height") if "height" in nutrition_safe.columns else F.lit(None)).alias("height"),
        (F.col(f"n.{water_col}") if water_col else F.lit(None)).alias("water_ml"),
    ]

    # Try joining safely
    try:
        joined_inner = (
            activity_pivoted.alias("a")
            .join(nutrition_safe.alias("n"), on="user_id", how="inner")
            .select(*sel)
        )
        joined_count = joined_inner.count()
        logger.info("Inner join produced %d rows", joined_count)
    except Exception as e:
        logger.warning("Join failed (%s). Falling back to left join.", e)
        joined_inner = activity_pivoted.alias("a").join(nutrition_safe.alias("n"), on="user_id", how="left").select(*sel)
        joined_count = -1

    if joined_count == 0:
        logger.warning("Inner join empty ‚Äî diagnosing mismatch and using left join fallback.")
        try:
            a_uids = set([r["user_id"] for r in activity_pivoted.select("user_id").distinct().collect()])
            n_uids = set([r["user_id"] for r in nutrition_safe.select("user_id").distinct().collect()])
            logger.info("Unique activity IDs (sample): %s", list(a_uids)[:20])
            logger.info("Unique nutrition IDs (sample): %s", list(n_uids)[:20])
        except Exception as e:
            logger.warning("Could not inspect user_id sets: %s", e)

        joined_left = (
            activity_pivoted.alias("a")
            .join(nutrition_safe.alias("n"), on="user_id", how="left")
            .select(*sel)
        )
        logger.info("Left join row count: %d", joined_left.count())
        result = _compute_physiologic_factors(joined_left)
        result = result.withColumn("join_diagnostic", F.lit("inner_empty_fallback_left"))
        return result

    # Normal path
    result = _compute_physiologic_factors(joined_inner)
    result = result.withColumn("join_diagnostic", F.lit("inner_ok"))
    return result


In [12]:
result_df = detect_daily_fatigue_with_diagnostics(df_daily_activity, df_daily_meals)
fatigue_df = result_df.drop("join_diagnostic")

[INFO] 2025-11-14 20:26:21,929 - Activity columns: ['user_id', 'metric_type', 'daily_value']
[INFO] 2025-11-14 20:26:21,931 - Nutrition columns: ['height_m', 'sport_available', 'meal_type', 'food_name', 'date', 'weight_kg', 'water_ml', 'user_id']
[INFO] 2025-11-14 20:26:22,291 - Counts: activity=0, nutrition=2160
[INFO] 2025-11-14 20:26:22,595 - Sample user_ids activity=[]
[INFO] 2025-11-14 20:26:22,596 - Sample user_ids nutrition=['4020332650', '5553957443', '5577150313', 'user_4020332650', 'user_5553957443', 'user_5577150313']
[INFO] 2025-11-14 20:26:22,611 - Pivoting long-format activity DF (metric_type/daily_value).
[INFO] 2025-11-14 20:26:22,787 - Column missing, adding NULL column: avg_heart_rate
[INFO] 2025-11-14 20:26:22,796 - Column missing, adding NULL column: total_steps
[INFO] 2025-11-14 20:26:22,808 - Column missing, adding NULL column: total_calories_burnt
[INFO] 2025-11-14 20:26:22,831 - Column missing, adding NULL column: calories_kcal
[INFO] 2025-11-14 20:26:22,838 - C

## generate daily plan

### load the target plan of the day

In [14]:
from google.cloud import firestore
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, DateType
import datetime




def _to_date(raw):
    if raw is None:
        return None
    if isinstance(raw, datetime.date) and not isinstance(raw, datetime.datetime):
        return raw
    if isinstance(raw, datetime.datetime):
        return raw.date()
    if isinstance(raw, str):
        s = raw.strip()
        try:
            return datetime.date.fromisoformat(s)
        except Exception:
            try:
                return datetime.datetime.fromisoformat(s.replace("Z", "+00:00")).date()
            except Exception:
                import re
                m = re.search(r"(\d{4})[-/](\d{2})[-/](\d{2})", s)
                if m:
                    y, mo, d = m.groups()
                    try:
                        return datetime.date(int(y), int(mo), int(d))
                    except Exception:
                        return None
    return None


def load_monthly_plan_for_date( db, target_date, debug=False):
    """
    Load monthly_plan documents for all users for target_date and return a Spark DataFrame.
    - spark: SparkSession
    - db: firestore.Client()
    - target_date: datetime.date | datetime.datetime | 'YYYY-MM-DD' string
    """
    if isinstance(target_date, datetime.datetime):
        target_date_str = target_date.date().isoformat()
    elif isinstance(target_date, datetime.date):
        target_date_str = target_date.isoformat()
    else:
        target_date_str = str(target_date).strip()

    user_ids = get_user_ids(db)
    if debug:
        print(f"Found {len(user_ids)} users")
    
    rows = []
    for uid in user_ids:

        docs = db.collection(f"users/{uid}/monthly_plan").stream()
        docs = [doc.to_dict() for doc in docs]
        try:
            for doc in docs:
                doc["date"] = target_date_str
                rows.append(doc)
        
        except Exception as e:
            if debug:
                print(f"Query failed for {uid}: {e}")

    if not rows:
        if debug:
            print(f"‚ö†Ô∏è No monthly_plan found for {target_date_str}")
        empty_schema = StructType([
            StructField("user_id", StringType(), True),
            StructField("date", DateType(), True),
        ])
        return spark.createDataFrame([], schema=empty_schema)

    # infer schema dynamically from first document
    sample = rows[0]
    fields = []
    for k, v in sample.items():
        if k == "user_id":
            fields.append(StructField(k, StringType(), True))
        elif k == "date":
            fields.append(StructField(k, StringType(), True))  # keep as string for consistency
        elif isinstance(v, (int, float)):
            fields.append(StructField(k, DoubleType(), True))
        else:
            fields.append(StructField(k, StringType(), True))

    schema = StructType(fields)
    df = spark.createDataFrame(rows, schema=schema)

    return df


In [15]:
df_plans = load_monthly_plan_for_date(db, "2025-11-12", debug=True)
# create Spark DF if you want:
df_plans.columns


Found 34 users


['total_Calories (kcal)',
 'sleep_hours',
 'total_Sugars (g)',
 'total_Sodium (mg)',
 'total_Fat (g)',
 'target_steps',
 'total_Cholesterol (mg)',
 'date',
 'target_calories_burnt',
 'total_Carbohydrates (g)',
 'user_id',
 'total_Protein (g)',
 'total_Fiber (g)',
 'total_Water_Intake (ml)']

### load reference food and activities

In [16]:
# --- 1. Load datasets ---
if not food_df:
    food_df = spark.read.csv("./food_nutrition_dataset.csv", header=True, inferSchema=True)

# --- 2. Define meal structure ---
meal_structure = {
    "Breakfast": ["Grains", "Dairy", "Fruits", "Beverages"],
    "Lunch": ["Meat", "Vegetables", "Grains"],
    "Snack": ["Snacks", "Fruits", "Beverages"],
    "Dinner": ["Meat", "Vegetables", "Grains", "Dairy"]
}

# Broadcast static food data for performance
food_broadcast = spark.sparkContext.broadcast(food_df.collect())


In [17]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

activity_data = [
    # --- Recovery / Tired states ---
    {"Activity": "Rest", "Category": "Tired", "Duration_min": 30, "Calories_Burned": 0, "Intensity": "Very Low"},
    {"Activity": "Gentle Yoga", "Category": "Tired", "Duration_min": 30, "Calories_Burned": 80, "Intensity": "Low"},
    {"Activity": "Light Walk", "Category": "Recovery", "Duration_min": 20, "Calories_Burned": 100, "Intensity": "Low"},
    {"Activity": "Stretching Routine", "Category": "Recovery", "Duration_min": 15, "Calories_Burned": 50, "Intensity": "Low"},

    # --- Normal state ---
    {"Activity": "Brisk Walk", "Category": "Normal", "Duration_min": 30, "Calories_Burned": 150, "Intensity": "Moderate"},
    {"Activity": "Bodyweight Strength", "Category": "Normal", "Duration_min": 30, "Calories_Burned": 200, "Intensity": "Moderate"},
    {"Activity": "Cycling (Easy)", "Category": "Normal", "Duration_min": 30, "Calories_Burned": 180, "Intensity": "Moderate"},

    # --- Energetic state ---
    {"Activity": "HIIT Circuit", "Category": "Energetic", "Duration_min": 20, "Calories_Burned": 300, "Intensity": "High"},
    {"Activity": "Running / Jogging", "Category": "Energetic", "Duration_min": 30, "Calories_Burned": 350, "Intensity": "High"},
    {"Activity": "Strength Training (Weights)", "Category": "Energetic", "Duration_min": 40, "Calories_Burned": 400, "Intensity": "High"},
    {"Activity": "Spin Class / Cardio", "Category": "Energetic", "Duration_min": 30, "Calories_Burned": 380, "Intensity": "High"},
]

activity_df = spark.createDataFrame(activity_data)
activity_df.show(truncate=False)


+---------------------------+---------------+---------+------------+---------+
|Activity                   |Calories_Burned|Category |Duration_min|Intensity|
+---------------------------+---------------+---------+------------+---------+
|Rest                       |0              |Tired    |30          |Very Low |
|Gentle Yoga                |80             |Tired    |30          |Low      |
|Light Walk                 |100            |Recovery |20          |Low      |
|Stretching Routine         |50             |Recovery |15          |Low      |
|Brisk Walk                 |150            |Normal   |30          |Moderate |
|Bodyweight Strength        |200            |Normal   |30          |Moderate |
|Cycling (Easy)             |180            |Normal   |30          |Moderate |
|HIIT Circuit               |300            |Energetic|20          |High     |
|Running / Jogging          |350            |Energetic|30          |High     |
|Strength Training (Weights)|400            |Energet

### meals planner functions

In [18]:
import random

def adjust_daily_targets_pro(row):
    """
    Professional-grade adaptive target adjustment logic.
    Uses fatigue, sleep, hydration, activity, and energy balance to optimize user plans.
    """
    # --- Extract key inputs ---
    fatigue = row.fatigue_score if row.fatigue_score is not None else 0.8
    sleep = row.sleep_hours
    intake = row.calories_intake
    burnt = row.total_calories_burnt
    water = row.water_ml
    weight = row.weight
    height = row.height
    activity = row.total_steps

    # --- Derived metrics ---
    energy_balance = (intake - burnt) / max(intake, 1)
    hydration_ratio = water / max(row["total_Water_Intake (ml)"], 1)
    sleep_deficit = (row["sleep_hours"] - sleep) / max(row["sleep_hours"], 1)
    tired = fatigue < 0.7
    overreached = fatigue < 0.5 and sleep_deficit > 0.2
    energetic = fatigue > 0.9 and sleep > row["sleep_hours"]

    # --- Base daily targets ---
    targets = {
        "total_Calories (kcal)": row["total_Calories (kcal)"],
        "total_Protein (g)": row["total_Protein (g)"],
        "total_Carbohydrates (g)": row["total_Carbohydrates (g)"],
        "total_Fat (g)": row["total_Fat (g)"],
        "total_Water_Intake (ml)": row["total_Water_Intake (ml)"],
        "sleep_hours": row["sleep_hours"],
        "target_steps": row["target_steps"],
        "target_calories_burnt": row["target_calories_burnt"],
    }

    # --- Adaptive adjustment rules ---
    if overreached:
        # Deep recovery mode
        adj = {
            "total_Calories (kcal)": 0.95,
            "total_Protein (g)": 1.15,
            "total_Fat (g)": 0.9,
            "target_steps": 0.75,
            "target_calories_burnt": 0.75,
            "sleep_hours": +1.0,
            "total_Water_Intake (ml)": 1.2
        }
        state = "Recovery"

    elif tired:
        # Light fatigue ‚Üí promote repair
        adj = {
            "total_Calories (kcal)": 0.9,
            "total_Protein (g)": 1.1,
            "total_Fat (g)": 0.85,
            "target_steps": 0.85,
            "target_calories_burnt": 0.85,
            "sleep_hours": +0.5,
            "total_Water_Intake (ml)": 1.1
        }
        state = "Tired"

    elif energetic:
        # Strong recovery ‚Üí allow progression
        adj = {
            "total_Calories (kcal)": 1.1,
            "total_Protein (g)": 1.05,
            "total_Carbohydrates (g)": 1.1,
            "target_steps": 1.15,
            "target_calories_burnt": 1.15,
            "total_Water_Intake (ml)": 1.05
        }
        state = "Energetic"

    else:
        # Normal balance
        adj = {k: 1.0 for k in targets}
        state = "Normal"

    # --- Fine-tune hydration & energy ---
    # Correct for hydration deficiency or excess
    if hydration_ratio < 0.9:
        targets["total_Water_Intake (ml)"] *= 1.15
    elif hydration_ratio > 1.2:
        targets["total_Water_Intake (ml)"] *= 0.95

    # Adjust energy targets based on balance
    if energy_balance < -0.15:  # calorie deficit too large
        targets["total_Calories (kcal)"] *= 1.05
    elif energy_balance > 0.15:  # surplus too large
        targets["total_Calories (kcal)"] *= 0.95

    # Apply scaling
    for k, v in adj.items():
        if isinstance(v, (int, float)):
            if v > 1:
                targets[k] *= v
            elif v < 1:
                targets[k] *= v
            elif isinstance(v, (int, float)) and "+" in str(v):
                targets[k] += v

    return targets, state


# --- 2. Meal Generation ---
meal_structure = {
    "Breakfast": ["Grains", "Dairy", "Fruits", "Beverages"],
    "Lunch": ["Meat", "Vegetables", "Grains"],
    "Snack": ["Snacks", "Fruits", "Beverages"],
    "Dinner": ["Meat", "Vegetables", "Grains", "Dairy"]
}

def generate_plan_with_activity(row):
    """
    Generate a full daily plan including meals and recommended activity.
    Activity is chosen based on user's fatigue/recovery state and explained.
    """
    meals = generate_explainable_plan(row)  # existing meal plan function

    # Filter activities matching user's plan_state
    activities = [a for a in activity_data.value if a["Category"] == row.plan_state]
    
    # Fallback to Rest if none found
    if activities:
        chosen_activity = random.choice(activities)
    else:
        chosen_activity = {
            "Activity": "Rest",
            "Category": row.plan_state,
            "Duration_min": 30,
            "Calories_Burned": 0,
            "Intensity": "Very Low"
        }

    # Append activity as a special "meal/activity" row
    meals.append((
        row.user_id,
        row.date,
        "Activity",                      # Treat as a meal type for partitioning
        chosen_activity["Activity"],
        chosen_activity["Category"],
        float(chosen_activity["Calories_Burned"]),  # calories burned
        0.0, 0.0, 0.0,                    # no macronutrients
        row.plan_state,
        f"Activity assigned based on state {row.plan_state}, duration {chosen_activity.get('Duration_min', 30)} min, intensity {chosen_activity.get('Intensity', 'Low')}",
        float(row.fatigue_score),
        float(row.sleep_hours),
        float(row.total_steps),
        float(row.total_calories_burnt),
        float(row.calories_intake),
        float(row.total_Water_Intake),
        0.0, 0.0
    ))

    return meals


In [24]:
import datetime
import random
from pyspark.sql import functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType
)

# ----------------------------------------------------
# 1. Final explainable plan schema
# ----------------------------------------------------
schema = StructType([
    StructField("user_id", StringType()),
    StructField("date", StringType()),
    StructField("meal", StringType()),
    StructField("food_item", StringType()),
    StructField("category", StringType()),
    StructField("Calories", DoubleType()),
    StructField("Protein", DoubleType()),
    StructField("Carbs", DoubleType()),
    StructField("Fat", DoubleType()),
    StructField("plan_state", StringType()),          # Recovery / Tired / Energetic / Normal
    StructField("reasoning", StringType()),           # Human-readable explanation
    StructField("fatigue_score", DoubleType()),
    StructField("sleep_hours", DoubleType()),
    StructField("total_steps", DoubleType()),
    StructField("total_calories_burnt", DoubleType()),
    StructField("calories_intake", DoubleType()),
    StructField("total_Water_Intake (ml)", DoubleType()),
    StructField("energy_balance", DoubleType()),
    StructField("hydration_ratio", DoubleType())
])

# ----------------------------------------------------
# 2. Adaptive scaling logic with explanation
# ----------------------------------------------------
def adjust_daily_targets_explainable(row):
    # --- Safe getters ---
    def g(x, default=0):
        return x if x is not None else default

    fatigue = g(row.fatigue_score, 0.8)
    sleep = g(row.sleep_hours, 0)
    intake = g(row.calories_intake, 0)
    burnt = g(row.total_calories_burnt, 0)
    water = g(row.water_ml, 0)
    steps = g(row.total_steps, 0)

    base_sleep = g(row["sleep_hours"], sleep)
    base_water = g(row["total_Water_Intake (ml)"], water)
    base_calories = g(row["total_Calories (kcal)"], intake)

    # --- Derived safely ---
    energy_balance = (intake - burnt) / max(intake, 1)
    hydration_ratio = water / max(base_water, 1)

    # You previously had: (row["sleep_hours"] - sleep)
    # But row["sleep_hours"] == sleep ‚Üí always 0 or None
    # Meaning this metric was logically incorrect.
    sleep_deficit = max(0, (8 - sleep) / 8)   # Fix: deficit relative to ideal sleep

    # --- Fatigue logic ---
    tired = fatigue < 0.7
    overreached = fatigue < 0.5 and sleep_deficit > 0.2
    energetic = fatigue > 0.9 and sleep > 7

    # --- Initialize targets ---
    targets = {
        "total_Calories (kcal)": base_calories,
        "total_Protein (g)": g(row["total_Protein (g)"], 0),
        "total_Carbohydrates (g)": g(row["total_Carbohydrates (g)"], 0),
        "total_Fat (g)": g(row["total_Fat (g)"], 0),
        "total_Water_Intake (ml)": base_water,
        "sleep_hours": base_sleep,
        "target_steps": g(row["target_steps"], steps),
        "target_calories_burnt": g(row["target_calories_burnt"], burnt),
    }

    reasoning = []

    # --- State logic unchanged ---
    if overreached:
        state = "Recovery"
        reasoning.append("Detected signs of overreaching: low fatigue score and sleep deficit.")
        targets["total_Calories (kcal)"] *= 0.95
        targets["total_Protein (g)"] *= 1.15
        targets["total_Fat (g)"] *= 0.9
        targets["target_steps"] *= 0.75
        targets["target_calories_burnt"] *= 0.75
        targets["sleep_hours"] += 1.0
        targets["total_Water_Intake (ml)"] *= 1.2

    elif tired:
        state = "Tired"
        reasoning.append("Mild fatigue detected ‚Äî focusing on recovery and protein repair.")
        targets["total_Calories (kcal)"] *= 0.9
        targets["total_Protein (g)"] *= 1.1
        targets["total_Fat (g)"] *= 0.85
        targets["target_steps"] *= 0.85
        targets["target_calories_burnt"] *= 0.85
        targets["sleep_hours"] += 0.5
        targets["total_Water_Intake (ml)"] *= 1.1

    elif energetic:
        state = "Energetic"
        reasoning.append("High recovery ‚Äî increasing load slightly for progression.")
        targets["total_Calories (kcal)"] *= 1.1
        targets["total_Protein (g)"] *= 1.05
        targets["total_Carbohydrates (g)"] *= 1.1
        targets["target_steps"] *= 1.15
        targets["target_calories_burnt"] *= 1.15
        targets["total_Water_Intake (ml)"] *= 1.05

    else:
        state = "Normal"
        reasoning.append("Balanced state ‚Äî maintaining plan at normal levels.")

    # --- Fine-tune ---
    if energy_balance < -0.15:
        targets["total_Calories (kcal)"] *= 1.05
        reasoning.append("Caloric deficit detected ‚Äî slightly increasing intake.")
    elif energy_balance > 0.15:
        targets["total_Calories (kcal)"] *= 0.95
        reasoning.append("Caloric surplus detected ‚Äî moderating intake.")

    if hydration_ratio < 0.9:
        targets["total_Water_Intake (ml)"] *= 1.15
        reasoning.append("Low hydration ‚Äî increasing water intake.")
    elif hydration_ratio > 1.2:
        targets["total_Water_Intake (ml)"] *= 0.95
        reasoning.append("High hydration ‚Äî lowering water target slightly.")

    return targets, state, "; ".join(reasoning), energy_balance, hydration_ratio


# ----------------------------------------------------
# 3. Adaptive meal plan generator
# ----------------------------------------------------
meal_structure = {
    "Breakfast": ["Grains", "Dairy", "Fruits", "Beverages"],
    "Lunch": ["Meat", "Vegetables", "Grains"],
    "Snack": ["Snacks", "Fruits", "Beverages"],
    "Dinner": ["Meat", "Vegetables", "Grains", "Dairy"]
}
def generate_explainable_plan(row):
    global food_data

    def safe_float(x):
        return float(x) if x is not None else None

    targets, state, reasoning, energy_balance, hydration_ratio = adjust_daily_targets_explainable(row)
    meals = []

    # Skip entire day if key targets are missing
    if row.fatigue_score is None or row.sleep_hours is None:
        return []

    for meal, categories in meal_structure.items():
        if state in ["Tired", "Recovery"] and meal in ["Lunch", "Dinner"]:
            categories = [c for c in categories if c not in ["Snacks", "Fried", "HighFat"]]

        allowed = [f for f in food_data.value if f["Category"] in categories]
        if not allowed:
            continue

        chosen = random.sample(allowed, min(3, len(allowed)))

        for f in chosen:
            cal = safe_float(f.get("Calories (kcal)"))
            prot = safe_float(f.get("Protein (g)"))
            carbs = safe_float(f.get("Carbohydrates (g)"))
            fat = safe_float(f.get("Fat (g)"))

            meals.append((
                row.user_id,
                row.date,
                meal,
                f["Food_Item"],
                f["Category"],
                cal,
                prot,
                carbs,
                fat,
                state,
                reasoning,
                safe_float(row.fatigue_score),
                safe_float(row.sleep_hours),
                safe_float(row.total_steps),
                safe_float(row.total_calories_burnt),
                safe_float(row.calories_intake),
                safe_float(targets.get("total_Water_Intake (ml)")),
                safe_float(energy_balance),
                safe_float(hydration_ratio)
            ))
    return meals



# ----------------------------------------------------
# 4. Merge daily targets and fatigue data
# ----------------------------------------------------
daily_df = df_plans.join(fatigue_df, on=["user_id", "date"], how="left")
food_data = spark.sparkContext.broadcast(food_df.collect())


In [25]:
# ----------------------------------------------------
# 5. Generate adaptive explainable plans
# ----------------------------------------------------
all_plans = (
    daily_df.rdd.flatMap(generate_explainable_plan)
    .toDF(schema)
)
# ----------------------------------------------------
# 6. User summaries
# ----------------------------------------------------
summary_df = (
    all_plans.groupBy("user_id")
    .agg(
        F.sum("Calories").alias("TotalCalories"),
        F.sum("Protein").alias("TotalProtein"),
        F.sum("Carbs").alias("TotalCarbs"),
        F.sum("Fat").alias("TotalFat"),
        F.first("plan_state").alias("State"),
        F.first("reasoning").alias("Explanation")
    )
)
summary_df.show(10)
summary_df.columns

all_plans.show(10)
all_plans.count()

+-------+-------------+------------+----------+--------+-----+-----------+
|user_id|TotalCalories|TotalProtein|TotalCarbs|TotalFat|State|Explanation|
+-------+-------------+------------+----------+--------+-----+-----------+
+-------+-------------+------------+----------+--------+-----+-----------+

+-------+----+----+---------+--------+--------+-------+-----+---+----------+---------+-------------+-----------+-----------+--------------------+---------------+-----------------------+--------------+---------------+
|user_id|date|meal|food_item|category|Calories|Protein|Carbs|Fat|plan_state|reasoning|fatigue_score|sleep_hours|total_steps|total_calories_burnt|calories_intake|total_Water_Intake (ml)|energy_balance|hydration_ratio|
+-------+----+----+---------+--------+--------+-------+-----+---+----------+---------+-------------+-----------+-----------+--------------------+---------------+-----------------------+--------------+---------------+
+-------+----+----+---------+--------+--------+-

0

In [21]:
import logging
import time
import datetime
import decimal
import json
from google.cloud import firestore
from google.api_core.exceptions import GoogleAPICallError, RetryError
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, FloatType, BooleanType, DateType
import random

# ----------------------------
# 1. Logging Setup
# ----------------------------
fs_logger = logging.getLogger("firestore-writer")
fs_logger.setLevel(logging.INFO)
if not fs_logger.hasHandlers():
    h = logging.StreamHandler()
    h.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s"))
    fs_logger.addHandler(h)


# ----------------------------
# 2. Helper Functions
# ----------------------------
def _serialize_value(v):
    if v is None:
        return None
    if isinstance(v, (datetime.datetime, datetime.date)):
        return v.isoformat()
    if isinstance(v, decimal.Decimal):
        return float(v)
    if isinstance(v, (bytes, bytearray)):
        try:
            return v.decode()
        except Exception:
            return str(v)
    try:
        json.dumps(v)
        return v
    except Exception:
        return str(v)


def _commit_with_retries(batch_obj, max_retries=3, base_backoff=0.5):
    attempt = 0
    while True:
        try:
            batch_obj.commit()
            return
        except (GoogleAPICallError, RetryError, IOError) as e:
            attempt += 1
            if attempt > max_retries:
                fs_logger.exception("Firestore commit failed after %d attempts", attempt - 1)
                raise
            backoff = base_backoff * (2 ** (attempt - 1))
            fs_logger.warning(
                "Transient error committing firestore batch (attempt %d). Backing off %.2fs. Error: %s",
                attempt, backoff, str(e)
            )
            time.sleep(backoff)

def _throttle():
    nonlocal writes_this_minute, minute_window_start

    now = time.time()
    if now - minute_window_start >= 60:
        writes_this_minute = 0
        minute_window_start = now

    if writes_this_minute >= max_writes_per_minute:
        sleep_for = 60 - (now - minute_window_start)
        fs_logger.warning(f"Rate limit reached. Sleeping {sleep_for:.1f}s to stay under/min limit.")
        time.sleep(sleep_for)
        writes_this_minute = 0
        minute_window_start = time.time()

def make_firestore_writer(
    collection_name,
    firestore_client,
    batch_size=500,
    max_retries=3,
    max_writes_per_run=20000,     # Protect daily free-tier
    max_writes_per_minute=500     # Optional rate limit
):
    writes_this_run = 0
    writes_this_minute = 0
    minute_window_start = time.time()

    def write_batch_to_firestore(batch_df, epoch_id=None):
        rows = batch_df.count()
        if not rows:
            fs_logger.info("[epoch %s] empty, skipping collection=%s", str(epoch_id), collection_name)
            return

        fs_logger.info("[epoch %s] writing %s rows to Firestore collection '%s'",
                       str(epoch_id), rows, collection_name)

        docs_written = 0
        ops_in_current_batch = 0
        fs_batch = firestore_client.batch()

        for row in batch_df.toLocalIterator():
        
            if writes_this_run >= max_writes_per_run:
                fs_logger.warning(
                    f"[epoch {epoch_id}] Stopped early ‚Äî reached max_writes_per_run={max_writes_per_run}"
                )
                break
        
            _throttle()
        
            data = row.asDict(recursive=True)
            user_id = data.get("user_id")
        
            if not user_id:
                fs_logger.warning("[epoch %s] skipping row without user_id", str(epoch_id))
                continue
        
            for k, v in list(data.items()):
                data[k] = _serialize_value(v)
        
            doc_ref = (
                firestore_client.collection("users")
                .document(str(user_id))
                .collection(collection_name)
                .document()
            )
        
            fs_batch.set(doc_ref, data)
            ops_in_current_batch += 1
            writes_this_run += 1
            writes_this_minute += 1
        
            if ops_in_current_batch >= batch_size:
                _commit_with_retries(fs_batch, max_retries=max_retries)
                ops_in_current_batch = 0
                fs_batch = firestore_client.batch()

        if ops_in_current_batch > 0:
            _commit_with_retries(fs_batch, max_retries=max_retries)
            docs_written += ops_in_current_batch

        fs_logger.info("[epoch %s] wrote %d docs under users/*/%s/",
                       str(epoch_id), docs_written, collection_name)
    return write_batch_to_firestore



In [37]:
writer = make_firestore_writer("daily_plan", db)

# Run the batch write
writer(all_plans,epoch_id="2025-11-13")

[INFO] 2025-11-12 20:18:53,404 - [epoch None] writing 64800 rows to Firestore collection 'daily_plan'
[ERROR] 2025-11-12 20:25:27,986 - Firestore commit failed after 3 attempts
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/google/api_core/grpc_helpers.py", line 75, in error_remapped_callable
    return callable_(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/grpc/_interceptor.py", line 277, in __call__
    response, ignored_call = self._with_call(
  File "/usr/local/lib/python3.8/dist-packages/grpc/_interceptor.py", line 332, in _with_call
    return call.result(), call
  File "/usr/local/lib/python3.8/dist-packages/grpc/_channel.py", line 440, in result
    raise self
  File "/usr/local/lib/python3.8/dist-packages/grpc/_interceptor.py", line 315, in continuation
    response, call = self._thunk(new_method).with_call(
  File "/usr/local/lib/python3.8/dist-packages/grpc/_channel.py", line 1198, in with_call
    return _end_unary_respo

RetryError: Timeout of 60.0s exceeded, last exception: 429 Quota exceeded.