In [2]:
import os
import sys
import re
import time
import random
import warnings
import collections
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns

sys.path.append('../src')
import cb_utils

sns.set(style="darkgrid")
pd.options.display.max_columns = 500

%load_ext autoreload
%autoreload 2

In [55]:
member_ids_query = f"SELECT id FROM cb.members m where m.mco_id = %(mco_id)s"

member_claims_query = f"""
WITH
    lagged_claims AS (
        SELECT
            member_id
          , c.id                                                            claim_id
          , date_from
          , LAG(date_from) OVER (PARTITION BY member_id ORDER BY date_from) prev_claim_date
          , is_rx
          , place_of_service
          , procedure_code                                                  cpt
          , LOWER(ndc.non_proprietary_name)                                 drug_name
          , paid_amount
        FROM
            cb.claims c
            LEFT JOIN ref.ndc_cder ndc ON ndc.id = c.rx_ndc_code_id
        WHERE
            c.mco_id = %(mco_id)s
        and c.member_id = %(member_id)s
    )
SELECT
    member_id
  , date_from
  , c.claim_id
  , prev_claim_date
  , date_from - prev_claim_date ttlc
  , is_rx
  , place_of_service
  , cpt
  , drug_name
  , paid_amount
  , ARRAY_AGG(cd.diag ORDER BY cd.diag_sequence) FILTER ( WHERE cd.diag IS NOT NULL) icds_by_seq
  , ARRAY_AGG(DISTINCT cd.diag ORDER BY cd.diag) FILTER ( WHERE cd.diag IS NOT NULL) icds_by_alpha
FROM
    lagged_claims c
    LEFT JOIN cb.claims_diagnosis cd ON c.claim_id = cd.claim_id
WHERE
    cd.mco_id = %(mco_id)s
and c.member_id = %(member_id)s
GROUP BY
    1, 2, 3, 4, 5, 6, 7, 8, 9, 10
ORDER BY
    1, 2
   """ 

In [57]:
def get_days_cat(time_to_last_claim):
    if time_to_last_claim < 0:
        raise "Got a negative time to last claim. should never happen"
    if time_to_last_claim == 0:
        return 'ttlc_0'
    if time_to_last_claim <= 1:
        return 'ttlc_1'
    if time_to_last_claim <= 2:
        return 'ttlc_2'
    if time_to_last_claim <= 5:
        return 'ttlc_5'
    if time_to_last_claim <= 10:
        return 'ttlc_10'
    if time_to_last_claim <= 15:
        return 'ttlc_15'
    if time_to_last_claim <= 30:
        return 'ttlc_30'
    if time_to_last_claim <= 60:
        return 'ttlc_60'
    if time_to_last_claim <= 90:
        return 'ttlc_90'
    if time_to_last_claim <= 180:
        return 'ttlc_180'
    
    return 'ttlc_gt180'

In [70]:
def fetch_and_build_member_data(cur, mco_id, member_id, save_path):
    cur.execute(member_claims_query, {'mco_id': mco_id, 'member_id': member_id})

    i = 1
    s = 'xxbos'
    for member_id, date_from, claim_id, prev_claim_date, ttlc, is_rx, place_of_service, cpt, drug_name, paid_amount, icds_by_seq, icds_by_alpha in cur:
        if prev_claim_date is not None:
            s += ' ' + get_days_cat(ttlc) 
        s += ' ' + ' '.join(icds_by_seq)

    file_name = f'{save_path}/{mco_id}_{member_id}.txt'

    with open(file_name, 'w') as f: f.write(s)

In [71]:
def build_language_model_data_for_mco(mco_id):
    conn = cb_utils.get_conn()
    cur = conn.cursor()

    cur.execute(member_ids_query, {'mco_id': mco_id}) 

    save_path = '../language_model_data/just_icds'
    for m in tqdm([x[0] for x in cur]):
        fetch_and_build_member_data(cur, mco_id, m, save_path)

In [72]:
build_language_model_data_for_mco(2)

100%|█████████████████████████████████████| 28919/28919 [44:41<00:00, 10.78it/s]


In [73]:
build_language_model_data_for_mco(1)

100%|█████████████████████████████████| 186729/186729 [3:38:59<00:00, 14.21it/s]


In [74]:
build_language_model_data_for_mco(4)

100%|███████████████████████████████████████| 3883/3883 [10:12<00:00,  6.34it/s]


In [75]:
build_language_model_data_for_mco(5)

100%|█████████████████████████████████████| 20114/20114 [52:19<00:00,  6.41it/s]


In [76]:
build_language_model_data_for_mco(6)

100%|███████████████████████████████████████| 5042/5042 [15:09<00:00,  5.54it/s]


### Modelling

In [77]:
from fastai.text.all import *

In [80]:
path = Path('../language_model_data')

In [81]:
files = get_text_files(path, folders = ['just_icds'])

In [82]:
len(files)

244687

In [85]:
txt = files[10].open().read(); txt[:75]

'xxbos ttlc_0 j40 j441 j40 j441 ttlc_90 j40 j441 ttlc_60 j40 j441 ttlc_30 j4'

In [96]:
txts = L(o.open().read().split(' ') for o in files[:2000])

In [97]:
num = Numericalize()
num.setup(txts)

In [98]:
coll_repr(num.vocab, 20)

"(#5944) ['xxunk','xxpad','xxbos','xxeos','xxfld','xxrep','xxwrep','xxup','xxmaj','ttlc_0','ttlc_1','i10','r5381','n186','e119','j449','r6889','r53','d631','n2581'...]"

In [107]:
nums = txts.map(num)

In [108]:
' '.join(num.vocab[o] for o in nums[0])

'xxbos ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311

In [109]:
dl = LMDataLoader(nums)

In [110]:
x,y = first(dl)

In [112]:
x.shape, y.shape

(torch.Size([64, 72]), torch.Size([64, 72]))

In [113]:
' '.join(num.vocab[o] for o in x[0][:20])

'xxbos ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630'

In [114]:
' '.join(num.vocab[o] for o in y[0][:20])

'ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634 d631 ttlc_0 i1311 n185 z905 r630 r634'

In [120]:
dl?

[0;31mType:[0m        LMDataLoader
[0;31mString form:[0m <fastai.text.data.LMDataLoader object at 0x1583917c0>
[0;31mLength:[0m      463
[0;31mFile:[0m        ~/.local/share/virtualenvs/data-analytics-1yVNxZKx/lib/python3.8/site-packages/fastai/text/data.py
[0;31mDocstring:[0m   A `DataLoader` suitable for language modeling


In [121]:
dls = TextDataLoaders.from_folder(path / 'just_icds', valid_pct=.1, seed=None, is_lm=True, tok_tfm=None, seq_len=72, backwards=False, bs=64, val_bs=None, shuffle=True, device=None)

In [124]:
learn = language_model_learner(
    dls, AWD_LSTM, drop_mult=0.3, 
    pretrained=False,
    metrics=[accuracy, Perplexity()]).to_fp16()



In [125]:
learn.fit_one_cycle(1, 

[0;31mSignature:[0m
[0mlearn[0m[0;34m.[0m[0mfit_one_cycle[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mn_epoch[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlr_max[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdiv[0m[0;34m=[0m[0;36m25.0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdiv_final[0m[0;34m=[0m[0;36m100000.0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpct_start[0m[0;34m=[0m[0;36m0.25[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mwd[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmoms[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcbs[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreset_opt[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m Fit `self.model` for `n_epoch` using the 1cycle policy.
[0;31mFile:[0m      ~/.local/share/virtualenvs/data-analytics-1yVNxZKx/lib/python3

In [130]:
learn.lr_find()