In [None]:
import sys
sys.path.append('/autofs/homes/005/fd881/repos/MedImaging-ModelDriftMonitoring/')


In [None]:
import warnings
warnings.filterwarnings("ignore")


In [None]:
from pathlib import Path
import pandas as pd

#import click
#import plotnine
from pycrumbs import tracked
import multiprocessing

from src.model_drift.data import mgb_data
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sn
from tqdm import tqdm
from scipy.stats import multivariate_normal
from src.model_drift import helpers, mgb_locations
import matplotlib.dates as mdates
from scipy.linalg import det, inv
import scipy
from datetime import timedelta



In [None]:
import os
os.getcwd()

In [None]:
def split_on_date(df, splits, col=None):
    splits = pd.to_datetime(splits).sort_values()

    rem = df

    for split in splits:
        if col is None:
            curr, rem = rem[rem.index < split], rem[rem.index >= split]
        else:
            curr, rem = rem[rem[col] < split], rem[rem[col] >= split]
        yield curr
    yield rem

In [None]:
def make_index(row: pd.Series):
    return f"{row.PatientID}_{row.AccessionNumber}_{row.SOPInstanceUID}"

meta_df = pd.read_csv(
    mgb_locations.dicom_inventory_csv,
    index_col=0,
)
meta_df.drop(columns=["StudyDate"], inplace=True)  # anonymized dates
labels_df = pd.read_csv(
    mgb_locations.labels_csv,
    index_col=0,
)  # need real dates from this file
meta_df = meta_df.merge(
    labels_df,
    how="left",
    on=("StudyInstanceUID", "PatientID", "AccessionNumber"),
)

# Some metadata is from the RIS and is in the reports CSV
reports = pd.read_csv(mgb_locations.reports_csv, dtype=str)
reports = reports[
    [
        "Accession Number",
        "Point of Care",
        "Patient Sex",
        "Patient Age",
        "Is Stat",
        "Exam Code",
    ]
].copy()
crosswalk = pd.read_csv(mgb_locations.crosswalk_csv, dtype={"ANON_AccNumber": int})
crosswalk = crosswalk[["ANON_AccNumber", "ORIG_AccNumber"]]
# meta_df.assign(AccessionNumber=lambda x: x.AccessionNumber.str.lstrip("0"))

meta_df = meta_df.merge(
    crosswalk,
    how="left",
    left_on="AccessionNumber",
    right_on="ANON_AccNumber",
    validate="many_to_one",
)
meta_df = meta_df.merge(
    reports,
    how="left",
    left_on="ORIG_AccNumber",
    right_on="Accession Number",
    validate="many_to_one",
)

meta_df["StudyDate"] = pd.to_datetime(meta_df["StudyDate"], format='%m/%d/%Y')
meta_df["index"] = meta_df.apply(make_index, axis=1)

In [None]:
# set to 512 for using the resnet features
num_feat = 128

In [None]:
#vae_pred_file = '/autofs/cluster/qtim/projects/xray_drift/models/mgb/resnet_features/preds.jsonl' #resnet_features
vae_pred_file = '/autofs/cluster/qtim/projects/xray_drift/inferences/mgb_with_chexpert_model_vae_take2/preds.jsonl' #vae_features

vae_df = helpers.jsonl_files2dataframe([vae_pred_file], desc="reading VAE results", refresh_rate=.1)
vae_df = pd.concat(
    [
        vae_df,
        pd.DataFrame(vae_df['mu'].values.tolist(), columns=[f"mu.{c:0>3}" for c in range(num_feat)]) #512
    ],
    axis=1
)
vae_df.drop_duplicates(subset="index", inplace=True)


In [None]:
merged_df = meta_df.merge(vae_df, on="index", how="left")
train_df, val_df, test_df = split_on_date(
        merged_df,
        [mgb_data.TRAIN_DATE_END, mgb_data.VAL_DATE_END],
        col="StudyDate",
    )

merged_df.set_index("StudyDate", inplace=True)
merged_df_use = merged_df.copy()


In [None]:
def kl_mvn(to, fr):
    """Calculate `KL(to||fr)`, where `to` and `fr` are pairs of means and covariance matrices"""
    m_to, S_to = to
    m_fr, S_fr = fr
    
    S_to += np.eye(S_to.shape[0]) * 1e-6
    S_fr += np.eye(S_fr.shape[0]) * 1e-6

    d = m_fr - m_to
    
    c, lower = scipy.linalg.cho_factor(S_fr)
    def solve(B):
        return scipy.linalg.cho_solve((c, lower), B)
    
    def logdet(S):
        return np.linalg.slogdet(S)[1]

    term1 = np.trace(solve(S_to))
    term2 = logdet(S_fr) - logdet(S_to)
    term3 = d.T @ solve(d)
    return (term1 + term2 + term3 - len(d))/2.


In [None]:

list_rows = val_df['mu'].apply(lambda x: isinstance(x, list))
val_df = val_df[list_rows]
list_rows = merged_df_use['mu'].apply(lambda x: isinstance(x, list))
merged_df_use = merged_df_use[list_rows]

# Define the reference set and fit the reference Gaussian
#reference_set = np.array(vae_df_use['mu'].iloc[:5000].tolist())
reference_set = np.array(val_df['mu'].tolist())
reference_set = reference_set.reshape((len(val_df), num_feat)) #512
reference_set = np.nan_to_num(reference_set, nan=0.0, posinf=0.0, neginf=0.0)

mean_vector_ref = np.mean(reference_set, axis=0)
covariance_matrix_ref = np.cov(reference_set, rowvar=False)

# Function to generate windows
def generate_windows(start_date, end_date, window_size_days=30, stride_days=1):
    current_start = start_date
    while current_start + timedelta(days=window_size_days) <= end_date:
        current_end = current_start + timedelta(days=window_size_days)
        yield current_start, current_end
        current_start += timedelta(days=stride_days)

# Convert index to datetime if not already
merged_df_use.index = pd.to_datetime(merged_df_use.index)

# Get the overall start and end dates
start_date = merged_df_use.index.min()
end_date = merged_df_use.index.max()

kl_results = []

for window_start, window_end in tqdm(generate_windows(start_date, end_date)):
    # Select data for the current window
    window_data = merged_df_use[(merged_df_use.index >= window_start) & (merged_df_use.index < window_end)]

    # Check if there are enough samples, if not, continue to next window
    if len(window_data) < 1:
        continue

    # Reshape the data
    comparison_set = np.array(window_data['mu'].tolist())
    comparison_set = comparison_set.reshape((len(window_data), num_feat)) #512
    comparison_set = np.nan_to_num(comparison_set, nan=0.0, posinf=0.0, neginf=0.0)

    # Fit a Gaussian to the comparison set
    mean_vector_comp = np.mean(comparison_set, axis=0)
    covariance_matrix_comp = np.cov(comparison_set, rowvar=False)

    # Compute KL divergence
    kl_divergence = kl_mvn((mean_vector_ref, covariance_matrix_ref), (mean_vector_comp, covariance_matrix_comp))

    # Store the result with the start date
    kl_results.append({'start_date': window_start, 'kl_divergence': kl_divergence})

# Convert the list of dictionaries to a DataFrame
kl_df = pd.DataFrame(kl_results)


In [None]:
plt.figure(figsize=(10, 6))
plt.plot(kl_df['start_date'], kl_df['kl_divergence'])

plt.gca().xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
plt.gca().xaxis.set_major_locator(mdates.MonthLocator())  
plt.gcf().autofmt_xdate()  

plt.title("KL Divergence Over Time")
plt.xlabel("Date")
plt.ylabel("KL Divergence")
plt.grid(True)
plt.show()

In [None]:
list_rows = val_df['mu'].apply(lambda x: isinstance(x, list))
val_df = val_df[list_rows]
list_rows = merged_df_use['mu'].apply(lambda x: isinstance(x, list))
merged_df_use = merged_df_use[list_rows]
# Step 1: Define the reference set and fit the reference Gaussian
#reference_set = np.array(vae_df_use['mu'].iloc[:5000].tolist())
reference_set = np.array(val_df['mu'].tolist())
reference_set = reference_set.reshape((len(val_df), 512))
reference_set = np.nan_to_num(reference_set, nan=0.0, posinf=0.0, neginf=0.0)

mean_vector_ref = np.mean(reference_set, axis=0)
covariance_matrix_ref = np.cov(reference_set, rowvar=False)

# Step 2: Apply sliding window approach
window_size = 1000
stride = 500  # Stride of 1000 for no overlap
num_windows = (len(vae_df_use) - window_size) // stride + 1
kl_list = []

for i in tqdm(range(num_windows)):
    window_start = i * stride
    comparison_set = np.array(merged_df_use['mu'].iloc[window_start:window_start + window_size].tolist())
    comparison_set = comparison_set.reshape((1000, 512))
    comparison_set = np.nan_to_num(comparison_set, nan=0.0, posinf=0.0, neginf=0.0)

    # Fit a Gaussian to the comparison set
    mean_vector_comp = np.mean(comparison_set, axis=0)
    covariance_matrix_comp = np.cov(comparison_set, rowvar=False)

    # Compute KL divergence
    kl_divergence = kl_mvn((mean_vector_ref, covariance_matrix_ref), (mean_vector_comp, covariance_matrix_comp))
 
    # Process the KL divergence result as needed
    kl_list.append(kl_divergence)

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(kl_list)  # or plt.bar(range(len(kl_values)), kl_values) for a bar plot
plt.title("KL Divergence over Different Windows")
plt.xlabel("Window Number")
plt.ylabel("KL Divergence")
plt.grid(True)
plt.show()