##### Copyright 2025 Google LLC.
Licensed under the Apache 2.0 License.

In [None]:
# @title Licensed under the Apache 2.0 License (the "License"); { display-mode: "form" }
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

This colab provides functionality and examples for interacting with the long-form case studies from the paper Khasentino, Belyaeva, Liu, Yang, Furlotte et al, A personal health large language model for sleep and fitness coaching, *Nature Medicine*, 2025.

In [None]:
# Imports.
import collections
import dataclasses
import json
import re

from matplotlib import ticker
from matplotlib import transforms
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# These files must be loaded from the GitHub repository onto the colab.
SLEEP_CASE_STUDIES_JSONL = '/content/sleep_case_studies.all.jsonl'
FITNESS_CASE_STUDIES_JSONL = '/content/fitness_case_studies.all.jsonl'

SPLITS = ('train', 'validation', 'test', 'holdout', 'holdout_model')


BUCKETIZED_AGE_TO_MIDPOINT = {
    '80+': 82,
    '<30': 27,
    '[30-34]': 32,
    '[35-39]': 37,
    '[40-44]': 42,
    '[45-49]': 47,
    '[50-54]': 52,
    '[55-59]': 57,
    '[60-64]': 62,
    '[65-69]': 67,
    '[70-74]': 72,
    '[75-79]': 77,
}

SEX_TO_NORMALIZED_SEX = {
    'Male': 'Male',
    'male': 'Male',
    'Female': 'Female',
    'female': 'Female',
}


@dataclasses.dataclass(frozen=True, kw_only=True)
class SleepCaseStudy:
  # Meta fields.
  case_study_id: str
  user_id: str
  vertical: str
  split: str
  # Input fields.
  input: str
  # Output fields.
  insight_output: str
  etiology_output: str
  recommendation_output: str

  def __post_init__(self):
    if self.vertical != 'sleep':
      raise ValueError(f'Vertical must be "sleep", got {self.vertical}')
    if self.split not in SPLITS:
      raise ValueError(f'Split must be in {SPLITS}, got {self.split}')
    # Ensure these can be successfully computed.
    self.midpoint_age
    self.sex

  @property
  def midpoint_age(self) -> int:
    """Returns the midpoint of the bucketized age of the user."""
    age_str = re.search(
        r'male, (.*) years old.$', self.input.split('\n')[0]
    ).group(1)
    return BUCKETIZED_AGE_TO_MIDPOINT[age_str]

  @property
  def sex(self) -> str:
    """Returns the normalized sex of the user."""
    sex_str = re.match(
        r'The user is (female|male),', self.input.split('\n')[0]
    ).group(1)
    return SEX_TO_NORMALIZED_SEX[sex_str]


@dataclasses.dataclass(frozen=True, kw_only=True)
class FitnessCaseStudy:
  # Meta fields.
  case_study_id: str
  user_id: str
  vertical: str
  split: str
  # Input fields.
  demographics_input: str
  training_load_input: str
  sleep_input: str
  health_metrics_input: str
  muscle_soreness_input: str
  subjective_readiness_input: str
  # Output fields.
  demographics_output: str
  training_load_output: str
  sleep_output: str
  health_metrics_output: str
  readiness_assessment_output: str

  def __post_init__(self):
    if self.vertical != 'fitness':
      raise ValueError(f'Vertical must be "fitness", got {self.vertical}')
    if self.split not in SPLITS:
      raise ValueError(f'Split must be in {SPLITS}, got {self.split}')
    # Ensure these can be successfully computed.
    self.midpoint_age
    self.height
    self.weight
    self.bmi
    self.sex

  @property
  def midpoint_age(self) -> int:
    """Returns the midpoint of the bucketized age of the user."""
    age_str = re.match(
        r'Age: (.*)$', self.demographics_input.split('\n')[0]
    ).group(1)
    return BUCKETIZED_AGE_TO_MIDPOINT[age_str]

  @property
  def height(self) -> float:
    """Returns the height of the user."""
    return float(
        re.match(
            r'Height: (.*)m$', self.demographics_input.split('\n')[1]
        ).group(1)
    )

  @property
  def weight(self) -> float:
    """Returns the weight of the user."""
    return float(
        re.match(
            r'Weight: (.*)kg$', self.demographics_input.split('\n')[2]
        ).group(1)
    )

  @property
  def bmi(self) -> float:
    """Returns the BMI of the user."""
    return float(
        re.match(r'BMI: (.*)$', self.demographics_input.split('\n')[3]).group(1)
    )

  @property
  def sex(self) -> str:
    """Returns the normalized sex of the user."""
    sex = re.match(
        r'Gender: (Male|Female)$', self.demographics_input.split('\n')[-1]
    ).group(1)
    return SEX_TO_NORMALIZED_SEX[sex]


def _validate_sleep_case_studies(case_studies: list[SleepCaseStudy]) -> None:
  """Performs validation of the sleep case studies."""
  # Check case study counts.
  if len(case_studies) != 557:
    raise ValueError(f'Expected 557 sleep studies: {len(case_studies)=}')
  num_unique_ids = len(set(cs.case_study_id for cs in case_studies))
  num_unique_users = len(set(cs.user_id for cs in case_studies))
  if not (num_unique_ids == num_unique_users == 507):
    raise ValueError(
        f'Unexpected total counts {num_unique_ids=}, {num_unique_users=}'
    )
  # Check counts across splits.
  split_counts = collections.Counter(cs.split for cs in case_studies)
  if split_counts != {
      'train': 319,
      'validation': 69,
      'test': 69,
      'holdout': 50,
      'holdout_model': 50,
  }:
    raise ValueError(f'Unexpected split counts: {split_counts}')

  # Check that the holdout studies and holdout model studies are on the same
  # data.
  holdout_studies = {
      cs.case_study_id: cs for cs in case_studies if cs.split == 'holdout'
  }
  holdout_model_studies = {
      cs.case_study_id: cs for cs in case_studies if cs.split == 'holdout_model'
  }
  if len(holdout_studies) != 50:
    raise ValueError(f'Expected 50 holdout studies: {len(holdout_studies)=}')
  if sorted(holdout_studies) != sorted(holdout_model_studies):
    raise ValueError('Holdout and holdout_model are on different case studies.')

  def sleep_inputs(study: SleepCaseStudy) -> tuple[str, ...]:
    return study.case_study_id, study.user_id, study.vertical, study.input

  for case_study_id, holdout_study in holdout_studies.items():
    holdout_model_study = holdout_model_studies[case_study_id]
    if sleep_inputs(holdout_study) != sleep_inputs(holdout_model_study):
      raise ValueError(
          'Holdout and model studies have different inputs: '
          f'{holdout_study=}, {holdout_model_study=}'
      )


def _validate_fitness_case_studies(
    case_studies: list[FitnessCaseStudy],
) -> None:
  """Performs validation of the fitness case studies."""
  # Check case study counts.
  if len(case_studies) != 400:
    raise ValueError(f'Expected 400 fitness studies: {len(case_studies)=}')
  num_unique_ids = len(set(cs.case_study_id for cs in case_studies))
  num_unique_users = len(set(cs.user_id for cs in case_studies))
  if num_unique_ids != 350:
    raise ValueError(f'Unexpected number of unique ids: {num_unique_ids}')
  if num_unique_users != 58:
    raise ValueError(f'Unexpected number of unique users: {num_unique_users}')
  # Check counts across splits.
  split_counts = collections.Counter(cs.split for cs in case_studies)
  if split_counts != {
      'train': 210,
      'validation': 45,
      'test': 45,
      'holdout': 50,
      'holdout_model': 50,
  }:
    raise ValueError(f'Unexpected split counts: {split_counts}')

  # Check that the holdout studies and holdout model studies are on the same
  # data.
  holdout_studies = {
      cs.case_study_id: cs for cs in case_studies if cs.split == 'holdout'
  }
  holdout_model_studies = {
      cs.case_study_id: cs for cs in case_studies if cs.split == 'holdout_model'
  }
  if len(holdout_studies) != 50:
    raise ValueError(f'Expected 50 holdout studies: {len(holdout_studies)=}')
  if sorted(holdout_studies) != sorted(holdout_model_studies):
    raise ValueError('Holdout and holdout_model are on different case studies.')

  def fitness_inputs(study: FitnessCaseStudy) -> tuple[str, ...]:
    return (
        study.case_study_id,
        study.user_id,
        study.vertical,
        study.demographics_input,
        study.training_load_input,
        study.sleep_input,
        study.health_metrics_input,
        study.muscle_soreness_input,
        study.subjective_readiness_input,
    )

  for case_study_id, holdout_study in holdout_studies.items():
    holdout_model_study = holdout_model_studies[case_study_id]
    if fitness_inputs(holdout_study) != fitness_inputs(holdout_model_study):
      raise ValueError(
          'Holdout and model studies have different inputs: '
          f'{holdout_study=}, {holdout_model_study=}'
      )


def load_case_studies(vertical: str) -> list[SleepCaseStudy | FitnessCaseStudy]:
  """Loads all case studies for the given `vertical`.

  Args:
    vertical: Either "sleep" or "fitness".

  Returns:
    A list of all case studies of the given vertical.

  Raises:
    ValueError: If `vertical` is not in ("sleep", "fitness").
    ValueError: Any expected post-condition of the dataset is violated.
  """
  if vertical == 'sleep':
    with open(SLEEP_CASE_STUDIES_JSONL) as f:
      retval = [SleepCaseStudy(**json.loads(l)) for l in f.readlines()]
      _validate_sleep_case_studies(retval)
      return retval
  elif vertical == 'fitness':
    with open(FITNESS_CASE_STUDIES_JSONL) as f:
      retval = [FitnessCaseStudy(**json.loads(l)) for l in f.readlines()]
      _validate_fitness_case_studies(retval)
      return retval
  else:
    raise ValueError(f'Vertical must be in ("sleep", "fitness"): {vertical}')

In [None]:
# Load all case studies.
g_sleep_studies = load_case_studies('sleep')
g_fitness_studies = load_case_studies('fitness')

# Code to plot case study figures.


In [None]:
# Set general plot style.
sns.set_style(
    'white',
    {
        'font.family': ['Liberation Sans'],
    },
)
plt.rc('axes.spines', top=False, right=False)


def plot_extended_data_fig_10(
    sleep_case_studies: list[SleepCaseStudy],
    fitness_case_studies: list[FitnessCaseStudy],
) -> None:
  fig = plt.figure(figsize=(8, 10), dpi=300)
  subfigs = fig.subfigures(2, 1)
  ax_top = subfigs[0].subplots(1, 2, sharey=False)
  ax_bottom = subfigs[1].subplots(1, 2, sharey=False)

  large_font_size = 16
  small_font_size = 12
  color = '#4285f4'

  # Restrict to individuals.
  individual_sleep_ages = {
      cs.user_id: cs.midpoint_age for cs in sleep_case_studies
  }
  individual_sleep_genders = {cs.user_id: cs.sex for cs in sleep_case_studies}
  individual_fitness_ages = {
      cs.user_id: cs.midpoint_age for cs in fitness_case_studies
  }
  individual_fitness_genders = {
      cs.user_id: cs.sex for cs in fitness_case_studies
  }
  assert len(individual_sleep_ages) == 507
  assert len(individual_fitness_ages) == 58

  # Plot age sections.
  subfigs[0].suptitle('Sleep', fontsize=large_font_size, weight='bold')
  subfigs[1].suptitle('Fitness', fontsize=large_font_size, weight='bold')

  age_bins = list(range(25, 90, 5))
  for ax, ages, ymax in [
      (ax_top[0], individual_sleep_ages.values(), 125),
      (ax_bottom[0], individual_fitness_ages.values(), 25),
  ]:
    ax.hist(ages, bins=age_bins, color=color)
    ax.set_xlabel('Age', fontsize=small_font_size)
    ax.set_ylabel('Number of individuals', fontsize=small_font_size)
    ax.set_ylim((0, ymax))
    ax.tick_params(
        bottom=True,
        left=True,
        width=1.5,
        direction='inout',
        labelsize=small_font_size,
    )
    ax.set_xticks(
        ticks=[27.5, 42.5, 62.5, 82.5],
        labels=['<30', '40-44', '60-64', '  80+'],
        fontsize=small_font_size,
    )

  # Plot gender sections.
  sleep_gender_counts = collections.Counter(individual_sleep_genders.values())
  sleep_labels, sleep_values = zip(*sorted(sleep_gender_counts.items()))
  fitness_gender_counts = collections.Counter(
      individual_fitness_genders.values()
  )
  fitness_labels, fitness_values = zip(*sorted(fitness_gender_counts.items()))
  for ax, labels, values, ymax in [
      (ax_top[1], sleep_labels, sleep_values, 270),
      (ax_bottom[1], fitness_labels, fitness_values, 50),
  ]:
    ax.bar(labels, values, color=color)
    ax.set_xlabel('Gender', fontsize=small_font_size)
    ax.set_ylabel('', fontsize=small_font_size)
    ax.set_ylim((0, ymax))
    ax.tick_params(
        bottom=False,
        left=True,
        width=1.5,
        direction='inout',
        labelsize=small_font_size,
    )
    ax.set_xticks(
        ticks=ax.get_xticks(),
        labels=labels,
        fontsize=small_font_size,
    )

  # Add a, b labels.
  trans = transforms.ScaledTranslation(-20 / 72, 7 / 72, fig.dpi_scale_trans)
  ax_top[0].text(
      0.05,
      0.95,
      'a',
      transform=fig.transFigure + trans,
      fontsize='24',
      va='bottom',
      weight='bold',
  )
  ax_bottom[0].text(
      0.05,
      0.45,
      'b',
      transform=fig.transFigure + trans,
      fontsize='24',
      va='bottom',
      weight='bold',
  )
  plt.show()


def plot_supplementary_fig_13(
    fitness_case_studies: list[FitnessCaseStudy],
) -> None:
  """Plots the supplementary figure 13."""
  fig = plt.figure(figsize=(10, 5), dpi=300)
  subfigs = fig.subfigures(1, 1)
  ax_top = subfigs.subplots(1, 3, sharey=False)

  large_font_size = 16
  small_font_size = 12
  color = '#4285f4'

  # Restrict to individuals.
  individual_bmi = {cs.user_id: cs.bmi for cs in fitness_case_studies}
  individual_height = {cs.user_id: cs.height for cs in fitness_case_studies}
  individual_weight = {cs.user_id: cs.weight for cs in fitness_case_studies}
  assert len(individual_bmi) == 58

  fig.suptitle('Fitness', fontsize=large_font_size, weight='bold')
  for ax, data, xlabel, ylabel in [
      (
          ax_top[0],
          list(individual_bmi.values()),
          'Body mass index (BMI)',
          'Number of individuals',
      ),
      (ax_top[1], list(individual_height.values()), 'Height (m)', ''),
      (ax_top[2], list(individual_weight.values()), 'Weight (kg)', ''),
  ]:
    ax.hist(data, color=color)
    ax.set_xlabel(xlabel, fontsize=small_font_size)
    ax.set_ylabel(ylabel, fontsize=small_font_size)
    ax.set_ylim((0, 16))
    ax.xaxis.set_major_locator(ticker.MaxNLocator(nbins=5))
    ax.tick_params(
        bottom=False,
        left=True,
        width=1.5,
        direction='inout',
        labelsize=small_font_size,
    )

  plt.show()

In [None]:
plot_extended_data_fig_10(g_sleep_studies, g_fitness_studies)

In [None]:
plot_supplementary_fig_13(g_fitness_studies)