In [1]:
import sys
sys.path.append('/autofs/homes/005/fd881/repos/MedImaging-ModelDriftMonitoring/src')

In [2]:
import pandas as pd
import os
from model_drift import mgb_locations
import matplotlib.pyplot as plt
from datetime import timedelta
import seaborn as sns

In [3]:
# Function to generate windows
def generate_windows(start_date, end_date, window_size_days=14, stride_days=1):
    current_end = start_date
    while current_end <= end_date:
        yield start_date, current_end
        current_end += timedelta(days=stride_days)
        if (current_end - start_date).days > window_size_days:
            start_date += timedelta(days=stride_days)

In [4]:
label_cols = ['Atelectasis', 'Cardiomegaly', 'Edema', 'Lung Opacity', 'Pleural Other', 'Pleural Effusion', 'Pneumonia', 'Pneumothorax', 'Support Devices', ]

In [5]:
labels_df = pd.read_csv(
    mgb_locations.labels_csv,
)
labels_df['StudyDate'] = pd.to_datetime(labels_df['StudyDate'])
labels_df['StudyDate'] = labels_df['StudyDate'].dt.date

In [6]:
# read jsonl file from pred_folder in df
pred_folder = '/autofs/cluster/qtim/projects/xray_drift/inferences/classification_final_allpoc_inference_woconsolidation/'

df_preds = pd.read_json(os.path.join(pred_folder, 'preds.jsonl'), lines=True)
df_preds = pd.concat(
    [
        df_preds,
        pd.DataFrame(df_preds['activation'].values.tolist(), columns=[f"activation.{c}" for c in label_cols])
    ],
    axis=1
)

df_preds = pd.concat(
    [
        df_preds,
        pd.DataFrame(df_preds['label'].values.tolist(), columns=[f"label.{c}" for c in label_cols])
    ],
    axis=1
)

In [7]:
df_dicom = pd.read_csv(mgb_locations.dicom_inventory_csv)

def make_index(row: pd.Series):
    return f"{row.PatientID}_{row.AccessionNumber}_{row.SOPInstanceUID}"

# df_dicom only has anonimized dates, so we are pulling them from labels
study_dates = labels_df[['StudyInstanceUID','StudyDate']].copy()

df_dicom.drop(columns=["StudyDate"], inplace=True)
df_dicom = df_dicom.merge(
    study_dates,
    left_on="StudyInstanceUID",
    right_on="StudyInstanceUID",
)
df_dicom["index"] = df_dicom.apply(make_index, axis=1)


In [8]:
df_preds = df_preds.merge(
    df_dicom,
    on="index",
)

#get accesion number from here
crosswalk = pd.read_csv(mgb_locations.crosswalk_csv, dtype={"ANON_AccNumber": int})
crosswalk = crosswalk[["ANON_AccNumber", "ORIG_AccNumber"]]

# get other metadata from here
reports = pd.read_csv(mgb_locations.reports_csv, dtype=str)
reports = reports[
    [
        "Accession Number",
        "Point of Care",
        "Patient Sex",
        "Patient Age",
        "Is Stat",
        "Exam Code",
    ]
].copy()

df_preds = df_preds.merge(
    crosswalk,
    how="left",
    left_on="AccessionNumber",
    right_on="ANON_AccNumber",
    validate="many_to_one",
)
df_preds = df_preds.merge(
    reports,
    how="left",
    left_on="ORIG_AccNumber",
    right_on="Accession Number",
    validate="many_to_one",
)


df_preds['StudyDate'] = pd.to_datetime(df_preds['StudyDate'])
df_preds['StudyDate'] = df_preds['StudyDate'].dt.date

In [None]:
# exclude laterals 
df_preds = df_preds[df_preds['ViewPosition'] != 'LL']
df_preds['ViewPosition'].value_counts()


In [10]:
# make a new label for point of care that combines all the OPs
df_preds['Point of Care_combined'] = df_preds['Point of Care'].apply(lambda x: 'MGH IMG XR OPX' if 'OP' in x else x)

In [None]:
# look at one specific site 

site = 'MGH IMG XR OP YAW6'
site_disp = site.replace(' ', '~')

print(df_preds['Point of Care'].value_counts())

#df_er = df_preds[df_preds['Point of Care'] == site].copy()
df_er = df_preds.copy()


df_er['StudyDate'] = pd.to_datetime(df_er['StudyDate'])
df_er.set_index('StudyDate', inplace=True)


In [None]:
exam_counts = {}

# Get the overall start and end dates
start_date = df_er.index.min()
end_date = df_er.index.max()

window_size_days = 30 
# Loop through each time window and count exams
for window_start, window_end in generate_windows(start_date, end_date, window_size_days=window_size_days):
    window_data = df_er[(df_er.index >= window_start) & (df_er.index < window_end)]
    count = window_data.shape[0]/window_size_days  # Count the number of rows (exams) in the window, normalized by window size
    exam_counts[window_start] = {'ExamCount': count}

# Convert the exam counts dictionary to a DataFrame
exam_counts_df = pd.DataFrame(exam_counts).T

# If needed, fill missing values with 0
exam_counts_df.fillna(0, inplace=True)

# Display the resulting DataFrame
plt.figure(figsize=(10, 6))  # Set the figure size for better readability

plt.plot(exam_counts_df.index, exam_counts_df['ExamCount'])

plt.title(f'Exams per day at $\mathbf{{{site_disp}}}$')
plt.xlabel('Window Start Date')
plt.ylabel('Count per Day')
plt.xticks(rotation=45)  
plt.tight_layout()  


plt.show()


In [None]:
parameter = 'ViewPosition'


# check the machines used
print(df_er[parameter].value_counts())

# Get the overall start and end dates
start_date = df_er.index.min()
end_date = df_er.index.max()

proportion_data = {}

for window_start, window_end in generate_windows(start_date, end_date, window_size_days=30):
    window_data = df_er[(df_er.index >= window_start) & (df_er.index < window_end)]
    proportions = window_data[parameter].value_counts(normalize=False)
    proportion_data[window_end] = proportions/30
    
machine_df = pd.DataFrame(proportion_data).T

machine_df.fillna(0, inplace=True)


In [None]:
# drop the Fluorospot, which has only two non nans rows

#count non nans in each column
print(machine_df.count())

#machine_df.drop(columns=['Fluorospot Compact FD'], inplace=True, errors='ignore')

In [None]:
machine_df

In [None]:
plt.figure(figsize=(10, 6))

palettes = ['deep', 'bright', 'pastel', 'muted', 'dark', 'colorblind']
colors = []

machine_df_prop = machine_df.div(machine_df.sum(axis=1), axis=0).copy()
for pal in palettes:
    # Extend the color list with colors from each palette
    palette_colors = sns.color_palette(pal)
    if pal == 'deep':
        # Swap the first two colors
        palette_colors[0], palette_colors[1] = palette_colors[1], palette_colors[0]
    colors.extend(palette_colors)

plt.stackplot(machine_df_prop.index, *machine_df_prop.T.values, labels=machine_df_prop.columns, colors=colors)
plt.xlim(machine_df_prop.index.min(), machine_df_prop.index.max())
plt.ylim(0, 1)  

#plt.title(f'{parameter} over Time at $\mathbf{{{site_disp}}}$')
plt.title(f'{parameter} over Time ')

# make a vertical line at March 18 2020
plt.axvline(pd.to_datetime('2020-03-10'), color='darkgray', linestyle='--', label='March 10, 2020')
plt.xlabel('Window End Date')
plt.ylabel('Proportion')
plt.xticks(rotation=45)
plt.legend()  
plt.tight_layout()

# Set x-axis limits to start from November 2019
plt.xlim(pd.to_datetime('2019-11-01'), machine_df_prop.index.max())

# Set x-axis ticks every 2 months
start_date = pd.to_datetime('2019-11-01')
end_date = machine_df_prop.index.max()
date_range = pd.date_range(start=start_date, end=end_date, freq='2MS')
plt.xticks(date_range, date_range.strftime('%Y-%m'), rotation=45)

#plt.savefig(os.path.join(output_dir, 'view_position_over_time.png'))
plt.savefig('view_position_over_time.png')


plt.show()

In [17]:
df_cardiomegaly_perf = pd.read_csv("/autofs/cluster/qtim/projects/xray_drift/drift_analyses/PLOTS/paper/drift_analysis_allpoc_emd_jackknife_helllinger_final_florence_PLOTS/performance_Cardiomegaly.csv")

In [None]:
# Create a figure with two subplots, one above the other
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 12), sharex=True)

# Plot device proportion on the top subplot
ax1.stackplot(machine_df_prop.index, *machine_df_prop.T.values, labels=machine_df_prop.columns, colors=colors)
ax1.set_ylim(0, 1)
ax1.set_title(f'{parameter} over Time')
ax1.axvline(pd.to_datetime('2020-03-10'), color='#5088A1', linestyle='--')
ax1.axvline(pd.to_datetime('2020-01-01'), color='darkblue', linestyle='--')

ax1.set_ylabel('Proportion')
ax1.legend(loc='upper right')#, bbox_to_anchor=(1, 0.5))

# Plot performance on the bottom subplot
ax2.plot(pd.to_datetime(df_cardiomegaly_perf['date']), df_cardiomegaly_perf['auroc'], label='AUROC')
ax2.set_title('Cardiomegaly Performance over Time')
ax2.set_ylabel('AUROC')

# Set x-axis limits and ticks
start_date = pd.to_datetime('2019-11-01')
end_date = machine_df_prop.index.max()
date_range = pd.date_range(start=start_date, end=end_date, freq='2MS')
ax2.set_xlim(start_date, end_date)
ax2.set_xticks(date_range)
ax2.set_xticklabels(date_range.strftime('%Y-%m'), rotation=45, ha='right')
# Add the same vertical line to the performance plot

ax2.axvline(pd.to_datetime('2020-03-10'), color='#5088A1', linestyle='--')
ax2.axvline(pd.to_datetime('2020-01-01'), color='darkblue', linestyle='--')


# Adjust the legend to include the vertical line
handles, labels = ax2.get_legend_handles_labels()
ax2.legend(handles=handles, labels=labels, loc='upper right')


# Adjust layout and save figure
output_dir = '/autofs/cluster/qtim/projects/xray_drift/drift_analyses/PLOTS/paper/drift_analysis_allpoc_emd_jackknife_helllinger_final_florence_PLOTS/'
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'ViewPosition_time_performance.png'), dpi=600, bbox_inches='tight')
plt.show()