### Imports

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.gridspec as gridspec
from matplotlib.transforms import Bbox
from matplotlib.lines import Line2D
import matplotlib.colors as mcolors
from pathlib import Path
import yaml
from zoneinfo import ZoneInfo
from datetime import timedelta, datetime
import json
from tqdm import tqdm
from scipy.optimize import curve_fit
from scipy import stats
from scipy.ndimage import gaussian_filter1d

import utils
import json_file_utils
import stats_utils
import artifact_correct

%load_ext autoreload
%autoreload 2

## Setup

In [None]:
# Read in patient info and translate it into a dictionary
with open('patient_info.yaml') as stream:
    patient_info = yaml.safe_load(stream)
patient_dict = {d['id']: d for d in patient_info['patients']}

# Get list of patient IDs
ids = list(patient_dict.keys())

# Set up some other utility stuff
central_time = ZoneInfo('America/Chicago')

### Read in raw data from JSONs on Elias

In [None]:
changes_dfs = []
pt_raw_dfs = []
for pt_id in ids:
    # Get path to JSON files containing data for the current patient.
    pt_percept_dir = Path(patient_dict[pt_id]['data_path'])
    jsons = json_file_utils.get_json_filenames(pt_percept_dir)
    
    raw_data_list = []
    for filename in tqdm(jsons):
        # Read LFP and stim data from JSON file.
        try:
            raw = json_file_utils.chronic_lfp_from_json(filename)
        except (PermissionError, json.JSONDecodeError) as e:
            continue
    
        # Extract raw stim and LFP data from JSON to dataframe.
        if not raw.empty:
            raw_data_list.append(raw)

    # Concatenate individual files' data into a single dataframe per patient
    try:
        pt_raw_df = pd.concat(raw_data_list, ignore_index=True)
        pt_raw_df['pt_id'] = pt_id
        if pt_raw_df.size != 0:
            pt_raw_dfs.append(pt_raw_df)
    except ValueError as e:
        print(f'no chronic LFP power data from {pt_id}')

raw_df = pd.concat(pt_raw_dfs, ignore_index=True)

**VC/VS Only**

Most of our patients only have bilateral VC/VS implants. Their leads are connected down to a single IPG on the right side of their chest. The device measures from both leads in parallel, and an average power value from each hemisphere is calculated every 10 minutes, synced across hemispheres.

| Actual lead location | Labeled lead location in JSON | IPG side | Requires correction |
| --- | --- | --- | --- |
| Left VC/VS | Left VC/VS | Right chest | No |
| Right VC/VS | Right VC/VS | Right chest | No |

Note that the lead location is listed as "Other" instead of "VC/VS" for some patients.

**VC/VS + GPI**

Some OCD patients have comorbid Tourette syndrome (B001, B005, B008, B010), and in addition to their bilateral VC/VS leads receive two additional leads implanted bilaterally in the GPI. These patients then receive two IPGs, one in each half of the chest, and two leads connect to each IPG. The IPG in the right side of the chest is connected to both VC/VS leads, and the left IPG is likewise connected to the GPI leads. No correction is required for these patients, and the files containing GPI data are simply removed for our analysis here.

| Actual lead location | Labeled lead location in JSON | IPG side | Requires correction |
| --- | --- | --- | --- |
| Left VC/VS | Left VC/VS | Right chest | No |
| Right VC/VS | Right VC/VS | Right chest | No |
| Left GPI | Left GPI | Left chest | No |
| Right GPI | Right GPI | Left chest | No |

**VC/VS + OFC (B014 and B015)**

A handful of patients (B014, B015, and B017) received leads implanted bilaterally in the VC/VS and bilaterally in the OFC. These patients also receive two IPGs in the chest with the two right hemisphere leads connected to the right IPG and the two left hemisphere leads connected to the left IPG. The Medtronic Percept device collects up to two streams of data, one from each from its two connected leads, and it assigns one lead a hemisphere label of “right” and the other a label of “left.” Each lead may have its own implant location selected from a list of presets. For these patients, because both leads feeding into each device come from the same hemisphere, we must correct the data to properly determine where it came from. In B014 and B015, the left VC/VS lead is connected to the left IPG and given a label of "left VC/VS," and the left OFC lead is connected to the left IPG and given a label of "right Other" (OFC is not one of the preset lead location options). Likewise, the right VC/VS lead and right OFC lead are both fed to the right IPG with labels of "right VC/VS" and "left Other," respectively.

| Actual lead location | Labeled lead location in JSON | IPG side | Requires correction |
| --- | --- | --- | --- |
| Left VC/VS | Left VC/VS | Left chest | No |
| Right VC/VS | Right VC/VS | Right chest | No |
| Left OFC | Right Other | Left chest | Yes |
| Right OFC | Left Other | Right chest | Yes |

Note that B016 and B018 also have VC/VS and OFC implants, but we do not have any chronic LFP data from their devices.

**VC/VS + OFC (B017)**

This patient's system is configured slightly differently from the other VC/VS+OFC patients. The right IPG's labels are swapped, so the right VC/VS lead is labeled in the device's data collected as "left VC/VS," and the right OFC lead is labeled as "right Other." This results in both VC/VS leads being labeled as left VC/VS and both OFC leads being labeled as right Other.

| Actual lead location | Labeled lead location in JSON | IPG side | Requires correction |
| --- | --- | --- | --- |
| Left VC/VS | Left VC/VS | Left chest | No |
| Right VC/VS | Left VC/VS | Right chest | Yes |
| Left OFC | Right Other | Left chest | Yes |
| Right OFC | Right Other | Right chest | No |

In [None]:
raw_df.query('pt_id != "B017" or source_file != "Report_Json_Session_Report_017_01R_20240516T152238.json"', inplace=True) # This file is for the wrong patient.

# B017's data is labeled differently from other OFC patients --> correct it to match the others.
right_chest_jsons_017 = list((Path(patient_dict['B017']['data_path']) / 'R').glob('[!.]*.json'))
right_chest_jsons_017 = [f.name for f in right_chest_jsons_017]
right_chest_data_017 = raw_df.query('pt_id == "B017" and source_file in @right_chest_jsons_017').copy()
right_chest_data_017[['lfp_left', 'stim_left', 'lfp_right', 'stim_right', 'left_lead_location', 'right_lead_location']] = right_chest_data_017[['lfp_right', 'stim_right', 'lfp_left', 'stim_left', 'right_lead_location', 'left_lead_location']].copy()
raw_df.drop(labels=right_chest_data_017.index, inplace=True)
raw_df = pd.concat([raw_df, right_chest_data_017], ignore_index=True)

# Correct mislabeled data for OFC patients
ofc_patients = ['B014', 'B015', 'B016', 'B017', 'B018']
ofc_pt_right_ipg_data_vcvs = raw_df.query('pt_id in @ofc_patients and left_lead_location == "OTHER" and right_lead_location == "VC/VS"').copy()
ofc_pt_right_ipg_data_ofc = ofc_pt_right_ipg_data_vcvs.copy()
ofc_pt_right_ipg_data_vcvs[['lfp_left', 'stim_left']] = np.nan
ofc_pt_right_ipg_data_vcvs[['right_lead_location', 'left_lead_location']] = 'VC/VS'
ofc_pt_right_ipg_data_ofc[['lfp_right', 'stim_right']] = ofc_pt_right_ipg_data_ofc[['lfp_left', 'stim_left']].copy()
ofc_pt_right_ipg_data_ofc[['lfp_left', 'stim_left']] = np.nan
ofc_pt_right_ipg_data_ofc[['right_lead_location', 'left_lead_location']] = 'OFC'

ofc_pt_left_ipg_data_vcvs = raw_df.query('pt_id in @ofc_patients and left_lead_location == "VC/VS" and right_lead_location == "OTHER"').copy()
ofc_pt_left_ipg_data_ofc = ofc_pt_left_ipg_data_vcvs.copy()
ofc_pt_left_ipg_data_vcvs[['lfp_right', 'stim_right']] = np.nan
ofc_pt_left_ipg_data_vcvs[['right_lead_location', 'left_lead_location']] = 'VC/VS'
ofc_pt_left_ipg_data_ofc[['lfp_left', 'stim_left']] = ofc_pt_left_ipg_data_ofc[['lfp_right', 'stim_right']].copy()
ofc_pt_right_ipg_data_ofc[['lfp_right', 'stim_right']] = np.nan
ofc_pt_right_ipg_data_ofc[['right_lead_location', 'left_lead_location']] = 'OFC'

# Finish up the OFC data
raw_df.drop(labels=ofc_pt_right_ipg_data_vcvs.index, inplace=True)
raw_df.drop(labels=ofc_pt_left_ipg_data_vcvs.index, inplace=True)
raw_df = pd.concat([raw_df, ofc_pt_right_ipg_data_vcvs, ofc_pt_left_ipg_data_vcvs, ofc_pt_right_ipg_data_ofc, ofc_pt_left_ipg_data_ofc], ignore_index=True)
raw_df.dropna(subset=['lfp_left', 'lfp_right'], how='all', inplace=True, ignore_index=True)
raw_df.sort_values(by=['pt_id', 'left_lead_location', 'timestamp'], inplace=True, ignore_index=True)

# Relabel all remaining "OTHER" lead locations to VC/VS
raw_df.loc[raw_df['left_lead_location'] == "OTHER", 'left_lead_location'] = 'VC/VS'
raw_df.loc[raw_df['right_lead_location'] == "OTHER", 'right_lead_location'] = 'VC/VS'

assert (raw_df['left_lead_location'] == raw_df['right_lead_location']).all()
raw_df['lead_location'] = raw_df['left_lead_location'].where(raw_df['left_lead_location'] == raw_df['right_lead_location'], None)
raw_df.drop(columns=['left_lead_location', 'right_lead_location'], inplace=True)

assert (raw_df['left_lead_model'] == raw_df['right_lead_model']).all()
raw_df['lead_model'] = raw_df['left_lead_model'].where(raw_df['left_lead_model'] == raw_df['right_lead_model'], None)
raw_df.drop(columns=['left_lead_model', 'right_lead_model'], inplace=True)

# Get rid of all non-VC/VS data
raw_df.query('lead_location == "VC/VS"', inplace=True)

# Sort and format the raw data
raw_df['timestamp'] = pd.to_datetime(raw_df['timestamp'])
raw_df.sort_values(by=['pt_id', 'lead_location', 'timestamp'], inplace=True, ignore_index=True)

When new data is measured, the Percept device overwrites the oldest data, always storing exactly $N$ days worth of data, where $N$ is 35 or 60, depending on the model of the Percept device. This means that data files always contain the most recent $N$ days of data, not just the data collected since the last download. Thus, a single data sample or a collection of many consecutive samples may be duplicated several times throughout the dataset due to it appearing in multiple data files.

At clinical visits, we sync the device with network time, and it applies a transformation to the collected data in an attempt to make its timestamps more accurate. This means that the duplicated data streams may have slightly offset timestamps. To severely mitigate this issue, we only keep data from the most recent file for any particular timestamp.

In [None]:
# Remove any duplicate data points included in multiple files (sometimes even with time drift)
# Step 1: Get first timestamp per pt_id and source_file
file_start_times = (
    raw_df.groupby(['pt_id', 'source_file'])['timestamp']
    .min()
    .reset_index()
    .rename(columns={'timestamp': 'file_start_time'})
)
file_end_times = (
    raw_df.groupby(['pt_id', 'source_file'])['timestamp']
    .max()
    .reset_index()
    .rename(columns={'timestamp': 'file_end_time'})
)

# Step 2: Sort files in temporal order per pt_id
file_order = (
    file_start_times.sort_values(['pt_id', 'file_start_time'])
    .groupby('pt_id')['source_file']
    .apply(list)
    .to_dict()
)

# Step 3: For each file, record the not-yet-seen files (per pt_id)
file_to_future = {
    (pt_id, current_file): set(files[i+1:])
    for pt_id, files in file_order.items()
    for i, current_file in enumerate(files)
}

# Get start and end times for each future file
file_bounds = file_start_times.merge(file_end_times, on=['pt_id', 'source_file']).rename(columns={'source_file': 'future_file'})
expanded = []
for (pt_id, source_file), future_files in file_to_future.items():
    for future_file in future_files:
        expanded.append({'pt_id': pt_id, 'source_file': source_file, 'future_file': future_file})
future_map = pd.DataFrame(expanded)

future_map = future_map.merge(file_bounds, on=['pt_id', 'future_file'], how='left')
future_map = future_map.groupby(['pt_id', 'source_file']).apply(lambda x: list(zip(x['file_start_time'], x['file_end_time'])), include_groups=False).reset_index(name='time_intervals')

raw_df = raw_df.merge(future_map, on=['pt_id', 'source_file'], how='left')
remove = raw_df.apply(lambda row: isinstance(row['time_intervals'], list) and any(start <= row['timestamp'] <= end for start, end in row['time_intervals']), axis=1)
raw_df = raw_df[~remove]
raw_df.drop(columns=['time_intervals'], inplace=True)
raw_df.reset_index(drop=True, inplace=True)

### Process data

In [None]:
# Process raw data

# Fill outliers: define which filling method(s) you want to use
outlier_fill_methods = {
    'threshold': artifact_correct.threshold_outliers,
    'OvER': artifact_correct.fill_outliers_OvER
}

raw_df.rename(columns={'lfp_left': 'lfp_left_raw', 'lfp_right': 'lfp_right_raw'}, inplace=True)

# Put timestamps in datetime format and drop any remaining duplicate readings
raw_df['timestamp'] = pd.to_datetime(raw_df['timestamp'])
raw_df.drop_duplicates(subset=['pt_id', 'timestamp', 'lead_location'], inplace=True)
raw_df.sort_values(['pt_id', 'lead_location', 'timestamp'], inplace=True, ignore_index=True)

# Convert timestamp column into central time for analysis.
raw_df['CT_timestamp'] = raw_df['timestamp'].dt.tz_convert(central_time)

# Add columns for 10 minute bin reading falls into (round up to 10 minute interval)
raw_df['time_bin'] = raw_df['timestamp'].dt.ceil('10min')
raw_df['time_bin_time'] = raw_df['time_bin'].dt.time

# If there are still any duplicates, eliminate them.
processed_df = raw_df.drop_duplicates(subset=['pt_id', 'time_bin', 'lead_location', 'lfp_left_raw', 'lfp_right_raw'], keep='first', ignore_index=True).copy()

# Add column for days since first VC/VS DBS activation
dbs_on_date_dict = {}
for pt_id in processed_df['pt_id'].unique():
    dbs_on_date_dict[pt_id] = utils.get_dbs_on_date(patient_dict[pt_id]['dbs_on_date'])
processed_df['dbs_start_date'] = processed_df['pt_id'].map(dbs_on_date_dict)
processed_df['days_since_dbs'] = (processed_df['CT_timestamp'].dt.date - processed_df['dbs_start_date']).apply(lambda td: td.days)

# Add empty new rows to fill in missing timestamps, and corrected outliers and interpolate missing rows.
added_rows = processed_df.groupby(['pt_id', 'lead_location'], group_keys=False).apply(lambda g: utils.add_empty_rows(g, g.name[0], g.name[1], dbs_on_date=dbs_on_date_dict[g.name[0]]), include_groups=False)
if not added_rows.empty:
    processed_df = pd.concat((processed_df, added_rows), ignore_index=True)
processed_df.sort_values(by=['pt_id', 'lead_location', 'timestamp'], inplace=True, ignore_index=True)

# Fix overvoltages and fill in holes in data using the specified method(s).
for name, func in outlier_fill_methods.items():
    outlier_corrected_cols = processed_df.groupby(['pt_id', 'lead_location'], group_keys=False)\
        .apply(lambda g: func(g, cols_to_fill=['lfp_left_raw', 'lfp_right_raw']), include_groups=False)
    processed_df = pd.merge(processed_df, outlier_corrected_cols, how='outer', left_index=True, right_index=True)
    for hem, other_hem in ['left', 'right'], ['right', 'left']:
        # Get rid of all rows where other hem is not NaN but this hem is NaN (these were likely never meant to have data from this hem).
        interp_df = processed_df.loc[processed_df[f'lfp_{hem}_raw'].notna() | processed_df[f'lfp_{other_hem}_raw'].isna()]
        filled_cols = interp_df.groupby(['pt_id', 'lead_location'], group_keys=False)\
            .apply(lambda g: artifact_correct.interpolate_holes(g, cols_to_fill=[col for col in outlier_corrected_cols.columns if ((hem in col) and ('num_overages' not in col))]), include_groups=False)
        processed_df = pd.merge(processed_df, filled_cols, how='outer', left_index=True, right_index=True)
corr_col_names = [f'lfp_left_{name}_interpolate' for name in outlier_fill_methods.keys()] + \
                 [f'lfp_right_{name}_interpolate' for name in outlier_fill_methods.keys()]
processed_df.dropna(subset=corr_col_names, how='all', inplace=True) # Get rid of any rows that are still empty in both hems

# Mark rows that were changed with 'corrected' tag.
processed_df['left_corrected'] = False
processed_df['right_corrected'] = False
for name in outlier_fill_methods.keys():
    processed_df.loc[processed_df[f'lfp_left_{name}_interpolate'].notna() & (processed_df['lfp_left_raw'] != processed_df[f'lfp_left_{name}_interpolate']), 'left_corrected'] = True
    processed_df.loc[processed_df[f'lfp_right_{name}_interpolate'].notna() & (processed_df['lfp_right_raw'] != processed_df[f'lfp_right_{name}_interpolate']), 'right_corrected'] = True

# Z score LFP data within each day.
processed_df = processed_df.reset_index(drop=True)
groups = processed_df.groupby(['pt_id', 'lead_location', 'days_since_dbs'], group_keys=False)
zscored_data = groups.apply(lambda g: utils.zscore_group(g, cols_to_zscore=corr_col_names), include_groups=False)
zscored_cols = zscored_data.columns
processed_df = pd.merge(processed_df, zscored_data, how='outer', left_index=True, right_index=True) # There must be a better way to do this

df = processed_df.copy()

# Drop any remaining duplicate readings and sort DF.
df.drop_duplicates(['pt_id', 'lfp_left_raw', 'lfp_right_raw', 'stim_left', 'stim_right', 'lead_location', 'time_bin'], inplace=True, ignore_index=True)
df.sort_values(by=['pt_id', 'lead_location', 'timestamp'], inplace=True, ignore_index=True)

# Get rid of any uninterpretable data
df.query('pt_id != "B001" or days_since_dbs <= 100', inplace=True) # Cut off B001 after 100 days (opted out of research)
df.query('pt_id != "B015" or days_since_dbs >= 830', inplace=True) # Cut out B015's data before 830 days (almost 100% outliers)
df.reset_index(drop=True, inplace=True)

# Mark outliers.
df['is_outlier_left'] = (df['lfp_left_raw'] >= ((2 ** 32) - 1) / 60) & (df['lfp_left_raw'].notna())
df['is_outlier_right'] = (df['lfp_right_raw'] >= ((2 ** 32) - 1) / 60) & (df['lfp_right_raw'].notna())

In [None]:
samples_per = pd.DataFrame(df['pt_id'].value_counts(sort=False))
samples_per['unique days'] = df.groupby('pt_id')['days_since_dbs'].nunique()
samples_per['days pre-DBS'] = df.query('days_since_dbs < 0').groupby('pt_id')['days_since_dbs'].nunique()
samples_per['days pre-DBS'] = samples_per['days pre-DBS'].fillna(0).astype(int)
samples_per['days post-DBS'] = df.query('days_since_dbs > 0').groupby('pt_id')['days_since_dbs'].nunique()
samples_per['days post-DBS'] = samples_per['days post-DBS'].fillna(0).astype(int)
samples_per.reset_index(inplace=True)

# Sum numeric columns into a new 'total' row
samples_per.loc['total'] = samples_per.sum(axis=0, numeric_only=True)
samples_per = samples_per.apply(lambda col: col.astype(int) if pd.api.types.is_numeric_dtype(col) else col) # Turn all numeric columns into ints
samples_per

In [None]:
print(f'df takes up {df.memory_usage().sum() / (2 ** 20):.2f}MB of memory')

## Overvoltage Analysis

### Quantify outlier frequency

In [None]:
hem_for_this_cell = 'right'
vcvs_df = df.query('lead_location == "VC/VS"').dropna(subset=f'lfp_{hem_for_this_cell}_raw')
vcvs_df.groupby('pt_id')[f'is_outlier_{hem_for_this_cell}'].mean() * 100

### Time of peak per day of each outlier handling method

In [None]:
sigma=6
df['lfp_left_OvER_interpolate_z_scored_gaussian_smoothed'] = df.groupby(['pt_id', 'days_since_dbs'], group_keys=False)['lfp_left_OvER_interpolate_z_scored'].transform(lambda x: utils.gaussian_smooth(x, sigma=sigma))
df['lfp_left_threshold_interpolate_z_scored_gaussian_smoothed'] = df.groupby(['pt_id', 'days_since_dbs'], group_keys=False)['lfp_left_threshold_interpolate_z_scored'].transform(lambda x: utils.gaussian_smooth(x, sigma=sigma))

pts_w_outliers = ['B009', 'B012', 'B015', 'B020'] # we leave out B002 because they don't have very much left hemisphere data
df_outliers = df.query('pt_id in @pts_w_outliers and lead_location == "VC/VS"')
peak_times_OvER, peak_times_threshold = [], []
for (pt_id, days_since_dbs), day_df in df_outliers.groupby(['pt_id', 'days_since_dbs']):
    peak_val_OvER = day_df['lfp_left_OvER_interpolate_z_scored_gaussian_smoothed'].max()
    peak_time_OvER = day_df.loc[day_df['lfp_left_OvER_interpolate_z_scored_gaussian_smoothed'] == peak_val_OvER, 'CT_timestamp'].iloc[0]
    time_seconds = peak_time_OvER.time().hour * 3600 + peak_time_OvER.time().minute * 60 + peak_time_OvER.time().second
    peak_times_OvER.append([time_seconds / (24*3600), peak_val_OvER])

    peak_val_threshold = day_df['lfp_left_threshold_interpolate_z_scored_gaussian_smoothed'].max()
    peak_time_threshold = day_df.loc[day_df['lfp_left_threshold_interpolate_z_scored_gaussian_smoothed'] == peak_val_threshold, 'CT_timestamp'].iloc[0]
    time_seconds = peak_time_threshold.time().hour * 3600 + peak_time_threshold.time().minute * 60 + peak_time_threshold.time().second
    peak_times_threshold.append([time_seconds / (24*3600), peak_val_threshold])
peak_times_OvER = np.vstack(peak_times_OvER)
peak_times_threshold = np.vstack(peak_times_threshold)

In [None]:
def cosinor(t, M, A, phi):
    omega = 2 * np.pi / 24
    return M + A * np.cos(omega * t + phi)

In [None]:
omega = 2 * np.pi / 24
window_days = 5
half_window = pd.Timedelta(days=window_days / 2)

peak_times_OvER = []
for (pt_id, lead_location), group_df in df_outliers.groupby(['pt_id', 'lead_location']):
    group_df = group_df.sort_values('CT_timestamp').reset_index(drop=True)
    unique_days = group_df['CT_timestamp'].dt.floor('D').unique()
    p_val_amps = []
    for center_day in unique_days:
        window_start = center_day + pd.Timedelta(hours=12) - half_window
        window_end = center_day + pd.Timedelta(hours=12) + half_window

        window_df = group_df[(group_df['CT_timestamp'] >= window_start) & (group_df['CT_timestamp'] < window_end)]

        t_hours = window_df['CT_timestamp'].dt.hour + window_df['CT_timestamp'].dt.minute / 60 + window_df['CT_timestamp'].dt.second / 3600
        y = window_df['lfp_left_OvER_interpolate'].values

        mask = ~np.isnan(y)
        t_hours = t_hours[mask]
        y = y[mask]

        if len(y) < 10:
            continue

        # Initial guesses: M=mean, A=half range, phi=0
        M0 = np.mean(y)
        A0 = (np.max(y) - np.min(y)) / 2
        phi0 = 0

        try:
            popt, pcov = curve_fit(cosinor, t_hours, y, p0=[M0, A0, phi0])
            perr = np.sqrt(np.diag(pcov))
            n = len(y)
            p = len(popt)
            dof = max(0, n-p)
            t_stats = popt / perr
            p_values = 2 * (1 - stats.t.cdf(np.abs(t_stats), dof))  # two-tailed p-values
            p_val_amps.append(p_values[1])  # p-value for amplitude
            M_fit, A_fit, phi_fit = popt

            if A_fit < 0:
                A_fit = -A_fit
                phi_fit += np.pi

            # Time of peak: when cos is 1 => ωt + φ = 0 mod 2π => t_peak = -φ / ω
            omega = 2 * np.pi / 24
            t_peak = (-phi_fit % (2 * np.pi)) / omega  # in hours

            # Convert t_peak to timestamp on the same date as day_df
            date = day_df['CT_timestamp'].iloc[0].normalize()
            peak_timestamp = pd.Timestamp(center_day) + pd.to_timedelta(t_peak, unit='h')

            peak_times_OvER.append([t_peak * 3600 / (24 * 3600), M_fit + np.abs(A_fit), pt_id])  # Store peak time and amplitude

        except RuntimeError:
            # fitting failed
            continue
    print(pt_id, lead_location)
    print(f'Min p: {np.min(p_val_amps):.4g}')
    print(f'Mean p: {np.mean(p_val_amps):.4g}')
    print(f'Max p: {np.max(p_val_amps):.4g}')
    print()

peak_times_OvER = np.vstack(peak_times_OvER)

In [None]:
peak_times_threshold = []
for (pt_id, lead_location), group_df in df_outliers.groupby(['pt_id', 'lead_location']):
    group_df = group_df.sort_values('CT_timestamp').reset_index(drop=True)
    unique_days = group_df['CT_timestamp'].dt.floor('D').unique()
    p_val_amps = []
    for center_day in unique_days:
        window_start = center_day + pd.Timedelta(hours=12) - half_window
        window_end = center_day + pd.Timedelta(hours=12) + half_window

        window_df = group_df[(group_df['CT_timestamp'] >= window_start) & (group_df['CT_timestamp'] < window_end)]

        t_hours = window_df['CT_timestamp'].dt.hour + window_df['CT_timestamp'].dt.minute / 60 + window_df['CT_timestamp'].dt.second / 3600
        y = window_df['lfp_left_threshold_interpolate'].values

        mask = ~np.isnan(y)
        t_hours = t_hours[mask]
        y = y[mask]

        if len(y) < 10:
            continue

        # Initial guesses: M=mean, A=half range, phi=0
        M0 = np.mean(y)
        A0 = (np.max(y) - np.min(y)) / 2
        phi0 = 0

        try:
            popt, pcov = curve_fit(cosinor, t_hours, y, p0=[M0, A0, phi0])
            perr = np.sqrt(np.diag(pcov))
            n = len(y)
            p = len(popt)
            dof = max(0, n-p)
            t_stats = popt / perr
            p_values = 2 * (1 - stats.t.cdf(np.abs(t_stats), dof))  # two-tailed p-values
            p_val_amps.append(p_values[1])  # p-value for amplitude
            M_fit, A_fit, phi_fit = popt

            if A_fit < 0:
                A_fit = -A_fit
                phi_fit += np.pi

            # Time of peak: when cos is 1 => ωt + φ = 0 mod 2π => t_peak = -φ / ω
            omega = 2 * np.pi / 24
            t_peak = (-phi_fit % (2 * np.pi)) / omega  # in hours

            # Convert t_peak to timestamp on the same date as day_df
            date = day_df['CT_timestamp'].iloc[0].normalize()
            peak_timestamp = pd.Timestamp(center_day) + pd.to_timedelta(t_peak, unit='h')

            peak_times_threshold.append([t_peak * 3600 / (24 * 3600), M_fit + np.abs(A_fit), pt_id])  # Store peak time and amplitude

        except RuntimeError:
            # fitting failed
            continue
    print(pt_id, lead_location)
    print(f'Min p: {np.min(p_val_amps):.4g}')
    print(f'Mean p: {np.mean(p_val_amps):.4g}')
    print(f'Max p: {np.max(p_val_amps):.4g}')
    print()

peak_times_threshold = np.vstack(peak_times_threshold)

In [None]:
fig, ax = plt.subplots(figsize=(4,4), subplot_kw=dict(projection='polar'))
ax.scatter(peak_times_threshold[:, 0].astype(float) * 2 * np.pi, peak_times_threshold[:, 1].astype(float), label='Threshold Peaks', s=10, marker='o', alpha=0.5)
ax.scatter(peak_times_OvER[:, 0].astype(float) * 2 * np.pi, peak_times_OvER[:, 1].astype(float), label='OvER Peaks', s=10, marker='o', alpha=0.5)
ax.set_theta_direction(-1)
ax.set_theta_zero_location('N')
ax.set(xticks=np.linspace(0, 2*np.pi, 8, endpoint=False), xticklabels=['', '03:00', '06:00', '09:00', '12:00', '15:00', '18:00', '21:00'],
       yscale='log', ylim=[1, ax.get_ylim()[1]])
plt.legend()
fig.tight_layout()
plt.show()

In [None]:
fig, axs = plt.subplots(ncols=len(pts_w_outliers), figsize=(3 * len(pts_w_outliers), 4), subplot_kw=dict(projection='polar'))
for ax, pt_id in zip(axs.flatten(), pts_w_outliers):
    pt_peak_times_OvER = peak_times_OvER[peak_times_OvER[:, 2] == pt_id, :2].astype(float)
    pt_peak_times_threshold = peak_times_threshold[peak_times_threshold[:, 2] == pt_id, :2].astype(float)
    
    ax.scatter(pt_peak_times_threshold[:, 0].astype(float) * 2 * np.pi, pt_peak_times_threshold[:, 1].astype(float), label='Threshold Peaks', s=10, marker='o', alpha=0.5)
    ax.scatter(pt_peak_times_OvER[:, 0].astype(float) * 2 * np.pi, pt_peak_times_OvER[:, 1].astype(float), label='OvER Peaks', s=10, marker='o', alpha=0.5)
    
    ax.set_theta_direction(-1)
    ax.set_theta_zero_location('N')
    ax.set(xticks=np.linspace(0, 2*np.pi, 8, endpoint=False), xticklabels=['', '03:00', '06:00', '09:00', '12:00', '15:00', '18:00', '21:00'],)
    ax.set_title(pt_id)

    ax.set_yscale('log')
plt.legend()
fig.tight_layout()
plt.show()

## Oura

### Read in and sync Oura ring data

#### Using data on server

In [None]:
# Read in Oura activity and sleep data from server
activity_mapping = {
    0: 'nonwear',
    1: 'rest',
    2: 'inactive',
    3: 'low activity',
    4: 'medium activity',
    5: 'high activity'
}
activity_categories = [activity_mapping[k] for k in sorted(activity_mapping.keys())]
activity_dtype = pd.CategoricalDtype(categories=activity_categories, ordered=True)
sleep_phase_mapping = {
    1: 'deep',
    2: 'light',
    3: 'REM',
    4: 'awake'
}
sleep_phase_categories = [sleep_phase_mapping[k] for k in sorted(sleep_phase_mapping.keys())]
sleep_phase_dtype = pd.CategoricalDtype(categories=sleep_phase_categories, ordered=True)

met_dfs, activity_class_dfs, sleep_dfs = [], [], []
for pt_id in tqdm(ids):
    # Find the JSON file names containing data for the current patient.
    pt_oura_dir = Path(patient_dict[pt_id]['data_path']) / '..' / 'oura'

    if pt_oura_dir.exists():
        for date_dir in pt_oura_dir.iterdir():
            activity_path = date_dir / 'daily_activity.json'
            if activity_path.exists():
                with open(activity_path, 'r') as f:
                    activity_data = json.load(f)[0]
                met_data = activity_data['met']['items']
                met_first_timestamp = pd.to_datetime(datetime.fromisoformat(activity_data['met']['timestamp']))
                met_interval = timedelta(seconds=activity_data['met']['interval'])
                met_timestamps = met_first_timestamp + met_interval * np.arange(len(met_data))
                day_met_df = pd.DataFrame({'timestamp': met_timestamps, 'met_score': met_data, 'pt_id': pt_id})
                met_dfs.append(day_met_df)

                activity_class_data = [int(c) for c in list(activity_data['class_5_min'])]
                activity_class_first_timestamp = pd.to_datetime(datetime.fromisoformat(activity_data['timestamp']))
                activity_class_interval = timedelta(minutes=5)
                activity_class_timestamps = activity_class_first_timestamp + activity_class_interval * np.arange(len(activity_class_data))
                day_activity_class_df = pd.DataFrame({'timestamp': activity_class_timestamps, 'activity_class': activity_class_data, 'pt_id': pt_id})
                activity_class_dfs.append(day_activity_class_df)

            sleep_path = date_dir / 'sleep.json'
            if sleep_path.exists():
                with open(sleep_path, 'r') as f:
                    sleep_data = json.load(f)
                for individual_sleep_data in sleep_data:
                    sleep_phases = [int(p) for p in list(individual_sleep_data['sleep_phase_5_min'])]
                    sleep_first_timestamp = pd.to_datetime(datetime.fromisoformat(individual_sleep_data['bedtime_start']))
                    sleep_interval = timedelta(minutes=5)
                    sleep_timestamps = sleep_first_timestamp + sleep_interval * np.arange(len(sleep_phases))
                    day_sleep_df = pd.DataFrame({'timestamp': sleep_timestamps, 'sleep_phase': sleep_phases, 'pt_id': pt_id})
                    sleep_dfs.append(day_sleep_df)

met_df = pd.concat(met_dfs, ignore_index=True)
met_df.sort_values(['pt_id', 'timestamp'], inplace=True)
activity_class_df = pd.concat(activity_class_dfs, ignore_index=True)
activity_class_df['activity_class'] = activity_class_df['activity_class'].map(activity_mapping).astype(activity_dtype)
activity_class_df.sort_values(['pt_id', 'timestamp'], inplace=True)
sleep_df = pd.concat(sleep_dfs, ignore_index=True)
sleep_df['sleep_phase'] = sleep_df['sleep_phase'].map(sleep_phase_mapping).astype(sleep_phase_dtype)
sleep_df.sort_values(['pt_id', 'timestamp'], inplace=True)

In [None]:
# Set up timestamps in CT and UTC for all dataframes.
CT_convert = np.vectorize(lambda x: x.tz_convert('America/Chicago') if (pd.notna(x) and x.tzinfo) else pd.NaT)
get_orig_time_zone = np.vectorize(lambda x: x.tzname() if (pd.notna(x) and x.tzinfo) else None)

met_df['raw_timestamp'] = met_df['timestamp'].copy()  # Save the original timestamp before converting to CT
met_df['CT_timestamp'] = CT_convert(met_df['timestamp'])
met_df['original_time_zone'] = get_orig_time_zone(met_df['timestamp'])
met_df['timestamp'] = met_df['CT_timestamp'].dt.tz_convert('UTC')

activity_class_df['raw_timestamp'] = activity_class_df['timestamp'].copy()  # Save the original timestamp before converting to CT
activity_class_df['CT_timestamp'] = CT_convert(activity_class_df['timestamp'])
activity_class_df['original_time_zone'] = get_orig_time_zone(activity_class_df['timestamp'])
activity_class_df['timestamp'] = activity_class_df['CT_timestamp'].dt.tz_convert('UTC')

sleep_df['raw_timestamp'] = sleep_df['timestamp'].copy()  # Save the original timestamp before converting to CT
sleep_df['CT_timestamp'] = CT_convert(sleep_df['timestamp'])
sleep_df['original_time_zone'] = get_orig_time_zone(sleep_df['timestamp']) # save the time zone the data was collected in so we know for sure where they were located
sleep_df['timestamp'] = sleep_df['CT_timestamp'].dt.tz_convert('UTC')

In [None]:
def align_oura_data(pt_lfp_df, pt_oura_df, oura_val_colname, new_col_name, nonwear_fill_val=[0], get_orig_time_zone=True):
    """
    Aligns Oura data with LFP data for a given patient.

    Parameters:
    - pt_lfp_df (pd.DataFrame): DataFrame containing LFP data.
    - pt_oura_df (pd.DataFrame): DataFrame containing Oura data.
    - oura_val_colname (str): Column name in Oura data to align with LFP data.
    - new_col_name (str): New column name for the aligned Oura data.
    - nonwear_fill_val (Object, optional): Value to fill for non-wear periods.
    - get_orig_time_zone (bool, optional): Whether to include the original time zone in the output.
    """
    oura_interval_starts = pt_oura_df['timestamp'].values
    oura_interval_ends = pt_oura_df['timestamp'].values + pd.Timedelta(minutes=1)
    oura_intervals = np.vstack((oura_interval_starts, oura_interval_ends)).T

    lfp_interval_starts = pt_lfp_df['timestamp'].values - pd.Timedelta(minutes=10)
    lfp_interval_ends = pt_lfp_df['timestamp'].values
    lfp_intervals = np.vstack((lfp_interval_starts, lfp_interval_ends)).T

    lfp_starts = lfp_intervals[:, 0]
    lfp_ends = lfp_intervals[:, 1]
    oura_starts = oura_intervals[:, 0]
    oura_ends = oura_intervals[:, 1]

    upper = np.searchsorted(oura_starts, lfp_ends, side='right')
    lower = np.searchsorted(oura_ends, lfp_starts, side='left')

    oura_vals, time_zones = [], []
    for i in range(len(lfp_starts)):
        l, u = lower[i], upper[i]
        if l < u:
            candidate_idx = np.arange(l, u)
            slice_ends2 = oura_ends[l:u]
            slice_starts2 = oura_starts[l:u]
            mask = (slice_ends2 > lfp_starts[i]) & (slice_starts2 < lfp_ends[i])
            overlapping_indices = candidate_idx[mask]
        else:
            overlapping_indices = np.array([], dtype=int)
        oura_vals.append([pt_oura_df[oura_val_colname].iloc[overlapping_indices].values.tolist()] if len(overlapping_indices) > 0 else [nonwear_fill_val])
        if get_orig_time_zone:
            time_zone_set = set(pt_oura_df['original_time_zone'].iloc[overlapping_indices].values.tolist())
            time_zones.append(time_zone_set)
    
    return_df = pd.DataFrame(oura_vals, columns=[new_col_name], index=pt_lfp_df.index)
    if get_orig_time_zone:
        return_df['original_time_zone'] = time_zones
    return return_df

In [None]:
# Sync up the Oura data with the LFP data for activity and sleep by combining the two dataframes.
lfp_groups = df.groupby(['pt_id'], group_keys=False)
met_vals_df = lfp_groups.apply(lambda g: align_oura_data(g, met_df.query('pt_id == @g.name'), 'met_score', new_col_name='met_vals', nonwear_fill_val=[0]), include_groups=False)
activity_vals_df = lfp_groups.apply(lambda g: align_oura_data(g, activity_class_df.query('pt_id == @g.name'), 'activity_class', new_col_name='activity_classes', nonwear_fill_val=['nonwear']), include_groups=False)
sleep_phases_df = lfp_groups.apply(lambda g: align_oura_data(g, sleep_df.query('pt_id == @g.name'), 'sleep_phase', new_col_name='sleep_phases', nonwear_fill_val=[]), include_groups=False)
df['met_vals'] = met_vals_df['met_vals']
df['activity_classes'] = activity_vals_df['activity_classes']
df['sleep_phases'] = sleep_phases_df['sleep_phases']

df['original_time_zones'] = [s1 | s2 | s3 for s1, s2, s3 in zip(met_vals_df['original_time_zone'], activity_vals_df['original_time_zone'], sleep_phases_df['original_time_zone'])]

In [None]:
df['max_met'] = df['met_vals'].apply(lambda x: max(x))
df['avg_met'] = df['met_vals'].apply(lambda x: np.mean(x))
df['wearing_ring'] = df['activity_classes'].apply(lambda x: 'nonwear' not in x)
nanmet_inds = np.where(df['max_met'].isna())[0]
df.loc[df.index[nanmet_inds], 'max_met'] = pd.Series([0] * len(nanmet_inds), index=df.index[nanmet_inds])
df.loc[df.index[nanmet_inds], 'wearing_ring'] = pd.Series([False] * len(nanmet_inds), index=df.index[nanmet_inds])

# Fill in sleep data in more interpretable way
unknown_mask = ~df['wearing_ring']
awake_mask = df['wearing_ring'] & ((df['sleep_phases'].apply(lambda x: set(x) == {'awake'})) | (df['sleep_phases'].apply(lambda x: x == [])))
asleep_mask = df['wearing_ring'] & (~df['sleep_phases'].apply(lambda x: 'awake' in x)) & (df['sleep_phases'].apply(lambda x: x != []))
df['sleep_state'] = 'Mixed'
df.loc[unknown_mask, 'sleep_state'] = 'Unknown'
df.loc[awake_mask, 'sleep_state'] = 'Awake'
df.loc[asleep_mask, 'sleep_state'] = 'Asleep'

In [None]:
oura_df = df.query('wearing_ring').copy()

#### Validate Oura data

In [None]:
oura_df = df.query('wearing_ring == True and lead_location == "VC/VS"')
oura_samples_per = pd.DataFrame(oura_df['pt_id'].value_counts(sort=False), index=df['pt_id'].unique())
oura_samples_per['samples'] = oura_samples_per['count']
oura_samples_per.drop(columns=['count'], inplace=True)
for pt_id, pt_oura_df in oura_df.groupby('pt_id'):
    dbs_on_date = utils.get_dbs_on_date(patient_dict[pt_id]['dbs_on_date'])
    oura_samples_per.loc[pt_id, 'unique_days'] = pt_oura_df.groupby(pd.Grouper(key='CT_timestamp', freq='D')).head(1).shape[0]
    oura_samples_per.loc[pt_id, 'days pre-DBS'] = pt_oura_df[pt_oura_df['CT_timestamp'].dt.date < dbs_on_date].groupby(pd.Grouper(key='CT_timestamp', freq='D')).head(1).shape[0]
    oura_samples_per.loc[pt_id, 'days post-DBS'] = pt_oura_df[pt_oura_df['CT_timestamp'].dt.date > dbs_on_date].groupby(pd.Grouper(key='CT_timestamp', freq='D')).head(1).shape[0]
oura_samples_per.fillna(0, inplace=True)
oura_samples_per = oura_samples_per.astype(int)
oura_samples_per.loc['total'] = oura_samples_per.sum()
oura_samples_per

In [None]:
# Show when patients were wearing ring.
print('Patient was wearing ring in highlighted regions')

ncols = 3
df_vcvs = df.query('lead_location == "VC/VS"')
fig, axs = plt.subplots(nrows=np.ceil(df_vcvs['pt_id'].nunique() / ncols).astype(int), ncols=ncols, figsize=(20,16), sharey=True)
for ax, (pt_id, pt_df) in zip(axs.flatten(), df_vcvs.groupby('pt_id')):
    ax.scatter(pt_df['CT_timestamp'], pt_df[f'lfp_left_OvER_interpolate_z_scored'], s=2)
    utils.transform_timestamp_to_days(pt_df, ax)
    ax.set(xlabel='Days since DBS', ylabel='Residual Variance', title=pt_id)

    ring_inds = np.where(pt_df['wearing_ring'] == True)[0]
    continuous_chunks = np.split(ring_inds, np.where(np.diff(ring_inds) != 1)[0] + 1)
    if len(continuous_chunks[0]) == 0:
        continue
    for chunk in continuous_chunks:
        ax.axvspan(pt_df.iloc[chunk[0]]['CT_timestamp'], pt_df.iloc[chunk[-1]]['CT_timestamp'], color='limegreen', alpha=0.3, ec=None)

for ax in axs.flatten()[len(df_vcvs.groupby('pt_id')):]:
    ax.axis('off')
fig.tight_layout()
plt.show()

In [None]:
# Put all activity and sleep data into a separate dataframe that has not been reduced to just that intersecting with neural data.
sleep_df['utc_times'] = sleep_df['timestamp'].dt.ceil('5min')
met_df['utc_times'] = met_df['timestamp'].dt.ceil('5min')
sleep_df['home_time_sleep'] = sleep_df.apply(lambda r: r['utc_times'].tz_convert(r['original_time_zone']).time(), axis=1)
met_df['home_time_met'] = met_df.apply(lambda r: r['utc_times'].tz_convert(r['original_time_zone']).time(), axis=1)
activity_and_sleep_df = pd.merge(sleep_df[['pt_id', 'utc_times', 'sleep_phase', 'home_time_sleep']], met_df[['pt_id', 'utc_times', 'met_score', 'home_time_met']], on=['pt_id', 'utc_times'], how='outer')
activity_and_sleep_df.query('met_score > 0.1', inplace=True) # Remove non-wear data.
activity_and_sleep_df.fillna({'sleep_phase': 'awake'}, inplace=True)
activity_and_sleep_df.loc[activity_and_sleep_df['home_time_sleep'].isna(), 'home_time_sleep'] = activity_and_sleep_df.loc[activity_and_sleep_df['home_time_sleep'].isna(), 'home_time_met']
activity_and_sleep_df['sleep_state'] = np.where(activity_and_sleep_df['sleep_phase'] == 'awake', 'Awake', 'Asleep')
activity_and_sleep_df.shape

## Figure Creation

### NER Paper/SfN Figures

In [None]:
plt.rcParams['font.family'] = 'serif'

overvoltage_color = 'crimson'
nonovervoltage_color = 'coral'
corrected_overvoltage_color = 'dodgerblue'

hemi = 'left'

We added a lot of "if True" statements to be able to fold blocks in VS Code. They don't do anything functionally, purely cosmetic. We made all our figures in Python without the use of a vector editing program like Adobe Illustrator because we wanted to be able to reproduce them easily and quickly. Illustrator also could not handle the number of data points we had in our figures, so we had to use Python.


In [None]:
# Figure 1.

fig = plt.figure(figsize=(13, 8))
gs = gridspec.GridSpec(nrows=3, ncols=2, figure=fig, height_ratios=[2,2, 1], width_ratios=[4, 1])

# Fig A-B: Scatter plots for patients B009 and B004
if True:
    subplot_AB_gs = gs[:-1, :-1].subgridspec(2, 1, hspace=0.4)
    axA = fig.add_subplot(subplot_AB_gs[0, 0])
    axB = fig.add_subplot(subplot_AB_gs[1, 0], sharey=axA)

    # Fig A: Draw scatter plot for patient B009 (3387 leads)
    if True:
        s = 0.15
        id = 'B009'
        pt_df = df.query(f'pt_id == @id and lfp_{hemi}_raw > 15')
        non_outlier_pts = pt_df.query(f'is_outlier_{hemi} == False')
        axA.scatter(non_outlier_pts['CT_timestamp'], non_outlier_pts[f'lfp_{hemi}_raw'], s=s, c=nonovervoltage_color, label='_nolegend_')
        outlier_pts = pt_df.query(f'is_outlier_{hemi} == True')
        axA.scatter(outlier_pts['CT_timestamp'], outlier_pts[f'lfp_{hemi}_raw'], s=s, c=overvoltage_color, label='_nolegend_')

        axA_twin = axA.twinx()
        daily_outlier_percentage = pt_df.groupby(pd.Grouper(key='CT_timestamp', freq='D'))[f'is_outlier_{hemi}'].mean() * 100
        axA_twin.plot(daily_outlier_percentage.index, daily_outlier_percentage.values, color='darkslateblue', lw=1)

        axA.set(xlabel='Days Since DBS', ylabel='LFP Power', yscale='log', title=f'{id} (3387 Leads)')
        axA.yaxis.label.set_color('k')
        axA.tick_params(axis='y', colors='k', which='both')
        axA_twin.set(ylabel='Daily Overvoltage Percentage (%)', ylim=[-3, 100])
        axA_twin.yaxis.label.set_color('darkslateblue')
        axA_twin.tick_params(axis='y', colors='darkslateblue', which='both')
        utils.transform_timestamp_to_days(pt_df, axA)

    # Fig A callout: zoom in on outliers
    if True:
        callout_df = pt_df.query('days_since_dbs > 1530')
        subplot_A_callout = gs[0, -1].subgridspec(1, 1)
        axA_callout = fig.add_subplot(subplot_A_callout[0, 0])
        axA_callout.scatter(callout_df.query('is_outlier_left == True')['CT_timestamp'],
                            callout_df.query('is_outlier_left == True')['lfp_left_raw'],
                            s=0.1, c=overvoltage_color, label='_nolegend_')
        axA_callout.scatter(callout_df.query('is_outlier_left == False')['CT_timestamp'],
                            callout_df.query('is_outlier_left == False')['lfp_left_raw'],
                            s=0.1, c=nonovervoltage_color, label='_nolegend_')
        utils.transform_timestamp_to_days(callout_df, axA_callout)
        axA_callout.ticklabel_format(style='scientific', axis='y', scilimits=(0, 0))
        axA_callout.set(xlabel='Days Since DBS', ylabel='LFP Power', title=f'{id}', ylim=[-1e8, axA_callout.get_ylim()[1]],
                        yticks=[0, 1e9, 2e9], yticklabels=['0', r'10$^\text{9}$', r'2$\times$10$^\text{9}$'])
        
        axA_callout_right = axA_callout.twinx()
        yticks = np.linspace(0, 2**32-1, 61, endpoint=True)[::5]
        axA_callout_right.set_yticks(yticks)
        axA_callout_right.set_yticklabels(np.arange(0, 61, 5))
        axA_callout_right.set_ylim(axA_callout.get_ylim())
        axA_callout_right.set_ylabel('Number of Overvoltages per Interval')

        [spine.set_color('b') for spine in axA_callout.spines.values()]
        [spine.set_color('b') for spine in axA_callout_right.spines.values()]

        y_start_callout_box = 1e4
        rect = mpatches.Rectangle((axA.get_xlim()[1]-65, y_start_callout_box), 46, 3e9, color='b', zorder=0, lw=0.5, fill=False)
        axA.add_patch(rect)

    # Fig B: Draw scatter plot for patient B004 (SenSight leads)
    if True:
        id = 'B004'
        pt_df = df.query(f'pt_id == @id and lfp_{hemi}_raw > 15')
        non_outlier_pts = pt_df.query(f'is_outlier_{hemi} == False')
        axB.scatter(non_outlier_pts['CT_timestamp'], non_outlier_pts[f'lfp_{hemi}_raw'], s=s, c=nonovervoltage_color, label='_nolegend_')
        outlier_pts = pt_df.query(f'is_outlier_{hemi} == True')
        axB.scatter(outlier_pts['CT_timestamp'], outlier_pts[f'lfp_{hemi}_raw'], s=s, c=overvoltage_color, label='_nolegend_')

        axB_twin = axB.twinx()
        daily_outlier_percentage = pt_df.groupby(pd.Grouper(key='CT_timestamp', freq='D'))[f'is_outlier_{hemi}'].mean() * 100
        axB_twin.plot(daily_outlier_percentage.index, daily_outlier_percentage.values, color='darkslateblue', lw=1)

        axB.set(xlabel='Days Since DBS', ylabel='LFP Power', yscale='log', title=f'{id} (SenSight Leads)')
        axB.yaxis.label.set_color('k')
        axB.tick_params(axis='y', colors='k', which='both')
        axB_twin.set(ylabel='Daily Overvoltage Percentage (%)', ylim=[-3, 100])
        axB_twin.yaxis.label.set_color('darkslateblue')
        axB_twin.tick_params(axis='y', colors='darkslateblue', which='both')
        utils.transform_timestamp_to_days(pt_df, axB)

# Fig C: Box plots
if True:
    subplot_C_gs = gs[1:, -1:].subgridspec(2, 1, height_ratios=[4, 1])
    axC, axC_inset = fig.add_subplot(subplot_C_gs[0, 0]), fig.add_subplot(subplot_C_gs[1, 0])
    noise_amt, hem_spread = 0.03, 0.02

    # Fig C: Draw box plots for overvoltage percentage by lead model
    if True:
        lead_model_vals = []
        for i, (lead_model, lead_df) in enumerate(df.query('lead_location == "VC/VS"').groupby('lead_model')):
            lead_outlier_percents_left, lead_outlier_percents_right = [], []
            for pt_id, pt_df in lead_df.groupby('pt_id'):
                if not pt_df.dropna(subset=['lfp_left_raw']).empty:
                    left_outlier_percentage = pt_df.dropna(subset=['lfp_left_raw'])['is_outlier_left'].mean()
                    lead_outlier_percents_left.append(left_outlier_percentage * 100)
                if not pt_df.dropna(subset=['lfp_right_raw']).empty:
                    right_outlier_percentage = pt_df.dropna(subset=['lfp_right_raw'])['is_outlier_right'].mean()
                    lead_outlier_percents_right.append(right_outlier_percentage * 100)
            lead_outlier_percents = np.concatenate([lead_outlier_percents_left, lead_outlier_percents_right])
            axC.boxplot(lead_outlier_percents, positions=[i/2], whis=100000000, zorder=5)
            hor_noise = np.random.RandomState(42).uniform(-noise_amt, noise_amt, len(lead_outlier_percents))# * i
            axC.scatter(np.array([i/2] * len(lead_outlier_percents_left)) - hem_spread + np.random.RandomState(42).uniform(-noise_amt, noise_amt, len(lead_outlier_percents_left)) * i,
                        lead_outlier_percents_left, s=15, alpha=0.7, c='darkgreen', edgecolor='k', zorder=10)
            axC.scatter(np.array([i/2] * len(lead_outlier_percents_right)) + hem_spread + np.random.RandomState(42).uniform(-noise_amt, noise_amt, len(lead_outlier_percents_right)) * i,
                        lead_outlier_percents_right, s=15, alpha=0.7, c='darkgoldenrod', edgecolor='k', zorder=10)
            lead_model_vals.append(lead_outlier_percents)
        old_leads, new_leads, lead_location = 'LEAD_3387', 'LEAD_B33015', 'VC/VS'
        axC.set(xticks=[0,1/2],
                xticklabels=[f'3387 Lead\n(N={df.query("lead_model == @old_leads and lead_location == @lead_location")["pt_id"].nunique()})',
                             f'SenSight Lead\n(N={df.query("lead_model == @new_leads and lead_location == @lead_location")["pt_id"].nunique()})'],
                ylabel='Overvoltage Percentage (%)',
                xlim=[-.2, 0.7])
        rect = mpatches.Rectangle((axC.get_xlim()[0], -.2), (axC.get_xlim()[1] - axC.get_xlim()[0]), 0.8, color='b', zorder=0, lw=0.5, fill=False)
        axC.add_patch(rect)

        u_stat, p_value = stats.mannwhitneyu(lead_model_vals[0], lead_model_vals[1], alternative='two-sided')
        axC.text(0.65, 0.65, f'p={p_value:.3g}', transform=axC.transAxes, ha='center', va='bottom', fontsize=10)

    # Fig C inset: Draw inset box plots for overvoltage percentage by lead model close to 0
    if True:
        lead_model_vals = []
        for i, (lead_model, lead_df) in enumerate(df.query('lead_location == "VC/VS"').groupby('lead_model')):
            lead_outlier_percents_left, lead_outlier_percents_right = [], []
            for pt_id, pt_df in lead_df.groupby('pt_id'):
                if not pt_df.dropna(subset=['lfp_left_raw']).empty:
                    left_outlier_percentage = pt_df.dropna(subset=['lfp_left_raw'])['is_outlier_left'].mean()
                    lead_outlier_percents_left.append(left_outlier_percentage * 100)
                if not pt_df.dropna(subset=['lfp_right_raw']).empty:
                    right_outlier_percentage = pt_df.dropna(subset=['lfp_right_raw'])['is_outlier_right'].mean()
                    lead_outlier_percents_right.append(right_outlier_percentage * 100)
            lead_outlier_percents = np.concatenate([lead_outlier_percents_left, lead_outlier_percents_right])
            axC_inset.boxplot(lead_outlier_percents, positions=[i/2], whis=100000000, zorder=5)
            hor_noise = np.random.RandomState(42).uniform(-noise_amt, noise_amt, len(lead_outlier_percents))# * i
            axC_inset.scatter(np.array([i/2] * len(lead_outlier_percents_left)) - hem_spread + np.random.RandomState(42).uniform(-noise_amt, noise_amt, len(lead_outlier_percents_left)) * i,
                        lead_outlier_percents_left, s=15, alpha=0.7, c='darkgreen', edgecolor='k', zorder=10)
            axC_inset.scatter(np.array([i/2] * len(lead_outlier_percents_right)) + hem_spread + np.random.RandomState(42).uniform(-noise_amt, noise_amt, len(lead_outlier_percents_right)) * i,
                        lead_outlier_percents_right, s=15, alpha=0.7, c='darkgoldenrod', edgecolor='k', zorder=10)
            lead_model_vals.append(lead_outlier_percents)
        axC_inset.set(ylim=[-0.05, 0.25], xticks=[0,1/2], xticklabels=[f'3387 Lead', f'SenSight Lead'], xlim=[-.2, 0.7])
        [spine.set_color('b') for spine in axC_inset.spines.values()]

# Fig D: Polar plots for patients with high outlier percentage
if True:
    high_outlier_ids = ['B009', 'B012', 'B015', 'B020']
    subplot_D_gs = gs[-1, :-1].subgridspec(1, len(high_outlier_ids)+1, width_ratios=[1] * len(high_outlier_ids) + [0.3])
    axs = [fig.add_subplot(subplot_D_gs[0, i], polar=True) for i in range(len(high_outlier_ids))]

    # Fig D: Draw polar plots for patients with high outlier percentage
    if True:
        for i, (ax, id) in enumerate(zip(axs, high_outlier_ids)):
            pt_df = df.query('pt_id == @id and (pt_id != "B015" or days_since_dbs >= 830)').copy()
            pt_df['times'] = pt_df['CT_timestamp'].dt.tz_convert('UTC').dt.ceil('10min').dt.tz_convert('US/Central').dt.time # have to sidestep to UTC to avoid issues with daylight savings time

            if id == "B009":
                pt_df['times'] = pt_df['times'].apply(lambda t: datetime.combine(datetime.today(), t) + timedelta(hours=1)).dt.time
            
            # Calculate the x and y coordinates for the polar plot
            num_left_outliers = pt_df.groupby('times')[f'is_outlier_left'].sum() / pt_df.groupby('times')[f'is_outlier_left'].count()
            num_right_outliers = pt_df.groupby('times')[f'is_outlier_right'].sum() / pt_df.groupby('times')[f'is_outlier_right'].count()
            num_left_outliers /= num_left_outliers.max() if num_left_outliers.max() > 0 else 1
            num_right_outliers /= num_right_outliers.max() if num_right_outliers.max() > 0 else 1
            num_wedges = 24*6
            angle_offset = 2 * np.pi / num_wedges # Offset by a wedge to correctly place the bars
            angles = np.linspace(0 + angle_offset, 2*np.pi + angle_offset, num_wedges, endpoint=False)

            # Set midnight at the top and progress clockwise
            ax.set_theta_zero_location("N")
            ax.set_theta_direction(-1)

            # Draw the wedges
            for j, (angle, left_val, right_val) in enumerate(zip(angles, num_left_outliers.values, num_right_outliers.values)):
                ax.bar(angle, left_val, width=2*np.pi/num_wedges, alpha=1, color='C0', label=('Number of\nLeft Hemisphere\nOvervoltages' if j == 0 else '_nolegend_'))
                ax.bar(angle, right_val, width=2*np.pi/num_wedges, alpha=0.85, color='C8', label=('Number of\nRight Hemisphere\nOvervoltages' if j == 0 else '_nolegend_'))
            max_y = max(num_left_outliers.max(), num_right_outliers.max())

            # Draw line for when patient was awake
            pt_activity_sleep_df = activity_and_sleep_df.query('pt_id == @id').groupby('utc_times').head(3).groupby('utc_times').tail(1)
            awake_ratios = {}
            for time, group in pt_activity_sleep_df.groupby('home_time_sleep'):
                vc = group['sleep_state'].value_counts()
                awake_ratios[time] = vc['Awake'] / (vc['Awake'] + vc['Asleep']) if 'Awake' in vc and 'Asleep' in vc else (1 if 'Awake' in vc else 0)
            awake_ratios = pd.Series(awake_ratios)
            awake_angles = np.linspace(0, 2*np.pi, len(awake_ratios), endpoint=False)
            ax.plot(awake_angles, awake_ratios * max_y, color='purple', lw=1, label='Awake Ratio')

            # Clean up the plot
            ax.set(xticks=np.linspace(0, 2 * np.pi, 8, endpoint=False),
                xticklabels=['', '3:00', '6:00', '9:00', '12:00', '15:00', '18:00', '21:00'])
            if awake_ratios.notna().any():
                ax.set(yticks=[max_y] if awake_ratios.notna().any() else [],
                       yticklabels=['100%'] if awake_ratios.notna().any() else [])
                ax.get_yticklabels()[0].set_horizontalalignment('right')
                ax.get_yticklabels()[0].set_verticalalignment('top')
            else:
                ax.set(yticks=[], yticklabels=[])
            ax.tick_params(axis='y', colors='purple')
            ax.set_rlabel_position(245)
            ax.set(title=id)

# Finish up figure canvas
fig.tight_layout()

# Display legends
if True:
    # Fig A legend
    if True:
        outlier_legend_dot = Line2D([0], [0], marker='o', color=overvoltage_color, label='Overvoltage', markersize=4, linestyle='None')
        nonoutlier_legend_dot = Line2D([0], [0], marker='o', color=nonovervoltage_color, label='Nonovervoltage', markersize=4, linestyle='None')
        handles = [outlier_legend_dot, nonoutlier_legend_dot, axA_twin.get_lines()[0]]
        labels = ['Overvoltage', 'Nonovervoltage', 'Overvoltage Percentage']
        A_bottom = axA.get_xticklabels()[0].get_window_extent().transformed(fig.transFigure.inverted()).y0
        B_top = axB.get_window_extent().transformed(fig.transFigure.inverted()).y1
        A_left = axA.get_window_extent().transformed(fig.transFigure.inverted()).x0
        fig.legend(handles,
                labels,
                bbox_to_anchor=(A_left, (A_bottom + B_top) / 2),
                loc='center left',
                fontsize=8)
    # Fig D legend
    if True:
        axs[-1].legend(
            loc='center left',
            bbox_to_anchor=(1.3, 0.5),
            fontsize=8,
        )

# Draw lines connecting subfigure to inset subfigure for A.
if True:
    start_disp = [
        axA.transData.transform((axA.get_xlim()[1]-65+46, y_start_callout_box)),
        axA.transData.transform((axA.get_xlim()[1]-65+46, y_start_callout_box+3e9))
    ]
    end_disp = [
        axA_callout.transData.transform((axA_callout.get_xlim()[0], axA_callout.get_ylim()[0])),
        axA_callout.transData.transform((axA_callout.get_xlim()[0], axA_callout.get_ylim()[1]))
    ]
    for start, end in zip(start_disp, end_disp):
        start_fig = fig.transFigure.inverted().transform(start)
        end_fig = fig.transFigure.inverted().transform(end)
        line = Line2D([start_fig[0], end_fig[0]],
                      [start_fig[1], end_fig[1]],
                      lw=0.5, color='b')
        fig.add_artist(line)

# Draw lines connecting box plot subfigure to inset subfigure for C.
if True:
    start_disp = [
        axC.transData.transform((-.1, -.3)),
        axC.transData.transform((0.6, -.3))
    ]
    end_disp = [
        axC_inset.transData.transform((axC_inset.get_xlim()[0], axC_inset.get_ylim()[1])),
        axC_inset.transData.transform((axC_inset.get_xlim()[1], axC_inset.get_ylim()[1]))
    ]
    for start, end in zip(start_disp, end_disp):
        start_fig = fig.transFigure.inverted().transform(start)
        end_fig = fig.transFigure.inverted().transform(end)
        line = Line2D([start_fig[0], end_fig[0]],
                      [start_fig[1], end_fig[1]],
                      lw=0.5, color='b')
        fig.add_artist(line)

# Add subfigure labels
if True:
    vert_offset, hor_offset = 0.01, -0.04
    A_left = axA.get_window_extent().transformed(fig.transFigure.inverted()).x0
    A_top = axA.get_window_extent().transformed(fig.transFigure.inverted()).y1
    B_left = axB.get_window_extent().transformed(fig.transFigure.inverted()).x0
    B_top = axB.get_window_extent().transformed(fig.transFigure.inverted()).y1
    C_left = axC.get_window_extent().transformed(fig.transFigure.inverted()).x0
    C_top = axC.get_window_extent().transformed(fig.transFigure.inverted()).y1
    D_left = axs[0].get_window_extent().transformed(fig.transFigure.inverted()).x0
    D_top = axs[0].get_window_extent().transformed(fig.transFigure.inverted()).y1
    fs=14
    fig.text(A_left+hor_offset, A_top+vert_offset, 'a', fontsize=fs, fontweight='bold', color='k', fontdict={'family': 'sans-serif'})
    fig.text(B_left+hor_offset, B_top+vert_offset, 'b', fontsize=fs, fontweight='bold', color='k', fontdict={'family': 'sans-serif'})
    fig.text(C_left+hor_offset, C_top+vert_offset, 'c', fontsize=fs, fontweight='bold', color='k', fontdict={'family': 'sans-serif'})
    fig.text(D_left+hor_offset, D_top+vert_offset, 'd', fontsize=fs, fontweight='bold', color='k', fontdict={'family': 'sans-serif'})

plt.show()

In [None]:
# Figure 2.
violinplot_width_ratio = 0.15
inset_width_ratio = 0.4
wspace=0.02
pfontsize=9

fig = plt.figure(figsize=(12, 5))
gs = gridspec.GridSpec(nrows=2, ncols=3, figure=fig)

s1, s2 = 0.15, 1

id = 'B020'
pt_df = oura_df.query('pt_id == @id')
non_outlier_pts = pt_df.query(f'is_outlier_{hemi} == False and `lfp_left_OvER_interpolate` > 14').dropna(subset=[f'lfp_{hemi}_raw', f'lfp_left_OvER_interpolate', f'lfp_left_threshold_interpolate'], how='any')
outlier_pts = pt_df.query(f'is_outlier_{hemi} == True and `lfp_left_OvER_interpolate` > 14').dropna(subset=[f'lfp_{hemi}_raw', f'lfp_left_OvER_interpolate', f'lfp_left_threshold_interpolate'], how='any')

inset_df = pt_df.query('days_since_dbs >= 1924 and days_since_dbs < 1927').dropna(subset=[f'lfp_{hemi}_raw', f'lfp_left_OvER_interpolate', f'lfp_left_threshold_interpolate'], how='any')
non_outlier_pts_inset = inset_df.query(f'is_outlier_{hemi} == False and `lfp_left_OvER_interpolate` > 14').dropna(subset=[f'lfp_{hemi}_raw', f'lfp_left_OvER_interpolate', f'lfp_left_threshold_interpolate'], how='any')
outlier_pts_inset = inset_df.query(f'is_outlier_{hemi} == True and `lfp_left_OvER_interpolate` > 14').dropna(subset=[f'lfp_{hemi}_raw', f'lfp_left_OvER_interpolate', f'lfp_left_threshold_interpolate'], how='any')

window_size = 10

pts_w_outliers = ['B009', 'B012', 'B015', 'B020']
        
omega = 2 * np.pi / 24
window_days = 5
half_window = pd.Timedelta(days=window_days / 2)

# Fig A (OvER)
if True:
    # Set up subplot A
    subplot_A_gs = gs[0, :2].subgridspec(1, 2, width_ratios=[1, inset_width_ratio], wspace=wspace)
    axA1 = fig.add_subplot(subplot_A_gs[0, 0])
    axA2 = fig.add_subplot(subplot_A_gs[0, 1], sharey=axA1)

    # AI: Scatter plot with OvER
    axA1.scatter(non_outlier_pts['CT_timestamp'], non_outlier_pts[f'lfp_left_OvER_interpolate'], s=s1, c=nonovervoltage_color, label='_nolegend_')
    axA1.scatter(outlier_pts['CT_timestamp'], outlier_pts[f'lfp_left_OvER_interpolate'], s=s1, c=corrected_overvoltage_color, label='_nolegend_')
    axA1.scatter(outlier_pts['CT_timestamp'], outlier_pts[f'lfp_{hemi}_raw'], s=s1, c=overvoltage_color, label='_nolegend_')

    # AII: Show inset scatter plot for 3 days
    axA2.scatter(non_outlier_pts_inset['CT_timestamp'], non_outlier_pts_inset[f'lfp_left_OvER_interpolate'], s=s2, c=nonovervoltage_color, label='_nolegend_')
    axA2.scatter(outlier_pts_inset['CT_timestamp'], outlier_pts_inset[f'lfp_left_OvER_interpolate'], s=s2, c=corrected_overvoltage_color, label='_nolegend_')

    # Add smoothed line using Gaussian filter
    smoothed = gaussian_filter1d(inset_df[f'lfp_left_OvER_interpolate'], sigma=window_size/2)
    axA2.plot(inset_df['CT_timestamp'], smoothed, color='k', lw=0.5, label='_nolegend_')

    axA1.set(ylabel='LFP Power', yscale='log', xlabel='Days Since DBS')
    axA2.tick_params(axis='y', labelleft=False)
    axA2.set(ylabel='', xlabel='Days Since DBS')
    utils.transform_timestamp_to_days(pt_df, axA1)
    utils.transform_timestamp_to_days(inset_df, axA2)

# Fig B (threshold and interpolate)
if True:
    # Set up subplot B
    subplot_B_gs = gs[1, :-1].subgridspec(1, 2, width_ratios=[1, inset_width_ratio], wspace=wspace)
    axB1 = fig.add_subplot(subplot_B_gs[0, 0], sharex=axA1, sharey=axA1)
    axB2 = fig.add_subplot(subplot_B_gs[0, 1], sharex=axA2, sharey=axB1)

    # BI: Scatter plot with OvER
    axB1.scatter(non_outlier_pts['CT_timestamp'], non_outlier_pts[f'lfp_left_threshold_interpolate'], s=s1, c=nonovervoltage_color, label='_nolegend_')
    axB1.scatter(outlier_pts['CT_timestamp'], outlier_pts[f'lfp_left_threshold_interpolate'], s=s1, c=corrected_overvoltage_color, label='_nolegend_')
    axB1.scatter(outlier_pts['CT_timestamp'], outlier_pts[f'lfp_{hemi}_raw'], s=s1, c=overvoltage_color, label='_nolegend_')

    # AII: Show inset scatter plot for 3 days
    axB2.scatter(non_outlier_pts_inset['CT_timestamp'], non_outlier_pts_inset[f'lfp_left_threshold_interpolate'], s=s2, c=nonovervoltage_color, label='_nolegend_')
    axB2.scatter(outlier_pts_inset['CT_timestamp'], outlier_pts_inset[f'lfp_left_threshold_interpolate'], s=s2, c=corrected_overvoltage_color, label='_nolegend_')

    # Add 9-sample exponential moving average line
    smoothed = gaussian_filter1d(inset_df[f'lfp_left_threshold_interpolate'], sigma=window_size/2)
    axB2.plot(inset_df['CT_timestamp'], smoothed, color='k', lw=0.5, label='_nolegend_')

    axB1.set(ylabel='LFP Power', yscale='log', xlabel='Days Since DBS', title=' ')
    axB2.tick_params(axis='y', labelleft=False)
    axB2.set(ylabel='', xlabel='Days Since DBS')
    utils.transform_timestamp_to_days(pt_df, axA1)
    utils.transform_timestamp_to_days(inset_df, axB2)

# Fig C (polar plots showing circadianness of each method)
if True:
    # Set up subplot C
    subplot_C_gs = gs[:, -1].subgridspec(2, 2, wspace=0.7)
    subplot_C_axes = []
    for i, pt_id in enumerate(pts_w_outliers):
        ax = fig.add_subplot(subplot_C_gs[i // 2, i % 2], polar=True)
        pt_df = df.query('pt_id == @pt_id')

        peak_times_OvER, peak_times_thresh = [], []
        unique_days = pt_df['CT_timestamp'].dt.floor('D').unique()
        p_val_amps_OvER, p_val_amps_thresh = [], []
        for center_day in unique_days:
            window_start = center_day + pd.Timedelta(hours=12) - half_window
            window_end = center_day + pd.Timedelta(hours=12) + half_window

            window_df = pt_df[(pt_df['CT_timestamp'] >= window_start) & (pt_df['CT_timestamp'] < window_end)]

            t_hours = window_df['CT_timestamp'].dt.hour + window_df['CT_timestamp'].dt.minute / 60 + window_df['CT_timestamp'].dt.second / 3600
            y_OvER = window_df['lfp_left_OvER_interpolate'].values
            y_thresh = window_df['lfp_left_threshold_interpolate'].values

            mask_OvER = ~np.isnan(y_OvER)
            t_hours_OvER = t_hours[mask_OvER]
            y_OvER = y_OvER[mask_OvER]
            mask_thresh = ~np.isnan(y_thresh)
            t_hours_thresh = t_hours[mask_thresh]
            y_thresh = y_thresh[mask_thresh]

            # Initial guesses: M=mean, A=half range, phi=0
            M0_OvER = np.mean(y_OvER)
            A0_OvER = (np.max(y_OvER) - np.min(y_OvER)) / 2
            phi0_OvER = 0
            M0_thresh = np.mean(y_thresh)
            A0_thresh = (np.max(y_thresh) - np.min(y_thresh)) / 2
            phi0_thresh = 0

            try:
                popt, pcov = curve_fit(cosinor, t_hours_OvER, y_OvER, p0=[M0_OvER, A0_OvER, phi0_OvER])
                perr = np.sqrt(np.diag(pcov))
                n = len(y)
                p = len(popt)
                dof = max(0, n-p)
                t_stats = popt / perr
                p_values = 2 * (1 - stats.t.cdf(np.abs(t_stats), dof))  # two-tailed p-values
                p_val_amps_OvER.append(p_values[1])  # p-value for amplitude
                M_fit, A_fit, phi_fit = popt

                if A_fit < 0:
                    A_fit = -A_fit
                    phi_fit += np.pi

                # Time of peak: when cos is 1 => ωt + φ = 0 mod 2π => t_peak = -φ / ω
                omega = 2 * np.pi / 24
                t_peak = (-phi_fit % (2 * np.pi)) / omega  # in hours

                # Convert t_peak to timestamp on the same date as day_df
                date = day_df['CT_timestamp'].iloc[0].normalize()
                peak_timestamp = pd.Timestamp(center_day) + pd.to_timedelta(t_peak, unit='h')

                peak_times_OvER.append([t_peak * 3600 / (24 * 3600), M_fit + np.abs(A_fit)])  # Store peak time and amplitude
            except RuntimeError:
                # fitting failed
                continue

            try:
                popt, pcov = curve_fit(cosinor, t_hours_thresh, y_thresh, p0=[M0_thresh, A0_thresh, phi0_thresh])
                perr = np.sqrt(np.diag(pcov))
                n = len(y)
                p = len(popt)
                dof = max(0, n-p)
                t_stats = popt / perr
                p_values = 2 * (1 - stats.t.cdf(np.abs(t_stats), dof))  # two-tailed p-values
                p_val_amps_thresh.append(p_values[1])  # p-value for amplitude
                M_fit, A_fit, phi_fit = popt

                if A_fit < 0:
                    A_fit = -A_fit
                    phi_fit += np.pi

                # Time of peak: when cos is 1 => ωt + φ = 0 mod 2π => t_peak = -φ / ω
                omega = 2 * np.pi / 24
                t_peak = (-phi_fit % (2 * np.pi)) / omega  # in hours

                # Convert t_peak to timestamp on the same date as day_df
                date = day_df['CT_timestamp'].iloc[0].normalize()
                peak_timestamp = pd.Timestamp(center_day) + pd.to_timedelta(t_peak, unit='h')

                peak_times_thresh.append([t_peak * 3600 / (24 * 3600), M_fit + np.abs(A_fit)])  # Store peak time and amplitude
            except RuntimeError:
                # fitting failed
                continue

        peak_times_OvER = np.vstack(peak_times_OvER)
        peak_times_thresh = np.vstack(peak_times_thresh)

        ax.scatter(peak_times_OvER[:, 0] * 2 * np.pi, peak_times_OvER[:, 1], s=3, c='C2', label='OvER Peaks', alpha=0.5)
        ax.scatter(peak_times_thresh[:, 0] * 2 * np.pi, peak_times_thresh[:, 1], s=3, c='C6', label='Threshold Peaks', alpha=0.5)
        ax.set(
            title=pt_id,
            xticks=np.linspace(0, 2 * np.pi, 8, endpoint=False),
            xticklabels=['', '3:00', '6:00', '9:00', '12:00', '15:00', '18:00', '21:00'],
            ylim=[0, 1.2 * max(peak_times_OvER[:, 1].max(), peak_times_thresh[:, 1].max())],
            yticks = np.linspace(0, 1.2 * max(peak_times_OvER[:, 1].max(), peak_times_thresh[:, 1].max()), 5),
            yticklabels=[''] * 5,
        )
        ax.set_theta_zero_location("N")
        ax.set_theta_direction(-1)
        subplot_C_axes.append(ax)

fig.tight_layout()

# Add titles
if True:
    bb1 = axA1.get_position()
    bb2 = axA2.get_position()
    x_center = (bb1.x0 + bb2.x1) / 2
    y_top = max(bb1.y1, bb2.y1)
    fig.text(x_center, A_top + 0.01, 'Overvoltage Event Removal (OvER)', ha='center', va='bottom', fontsize=12)
    fig.text(x_center, B_top + 0.01, 'Threshold and Interpolate', ha='center', va='bottom', fontsize=12)

# Legends
if True:
    outlier_legend_dot = Line2D([0], [0], marker='o', color=overvoltage_color, label='Overvoltage', markersize=4, linestyle='None')
    corrected_outlier_legend_dot = Line2D([0], [0], marker='o', color=corrected_overvoltage_color, label='Corrected Overvoltage', markersize=4, linestyle='None')
    nonoutlier_legend_dot = Line2D([0], [0], marker='o', color=nonovervoltage_color, label='Nonovervoltage', markersize=4, linestyle='None')
    handles = [outlier_legend_dot, corrected_outlier_legend_dot, nonoutlier_legend_dot]
    labels = ['Overvoltage', 'Corrected Overvoltage', 'Nonovervoltage']
    A_top = axA1.get_window_extent().transformed(fig.transFigure.inverted()).y1
    A_left = axA1.get_window_extent().transformed(fig.transFigure.inverted()).x0
    B_top = axB1.get_window_extent().transformed(fig.transFigure.inverted()).y1
    B_left = axB1.get_window_extent().transformed(fig.transFigure.inverted()).x0
    leg = fig.legend(handles,
                     labels,
                     bbox_to_anchor=(B_left-.02, (A_top+.08)),
                     loc='center left',
                     fontsize=9)
    
    c_bbox = Bbox.union([ax.get_position(fig.transFigure) for ax in subplot_C_axes])
    ax.legend(
        handles=[
            Line2D([0], [0], color='C2', label='OvER Peaks', marker='o', markersize=4, linestyle='None'),
            Line2D([0], [0], color='C6', label='Threshold Peaks', marker='o', markersize=4, linestyle='None')
        ],
        labels=['OvER Peaks', 'Threshold Peaks'],
        loc='upper center',
        bbox_to_anchor=(c_bbox.x0 + c_bbox.width / 2, c_bbox.y1 + 0.13),
        bbox_transform=fig.transFigure,
        fontsize=9
    )

# Add subfigure labels
if True:
    C_top = c_bbox.y1
    C_left = c_bbox.x0
    vert_offset, hor_offset = 0.01, -0.07
    a_label = fig.text(A_left+hor_offset, A_top+vert_offset, 'a', fontsize=14, fontweight='bold', color='k', fontdict={'family': 'sans-serif'})
    b_label = fig.text(B_left+hor_offset, B_top+vert_offset, 'b', fontsize=14, fontweight='bold', color='k', fontdict={'family': 'sans-serif'})
    c_label = fig.text(C_left-.03, C_top+vert_offset, 'c', fontsize=14, fontweight='bold', color='k', fontdict={'family': 'sans-serif'})

# Box out callout portion (pink) on AI and BI
if True:
    old_ylim = axA1.get_ylim()
    axA1.fill_between(
        [inset_df['CT_timestamp'].min(), inset_df['CT_timestamp'].max()],
        axA1.get_ylim()[0], axA1.get_ylim()[1],
        color='violet', alpha=0.25, label='_nolegend_', edgecolor=None, zorder=0)
    axB1.fill_between(
        [inset_df['CT_timestamp'].min(), inset_df['CT_timestamp'].max()],
        axB1.get_ylim()[0], axB1.get_ylim()[1],
        color='violet', alpha=0.25, label='_nolegend_', edgecolor=None, zorder=0)
    axA1.set_ylim(old_ylim)
    axA2.set_facecolor(mcolors.to_rgba('violet', alpha=0.25))
    axB2.set_facecolor(mcolors.to_rgba('violet', alpha=0.25))

plt.show()

In [None]:
fig = plt.figure(figsize=(15, 7))
ncols = 8
gs = gridspec.GridSpec(nrows=2, ncols=ncols+1, figure=fig, hspace=0.5, wspace=0.5, width_ratios=[1]*ncols+[0.1])

vmax = 8
c1 = 'lightcoral'
c2 = 'lightblue'

# Fig A: Violin plots
if True:
    subplot_A = gs[0, :2].subgridspec(1, 1)
    axA = fig.add_subplot(subplot_A[0, 0])
    box = axA.get_position()
    width_ratio, height_ratio = 0.6, 0.9
    axA.set_position([box.x0,#+(1-width_ratio)/2*(box.x1-box.x0),
                      box.y0+(box.y1-box.y0)*(1-height_ratio),
                      box.width * width_ratio,
                      box.height * height_ratio])  # shrink width and height

    df1 = oura_df[(~oura_df[f'is_outlier_left']) & (~oura_df['is_outlier_right']) & (oura_df[f'lfp_{hemi}_OvER_interpolate'].notna())]
    df2 = oura_df[((oura_df[f'is_outlier_left']) | (oura_df['is_outlier_right'])) & (oura_df[f'lfp_{hemi}_OvER_interpolate'].notna())]

    x = df1['max_met'].values
    y = df2['max_met'].values
    parts1 = axA.violinplot(x, positions=[0], showextrema=False, side='both', points=1000)
    parts2 = axA.violinplot(y, positions=[0], showextrema=False, side='both', points=1000)

    utils.make_violin_plot_pretty(parts1, nonovervoltage_color, np.median(x), axA, alpha=0.9)
    utils.make_violin_plot_pretty(parts2, corrected_overvoltage_color, np.median(y), axA, alpha=0.5)

    axA.set(
            xticks=[],
            ylabel=f'MET (multiples of RMR)',
            ylim=[0, 17],
            #   title='Comparison of MET between overages and non-overages'
           )
    axA.tick_params('x', tick1On=False, tick2On=False)

    fake_handles = [mpatches.Patch(color=nonovervoltage_color), mpatches.Patch(color=corrected_overvoltage_color)]
    axA.legend(fake_handles, ['Nonovervoltages', 'Corrected Overvoltages'], fontsize=10, bbox_to_anchor=(0.5, -0.01), loc='upper center')

    # P-value (two-sided)
    p_val = stats_utils.ks_test_with_downsampling(x, y)['p_value']
    axA.annotate(f'p = {p_val:.3g}', (0, axA.get_ylim()[1] * 0.98), fontsize=10, horizontalalignment='center', verticalalignment='top', color='k')

spearman_vals = []
# Group by lead model
group_order = []
for lead_model in ['LEAD_3387', 'LEAD_B33015']:
    for pt_id in oura_df.query('lead_model == @lead_model')['pt_id'].unique():
        if len(oura_df.query(f'lead_model == @lead_model and pt_id == @pt_id').dropna(subset=f'lfp_{hemi}_raw')) > 0:
            group_order.append((lead_model, pt_id))
groups = oura_df.query(f'lfp_{hemi}_raw > 20').groupby(['lead_model', 'pt_id'])

# Make plots B-C
pt_corr_dict = {}
for i, (lead_model, pt_id) in enumerate(group_order):
    pt_df = groups.get_group((lead_model, pt_id)).dropna(subset=f'lfp_{hemi}_OvER_interpolate')

    # Set up subplot
    subplot_gs = gs[(i+2)//ncols, (i+2)%ncols].subgridspec(2, 1, hspace=1)
    ax1 = fig.add_subplot(subplot_gs[1])
    ax2 = fig.add_subplot(subplot_gs[0])
    
    # Top: Scatter plot for patient comparing MET with LFP power
    x, y = pt_df[['max_met', f'lfp_{hemi}_OvER_interpolate']].values.T
    ax1.scatter(x, y, c=np.array([nonovervoltage_color, corrected_overvoltage_color])[pt_df[f'is_outlier_{hemi}'].astype(int)], s=0.1)

    results = stats_utils.spearman_with_downsampling(x, y)
    rho = results['spearman_r']
    p_val = results['p_value']
    spearman_vals.append((p_val, rho))

    m, b = np.polyfit(x, y, 1)
    x_fit = np.linspace(0.9, 16, 1000)
    y_fit = m * x_fit + b
    ax1.plot(x_fit, y_fit, lw=1, c='rebeccapurple', ls='--')

    ax1.set(xlabel='MET', ylabel='LFP Power' if (i==0 or (((i+2)%ncols)==0)) else '', xlim=[0,16], xticks=[0,8,16], yscale='log')
    ax1.annotate(f'R={rho:.2f}\np={p_val:.2g}', (0.95, 0.03), xycoords='axes fraction', fontsize=8, ha='right', va='bottom', color='k')
    # ax1.annotate(f'R={rho:.2f}\np={p_val:.2g}', (0.05, 0.97), xycoords='axes fraction', fontsize=8, ha='left', va='top', color='k')

    # Bottom: Scatter plot showing LFP power over time colored by MET
    ax2.scatter(pt_df['CT_timestamp'], pt_df[f'lfp_{hemi}_raw'], c=pt_df['max_met'], s=1, cmap='viridis', vmin=0, vmax=vmax)
    ax2.set(title=f'{pt_id}', xlabel='Days Since DBS', ylabel='LFP Power' if (i==0 or (((i+2)%ncols)==0)) else '', yscale='log')
    utils.transform_timestamp_to_days(pt_df, ax2, rotation=45)

    pt_corr_dict[pt_id] = (rho, p_val)

# Add colorbar in final column for MET
if True:
    cbar_ax = fig.add_subplot(gs[:, -1])
    norm = plt.Normalize(vmin=0, vmax=vmax)
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax, orientation='vertical', extend='max')
    cbar.set_label('MET (multiples of RMR)', fontsize=12)

# Add subfigure labels
if True:
    vert_offset, hor_offset = 0.04, -0.04
    A_left = gs.get_grid_positions(fig)[2][0]
    A_top = 1-gs.get_grid_positions(fig)[0][1]
    B_left = gs.get_grid_positions(fig)[2][2]
    B_top = 1-gs.get_grid_positions(fig)[0][1]
    C_left = gs.get_grid_positions(fig)[2][0]
    C_top = 1-gs.get_grid_positions(fig)[0][0]
    fig.text(A_left+hor_offset, A_top+vert_offset, 'a', fontsize=14, fontweight='bold', color='k', fontdict={'family': 'sans-serif'})
    fig.text(B_left+hor_offset, B_top+vert_offset, 'b', fontsize=14, fontweight='bold', color='k', fontdict={'family': 'sans-serif'})
    fig.text(C_left+hor_offset, C_top+vert_offset, 'c', fontsize=14, fontweight='bold', color='k', fontdict={'family': 'sans-serif'})
    fig.text(B_left+hor_offset+0.02, B_top+vert_offset, '3387 Lead', fontsize=14, color='k', fontdict={'family': 'serif'})
    fig.text(C_left+hor_offset+0.02, C_top+vert_offset, 'SenSight Lead', fontsize=14, color='k', fontdict={'family': 'serif'})

plt.show()