In [1]:
import csv
import os
import json
import re
import glob
import copy
import pickle
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
from transformers import BasicTokenizer

import sys
sys.path.append(os.path.dirname(os.getcwd()))   # add ../ to the system path
from Tokenizers import NaiveTokenizer           # then we can import our tokenizer

# 1. Raw data --> subject (i.e., patient) oriented data

- Saving events of a subject into a single file;
- Only keeping subjects whose events have at least one patient instruction (PI);
- Sorting events of a subject according to time (the top most event occurs first).

In [2]:
def prepare_subject_data(
        raw_data_path='NOTEEVENTS.csv', 
        save_path='./subjects_with_PI', 
        save_info=True,
        only_keep_subjects_with_PI=True,
        sort=True,
    ):
    print('- Reading data from', raw_data_path)
    if save_info:
        print('- Saving subject data to', save_path)
        os.makedirs(save_path, exist_ok=True)

    subjects = set()
    subjects_with_PI = {}

    pattern = re.compile('discharge\s*instructions.{0,30}:') # `discharge instruction` is referred as PI
    with open(raw_data_path) as f:
        reader = csv.reader(f, delimiter = ',')
        header = next(reader)
        print('- Header of the raw data:', header)

        for line in tqdm(reader):
            assert len(line) == 11, line
            subject_id = line[1]
            category = line[6]
            description = line[7]

            subjects.add(subject_id)

            if category.lower() == 'discharge summary':
                text = line[-1].lower()
                if description.lower() == 'report':
                    if pattern.search(text) is not None:
                        # a subject may have several hospitial admissions
                        subjects_with_PI[subject_id] = subjects_with_PI.get(subject_id, 0) + 1
            
            if save_info:
                this_subject_save_path = os.path.join(save_path, subject_id+'.csv')

                if os.path.exists(this_subject_save_path):
                    with open(this_subject_save_path, 'a') as wf:
                        writer = csv.writer(wf)
                        writer.writerow(line)
                else:
                    with open(this_subject_save_path, 'w') as wf:
                        writer = csv.writer(wf)
                        writer.writerow(header)
                        writer.writerow(line)
        
        print(f'- There are {len(subjects)} unique subjects in the raw data')
        print(f'- Only {len(subjects_with_PI)} subjects have >=1 PI')

        if save_info and only_keep_subjects_with_PI:
            for id in subjects:
                if id not in subjects_with_PI:
                    os.remove(os.path.join(save_path, id+'.csv'))                
    
    if sort:
        sort_subject_data(save_path)

    return subjects_with_PI


def reformat_date(string):
    year, month, day = string.split('-')
    return '{:04d}-{:02d}-{:02d}'.format(int(year), int(month), int(day))


def sort_subject_data(src='./subjects_with_PI'):
    print(f'- Sorting subject data in {src}')
    print('- For all events of a subject, the higher rank of the event, the earlier it happens')
    assert os.path.exists(src)
    all_subjects = glob.glob(os.path.join(src, '*.csv'))

    for subject_path in tqdm(all_subjects):
        with open(subject_path, 'r') as f:
            reader = csv.reader(f, delimiter=',')
            header = next(reader)
            record_appearance = set()

            all_lines = []
            for line in reader:
                line[3] = reformat_date(line[3]) #chartdate
                if line[-1] in record_appearance:
                    continue
                else:
                    record_appearance.add(line[-1])
                all_lines.append(line)
            
            all_lines = sorted(all_lines, key=lambda x: x[3])
        
        with open(subject_path, 'w') as wf:
            writer = csv.writer(wf)
            writer.writerow(header)
            for line in all_lines:
                writer.writerow(line)

In [3]:
subjects_with_PI = prepare_subject_data(
    raw_data_path='NOTEEVENTS.csv', 
    save_path='./subjects_with_PI/', 
    save_info=True,
    only_keep_subjects_with_PI=True,
    sort=True,
)

- Reading data from NOTEEVENTS.csv
- Saving subject data to ./subjects_with_PI/
- Header of the raw data: ['ROW_ID', 'SUBJECT_ID', 'HADM_ID', 'CHARTDATE', 'CHARTTIME', 'STORETIME', 'CATEGORY', 'DESCRIPTION', 'CGID', 'ISERROR', 'TEXT']


2083180it [09:35, 3618.33it/s]


- There are 46146 unique subjects in the raw data
- Only 30554 subjects have >=1 PI
- Sorting subject data in ./subjects_with_PI/
- For all events of a subject, the higher rank of the event, the earlier it happens


100%|██████████| 30554/30554 [04:31<00:00, 112.67it/s]


# 2. Extract patient instructions from subject oriented data

- Each patient instruction is saved to a single file.
- Each file is named like `subjectID_admissionID.csv`, because a subject may have several hospital admissions.

In [4]:
def clean_string(string):
    string = string.replace('??????', ' ')
    string = string.replace(r'\n', ' ')
    string = string.replace(r'\t', ' ')
    string = re.sub('\.\s*\.', '.', string) # e.g., a full stop followed by any spaces -->  a full stop
    strig = string.strip()
    return string

def extract_patient_instructions(
        src_path='./subjects_with_PI', 
        trg_path='./patient_instructions', 
        subject_ids=[],
        prev_failed_subjectID_to_admissionIDs=None,
        first_key='discharge\s*instructions.{0,30}:',
        second_key='followup\s*instruction',
        record_na_subjects=None,
        record_dead_subjects=None,
        index_pi_to_the_last=False
    ):    
    assert os.path.exists(src_path)
    os.makedirs(trg_path, exist_ok=True)
    
    if len(subject_ids):
        all_subjects = [os.path.join(src_path, sid+'.csv') for sid in subject_ids]
    else:
        all_subjects = glob.glob(os.path.join(src_path, '*.csv'))

    pattern = re.compile(r'{}(.*?){}'.format(first_key, second_key))

    failed_subjectID = set()
    failed_subjectID_to_admissionIDs = defaultdict(set)

    for subject_path in tqdm(all_subjects):
        with open(subject_path, 'r') as f:
            reader = csv.reader(f, delimiter=',')
            for line in reader:
                subject_id = line[1]
                if line[6].lower() == 'discharge summary' and line[7].lower() == 'report':
                    text = line[-1].lower()

                    if re.search(r'discharge\s*instructions.{0,30}:', repr(text)) is None:
                        continue

                    admission_id = line[2]
                    if prev_failed_subjectID_to_admissionIDs is not None \
                        and admission_id not in prev_failed_subjectID_to_admissionIDs[subject_id]:
                        continue

                    key = f'{subject_id}_{admission_id}'

                    # find the text between first_key and second key
                    if index_pi_to_the_last:
                        repr_text = repr(text)
                        pos = re.search(r'discharge\s*instructions.{0,30}:', repr_text).span()[1]
                        result = [repr_text[pos:]]
                    else:
                        result = pattern.findall(repr(text))

                    if len(result):
                        assert len(result) == 1

                        if record_dead_subjects is not None:
                            tmp_flag = False
                            for ts in ['expired', 'deceased']:
                                if re.search(r'\b{}\b'.format(ts), result[0]) is not None:
                                    tmp_flag = True
                                    break
                            if tmp_flag:
                                record_dead_subjects.add(key)

                        if record_na_subjects is not None:
                            if re.search(r'n/a\b', result[0]) is not None:
                                record_na_subjects.add(key)

                        with open(os.path.join(trg_path, key), 'w') as wf:
                            pi = result[0]
                            pi = clean_string(pi)
                            if index_pi_to_the_last:
                                pi = pi[:-3] # ignore the last useless chars
                            wf.write(pi)
                    else:
                        failed_subjectID.add(subject_id)
                        failed_subjectID_to_admissionIDs[subject_id].add(admission_id)
    
    return failed_subjectID, failed_subjectID_to_admissionIDs


def show_saved_di(keys, save_path='./patient_instructions', k=3):
    assert os.path.exists(save_path)
    for i, key in enumerate(keys):
        if i >= k:
            break
        with open(os.path.join(save_path, key), 'r') as f:
            print(key, f.readlines())
            print('-'*50)
        

In [5]:
failed_subjectID, failed_subjectID_to_admissionIDs = [], None
record_na_subjects = set()
record_dead_subjects = set()

for second_key in [
        'followup\s*instruction', 
        '\[\*\*.{0,50}\*\*\][p\s]\s*instruction',
        'instructions:',
        '\s{5,}\[\*\*',
        '\s{5,}dr[\s\.]\s*\[\*\*',
        'reviewed\s*by',
        'completed\s*by',
        'dictated\s*by',
        ''
    ]:
    if second_key == '':
        index_pi_to_the_last = True
    else:
        index_pi_to_the_last = False

    failed_subjectID, failed_subjectID_to_admissionIDs = extract_patient_instructions(
        src_path='./subjects_with_PI/',
        trg_path='./patient_instructions/',
        subject_ids=failed_subjectID, 
        prev_failed_subjectID_to_admissionIDs=failed_subjectID_to_admissionIDs,
        second_key=second_key,
        record_na_subjects=record_na_subjects,
        record_dead_subjects=record_dead_subjects,
        index_pi_to_the_last=index_pi_to_the_last,
    )
    print('In this round, fail to extract %d PIs of %d subjects' % (
        sum([len(v) for k, v in failed_subjectID_to_admissionIDs.items()]), 
        len(failed_subjectID)))

print('Successfully extract %d patient instructions' % (len(os.listdir('./patient_instructions/'))))

100%|██████████| 30554/30554 [02:16<00:00, 223.94it/s]


In this round, fail to extract 2704 PIs of 2592 subjects


100%|██████████| 2592/2592 [00:09<00:00, 278.84it/s]


In this round, fail to extract 2357 PIs of 2259 subjects


100%|██████████| 2259/2259 [00:09<00:00, 241.58it/s]


In this round, fail to extract 2318 PIs of 2223 subjects


100%|██████████| 2223/2223 [00:07<00:00, 303.05it/s]


In this round, fail to extract 297 PIs of 297 subjects


100%|██████████| 297/297 [00:00<00:00, 309.89it/s]


In this round, fail to extract 196 PIs of 196 subjects


100%|██████████| 196/196 [00:00<00:00, 266.16it/s]


In this round, fail to extract 183 PIs of 183 subjects


100%|██████████| 183/183 [00:00<00:00, 290.84it/s]


In this round, fail to extract 101 PIs of 101 subjects


100%|██████████| 101/101 [00:00<00:00, 281.62it/s]


In this round, fail to extract 70 PIs of 70 subjects


100%|██████████| 70/70 [00:00<00:00, 224.20it/s]

In this round, fail to extract 0 PIs of 0 subjects
Successfully extract 39275 patient instructions





In [6]:
print('There are %d patient instructions not suitable to use.' % (len(record_na_subjects) + len(record_dead_subjects)))
print('Here are some examples:')
show_saved_di(record_na_subjects, save_path='./patient_instructions', k=3)
show_saved_di(record_dead_subjects, save_path='./patient_instructions', k=3)

There are 1149 patient instructions not suitable to use.
Here are some examples:
85723_121701 [' n/a  ']
--------------------------------------------------
25395_178023 [' n/a  ']
--------------------------------------------------
31916_146431 [' n/a  ']
--------------------------------------------------
9544_115263 [' (deceased)  ']
--------------------------------------------------
32361_194740 [' patient expired at 6/7.  ']
--------------------------------------------------
59260_176324 [' patient expired on [**2167-1-27**].  ']
--------------------------------------------------


# 3. Extract health records from subject oriented data

- Each health record is saved to a single file.
- Each file is named like `subjectID_admissionID`, because a subject may have several hospital admissions.

In [7]:
def extract_health_records(
        src_path='./subjects_with_PI',
        trg_path='./health_records',
        delete_first_two_lines=True,
    ):    
    assert os.path.exists(src_path)
    os.makedirs(trg_path, exist_ok=True)
    
    all_subjects = glob.glob(os.path.join(src_path, '*.csv'))

    success_count = 0
    for subject_path in tqdm(all_subjects):
        with open(subject_path, 'r') as f:
            reader = csv.reader(f, delimiter=',')
            for line in reader:
                subject_id = line[1]
                if line[6].lower() == 'discharge summary' and line[7].lower() == 'report':
                    text = line[-1].lower()

                    if re.search(r'discharge\s*instructions.{0,30}:', repr(text)) is None:
                        continue

                    admission_id = line[2]
                    key = f'{subject_id}_{admission_id}'

                    if delete_first_two_lines:
                        '''
                        the first two lines of health record look like below

                            Date of Birth:   [**2143-5-12**]     Sex:  F
                            Service:
                        
                        We delete these information by default.
                        '''
                        
                        tmp = re.search(r'\s*sex:.*\n', text)
                        tmp2 = re.search(r'[\n]service.*:', text)
                        tmp3 = re.search(r'\nallergies:', text)
                        tmp4 = re.search(r'\nchief\s*complaint.*:', text)
                        tmp5 = re.search(r'\nhistory\s*of.*:', text)
                        if tmp is not None:
                            text = text[tmp.span()[1]:]
                        elif tmp2 is not None:
                            text = text[tmp2.span()[0]:]
                        elif tmp3 is not None:
                            text = text[tmp3.span()[0]:]
                        elif tmp4 is not None:
                            text = text[tmp4.span()[0]:]
                        elif tmp5 is not None:
                            text = text[tmp5.span()[0]:]
                        else:
                            print(text)

                    text = repr(text)
                    pos = re.search(r'discharge\s*instructions.{0,30}:', text).span()[0]
                    health_record = text[:pos]
                    health_record = clean_string(health_record)
                    with open(os.path.join(trg_path, key), 'w') as wf:
                        wf.write(health_record)
                        success_count += 1
    
    print(f'Successfully extract {success_count} health records')


In [9]:
_ = extract_health_records(
    src_path='./subjects_with_PI',
    trg_path='./health_records'
)

100%|██████████| 30554/30554 [02:19<00:00, 218.62it/s]

Successfully extract 39608 health records





# 4. Post-process patient instructions and health records

- Replace desensitized data with special tokens (e.g., [** 2022/07 **] --> [date])

In [10]:
replaced_tokens_mapping = {
    '[date]': [
        ('re', r'\[\*\*[0-9\-/]*\*\*\]'), 
        ('in', ['month', 'year', 'date range']), 
        ('in', ['january', 'february', 'march', 'april', 'may', 'june', 'july', 'august', 'september', 'october', 'november', 'december'])],
    '[hospital]': ('in', ['hospital']),
    '[contact]': ('in', ['telephone', 'contact']),
    '[address]': ('in', ['address', 'location']),
    '[identifier]': ('in', ['identifier']),
    '[country]': ('in', ['country']),
    '[university]': ('in', ['university', 'college']),
    '[company]': ('in', ['company']),
    '[provider]': ('in', ['provider']),
    '[url]': ('in', ['url']),
    '[state]': ('in', ['state']),
    '[holiday]': ('in', ['holiday']),
    '[number]': ('in', ['medical record number', 'serial number', 'social security number', 
                        'job number', 'pager number', 'unit number', 'po box', 'md number']),
    '[name]': ('in', ['lastname', 'last name', 'firstname', 'first name', 'nameis', 'namepattern',
                        '(ptitle)', '(stitle)', '(ni)', '(md)', '(pre)', '(prefixes)', 'wardname']),
    '[age]': ('in', ['age over']),
    '[clip]': ('in', ['clip number', 'radiology']),
    '[info]': ('in', ['attending info'])
}

In [11]:
def check_replace(text, mapping):
    for k, vs in mapping.items():
        if not isinstance(vs, list):
            vs = [vs]
        for v in vs:
            assert len(v) == 2
            type_, string = v
            if type_ == 'in':
                if not isinstance(string, list):
                    string = [string]
                flag = False
                for item in string:
                    if item in text:
                        flag = True
                if flag:
                    # here, we add spaces before and after the speicial tokens
                    # to avoid errors in Transformers.BasicTokenizer
                    return ' ' + k + ' '
            elif type_ == 're':
                if re.search(string, text) is not None:
                    return ' ' + k + ' '
            else:
                raise ValueError()
    
    return None


def clean_tokenized_data(data):
    new_data = []
    for word in data:
        if not word.isalnum() and len(new_data) == 0:
            # delete the number at the beginning
            continue
        if len(new_data) > 0 and word == new_data[-1]:
            # delete repeated words
            continue
        new_data.append(word)
    
    return new_data


def replace_special_tokens(
        src_path='./patient_instructions', 
        trg_path='./processed_patient_instructions',
        mapping={},
        extra_never_split=[],
        all_file_paths=None,
        save_pickle=True,
        save_pickle_path='.',
        save_pickle_prefix='',
        print_info=True
    ):
    assert os.path.exists(src_path)
    os.makedirs(trg_path, exist_ok=True)
    pattern = re.compile(r'\[\*\*[^*]*\*\*\]')

    if len(mapping) or len(extra_never_split):
        never_split = list(mapping.keys()) + extra_never_split
        tokenizer = BasicTokenizer(never_split=never_split)
    else:
        tokenizer = BasicTokenizer()

    unpropcessed_items = set()
    record_special_tokens = defaultdict(list)
    word_counts = {}
    length_counts = {}
    length_to_key = defaultdict(list)

    if all_file_paths is None:
        all_file_paths = [os.path.join(src_path, item) for item in os.listdir(src_path)]

    for path in tqdm(all_file_paths):
        file = os.path.basename(path)
        with open(path, 'r') as f:
            data = f.read()
            this_results = pattern.findall(data)

            new_data = copy.deepcopy(data)
            for item in this_results:
                replace_to = check_replace(item, mapping)
                if replace_to is not None:
                    new_data = new_data.replace(item, replace_to)
                    record_special_tokens[replace_to].append(item)
                else:
                    unpropcessed_items.add(item)
            
            tokenized_data = clean_tokenized_data(tokenizer.tokenize(new_data))

            for w in tokenized_data:
                word_counts[w] = word_counts.get(w, 0) + 1

            length = len(tokenized_data)
            length_counts[length] = length_counts.get(length, 0) + 1
            length_to_key[length].append(file)
        
        with open(os.path.join(trg_path, file), 'w') as wf:
            wf.write(' '.join(tokenized_data))
    
    if save_pickle:
        os.makedirs(save_pickle_path, exist_ok=True)
        pickle.dump(length_counts, open(os.path.join(save_pickle_path, f'{save_pickle_prefix}length_counts.pkl'), 'wb'))
        pickle.dump(length_to_key, open(os.path.join(save_pickle_path, f'{save_pickle_prefix}length_to_key.pkl'), 'wb'))
        pickle.dump(word_counts, open(os.path.join(save_pickle_path, f'{save_pickle_prefix}word_counts.pkl'), 'wb'))
    
    if print_info:
        print('unprocessed items', len(unpropcessed_items), unpropcessed_items)
        print('-' * 50)
        for k in record_special_tokens.keys():
            print('frequency of', k, len(record_special_tokens[k]))

    return unpropcessed_items, record_special_tokens, word_counts, length_counts, length_to_key


def postprocess_patient_instructions(
        src_path='./patient_instructions/',
        trg_path='./processed_patient_instructions/',
        **kwargs,
    ):
    assert os.path.exists(src_path)

    return replace_special_tokens(
        src_path=src_path,
        trg_path=trg_path,
        **kwargs
    )


def postprocess_health_records(
        subjects_to_keys=None,
        src_path='./health_records', 
        trg_path='./processed_health_records',
        **kwargs
    ):

    assert os.path.exists(src_path)
    
    if subjects_to_keys is not None:
        all_file_paths = []
        for subject in subjects_to_keys:
            for key in subjects_to_keys[subject]:
                all_file_paths.append(os.path.join(src_path, key))
    else:
        all_file_paths = None

    return replace_special_tokens(
        src_path=src_path,
        trg_path=trg_path,
        all_file_paths=all_file_paths,
        **kwargs
    )

In [12]:
_, _, pi_word_counts, pi_length_counts, pi_length_to_key = postprocess_patient_instructions(
    src_path='./patient_instructions/',
    trg_path='./processed_patient_instructions/',
    mapping=replaced_tokens_mapping, 
    save_pickle_path='./info',
    save_pickle_prefix='pi_',
)

100%|██████████| 39275/39275 [02:10<00:00, 300.17it/s]


unprocessed items 0 set()
--------------------------------------------------
frequency of  [date]  14092
frequency of  [hospital]  9826
frequency of  [name]  32349
frequency of  [address]  1480
frequency of  [contact]  9190
frequency of  [number]  56
frequency of  [identifier]  153
frequency of  [company]  97
frequency of  [state]  120
frequency of  [country]  22
frequency of  [provider]  10
frequency of  [age]  11
frequency of  [holiday]  8
frequency of  [university]  22
frequency of  [url]  12


In [13]:
_, _, hr_word_counts, hr_length_counts, hr_length_to_key = postprocess_health_records(
    src_path='./health_records/',
    trg_path='./processed_health_records/',
    mapping=replaced_tokens_mapping,
    save_pickle_path='./info',
    save_pickle_prefix='hr_'
)

100%|██████████| 39275/39275 [18:59<00:00, 34.46it/s] 


unprocessed items 0 set()
--------------------------------------------------
frequency of  [name]  300240
frequency of  [age]  5014
frequency of  [hospital]  162902
frequency of  [date]  1012436
frequency of  [state]  2750
frequency of  [contact]  14373
frequency of  [address]  35402
frequency of  [company]  3518
frequency of  [number]  2160
frequency of  [clip]  945
frequency of  [identifier]  6803
frequency of  [info]  55
frequency of  [country]  3390
frequency of  [holiday]  242
frequency of  [university]  538
frequency of  [url]  2
frequency of  [provider]  23


# 5. Split the data into train/val/test

In [16]:
import os
from sklearn import model_selection
from collections import defaultdict
import plotly.offline as py
import plotly.graph_objs as go
import pandas as pd
import pickle
from tqdm import tqdm

def plot_length_distribution(subjects, subjects_to_keys, key_to_length):
    length_counts = {}
    minl, maxl = 1e4, 0
    for subject in subjects:
        for key in subjects_to_keys[subject]:
            l = key_to_length[key]
            minl = min(minl, l)
            maxl = max(maxl, l)
            length_counts[l] = length_counts.get(l, 0) + 1
    
    x = [str(i) for i in range(minl, maxl+1)]
    y = [length_counts.get(i, 0) for i in range(minl, maxl+1)]
    py.iplot([go.Bar(x=x, y=y, marker={'color': 'red', 'line': {'width': 0.1, 'color': 'red'}})])


def filter_addendum_in_health_records(
        src_path='./processed_health_records',
        keys=['nsu addendum', 'service : addendum', 'med addendum', 'nsurg addendum', 
        'neurosurgery addendum', 'surgery addendum', 'neonatology addendum', 'neonatology addendum', 
        'ccu addendum', 'acove addendum', 'this is an addendum', 'medicine addendum']
    ):

    assert os.path.exists(src_path)
    filter_keys = set()

    for key in tqdm(os.listdir(src_path)):
        with open(os.path.join(src_path, key), 'r') as f:
            data = f.read()
        
        flag = False
        for k in keys:
            if k in data:
                flag = True
                break
        
        if flag:
            filter_keys.add(key)
    
    return filter_keys


def splitting(
        ratios=[0.2, 0.5], # (1) the ratio of non-training data; (2) the ratio of testing data in non-training data
        length_range=[4, 1e6],
        src_paths={'discharge_instruction': './processed_patient_instructions', 'discharge_summary': './processed_health_records'},
        root_path='./splits',
        names=['train', 'val', 'test'],
        plot=True,
        pi_length_to_key=None,
        pi_length_to_key_path='./info/pi_length_to_key.pkl',
        ):

    for src_path in src_paths.values():
        assert os.path.exists(src_path)
    os.makedirs(root_path, exist_ok=True)
    print('save to', root_path)
    
    valid_subjects_to_keys = defaultdict(list)
    key_to_length = {}
    valid_keys = set()
    
    assert len(length_range) == 2
    min_length, max_length = length_range

    filter_keys = filter_addendum_in_health_records(src_path=src_paths['discharge_summary'])
    if pi_length_to_key is None:
        assert os.path.exists(pi_length_to_key_path)
        pi_length_to_key = pickle.load(open(pi_length_to_key_path, 'rb'))

    for l, keys in pi_length_to_key.items():
        if min_length < l < max_length:
            for k in keys:
                if k in filter_keys:
                    continue

                subject = k.split('_')[0]
                valid_subjects_to_keys[subject].append(k)
                key_to_length[k] = l
                valid_keys.add(k)
    
    assert len(ratios) == 2
    non_train_ratio, test_ratio = ratios

    subjects_train, subjects_non_train = model_selection.train_test_split(
        list(valid_subjects_to_keys.keys()), test_size=non_train_ratio, random_state=0)
    
    subjects_val, subjects_test = model_selection.train_test_split(
        subjects_non_train, test_size=test_ratio, random_state=0)
    
    for subjects, name in zip([subjects_train, subjects_val, subjects_test], names):
        csv_path = os.path.join(root_path, f'{name}.csv')

        print(f'- There are {len(subjects)} `{name}` subjects')
        count = 0
        for subject in tqdm(subjects):
            for key in valid_subjects_to_keys[subject]:
                this_dict = {'key': [key]}
                for k, src_path in src_paths.items():
                    with open(os.path.join(src_path, key), 'r') as f:
                        data = f.read()
                    this_dict[k] = [data]

                df = pd.DataFrame(this_dict)
                if count == 0:
                    df.to_csv(csv_path, header=True, mode='w')
                else:
                    df.to_csv(csv_path, header=False, mode='a')
                count += 1
        
        print(f'- There are {count} `{name}` discharge instructions')
        
        if plot:
            plot_length_distribution(subjects, valid_subjects_to_keys, key_to_length)

    return subjects_train, subjects_val, subjects_test, valid_subjects_to_keys, key_to_length, valid_keys

In [17]:
subjects_train, subjects_val, subjects_test, valid_subjects_to_keys, key_to_length, valid_keys = splitting(root_path='./splits')

save to ./splits


100%|██████████| 39275/39275 [00:19<00:00, 2027.49it/s]


- There are 22423 `train` subjects


100%|██████████| 22423/22423 [01:01<00:00, 363.12it/s]


- There are 28673 `train` discharge instructions


- There are 2803 `val` subjects


100%|██████████| 2803/2803 [00:08<00:00, 314.59it/s]

- There are 3557 `val` discharge instructions





- There are 2803 `test` subjects


100%|██████████| 2803/2803 [00:08<00:00, 339.61it/s]

- There are 3621 `test` discharge instructions





In [18]:
# `discharge instruction` corresponds to `patient instruction`
# `discharge summary` corresponds to `health record`
data = pd.read_csv('./splits/test.csv')
data.head()

Unnamed: 0.1,Unnamed: 0,key,discharge_instruction,discharge_summary
0,0,56112_131325,you had an exacerbation of your copd and was t...,service : medicine allergies : patient recorde...
1,0,56112_127962,patient admitted with dyspnea . cardiac arrest...,service : medicine allergies : patient recorde...
2,0,92615_102566,you were admitted with hand pain and low blood...,service : medicine allergies : patient recorde...
3,0,81847_115715,"dear ms . [name] , it was a pleasure taking ca...",service : medicine allergies : sulfa ( sulfona...
4,0,93472_174433,you were admitted to the hospital after you ha...,service : surgery allergies : no known allergi...


# 6. Get the vocab from the training data

In [19]:
never_split_tokens = list(replaced_tokens_mapping.keys())
special_tokens_mapping = {
    'pad_token': '[pad]',
    'bos_token': '[bos]',
    'eos_token': '[eos]',
    'unk_token': '[unk]',
    'mask_token': '[mask]',    # unused
    'sep_token': '[sep]',      # unused
    'cls_token': '[cls]',      # unused
}

train_data = pd.read_csv('./splits/train.csv')
joint_word_counts = {}
tokenizer = BasicTokenizer(never_split=never_split_tokens)

for i in tqdm(range(len(train_data))):
    di = tokenizer.tokenize(train_data.iloc[i]['discharge_instruction'])
    ds = tokenizer.tokenize(train_data.iloc[i]['discharge_summary'])

    for w in di:
        joint_word_counts[w] = joint_word_counts.get(w, 0) + 1
    
    for w in ds:
        joint_word_counts[w] = joint_word_counts.get(w, 0) + 1

os.makedirs('./info', exist_ok=True)
pickle.dump(joint_word_counts, open('./info/joint_train_word_counts.pkl', 'wb'))

100%|██████████| 28673/28673 [13:19<00:00, 35.85it/s]


In [20]:
sorted_words = sorted(joint_word_counts.items(), key=lambda x: -x[1])
threshold = 20
tmp = [item[1] for item in sorted_words if item[1] > threshold]

total_count = sum([item[1] for item in sorted_words])
valid_count = sum(tmp)

n_total_word = len(joint_word_counts.keys())
n_kept_word = len(tmp)

print(f'Only keep words appearing more than {threshold} times')
print(f'{valid_count} of {total_count} ({valid_count * 100.0 / total_count:.2f} %) words are kept')
print(f'# unique words before filtering: {n_total_word}')
print(f'# unique words after filtering: {n_kept_word}')
print(f'# special tokens: {len(special_tokens_mapping)}')

save_path = './vocab'
os.makedirs(save_path, exist_ok=True)


words = list(special_tokens_mapping.values()) + [item[0] for item in sorted_words if item[1] > threshold]

vocab_file = './tmp_vocab'
with open(vocab_file, 'w') as wf:
    wf.write('\n'.join(words))

tokenizer = NaiveTokenizer(
    vocab_file, 
    never_split=never_split_tokens,
    **special_tokens_mapping
)

print('vocab size:', tokenizer.vocab_size)
print('special tokens:', tokenizer.all_special_tokens)
print('never split tokens:', never_split_tokens)

print('save vocab information to', save_path)
tokenizer.save_pretrained(save_path)

os.remove(vocab_file)

Only keep words appearing more than 20 times
65926402 of 66221429 (99.55 %) words are kept
# unique words before filtering: 122725
# unique words after filtering: 19893
# special tokens: 7
vocab size: 19900
special tokens: ['[bos]', '[eos]', '[unk]', '[sep]', '[pad]', '[cls]', '[mask]']
never split tokens: ['[date]', '[hospital]', '[contact]', '[address]', '[identifier]', '[country]', '[university]', '[company]', '[provider]', '[url]', '[state]', '[holiday]', '[number]', '[name]', '[age]', '[clip]', '[info]']
save vocab information to ./vocab


# 7. Print some statistics

In [21]:
tokenizer = NaiveTokenizer.from_pretrained('./vocab/joint/threshold20')

for p in ['./splits/train.csv', './splits/val.csv', './splits/test.csv']:
    data = pd.read_csv(p)
    li = 0
    lo = 0
    for i in range(len(data)):
        input = tokenizer.tokenize(data.iloc[i]['discharge_summary'])
        output = tokenizer.tokenize(data.iloc[i]['discharge_instruction'])

        li += len(input)
        lo += len(output)
    
    print(p, 'average input length', li * 1.0 / len(data))
    print(p, 'average output length', lo * 1.0 / len(data))
    

./splits/train.csv average input length 2147.0526976598194
./splits/train.csv average output length 162.48690405608065
./splits/val.csv average input length 2144.9398369412424
./splits/val.csv average output length 164.47877424796175
./splits/test.csv average input length 2124.2695388014363
./splits/test.csv average output length 162.8224247445457
