In [1]:
# load in the dataset of NSM fitted training data: latent vectors + OA labels
# fit the B-Score to this data

# load in the dataset of NSM fitted testing data: latent vectors + OA labels
# apply the B-Score to this data to get B-Scores

# plot the B-Scores by KL grade & other metrics similar to performed in the
# original B-Score paper

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression


import os

In [14]:
folder_demographics = '/dataNAS/people/aagatti/projects/OAI_DESS/aging_trajectories/data/demographics'
dict_demographic_filenames = {
    'baseline': '0_demographics_baseline.csv',
    '12_month': '1_demographics_12_month.csv',
    '24_month': '3_demographics_24_month.csv',
    '36_month': '5_demographics_36_month.csv',
    '48_month': '6_demographics_48_month.csv',
}

# load in the demographics data from baseline, 12, 24, and 48 months, and get the IDs of subjects (and knees) that are either healthy for all timepoints
# kl = 0, and that had OA for all timepoints (kl >=2)

for timepoint in ['baseline']:
    df_baseline = pd.read_csv(os.path.join(folder_demographics, dict_demographic_filenames[timepoint]))
    # create id_knee column so we include all individual knees
    df_baseline['id_knee'] = df_baseline['id'].astype(int).astype(str) + '_' + df_baseline['side'].astype(str)

In [20]:
df_baseline['id_side'] = df_baseline['id'].astype(int).astype(str) + '_' + df_baseline['side'].astype(str)

path_df = '/dataNAS/people/aagatti/projects/OAI_DESS/fit_nsm/results/647_nsm_femur_v0.0.1/2000/00m/latents.pkl'
df_latents = pd.read_pickle(path_df)
df_latents = df_latents.merge(df_baseline, on='id_side', how='left')

In [27]:
# get colum nanmes
df_latents.columns

Index(['subject_id', 'side_x', 'id_side', 'latent', 'Unnamed: 0', 'id',
       'side_y', 'visit_str', 'visit_number', 'kl', 'oa',
       'sex_1_male_2_female', 'age_y', 'height_mm', 'weight_kg', 'bmi_kg_m_2',
       'koos_pain', 'womac_pain', 'osteophytes_fem_ant_lat_score',
       'osteophytes_fem_ant_med_score', 'osteophytes_fem_cent_lat_score',
       'osteophytes_fem_cent_med_score', 'osteophytes_fem_post_lat_score',
       'osteophytes_fem_post_med_score', 'cart_fem_ant_lat_thinning',
       'cart_fem_ant_med_thinning', 'cart_fem_cent_lat_thinning',
       'cart_fem_cent_med_thinning', 'cart_fem_post_lat_thinning',
       'cart_fem_post_med_thinning', 'cart_fem_ant_lat_full_thickness',
       'cart_fem_ant_med_full_thickness', 'cart_fem_cent_lat_full_thickness',
       'cart_fem_cent_med_full_thickness', 'cart_fem_post_lat_full_thickness',
       'cart_fem_post_med_full_thickness', 'id_knee'],
      dtype='object')

In [29]:
# Convert strings in the 'latent' column to actual lists
def parse_latents(latent_str):
    return eval(latent_str) if isinstance(latent_str, str) else latent_str

# Apply the function to the 'latent' column
df_latents['latent'] = df_latents['latent'].apply(parse_latents)

# Prepare a DataFrame to store average latents
average_latents = []

# Iterate through each KL value (0 to 4, inclusive)
for kl_value in range(5):
    # Filter the DataFrame for the current KL grade
    filtered_latents = df_latents[df_latents['kl'] == kl_value]['latent']

    # Check if there are any latents for the current KL
    if not filtered_latents.empty:
        # Stack the latents as arrays and calculate the mean
        latents_array = np.stack(filtered_latents.values)
        avg_latent = np.mean(latents_array, axis=0)  # Average across patients
    else:
        avg_latent = np.zeros((512,))  # Handle case where there are no values

    # Append the KL grade and avg_latent to the list
    average_latents.append({'kl': kl_value, 'avg_latent': avg_latent})

{'kl': 0,
 'avg_latent': array([ 0.04150551, -0.04535307,  0.09482024,  0.02795899,  0.04194402,
         0.21857441, -0.05105849,  0.09845905,  0.01532099,  0.06209497,
         0.11780177, -0.01521489,  0.08931537, -0.01747702, -0.13059008,
        -0.01561932, -0.1437892 ,  0.09183716,  0.01452013,  0.02980872,
         0.04787097, -0.07427579, -0.08674766,  0.25792269, -0.15473316,
         0.12042286,  0.09352114,  0.1244296 ,  0.01659941,  0.12566314,
        -0.11473437, -0.03677314, -0.02755915,  0.11301973,  0.04249253,
         0.07063909,  0.00182471,  0.13957646, -0.11157266,  0.01213868,
         0.08039055,  0.01679471,  0.0234755 ,  0.2111812 , -0.19175165,
         0.04032886, -0.07674061, -0.04406363, -0.06391805,  0.02234597,
         0.15672574,  0.01488933,  0.14069293, -0.06730838,  0.07705805,
         0.11968234, -0.03821859, -0.12704771,  0.02443055, -0.15629584,
        -0.03279193,  0.01963031, -0.06727861, -0.11395743,  0.09556802,
         0.04888296,  0.015