In [1]:
import xml.etree.ElementTree as ET
import os
import csv

In [2]:
class MyFile(object):
    def __init__ (self, file_name):
        tree = ET.parse('./raw_complete/{}'.format(file_name))
        root = tree.getroot()
        info = root._children[1]
        self.text = root._children[0].text

        self.patient_info = self.get_patient_info(info)
        self.file_info = {'patient_id': file_name[:3], 'encounter_id': file_name[4:6]}

        
        med_list = list()
        diagnosis_list = list()
        for each in info._children:
            if each.tag == 'MEDICATION':
                med_list.append(self.get_sub_tag_info(each))
            if each.tag not in ['MEDICATION','SMOKER', 'FAMILY_HIST', 'PHI']: #should give all diagnoses
                diagnosis_list.append(self.get_sub_tag_info(each))
        self.meds = med_list
        self.diagnosis = diagnosis_list
        
    def clean_attribute(self, a_tag, a_key):
        try:
            word = a_tag.attrib[a_key]
        except KeyError:
            word = 'NA'
        word = ' '.join(word.split()) #removes excessive whitespace
        word = word.replace(',', '-') #remove commas to prepare for csv
        if len(word) == 0:
            word = 'NA'
        return word

    def get_sub_tag_info(self, a_tag):
        my_list = [a_tag.tag]

        for x in ['id', 'time', 'type1', 'type2']:
            my_list.append(self.clean_attribute(a_tag, x))

        child_list = list()
        for child in a_tag._children:
            text = self.clean_attribute(child, 'text')
            try:
                text = self.clean_attribute(child, 'indicator') + ': ' + text
            except KeyError:
                pass
            child_list.append(text)

        my_list.append("; ".join(set(child_list)))
        return my_list

    def get_patient_info(self, a_tag):
        info_dict = {}
        for each in a_tag._children:
            if each.tag == 'PHI':
                if each.attrib['TYPE'] not in info_dict.keys():
                    info_dict[each.attrib['TYPE']] = each.attrib['text']
            if each.tag == 'SMOKER':
                if 'SMOKER' not in info_dict.keys():
                    info_dict['SMOKER'] = each.attrib['status']
        info_dict['SEX'] = self.determine_sex()
        return info_dict
    
    def determine_sex(self):
        text = self.text.strip().lower()

        female_count = sum([text.count((" " + x + " ")) for x in [
                    'female', 'woman', 'her', 'she' 'yo f']])
        male_count = sum([text.count((" " + x + " ")) for x in [
                    'male', 'man', 'his', 'he', 'yo m', 'gentleman']])
        
        if female_count > male_count:
            return 'female'
        elif male_count > female_count:
            return 'male'
        else:
            return 'NA'

    def write(self, what, where):
        if what == 'meds':
            the_list = self.meds
        elif what == 'diagnosis':
            the_list = self.diagnosis
        else:
            raise ValueError('"what" must be "meds" or "diagnosis"')
        with open(where, 'a') as csvfile:
            spamwriter = csv.writer(csvfile, delimiter=',',
                                    quotechar='|', quoting=csv.QUOTE_MINIMAL)
            leading_info = [self.file_info['patient_id'], self.file_info['encounter_id']]
            for each in ['DATE', 'SEX', 'AGE', 'SMOKER']:
                try:
                    leading_info.append(self.patient_info[each])
                except KeyError:
                    leading_info.append('NA')

            for each in the_list:
                next_row = leading_info + each
                spamwriter.writerow(next_row)


In [3]:
def restart_file(file_name):
    with open(file_name, 'wb') as csvfile:
        spamwriter = csv.writer(csvfile, delimiter=',',
                                quotechar='|', quoting=csv.QUOTE_MINIMAL)
        spamwriter.writerow(['patient_id', 'encounter_id', 'encounter_date', 'patient_sex', 'patient_age', 
                             'smoker_status', 'tag_name', 'tag_id', 'tag_time', 'tag_type1',
                             'tag_type2', 'tag_child_text'])

In [4]:
file_name_list = list()
for each_file in os.listdir("./raw_complete"):
    if each_file.endswith(".xml"):
        file_name_list.append(each_file)

In [5]:
for each_csv in  ['meds.csv', 'diagnosis.csv']:
    restart_file(each_csv)

for file_name in file_name_list:
    next_file = MyFile(file_name)
    next_file.write('meds', 'meds.csv')
    next_file.write('diagnosis', 'diagnosis.csv')