In [1]:
import os
import sys
import re
import time
import random
import warnings
import collections
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns

sys.path.append('../src')
import cb_utils

sns.set(style="darkgrid")
pd.options.display.max_columns = 500

%load_ext autoreload
%autoreload 2

In [2]:
os.getcwd()

'/Users/bp/workspace/cb/data-analytics/notebooks'

In [3]:
member_ids_query = f"SELECT * FROM junk.language_model_test20220923 m where m.mco_id = %(mco_id)s"

member_claims_query = f"""
  WITH
      encounter_level   AS ( SELECT DISTINCT
                                 member_id
                               , date_from
                               , array_agg(distinct c.id) claim_ids
                             FROM
                                 cb.claims c
                                 LEFT JOIN ref.place_of_services pos ON pos.id = c.place_of_service_id
                                 LEFT JOIN ref.service_types st ON st.id = c.service_type_id
                             WHERE
                                   c.mco_id = %(mco_id)s
                               and c.member_id = %(member_id)s
                               AND c.service_type_id NOT IN (12, 13, 17, 18, 10, 15, 16)
                               AND NOT c.is_rx
                               AND c.date_from between %(pre_start)s and %(pre_end)s
                             GROUP BY 1,2
                             )
    , lagged_encounters AS ( SELECT
                                 el.*
                               , LAG(date_from) OVER (PARTITION BY member_id ORDER BY date_from) prev_claim_date
                             FROM
                                 encounter_level el )
  SELECT
      le.member_id
    , c.date_from
    , le.date_from - prev_claim_date days_since_last_encounter
    , ARRAY_AGG(cd.diag ORDER BY c.claim_line_id, cd.diag_sequence) FILTER ( WHERE cd.diag IS NOT NULL) icds_by_seq
  FROM
      lagged_encounters le
      JOIN cb.claims c ON c.id = any(le.claim_ids)
      LEFT JOIN cb.claims_diagnosis cd ON c.id = cd.claim_id
  WHERE
       c.mco_id = %(mco_id)s
   and cd.mco_id = %(mco_id)s
   and c.member_id = %(member_id)s
  GROUP BY
      1, 2, 3
  ORDER BY
      1, 2
;
   """ 

In [11]:
def get_days_cat(time_to_last_claim):
    if time_to_last_claim < 0:
        raise "Got a negative time to last claim. should never happen"
    if time_to_last_claim == 0:
        return 'ttlc_0'
    if time_to_last_claim <= 1:
        return 'ttlc_1'
    if time_to_last_claim <= 2:
        return 'ttlc_2'
    if time_to_last_claim <= 5:
        return 'ttlc_5'
    if time_to_last_claim <= 10:
        return 'ttlc_10'
    if time_to_last_claim <= 15:
        return 'ttlc_15'
    if time_to_last_claim <= 30:
        return 'ttlc_30'
    if time_to_last_claim <= 60:
        return 'ttlc_60'
    if time_to_last_claim <= 90:
        return 'ttlc_90'
    if time_to_last_claim <= 180:
        return 'ttlc_180'
    
    return 'ttlc_gt180'

In [8]:
def fetch_and_build_member_data(cur, mco_id, member, is_validation, save_path):
    member_id = member['member_id']
    target = member['impactable_spend_post']
    age = member['age']
    gender = member['gender']
    
    cur.execute(member_claims_query, {'mco_id': mco_id, 'member_id': member_id, 'pre_start': '2020-01-01', 'pre_end': '2020-12-31'})

    i = 1
    s = 'xxbos'
    for member_id, date_from, ttlc, icds_by_seq in cur:
        if ttlc is not None:
            s += ' ' + get_days_cat(ttlc) 
        s += ' ' + ' '.join(dict.fromkeys(icds_by_seq))
        

    dataset = 'valid' if is_validation else 'train'
    file_name = f'{save_path}/{dataset}/{mco_id}_{member_id}_{age}_{gender}_{target}.txt'

    with open(file_name, 'w') as f: f.write(s)

In [9]:
def build_language_model_data_for_mco(mco_id):
    cols = ['member_id' , 'mco_id' , 'pre_start' , 'pre_end' , 'post_start' , 'post_end' , 'impactable_spend_pre' , 'impactable_spend_post' , 'pre_impactable_spend_pct' , 'post_impactable_spend_pct' , 'age' , 'gender']
    conn = cb_utils.get_conn()
    cur = conn.cursor()

    cur.execute(member_ids_query, {'mco_id': mco_id}) 

    save_path = './data/icds_and_target'
    for row in tqdm([x for x in cur]):
        is_validation = random.random() >= .8
        member = {c: f for c, f in zip(cols, row)}
        fetch_and_build_member_data(cur, mco_id, member, is_validation, save_path)

In [14]:
mco_ids = [1, 2, 4, 5, 6, 8, 9]
mco_ids = [9]
for mco_id in mco_ids:
    build_language_model_data_for_mco(mco_id)

0it [00:00, ?it/s]
