In [1]:
from xml.etree import ElementTree as ET

def parse_clinical_results(root):
    """
    Parse the clinical results section from XML and return structured data.
    """
    clinical_results = {}
    
    # Find clinical_results section
    results_section = root.find('clinical_results')
    if results_section is None:
        return clinical_results
    
    # Parse participant flow
    participant_flow = {}
    flow_section = results_section.find('participant_flow')
    if flow_section is not None:
        # Parse groups
        groups = []
        group_list = flow_section.find('group_list')
        if group_list is not None:
            for group in group_list.findall('group'):
                group_data = {
                    'group_id': group.get('group_id', ''),
                    'title': group.find('title').text if group.find('title') is not None else '',
                    'description': group.find('description').text if group.find('description') is not None else ''
                }
                groups.append(group_data)
        
        # Parse periods and milestones
        periods = []
        period_list = flow_section.find('period_list')
        if period_list is not None:
            for period in period_list.findall('period'):
                period_data = {
                    'title': period.find('title').text if period.find('title') is not None else '',
                    'milestones': []
                }
                
                milestone_list = period.find('milestone_list')
                if milestone_list is not None:
                    for milestone in milestone_list.findall('milestone'):
                        milestone_data = {
                            'title': milestone.find('title').text if milestone.find('title') is not None else '',
                            'participants': []
                        }
                        
                        participants_list = milestone.find('participants_list')
                        if participants_list is not None:
                            for participants in participants_list.findall('participants'):
                                participant_data = {
                                    'group_id': participants.get('group_id', ''),
                                    'count': participants.get('count', '')
                                }
                                milestone_data['participants'].append(participant_data)
                        
                        period_data['milestones'].append(milestone_data)
                
                # Parse drop/withdraw reasons
                drop_reasons = []
                drop_list = period.find('drop_withdraw_reason_list')
                if drop_list is not None:
                    for reason in drop_list.findall('drop_withdraw_reason'):
                        reason_data = {
                            'title': reason.find('title').text if reason.find('title') is not None else '',
                            'participants': []
                        }
                        
                        participants_list = reason.find('participants_list')
                        if participants_list is not None:
                            for participants in participants_list.findall('participants'):
                                participant_data = {
                                    'group_id': participants.get('group_id', ''),
                                    'count': participants.get('count', '')
                                }
                                reason_data['participants'].append(participant_data)
                        
                        drop_reasons.append(reason_data)
                
                period_data['drop_withdraw_reasons'] = drop_reasons
                periods.append(period_data)
        
        participant_flow = {
            'groups': groups,
            'periods': periods
        }
    
    # Parse baseline characteristics
    baseline = {}
    baseline_section = results_section.find('baseline')
    if baseline_section is not None:
        # Parse baseline groups
        baseline_groups = []
        group_list = baseline_section.find('group_list')
        if group_list is not None:
            for group in group_list.findall('group'):
                group_data = {
                    'group_id': group.get('group_id', ''),
                    'title': group.find('title').text if group.find('title') is not None else '',
                    'description': group.find('description').text if group.find('description') is not None else ''
                }
                baseline_groups.append(group_data)
        
        # Parse analyzed participants
        analyzed_list = []
        analyzed_section = baseline_section.find('analyzed_list')
        if analyzed_section is not None:
            for analyzed in analyzed_section.findall('analyzed'):
                analyzed_data = {
                    'units': analyzed.find('units').text if analyzed.find('units') is not None else '',
                    'scope': analyzed.find('scope').text if analyzed.find('scope') is not None else '',
                    'counts': []
                }
                
                count_list = analyzed.find('count_list')
                if count_list is not None:
                    for count in count_list.findall('count'):
                        count_data = {
                            'group_id': count.get('group_id', ''),
                            'value': count.get('value', '')
                        }
                        analyzed_data['counts'].append(count_data)
                
                analyzed_list.append(analyzed_data)
        
        # Parse baseline measures
        measures = []
        measure_list = baseline_section.find('measure_list')
        if measure_list is not None:
            for measure in measure_list.findall('measure'):
                measure_data = {
                    'title': measure.find('title').text if measure.find('title') is not None else '',
                    'description': measure.find('description').text if measure.find('description') is not None else '',
                    'units': measure.find('units').text if measure.find('units') is not None else '',
                    'param': measure.find('param').text if measure.find('param') is not None else '',
                    'classes': []
                }
                
                class_list = measure.find('class_list')
                if class_list is not None:
                    for class_elem in class_list.findall('class'):
                        class_data = {
                            'title': class_elem.find('title').text if class_elem.find('title') is not None else '',
                            'categories': []
                        }
                        
                        category_list = class_elem.find('category_list')
                        if category_list is not None:
                            for category in category_list.findall('category'):
                                category_data = {
                                    'title': category.find('title').text if category.find('title') is not None else '',
                                    'measurements': []
                                }
                                
                                measurement_list = category.find('measurement_list')
                                if measurement_list is not None:
                                    for measurement in measurement_list.findall('measurement'):
                                        measurement_data = {
                                            'group_id': measurement.get('group_id', ''),
                                            'value': measurement.get('value', ''),
                                            'spread': measurement.get('spread', ''),
                                            'lower_limit': measurement.get('lower_limit', ''),
                                            'upper_limit': measurement.get('upper_limit', '')
                                        }
                                        category_data['measurements'].append(measurement_data)
                                
                                class_data['categories'].append(category_data)
                        
                        measure_data['classes'].append(class_data)
                
                measures.append(measure_data)
        
        baseline = {
            'groups': baseline_groups,
            'analyzed_list': analyzed_list,
            'measures': measures
        }
    
    # Parse outcomes
    outcomes = []
    outcome_list = results_section.find('outcome_list')
    if outcome_list is not None:
        for outcome in outcome_list.findall('outcome'):
            outcome_data = {
                'type': outcome.find('type').text if outcome.find('type') is not None else '',
                'title': outcome.find('title').text if outcome.find('title') is not None else '',
                'description': outcome.find('description').text if outcome.find('description') is not None else '',
                'time_frame': outcome.find('time_frame').text if outcome.find('time_frame') is not None else '',
                'population': outcome.find('population').text if outcome.find('population') is not None else '',
                'groups': [],
                'measures': []
            }
            
            # Parse outcome groups
            group_list = outcome.find('group_list')
            if group_list is not None:
                for group in group_list.findall('group'):
                    group_data = {
                        'group_id': group.get('group_id', ''),
                        'title': group.find('title').text if group.find('title') is not None else '',
                        'description': group.find('description').text if group.find('description') is not None else ''
                    }
                    outcome_data['groups'].append(group_data)
            
            # Parse outcome measures
            measure_elem = outcome.find('measure')
            if measure_elem is not None:
                measure_data = {
                    'title': measure_elem.find('title').text if measure_elem.find('title') is not None else '',
                    'description': measure_elem.find('description').text if measure_elem.find('description') is not None else '',
                    'population': measure_elem.find('population').text if measure_elem.find('population') is not None else '',
                    'units': measure_elem.find('units').text if measure_elem.find('units') is not None else '',
                    'param': measure_elem.find('param').text if measure_elem.find('param') is not None else '',
                    'dispersion': measure_elem.find('dispersion').text if measure_elem.find('dispersion') is not None else '',
                    'analyzed_list': [],
                    'classes': []
                }
                
                # Parse analyzed participants for this measure
                analyzed_list = measure_elem.find('analyzed_list')
                if analyzed_list is not None:
                    for analyzed in analyzed_list.findall('analyzed'):
                        analyzed_data = {
                            'units': analyzed.find('units').text if analyzed.find('units') is not None else '',
                            'scope': analyzed.find('scope').text if analyzed.find('scope') is not None else '',
                            'counts': []
                        }
                        
                        count_list = analyzed.find('count_list')
                        if count_list is not None:
                            for count in count_list.findall('count'):
                                count_data = {
                                    'group_id': count.get('group_id', ''),
                                    'value': count.get('value', '')
                                }
                                analyzed_data['counts'].append(count_data)
                        
                        measure_data['analyzed_list'].append(analyzed_data)
                
                # Parse measure classes
                class_list = measure_elem.find('class_list')
                if class_list is not None:
                    for class_elem in class_list.findall('class'):
                        class_data = {
                            'title': class_elem.find('title').text if class_elem.find('title') is not None else '',
                            'categories': []
                        }
                        
                        category_list = class_elem.find('category_list')
                        if category_list is not None:
                            for category in category_list.findall('category'):
                                category_data = {
                                    'title': category.find('title').text if category.find('title') is not None else '',
                                    'measurements': []
                                }
                                
                                measurement_list = category.find('measurement_list')
                                if measurement_list is not None:
                                    for measurement in measurement_list.findall('measurement'):
                                        measurement_data = {
                                            'group_id': measurement.get('group_id', ''),
                                            'value': measurement.get('value', ''),
                                            'spread': measurement.get('spread', ''),
                                            'lower_limit': measurement.get('lower_limit', ''),
                                            'upper_limit': measurement.get('upper_limit', '')
                                        }
                                        category_data['measurements'].append(measurement_data)
                                
                                class_data['categories'].append(category_data)
                        
                        measure_data['classes'].append(class_data)
                
                outcome_data['measures'].append(measure_data)
            
            outcomes.append(outcome_data)
    
    # Parse reported events (adverse events)
    reported_events = {}
    events_section = results_section.find('reported_events')
    if events_section is not None:
        reported_events = {
            'time_frame': events_section.find('time_frame').text if events_section.find('time_frame') is not None else '',
            'desc': events_section.find('desc').text if events_section.find('desc') is not None else '',
            'groups': [],
            'serious_events': [],
            'other_events': []
        }
        
        # Parse event groups
        group_list = events_section.find('group_list')
        if group_list is not None:
            for group in group_list.findall('group'):
                group_data = {
                    'group_id': group.get('group_id', ''),
                    'title': group.find('title').text if group.find('title') is not None else '',
                    'description': group.find('description').text if group.find('description') is not None else ''
                }
                reported_events['groups'].append(group_data)
        
        # Parse serious events
        serious_events = events_section.find('serious_events')
        if serious_events is not None:
            serious_data = {
                'default_vocab': serious_events.get('default_vocab', ''),
                'default_assessment': serious_events.get('default_assessment', ''),
                'categories': []
            }
            
            category_list = serious_events.find('category_list')
            if category_list is not None:
                for category in category_list.findall('category'):
                    category_data = {
                        'title': category.find('title').text if category.find('title') is not None else '',
                        'events': []
                    }
                    
                    event_list = category.find('event_list')
                    if event_list is not None:
                        for event in event_list.findall('event'):
                            event_data = {
                                'sub_title': event.find('sub_title').text if event.find('sub_title') is not None else '',
                                'assessment': event.get('assessment', ''),
                                'counts': []
                            }
                            
                            counts_elem = event.find('counts')
                            if counts_elem is not None:
                                for count_attr in ['subjects_affected', 'subjects_at_risk', 'events']:
                                    if counts_elem.get(count_attr):
                                        count_data = {
                                            'group_id': counts_elem.get('group_id', ''),
                                            'type': count_attr,
                                            'value': counts_elem.get(count_attr, '')
                                        }
                                        event_data['counts'].append(count_data)
                            
                            category_data['events'].append(event_data)
                    
                    serious_data['categories'].append(category_data)
            
            reported_events['serious_events'] = serious_data
        
        # Parse other events (similar structure to serious events)
        other_events = events_section.find('other_events')
        if other_events is not None:
            other_data = {
                'frequency_threshold': other_events.get('frequency_threshold', ''),
                'default_vocab': other_events.get('default_vocab', ''),
                'default_assessment': other_events.get('default_assessment', ''),
                'categories': []
            }
            
            category_list = other_events.find('category_list')
            if category_list is not None:
                for category in category_list.findall('category'):
                    category_data = {
                        'title': category.find('title').text if category.find('title') is not None else '',
                        'events': []
                    }
                    
                    event_list = category.find('event_list')
                    if event_list is not None:
                        for event in event_list.findall('event'):
                            event_data = {
                                'sub_title': event.find('sub_title').text if event.find('sub_title') is not None else '',
                                'assessment': event.get('assessment', ''),
                                'counts': []
                            }
                            
                            counts_elem = event.find('counts')
                            if counts_elem is not None:
                                for count_attr in ['subjects_affected', 'subjects_at_risk', 'events']:
                                    if counts_elem.get(count_attr):
                                        count_data = {
                                            'group_id': counts_elem.get('group_id', ''),
                                            'type': count_attr,
                                            'value': counts_elem.get(count_attr, '')
                                        }
                                        event_data['counts'].append(count_data)
                            
                            category_data['events'].append(event_data)
                    
                    other_data['categories'].append(category_data)
            
            reported_events['other_events'] = other_data
    
    # Parse agreements and point of contact
    agreements = {}
    agreements_section = results_section.find('certain_agreements')
    if agreements_section is not None:
        agreements = {
            'pi_employee': agreements_section.find('pi_employee').text if agreements_section.find('pi_employee') is not None else '',
            'restrictive_agreement': agreements_section.find('restrictive_agreement').text if agreements_section.find('restrictive_agreement') is not None else ''
        }
    
    point_of_contact = {}
    contact_section = results_section.find('point_of_contact')
    if contact_section is not None:
        point_of_contact = {
            'name_or_title': contact_section.find('name_or_title').text if contact_section.find('name_or_title') is not None else '',
            'organization': contact_section.find('organization').text if contact_section.find('organization') is not None else '',
            'phone': contact_section.find('phone').text if contact_section.find('phone') is not None else '',
            'email': contact_section.find('email').text if contact_section.find('email') is not None else ''
        }
    
    # Assemble clinical results
    clinical_results = {
        'participant_flow': participant_flow,
        'baseline': baseline,
        'outcomes': outcomes,
        'reported_events': reported_events,
        'certain_agreements': agreements,
        'point_of_contact': point_of_contact
    }
    
    return clinical_results


def xmlfile2results(xml_file):
    """
    Parse clinical trial XML file and return a dictionary with extracted data.
    """
    tree = ET.parse(xml_file)
    root = tree.getroot()
    
    # Basic study identifiers
    nctid = root.find('id_info/nct_id').text if root.find('id_info/nct_id') is not None else ''
    org_study_id = root.find('id_info/org_study_id').text if root.find('id_info/org_study_id') is not None else ''
    url = root.find('required_header/url').text if root.find('required_header/url') is not None else ''
    
    # Titles - handle both brief_title and official_title
    brief_title = root.find('brief_title').text if root.find('brief_title') is not None else ''
    official_title = root.find('official_title').text if root.find('official_title') is not None else ''
    
    # Sponsors and collaborators
    lead_sponsor = ''
    collaborators = []
    sponsors = root.find('sponsors')
    if sponsors is not None:
        lead_sponsor_elem = sponsors.find('lead_sponsor/agency')
        if lead_sponsor_elem is not None:
            lead_sponsor = lead_sponsor_elem.text
        
        for collab in sponsors.findall('collaborator/agency'):
            if collab is not None:
                collaborators.append(collab.text)
    
    # Study descriptions
    brief_summary = ''
    brief_summary_elem = root.find('brief_summary/textblock')
    if brief_summary_elem is not None:
        brief_summary = brief_summary_elem.text.strip() if brief_summary_elem.text else ''
    
    detailed_description = ''
    detailed_description_elem = root.find('detailed_description/textblock')
    if detailed_description_elem is not None:
        detailed_description = detailed_description_elem.text.strip() if detailed_description_elem.text else ''
    
    # Study type and phase
    study_type = root.find('study_type').text if root.find('study_type') is not None else ''
    phase = root.find('phase').text if root.find('phase') is not None else ''
    
    # Status and dates
    overall_status = root.find('overall_status').text if root.find('overall_status') is not None else ''
    why_stopped = root.find('why_stopped').text if root.find('why_stopped') is not None else ''

    # create an inferred label from the overall_status and why_stopped
    
    
    # Handle dates with potential type attributes
    start_date = ''
    start_date_type = ''
    start_date_elem = root.find('start_date')
    if start_date_elem is not None:
        start_date = start_date_elem.text if start_date_elem.text else ''
        start_date_type = start_date_elem.get('type', '')
    
    completion_date = ''
    completion_date_type = ''
    # Check both completion_date and primary_completion_date
    completion_date_elem = root.find('completion_date')
    if completion_date_elem is None:
        completion_date_elem = root.find('primary_completion_date')
    if completion_date_elem is not None:
        completion_date = completion_date_elem.text if completion_date_elem.text else ''
        completion_date_type = completion_date_elem.get('type', '')
    
    study_first_posted = ''
    study_first_posted_elem = root.find('study_first_posted')
    if study_first_posted_elem is not None:
        study_first_posted = study_first_posted_elem.text if study_first_posted_elem.text else ''
    
    # Interventions - handle all types, not just drugs
    interventions = []
    for intervention in root.findall('intervention'):
        intervention_data = {}
        intervention_type_elem = intervention.find('intervention_type')
        intervention_name_elem = intervention.find('intervention_name')
        intervention_desc_elem = intervention.find('description')
        
        if intervention_type_elem is not None:
            intervention_data['type'] = intervention_type_elem.text
        if intervention_name_elem is not None:
            intervention_data['name'] = intervention_name_elem.text
        if intervention_desc_elem is not None:
            intervention_data['description'] = intervention_desc_elem.text

    # add intervention Mesh terms
    # Get only intervention mesh terms
    intervention_section = root.find('intervention_browse')
    if intervention_section is not None:
        intervention_mesh_terms = [term.text.strip() for term in intervention_section.findall('mesh_term')]
    else:
        intervention_mesh_terms = []
    
    # Extract drug interventions separately for backward compatibility
    drug_interventions = [i['name'] for i in interventions if i.get('type') == 'Drug' and 'name' in i]
    
    # Arm groups
    arm_groups = []
    for ag in root.findall('arm_group'):
        arm_data = {}
        label_elem = ag.find('arm_group_label')
        desc_elem = ag.find('description')
        type_elem = ag.find('arm_group_type')
        
        if label_elem is not None:
            arm_data['label'] = label_elem.text
        if desc_elem is not None:
            arm_data['description'] = desc_elem.text
        if type_elem is not None:
            arm_data['type'] = type_elem.text
        
        if arm_data:
            arm_groups.append(arm_data)
    
    # Study design information
    study_design_info = {}
    sdi = root.find('study_design_info')
    if sdi is not None:
        for child in sdi:
            if child.text:
                study_design_info[child.tag] = child.text

    # Clinical Results
    
    # Primary outcome
    primary_outcomes = []
    for po in root.findall('primary_outcome'):
        outcome_data = {}
        measure_elem = po.find('measure')
        time_frame_elem = po.find('time_frame')
        description_elem = po.find('description')
        
        if measure_elem is not None:
            outcome_data['measure'] = measure_elem.text
        if time_frame_elem is not None:
            outcome_data['time_frame'] = time_frame_elem.text
        if description_elem is not None:
            outcome_data['description'] = description_elem.text
        
        if outcome_data:
            primary_outcomes.append(outcome_data)
    
    # Secondary outcomes
    secondary_outcomes = []
    for so in root.findall('secondary_outcome'):
        outcome_data = {}
        measure_elem = so.find('measure')
        time_frame_elem = so.find('time_frame')
        description_elem = so.find('description')
        
        if measure_elem is not None:
            outcome_data['measure'] = measure_elem.text
        if time_frame_elem is not None:
            outcome_data['time_frame'] = time_frame_elem.text
        if description_elem is not None:
            outcome_data['description'] = description_elem.text
        
        if outcome_data:
            secondary_outcomes.append(outcome_data)
    
    # Conditions/indications
    conditions = [condition.text.strip() for condition in root.findall('condition')]
    
    # MeSH terms for conditions
    conditions_mesh_terms = []
    conditions_section = root.find('condition_browse')
    if conditions_section is not None:
        conditions_mesh_terms = [term.text.strip() for term in conditions_section.findall('mesh_term')]
    
    # Enrollment
    enrollment = ''
    enrollment_type = ''
    enrollment_elem = root.find('enrollment')
    if enrollment_elem is not None:
        enrollment = enrollment_elem.text if enrollment_elem.text else ''
        enrollment_type = enrollment_elem.get('type', '')
    
    # Eligibility criteria
    criteria = ''
    criteria_elem = root.find('eligibility/criteria/textblock')
    if criteria_elem is not None:
        criteria = criteria_elem.text.strip() if criteria_elem.text else ''
    
    # Gender, age constraints
    gender = ''
    minimum_age = ''
    maximum_age = ''
    healthy_volunteers = ''
    
    eligibility = root.find('eligibility')
    if eligibility is not None:
        gender_elem = eligibility.find('gender')
        min_age_elem = eligibility.find('minimum_age')
        max_age_elem = eligibility.find('maximum_age')
        healthy_elem = eligibility.find('healthy_volunteers')
        
        if gender_elem is not None:
            gender = gender_elem.text
        if min_age_elem is not None:
            minimum_age = min_age_elem.text
        if max_age_elem is not None:
            maximum_age = max_age_elem.text
        if healthy_elem is not None:
            healthy_volunteers = healthy_elem.text
    
    # Number of groups
    number_of_groups = root.find('number_of_groups').text if root.find('number_of_groups') is not None else ''
    
    # Locations
    locations = []
    for loc in root.findall('location'):
        location_data = {}
        facility = loc.find('facility')
        if facility is not None:
            name_elem = facility.find('name')
            if name_elem is not None:
                location_data['facility_name'] = name_elem.text
            
            address = facility.find('address')
            if address is not None:
                city_elem = address.find('city')
                state_elem = address.find('state')
                zip_elem = address.find('zip')
                country_elem = address.find('country')
                
                if city_elem is not None:
                    location_data['city'] = city_elem.text
                if state_elem is not None:
                    location_data['state'] = state_elem.text
                if zip_elem is not None:
                    location_data['zip'] = zip_elem.text
                if country_elem is not None:
                    location_data['country'] = country_elem.text
        
        # Location status
        status_elem = loc.find('status')
        if status_elem is not None:
            location_data['status'] = status_elem.text
        
        # Contact information
        contact = loc.find('contact')
        if contact is not None:
            contact_name = contact.find('last_name')
            contact_email = contact.find('email')
            contact_phone = contact.find('phone')
            
            if contact_name is not None:
                location_data['contact_name'] = contact_name.text
            if contact_email is not None:
                location_data['contact_email'] = contact_email.text
            if contact_phone is not None:
                location_data['contact_phone'] = contact_phone.text
        
        if location_data:
            locations.append(location_data)
    
    # Overall contacts
    overall_contact = {}
    contact = root.find('overall_contact')
    if contact is not None:
        name_elem = contact.find('last_name')
        email_elem = contact.find('email')
        phone_elem = contact.find('phone')
        
        if name_elem is not None:
            overall_contact['name'] = name_elem.text
        if email_elem is not None:
            overall_contact['email'] = email_elem.text
        if phone_elem is not None:
            overall_contact['phone'] = phone_elem.text
    
    # Oversight info
    oversight_info = {}
    oversight = root.find('oversight_info')
    if oversight is not None:
        for child in oversight:
            if child.text:
                oversight_info[child.tag] = child.text
    
    # Keywords
    keywords = []
    for keyword in root.findall('keyword'):
        if keyword.text:
            keywords.append(keyword.text)

    # Clinical Results
    clinical_results = parse_clinical_results(root)
    
    # Assemble the complete data dictionary
    data = {
        'nctid': nctid,
        'org_study_id': org_study_id,
        'url': url,
        'brief_title': brief_title,
        'official_title': official_title,
        'lead_sponsor': lead_sponsor,
        'collaborators': collaborators,
        'brief_summary': brief_summary,
        'detailed_description': detailed_description,
        'study_type': study_type,
        'phase': phase,
        'overall_status': overall_status,
        'why_stopped': why_stopped,
        'start_date': start_date,
        'start_date_type': start_date_type,
        'completion_date': completion_date,
        'completion_date_type': completion_date_type,
        'study_first_posted': study_first_posted,
        'interventions': interventions,
        'intervention_mesh_terms': intervention_mesh_terms,
        'drug_interventions': drug_interventions,  # for backward compatibility
        'arm_groups': arm_groups,
        'study_design_info': study_design_info,
        'primary_outcomes': primary_outcomes,
        'secondary_outcomes': secondary_outcomes,
        'conditions': conditions,
        'conditions_mesh_terms': conditions_mesh_terms,
        'enrollment': enrollment,
        'enrollment_type': enrollment_type,
        'criteria': criteria,
        'gender': gender,
        'minimum_age': minimum_age,
        'maximum_age': maximum_age,
        'healthy_volunteers': healthy_volunteers,
        'number_of_groups': number_of_groups,
        'locations': locations,
        'overall_contact': overall_contact,
        'oversight_info': oversight_info,
        'keywords': keywords,
        'clinical_results': clinical_results
    }
    
    return data


In [6]:
import os
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
import pickle

random_seed = 42
chunk_size = 100
feat_cats = ['brief_summary', 'detailed_description']
pdir = "data_processed/"
os.makedirs(pdir, exist_ok=True)

# Lists for final labels
overall_statuses = []
clinical_results = []
phases = []

# Temporary list for text features
X = []

# Loop through all folders
folders = os.listdir('raw_data/')
for idx, folder in enumerate(folders[:200]):
    base = os.path.join('raw_data', folder)
    if not os.path.isdir(base):
        continue
    
    for file in os.listdir(base):
        xml_path = os.path.join(base, file)
        try:
            data = xmlfile2results(xml_path)
        except Exception as e:
            print(f"Skipping {xml_path}: {e}")
            continue
        
        # Skip entries with missing phase
        phase = data.get('phase', '').strip()
        if not phase or phase.upper() == 'N/A':
            continue
        
        # Collect text features safely
        textfeats = ""
        for feat in feat_cats:
            textfeats += data.get(feat, '')
        textfeats += "".join(data.get('interventions', []))
        for po in data.get('primary_outcomes', []):
            textfeats += po.get('measure', '')
            textfeats += "".join(po.get('description', []))
        
        X.append(textfeats)
        overall_statuses.append(data.get('overall_status'))
        clinical_results.append(data.get('clinical_results'))
        phases.append(phase)
    
    # Save chunk if reached chunk size
    # if (idx + 1) % chunk_size == 0:
    #     np.save(os.path.join(pdir, f"X_{idx+1}.npy"), np.array(X_chunk,dtype=object), allow_pickle=True)
    #     X_chunk.clear()

# Save any remaining data
# if (idx + 1) % chunk_size == 0:
#     arr = np.asarray(X_chunk, dtype=object)   # <-- FORCE object array
#     np.save(os.path.join(pdir, f"X_{idx+1}.npy"), arr, allow_pickle=True)
#     X_chunk.clear()

# # Save any remaining data
# if X_chunk:
#     arr = np.asarray(X_chunk, dtype=object)   # <-- FORCE object array
#     np.save(os.path.join(pdir, "X_final.npy"), arr, allow_pickle=True)
#     X_chunk.clear()


# Create binary target y
y = np.array([
    1 if (str(overall_statuses[i]).lower() == 'completed' 
          and clinical_results[i] is not None
          and phases[i].lower() in ['phase 2/phase 3','phase 3','phase 4'])
    else 0
    for i in range(len(overall_statuses))
])

# Save labels
# np.save(os.path.join(pdir, "y.npy"), y)
# np.save(os.path.join(pdir, "phases.npy"), np.array(phases), allow_pickle=True)




In [9]:
# def load_and_concatenate_chunks(folder):
#     import glob
#     files = sorted(glob.glob(os.path.join(folder, "X_*.npy")))
#     arrays = [np.load(f, allow_pickle=True) for f in files]
#     return np.concatenate(arrays, axis=0)
# # train/test split   
# X_all = load_and_concatenate_chunks(pdir)  # see next step
# y = npload()
# print(len(X_all), len(y))
X_train, X_test, y_train, y_test = train_test_split(
    X, y, stratify=y, random_state=random_seed, shuffle=True, test_size=0.1
)


In [None]:
'''
TF-IDF: Total Frequency - Inverse Document Frequency
    Extracts Important words in documents, higher scores: greater relevance

Official Title
Brief Summary
https://www.geeksforgeeks.org/machine-learning/understanding-tf-idf-term-frequency-inverse-document-frequency/

'''
tfidf = TfidfVectorizer()
result=tfidf.fit_transform(X_train)
# print('result: ', result)
# print('\nidf values:')
# for ele1, ele2 in zip(tfidf.get_feature_names_out(), tfidf.idf_):
#     print(ele1, ':', ele2)

tfidf_means = np.asarray(result.mean(axis=0)).ravel()

K = 5000    # choose your number of features
top_idx = np.argsort(tfidf_means)[-K:]   # indices of top K tf-idf features
selected_terms = [tfidf.get_feature_names_out()[i] for i in top_idx]
vectorizer_reduced = TfidfVectorizer(vocabulary=selected_terms)
X_train_reduced = vectorizer_reduced.fit_transform(X_train)
# X_reduced
model = LogisticRegression(max_iter=5000)
model.fit(X_train_reduced, y_train)
X_test_reduced = vectorizer_reduced.transform(X_test)
y_preds = model.predict(X_test_reduced)
from sklearn.metrics import accuracy_score,precision_score, recall_score, f1_score
print(accuracy_score(y_test,y_preds), precision_score(y_test,y_preds), recall_score(y_test,y_preds), f1_score(y_test,y_preds))
print(y_preds)
print(y_test)



idf values:
00 : 6.283912222070448
000 : 4.681289936171408
0000 : 8.91920306889023
0000000000001848 : 11.170494867496725
000001 : 11.170494867496725
0000010028 : 11.170494867496725
000002 : 11.170494867496725
00001 : 10.071882578828616
000029 : 11.170494867496725
000036 : 11.170494867496725
0000364 : 11.170494867496725
00007 : 11.170494867496725
000089 : 11.170494867496725
0001 : 8.125972429773302
000102 : 10.765029759388561
000184 : 10.765029759388561
0001st : 11.170494867496725
0002 : 9.666417470720452
00022 : 11.170494867496725
00023 : 11.170494867496725
0003 : 10.25420413562257
000306 : 11.170494867496725
0003454 : 11.170494867496725
0004 : 10.071882578828616
000434 : 11.170494867496725
00045 : 11.170494867496725
00048923 : 11.170494867496725
0005 : 10.47734768693678
0005004439 : 10.765029759388561
00052 : 11.170494867496725
000596 : 11.170494867496725
0005a : 11.170494867496725
0005b : 11.170494867496725
0005c : 11.170494867496725
0006 : 9.784200506376834
0007 : 11.17049486749672

In [None]:
vectorizer = TfidfVectorizer(max_features=None)
X_train_vec = vectorizer.fit_transform(X_train)
clf = LogisticRegression(max_iter=5000)
clf.fit(X_train_vec, y_train)

coefs = clf.coef_.ravel()          # shape: (n_features,)
abs_coefs = np.abs(coefs)
top_idx = np.argsort(abs_coefs)[-K:]
selected_terms = vectorizer.get_feature_names_out()[top_idx]

vectorizer_reduced = TfidfVectorizer(vocabulary=selected_terms)
X_train_reduced = vectorizer_reduced.fit_transform(X_train)
clf_final = LogisticRegression(max_iter=5000)
clf_final.fit(X_train_reduced, y_train)
X_test_reduced = vectorizer_reduced.fit_transform(X_test)

preds = clf_final.predict(X_test_reduced)
print(accuracy_score(y_test,preds), precision_score(y_test,preds), recall_score(y_test,preds), f1_score(y_test,preds))
print(preds)
print(y_test)


0.802184673152415 0.5912208504801097 0.33359133126934987 0.4265215239980208
[0 1 0 ... 0 0 0]
[1 1 0 ... 0 0 0]


In [None]:
from sklearn.decomposition import TruncatedSVD

svd = TruncatedSVD(n_components=300)
X_svd = svd.fit_transform(X_tfidf)


np.int64(87)