In [1]:
import pandas as pd
from pathlib import Path

# Paths
data_path = Path('/Users/sumanth/personal_project/data/')
out_path  = Path('/Users/sumanth/personal_project/outputs/')
out_path.mkdir(parents=True, exist_ok=True)

# Read robustly
def read_csv_clean(fn):
    return pd.read_csv(data_path / fn, engine='python', on_bad_lines='skip')

patients   = read_csv_clean('patients.csv')
conditions = read_csv_clean('conditions.csv')
conditions['START'] = pd.to_datetime(conditions['START'], errors='coerce')

AFTER = pd.Timestamp('2020-01-01')

# Disease keyword sets
CATEGORIES = {
    'cancer': [
        'malignant neoplasm','carcinoma','cancer','sarcoma','melanoma',
        'lymphoma','leukemia','myeloma','metastasis','in situ'
    ],
    'diabetes': ['diabetes mellitus type 2','prediabetes','hyperglycemia'],
    'heart_disease': ['ischemic heart disease','myocardial infarction','heart failure','coronary','angina'],
    'stroke': ['stroke','cerebrovascular accident'],
    'copd_asthma': ['chronic obstructive','emphysema','asthma'],
    'ckd': ['chronic kidney disease','end-stage renal disease'],
    'hypertension': ['essential hypertension']
}

# Strict count = first-dx filter after computing earliest dx
def count_unique_patients_strict(terms):
    m = conditions['DESCRIPTION'].str.contains('|'.join(terms), case=False, na=False)
    cond_sub = conditions.loc[m, ['PATIENT','START']]
    if cond_sub.empty:
        return 0
    first_dx = cond_sub.groupby('PATIENT')['START'].min().reset_index(name='FIRST_DIAGNOSIS')
    recent_first = first_dx[first_dx['FIRST_DIAGNOSIS'] >= AFTER]
    return recent_first['PATIENT'].nunique()

# Get strict counts for all categories
counts_strict = {k: count_unique_patients_strict(v) for k,v in CATEGORIES.items()}
print('Strict counts (first diagnosis >= 2020):')
for k,v in counts_strict.items():
    print(f'  {k}: {v}')

# Helper to build a strict cohort for any set of terms (OR across terms)
def build_cohort_strict(terms, label):
    m = conditions['DESCRIPTION'].str.contains('|'.join(terms), case=False, na=False)
    cond_sub = conditions.loc[m, ['PATIENT','START']]
    if cond_sub.empty:
        print(f'No matches for {label}')
        return pd.DataFrame()
    first_dx = cond_sub.groupby('PATIENT')['START'].min().reset_index(name='FIRST_DIAGNOSIS')
    recent = first_dx[first_dx['FIRST_DIAGNOSIS'] >= AFTER]
    cohort = patients.merge(recent, left_on='Id', right_on='PATIENT', how='inner')
    print(f'{label} cohort size (strict):', cohort.shape[0])
    return cohort

# Examples:
# cancer only
cancer_cohort = build_cohort_strict(CATEGORIES['cancer'], 'cancer')
cancer_cohort.to_csv(out_path / 'cohort_cancer_after2020.csv', index=False)

# cancer + diabetes
combo_terms = CATEGORIES['cancer'] + CATEGORIES['diabetes']
combo_cohort = build_cohort_strict(combo_terms, 'cancer_plus_diabetes')
combo_cohort.to_csv(out_path / 'cohort_cancer_diabetes_after2020.csv', index=False)


Strict counts (first diagnosis >= 2020):
  cancer: 275
  diabetes: 445
  heart_disease: 356
  stroke: 18
  copd_asthma: 70
  ckd: 231
  hypertension: 388
cancer cohort size (strict): 275
cancer_plus_diabetes cohort size (strict): 531
