```
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
```

# PH-LLM figures

TODO(jtcosentino): Flush out README, overall documentation, and function doc
strings.

This notebook reproduces a subset of the main and extended data figures from the
PH-LLM manuscript. Specifically:

-   Fig. 2.{c,b}: Long-form case study evaluation and performance.
-   Extended Data Fig. 3: Pairwise Gwet's AC2 measuring inter-rater reliability
    between primary and secondary raters.
-   Extended Data Fig. 4: Contingency tables showing pairwise rating agreement
    between raters.
-   Extended Data Fig. 5: Sleep and fitness case study human evaluation results
    by principle.
-   Extended Data Fig. 6: Contingency tables showing pairwise rating agreement
    between our best AutoRaters, their corresponding expert raters, and other
    experts.
-   Extended Data Fig. 7: Automatic evaluation of coaching recommendations
    across PH-LLM and baseline models.
-   Extended Data Fig. 8: Effect of fine-tuning data scale on model performance
    in coaching recommendations.

This notebooks assumes that it will be run on
[Google Colab](https://colab.research.google.com/) and requires that the
following files be uploaded to the `/content/` directory:

-   `fitness_autoeval_external_ratings.tsv`: Fitness AutoEval ratings for
    external model comparisons (i.e., comparing PH-LLM, GPT 4 Turbo, Claude 3
    Opus, etc.).
-   `fitness_autoeval_subsample_ratings.tsv`: Fitness AutoEval ratings for
    subsampled model comparisons (i.e., comparing Gemini Ultra, PH-LLM trained
    with 25% of the training dataset, PH-LLM trained with 50% of the training
    dataset, and PH-LLM).
-   `fitness_human_expert_ratings.tsv`: Fitness human expert ratings.
-   `sleep_autoeval_external_ratings.tsv`: Sleep AutoEval ratings for external
    model comparisons (i.e., comparing PH-LLM, GPT 4 Turbo, Claude 3 Opus,
    etc.).
-   `sleep_autoeval_subsample_ratings.tsv`: Sleep AutoEval ratings for
    subsampled model comparisons (i.e., comparing Gemini Ultra, PH-LLM trained
    with 25% of the training dataset, PH-LLM trained with 50% of the training
    dataset, and PH-LLM).
-   `sleep_human_expert_ratings.tsv`: Sleep human expert ratings.
-   `pro_preds_bootstrap.tsv`: Bootstrapping results of model predictions for
    Patient Reported Outcomes (PRO).
-   `pro_prevalence.tsv`: Prevalence of binary targets in Patient Reported
    Outcomes (PRO).

Output figures are also written to the same `/content/` directory as PDFs:

-   `figure_2_cd.pdf`
-   `figure_3_cd.pdf`
-   `extended_data_figure_3.pdf`
-   `extended_data_figure_4.pdf`
-   `extended_data_figure_5.pdf`
-   `extended_data_figure_6.pdf`
-   `extended_data_figure_7.pdf`
-   `extended_data_figure_8.pdf`
-   `extended_data_figure_9.pdf`

## Install dependencies and prepare environment

In [None]:
!pip install irrCAC

# Google Colab comes with NumPy preinstalled. However, installing irrCAC also
# installs a newer version of NumPy that is required for irrCAC. Thus we need to
# restart the runtime for the changes to take effect.
# After the session is restarted, you can run the rest of the notebook.
import os
os.kill(os.getpid(), 9)

In [None]:
import abc
import collections
import concurrent.futures
import dataclasses
import enum
import itertools
import os
import string
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence

import immutabledict
import irrCAC.raw
from matplotlib import container
from matplotlib import figure
from matplotlib import text
from matplotlib import transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
import seaborn as sns
import sklearn.metrics
import statsmodels.stats.multitest

In [None]:
# Set global plotting styles.
sns.set_style('white')
plt.rc('axes.spines', top=False, right=False)

### Global constants

In [None]:
class RaterType(enum.Enum):
  """Denotes whether a rater is a primary or secondary rater."""

  PRIMARY = 'primary'
  SECONDARY = 'secondary'


class Vertical(enum.Enum):
  """Denotes a vertical."""

  FITNESS = 'fitness'
  SLEEP = 'sleep'


class RatingsSource(enum.Enum):
  """Denotes a source of ratings."""

  HUMAN_EXPERT = 'human_expert'
  AUTOEVAL_SUBSAMPLE = 'autoeval_subsample'
  AUTOEVAL_EXTERNAL = 'autoeval_external'


# A mapping of vertical to ratings base dir.
_VERTICAL_TO_BASE_DIR = immutabledict.immutabledict({
    Vertical.FITNESS: '/content/',
    Vertical.SLEEP: '/content/',
})

# Ratings dataframe columns used when plotting.
COL_RATING: str = 'rating'
COL_RATING_HUMAN_EXPERT: str = 'rating_human_expert'
COL_RATING_AUTOEVAL_FITNESS: str = 'rating_autoeval_fitness_primary_c_high_var'
COL_RATING_AUTOEVAL_SLEEP: str = 'rating_autoeval_sleep_primary_c_high_var'
COL_CASE_STUDY_ID: str = 'case_study_id'
COL_CONVERSATION_SOURCE: str = 'conversation_source'
COL_SECTION_TAG: str = 'tag'
COL_PRINCIPLE: str = 'principle'
COL_RATER: str = 'rater'

# Columns derived from the dataframe and used within plotting function.
DERIVED_COL_KEY: str = 'key'
DERIVED_COL_KEY_NO_RATER: str = 'key_no_rater'

# Conversation source keys (i.e., the source of a given response).
KEY_HUMAN_EXPERT: str = 'human_expert'
KEY_GEMINI_ULTRA: str = 'gemini_ultra'
KEY_PHLLM: str = 'phllm'
KEY_GPT_4_TURBO: str = 'gpt_4_turbo'
KEY_CLAUDE_3_OPUS: str = 'claude_3_opus'
KEY_SUBSAMPLE_25: str = 'phllm_subsample_25pct'
KEY_SUBSAMPLE_50: str = 'phllm_subsample_50pct'

# Color hex codes used when plotting.
COLOR_BLUE: str = '#4285f4'
COLOR_BLUE_LIGHT: str = '#D3E3FD'
COLOR_BLUE_DARK: str = '#076EFF'
COLOR_BLUE_MID: str = '#4FABFF'
COLOR_GRAY = '#34495E'
COLOR_GRAY_DARK = '#202124'
COLOR_GREEN: str = '#1E8E3E'
COLOR_ORANGE: str = '#FFB482'
COLOR_PINK_LIGHT: str = '#FFC7E8'
COLOR_PINK_DARK: str = '#FBA9D6'
COLOR_PURPLE_DARK: str = '#C58AF9'
COLOR_YELLOW: str = '#FBBC04'

# A mapping of conversation source keys to figure labels.
CONVERSATION_SOURCE_KEY_TO_LABEL: immutabledict.immutabledict[str, str] = (
    immutabledict.immutabledict({
        KEY_HUMAN_EXPERT: 'Human Expert',
        KEY_GEMINI_ULTRA: 'Gemini Ultra',
        KEY_PHLLM: 'PH-LLM',
        KEY_GPT_4_TURBO: 'GPT-4 Turbo',
        KEY_CLAUDE_3_OPUS: 'Claude 3 Opus',
        KEY_SUBSAMPLE_25: 'Subsample 25%',
        KEY_SUBSAMPLE_50: 'Subsample 50%',
    })
)

# A mapping of AutoEval model IDs to figure labels.
AUTOEVAL_MODEL_TO_LABEL: immutabledict.immutabledict[str, str] = (
    immutabledict.immutabledict({
        COL_RATING_AUTOEVAL_FITNESS: 'Best AutoRater (C)',
        COL_RATING_AUTOEVAL_SLEEP: 'Best AutoRater (C)',
    })
)

# A mapping of conversation source keys to source colors.
CONVERSATION_SOURCE_PALETTE: immutabledict.immutabledict[str, str] = (
    immutabledict.immutabledict({
        KEY_HUMAN_EXPERT: COLOR_BLUE_LIGHT,
        KEY_GEMINI_ULTRA: COLOR_BLUE_MID,
        KEY_PHLLM: COLOR_BLUE_DARK,
        KEY_GPT_4_TURBO: COLOR_ORANGE,
        KEY_CLAUDE_3_OPUS: COLOR_PINK_LIGHT,
        KEY_SUBSAMPLE_25: COLOR_PINK_DARK,
        KEY_SUBSAMPLE_50: COLOR_PURPLE_DARK,
    })
)


# A tuple defining the order of conversation sources in the main figure.
CONVERSATION_SOURCE_ORDER_MAIN: tuple[str, ...] = (
    KEY_GEMINI_ULTRA,
    KEY_PHLLM,
    KEY_HUMAN_EXPERT,
)

# A tuple defining the order of conversation sources in the external figure.
CONVERSATION_SOURCE_ORDER_EXTERNAL: tuple[str, ...] = (
    *CONVERSATION_SOURCE_ORDER_MAIN,
    KEY_GPT_4_TURBO,
    KEY_CLAUDE_3_OPUS,
)

# A tuple defining the order of conversation sources in the subsample figure.
CONVERSATION_SOURCE_ORDER_SUBSAMPLE: tuple[str, ...] = (
    KEY_GEMINI_ULTRA,
    KEY_SUBSAMPLE_25,
    KEY_SUBSAMPLE_50,
    KEY_PHLLM,
)

# A tuple defining the order of fitness section tags in figures.
FITNESS_SECTION_TAG_ORDER: tuple[str, ...] = (
    'training_load',
    'sleep',
    'health_metrics',
    'assessment',
)

# A tuple defining the order of sleep section tags in figures.
SLEEP_SECTION_TAG_ORDER: tuple[str, ...] = (
    'insights',
    'etiology',
    'recommendations',
)

# A mapping of section tags to figure labels.
SECTION_TAG_TO_LABEL: immutabledict.immutabledict[str, str] = (
    immutabledict.immutabledict({
        'training_load': 'Training Load',
        'sleep': 'Sleep',
        'health_metrics': 'Health Metrics',
        'assessment': 'Assessment',
        'insights': 'Insights',
        'etiology': 'Etiology',
        'recommendations': 'Recommendations',
    })
)

# A tuple defining the order of principles in figures.
PRINCIPLE_ORDER: tuple[str, ...] = (
    'important_user_data',
    'no_unimportant_user_data',
    'no_incorrect_user_data',
    'important_interpretations',
    'no_unimportant_interpretations',
    'no_incorrect_important_interpretations',
    'no_incorrect_unimportant_interpretations',
    'no_assumptions',
    'important_domain_knowledge',
    'no_unimportant_domain_knowledge',
    'no_incorrect_domain_knowledge',
    'no_hallucinations',
    'non_harmful',
    'readable',
    'overall_quality',
)

# A mapping of principles to figure labels.
PRINCIPLE_TO_LABEL: immutabledict.immutabledict[str, str] = (
    immutabledict.immutabledict({
        'important_user_data': 'Important User Data',
        'no_unimportant_user_data': 'No Unimportant User Data',
        'no_incorrect_user_data': 'No Incorrect User Data',
        'important_interpretations': 'Important Interpretations',
        'no_unimportant_interpretations': 'No Unimportant Interpretations',
        'no_incorrect_important_interpretations': (
            'No Incorrect Important Interpretations'
        ),
        'no_incorrect_unimportant_interpretations': (
            'No Incorrect Unimportant Interpretations'
        ),
        'no_assumptions': 'No Assumptions',
        'important_domain_knowledge': 'Important Domain Knowledge',
        'no_unimportant_domain_knowledge': 'No Unimportant Domain Knowledge',
        'no_incorrect_domain_knowledge': 'No Incorrect Domain Knowledge',
        'no_hallucinations': 'No Hallucinations',
        'non_harmful': 'Non Harmful',
        'readable': 'Readable',
        'overall_quality': 'Overall Quality',
    })
)

# A mapping of metrics to figure labels.
METRIC_TO_LABEL_BINARY = immutabledict.immutabledict({
    'gwet_ac2': "Gwet's AC2",
})

# General labels used in plotting (e.g., axes).
LABEL_SECTION: str = 'Section'
LABEL_PRINCIPLE: str = 'Principle'
LABEL_AVG_RATING: str = 'Average Rating'


# Constants for Patient Reported Constants (PRO).

# Path to bootstrapping results of PRO predictions and prevalence of labels.
PRO_BOOTSTRAP_PATH = '/content/pro_preds_bootstrap.tsv'
PRO_PREVALENCE_PATH = '/content/pro_prevalence.tsv'

# Mapping from binary targets in PRO to their names shown in figure.
PRO_BINARY_TARGETS_TO_NAME = {
    'alertness_score_binary': 'Alert',
    'tiredness_score_binary': 'Tiredness',
    'satisfied_score_binary': 'Satisfied',
    'refreshing_score_binary': 'Refreshed',
    'sleepy_during_daytime_score_binary': 'Sleepy during daytime',
    'restless_score_binary': 'Very restless',
    'trouble_falling_asleep_score_binary': 'Trouble falling asleep',
    'enough_sleep_score_binary': 'Enough sleep',
    'trouble_staying_asleep_score_binary': 'Trouble staying asleep',
    'trouble_concentrating_score_sleepimpairment_binary': (
        'SI due to trouble concentrating'
    ),
    'trouble_sleeping_score_binary': 'Trouble sleeping',
    'irritability_score_sleepimpairment_binary': 'SI due to irritability',
    'trouble_productivity_score_binary': 'Trouble being productive',
    'problems_score_binary': 'Having problems',
    'quality_score_binary': 'Quality',
    'trouble_staying_awake_score_binary': 'Trouble staying awake',
}
PRO_BINARY_TARGETS = list(PRO_BINARY_TARGETS_TO_NAME.values())

# Mapping from model suffix to model name.
PRO_SUFFIX_TO_MODEL = {
    '_llm': 'PH-LLM w/ Adapter',
    '_llm_few_shot': 'PH-LLM Few-shot',
    '_llm_zero_shot': 'PH-LLM Zero-shot',
    '_logreg_2914_data': 'LogReg',
    '_cnn': 'CNN',
}

# Mapping from model name to color.
PRO_MODEL_TO_COLOR = {
    'PH-LLM w/ Adapter': COLOR_BLUE_DARK,
    'PH-LLM Few-shot': COLOR_BLUE_MID,
    'PH-LLM Zero-shot': COLOR_BLUE_LIGHT,
    'Prevalence': COLOR_GRAY,
    'LogReg': COLOR_ORANGE,
    'CNN': COLOR_GREEN,
}

# Mapping from metric to their configs for plotting.
# xlabel: The xlabel of subplot.
# non_significant_indices: Target indices w/o significantly different
#   performance.
# significant_loc: The location of its significant mark denoted by (y, h) where
#   y is the coordinate at which to start the annotation bracket, and h is the
#   height of the annotation bracket.
# legend_loc: The location of legend.
PRO_METRIC_TO_CONFIGS = {
    'auc': {
        'xlabel': 'AUROC',
        'non_significant_indices': [12, 13],
        'significant_loc': [0.8, 0.02],
        'legend_loc': 'lower left',
    },
    'auprc': {
        'xlabel': 'AUPRC',
        'non_significant_indices': [12, 13, 15],
        'significant_loc': [0.3, 0.01],
        'legend_loc': 'lower center',
    },
}

# Column names in dataframes.
PRO_COL_MODEL = 'Model'
PRO_COL_TARGET = 'Target'

### Data loading utilities

In [None]:
def tsv_to_df(filepath: str) -> pd.DataFrame:
  """Returns a dataframe from a TSV file."""
  with open(filepath, mode='r') as f:
    df = pd.read_csv(f, sep='\t')
  return df


def load_ratings_df(
    vertical: Vertical,
    rating_source: RatingsSource,
) -> pd.DataFrame:
  """Returns a ratings dataframe for the given vertical and rating source."""
  base_dir = _VERTICAL_TO_BASE_DIR[vertical]
  filename = f'{vertical.value.lower()}_{rating_source.value}_ratings.tsv'
  filepath = os.path.join(base_dir, filename)
  return tsv_to_df(filepath)


def extract_pro_target_and_model(
    raw_name: str,
    prefixes: list[str] | None = None,
    suffixes: list[str] | None = None,
) -> tuple[str, str]:
  if prefixes is None:
    prefixes = list(PRO_BINARY_TARGETS_TO_NAME.keys())
  if suffixes is None:
    suffixes = list(PRO_SUFFIX_TO_MODEL.keys())
  for p in prefixes:
    if raw_name.startswith(p):
      s = raw_name[len(p) :]
      if s in suffixes:
        return PRO_BINARY_TARGETS_TO_NAME[p], PRO_SUFFIX_TO_MODEL[s]
      return p, ''
  raise ValueError(f'{raw_name} does not start from target name.')


def load_pro_df(
    models_of_interest: list[str],
    prediction_path: str = PRO_BOOTSTRAP_PATH,
    prevalence_path: str = PRO_PREVALENCE_PATH,
) -> pd.DataFrame:
  # Load bootstrapping results and prevalence info.
  df_pro = tsv_to_df(prediction_path)
  df_pro_prevalence = tsv_to_df(prevalence_path)

  # Add readable model and target name.
  df_pro[[PRO_COL_TARGET, PRO_COL_MODEL]] = df_pro['prediction_name'].apply(
      lambda x: pd.Series(extract_pro_target_and_model(x))
  )
  # Filter in results from models of interest.
  df_pro = df_pro[df_pro[PRO_COL_MODEL].isin(models_of_interest)]
  # Order rows by model then target (to align with barplot).
  target_category_type = pd.CategoricalDtype(
      categories=PRO_BINARY_TARGETS, ordered=True
  )
  model_category_type = pd.CategoricalDtype(categories=models_of_interest, ordered=True)
  df_pro[PRO_COL_TARGET] = df_pro[PRO_COL_TARGET].astype(target_category_type)
  df_pro[PRO_COL_MODEL] = df_pro[PRO_COL_MODEL].astype(model_category_type)
  df_pro = df_pro.sort_values(by=[PRO_COL_MODEL, PRO_COL_TARGET]).reset_index(
      drop=True
  )
  return df_pro, df_pro_prevalence

### Significance test utilities

In [None]:
def _is_pair_significant(
    grouped_df: pd.DataFrame,
    source_col: str,
    source_a: str,
    source_b: str,
    rating_column: str,
) -> pd.Series:
  result = grouped_df.apply(
      lambda x: (
          sum(x[source_col] == source_a) + sum(x[source_col] == source_b),
          scipy.stats.ranksums(
              x[x[source_col] == source_a][rating_column],
              x[x[source_col] == source_b][rating_column],
          ),
      ),
      include_groups=False,
  )
  return result


def _split_pair_test(pair_test: pd.Series) -> tuple[
    pd.Series,
    pd.Series,
    pd.Series,
]:
  result = (
      pair_test.apply(lambda x: x[0]),
      pair_test.apply(lambda x: x[1].statistic),
      pair_test.apply(lambda x: x[1].pvalue),
  )
  return result


def _unroll_pair_tests(
    pair_tests: dict[tuple[str, str], pd.Series],
) -> pd.DataFrame:
  unrolled_tests = []
  unrolled_columns = []
  for (source_a, source_b), pair_test in pair_tests.items():
    n, statistic, pvalue = _split_pair_test(pair_test)
    corrected_pvalue = pd.Series(
        statsmodels.stats.multitest.fdrcorrection(pvalue.values)[1],
        index=pvalue.index,
    )
    is_sig = corrected_pvalue < 0.05
    effect_size = statistic / np.sqrt(n)
    unrolled_tests.extend([n, statistic, corrected_pvalue, is_sig, effect_size])
    col_prefix = f'{source_a} vs {source_b}'
    unrolled_columns.extend([
        f'{col_prefix} n',
        f'{col_prefix} statistic',
        f'{col_prefix} p-value',
        f'{col_prefix} significance',
        f'{col_prefix} effect size',
    ])
  unrolled_test_df = pd.concat(unrolled_tests, axis=1)
  unrolled_test_df.columns = unrolled_columns
  return unrolled_test_df


def significance_test(
    df,
    groupby_column,
    rating_column: str = 'rating',
) -> pd.DataFrame:
  """Performs significance testing for each group using Wilcoxon rank-sum test.

  The signigicance testing is done for all pairs. The p-values are adjusted with
  FDR correction for multiple hypothesis testing.

  Args:
    df: The dataframe to perform significance testing on.
    groupby_column: The column to group by.
    rating_column: The column to use for ratings.

  Returns:
    A dataframe with statistic, p-value, and significance result for each group
      and test.
  """
  # Group the dataframe by the specified column.
  grouped_df = df.groupby(groupby_column)

  # Perform significance testing for each rater pair.
  is_significant_pairs: dict[tuple[str, str], pd.Series] = {}
  sources = sorted(df[COL_CONVERSATION_SOURCE].unique())
  for i, source_a in enumerate(sources[:-1]):
    for source_b in sources[i + 1 :]:
      is_significant_pairs[(source_a, source_b)] = _is_pair_significant(
          grouped_df=grouped_df,
          source_col=COL_CONVERSATION_SOURCE,
          source_a=source_a,
          source_b=source_b,
          rating_column=rating_column,
      )
      is_significant_pairs[(source_b, source_a)] = _is_pair_significant(
          grouped_df=grouped_df,
          source_col=COL_CONVERSATION_SOURCE,
          source_a=source_b,
          source_b=source_a,
          rating_column=rating_column,
      )

  # Concatenate the results together into a dataframe.
  stat_results = _unroll_pair_tests(is_significant_pairs)
  return stat_results


def get_sig_sections(sig_df: pd.DataFrame) -> dict[str, list[tuple[str, str]]]:
  """Returns a dictionary mapping sections to their statsig diff pairs."""
  sig_pairs = collections.defaultdict(list)
  check_pairs = [
      (
          (KEY_GEMINI_ULTRA, KEY_PHLLM),
          f'{KEY_GEMINI_ULTRA} vs {KEY_PHLLM} significance',
      ),
      (
          (KEY_HUMAN_EXPERT, KEY_PHLLM),
          f'{KEY_HUMAN_EXPERT} vs {KEY_PHLLM} significance',
      ),
  ]
  for section in sig_df.index:
    for pair, label in check_pairs:
      if label not in sig_df.columns:
        continue
      is_sig = sig_df[(sig_df.index == section)][label].values[0]
      if is_sig:
        sig_pairs[section].append(pair)
  return sig_pairs


def get_non_sig_sections(
    sig_df: pd.DataFrame,
) -> dict[str, list[tuple[str, str]]]:
  """Returns a dictionary mapping sections to their non statsig diff pairs."""
  sig_pairs = collections.defaultdict(list)
  check_pairs = []
  all_sources = list(CONVERSATION_SOURCE_KEY_TO_LABEL)
  for source_a in all_sources:
    for source_b in all_sources:
      if source_a == source_b:
        continue
      else:
        check_pairs.append(
            ((source_a, source_b), f'{source_a} vs {source_b} significance')
        )

  for section in sig_df.index:
    sig_pairs[section] = []
    for pair, label in check_pairs:
      if label not in sig_df.columns:
        continue
      is_sig = sig_df[(sig_df.index == section)][label].values[0]
      if not is_sig:
        sig_pairs[section].append(pair)
  return sig_pairs

### Bootstrapping utilities

In [None]:
# A function that computes a numeric outcome from label and prediction arrays.
BootstrappableFn = Callable[[np.ndarray, np.ndarray], float]

# Constants denoting the expected case and control values for binary encodings.
BINARY_LABEL_CONTROL = 0
BINARY_LABEL_CASE = 1

# The maximum number of threads/workers.
_MAX_PARALLEL_WORKERS = 10

# Represents a numpy array of indices for a single bootstrap sample.
IndexSample = np.ndarray


class Metric(abc.ABC):
  """Represents a callable wrapper class for a named metric function.

  Attributes:
    name: The metric's name.
  """

  def __init__(self, name: str, fn: BootstrappableFn) -> None:
    """Initializes the metric.

    Args:
      name: The metric's name.
      fn: A function that computes an outcome from label and prediction arrays.
        The function's signature should accept a `y_true` label array and a
        `y_pred` model prediction array. This function is invoked when the
        `Metric` instance is called.
    """
    self._name: str = name
    self._fn: BootstrappableFn = fn

  @property
  def name(self) -> str:
    """The `Metric`'s name."""
    return self._name

  @abc.abstractmethod
  def _validate(self, y_true: np.ndarray, y_pred: np.ndarray) -> None:
    """Validates the `y_true` labels and `y_pred` predictions.

    Note: Each prediction subarray `y_pred[i, ...]` at index `i` should
    correspond to the `y_true[i]` label.

    Args:
      y_true: The ground truth label targets.
      y_pred: The target predictions.

    Raises:
      ValueError: If the first dimension of `y_true` and `y_pred` do not match.
    """
    if y_true.shape[0] != y_pred.shape[0]:
      raise ValueError(
          '`y_true` and `y_pred` first dimension mismatch: '
          f'{y_true.shape[0]} != {y_pred.shape[0]}'
      )

  def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Invokes the `Metric`'s function.

    Args:
      y_true: The ground truth label values.
      y_pred: The target predictions.

    Returns:
      The result of the `Metric.fn(y_true, y_pred)`.
    """
    self._validate(y_true, y_pred)
    return self._fn(y_true, y_pred)

  def __str__(self) -> str:
    return self.name


class ContinuousMetric(Metric):
  """Represents a callable wrapper class for a named continuous label function.

  Attributes:
    name: The metric's name.
  """

  # Note: This is a useful delegation since _validate is an @abc.abstractmethod.
  def _validate(  # pylint: disable=useless-super-delegation
      self,
      y_true: np.ndarray,
      y_pred: np.ndarray,
  ) -> None:
    """Validates the `y_true` labels and `y_pred` predictions.

    Args:
      y_true: The ground truth label values.
      y_pred: The target predictions.

    Raises:
      ValueError: If the first dimension of `y_true` and `y_pred` do not match.
    """
    super()._validate(y_true, y_pred)


class BinaryMetric(Metric):
  """Represents a callable wrapper class for a named binary label function.

  This class asserts that the provided `y_true` labels are binary targets in
  `{0, 1}` and that `y_true` contains at least one element in each class, i.e.,
  not all samples are from the same class.

  Attributes:
    name: The metric's name.
  """

  def _validate(self, y_true: np.ndarray, y_pred: np.ndarray) -> None:
    """Validates the `y_true` labels and `y_pred` predictions.

    Args:
      y_true: The ground truth label values.
      y_pred: The target predictions.

    Raises:
      ValueError: If the first dimension of `y_true` and `y_pred` do not match.
      ValueError: If `y_true` labels are nonbinary, i.e., not all values are in
        `{BINARY_LABEL_CONTROL, BINARY_LABEL_CASE}` or if `y_true` does not
        contain at least one element from each class.
    """
    super()._validate(y_true, y_pred)
    if not is_valid_binary_label(y_true):
      raise ValueError(
          '`y_true` labels must be in `{BINARY_LABEL_CONTROL, '
          'BINARY_LABEL_CASE}` and have at least one element from '
          f'each class; found: {y_true}'
      )


def is_binary(metric: Metric) -> bool:
  """Whether `metric` is a metric computed with binary `y_true` labels."""
  return isinstance(metric, BinaryMetric)


def is_valid_binary_label(array: np.ndarray) -> bool:
  """Whether `array` is a "valid" binary label array for bootstrapping.

  We define a valid binary label array as an array that contains only binary
  values, i.e., `{BINARY_LABEL_CONTROL, BINARY_LABEL_CASE}`, and contains at
  least one value from each class.

  Args:
    array: A numpy array.

  Returns:
    Whether `array` is a "valid" binary label array.
  """
  is_case_mask = array == BINARY_LABEL_CASE
  is_control_mask = array == BINARY_LABEL_CONTROL
  return (
      np.any(is_case_mask)
      and np.any(is_control_mask)
      and np.all(np.logical_or(is_case_mask, is_control_mask))
  )


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class NamedArray:
  """Represents a named numpy array.

  Attributes:
    name: The array name.
    values: A numpy array.
  """

  name: str
  values: np.ndarray

  def __post_init__(self):
    if not self.name:
      raise ValueError('`name` must be specified.')

  def __len__(self) -> int:
    return len(self.values)

  def __str__(self) -> str:
    return f'{self.__class__.__name__}({self.name})'


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class Label(NamedArray):
  """Represents a named numpy array of ground truth label targets.

  Attributes:
    name: The label name.
    values: A numpy array containing ground truth label targets.
  """


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class Prediction(NamedArray):
  """Represents a named numpy array of target predictions.

  Attributes:
    model_name: The name of the model that generated the predictions.
    name: The name of the predictions (e.g., the prediction column).
    values: A numpy array containing model predictions.
  """

  model_name: str

  def __post_init__(self):
    super().__post_init__()
    if not self.model_name:
      raise ValueError('`model_name` must be specified.')

  def __str__(self) -> str:
    return f'{self.__class__.__name__}({self.model_name}.{self.name})'


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class SampleMean:
  """Represents an estimate of the population mean for a given sample.

  Attributes:
    mean: The mean of a given sample.
    stddev: The standard deviation of the sample mean.
    num_samples: The number of samples used to calculate `mean` and `stddev`.

  Raises:
    ValueError: If `num_samples` is not >= `1`.
    ValueError: If `stddev` is not `0` when `num_samples` is `1`.
  """

  mean: float
  stddev: float
  num_samples: int

  def __post_init__(self):
    # Ensure we have a valid number of samples.
    if self.num_samples < 1:
      raise ValueError(f'`num_samples` must be >= `1`: {self.num_samples}')

    # Ensure the standard deviation is 0 given a single sample.
    if self.num_samples == 1 and self.stddev != 0.0:
      raise ValueError(
          f'`stddev` must be `0` if `num_samples` is `1`: {self.stddev:0.4f}'
      )

  def __str__(self) -> str:
    return f'{self.mean:0.4f} (SD={self.stddev:0.4f}, n={self.num_samples})'


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class ConfidenceInterval(SampleMean):
  """Represents a confidence interval (CI) for a sample mean.

  Attributes:
    mean: The mean of a given sample.
    stddev: The standard deviation of the sample mean.
    num_samples: The number of samples used to calculate `mean` and `stddev`.
    level: The confidence level at which the CI is calculated (e.g., 95).
    ci_lower: The lower limit of the `level` confidence interval.
    ci_upper: The upper limit of the `level` confidence interval.

  Raises:
    ValueError: If `num_samples` is not >= `1`.
    ValueError: If `stddev` is not `0` when `num_samples` is `1`.
    ValueError: If `level` is not in range (0, 100].
    ValueError: If `ci_lower` or `ci_upper` does not match not `mean` when
      `num_samples` is `1`.
  """

  level: float
  ci_lower: float
  ci_upper: float

  def __post_init__(self):
    super().__post_init__()
    # Ensure we have a valid confidence level.
    if not 0 < self.level <= 100:
      raise ValueError(f'`level` must be in range (0, 100]: {self.level:0.2f}')

    # Ensure confidence intervals match the sample mean given a single sample.
    if self.num_samples == 1:
      if (self.ci_lower != self.mean) or (self.ci_upper != self.mean):
        raise ValueError(
            '`ci_lower` and `ci_upper` must match `mean` if `num_samples` is '
            f'1: mean={self.mean:0.4f}, ci_lower={self.ci_lower:0.4f}, '
            f'ci_upper={self.ci_upper:0.4f}'
        )

  def __str__(self) -> str:
    return (
        f'{self.mean:0.4f} (SD={self.stddev:0.4f}, n={self.num_samples}, '
        f'{self.level:0>6.2f}% CI=[{self.ci_lower:0.4f}, '
        f'{self.ci_upper:0.4f}])'
    )


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class Result:
  """Represents a bootstrapped metric result for an individual model.

  Attributes:
    model_name: The model's name.
    prediction_name: The model's prediction name (e.g., the model head's name or
      the label name used in training).
    metric_name: The metric's name.
    ci: A confidence interval describing the distribution of metric samples.
  """

  model_name: str
  prediction_name: str
  metric_name: str
  ci: ConfidenceInterval

  def __post_init__(self):
    # Ensure model, prediction, and metric names are specified.
    if not self.model_name:
      raise ValueError('`model_name` must be specified.')
    if not self.prediction_name:
      raise ValueError('`prediction_name` must be specified.')
    if not self.metric_name:
      raise ValueError('`metric_name` must be specified.')

  def __str__(self) -> str:
    return (
        f'{self.model_name}.{self.prediction_name}: '
        f'{self.metric_name}: {self.ci}'
    )


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class PairedResult:
  """Represents a paired bootstrapped metric result for two models.

  Attributes:
    model_name_a: The first model's name.
    prediction_name_a: The first model's prediction name (e.g., the model head's
      name or the label name used in training).
    model_name_b: The second model's name.
    prediction_name_b: The second model's prediction name (e.g., the model
      head's name or the label name used in training).
    metric_name: The metric's name.
    ci: A confidence interval describing the distribution of differences between
      the first and second models' metric samples.
  """

  model_name_a: str
  prediction_name_a: str
  model_name_b: str
  prediction_name_b: str
  metric_name: str
  ci: ConfidenceInterval

  def __post_init__(self):
    # Ensure model, prediction, and metric names are specified.
    if not self.model_name_a:
      raise ValueError('`model_name_a` must be specified.')
    if not self.prediction_name_a:
      raise ValueError('`prediction_name_a` must be specified.')
    if not self.model_name_b:
      raise ValueError('`model_name_b` must be specified.')
    if not self.prediction_name_b:
      raise ValueError('`prediction_name_b` must be specified.')
    if not self.metric_name:
      raise ValueError('`metric_name` must be specified.')

  def __str__(self) -> str:
    return (
        f'({self.model_name_a}.{self.prediction_name_a} - '
        f'{self.model_name_b}.{self.prediction_name_b}): '
        f'{self.metric_name}: {self.ci}'
    )


def _compute_confidence_interval(
    samples: np.ndarray,
    ci_level: float,
) -> ConfidenceInterval:
  """Computes the mean, standard deviation, and confidence interval for samples.

  Args:
    samples: A boostrapped array of observed sample values.
    ci_level: The confidence level/width of the desired confidence interval.

  Returns:
    A `Result` containing the mean, standard deviation, and the `ci_level`%
    confidence interval for the observed sample values.
  """
  sample_mean = np.mean(samples, axis=0)
  sample_std = np.std(samples, axis=0)

  lower_percentile = (100 - ci_level) / 2
  upper_percentile = 100 - lower_percentile
  percentiles = [lower_percentile, upper_percentile]
  ci_lower, ci_upper = np.percentile(a=samples, q=percentiles, axis=0)

  ci = ConfidenceInterval(
      mean=sample_mean,
      stddev=sample_std,
      num_samples=len(samples),
      level=ci_level,
      ci_lower=ci_lower,
      ci_upper=ci_upper,
  )

  return ci


def _generate_sample_indices(
    label: Label,
    is_binary: bool,
    num_bootstrap: int,
    seed: int,
) -> List[IndexSample]:
  """Returns a list of `num_bootstrap` randomly sampled bootstrap indices.

  Args:
    label: The ground truth label targets.
    is_binary: Whether to generate valid binary samples; i.e., each index sample
      contains at least one index corresponding to a label from each class.
    num_bootstrap: The number of bootstrap indices to generate.
    seed: The random seed; set prior to generating bootstrap indices.

  Returns:
    A list of `num_bootstrap` bootstrap sample indices.
  """
  rng = np.random.default_rng(seed)
  num_observations = len(label)
  sample_indices = []
  while len(sample_indices) < num_bootstrap:
    index = rng.integers(0, high=num_observations, size=num_observations)
    sample_true = label.values[index]
    # If computing a binary metric, skip indices that result in invalid labels.
    if is_binary and not is_valid_binary_label(sample_true):
      continue
    sample_indices.append(index)
  return sample_indices


def _compute_metric_samples(
    metric: Metric,
    label: Label,
    predictions: Sequence[Prediction],
    sample_indices: Sequence[np.ndarray],
) -> Dict[str, np.ndarray]:
  """Generates `num_bootstrap` metric samples for each `Prediction`.

  Note: This method assumes that label and prediction values are orded so that
  the value at index `i` in a given `Prediction` corresponds to the label value
  at index `i` in `label`. Both the `Label` and `Prediction` arrays are indexed
  using the given `sample_indices`.

  Args:
    metric: An instance of a bootstrappable `Metric`; used to compute samples.
    label: The ground truth label targets.
    predictions: A list of target predictions from a set of models.
    sample_indices: An array of bootstrap sample indices. If empty, returns the
      single value computed on the entire dataset for each prediction.

  Returns:
    A mapping of model names to the corresponding metric samples array.
  """
  if not sample_indices:
    metric_samples = {}
    for prediction in predictions:
      value = metric(label.values, prediction.values)
      metric_samples[prediction.model_name] = np.asarray([value])
    return metric_samples

  metric_samples = {prediction.model_name: [] for prediction in predictions}
  for index in sample_indices:
    sample_true = label.values[index]
    for prediction in predictions:
      sample_value = metric(sample_true, prediction.values[index])
      metric_samples[prediction.model_name].append(sample_value)

  metric_samples = {
      name: np.asarray(samples) for name, samples in metric_samples.items()
  }

  return metric_samples


def _compute_all_metric_samples(
    metrics: Sequence[Metric],
    contains_binary_metric: bool,
    label: Label,
    predictions: Sequence[Prediction],
    num_bootstrap: int,
    seed: int,
) -> Dict[str, Dict[str, np.ndarray]]:
  """Generates `num_bootstrap` samples for each `Prediction` and `Metric`.

  Args:
    metrics: A sequence of a bootstrappable `Metric` instances.
    contains_binary_metric: Whether the set of metrics contains a binary metric.
    label: The ground truth label targets.
    predictions: A list of target predictions from a set of models.
    num_bootstrap: The number of bootstrap iterations.
    seed: The random seed; set prior to generating bootstrap indices.

  Returns:
    A mapping of metric names to model-sample dictionaries.
  """
  sample_indices = _generate_sample_indices(
      label,
      contains_binary_metric,
      num_bootstrap,
      seed,
  )
  metric_samples_kwargs = []
  for metric in metrics:
    metric_samples_kwargs.append({
        'metric': metric,
        'label': label,
        'predictions': predictions,
        'sample_indices': sample_indices,
    })

  with concurrent.futures.ThreadPoolExecutor(
      max_workers=min(_MAX_PARALLEL_WORKERS, len(metrics))
  ) as executor:
    futures = [
        executor.submit(_compute_metric_samples, **kwargs)
        for kwargs in metric_samples_kwargs
    ]
    metric_samples = [future.result() for future in futures]

  return {
      metric.name: metric_sample
      for metric, metric_sample in zip(metrics, metric_samples)
  }


def _process_metric_samples(
    metric: Metric,
    predictions: Sequence[Prediction],
    model_names_to_metric_samples: Dict[str, np.ndarray],
    ci_level: float,
) -> List[Result]:
  """Compute `ConfidenceInterval`s for metric samples across predictions."""
  results = []
  for prediction in predictions:
    metric_samples = model_names_to_metric_samples[prediction.model_name]
    ci = _compute_confidence_interval(metric_samples, ci_level)
    result = Result(prediction.model_name, prediction.name, metric.name, ci)
    results.append(result)
  return results


def _bootstrap(
    metrics: Sequence[Metric],
    contains_binary_metric: bool,
    label: Label,
    predictions: Sequence[Prediction],
    num_bootstrap: int,
    ci_level: float,
    seed: int,
) -> Dict[str, List[Result]]:
  """Performs bootstrapping for all models using the given metrics.

  Args:
    metrics: A sequence of a bootstrappable `Metric` instances.
    contains_binary_metric: Whether the set of metrics contains a binary metric.
    label: The ground truth label targets.
    predictions: A list of target predictions from a set of models.
    num_bootstrap: The number of bootstrap iterations.
    ci_level: The confidence level/width of the desired confidence interval.
    seed: The random seed; set prior to generating bootstrap indices.

  Returns:
    A dictionary mapping metric names to a list of `Result`s containing the mean
    metric values of each model over `num_bootstrap` bootstrapping iterations.
  """
  metric_to_model_to_samples = _compute_all_metric_samples(
      metrics,
      contains_binary_metric,
      label,
      predictions,
      num_bootstrap,
      seed,
  )
  process_metric_samples_kwargs = []
  for metric in metrics:
    process_metric_samples_kwargs.append({
        'metric': metric,
        'predictions': predictions,
        'model_names_to_metric_samples': metric_to_model_to_samples[
            metric.name
        ],
        'ci_level': ci_level,
    })

  with concurrent.futures.ThreadPoolExecutor(
      max_workers=min(_MAX_PARALLEL_WORKERS, len(metrics))
  ) as executor:
    futures = [
        executor.submit(_process_metric_samples, **kwargs)
        for kwargs in process_metric_samples_kwargs
    ]
    metric_samples = [future.result() for future in futures]

  return {
      metric.name: metric_sample
      for metric, metric_sample in zip(metrics, metric_samples)
  }


def validate_and_mask(
    label: Label,
    predictions: Sequence[Prediction],
    mask: Optional[np.ndarray] = None,
) -> tuple[Label, list[Prediction]]:
  """Validates bootstrap argument shape and applies the mask if needed."""
  for prediction in predictions:
    if len(label) != len(prediction):
      raise ValueError('Label and prediction dimensions do not match.')
  if mask is not None and len(mask) != len(label):
    raise ValueError('Label and prediction dimensions do not match mask.')
  if mask is not None:
    label = Label(label.name, label.values[mask])
    predictions = [
        Prediction(
            name=label.name, values=p.values[mask], model_name=p.model_name
        )
        for p in predictions
    ]
  return label, predictions  # pytype: disable=bad-return-type


class PerformanceMetricsParallel:
  """A named collection of invocable, bootstrapable `Metric`s.

  Initializes a class that applies the given `Metric` functions to new ground
  truth labels and predictions. `Metric`s can be evaluated with and without
  bootstrapping.

  Raises:
    ValueError: if an item in `metrics` is not of type `Metric`.
  """

  def __init__(
      self,
      name: str,
      metrics: Optional[list[Metric]] = None,
  ) -> None:
    if metrics is None:
      raise ValueError('No metric is provided.')

    for metric in metrics:
      if not isinstance(metric, Metric):
        raise ValueError('Invalid metric value: must be of class `Metric`.')

    if len(metrics) != len({metric.name for metric in metrics}):
      raise ValueError(f'Metric names must be unique: {metrics}')

    self.name = name
    self.metrics = metrics
    self.contains_binary = any(is_binary(m) for m in metrics)

  def compute(
      self,
      label: Label,
      predictions: Sequence[Prediction],
      mask: Optional[np.ndarray] = None,
      n_bootstrap: int = 0,
      conf_interval: float = 95,
      seed: int = 42,
  ) -> dict[str, list[Result]]:
    """Evaluates all metrics using the given labels and predictions.

    Args:
      label: The ground truth label targets.
      predictions: A list of target predictions from a set of models.
      mask: A boolean mask; applied to `y_true` and `y_pred`.
      n_bootstrap: An integer denoting the number of bootstrap iterations for
        each evaluation metric.
      conf_interval: A float denoting the width of confidence interval.
      seed: An int denoting the seed for the PRNG.

    Returns:
      A dictionary of bootstrapped metrics keyed on metric name with
      `Result` values.

    Raises:
      ValueError: If the dimensions of `y_true`, `y_pred`, or `mask` do not
      match, or labels are not in {0 , 1}.
    """
    label, predictions = validate_and_mask(label, predictions, mask)
    metric_results = _bootstrap(
        self.metrics,
        contains_binary_metric=self.contains_binary,
        label=label,
        predictions=predictions,
        num_bootstrap=n_bootstrap,
        ci_level=conf_interval,
        seed=seed,
    )

    return metric_results


def result_to_record(result: Result) -> dict[str, Any]:
  """Converts a `Result` into a flattened dictionary record."""
  record = {
      'model_name': result.model_name,
      'prediction_name': result.prediction_name,
      'metric_name': result.metric_name,
      'mean': result.ci.mean,
      'stddev': result.ci.stddev,
      'num_samples': result.ci.num_samples,
      'ci_level': result.ci.level,
      'ci_lower': result.ci.ci_lower,
      'ci_upper': result.ci.ci_upper,
  }
  return record


def results_map_to_df(
    metrics_to_results: Mapping[str, Sequence[Result]],
) -> pd.DataFrame:
  """Converts a metrics to `Result`s mapping to a dataframe."""
  all_results = []
  for metric_results in metrics_to_results.values():
    all_results.extend(metric_results)
  records = [result_to_record(result) for result in all_results]
  return pd.DataFrame.from_records(records)


def build_bootstrap_inputs(
    df: pd.DataFrame,
    label_column: str,
    prediction_columns: list[str],
    id_col: str,
) -> tuple[Label, list[Prediction]]:
  """Returns a bootstrapping label and prediction list."""
  if (
      (label_column in prediction_columns)
      or (id_col in prediction_columns)
      or (len(prediction_columns) != len(set(prediction_columns)))
  ):
    raise ValueError('Label and ID columns must be unique.')
  expected_columns = {id_col, label_column, *prediction_columns}
  column_diff = expected_columns - set(df.columns)
  if column_diff:
    raise ValueError(f'Missing expected dataframe columns: {column_diff}.')

  label = Label(label_column, df[label_column].to_numpy())
  preds = []
  for pred_col in prediction_columns:
    pred_values = df[pred_col].to_numpy()
    preds.append(Prediction(label_column, pred_values, pred_col))
  return label, preds


def compute(
    metrics: list[Metric],
    label: Label,
    preds: Sequence[Prediction],
    n_bootstrap: int,
) -> pd.DataFrame:
  """Bootstraps predictions using `metrics` and returns a result dataframe."""
  perf = PerformanceMetricsParallel(name='', metrics=metrics)
  metric_results = perf.compute(label, preds, n_bootstrap=n_bootstrap)
  results_df = results_map_to_df(metric_results)
  return results_df


def bootstrap_rater_pairs_binary(
    df: pd.DataFrame,
) -> pd.DataFrame:
  df = df.copy()
  df[DERIVED_COL_KEY_NO_RATER] = df.apply(_key_no_rater_column, axis=1)
  rater_pairs = get_rater_pairs(df)
  df_pivot = df.pivot_table(
      index=DERIVED_COL_KEY_NO_RATER,
      columns=COL_RATER,
      values=COL_RATING,
  ).reset_index()
  bs_dfs = []
  for rater_a, rater_b in rater_pairs:
    sub_df = (
        df_pivot[[DERIVED_COL_KEY_NO_RATER, rater_a, rater_b]]
        .dropna()
        .reset_index(drop=True)
    )
    label, preds = build_bootstrap_inputs(
        df=sub_df,
        label_column=rater_a,
        prediction_columns=[rater_b],
        id_col=DERIVED_COL_KEY_NO_RATER,
    )
    results_df = compute(
        metrics=BS_METRICS_BINARY,
        label=label,
        preds=preds,
        n_bootstrap=1000,
    )
    bs_dfs.append(results_df)
  return pd.concat(bs_dfs).reset_index(drop=True)

In [None]:
def gwet_ac2(ratings_a: np.ndarray, ratings_b: np.ndarray) -> float:
  """Calculates Gwet's AC2 coefficient for two sets of nominal ratings.

  Args:
      ratings_a: A numpy array of ratings from rater A.
      ratings_b: A numpy array of ratings from rater B.

  Returns:
      The AC1 inter-rater reliability coefficient.
  """
  df = pd.DataFrame({'rater A': ratings_a, 'rater B': ratings_b})
  cac = irrCAC.raw.CAC(df, weights='quadratic')
  return cac.gwet()['est']['coefficient_value']


def sample_count(ratings_a: np.ndarray, ratings_b: np.ndarray) -> float:
  del ratings_b  # Unused.
  return len(ratings_a)


BS_METRICS_BINARY: list[Metric] = [
    ContinuousMetric(
        'gwet_ac2',
        gwet_ac2,
    ),
    ContinuousMetric(
        'sample_count',
        sample_count,
    ),
]

### Plotting utilities

In [None]:
def _get_label_position(
    labels: list[text.Text],
    target_label: str,
) -> int:
  """Returns the position of the taget label in the given labels."""
  for i, label in enumerate(labels):
    if label.get_text() == target_label:
      return i
  raise ValueError(f'Label {target_label} not found in labels: {labels}')


def get_label_position_x(
    target_label: str,
    ax: plt.Axes,
) -> int:
  """Returns the position of the target x label in the given axis."""
  return _get_label_position(ax.xaxis.get_ticklabels(), target_label)


def get_label_position_y(
    target_label: str,
    ax: plt.Axes,
) -> int:
  """Returns the position of the target y label in the given axis."""
  return _get_label_position(ax.yaxis.get_ticklabels(), target_label)


def get_bar_x_coord(
    label: str,
    label_position: int,
    ax: plt.Axes,
) -> int:
  """Returns the x coordinate of the given bar in the given axis."""
  labels = [x.get_text() for x in ax.get_legend().get_texts()]
  bars = [i for i in ax.containers if isinstance(i, container.BarContainer)]
  assert len(labels) == len(bars)
  for bar_label, bar in zip(labels, bars):
    if bar_label == label:
      bar = bar.patches[label_position]
      return bar.get_x() + (bar.get_width() / 2)
  raise ValueError(f'Label {label} not found in axis: {ax}')


def get_bar_mid_x_coords(
    pair: tuple[str, str],
    label_position: int,
    ax: plt.Axes,
) -> tuple[int, int]:
  """Returns the mid x coordinates of each bar pair in the given axis."""
  labels = [x.get_text() for x in ax.get_legend().get_texts()]
  bars = [i for i in ax.containers if isinstance(i, container.BarContainer)]
  assert len(labels) == len(bars)
  x_coords = []
  for label, bar in zip(labels, bars):
    if label in pair:
      bar = bar.patches[label_position]
      x_coords.append(bar.get_x() + (bar.get_width() / 2))
  assert len(x_coords) == 2, (pair, label_position, x_coords)
  return tuple(sorted(x_coords))


def get_bar_mid_y_coords(
    pair: tuple[str, str],
    label_position: int,
    ax: plt.Axes,
) -> tuple[int, int]:
  """Returns the mid y coordinates of each bar pair in the given axis."""
  labels = [x.get_text() for x in ax.get_legend().get_texts()]
  bars = [i for i in ax.containers if isinstance(i, container.BarContainer)]
  assert len(labels) == len(bars)
  y_coords = []
  for label, bar in zip(labels, bars):
    if label in pair:
      bar = bar.patches[label_position]
      y_coords.append(bar.get_y() + (bar.get_height() / 2))
  assert len(y_coords) == 2, y_coords
  return tuple(sorted(y_coords))


def add_significance_x(
    ax: plt.Axes,
    x1: float,
    x2: float,
    y: float,
    h: float,
    text: str = '*',
) -> None:
  """Add significance annotations between two bar midpoints on the x axis.

  Args:
    ax: The axis to annotate.
    x1: The x midpoint coordinate of the first bar.
    x2: The x midpoint coordinate of the second bar.
    y: The y coordinate at which to start the annotation bracket.
    h: The height of the annotation bracket.
    text: The text in the annotation.
  """
  ax.plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=2, c='k')
  ax.text(
      (x1 + x2) * 0.5,
      y + h,
      text,
      ha='center',
      va='bottom',
      color='k',
      fontdict={'family': 'monospace'},
  )


def add_significance_y(
    ax: plt.Axes,
    y1: float,
    y2: float,
    x: float,
    d: float,
    text: str = '*',
) -> None:
  """Add significance annotations between two bar midpoints on the y axis.

  Args:
    ax: The axis to annotate.
    y1: The y midpoint coordinate of the first bar.
    y2: The y midpoint coordinate of the second bar.
    x: The x coordinate at which to start the annotation bracket.
    d: The depth of the annotation bracket.
    text: The text in the annotation.
  """
  ax.plot([x, x + d, x + d, x], [y1, y1, y2, y2], lw=2, c='k')
  ax.text(
      x + (2 * d),
      # Accounting for lineheight, which puts the * a bit too high.
      ((y1 + y2) * 0.5) - 0.1,
      text,
      ha='center',
      va='center',
      color='k',
      fontdict={'family': 'monospace'},
  )


def _set_axis_attrs_section(
    ax: plt.Axes,
    title: str,
    xlabel: str,
    ylabel: str,
    fontsize: int,
) -> None:
  """Sets x and y axis attributes for a "Section" plot."""
  ax.set_title(title, weight='bold', fontsize=fontsize)
  ax.set_ylabel(ylabel, fontsize=fontsize)
  ax.set_ylim((1, 5.5))
  ax.tick_params(bottom=False, left=True, width=1.5, direction='inout')
  ax.set_yticks(ticks=list(range(1, 6, 1)))
  ax.set_yticklabels(labels=list(range(1, 6, 1)), fontsize=fontsize)
  ax.set_xlabel(xlabel, weight='bold', fontsize=fontsize)
  labels = [
      SECTION_TAG_TO_LABEL.get(l.get_text(), l.get_text())
      for l in ax.get_xticklabels()
  ]
  ax.set_xticklabels(labels=labels, fontsize=fontsize)


def _set_axis_attrs_principle(
    ax: plt.Axes,
    title: str,
    xlabel: str,
    ylabel: str,
    title_fontsize: int,
    other_fontsize: int,
    label_rotation_x: int = 0,
    label_rotation_y: int = 0,
) -> None:
  """Sets x and y axis attributes for a "Principle" plot."""
  ax.set_title(title, weight='bold', fontsize=title_fontsize)
  ax.set_xlabel(xlabel, fontsize=title_fontsize)
  ax.set_ylabel(ylabel, weight='bold', fontsize=title_fontsize)
  ax.set_xlim((1, 5.5))
  ax.set_ylim((-1.15, 15.15))
  ax.tick_params(bottom=True, left=False, width=1.5, direction='inout')
  if xlabel:
    ax.set_xticks(ticks=list(range(1, 6, 1)))
    ax.set_xticklabels(
        labels=list(range(1, 6, 1)),
        fontsize=other_fontsize,
        rotation=label_rotation_x,
    )
  if ylabel:
    labels = [
        PRINCIPLE_TO_LABEL.get(l.get_text(), l.get_text())
        for l in ax.get_yticklabels()
    ]
    ax.set_yticklabels(
        labels=labels,
        fontsize=other_fontsize,
        rotation=label_rotation_y,
    )


def _label_axes(fig: figure.Figure) -> None:
  """Adds alpha sublabels (e.g., a, b, ...) to each axis in a figure."""
  labeled_axes = [*fig.get_axes()]
  for i, ax in enumerate(labeled_axes):
    ax_label = f'{string.ascii_lowercase[i]}'
    trans = transforms.ScaledTranslation(-20 / 72, 7 / 72, fig.dpi_scale_trans)
    ax.text(
        -0.04,
        1.0,
        ax_label,
        transform=ax.transAxes + trans,
        fontsize='18',
        va='bottom',
        weight='bold',
    )


def _plot_by_section_to_ax(
    ax: plt.Axes,
    df: pd.DataFrame,
    order: list[str] | None,
    hue_order: list[str] | None,
    rating_col: str = COL_RATING,
    add_stat_sig_annot: bool = False,
    pairwise_statsig: bool = True,
) -> None:
  """Plots ratings by section to the given axis."""
  sns.barplot(
      ax=ax,
      x=COL_SECTION_TAG,
      y=rating_col,
      hue=COL_CONVERSATION_SOURCE,
      data=df,
      errorbar=('ci', 95),
      n_boot=1_000,
      palette=dict(CONVERSATION_SOURCE_PALETTE),
      order=order,
      hue_order=hue_order,
      # Uncomment to add spacing between bars, but this requires fixing the
      # legend (which also has the line width applied).
      # linewidth=5,
      # edgecolor='white',
  )

  if add_stat_sig_annot:
    sig_df = significance_test(
        df=df,
        groupby_column=COL_SECTION_TAG,
        rating_column=rating_col,
    )
    if pairwise_statsig:
      for sig_section, pairs in get_sig_sections(sig_df).items():
        label_position = get_label_position_x(sig_section, ax)
        for pair in pairs:
          if pair[0] not in hue_order or pair[1] not in hue_order:
            continue
          x_coords = get_bar_mid_x_coords(pair, label_position, ax)
          add_significance_x(ax, x_coords[0], x_coords[1], 5.1, 0.1, '*')
    else:
      mean_by_section = (
          df.groupby([COL_SECTION_TAG, COL_CONVERSATION_SOURCE])[rating_col]
          .mean()
          .reset_index()
          .sort_values(by=rating_col, ascending=False)
      )
      highest_per_section = (
          mean_by_section.groupby(COL_SECTION_TAG).first().reset_index()
      )
      section_to_highest_type = dict(
          zip(
              highest_per_section[COL_SECTION_TAG],
              highest_per_section[COL_CONVERSATION_SOURCE],
          )
      )
      # For each pair of non-stat-sig-diff sources in a section, add a "*" to
      # each source that is non-stat-sig-diff from the *highest* source within
      # that section.
      for non_sig_section, pairs in get_non_sig_sections(sig_df).items():
        label_position = get_label_position_x(non_sig_section, ax)
        highest_section = section_to_highest_type[non_sig_section]
        ax.text(
            get_bar_x_coord(highest_section, label_position, ax),
            5.1 + 0.1,
            '*',
            ha='center',
            va='bottom',
            color='k',
            fontdict={'family': 'monospace'},
        )
        starred_labels = set({highest_section})
        for pair in pairs:
          if highest_section not in pair:
            continue
          for pair_element in pair:
            if pair_element in starred_labels:
              continue
            x_coord = get_bar_x_coord(pair_element, label_position, ax)
            ax.text(
                x_coord,
                5.1 + 0.1,
                '*',
                ha='center',
                va='bottom',
                color='k',
                fontdict={'family': 'monospace'},
            )
            starred_labels.add(pair_element)


def plot_case_study_main_fig(
    sleep_df: pd.DataFrame,
    fitness_df: pd.DataFrame,
    sleep_title: str,
    fitness_title: str,
    savefig_filepath: str,
    hue_order: list[str],
    savefig_format: str = 'pdf',
    sleep_rating_col: str = COL_RATING,
    sleep_rating_label: str = LABEL_AVG_RATING,
    fitness_rating_col: str = COL_RATING,
    fitness_rating_label: str = LABEL_AVG_RATING,
    other_font_size: int = 16,
    label_subplots: bool = False,
    pairwise_statsig: bool = True,  # If false, we star the best and any _not_ statsig from best.
    drop_overall: bool = True,
) -> None:
  """Plots ratings by section for both verticals side-by-side."""
  # Drop the overall section.
  if drop_overall:
    sleep_df = sleep_df[sleep_df[COL_SECTION_TAG] != 'overall'].reset_index(
        drop=True
    )
    fitness_df = fitness_df[
        fitness_df[COL_SECTION_TAG] != 'overall'
    ].reset_index(drop=True)

  sleep_df = sleep_df[
      sleep_df[COL_CONVERSATION_SOURCE].isin(hue_order)
  ].reset_index(drop=True)
  fitness_df = fitness_df[
      fitness_df[COL_CONVERSATION_SOURCE].isin(hue_order)
  ].reset_index(drop=True)

  fig, axes = plt.subplots(1, 2, figsize=(16, 6))
  _plot_by_section_to_ax(
      ax=axes[0],
      df=sleep_df,
      order=SLEEP_SECTION_TAG_ORDER,
      hue_order=hue_order,
      rating_col=sleep_rating_col,
      add_stat_sig_annot=True,
      pairwise_statsig=pairwise_statsig,
  )
  _plot_by_section_to_ax(
      ax=axes[1],
      df=fitness_df,
      order=FITNESS_SECTION_TAG_ORDER,
      hue_order=hue_order,
      rating_col=fitness_rating_col,
      add_stat_sig_annot=True,
      pairwise_statsig=pairwise_statsig,
  )

  # Add a legend to the first subplot.
  legend_handles = axes[0].get_legend().legend_handles
  legend_labels = [
      CONVERSATION_SOURCE_KEY_TO_LABEL.get(t.get_text(), t.get_text())
      for t in axes[0].get_legend().get_texts()
  ]
  axes[0].legend(
      loc='lower left',
      frameon=True,
      facecolor='white',
      edgecolor='white',
      framealpha=1,
      handlelength=0.7,
      borderpad=0.3,
      bbox_to_anchor=(0.065, 0),
      handles=legend_handles,
      labels=legend_labels,
  )
  axes[1].get_legend().remove()

  # Update x and y axis attributes.
  _set_axis_attrs_section(
      axes[0],
      title=sleep_title,
      xlabel=LABEL_SECTION,
      ylabel=sleep_rating_label,
      fontsize=other_font_size,
  )
  _set_axis_attrs_section(
      axes[1],
      title=fitness_title,
      xlabel=LABEL_SECTION,
      ylabel=fitness_rating_label,
      fontsize=other_font_size,
  )

  if label_subplots:
    _label_axes(fig)

  fig.tight_layout()
  fig.savefig(savefig_filepath, format=savefig_format, dpi=300)
  plt.show()


def _add_stat_sig_marker_principle(
    df: pd.DataFrame,
    rating_col: str,
    ax: plt.Axes,
) -> None:
  sig_df = significance_test(df, COL_PRINCIPLE, rating_col)
  for sig_section, pairs in get_sig_sections(sig_df).items():
    label_position = get_label_position_y(sig_section, ax)
    for pair in pairs:
      y_coords = get_bar_mid_y_coords(pair, label_position, ax)
      add_significance_y(ax, y_coords[0], y_coords[1], 5.1, 0.1, '*')


def plot_by_principle_all(
    sleep_df: pd.DataFrame,
    sleep_title: str,
    fitness_df: pd.DataFrame,
    fitness_title: str,
    savefig_filepath: str,
    order: list[str] | None,
    hue_order: list[str] | None,
    sleep_rating_col: str = COL_RATING,
    fitness_rating_col: str = COL_RATING,
    title_font_size: int = 16,
    other_font_size: int = 8,
    add_stat_sig_annot: bool = False,
    label_subplots: bool = False,
) -> None:
  """Plots all ratings grouped by conversation source and principle."""
  fig, axes = plt.subplots(1, 2, figsize=(16, 10), sharey=True)
  sns.barplot(
      y=COL_PRINCIPLE,
      x=sleep_rating_col,
      hue=COL_CONVERSATION_SOURCE,
      data=sleep_df,
      errorbar=('ci', 95),
      n_boot=1_000,
      palette=dict(CONVERSATION_SOURCE_PALETTE),
      # Reverse the order since seaborn plots the last in list on top.
      order=order[::-1],
      hue_order=hue_order,
      orient='h',
      ax=axes[0],
  )
  sns.barplot(
      y=COL_PRINCIPLE,
      x=fitness_rating_col,
      hue=COL_CONVERSATION_SOURCE,
      data=fitness_df,
      errorbar=('ci', 95),
      n_boot=1_000,
      palette=dict(CONVERSATION_SOURCE_PALETTE),
      # Reverse the order since seaborn plots the last in list on top.
      order=order[::-1],
      hue_order=hue_order,
      orient='h',
      ax=axes[1],
  )
  if add_stat_sig_annot:
    _add_stat_sig_marker_principle(sleep_df, sleep_rating_col, axes[0])
    _add_stat_sig_marker_principle(fitness_df, fitness_rating_col, axes[1])

  # Add a legend to the first subplot.
  legend_handles = axes[0].get_legend().legend_handles
  legend_labels = [
      CONVERSATION_SOURCE_KEY_TO_LABEL.get(t.get_text(), t.get_text())
      for t in axes[0].get_legend().get_texts()
  ]
  axes[0].legend(
      loc='upper left',
      frameon=True,
      facecolor='white',
      edgecolor='white',
      framealpha=1,
      handlelength=0.7,
      borderpad=0.3,
      bbox_to_anchor=(0, 1 - 0.046),
      reverse=True,
      handles=legend_handles,
      labels=legend_labels,
  )
  axes[1].get_legend().remove()

  # Update x and y axis attributes.
  _set_axis_attrs_principle(
      axes[0],
      sleep_title,
      xlabel=LABEL_AVG_RATING,
      ylabel=LABEL_PRINCIPLE,
      title_fontsize=title_font_size,
      other_fontsize=other_font_size,
      label_rotation_y=45,
  )
  _set_axis_attrs_principle(
      axes[1],
      fitness_title,
      xlabel=LABEL_AVG_RATING,
      ylabel='',
      title_fontsize=title_font_size,
      other_fontsize=other_font_size,
  )

  if label_subplots:
    _label_axes(fig)

  plt.tight_layout()
  fig.savefig(savefig_filepath, format='pdf', dpi=300)
  plt.show()


def _rater_id_to_label(rater_id: str) -> str:
  """Returns the rater label for a given a rater ID."""
  return ' '.join(x.capitalize() for x in rater_id.split('_'))


def _rater_id_to_type(rater_id: str) -> RaterType:
  """Returns the rater type for a given a rater ID."""
  if RaterType.PRIMARY.value in rater_id:
    return RaterType.PRIMARY
  elif RaterType.SECONDARY.value in rater_id:
    return RaterType.SECONDARY
  else:
    raise ValueError(f'Unexpected {rater_id=}')


def _key_column(df_row: dict[str, Any]) -> str:
  """Returns a unique key column value for the given expert rating row."""
  case_study_id = df_row[COL_CASE_STUDY_ID]
  rater = df_row[COL_RATER]
  tag = df_row[COL_SECTION_TAG]
  principle = df_row[COL_PRINCIPLE]
  conversation_source = df_row[COL_CONVERSATION_SOURCE]
  key = f'{case_study_id}::{rater}::{tag}::{principle}::{conversation_source}'
  return key


def _key_no_rater_column(df_row: dict[str, Any]) -> str:
  """Returns a key column value for the given expert rating row sans rater."""
  case_study_id = df_row[COL_CASE_STUDY_ID]
  tag = df_row[COL_SECTION_TAG]
  principle = df_row[COL_PRINCIPLE]
  conversation_source = df_row[COL_CONVERSATION_SOURCE]
  key_no_rater = f'{case_study_id}::{tag}::{principle}::{conversation_source}'
  return key_no_rater


def get_rater_pairs(df: pd.DataFrame) -> set[tuple[str, str]]:
  """Returns the set of rater pairs for which we have replicated ratings."""
  rater_groups = set(
      df.groupby(by=DERIVED_COL_KEY_NO_RATER)
      .rater.agg(lambda x: tuple(sorted(x)))
      .values
  )
  rater_pairs = set()
  for rater_group in rater_groups:
    if len(rater_group) < 2:
      continue
    combinations = list(itertools.combinations(rater_group, 2))
    for combination in combinations:
      rater_pairs.add((combination[0], combination[1]))
  return rater_pairs


def plot_rating_contingency(
    df: pd.DataFrame,
    col_a: str,
    label_a: str,
    col_b: str,
    label_b: str,
    label_values: list[str] = [1, 2, 3, 4, 5],
    plot_percents: bool = False,
    plot_cbar: bool = False,
    ax: plt.Axes | None = None,
    color: str = COLOR_GRAY_DARK,
) -> plt.Axes:
  """Plots a contingency table for samples where both ratings are present.

  Args:
    df: A dataframe containing the target columns.
    col_a: The column containing values for the vertical axis.
    label_a: `col_a`'s label value.
    col_b: The column containing values for the horizontal axis.
    label_b: `col_b`'s label value.
    label_values: The possible label values.
    plot_percents: Whether to plot percentages of the contingency table.
    plot_cbar: Whether to include the cbar to the right of the contingency
      table.
    ax: An optional axis on which the spirogram is plotted; if not specified, a
      new axis is created.

  Returns:
    The axis on which the contingency table was plotted.
  """
  if ax is None:
    ax = plt.axes()

  # Get agreement values.
  contingency_df = df[[col_a, col_b]].dropna().copy()
  label_a_values = contingency_df[col_a].values
  label_b_values = contingency_df[col_b].values
  cf_matrix = sklearn.metrics.confusion_matrix(
      label_a_values,
      label_b_values,
      labels=label_values,
  )
  group_counts = [f'{value:g}' for value in cf_matrix.flatten()]
  group_percentages = [
      f'{value:0.2%}' for value in cf_matrix.flatten() / np.sum(cf_matrix)
  ]
  if plot_percents:
    labels = [
        f'{v2}\n({v3})' for v2, v3 in zip(group_counts, group_percentages)
    ]
  else:
    labels = group_counts
  labels = np.asarray(labels).reshape(5, 5)
  sns.heatmap(
      cf_matrix,
      square=True,
      cmap=sns.light_palette(color, as_cmap=True),
      fmt='',
      annot=labels,
      xticklabels=label_values,
      yticklabels=label_values,
      cbar=plot_cbar,
      ax=ax,
  )
  ax.set_yticklabels(labels=ax.get_yticklabels(), va='center')
  ax.set_ylabel(label_a)
  ax.set_xlabel(label_b)
  return ax


def _process_vertical_pairs(
    df: pd.DataFrame,
    rater_types_a: list[RaterType],
    rater_types_b: list[RaterType],
) -> tuple[list[str], list[tuple[str, str]], list[list[tuple[str, str]]]]:
  """Returns a tuple of colors, rater pairs, and grouped rater pairs."""
  colors = []
  all_filtered_pairs = []
  filtered_pair_groups = []
  for rater_type_a, rater_type_b in zip(rater_types_a, rater_types_b):
    if rater_type_a == RaterType.PRIMARY and rater_type_b == RaterType.PRIMARY:
      colors.append(COLOR_BLUE)
    elif (
        rater_type_a == RaterType.PRIMARY
        and rater_type_b == RaterType.SECONDARY
    ):
      colors.append(COLOR_GREEN)
    elif (
        rater_type_a == RaterType.SECONDARY
        and rater_type_b == RaterType.SECONDARY
    ):
      colors.append(COLOR_YELLOW)
    else:
      raise ValueError('Unknown rater type pairing.')
    target_rater_pair_type = tuple(
        sorted([rater_type_a.value, rater_type_b.value])
    )
    rater_pairs = get_rater_pairs(df)
    filtered_pairs = [
        pair
        for pair in rater_pairs
        if target_rater_pair_type
        == tuple(
            sorted([
                _rater_id_to_type(pair[0]).value,
                _rater_id_to_type(pair[1]).value,
            ])
        )
    ]
    all_filtered_pairs.extend(filtered_pairs)
    filtered_pair_groups.append(filtered_pairs)
  return colors, all_filtered_pairs, filtered_pair_groups


def plot_rater_contingency_all_both_verticals(
    sleep_df: pd.DataFrame,
    fitness_df: pd.DataFrame,
    rater_types_a: list[RaterType],
    rater_types_b: list[RaterType],
    filename: str | None = None,
) -> pd.DataFrame:
  if len(rater_types_a) != len(rater_types_b):
    raise ValueError('rater_type_a and rater_type_b must have the same length.')

  sleep_df = sleep_df.copy()
  fitness_df = fitness_df.copy()

  sleep_df[DERIVED_COL_KEY_NO_RATER] = sleep_df.apply(
      _key_no_rater_column, axis=1
  )
  fitness_df[DERIVED_COL_KEY_NO_RATER] = fitness_df.apply(
      _key_no_rater_column, axis=1
  )

  sleep_colors, sleep_all_filtered_pairs, sleep_filtered_pair_groups = (
      _process_vertical_pairs(sleep_df, rater_types_a, rater_types_b)
  )
  (
      readiness_colors,
      readiness_all_filtered_pairs,
      readiness_filtered_pair_groups,
  ) = _process_vertical_pairs(fitness_df, rater_types_a, rater_types_b)

  assert len(sleep_all_filtered_pairs) == 12
  assert len(readiness_all_filtered_pairs) == 11

  sleep_df_pivot = sleep_df.pivot(
      index=DERIVED_COL_KEY_NO_RATER,
      columns=COL_RATER,
      values=COL_RATING,
  )
  fitness_df_pivot = fitness_df.pivot(
      index=DERIVED_COL_KEY_NO_RATER,
      columns=COL_RATER,
      values=COL_RATING,
  )
  num_cols = 4
  num_rows = 6
  fig, axes = plt.subplots(
      num_rows,
      num_cols + 1,
      figsize=(4 * num_cols, 3.5 * num_rows),
      dpi=300,
      gridspec_kw={'width_ratios': [1, 1, 0.1, 1, 1]},
  )
  # Handle matplotlib returning varied types/shapes for subplots.
  if not isinstance(axes, np.ndarray):
    axes = [[axes]]
  elif len(axes.shape) == 1:
    axes = [axes]

  for offset, df_pivot, colors, filtered_pair_groups in [
      (0, sleep_df_pivot, sleep_colors, sleep_filtered_pair_groups),
      (
          num_cols // 2 + 1,
          fitness_df_pivot,
          readiness_colors,
          readiness_filtered_pair_groups,
      ),
  ]:
    num_seen_in_other_groups = 0
    for k, filtered_pair_group in enumerate(filtered_pair_groups):
      for n, (rater_a, rater_b) in enumerate(
          sorted(
              filtered_pair_group,
              key=lambda x: (
                  min([_rater_id_to_label(x[0]), _rater_id_to_label(x[1])]),
                  max(
                      [_rater_id_to_label(x[0]), _rater_id_to_label(x[1])]
                  ),  # For tie breaks.
              ),
          )
      ):
        n = n + num_seen_in_other_groups
        i = n // (num_cols // 2)
        j = n % (num_cols // 2)
        j += offset
        label_a = _rater_id_to_label(rater_a)
        label_b = _rater_id_to_label(rater_b)
        # Nature of labels should allow us to alpha order so we are consistent
        # with axes across pairs.
        if label_a < label_b:
          rater_a, rater_b = rater_b, rater_a
          label_a, label_b = label_b, label_a
        rater_pair_df = (
            df_pivot[[
                rater_a,
                rater_b,
            ]]
            .dropna()
            .reset_index(drop=True)
        )
        plot_rating_contingency(
            rater_pair_df,
            col_a=rater_a,
            label_a=label_a,
            col_b=rater_b,
            label_b=label_b,
            ax=axes[i][j],
            color=colors[k],
        )
      num_seen_in_other_groups += len(filtered_pair_group)

  # Hide unused axes.
  for i in range(num_rows):
    axes[i][2].set_visible(False)
  axes[num_rows - 1][num_cols].set_visible(False)

  # Add a, b, etc. labels to subfigures.
  labeled_axes = [axes[0][0], axes[0][3]]
  for i, ax in enumerate(labeled_axes):  # Skip the legend.
    ax_label = f'{string.ascii_lowercase[i]}'
    trans = transforms.ScaledTranslation(-20 / 72, 7 / 72, fig.dpi_scale_trans)
    ax.text(
        -0.04,
        1.0,
        ax_label,
        transform=ax.transAxes + trans,
        fontsize='18',
        va='bottom',
        weight='bold',
    )

  fig.tight_layout()
  plt.show()
  if filename:
    fig.savefig(filename, format='pdf', dpi=300)


def _plot_stats_for_metric_binary(
    df: pd.DataFrame,
    metric_name: str,
    ax: plt.Axes,
    plot_cmap: bool = False,
    label_rotation_x: int = 90,
    label_rotation_y: int = 0,
    title_prefix: str | None = None,
) -> None:
  rater_to_label = {
      r: _rater_id_to_label(r)
      for r in list(df.model_name.unique()) + list(df.prediction_name.unique())
  }
  label_to_rater = {l: r for r, l in rater_to_label.items()}
  rater_to_label = {
      k: v.replace('Sleep ', '').replace('Fitness ', '')
      for k, v in rater_to_label.items()
  }
  label_to_rater = {
      k.replace('Sleep ', '').replace('Fitness ', ''): v
      for k, v in label_to_rater.items()
  }
  df = df.copy()
  df = df.replace(rater_to_label)
  df = df.replace(label_to_rater)
  sub_df = (
      df[df.metric_name == metric_name]
      .reset_index(drop=True)
      .rename(columns={'model_name': 'outer', 'prediction_name': 'inner'})
      .replace(rater_to_label)
  )
  reflection_df = sub_df.copy()
  reflection_df['inner'], reflection_df['outer'] = (
      reflection_df.outer,
      reflection_df.inner,
  )
  sub_df = pd.concat([sub_df, reflection_df]).reset_index(drop=True)
  sub_df_pivot = (
      sub_df.pivot(
          index='outer',
          columns='inner',
          values='mean',
      )
      .sort_index()
      .sort_index(axis=1)
  )
  sub_df_pivot = sub_df_pivot.sort_index().sort_index(axis=1)
  annot_labels = []
  for outer in sub_df_pivot.index:
    dim_labels = []
    for inner in sub_df_pivot.columns:
      records = sub_df[
          (sub_df.inner == inner) & (sub_df.outer == outer)
      ].to_dict('records')
      if len(records) == 0:
        label = ''
      else:
        record = records[0]
        sample_records = df[
            (df.model_name == label_to_rater[outer])
            & (df.prediction_name == label_to_rater[inner])
            & (df.metric_name == 'sample_count')
        ].to_dict('records')
        if len(sample_records) == 0:
          # Handle reflection.
          sample_records = df[
              (df.model_name == label_to_rater[inner])
              & (df.prediction_name == label_to_rater[outer])
              & (df.metric_name == 'sample_count')
          ].to_dict('records')
        assert len(sample_records) == 1, sample_records
        sample_record = sample_records[0]
        ci = f'[{record["ci_lower"]:.2f}–{record["ci_upper"]:.2f}]'.replace(
            '0.', '.'
        )
        label = f'{record["mean"]:0.3f}\n{ci}\nn={sample_record["mean"]:0.0f}'
      dim_labels.append(label)
    annot_labels.append(dim_labels)
  sns.heatmap(
      data=sub_df_pivot,
      vmin=0,
      vmax=1,
      annot=annot_labels,
      fmt='',
      ax=ax,
      cmap=sns.cubehelix_palette(as_cmap=True, reverse=True),
      cbar=plot_cmap,
  )
  num_core = len([x for x in sub_df_pivot.index if 'Primary' in x])
  ax.axhline(y=num_core, color='black', linewidth=5)
  ax.axvline(x=num_core, color='black', linewidth=5)
  ax.tick_params(axis='x', labelrotation=label_rotation_x)
  ax.tick_params(axis='y', labelrotation=label_rotation_y)
  title = (
      METRIC_TO_LABEL_BINARY[metric_name]
      if title_prefix is None
      else f'{title_prefix} {METRIC_TO_LABEL_BINARY[metric_name]}'
  )
  ax.set_title(title)
  ax.set_xlabel('')
  ax.set_ylabel('')


def plot_rater_pair_stats_binary_only_gwet(
    sleep_df: pd.DataFrame,
    readiness_df: pd.DataFrame,
    filename: str | None = None,
) -> None:
  fig, axes = plt.subplots(
      nrows=2,
      ncols=2,
      figsize=(25, 15),
      dpi=300,
      sharey=False,
      sharex=False,
      gridspec_kw={'height_ratios': [1, 20]},
  )
  gs = axes[0, 0].get_gridspec()
  # remove the underlying Axes
  for ax in axes[0, :]:
    ax.remove()
  legend_ax = fig.add_subplot(gs[0, :])
  _plot_stats_for_metric_binary(
      sleep_df,
      'gwet_ac2',
      axes[1][0],
      plot_cmap=False,
      label_rotation_x=45,
      label_rotation_y=45,
      title_prefix='Sleep:',
  )
  _plot_stats_for_metric_binary(
      readiness_df,
      'gwet_ac2',
      axes[1][1],
      plot_cmap=False,
      label_rotation_x=45,
      label_rotation_y=45,
      title_prefix='Fitness:',
  )
  fig.colorbar(
      axes[1][1].get_children()[0],
      cax=legend_ax,
      orientation='horizontal',
  )
  legend_ax.xaxis.set_ticks_position('top')
  plt.subplots_adjust(wspace=0.075, hspace=0.075)

  # Add a, b, etc. labels to subfigures.
  labeled_axes = [*fig.get_axes()]
  for i, ax in enumerate(labeled_axes[:-1]):  # Skip the legend.
    ax_label = f'{string.ascii_lowercase[i]}'
    trans = transforms.ScaledTranslation(-20 / 72, 7 / 72, fig.dpi_scale_trans)
    ax.text(
        -0.04,
        1.0,
        ax_label,
        transform=ax.transAxes + trans,
        fontsize='18',
        va='bottom',
        weight='bold',
    )

  plt.tight_layout()
  if filename:
    fig.savefig(filename, format='pdf', dpi=300)


def _prep_df(
    base_df: pd.DataFrame,
    autoeval_df: pd.DataFrame,
    autoeval_rating_col: str,
    autoeval_rater: str,
) -> Any:
  # We append derived keys to the base dataframe.
  base_df = base_df.copy()
  base_df[DERIVED_COL_KEY] = base_df.apply(
      _key_column,
      axis=1,
  )
  base_df[DERIVED_COL_KEY_NO_RATER] = base_df.apply(
      _key_no_rater_column,
      axis=1,
  )

  # We subset the AutoEval dataframe to samples from the AutoEval rater the
  # model was trained on, overwrite the rater to be the AutoEval model, append
  # derived keys, and rename the rating column.
  autoeval_df = autoeval_df.copy()
  autoeval_df = autoeval_df[autoeval_df.rater == autoeval_rater].reset_index(
      drop=True
  )
  autoeval_df[COL_RATER] = autoeval_rating_col
  autoeval_df[DERIVED_COL_KEY] = autoeval_df.apply(
      _key_column,
      axis=1,
  )
  autoeval_df[DERIVED_COL_KEY_NO_RATER] = autoeval_df.apply(
      _key_no_rater_column,
      axis=1,
  )
  autoeval_df = autoeval_df.drop(columns=[COL_RATING_HUMAN_EXPERT]).rename(
      columns={autoeval_rating_col: COL_RATING}
  )
  df_merged = pd.concat([base_df, autoeval_df])

  # Here we build pairs that contain the autorater and/or the rater used to
  # train the autorater. We fix ordering here for easier comparison.
  rater_pairs_rater = []
  rater_pairs_auto = []
  rater_pairs_mixed = []
  for pair in get_rater_pairs(df_merged):
    if (autoeval_rater in pair) and (autoeval_rating_col in pair):
      rater_pairs_mixed.append((autoeval_rating_col, autoeval_rater))
    elif autoeval_rater in pair:
      if pair[0] == autoeval_rater:
        rater_pairs_rater.append((pair[1], pair[0]))
      else:
        rater_pairs_rater.append(pair)
    elif autoeval_rating_col in pair:
      if pair[0] == autoeval_rating_col:
        rater_pairs_auto.append((pair[1], pair[0]))
      else:
        rater_pairs_auto.append(pair)
  rater_pairs = (
      sorted(
          rater_pairs_rater,
          key=lambda x: AUTOEVAL_MODEL_TO_LABEL.get(
              x[0], _rater_id_to_label(x[0])
          ),
      )
      + sorted(
          rater_pairs_auto,
          key=lambda x: AUTOEVAL_MODEL_TO_LABEL.get(
              x[0], _rater_id_to_label(x[0])
          ),
      )
      + sorted(
          rater_pairs_mixed,
          key=lambda x: AUTOEVAL_MODEL_TO_LABEL.get(
              x[0], _rater_id_to_label(x[0])
          ),
      )
  )
  colors = (
      [COLOR_BLUE] * len(rater_pairs_rater)
      + [COLOR_GREEN] * len(rater_pairs_auto)
      + [COLOR_YELLOW] * len(rater_pairs_mixed)
  )

  df_pivot = df_merged.pivot(
      index=DERIVED_COL_KEY_NO_RATER,
      columns=COL_RATER,
      values=COL_RATING,
  )
  return rater_pairs, colors, df_pivot


def plot_rater_contingency_autorater_rater_overlap_both(
    sleep_df: pd.DataFrame,
    sleep_autoeval_df: pd.DataFrame,
    sleep_autoeval_rating_col: str,
    sleep_autoeval_rater_id: str,
    fitness_df: pd.DataFrame,
    fitness_autoeval_df: pd.DataFrame,
    fitness_autoeval_rating_col: str,
    fitness_autoeval_rater_id: str,
    filename: str | None = None,
) -> pd.DataFrame:
  """Plots a contigency plot for the given autorater.

  This subsets to only overlaps with the training autorater (i.e., we take
  all train rater v.s. other rater pairs and report concord). This is for easier
  comparison with the original rater.
  """
  sleep_df = sleep_df.copy()
  sleep_autoeval_df = sleep_autoeval_df.copy()
  fitness_df = fitness_df.copy()
  fitness_autoeval_df = fitness_autoeval_df.copy()

  sleep_rater_pairs, sleep_colors, sleep_df_pivot = _prep_df(
      sleep_df,
      sleep_autoeval_df,
      sleep_autoeval_rating_col,
      sleep_autoeval_rater_id,
  )
  fitness_rater_pairs, fitness_colors, fitness_df_pivot = _prep_df(
      fitness_df,
      fitness_autoeval_df,
      fitness_autoeval_rating_col,
      fitness_autoeval_rater_id,
  )

  num_rows = 5
  num_cols = 4
  fig, axes = plt.subplots(
      num_rows,
      num_cols,
      figsize=(4 * num_cols, 4 * num_rows),
      dpi=300,
  )
  # Handle matplotlib returning varied types/shapes for subplots.
  if not isinstance(axes, np.ndarray):
    axes = [[axes]]
  elif len(axes.shape) == 1:
    axes = [axes]

  used_axes = []
  for row_offset, df_pivot, colors, rater_pairs in [
      (0, sleep_df_pivot, sleep_colors, sleep_rater_pairs),
      (
          2,
          fitness_df_pivot,
          fitness_colors,
          fitness_rater_pairs,
      ),
  ]:
    num_pairs = len(rater_pairs)
    row_indices, col_indices = np.unravel_index(
        range(num_pairs + 1 if row_offset == 0 else num_pairs),
        (2 if row_offset == 0 else 3, num_cols),
        order='C',
    )
    for n, (rater_a, rater_b) in enumerate(rater_pairs):
      ax_n = n
      if row_offset == 0 and n > 2:
        if n > num_pairs:
          continue
        ax_n += 1
      label_a = AUTOEVAL_MODEL_TO_LABEL.get(
          rater_a, _rater_id_to_label(rater_a)
      )
      label_b = AUTOEVAL_MODEL_TO_LABEL.get(
          rater_b, _rater_id_to_label(rater_b)
      )
      rater_pair_df = (
          df_pivot[[
              rater_a,
              rater_b,
          ]]
          .dropna()
          .reset_index(drop=True)
      )
      ax = axes[row_indices[ax_n] + row_offset][col_indices[ax_n]]
      used_axes.append(ax)
      plot_rating_contingency(
          rater_pair_df,
          col_a=rater_a,
          label_a=label_a,
          col_b=rater_b,
          label_b=label_b,
          ax=ax,
          color=colors[n],
      )

  for i in range(num_rows):
    for j in range(num_cols):
      ax = axes[i, j]
      if ax not in used_axes:
        ax.set_visible(False)

  labeled_axes = [axes[0][0], axes[2][0]]

  for i, ax in enumerate(labeled_axes):  # Skip the legend.
    ax_label = f'{string.ascii_lowercase[i]}'
    trans = transforms.ScaledTranslation(-20 / 72, 7 / 72, fig.dpi_scale_trans)
    ax.text(
        -0.04,
        1.0,
        ax_label,
        transform=ax.transAxes + trans,
        fontsize='18',
        va='bottom',
        weight='bold',
    )
  fig.tight_layout()
  plt.show()
  if filename:
    fig.savefig(filename, format='pdf', dpi=300)


def _plot_pro_subfig(
    ax: plt.Axes,
    df: pd.DataFrame,
    df_prevalence: pd.DataFrame,
    metric: str,
    models: list[str],
    add_significance: bool,
    targets: list[str] | None = None,
    palette: dict[str, str] | None = None,
    fontsize: str = '18',
) -> None:
  """Plot PRO sub-figure."""
  if targets is None:
    targets = PRO_BINARY_TARGETS
  if palette is None:
    palette = PRO_MODEL_TO_COLOR
  num_patches = len(models) * len(targets)
  metric_config = PRO_METRIC_TO_CONFIGS[metric]

  # Add random guess.
  if metric == 'auc':
    # Random guess has 0.5 success rate for AUROC.
    plt.sca(ax)
    plt.axvline(x=0.5, color='black', linestyle='--')
  else:
    # Add prevalence for AUPRC.
    df = pd.concat([df, df_prevalence], ignore_index=True)
    new_models = list(df_prevalence[PRO_COL_MODEL].unique())
    models = models + new_models
    num_patches += len(new_models) * len(targets)

  # Add main plot.
  sns.barplot(
      y=PRO_COL_TARGET,
      x='mean',
      hue=PRO_COL_MODEL,
      data=df,
      n_boot=100,
      palette=palette,
      order=targets,
      hue_order=models,
      orient='h',
      ax=ax,
  )

  # Add error bar.
  err_lower = (df['mean'] - df['ci_lower']).values.flatten()
  err_upper = (df['ci_upper'] - df['mean']).values.flatten()
  err = (err_lower, err_upper)
  y = [
      patch.get_xy()[1] + patch.get_height() / 2.0
      for patch in ax.patches[:num_patches]
  ]
  _ = ax.errorbar(x=df['mean'], y=y, xerr=err, fmt='o', color='black')

  # Add significant mark.
  if add_significance:
    indices_non_sig = metric_config['non_significant_indices']
    indices_significant_improve = [
        i for i in range(len(targets)) if i not in indices_non_sig
    ]
    for i in indices_significant_improve:
      label_position = i
      y_coords = get_bar_mid_y_coords([models[0], models[-1]], label_position, ax)
      add_significance_y(
          ax, y_coords[0], y_coords[1], *metric_config['significant_loc'], '*'
      )

  # Adjust legend, title, xlabel, and ylabel.
  ax.legend(
      loc=metric_config['legend_loc'],
      facecolor='white',
      edgecolor='white',
      framealpha=1,
      handlelength=0.7,
      borderpad=0.3,
      fontsize=fontsize
  )
  ax.set_title('')
  ax.set_xlabel(metric_config['xlabel'], fontsize=fontsize)
  ax.set_ylabel('')
  ax.tick_params(axis='x', labelsize=fontsize)
  ax.tick_params(axis='y', labelsize=fontsize)


def plot_pro_fig(
    models: list[str],
    df_pro: pd.DataFrame,
    df_pro_prevalence: pd.DataFrame,
    savefig_filepath: str,
    add_significance: bool = False,
    figsize: list[int] = [16, 10],
    subfig_idx_offset: int = 0,
    subfig_idx_pos: list[list[float]] = [[0.1, 0.9], [0.6, 0.9]],
    fontsize: str = '12',
) -> None:
  """Plot prediction part of the PRO main figure, i.e., figure 3 c,d."""
  fig, axes = plt.subplots(1, 2, figsize=figsize)
  for axes_idx, metric in enumerate(PRO_METRIC_TO_CONFIGS):
    _plot_pro_subfig(
        ax=axes[axes_idx],
        df=df_pro[df_pro['metric_name'] == metric].reset_index().copy(),
        df_prevalence=df_pro_prevalence,
        metric=metric,
        models=models,
        add_significance=add_significance,
        fontsize=fontsize,
    )
  # Add a,b,c,d like labels to subfigures.
  # Avoid adding label to heatmap color bar.
  labeled_axes = [
      ax for ax in [*fig.get_axes()] if ax.get_label() != '<colorbar>'
  ]
  for i, ax in enumerate(labeled_axes):
    ax_label = string.ascii_lowercase[i + subfig_idx_offset]
    trans = transforms.ScaledTranslation(-20 / 72, 7 / 72, fig.dpi_scale_trans)
    ax.text(
        subfig_idx_pos[i][0],
        subfig_idx_pos[i][1],
        ax_label,
        transform=fig.transFigure + trans,
        fontsize='16',
        va='bottom',
        weight='bold',
    )
  plt.tight_layout()
  fig.savefig(savefig_filepath, format='pdf', dpi=300)
  plt.show()

## Load human expert and AutoEval ratings

In [None]:
# Human expert ratings for internal models.
g_sleep_df = load_ratings_df(
    Vertical.SLEEP,
    RatingsSource.HUMAN_EXPERT,
)
g_fitness_df = load_ratings_df(
    Vertical.FITNESS,
    RatingsSource.HUMAN_EXPERT,
)

# AutoEval ratings for internal and external models.
g_sleep_external_df = load_ratings_df(
    Vertical.SLEEP,
    RatingsSource.AUTOEVAL_EXTERNAL,
)
g_fitness_external_df = load_ratings_df(
    Vertical.FITNESS,
    RatingsSource.AUTOEVAL_EXTERNAL,
)

# AutoEval ratings for subsampled PH-LLM models.
g_sleep_subsample_df = load_ratings_df(
    Vertical.SLEEP,
    RatingsSource.AUTOEVAL_SUBSAMPLE,
)
g_fitness_subsample_df = load_ratings_df(
    Vertical.FITNESS,
    RatingsSource.AUTOEVAL_SUBSAMPLE,
)

## Plot figures

### Fig. 2

**Long-form case study evaluation and performance.** [...] c,d, Mean ratings
given by experts for the case study subsections across the sleep (c) and fitness
(d) domains. Error bars represent 95% confidence intervals. “∗” indicates a
statistically significant difference between two response types after multiple
hypothesis testing correction. REM: Rapid Eye Movement, HRV RMSSD: Heart Rate
Variability Root Mean Square of Successive Differences, ACWR: Acute:Chronic
Workload Ratio.

In [None]:
# Note: Since we do not include Fig. 2's case study examples as the first two
# subfigures here, the subplots in the figure below are annotated "a" and "b".
plot_case_study_main_fig(
    sleep_df=g_sleep_df,
    fitness_df=g_fitness_df,
    sleep_title=Vertical.SLEEP.value.capitalize(),
    fitness_title=Vertical.FITNESS.value.capitalize(),
    sleep_rating_col=COL_RATING,
    fitness_rating_col=COL_RATING,
    hue_order=CONVERSATION_SOURCE_ORDER_MAIN,
    label_subplots=True,
    savefig_filepath='figure_2_cd.pdf',
)

### Fig. 3
**Prediction of patient-reported outcomes by PH-LLM.** [...] c, Area under the receiver operating characteristic curve (AUROC) for the performance of PH-LLM with adapter,
zero-shot, and few-shot prompting approaches when predicting binary outcomes derived from survey responses. The
dotted vertical line denotes the AUROC of the random predictor. Outcomes for which the performance of PH-LLM
with adapter approach is significantly better than both zero- and few-shot are annotated with “*”. d, Area under the
precision-recall curve (AUPRC) for the performance of PH-LLM with adapter, zero-shot, and few-shot prompting
approaches when predicting binary outcomes derived from survey responses. Outcome-specific prevalence bars are
added to show the AUPRC of the random predictor. Survey response names are mapped to their corresponding questions
in Supplementary Tables 39 and 40. “SI”, sleep impairment. Error bars represent 95% confidence intervals.

In [None]:
# Load PRO data.
models = ['PH-LLM w/ Adapter', 'PH-LLM Few-shot', 'PH-LLM Zero-shot']
df_pro, df_pro_prevalence = load_pro_df(models)
# Plot.
plot_pro_fig(
    models,
    df_pro,
    df_pro_prevalence,
    savefig_filepath='figure_3_cd.pdf',
    add_significance=True,
    subfig_idx_offset=2,  # This changes indices from a,b to c,d.
    fontsize='12',
)

### Extended Data Fig. 3

**Pairwise Gwet's AC2 measuring inter-rater reliability between primary and
secondary raters.** Metrics were computed using all ratings for each principle
and section across case studies rated by more than one rater in the sleep (a)
and fitness (b) domains. The number of overlapping ratings is denoted by n. Mean
metrics and 95% confidence intervals derived from 1,000 bootstrapping iterations
are reported for each pair.

In [None]:
g_sleep_bootstrap_metrics_df = bootstrap_rater_pairs_binary(g_sleep_df)
g_fitness_bootstrap_metrics_df = bootstrap_rater_pairs_binary(g_fitness_df)
plot_rater_pair_stats_binary_only_gwet(
    g_sleep_bootstrap_metrics_df,
    g_fitness_bootstrap_metrics_df,
    filename='extended_data_figure_3.pdf',
)

### Extended Data Fig. 4

**Contingency tables showing pairwise rating agreement between raters.** Counts
are aggregated across all case studies, sections, and principles for each case
study for which multiple ratings are available in the sleep (a) and fitness (b)
domains. Blue, primary versus primary raters. Green, primary versus secondary
raters. Yellow, secondary versus secondary raters.

In [None]:
plot_rater_contingency_all_both_verticals(
    g_sleep_df,
    g_fitness_df,
    filename='extended_data_figure_4.pdf',
    rater_types_a=[RaterType.PRIMARY, RaterType.PRIMARY, RaterType.SECONDARY],
    rater_types_b=[RaterType.PRIMARY, RaterType.SECONDARY, RaterType.SECONDARY],
)

### Extended Data Fig. 5

**Sleep and fitness case study human evaluation results by principle.** Mean
ratings given by experts for different case study evaluation principles across
all sections in the sleep (a) and fitness (b) domains. The principles are
ordered according to the rubric presented in Supplementary Table 9. “∗”
indicates a statistically significant difference between two response types
after multiple hypothesis testing correction. Error bars represent 95%
confidence intervals.

In [None]:
plot_by_principle_all(
    sleep_df=g_sleep_df,
    fitness_df=g_fitness_df,
    sleep_title=Vertical.SLEEP.value.capitalize(),
    fitness_title=Vertical.FITNESS.value.capitalize(),
    order=PRINCIPLE_ORDER,
    # Note: We reverse the order since vertical bar plots are bottom-to-top.
    hue_order=CONVERSATION_SOURCE_ORDER_MAIN[::-1],
    other_font_size=10,
    add_stat_sig_annot=True,
    label_subplots=True,
    savefig_filepath='extended_data_figure_5.pdf',
)

### Extended Data Fig. 6

**Contingency tables showing pairwise rating agreement between our best
AutoRaters, their corresponding expert raters, and other experts.** Counts are
aggregated across all case studies, sections, and principles for each case study
for which at least one rating from the AutoEval training rater is available in
the sleep (a) and fitness (b) domains. Blue, the primary expert rater versus
other raters. Green, the AutoEval model trained on primary expert ratings versus
other raters. Yellow, the primary expert rater versus the corresponding AutoEval
model.

In [None]:
plot_rater_contingency_autorater_rater_overlap_both(
    sleep_df=g_sleep_df,
    sleep_autoeval_df=g_sleep_external_df,
    sleep_autoeval_rating_col=COL_RATING_AUTOEVAL_SLEEP,
    sleep_autoeval_rater_id='sleep_primary_c',
    fitness_df=g_fitness_df,
    fitness_autoeval_df=g_fitness_external_df,
    fitness_autoeval_rating_col=COL_RATING_AUTOEVAL_FITNESS,
    fitness_autoeval_rater_id='fitness_primary_c',
    filename='extended_data_figure_6.pdf',
)

### Extended Data Fig. 7

**Automatic evaluation of coaching recommendations across PH-LLM and baseline
models.** Mean ratings were generated using our best AutoEval models in the
sleep (a) and fitness (b) domains. Within each section, a “*” denotes the
highest-rated model and all models not statistically significantly different
from that model after multiple hypothesis testing correction. Error bars
represent 95% confidence intervals.

In [None]:
plot_case_study_main_fig(
    sleep_df=g_sleep_external_df,
    fitness_df=g_fitness_external_df,
    sleep_title=Vertical.SLEEP.value.capitalize(),
    fitness_title=Vertical.FITNESS.value.capitalize(),
    sleep_rating_col=COL_RATING_AUTOEVAL_SLEEP,
    fitness_rating_col=COL_RATING_AUTOEVAL_FITNESS,
    hue_order=CONVERSATION_SOURCE_ORDER_EXTERNAL,
    label_subplots=True,
    pairwise_statsig=False,
    savefig_filepath='extended_data_figure_7.pdf',
)

### Extended Data Fig. 8

**Effect of fine-tuning data scale on model performance in coaching
recommendations.** Ratings are obtained via the best AutoEval models for the
holdout case study subsections in the sleep (a) and fitness (b) domains.
“PH-LLM” denotes standard performance while “Subsampled 25%” and “Subsampled
50%” denote responses from models trained on 25% and 50% of the training
dataset, respectively. “Gemini Ultra 1.0” denotes untuned baseline performance
(i.e., trained on 0% of the training dataset). Within each section, a “*”
denotes the highest-rated model and all models not statistically significantly
different from that model after multiple hypothesis testing correction. Error
bars represent 95% confidence intervals.

In [None]:
plot_case_study_main_fig(
    sleep_df=g_sleep_subsample_df,
    fitness_df=g_fitness_subsample_df,
    sleep_title=Vertical.SLEEP.value.capitalize(),
    fitness_title=Vertical.FITNESS.value.capitalize(),
    sleep_rating_col=COL_RATING_AUTOEVAL_SLEEP,
    fitness_rating_col=COL_RATING_AUTOEVAL_FITNESS,
    hue_order=CONVERSATION_SOURCE_ORDER_SUBSAMPLE,
    label_subplots=True,
    pairwise_statsig=False,
    savefig_filepath='extended_data_figure_8.pdf',
)

### Extended Data Fig. 9

**Performance of PH-LLM and traditional ML models on patient-reported outcomes
prediction.** We compared the ability of PH-LLM with and without a multimodal adapter, logistic regression, and a
convolutional neural network (CNN) to infer subjective patient-reported outcomes. a, Area under the receiver operating
characteristic curve (AUROC). b, Area under the precision-recall curve (AUPRC). Error bars represent 95% confidence
intervals. The CNN underperforms logistic regression, likely due to the limited size of the dataset.

In [None]:
# Load PRO data.
models = ['PH-LLM w/ Adapter', 'PH-LLM Few-shot', 'PH-LLM Zero-shot', 'LogReg', 'CNN']
df_pro, df_pro_prevalence = load_pro_df(models)
# Plot.
plot_pro_fig(
    models,
    df_pro,
    df_pro_prevalence,
    savefig_filepath='extended_data_figure_9.pdf',
    figsize=[22, 22],
    subfig_idx_pos=[[0.05, 0.92], [0.55, 0.92]],
    fontsize='16',
)