In [None]:
import pandas as pd
import os
import numpy as np
import json
import seaborn as sns
import matplotlib.pyplot as plt
import copy
from IPython.display import display, Markdown

from utilities import data
from utilities.info import *

%matplotlib inline

## Grab Data for Each Model

In [None]:
venk21_nodule = pd.read_csv(f"{FILE_DIR}/nlst_allmodels_demos.csv")

with open(f'{FILE_DIR}/nlst_democols.json') as json_data:
    venk21_demos_original = json.load(json_data)
    json_data.close()

venk21_data, venk21_demos, _ = data.prep_nlst_preds(venk21_nodule, venk21_demos_original, scanlevel=True, sybil=False, bin_num=False)
venk21_demos['cat']['other'].append('label')

In [None]:
sybil_data = pd.read_csv(f"{FILE_DIR}/nlst_sybil_demos.csv")

with open(f'{FILE_DIR}/nlst_sybil_democols.json') as json_data:
    sybil_demos = json.load(json_data)
    json_data.close()

In [None]:
sybil_splits = {s: sybil_data.query(f'split == "{s}"') for s in ['train', 'dev', 'test']}

## Model Training Sets

In [None]:
training_sets = {
    'Venk21': venk21_data,
    'Sybil': sybil_splits['train']
}

### Categorical columns

In [None]:
cat_demo_splits = data.combine_diff_dfs(sybil_demos['cat'], data.diffs_category_prevalence, training_sets).dropna(subset='value', axis=0).query('value != 0')
display(cat_demo_splits.sort_values(by='diff_Venk21_Sybil', ascending=False).head(10))
display(cat_demo_splits.sort_values(by='diff_Venk21_Sybil', ascending=True ).head(10))

### Numerical columns

In [None]:
num_demo_splits = data.combine_diff_dfs(sybil_demos['num'], data.diffs_numerical_means, training_sets)
display(num_demo_splits.sort_values(by='diff_Venk21_Sybil', ascending=False).query('diff_Venk21_Sybil > 0'))
num_demo_splits.sort_values(by='diff_Venk21_Sybil', ascending=True).query('diff_Venk21_Sybil < 0')

## Model Validation Sets

In [None]:
val_sets = {
    'Venk21': venk21_data,
    'SybilDev': sybil_splits['dev'],
    'SybilTest': sybil_splits['test']
}

### Categorical columns

In [None]:
cat_demo_val = data.combine_diff_dfs(sybil_demos['cat'], data.diffs_category_prevalence, val_sets).dropna(subset='value', axis=0).query('value != 0')
display(cat_demo_val.sort_values(by='diff_Venk21_SybilTest', ascending=False).head(10))
display(cat_demo_val.sort_values(by='diff_Venk21_SybilTest', ascending=True ).head(10))

### Numerical columns

In [None]:
num_demo_val = data.combine_diff_dfs(sybil_demos['num'], data.diffs_numerical_means, val_sets)    
display(num_demo_val.sort_values(by='diff_Venk21_SybilTest', ascending=False).head(10))
display(num_demo_val.sort_values(by='diff_Venk21_SybilTest', ascending=True ).head(10))

## Sybil Train vs. Validation Sets

In [None]:
sybil_splits["eval"] = venk21_data

In [None]:
cat_demo_shift = data.combine_diff_dfs(sybil_demos['cat'], data.diffs_category_prevalence, sybil_splits).dropna(subset='value', axis=0).query('value != 0')
display(cat_demo_shift.sort_values(by='diff_train_test', ascending=False).head(10))
cat_demo_shift.sort_values(by='diff_train_test', ascending=True).head(10)

In [None]:
num_demo_shift = data.combine_diff_dfs(sybil_demos['num'], data.diffs_numerical_means, sybil_splits)
display(num_demo_shift.sort_values(by='diff_train_test', ascending=False).head(10))
num_demo_shift.sort_values(by='diff_train_test', ascending=True).head(10)

Conclusion: not much demographic shift (1-2% overall not much, really). Besides family history.

## What about men vs. women?

### Training sets

In [None]:
gender_train_sets = {
    "M":sybil_splits['train'].query('Gender == 1'),
    "F":sybil_splits['train'].query('Gender == 2'),
}

In [None]:
cat_demo_gender = data.combine_diff_dfs(sybil_demos['cat'], data.diffs_category_prevalence, gender_train_sets, include_stat=True).query('value != 0')
display(cat_demo_gender.sort_values(by='diff_M_F', ascending=False))
cat_demo_gender.sort_values(by='diff_M_F', ascending=True)

In [None]:
display(cat_demo_gender.sort_values(by='diff_M_F', ascending=False).query('category == "lungcanc"'))
cat_demo_gender.sort_values(by='diff_M_F', ascending=True).query('category == "lungcanc"')

In [None]:
num_demo_gender = data.combine_diff_dfs(sybil_demos['num'], data.diffs_numerical_means, gender_train_sets, include_stat=True)
display(num_demo_gender.sort_values(by='diff_M_F', ascending=False).head(10))
num_demo_gender.sort_values(by='diff_M_F', ascending=True).head(10)

### Evaluation sets (Venk21 data)

In [None]:
gender_eval_sets = {
    "M":venk21_data.query('Gender == 1'),
    "F":venk21_data.query('Gender == 2'),
}

In [None]:
cat_gender_eval = data.combine_diff_dfs(venk21_demos['cat'], data.diffs_category_prevalence, gender_eval_sets).query('value != 0')
display(cat_gender_eval.sort_values(by='diff_M_F', ascending=False).head(40))
cat_gender_eval.sort_values(by='diff_M_F', ascending=True).head(40)

In [None]:
display(cat_gender_eval.query('category == "nodule"').sort_values(by='diff_M_F', ascending=False).head(40))
cat_gender_eval.query('category == "nodule"').sort_values(by='diff_M_F', ascending=True).head(40)

In [None]:
display(cat_gender_eval.query('attribute == "LC_stage"').sort_values(by='diff_M_F', ascending=False).head(40))
cat_gender_eval.query('attribute == "LC_stage"').sort_values(by='diff_M_F', ascending=True).head(40)

In [None]:
num_gender_eval = data.combine_diff_dfs(venk21_demos['num'], data.diffs_numerical_means, gender_eval_sets)
display(num_gender_eval.sort_values(by='diff_M_F', ascending=False).head(10))
num_gender_eval.sort_values(by='diff_M_F', ascending=True).head(10)

## What about BMI?

### Sybil training set

In [None]:
sybil_train_over = sybil_splits['train'].query('Overweight == 1')
sybil_train_normal = sybil_splits['train'].query('Overweight == 0')

overweight_train_sets = {
    "over":sybil_train_over,
    "normal":sybil_train_normal,
}

In [None]:
cat_demo_overweight = data.combine_diff_dfs(sybil_demos['cat'], data.diffs_category_prevalence, overweight_train_sets).query('value != 0')
display(cat_demo_overweight.sort_values(by='diff_over_normal', ascending=False).head(40))
cat_demo_overweight.sort_values(by='diff_over_normal', ascending=True).head(40)

In [None]:
display(cat_demo_overweight.sort_values(by='diff_over_normal', ascending=False).query('category == "lungcanc"'))
cat_demo_overweight.sort_values(by='diff_over_normal', ascending=True).query('category == "lungcanc"')

In [None]:
num_demo_overweight = data.combine_diff_dfs(sybil_demos['num'], data.diffs_numerical_means, overweight_train_sets)
display(num_demo_overweight.sort_values(by='diff_over_normal', ascending=False).head(10))
num_demo_overweight.sort_values(by='diff_over_normal', ascending=True).head(10)

### Evaluation set (venk21 Data)

In [None]:
venk21_data_over = venk21_data.query('Overweight == 1')
venk21_data_normal = venk21_data.query('Overweight == 0')

overweight_eval_sets = {
    "over":venk21_data_over,
    "normal":venk21_data_normal,
}

In [None]:
cat_eval_overweight = data.combine_diff_dfs(venk21_demos['cat'], data.diffs_category_prevalence, overweight_eval_sets).query('value != 0')
display(cat_eval_overweight.sort_values(by='diff_over_normal', ascending=False).head(40))
cat_eval_overweight.sort_values(by='diff_over_normal', ascending=True).head(40)

In [None]:
display(cat_eval_overweight.sort_values(by='diff_over_normal', ascending=False).query('category == "nodule"'))
cat_eval_overweight.sort_values(by='diff_over_normal', ascending=True).query('category == "nodule"')

In [None]:
num_eval_overweight = data.combine_diff_dfs(venk21_demos['num'], data.diffs_numerical_means, overweight_eval_sets)
display(num_eval_overweight.sort_values(by='diff_over_normal', ascending=False).head(10))
num_eval_overweight.sort_values(by='diff_over_normal', ascending=True).head(10)

## What about race?

### venk21 data

In [None]:
venk21_data_white = venk21_data.query('race == 1')
venk21_data_black = venk21_data.query('race == 2')

race_venk21_sets = {
    "white":venk21_data_white,
    "black":venk21_data_black,
}

In [None]:
cat_race_venk21 = data.combine_diff_dfs(venk21_demos['cat'], data.diffs_category_prevalence, race_venk21_sets, include_stat=True).query('value != 0')
display(cat_race_venk21.sort_values(by='diff_white_black', ascending=False).head(40))
cat_race_venk21.sort_values(by='diff_white_black', ascending=True).head(40)

In [None]:
display(cat_race_venk21.query('category == "nodule"').sort_values(by='diff_white_black', ascending=False).head(40))
cat_race_venk21.query('category == "nodule"').sort_values(by='diff_white_black', ascending=True).head(40)

In [None]:
display(cat_race_venk21.query('attribute == "LC_stage"').sort_values(by='diff_white_black', ascending=False).head(40))
cat_race_venk21.query('attribute == "LC_stage"').sort_values(by='diff_white_black', ascending=True).head(40)

In [None]:
num_race_venk21 = data.combine_diff_dfs(venk21_demos['num'], data.diffs_numerical_means, race_venk21_sets, include_stat=True)
display(num_race_venk21.sort_values(by='diff_white_black', ascending=False).head(10))
num_race_venk21.sort_values(by='diff_white_black', ascending=True).head(10)