Code for figure 1 and appendix figure showing similar PPG signals

In [None]:
pip install scikit-learn

In [4]:
import os
import pickle
#import torch
#import torch.nn as nn
#import torch.nn.functional as F
#from scipy.signal import periodogram
import numpy as np
import random
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split

seed = 0
random.seed(seed)
np.random.seed(seed)
#torch.manual_seed(seed)
#if torch.cuda.is_available():
#    torch.cuda.manual_seed(seed)
#torch.backends.cudnn.deterministic = True

device = "cpu" # device to use

In [5]:
def parse_ppg(ppg_str):
    parts = ppg_str.split("|")
    y_values = []
    for part in parts[1:]:
        if part.strip() == "":
            continue
        _, y_val = part.split(",")
        y_values.append(float(y_val))
    return np.array(y_values)

df = pd.read_csv("data.csv")
df = df[df['p4205_i0'].notna()]
#df = df.sample(n=10000, random_state=seed)
df['y_values'] = df['p4205_i0'].apply(parse_ppg)

X = np.stack(df['y_values'].values)
Y = df['p21003_i0'].values

# 80/20 train/test split
#X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=seed)

# 70/15/15 train/val/test
X_train, X_temp, Y_train, Y_temp = train_test_split(X, Y, test_size=0.3, random_state=seed)
X_val, X_test, Y_val, Y_test = train_test_split(X_temp, Y_temp, test_size=0.5, random_state=seed)


In [None]:
import pandas as pd, numpy as np, matplotlib.pyplot as plt, itertools, math

seed=1
rng=np.random.default_rng(seed)
n_seed_per_bucket=1000
n_preview_groups=100
chunk=6500
mse_percentile_threshold = 1

bins=np.array([40,45,50,55,60,65,70])
labels=["40-45","45-50","50-55","55-60","60-65","65-70"]
df["age_bucket"]=pd.cut(df["p21003_i0"],bins=bins,labels=labels,right=False)
df=df[df["age_bucket"].notna()].copy()

full_X={b:np.stack(g["y_values"].values) for b,g in df.groupby("age_bucket")}
full_age={b:g["p21003_i0"].values for b,g in df.groupby("age_bucket")}

seed_X={b:np.stack(g.sample(min(len(g),n_seed_per_bucket),random_state=seed)["y_values"].values) for b,g in df.groupby("age_bucket")}

def best_match(x_signal, X_bucket_signals):
    best_mse = np.inf
    best_idx = -1
    x_2d = x_signal.reshape(1, -1) if x_signal.ndim == 1 else x_signal

    for i in range(0, X_bucket_signals.shape[0], chunk):
        current_chunk = X_bucket_signals[i:i+chunk]
        if current_chunk.shape[0] > 0: 
            mse_values = ((x_2d - current_chunk)**2).mean(axis=1)
            if mse_values.size > 0:
                min_mse_in_chunk = mse_values.min()
                if min_mse_in_chunk < best_mse:
                    best_mse = min_mse_in_chunk
                    best_idx_in_chunk = mse_values.argmin()
                    best_idx = i + best_idx_in_chunk
    return best_idx, best_mse

# Find groups of similar signals across all age buckets.
pairs = []
bucket_pairs = list(itertools.combinations(labels, 2))
for b1, b2 in bucket_pairs:
    n_samples_b1 = min(len(seed_X[b1]), 200)
    n_samples_b2 = min(len(seed_X[b2]), 200)

    a = seed_X[b1][rng.choice(len(seed_X[b1]), n_samples_b1, replace=False)]
    b = seed_X[b2][rng.choice(len(seed_X[b2]), n_samples_b2, replace=False)]
    pairs.append(((a[:, None] - b)**2).mean(axis=2).ravel())

all_mse = np.concatenate(pairs)
target_mse = np.percentile(all_mse, mse_percentile_threshold)

groups = []
for sb_starting_bucket in labels:
    print(sb_starting_bucket)
    for x_seed_signal in seed_X[sb_starting_bucket]:
        g_indices = {}
        errs = []
        
        idx, err = best_match(x_seed_signal, full_X[sb_starting_bucket])
        if idx != -1: 
            g_indices[sb_starting_bucket] = idx
            errs.append(err)
        else:
            continue

        all_buckets_matched = True
        for b_target_bucket in [l_other for l_other in labels if l_other != sb_starting_bucket]:
            idx, err = best_match(x_seed_signal, full_X[b_target_bucket])
            if idx != -1:
                g_indices[b_target_bucket] = idx
                errs.append(err)
            else:
                all_buckets_matched = False
                break
        
        group_signals_list = [full_X[b_label][g_indices[b_label]] for b_label in labels]
        first_signal_len = group_signals_list[0].shape[0]
        if all(s.shape[0] == first_signal_len for s in group_signals_list):
            sigs_matrix = np.vstack(group_signals_list)
            d = sigs_matrix[:, None, :] - sigs_matrix[None, :, :]
            score = np.triu(((d**2).mean(axis=2)), 1).sum()
            groups.append((g_indices, score))

groups = sorted(groups, key=lambda x: x[1])[:n_preview_groups]

# Select groups
group1_dict, _ = groups[0] 
group2_dict = None
mean_signal_g1 = np.mean(np.vstack([full_X[b][group1_dict[b]] for b in labels]), axis=0)
max_mse_from_g1 = -1

if len(groups) > 1:
    for i in range(1, min(len(groups), n_preview_groups // 2)): 
        current_candidate_dict, _ = groups[i]
        candidate_signals_list = [full_X[b_label][current_candidate_dict[b_label]] for b_label in labels]
        mean_signal_candidate = np.mean(np.vstack(candidate_signals_list), axis=0)
        mse_with_g1 = np.mean((mean_signal_g1 - mean_signal_candidate)**2)
        if mse_with_g1 > max_mse_from_g1:
            max_mse_from_g1 = mse_with_g1
            group2_dict = current_candidate_dict


# Plotting
fig = plt.figure(figsize=(22, 16))
gs = fig.add_gridspec(3, len(labels), height_ratios=[1.5, 3, 1.5], width_ratios=[1]*len(labels), 
                      hspace=0.6, wspace=0.4) 

ax_main = fig.add_subplot(gs[1, :])
color1, color2 = 'royalblue', 'firebrick'
x_connect_point = signal_length // 2 

# Plot Group 1
for j, b_label in enumerate(labels):
    idx_in_bucket = group1_dict[b_label]
    signal = full_X[b_label][idx_in_bucket]
    age = full_age[b_label][idx_in_bucket]
    
    line, = ax_main.plot(signal, color=color1, alpha=0.7, linewidth=2.5, zorder=5)
    
    ax_inset = fig.add_subplot(gs[0, j])
    ax_inset.plot(signal, color=color1, linewidth=1.5)
    ax_inset.set_title(f"Age {int(age)}", fontsize=20, pad=3)
    
    y_connect_point = line.get_ydata()[x_connect_point]
    con = ConnectionPatch(xyA=(x_connect_point, y_connect_point), xyB=(0.5, 0.05),
                          coordsA='data', coordsB='axes fraction',
                          axesA=ax_main, axesB=ax_inset,
                          color='gray', linestyle=':', linewidth=1, alpha=0.8, zorder=1)
    fig.add_artist(con)

# Plot Group 2
if group2_dict:
    for j, b_label in enumerate(labels):
        idx_in_bucket = group2_dict[b_label]
        signal = full_X[b_label][idx_in_bucket]
        age = full_age[b_label][idx_in_bucket]

        line, = ax_main.plot(signal, color=color2, alpha=0.7, linewidth=2.5, zorder=5)

        ax_inset = fig.add_subplot(gs[2, j])
        ax_inset.plot(signal, color=color2, linewidth=1.5)
        ax_inset.set_title(f"Age {int(age)}", fontsize=20, pad=3)

        y_connect_point = line.get_ydata()[x_connect_point]
        con = ConnectionPatch(xyA=(x_connect_point, y_connect_point), xyB=(0.5, 0.95), 
                              coordsA='data', coordsB='axes fraction',
                              axesA=ax_main, axesB=ax_inset,
                              color='gray', linestyle=':', linewidth=1, alpha=0.8, zorder=1)
        fig.add_artist(con)

ax_main.set_title("Comparison of Two Distinct PPG Signal Groups", fontsize=20, pad=25)
ax_main.set_xlabel("Time Index", fontsize=20)
ax_main.set_ylabel("PPG Amplitude", fontsize=20)
ax_main.tick_params(axis='both', which='major', labelsize=20)

legend_elements = [Line2D([0], [0], color=color1, lw=3, label='Group 1 Signals')]
if group2_dict:
    legend_elements.append(Line2D([0], [0], color=color2, lw=3, label='Group 2 Signals'))
ax_main.legend(handles=legend_elements, loc='upper right', fontsize=20)

fig.suptitle("PPG Signal Analysis: Intra-Group Similarity vs. Inter-Group Dissimilarity", fontsize=22, y=0.97)
plt.subplots_adjust(top=0.90, bottom=0.05, left=0.05, right=0.95)
plt.show()


In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
from matplotlib.lines import Line2D
import numpy as np 

# Plotting
fig = plt.figure(figsize=(22, 16))
gs = fig.add_gridspec(3, len(labels), height_ratios=[1.5, 3, 1.5], width_ratios=[1]*len(labels), 
                      hspace=0.6, wspace=0.4) 

ax_main = fig.add_subplot(gs[1, :])
color1, color2 = 'royalblue', 'firebrick'
x_connect_point = 75

# Plot Group 1
for j, b_label in enumerate(labels):
    idx_in_bucket = group1_dict[b_label]
    signal = full_X[b_label][idx_in_bucket]

    if not isinstance(signal, np.ndarray):
        signal = np.array(signal)
        
    age = full_age[b_label][idx_in_bucket]
    
    # Normalize signal to 0-1 range
    signal_max = np.max(signal)
    if signal_max > 0:
        normalized_signal = signal / signal_max
    else:
        normalized_signal = signal 
        
    line, = ax_main.plot(normalized_signal, color=color1, alpha=0.7, linewidth=2.5, zorder=5)
    
    ax_inset = fig.add_subplot(gs[0, j])
    ax_inset.plot(normalized_signal, color=color1, linewidth=1.5) 
    ax_inset.set_title(f"Age {int(age)}", fontsize=20, pad=3)
    ax_inset.tick_params(axis='both', which='major', labelsize=22)
    ax_inset.set_ylim(-0.05, 1.05)

    
    y_connect_point = line.get_ydata()[x_connect_point]
    con = ConnectionPatch(xyA=(x_connect_point, y_connect_point), xyB=(0.5, 0.05),
                          coordsA='data', coordsB='axes fraction',
                          axesA=ax_main, axesB=ax_inset,
                          color='gray', linestyle=':', linewidth=1, alpha=0.8, zorder=1)
    fig.add_artist(con)

# Plot Group 2
if group2_dict:
    for j, b_label in enumerate(labels):
        idx_in_bucket = group2_dict[b_label]
        signal = full_X[b_label][idx_in_bucket]
        if not isinstance(signal, np.ndarray):
            signal = np.array(signal)

        age = full_age[b_label][idx_in_bucket]

        # Normalize signal to 0-1 range
        signal_max = np.max(signal)
        if signal_max > 0:
            normalized_signal = signal / signal_max
        else:
            normalized_signal = signal 

        line, = ax_main.plot(normalized_signal, color=color2, alpha=0.7, linewidth=2.5, zorder=5)

        ax_inset = fig.add_subplot(gs[2, j])
        ax_inset.plot(normalized_signal, color=color2, linewidth=1.5)
        ax_inset.set_title(f"Age {int(age)}", fontsize=20, pad=3)
        ax_inset.tick_params(axis='both', which='major', labelsize=22)
        ax_inset.set_ylim(-0.05, 1.05) 


        y_connect_point = line.get_ydata()[x_connect_point]
        con = ConnectionPatch(xyA=(x_connect_point, y_connect_point), xyB=(0.5, 0.95), 
                              coordsA='data', coordsB='axes fraction',
                              axesA=ax_main, axesB=ax_inset,
                              color='gray', linestyle=':', linewidth=1, alpha=0.8, zorder=1)
        fig.add_artist(con)

ax_main.set_title("Comparison of Two Distinct PPG Signal Groups", fontsize=28, pad=25)
ax_main.set_xlabel("Time point", fontsize=28)
ax_main.set_ylabel("Normalized PPG Amplitude (0-1)", fontsize=28) 
ax_main.tick_params(axis='both', which='major', labelsize=25)
ax_main.set_ylim(-0.05, 1.05)


legend_elements = [Line2D([0], [0], color=color1, lw=3, label='Group 1 Signals')]
if group2_dict:
    legend_elements.append(Line2D([0], [0], color=color2, lw=3, label='Group 2 Signals'))
ax_main.legend(handles=legend_elements, loc='upper right', fontsize=25)

fig.suptitle("PPG Signal Analysis: Intra-Group Similarity vs. Inter-Group Dissimilarity", fontsize=30, y=0.97)
plt.subplots_adjust(top=0.90, bottom=0.05, left=0.05, right=0.95)
plt.savefig("figure1_updated_normalized.png", dpi=300, bbox_inches='tight')


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import itertools
import math
from matplotlib.lines import Line2D

num_panels_desired = 3
groups_per_panel = 4
total_unique_groups_needed = num_panels_desired * groups_per_panel

used_signals_for_appendix = set()
appendix_groups_to_plot_indices = []

if 'group1_dict' in locals() and group1_dict:
    for bucket_label, signal_idx in group1_dict.items():
        used_signals_for_appendix.add((str(bucket_label), int(signal_idx)))
if 'group2_dict' in locals() and group2_dict:
    for bucket_label, signal_idx in group2_dict.items():
        used_signals_for_appendix.add((str(bucket_label), int(signal_idx)))

for g_indices, score in groups: 
    current_group_signal_ids = set()
    is_fresh_group = True
    for bucket_label, signal_idx in g_indices.items():
        signal_id = (str(bucket_label), int(signal_idx))
        if signal_id in used_signals_for_appendix:
            is_fresh_group = False
            break
        current_group_signal_ids.add(signal_id)

    if is_fresh_group:
        appendix_groups_to_plot_indices.append(g_indices)
        used_signals_for_appendix.update(current_group_signal_ids)
        if len(appendix_groups_to_plot_indices) >= total_unique_groups_needed:
            break

num_panels_actual = 0
if groups_per_panel > 0 and len(appendix_groups_to_plot_indices) > 0:
    num_panels_actual = math.ceil(len(appendix_groups_to_plot_indices) / groups_per_panel)

if num_panels_actual > 0:
    fig_appendix, appendix_axs = plt.subplots(num_panels_actual, 1, figsize=(18, 7 * num_panels_actual), sharex=True)
    if num_panels_actual == 1: 
        appendix_axs = [appendix_axs] 

    global_group_number_counter = 3 

    for i in range(num_panels_actual):
        ax_panel = appendix_axs[i]
        start_idx = i * groups_per_panel
        end_idx = min(start_idx + groups_per_panel, len(appendix_groups_to_plot_indices))
        panel_group_indices_list = appendix_groups_to_plot_indices[start_idx:end_idx]

        if not panel_group_indices_list: 
            # ax_panel.set_title(f"Appendix Panel {i+1}: No unique groups available", fontsize=22) # Removed for streamlining
            # ax_panel.axis('off') # Removed for streamlining
            continue

        for k, current_g_indices in enumerate(panel_group_indices_list):
            group_ages = [] 
            for bucket_label_for_age in labels:
                signal_idx_for_age = current_g_indices[bucket_label_for_age]
                age_val = full_age[bucket_label_for_age][signal_idx_for_age]
                group_ages.append(str(int(age_val)))
            
            legend_label_ages_full = ", ".join(group_ages)
            group_number_for_label = global_group_number_counter
            panel_legend_label = f"Group {group_number_for_label} (Ages: {legend_label_ages_full})"
            global_group_number_counter += 1

            for bucket_idx, bucket_label in enumerate(labels):
                signal_idx_in_full_X = current_g_indices[bucket_label]
                signal = full_X[bucket_label][signal_idx_in_full_X].copy() 
                
                signal_max = np.max(signal)
                if signal_max != 0 and not np.isnan(signal_max): 
                    signal = signal / signal_max

                if bucket_idx == 0: 
                    ax_panel.plot(signal, alpha=0.8, linewidth=2, label=panel_legend_label)
                else:
                    ax_panel.plot(signal, alpha=0.8, linewidth=2, color=ax_panel.lines[-1].get_color())
            
        ax_panel.set_title(f"Appendix Panel {i+1}: {len(panel_group_indices_list)} Distinct PPG Signal Groups", fontsize=22)
        ax_panel.set_ylabel("PPG Amplitude", fontsize=20)
        ax_panel.tick_params(axis='both', which='major', labelsize=18)
        ax_panel.grid(True, linestyle='--', alpha=0.6)
        ax_panel.legend(loc='upper right', fontsize=14)

    appendix_axs[-1].set_xlabel("Time Point", fontsize=20) 
    fig_appendix.suptitle("Appendix: Additional Examples of Distinct Age-Persistent PPG Signal Groups", fontsize=28, y=0.99)
    plt.tight_layout(rect=[0, 0, 1, 0.97]) 
    plt.savefig("figure1_appendix_updated.png", dpi=300, bbox_inches='tight')

