In [None]:
import pandas as pd
import os
import numpy as np
import json
import seaborn as sns
import matplotlib.pyplot as plt
import copy
from IPython.display import display, Markdown

from utilities import data
from utilities.info import *

%matplotlib inline
FILE_DIR

## Grab Data for Each Model

In [None]:
kiran_nodule = pd.read_csv(f"{FILE_DIR}/nlst_allmodels_demos.csv")

with open(f'{FILE_DIR}/nlst_democols.json') as json_data:
    kiran_demos_original = json.load(json_data)
    json_data.close()

kiran_data, kiran_demos, _ = data.prep_nlst_preds(kiran_nodule, kiran_demos_original, scanlevel=True, sybil=False, tijmen=False, bin_num=False)
kiran_demos['cat']['other'].append('label')
kiran_demos

In [None]:
sybil_data = pd.read_csv(f"{FILE_DIR}/nlst_sybil_demos.csv")

with open(f'{FILE_DIR}/nlst_sybil_democols.json') as json_data:
    sybil_demos = json.load(json_data)
    json_data.close()

sybil_demos

Get data for Tijmen's linear layer.

In [None]:
tijmen_train = kiran_data[kiran_data['Thijmen_mean'].isna()]
print("train:", len(tijmen_train), "Scans")
tijmen_val = kiran_data[~kiran_data['Thijmen_mean'].isna()]
print("val:", len(tijmen_val), "Scans")

In [None]:
sybil_splits = {s: sybil_data.query(f'split == "{s}"') for s in ['train', 'dev', 'test']}
for s in ['train', 'dev', 'test']:
    print(s, len(sybil_splits[s]), 'Scans')

## Model Training Sets

In [None]:
training_sets = {
    'Kiran': kiran_data,
    'Tijmen': tijmen_train,
    'Sybil': sybil_splits['train']
}

### Categorical columns

In [None]:
cat_demo_splits = data.combine_diff_dfs(sybil_demos['cat'], data.diffs_category_prevalence, training_sets).dropna(subset='value', axis=0).query('value != 0')
display(cat_demo_splits.sort_values(by='diff_Kiran_Sybil', ascending=False).head(10))
display(cat_demo_splits.sort_values(by='diff_Kiran_Sybil', ascending=True ).head(10))

In [None]:
display(cat_demo_splits.query('category == "demo"').sort_values(by='diff_Kiran_Sybil', ascending=False).head(10))
cat_demo_splits.query('category == "demo"').sort_values(by='diff_Kiran_Sybil', ascending=True).head(10)

### Numerical columns

In [None]:
num_demo_splits = data.combine_diff_dfs(sybil_demos['num'], data.diffs_numerical_means, training_sets)
display(num_demo_splits.sort_values(by='diff_Kiran_Sybil', ascending=False).query('diff_Kiran_Sybil > 0'))
num_demo_splits.sort_values(by='diff_Kiran_Sybil', ascending=True).query('diff_Kiran_Sybil < 0')

## Model Validation Sets

In [None]:
val_sets = {
    'Kiran': kiran_data,
    'Tijmen': tijmen_val,
    'SybilDev': sybil_splits['dev'],
    'SybilTest': sybil_splits['test']
}

### Categorical columns

In [None]:
cat_demo_val = data.combine_diff_dfs(sybil_demos['cat'], data.diffs_category_prevalence, val_sets).dropna(subset='value', axis=0).query('value != 0')
display(cat_demo_val.sort_values(by='diff_Kiran_SybilTest', ascending=False).head(10))
display(cat_demo_val.sort_values(by='diff_Kiran_SybilTest', ascending=True ).head(10))

### Numerical columns

In [None]:
num_demo_val = data.combine_diff_dfs(sybil_demos['num'], data.diffs_numerical_means, val_sets)    
display(num_demo_val.sort_values(by='diff_Kiran_SybilTest', ascending=False).head(10))
display(num_demo_val.sort_values(by='diff_Kiran_SybilTest', ascending=True ).head(10))

## Sybil Train vs. Validation Sets

In [None]:
sybil_splits["eval"] = kiran_data

In [None]:
cat_demo_shift = data.combine_diff_dfs(sybil_demos['cat'], data.diffs_category_prevalence, sybil_splits).dropna(subset='value', axis=0).query('value != 0')
display(cat_demo_shift.sort_values(by='diff_train_test', ascending=False).head(10))
cat_demo_shift.sort_values(by='diff_train_test', ascending=True).head(10)

In [None]:
num_demo_shift = data.combine_diff_dfs(sybil_demos['num'], data.diffs_numerical_means, sybil_splits)
display(num_demo_shift.sort_values(by='diff_train_test', ascending=False).head(10))
num_demo_shift.sort_values(by='diff_train_test', ascending=True).head(10)

Conclusion: not much demographic shift (1-2% overall not much, really). Besides family history.

## NLST vs. DLCST

In [None]:
dlcst_preds = pd.read_csv(f"{FILE_DIR}/dlcst_allmodels_cal.csv", header=0)
dlcst_preds.info()

In [None]:
kiran_data['Sex'] = kiran_data['Gender']
kiran_data['NoduleCountPerScan'] = kiran_data['NoduleCounts']
kiran_data['Emphysema'] = kiran_data['Emphysema'].astype(int)

In [None]:
screening_sets = {
    # "nlst_sybil": sybil_demos,
    "nlst": kiran_data,
    "dlcst": dlcst_preds
}

In [None]:
cat_demo_dlcst = data.combine_diff_dfs(DLCST_DEMOCOLS['cat'], data.diffs_category_prevalence, screening_sets)
display(cat_demo_dlcst.sort_values(by='diff_nlst_dlcst', ascending=False).head(10))
cat_demo_dlcst.sort_values(by='diff_nlst_dlcst', ascending=True).head(10)

In [None]:
num_demo_dlcst = data.combine_diff_dfs(DLCST_DEMOCOLS['num'], data.diffs_numerical_means, screening_sets)
display(num_demo_dlcst.sort_values(by='diff_nlst_dlcst', ascending=False).head(10))
num_demo_dlcst.sort_values(by='diff_nlst_dlcst', ascending=True).head(10)

### Different validation sets

In [None]:
all_nodules = pd.read_csv(f"{FILE_DIR}/nlst_allmodels_demos.csv")
some_nodules, nlst_democols_nodules, _ = data.prep_nlst_preds(all_nodules, democols=kiran_demos, scanlevel=False, tijmen=True, sybil=False)
print(len(all_nodules), len(some_nodules))

In [None]:
all_scans, _, _ = data.prep_nlst_preds(all_nodules, democols=kiran_demos_original, scanlevel=True, tijmen=False, sybil=True)
some_scans, _, _ = data.prep_nlst_preds(all_nodules, democols=kiran_demos_original, scanlevel=True, tijmen=True, sybil=True)
print(len(all_scans), len(some_scans))

In [None]:
valsets = {
    "allnodules": all_nodules,
    "somenodules": some_nodules,
    "allscans": all_scans,
    "somescans": some_scans
}

In [None]:
cat_demo_shift = data.combine_diff_dfs(kiran_demos['cat'], data.diffs_category_prevalence, valsets)
num_demo_shift = data.combine_diff_dfs(kiran_demos['num'], data.diffs_numerical_means, valsets)

#### Difference between Nodule sets and Scan sets

In [None]:
display(cat_demo_shift.sort_values(by='diff_allnodules_allscans', ascending=False).head(10))
cat_demo_shift.sort_values(by='diff_allnodules_allscans', ascending=True).head(10)

In [None]:
display(num_demo_shift.sort_values(by='diff_allnodules_allscans', ascending=False).head(10))
num_demo_shift.sort_values(by='diff_allnodules_allscans', ascending=True).head(10)

#### Diff between sets for Tijmen's combined model vs. the rest

In [None]:
display(cat_demo_shift.sort_values(by='diff_allscans_somescans', ascending=False).head(10))
cat_demo_shift.sort_values(by='diff_allscans_somescans', ascending=True).head(10)

In [None]:
display(num_demo_shift.sort_values(by='diff_allscans_somescans', ascending=False).head(10))
num_demo_shift.sort_values(by='diff_allscans_somescans', ascending=True).head(10)

## What about men vs. women?

### Training sets

In [None]:
gender_train_sets = {
    "M":sybil_splits['train'].query('Gender == 1'),
    "F":sybil_splits['train'].query('Gender == 2'),
}

In [None]:
cat_demo_gender = data.combine_diff_dfs(sybil_demos['cat'], data.diffs_category_prevalence, gender_train_sets, include_stat=True).query('value != 0')
display(cat_demo_gender.sort_values(by='diff_M_F', ascending=False))
cat_demo_gender.sort_values(by='diff_M_F', ascending=True)

In [None]:
display(cat_demo_gender.sort_values(by='diff_M_F', ascending=False).query('category == "lungcanc"'))
cat_demo_gender.sort_values(by='diff_M_F', ascending=True).query('category == "lungcanc"')

In [None]:
num_demo_gender = data.combine_diff_dfs(sybil_demos['num'], data.diffs_numerical_means, gender_train_sets, include_stat=True)
display(num_demo_gender.sort_values(by='diff_M_F', ascending=False).head(10))
num_demo_gender.sort_values(by='diff_M_F', ascending=True).head(10)

### Evaluation sets (Kiran data)

In [None]:
gender_eval_sets = {
    "M":kiran_data.query('Gender == 1'),
    "F":kiran_data.query('Gender == 2'),
}

In [None]:
cat_gender_eval = data.combine_diff_dfs(kiran_demos['cat'], data.diffs_category_prevalence, gender_eval_sets).query('value != 0')
display(cat_gender_eval.sort_values(by='diff_M_F', ascending=False).head(40))
cat_gender_eval.sort_values(by='diff_M_F', ascending=True).head(40)

In [None]:
display(cat_gender_eval.query('category == "nodule"').sort_values(by='diff_M_F', ascending=False).head(40))
cat_gender_eval.query('category == "nodule"').sort_values(by='diff_M_F', ascending=True).head(40)

In [None]:
display(cat_gender_eval.query('attribute == "LC_stage"').sort_values(by='diff_M_F', ascending=False).head(40))
cat_gender_eval.query('attribute == "LC_stage"').sort_values(by='diff_M_F', ascending=True).head(40)

In [None]:
num_gender_eval = data.combine_diff_dfs(kiran_demos['num'], data.diffs_numerical_means, gender_eval_sets)
display(num_gender_eval.sort_values(by='diff_M_F', ascending=False).head(10))
num_gender_eval.sort_values(by='diff_M_F', ascending=True).head(10)

## What about BMI?

### Sybil training set

In [None]:
sybil_train_over = sybil_splits['train'].query('Overweight == 1')
sybil_train_normal = sybil_splits['train'].query('Overweight == 0')

overweight_train_sets = {
    "over":sybil_train_over,
    "normal":sybil_train_normal,
}

In [None]:
cat_demo_overweight = data.combine_diff_dfs(sybil_demos['cat'], data.diffs_category_prevalence, overweight_train_sets).query('value != 0')
display(cat_demo_overweight.sort_values(by='diff_over_normal', ascending=False).head(40))
cat_demo_overweight.sort_values(by='diff_over_normal', ascending=True).head(40)

In [None]:
display(cat_demo_overweight.sort_values(by='diff_over_normal', ascending=False).query('category == "lungcanc"'))
cat_demo_overweight.sort_values(by='diff_over_normal', ascending=True).query('category == "lungcanc"')

In [None]:
num_demo_overweight = data.combine_diff_dfs(sybil_demos['num'], data.diffs_numerical_means, overweight_train_sets)
display(num_demo_overweight.sort_values(by='diff_over_normal', ascending=False).head(10))
num_demo_overweight.sort_values(by='diff_over_normal', ascending=True).head(10)

### Evaluation set (Kiran Data)

In [None]:
kiran_data_over = kiran_data.query('Overweight == 1')
kiran_data_normal = kiran_data.query('Overweight == 0')

overweight_eval_sets = {
    "over":kiran_data_over,
    "normal":kiran_data_normal,
}

In [None]:
cat_eval_overweight = data.combine_diff_dfs(kiran_demos['cat'], data.diffs_category_prevalence, overweight_eval_sets).query('value != 0')
display(cat_eval_overweight.sort_values(by='diff_over_normal', ascending=False).head(40))
cat_eval_overweight.sort_values(by='diff_over_normal', ascending=True).head(40)

In [None]:
display(cat_eval_overweight.sort_values(by='diff_over_normal', ascending=False).query('category == "nodule"'))
cat_eval_overweight.sort_values(by='diff_over_normal', ascending=True).query('category == "nodule"')

In [None]:
num_eval_overweight = data.combine_diff_dfs(kiran_demos['num'], data.diffs_numerical_means, overweight_eval_sets)
display(num_eval_overweight.sort_values(by='diff_over_normal', ascending=False).head(10))
num_eval_overweight.sort_values(by='diff_over_normal', ascending=True).head(10)

## What about race?

### Kiran data

In [None]:
kiran_data_white = kiran_data.query('race == 1')
kiran_data_black = kiran_data.query('race == 2')

race_kiran_sets = {
    "white":kiran_data_white,
    "black":kiran_data_black,
}

In [None]:
cat_race_kiran = data.combine_diff_dfs(kiran_demos_original['cat'], data.diffs_category_prevalence, race_kiran_sets, include_stat=True).query('value != 0')
display(cat_race_kiran.sort_values(by='diff_white_black', ascending=False).head(40))
cat_race_kiran.sort_values(by='diff_white_black', ascending=True).head(40)

In [None]:
display(cat_race_kiran.query('category == "nodule"').sort_values(by='diff_white_black', ascending=False).head(40))
cat_race_kiran.query('category == "nodule"').sort_values(by='diff_white_black', ascending=True).head(40)

In [None]:
display(cat_race_kiran.query('attribute == "LC_stage"').sort_values(by='diff_white_black', ascending=False).head(40))
cat_race_kiran.query('attribute == "LC_stage"').sort_values(by='diff_white_black', ascending=True).head(40)

In [None]:
num_race_kiran = data.combine_diff_dfs(kiran_demos['num'], data.diffs_numerical_means, race_kiran_sets, include_stat=True)
display(num_race_kiran.sort_values(by='diff_white_black', ascending=False).head(10))
num_race_kiran.sort_values(by='diff_white_black', ascending=True).head(10)