From 3edbc6dc0b50d9a584adb48064216cf15766a7d1 Mon Sep 17 00:00:00 2001 From: caseyta Date: Tue, 28 May 2024 15:56:19 -0700 Subject: [PATCH 1/6] MCQ query backend implemented --- cohd/cohd.py | 59 +++-- cohd/query_cohd_mysql.py | 532 +++++++++++++++++++++++++++++++-------- requirements.txt | 3 +- 3 files changed, 464 insertions(+), 130 deletions(-) diff --git a/cohd/cohd.py b/cohd/cohd.py index 26d9158..0e04f2a 100644 --- a/cohd/cohd.py +++ b/cohd/cohd.py @@ -48,16 +48,44 @@ def api_cohd(): return redirect("https://cohd.smart-api.info/", code=302) +@app.route('/api/metadata/datasets') +def api_metadata_datasets(): + return query_cohd_mysql.query_db_datasets() + + +@app.route('/api/metadata/domainCounts') +def api_metadata_domainCounts(): + dataset_id = query_cohd_mysql.get_arg_dataset_id(request.args) + return query_cohd_mysql.query_db_domain_counts(dataset_id) + + +@app.route('/api/metadata/domainPairCounts') +def api_metadata_domainPairCounts(): + dataset_id = query_cohd_mysql.get_arg_dataset_id(request.args) + return query_cohd_mysql.query_db_domain_pair_counts(dataset_id) + + +@app.route('/api/metadata/patientCount') +def api_metadata_patientCount(): + dataset_id = query_cohd_mysql.get_arg_dataset_id(request.args) + return query_cohd_mysql.query_db_patient_count(dataset_id) + + @app.route('/api/omop/findConceptIDs') @app.route('/api/v1/omop/findConceptIDs') def api_omop_reference(): - return api_call('omop', 'findConceptIDs') + query = request.args.get('q') + dataset_id = query_cohd_mysql.get_arg_dataset_id(request.args) + domain_id = request.args.get('domain') + min_count = request.args.get('min_count') + return query_cohd_mysql.query_db_find_concept_ids(dataset_id, query, domain_id, min_count) @app.route('/api/omop/concepts') @app.route('/api/v1/omop/concepts') def api_omop_concepts(): - return api_call('omop', 'concepts') + query = request.args.get('q') + return query_cohd_mysql.query_db_concepts(query) @app.route('/api/omop/conceptAncestors') @@ -95,25 +123,6 @@ def api_omop_xrefFromOMOP(): return api_call('omop', 'xrefFromOMOP') -@app.route('/api/metadata/datasets') -def api_metadata_datasets(): - return api_call('metadata', 'datasets') - - -@app.route('/api/metadata/domainCounts') -def api_metadata_domainCounts(): - return api_call('metadata', 'domainCounts') - - -@app.route('/api/metadata/domainPairCounts') -def api_metadata_domainPairCounts(): - return api_call('metadata', 'domainPairCounts') - - -@app.route('/api/metadata/patientCount') -def api_metadata_patientCount(): - return api_call('metadata', 'patientCount') - @app.route('/api/frequencies/singleConceptFreq') @app.route('/api/v1/frequencies/singleConceptFreq') @@ -160,6 +169,11 @@ def api_association_relativeFrequency(): return api_call('association', 'relativeFrequency') +@app.route('/api/association/mcq') +def api_association_mcq(): + return api_call('association', 'mcq') + + @app.route('/api/temporal/conceptAgeCounts') def api_temporal_conceptAgeCounts(): return api_call('temporal', 'conceptAgeCounts') @@ -306,7 +320,8 @@ def api_call(service=None, meta=None, query=None, version=None): elif service == 'association': if meta == 'chiSquare' or \ meta == 'obsExpRatio' or \ - meta == 'relativeFrequency': + meta == 'relativeFrequency' or \ + meta == 'mcq': result = query_cohd_mysql.query_db(service, meta, request.args) else: result = 'meta not recognized', 400 diff --git a/cohd/query_cohd_mysql.py b/cohd/query_cohd_mysql.py index 7722ed6..1884318 100644 --- a/cohd/query_cohd_mysql.py +++ b/cohd/query_cohd_mysql.py @@ -3,6 +3,7 @@ from scipy.stats import chisquare from numpy import argsort import logging +import pandas as pd from .omop_xref import xref_to_omop_standard_concept, omop_map_to_standard, omop_map_from_standard, \ xref_from_omop_standard_concept, xref_from_omop_local, xref_to_omop_local @@ -87,133 +88,167 @@ def get_arg_boolean(args, param_name): return None -def query_db(service, method, args): +def query_db_finalize(conn, cursor, json_return): + logging.debug(cursor._executed) + logging.debug(json_return) - # Connect to MYSQL database + cursor.close() + conn.close() + + json_return = {"results": json_return} + return jsonify(json_return) + + +def query_db_datasets(): + # The datasets in the COHD database conn = sql_connection() cur = conn.cursor() + sql = '''SELECT * + FROM cohd.dataset;''' + cur.execute(sql) + json_return = cur.fetchall() + return query_db_finalize(conn, cur, json_return) - json_return = [] - - query = args.get('q') - logging.debug(msg=f"Service: {service}; Method: {method}, Query: {query}") +def query_db_domain_counts(dataset_id): + # The number of concepts in each domain + conn = sql_connection() + cur = conn.cursor() + sql = '''SELECT * + FROM cohd.domain_concept_counts + WHERE dataset_id=%(dataset_id)s;''' + params = {'dataset_id': dataset_id} + cur.execute(sql, params) + json_return = cur.fetchall() + return query_db_finalize(conn, cur, json_return) - if service == 'metadata': - # The datasets in the COHD database - # endpoint: /api/v1/query?service=metadata&meta=datasets - if method == 'datasets': - sql = '''SELECT * - FROM cohd.dataset;''' - cur.execute(sql) - json_return = cur.fetchall() - # The number of concepts in each domain - # endpoint: /api/v1/query?service=metadata&meta=domainCounts&dataset_id=1 - elif method == 'domainCounts': - dataset_id = get_arg_dataset_id(args) - sql = '''SELECT * - FROM cohd.domain_concept_counts - WHERE dataset_id=%(dataset_id)s;''' - params = {'dataset_id': dataset_id} - cur.execute(sql, params) - json_return = cur.fetchall() +def query_db_domain_pair_counts(dataset_id): + # The number of pairs of concepts in each pair of domains + conn = sql_connection() + cur = conn.cursor() + sql = '''SELECT * + FROM cohd.domain_pair_concept_counts + WHERE dataset_id=%(dataset_id)s;''' + params = {'dataset_id': dataset_id} + cur.execute(sql, params) + json_return = cur.fetchall() + return query_db_finalize(conn, cur, json_return) - # The number of pairs of concepts in each pair of domains - # endpoint: /api/v1/query?service=metadata&meta=domainPairCounts&dataset_id=1 - elif method == 'domainPairCounts': - dataset_id = get_arg_dataset_id(args) - sql = '''SELECT * - FROM cohd.domain_pair_concept_counts - WHERE dataset_id=%(dataset_id)s;''' - params = {'dataset_id': dataset_id} - cur.execute(sql, params) - json_return = cur.fetchall() - # The number of patients in the dataset - # endpoint: /api/v1/query?service=metadata&meta=patientCount&dataset_id=1 - elif method == 'patientCount': - dataset_id = get_arg_dataset_id(args) - sql = '''SELECT * - FROM cohd.patient_count - WHERE dataset_id=%(dataset_id)s;''' - params = {'dataset_id': dataset_id} - cur.execute(sql, params) - json_return = cur.fetchall() +def query_db_patient_count(dataset_id): + # The number of patients in the dataset + conn = sql_connection() + cur = conn.cursor() + sql = '''SELECT * + FROM cohd.patient_count + WHERE dataset_id=%(dataset_id)s;''' + params = {'dataset_id': dataset_id} + cur = sql_connection().cursor() + cur.execute(sql, params) + json_return = cur.fetchall() + return query_db_finalize(conn, cur, json_return) - elif service == 'omop': - # Find concept_ids and concept_names that are similar to the query - # e.g. /api/v1/query?service=omop&meta=findConceptIDs&q=cancer - if method == 'findConceptIDs': - # Check query parameter - if query is None or query == [''] or query.isspace(): - return 'q parameter is missing', 400 - dataset_id = get_arg_dataset_id(args) +def query_db_find_concept_ids(dataset_id, query, domain_id=None, min_count=None): + conn = sql_connection() + cur = conn.cursor() + + # Check query parameter + if query is None or query == [''] or query.isspace(): + return 'q parameter is missing', 400 + + # Find concept IDs given name (query) + sql = '''SELECT c.concept_id, concept_name, domain_id, vocabulary_id, concept_class_id, concept_code, + CAST(IFNULL(concept_count, 0) AS UNSIGNED) AS concept_count + FROM cohd.concept c + LEFT JOIN cohd.concept_counts cc ON (cc.dataset_id = %(dataset_id)s AND cc.concept_id = c.concept_id) + WHERE concept_name like %(like_query)s AND standard_concept IN ('S','C') + {domain_filter} + {count_filter} + ORDER BY cc.concept_count DESC + LIMIT 1000;''' + params = { + 'like_query': '%' + query + '%', + 'dataset_id': dataset_id, + 'query': query + } - sql = '''SELECT c.concept_id, concept_name, domain_id, vocabulary_id, concept_class_id, concept_code, - CAST(IFNULL(concept_count, 0) AS UNSIGNED) AS concept_count - FROM cohd.concept c - LEFT JOIN cohd.concept_counts cc ON (cc.dataset_id = %(dataset_id)s AND cc.concept_id = c.concept_id) - WHERE concept_name like %(like_query)s AND standard_concept IN ('S','C') - {domain_filter} - {count_filter} - ORDER BY cc.concept_count DESC - LIMIT 1000;''' - params = { - 'like_query': '%' + query + '%', - 'dataset_id': dataset_id, - 'query': query - } + # Filter concepts by domain + if domain_id is None or domain_id == [''] or domain_id.isspace(): + domain_filter = '' + else: + domain_filter = 'AND domain_id = %(domain_id)s' + params['domain_id'] = domain_id - # Filter concepts by domain - domain_id = args.get('domain') - if domain_id is None or domain_id == [''] or domain_id.isspace(): - domain_filter = '' + # Filter concepts by minimum count + if min_count is None or min_count == ['']: + # Default to set min_count = 1 + count_filter = 'AND cc.concept_count >= 1' + else: + if min_count.strip().isdigit(): + min_count = int(min_count.strip()) + if min_count > 0: + count_filter = 'AND cc.concept_count >= %(min_count)s' + params['min_count'] = min_count else: - domain_filter = 'AND domain_id = %(domain_id)s' - params['domain_id'] = domain_id + count_filter = '' + else: + return 'min_count parameter should be an integer', 400 - # Filter concepts by minimum count - min_count = args.get('min_count') - if min_count is None or min_count == ['']: - # Default to set min_count = 1 - count_filter = 'AND cc.concept_count >= 1' - else: - if min_count.strip().isdigit(): - min_count = int(min_count.strip()) - if min_count > 0: - count_filter = 'AND cc.concept_count >= %(min_count)s' - params['min_count'] = min_count - else: - count_filter = '' - else: - return 'min_count parameter should be an integer', 400 + sql = sql.format(domain_filter=domain_filter, count_filter=count_filter) - sql = sql.format(domain_filter=domain_filter, count_filter=count_filter) + cur = sql_connection().cursor() + cur.execute(sql, params) + json_return = cur.fetchall() + return query_db_finalize(conn, cur, json_return) - cur.execute(sql, params) - json_return = cur.fetchall() - # Looks up concepts for a list of concept_ids - # e.g. /api/v1/query?service=omop&meta=concepts&q=4196636,437643 - elif method == 'concepts': - # Check query parameter - if query is None or query == [''] or query.isspace(): - return 'q parameter is missing', 400 - for concept_id in query.split(','): - if not concept_id.strip().isdigit(): - return 'Error in q: concept_ids should be integers', 400 +def query_db_concepts(query): + conn = sql_connection() + cur = conn.cursor() - # Convert query paramter to a list of concept ids - concept_ids = [int(x.strip()) for x in query.split(',')] + # Check query parameter + if query is None or query == [''] or query.isspace(): + return 'q parameter is missing', 400 + for concept_id in query.split(','): + if not concept_id.strip().isdigit(): + return 'Error in q: concept_ids should be integers', 400 - sql = '''SELECT concept_id, concept_name, domain_id, vocabulary_id, concept_class_id, concept_code - FROM cohd.concept - WHERE concept_id IN (%s);''' % ','.join(['%s' for _ in concept_ids]) + # Convert query paramter to a list of concept ids + concept_ids = [int(x.strip()) for x in query.split(',')] - cur.execute(sql, concept_ids) - json_return = cur.fetchall() + sql = '''SELECT concept_id, concept_name, domain_id, vocabulary_id, concept_class_id, concept_code + FROM cohd.concept + WHERE concept_id IN (%s);''' % ','.join(['%s' for _ in concept_ids]) + + cur.execute(sql, concept_ids) + json_return = cur.fetchall() + return query_db_finalize(conn, cur, json_return) + + +def query_db(service, method, args): + # Connect to MYSQL database + conn = sql_connection() + cur = conn.cursor() + + json_return = [] + + query = args.get('q') + + logging.debug(msg=f"Service: {service}; Method: {method}, Query: {query}") + + if service == 'omop': + # e.g. /api/v1/query?service=omop&meta=findConceptIDs&q=cancer + if method == 'findConceptIDs': + dataset_id = get_arg_dataset_id(args) + domain_id = args.get('domain') + min_count = args.get('min_count') + json_return = query_db_find_concept_ids(cur, dataset_id, query, domain_id, min_count) + + # e.g. /api/v1/query?service=omop&meta=concepts&q=4196636,437643 + elif method == 'concepts': + json_return = query_db_concepts(query) # Looks up ancestors of a given concept # e.g. /api/query?service=omop&meta=conceptAncestors&concept_id=313217 @@ -1096,6 +1131,21 @@ def query_db(service, method, args): for row in json_return: row['confidence_interval'] = rel_freq_ci(row['concept_pair_count'], row['concept_2_count'], confidence_level) + elif method == 'mcq': + # Get non-required parameters + dataset_id = get_arg_dataset_id(args) + domain_id = args.get('domain') + + # concept_id_1 is required + concept_ids = args.get('concept_ids') + if concept_ids is None: + return 'No concept_ids selected', 400 + if type(concept_ids) is str: + concept_ids = [int(x) for x in concept_ids.split(',')] + elif type(concept_ids) is not list: + concept_ids = [concept_ids] + + json_return = query_trapi_mcq(concept_ids, dataset_id, domain_id, bypass=True)['results'] logging.debug(cur._executed) logging.debug(json_return) @@ -1702,7 +1752,275 @@ def query_trapi(concept_id_1, concept_id_2=None, dataset_id=None, domain_id=None conn.close() json_return = {"results": json_return} - return json_return + return json_return + + +def get_pair_concept_count(cur = None, concept_id_list_1 = [], concept_id_list_2 = [], dataset_id = 3,domain_id = None, top_n = 999999): + sql = ''' + SELECT * FROM + ( + SELECT + cp.concept_id_1 as concept_id_1, + cp.concept_id_2 as concept_id_2, + cp.concept_count as concept_pair_count, + c1.concept_count as concept_count_1, + c2.concept_count as concept_count_2 + FROM concept_pair_counts cp + INNER JOIN concept_counts c1 ON cp.concept_id_1 = c1.concept_id + INNER JOIN concept_counts c2 ON cp.concept_id_2 = c2.concept_id + INNER JOIN concept con on con.concept_id = cp.concept_id_2 + WHERE cp.dataset_id = %(dataset_id)s + AND c1.dataset_id = %(dataset_id)s + AND c2.dataset_id = %(dataset_id)s + {concept_id_filter_1} + {domain_filter} + GROUP BY concept_id_1, concept_id_2 + UNION + SELECT + cp.concept_id_2 as concept_id_1, + cp.concept_id_1 as concept_id_2, + cp.concept_count as concept_pair_count, + c1.concept_count as concept_count_1, + c2.concept_count as concept_count_2 + FROM concept_pair_counts cp + INNER JOIN concept_counts c1 ON cp.concept_id_2 = c1.concept_id + INNER JOIN concept_counts c2 ON cp.concept_id_1 = c2.concept_id + INNER JOIN concept con on con.concept_id = cp.concept_id_1 + WHERE cp.dataset_id = %(dataset_id)s + AND c1.dataset_id = %(dataset_id)s + AND c2.dataset_id = %(dataset_id)s + {concept_id_filter_2} + {domain_filter} + GROUP BY concept_id_1, concept_id_2 + ) x + ORDER BY x.concept_pair_count DESC + LIMIT {top_n_filter}; + ''' + params = { + 'dataset_id': dataset_id, + } + concept_id_1 = ','.join([str(c) for c in concept_id_list_1]) + if len(concept_id_list_2) > 0: + concept_id_2 = ','.join([str(c) for c in concept_id_list_2]) + concept_id_filter_1 = ''' + AND ( + (cp.concept_id_1 in ({concept_id_1}) AND cp.concept_id_2 in ({concept_id_2}) ) + ) + ''' + concept_id_filter_2 = ''' + AND ( + (cp.concept_id_2 in ({concept_id_1}) AND cp.concept_id_1 in ({concept_id_2}) ) + ) + ''' + concept_id_filter_1 = concept_id_filter_1.format(concept_id_1= concept_id_1, concept_id_2 = concept_id_2) + concept_id_filter_2 = concept_id_filter_2.format(concept_id_1= concept_id_1, concept_id_2 = concept_id_2) + else: + concept_id_filter_1 = ''' + AND ( + (cp.concept_id_1 in ({concept_id_1}) ) + ) + ''' + concept_id_filter_2 = ''' + AND ( + (cp.concept_id_2 in ({concept_id_1}) ) + ) + ''' + concept_id_filter_1 = concept_id_filter_1.format(concept_id_1 = concept_id_1) + concept_id_filter_2 = concept_id_filter_2.format(concept_id_1 = concept_id_1) + # Filter concepts by domain + if domain_id is not None and not domain_id == ['']: + domain_filter = 'AND con.domain_id = %(domain_id)s' + params['domain_id'] = domain_id + else: + domain_filter = '' + + sql = sql.format(concept_id_filter_1 = concept_id_filter_1, concept_id_filter_2 = concept_id_filter_2, domain_filter=domain_filter, top_n_filter = top_n) + cur.execute(sql, params) + return cur.fetchall() + + +def _get_weighted_statistics(cur=None,dataset_id=None,domain_id = None,concept_id_1 = None, pair_count_df = None, json_key = 'jaccard_index'): + ''' + help function. + Input 1: original association statistics required + Input 2: concept_list in the query for weight calculation (currently only support jaccard index based weight calculation.) + return weighted json_key. e.g. ws_jaccard_index. + ''' + concept_list_1_w_df= pd.DataFrame({'concept_id_1':concept_id_1}) + concept_list_1_w_df['w'] = 1 + + # Calculate the weights based on Jaccard index between input concep + pair_count_q1 = pd.DataFrame(get_pair_concept_count(cur=cur,dataset_id=dataset_id,domain_id=domain_id, concept_id_list_1=concept_id_1,concept_id_list_2=concept_id_1)) + if pair_count_q1.shape[0] > 0: + # Sum of Jaccard index + pair_count_q1['jaccard_index'] = pair_count_q1['concept_pair_count'] / (pair_count_q1['concept_count_1'] + pair_count_q1['concept_count_2'] - pair_count_q1['concept_pair_count']) + pair_count_q1 = pair_count_q1.groupby('concept_id_1')['jaccard_index'].agg('sum').reset_index() + concept_list_1_w_df = concept_list_1_w_df.merge(pair_count_q1).reset_index() + # 1 + sum(Jaccards) + concept_list_1_w_df['w'] = concept_list_1_w_df['w'] + concept_list_1_w_df['jaccard_index'] + # Weight = 1/(1 + sum(Jaccards)) + concept_list_1_w_df['w'] = 1/concept_list_1_w_df['w'] + concept_list_1_w_df = concept_list_1_w_df[['concept_id_1','w']] + + # Multiply the scores by the weights + pair_count_df = pair_count_df.merge(concept_list_1_w_df) + pair_count_df[json_key] = pair_count_df['w'] * pair_count_df[json_key] + pair_count_df = pair_count_df[~pair_count_df['concept_id_2'].isin(concept_id_1)] + + # Group by concept_id_2. Sum the scores and combine concept_id_1 into a list + gb = pair_count_df.groupby('concept_id_2') + weighted_stats = gb['concept_id_1'].agg(list).reset_index() + weighted_stats = weighted_stats.merge(gb[json_key].agg('sum'), on='concept_id_2') + return weighted_stats + + +def _get_ci_scores(r, low, high): + if r[low] > 0: + return r[low] + elif r[high] < 0: + return r[high] + else: + return 0 + + +@cache.memoize(timeout=86400, unless=_bypass_cache) +def query_trapi_mcq(concept_ids, dataset_id=None, domain_id=None, concept_class_id=None, + ln_ratio_sign=0, bypass=False): + """ Query for TRAPI Multicurie Query. Calculates weighted scores using methods similar to linkage disequilibrium to + downweight contributions from input concepts that are similar to each other + + Parameters + ---------- + concept_ids: list of OMOP concept IDs + dataset_id: (optional) String - COHD dataset ID + domain_id: (optional) String - OMOP domain ID + concept_class_id: (optional) String - OMOP concept class ID + ln_ratio_sign: (optional) Int - 1: positive ln_ratio only; -1: negative ln_ratio only; 0: any ln_ratio + confidence: (optional) Float - Confidence level + + Returns + ------- + Dict results + """ + assert concept_ids is not None and type(concept_ids) is list, \ + 'query_cohd_mysql.py::query_trapi_mcq() - Bad input. concept_id_1={concept_ids}'.format( + concept_ids=str(concept_ids) + ) + + # Connect to MYSQL database + conn = sql_connection() + cur = conn.cursor() + + # Filter ln ratio + if ln_ratio_sign == 0: + ln_ratio_filter = '' + elif ln_ratio_sign > 0: + ln_ratio_filter = 'AND log(cp.concept_count * pc.count / (c1.concept_count * c2.concept_count + 0E0)) > 0' + elif ln_ratio_sign < 0: + ln_ratio_filter = 'AND log(cp.concept_count * pc.count / (c1.concept_count * c2.concept_count + 0E0)) < 0' + + sql = '''SELECT * + FROM + ((SELECT + cp.dataset_id, + cp.concept_id_1, + cp.concept_id_2, + ln_ratio, + ln_ratio_ci_lo, + ln_ratio_ci_hi, + log_odds, + log_odds_ci_lo, + log_odds_ci_hi, + c.concept_name AS concept_2_name, + c.domain_id AS concept_2_domain, + c.concept_class_id AS concept_2_class_id + FROM cohd.concept_pair_counts cp + JOIN cohd.concept c ON cp.concept_id_2 = c.concept_id + WHERE cp.dataset_id = %(dataset_id)s + AND cp.concept_id_1 = %(concept_id_1)s + {domain_filter} + {concept_class_filter} + {ln_ratio_filter}) + UNION + (SELECT + cp.dataset_id, + cp.concept_id_2 AS concept_id_1, + cp.concept_id_1 AS concept_id_2, + ln_ratio, + ln_ratio_ci_lo, + ln_ratio_ci_hi, + log_odds, + log_odds_ci_lo, + log_odds_ci_hi, + c.concept_name AS concept_2_name, + c.domain_id AS concept_2_domain, + c.concept_class_id AS concept_2_class_id + FROM cohd.concept_pair_counts cp + JOIN cohd.concept c ON cp.concept_id_1 = c.concept_id + WHERE cp.dataset_id = %(dataset_id)s + AND cp.concept_id_2 = %(concept_id_1)s + {domain_filter} + {concept_class_filter} + {ln_ratio_filter})) x + ORDER BY ABS(ln_ratio) DESC;''' + params = { + 'dataset_id': dataset_id, + } + + if domain_id is not None and not domain_id == ['']: + # Restrict the associated concept by domain + domain_filter = 'AND c.domain_id = %(domain_id)s' + params['domain_id'] = domain_id + else: + # Unrestricted domain + domain_filter = '' + + # Filter concepts by concept_class + if concept_class_id is None or not concept_class_id or concept_class_id == [''] or \ + concept_class_id.isspace(): + concept_class_filter = '' + else: + concept_class_filter = 'AND concept_class_id = %(concept_class_id)s' + params['concept_class_id'] = concept_class_id + + # Get the associations for each of the concepts in the list + pair_counts = list() + for concept_id_1 in concept_ids: + params['concept_id_1'] = concept_id_1 + sqlp = sql.format(domain_filter=domain_filter, concept_class_filter=concept_class_filter, + ln_ratio_filter=ln_ratio_filter) + + cur.execute(sqlp, params) + pair_counts.extend(cur.fetchall()) + pair_count = pd.DataFrame(pair_counts) + + # Scorify ln_ratio and log_odds + pair_count['ln_ratio_score'] = pair_count.apply(_get_ci_scores, axis=1, low='ln_ratio_ci_lo', high='ln_ratio_ci_hi') + # pair_count['log_odds_score'] = pair_count.apply(_get_ci_scores, axis=1, low='log_odds_ci_lo', high='log_odds_ci_hi') + + # Adjust the scores by weights + concept_list_1 = list(set(pair_count['concept_id_1'].tolist())) + weighted_ln_ratio = _get_weighted_statistics(cur=cur, dataset_id=dataset_id, domain_id=domain_id, + concept_id_1=concept_list_1, pair_count_df=pair_count, + json_key = 'ln_ratio_score') + # weighted_log_odds = _get_weighted_statistics(cur=cur, dataset_id=dataset_id, domain_id=domain_id, + # concept_id_1=concept_list_1, pair_count_df=pair_count, + # json_key = 'log_odds_score') + + # Extract concept 2 definitions + columns_c2 = ['dataset_id', 'concept_id_2', 'concept_2_name', 'concept_2_domain', 'concept_2_class_id'] + concept_2_defs = pair_count[columns_c2].groupby('concept_id_2').agg(lambda x: x.iloc[0]) + + # Merge and sort results, and convert to dict for JSON results + results = concept_2_defs.merge(weighted_ln_ratio, on='concept_id_2') + # results = results.merge(weighted_log_odds[['concept_id_2', 'log_odds_score']], on='concept_id_2') + results = results.sort_values('ln_ratio_score', ascending=False) + json_return = {'results': results.to_dict('records')} + + cur.close() + conn.close() + + return json_return def health(): diff --git a/requirements.txt b/requirements.txt index bc1f6fe..382a9cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,5 @@ opentelemetry-sdk opentelemetry-instrumentation-flask opentelemetry-exporter-jaeger opentelemetry-instrumentation-requests -opentelemetry-instrumentation-pymysql \ No newline at end of file +opentelemetry-instrumentation-pymysql +pandas From 5fcaf544ea1e38e0156d594b9667a6ae09a427b8 Mon Sep 17 00:00:00 2001 From: caseyta Date: Wed, 29 May 2024 21:49:09 +0000 Subject: [PATCH 2/6] Add TRAPI query validation back in --- cohd/cohd_trapi_15.py | 9 ++++----- cohd/trapi/reasoner_validator_ext.py | 24 ++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/cohd/cohd_trapi_15.py b/cohd/cohd_trapi_15.py index 257fe18..a2b9304 100644 --- a/cohd/cohd_trapi_15.py +++ b/cohd/cohd_trapi_15.py @@ -11,7 +11,7 @@ from .cohd_utilities import omop_concept_curie from .cohd_trapi import * from .biolink_mapper import * -from .trapi.reasoner_validator_ext import validate_trapi_14x as validate_trapi +from .trapi.reasoner_validator_ext import validate_trapi_15x as validate_trapi from .translator import bm_toolkit, bm_version from .translator.ubergraph import Ubergraph @@ -41,7 +41,7 @@ class CohdTrapi150(CohdTrapi): edge_types_negative = ['biolink:negatively_correlated_with'] default_negative_predicate = edge_types_negative[0] - tool_version = f'{CohdTrapi._SERVICE_NAME} 6.4.3' + tool_version = f'{CohdTrapi._SERVICE_NAME} 6.5.0' schema_version = '1.5.0' biolink_version = bm_version @@ -126,9 +126,8 @@ def _check_query_input(self): return self._valid_query, self._invalid_query_response # Use TRAPI Reasoner Validator to validate the query - try: - # For now, bypass the TRAPI validation because reasoner_validator doesn't work with TRAPI 1.5 - # validate_trapi(self._json_data, "Query") + try: + validate_trapi(self._json_data, "Query") self.log('Query passed reasoner validator') except ValidationError as err: self._valid_query = False diff --git a/cohd/trapi/reasoner_validator_ext.py b/cohd/trapi/reasoner_validator_ext.py index 3949f78..b08f032 100644 --- a/cohd/trapi/reasoner_validator_ext.py +++ b/cohd/trapi/reasoner_validator_ext.py @@ -134,6 +134,30 @@ def validate_trapi_14x(instance, component): return validator.validate(instance, component) +def validate_trapi_15x(instance, component): + """Validate instance against TRAPI 1.5 schema. + + Parameters + ---------- + instance + instance to validate + component : str + component to validate against + + Raises + ------ + `ValidationError `_ + If the instance is invalid. + + Examples + -------- + >>> validate({"message": {}}, "Query") + """ + # Validate against official TRAPI 1.4 release + validator = TRAPISchemaValidator(trapi_version='1.5.0') + return validator.validate(instance, component) + + def validate_trapi_response(trapi_version, bl_version, response): """ Uses the reasoner_validator's more advanced TRAPIResponseValidator to perform thorough validation From e77561817812300af97d20ddbe5199fbea207091 Mon Sep 17 00:00:00 2001 From: caseyta Date: Wed, 29 May 2024 23:46:56 +0000 Subject: [PATCH 3/6] Move checks for unsupported TRAPI queries to beginning of interpretation process --- cohd/cohd_trapi_15.py | 111 ++++++++++++++++++--------------------- cohd/query_cohd_mysql.py | 21 ++------ 2 files changed, 56 insertions(+), 76 deletions(-) diff --git a/cohd/cohd_trapi_15.py b/cohd/cohd_trapi_15.py index a2b9304..f74623b 100644 --- a/cohd/cohd_trapi_15.py +++ b/cohd/cohd_trapi_15.py @@ -146,8 +146,8 @@ def _check_query_input(self): return self._valid_query, self._invalid_query_response # Check the structure of the query graph. Should have 2 nodes and 1 edge (one-hop query) - nodes = query_graph.get('nodes') - edges = query_graph.get('edges') + nodes = list(query_graph.get('nodes').values()) + edges = list(query_graph.get('edges').values()) if nodes is None or len(nodes) != 2 or edges is None or len(edges) != 1: self._valid_query = False msg = 'Unsupported query. Only one-hop queries supported.' @@ -156,6 +156,55 @@ def _check_query_input(self): self._invalid_query_response = response, 200 return self._valid_query, self._invalid_query_response + # If client provided non-empty QNode constraints, respond with error code + if nodes[0].get('constraints') or nodes[1].get('constraints'): + self._valid_query = False + description = f'{CohdTrapi._SERVICE_NAME} does not support QNode constraints' + self.log(description, TrapiStatusCode.UNSUPPORTED_CONSTRAINT, logging.ERROR) + response = self._trapi_mini_response(TrapiStatusCode.UNSUPPORTED_CONSTRAINT, description) + self._invalid_query_response = response, 200 + return self._valid_query, self._invalid_query_response + if edges[0].get("attribute_constraints"): + self._valid_query = False + description = f'{CohdTrapi._SERVICE_NAME} does not support QEdge attribute constraints' + self.log(description, TrapiStatusCode.UNSUPPORTED_ATTR_CONSTRAINT, logging.ERROR) + response = self._trapi_mini_response(TrapiStatusCode.UNSUPPORTED_ATTR_CONSTRAINT, description) + self._invalid_query_response = response, 200 + return self._valid_query, self._invalid_query_response + if edges[0].get("qualifier_constraints"): + self._valid_query = False + description = f'{CohdTrapi._SERVICE_NAME} does not support QEdge qualifier constraints' + self.log(description, TrapiStatusCode.UNSUPPORTED_QUAL_CONSTRAINT, logging.ERROR) + response = self._trapi_mini_response(TrapiStatusCode.UNSUPPORTED_QUAL_CONSTRAINT, description) + self._invalid_query_response = response, 200 + return self._valid_query, self._invalid_query_response + + # If client specifies unsupported set_interpretation (ALL or MANY), respond with error code + if nodes[0].get('set_interpretation') in CohdTrapi150.unsupported_set_interpretation or \ + nodes[1].get('set_interpretation') in CohdTrapi150.unsupported_set_interpretation: + self._valid_query = False + description = f'{CohdTrapi._SERVICE_NAME} only supports QNode set_interpretation of {CohdTrapi150.supported_set_interpretation}' + self.log(description, TrapiStatusCode.UNSUPPORTED_SET_INTERPRETATION, logging.ERROR) + response = self._trapi_mini_response(TrapiStatusCode.UNSUPPORTED_SET_INTERPRETATION, description) + self._invalid_query_response = response, 200 + return self._valid_query, self._invalid_query_response + + # Check to see if cohd doesn't recognize any properties + qnode_properties = {'ids','categories', 'set_interpretation', 'constraints'} + unrec_properties = (set(nodes[0].keys()) | (set(nodes[1].keys()))) - qnode_properties + if unrec_properties: + description = f'{CohdTrapi._SERVICE_NAME} does not recognize the following node properties: ' \ + f'{", ".join(unrec_properties)}. {CohdTrapi._SERVICE_NAME} will ignore these properties.' + self.log(description, level=logging.WARNING) + + qedge_properties = {'knowledge_type', 'predicates', 'subject', 'object', 'attribute_constraints', + 'qualifier_constraints'} + unrec_properties = set(edges[0].keys()) - qedge_properties + if unrec_properties: + description = f'{CohdTrapi._SERVICE_NAME} does not recognize the following edge properties: ' \ + f'{", ".join(unrec_properties)}. {CohdTrapi._SERVICE_NAME} will ignore these properties.' + self.log(description, level=logging.WARNING) + # Check the workflow. Should be at most a single lookup operation workflow = self._json_data.get('workflow') if workflow and type(workflow) is list: @@ -227,7 +276,7 @@ def _interpret_query(self): True if input is valid, otherwise (False, message) """ # Log that TRAPI 1.4 was called because there's no clear indication otherwise - logging.debug('Query issued against TRAPI 1.4') + logging.debug(f'Query issued against TRAPI {CohdTrapi150.schema_version}') try: self._json_data = self._request.get_json() @@ -552,62 +601,6 @@ def _interpret_query(self): self.log(f'The following categories were not recognized in Biolink {bm_version}: {unrecognized_cats}', level=logging.WARNING) - # If client provided non-empty QNode constraints, respond with error code - if concept_1_qnode.get('constraints') or concept_2_qnode.get('constraints'): - self._valid_query = False - description = f'{CohdTrapi._SERVICE_NAME} does not support QNode constraints' - self.log(description, TrapiStatusCode.UNSUPPORTED_CONSTRAINT, logging.ERROR) - response = self._trapi_mini_response(TrapiStatusCode.UNSUPPORTED_CONSTRAINT, description) - self._invalid_query_response = response, 200 - return self._valid_query, self._invalid_query_response - if self._query_edge.get("attribute_constraints"): - self._valid_query = False - description = f'{CohdTrapi._SERVICE_NAME} does not support QEdge attribute constraints' - self.log(description, TrapiStatusCode.UNSUPPORTED_ATTR_CONSTRAINT, logging.ERROR) - response = self._trapi_mini_response(TrapiStatusCode.UNSUPPORTED_ATTR_CONSTRAINT, description) - self._invalid_query_response = response, 200 - return self._valid_query, self._invalid_query_response - if self._query_edge.get("qualifier_constraints"): - self._valid_query = False - description = f'{CohdTrapi._SERVICE_NAME} does not support QEdge qualifier constraints' - self.log(description, TrapiStatusCode.UNSUPPORTED_QUAL_CONSTRAINT, logging.ERROR) - response = self._trapi_mini_response(TrapiStatusCode.UNSUPPORTED_QUAL_CONSTRAINT, description) - self._invalid_query_response = response, 200 - return self._valid_query, self._invalid_query_response - - # If client specifies unsupported set_interpretation (ALL or MANY), respond with error code - if concept_1_qnode.get('set_interpretation') in CohdTrapi150.unsupported_set_interpretation or \ - concept_2_qnode.get('set_interpretation') in CohdTrapi150.unsupported_set_interpretation: - self._valid_query = False - description = f'{CohdTrapi._SERVICE_NAME} only supports QNode set_interpretation of {CohdTrapi150.supported_set_interpretation}' - self.log(description, TrapiStatusCode.UNSUPPORTED_SET_INTERPRETATION, logging.ERROR) - response = self._trapi_mini_response(TrapiStatusCode.UNSUPPORTED_SET_INTERPRETATION, description) - self._invalid_query_response = response, 200 - return self._valid_query, self._invalid_query_response - - # Check to see if cohd doesn't recognize any properties - qnode_properties = {'ids','categories', 'set_interpretation', 'constraints'} - qedge_properties = {'knowledge_type', 'predicates', 'subject', 'object', 'attribute_constraints', - 'qualifier_constraints'} - sep = ', ' - unrec_properties = set(concept_1_qnode.keys()) - qnode_properties - if unrec_properties: - description = f'{CohdTrapi._SERVICE_NAME} does not recognize the following properties: ' \ - f'{sep.join(unrec_properties)}. {CohdTrapi._SERVICE_NAME} will ignore these properties.' - self.log(description, level=logging.WARNING) - - unrec_properties = set(concept_2_qnode.keys()) - qnode_properties - if unrec_properties: - description = f'{CohdTrapi._SERVICE_NAME} does not recognize the following properties: ' \ - f'{sep.join(unrec_properties)}. {CohdTrapi._SERVICE_NAME} will ignore these properties.' - self.log(description, level=logging.WARNING) - - unrec_properties = set(self._query_edge.keys()) - qedge_properties - if unrec_properties: - description = f'{CohdTrapi._SERVICE_NAME} does not recognize the following properties: ' \ - f'{sep.join(unrec_properties)}. {CohdTrapi._SERVICE_NAME} will ignore these properties.' - self.log(description, level=logging.WARNING) - # Get concept_id_1. QNode IDs is a list. self._concept_1_omop_ids = list() found = False diff --git a/cohd/query_cohd_mysql.py b/cohd/query_cohd_mysql.py index 1884318..b1184a1 100644 --- a/cohd/query_cohd_mysql.py +++ b/cohd/query_cohd_mysql.py @@ -96,7 +96,7 @@ def query_db_finalize(conn, cursor, json_return): conn.close() json_return = {"results": json_return} - return jsonify(json_return) + return json_return def query_db_datasets(): @@ -239,20 +239,7 @@ def query_db(service, method, args): logging.debug(msg=f"Service: {service}; Method: {method}, Query: {query}") if service == 'omop': - # e.g. /api/v1/query?service=omop&meta=findConceptIDs&q=cancer - if method == 'findConceptIDs': - dataset_id = get_arg_dataset_id(args) - domain_id = args.get('domain') - min_count = args.get('min_count') - json_return = query_db_find_concept_ids(cur, dataset_id, query, domain_id, min_count) - - # e.g. /api/v1/query?service=omop&meta=concepts&q=4196636,437643 - elif method == 'concepts': - json_return = query_db_concepts(query) - - # Looks up ancestors of a given concept - # e.g. /api/query?service=omop&meta=conceptAncestors&concept_id=313217 - elif method == 'conceptAncestors': + if method == 'conceptAncestors': # Get non-required parameters dataset_id = get_arg_dataset_id(args, DATASET_ID_DEFAULT_HIER) @@ -2083,9 +2070,9 @@ def omop_concept_definitions(concept_ids): if not concept_ids: return concept_defs - response = query_db(service='omop', method='concepts', args={'q': ','.join(str(c) for c in concept_ids)}) + q = ','.join(str(c) for c in concept_ids) + concept_results = query_db_concepts(q) - concept_results = response.get_json() if concept_results is None or 'results' not in concept_results: return concept_defs From de8b940b28270c6bf11f30a587ad0ab7aac436e4 Mon Sep 17 00:00:00 2001 From: caseyta Date: Thu, 30 May 2024 00:18:39 +0000 Subject: [PATCH 4/6] Move more TRAPI checks up front. Remove deprecated query_options --- cohd/cohd_trapi_15.py | 51 +++++++++++-------------------------------- 1 file changed, 13 insertions(+), 38 deletions(-) diff --git a/cohd/cohd_trapi_15.py b/cohd/cohd_trapi_15.py index f74623b..4f98038 100644 --- a/cohd/cohd_trapi_15.py +++ b/cohd/cohd_trapi_15.py @@ -75,7 +75,6 @@ def __init__(self, request): self._request = request self._max_results_per_input = CohdTrapi.default_max_results_per_input self._max_results = CohdTrapi.default_max_results - self._local_oxo = CohdTrapi.default_local_oxo self._kg_nodes = {} self._knowledge_graph = { 'nodes': {}, @@ -155,7 +154,16 @@ def _check_query_input(self): response = self._trapi_mini_response(TrapiStatusCode.NO_RESULTS, msg) self._invalid_query_response = response, 200 return self._valid_query, self._invalid_query_response - + + # Check if QNodes are null + if nodes[0] is None or nodes[1] is None: + self._valid_query = False + msg = f'Null QNode found in query graph' + self.log(msg, level=logging.ERROR) + response = self._trapi_mini_response(TrapiStatusCode.NO_RESULTS, msg) + self._invalid_query_response = response, 200 + return self._valid_query, self._invalid_query_response + # If client provided non-empty QNode constraints, respond with error code if nodes[0].get('constraints') or nodes[1].get('constraints'): self._valid_query = False @@ -302,7 +310,7 @@ def _interpret_query(self): } self._log_level = log_level_enum.get(log_level, CohdTrapi.default_log_level) - # Check that the query input has the correct structure + # Check that the query input has the correct structure and doesn't request unsupported TRAPI features input_check = self._check_query_input() if not input_check[0]: return input_check @@ -354,16 +362,6 @@ def _interpret_query(self): self._confidence_interval = CohdTrapi.default_confidence_interval self._query_options['confidence_interval'] = CohdTrapi.default_confidence_interval - # Get the query_option for local_oxo - self._local_oxo = self._query_options.get('local_oxo') - if self._local_oxo is None or not isinstance(self._local_oxo, bool): - self._local_oxo = CohdTrapi.default_local_oxo - - # Get the query_option for maximum mapping distance - self._mapping_distance = self._query_options.get('mapping_distance') - if self._mapping_distance is None or not isinstance(self._mapping_distance, Number): - self._mapping_distance = CohdTrapi.default_mapping_distance - # Get query_option for including only Biolink nodes self._biolink_only = self._query_options.get('biolink_only') if self._biolink_only is None or not isinstance(self._biolink_only, bool): @@ -377,18 +375,9 @@ def _interpret_query(self): # Get query information from query_graph self._query_graph = self._json_data['message']['query_graph'] - - # Check that the query_graph is supported by the COHD reasoner (1-hop query) - edges = self._query_graph['edges'] - if len(edges) != 1: - self._valid_query = False - msg = f'{CohdTrapi._SERVICE_NAME} reasoner only supports 1-hop queries' - self.log(msg, level=logging.WARNING) - response = self._trapi_mini_response(TrapiStatusCode.NO_RESULTS, msg) - self._invalid_query_response = response, 200 - return self._valid_query, self._invalid_query_response - + # Check if the edge type is supported by COHD Reasoner and how it should be processed + edges = self._query_graph['edges'] self._query_edge_key = list(edges.keys())[0] # Get first and only edge self._query_edge = edges[self._query_edge_key] self._query_edge_predicates = self._query_edge.get('predicates') @@ -456,22 +445,8 @@ def _interpret_query(self): # Note: qnode_key refers to the key identifier for the qnode in the QueryGraph's nodes property, e.g., "n00" subject_qnode_key = self._query_edge['subject'] subject_qnode = self._find_query_node(subject_qnode_key) - if subject_qnode is None: - self._valid_query = False - msg = f'QNode id "{subject_qnode_key}" not found in query graph' - self.log(msg, level=logging.ERROR) - response = self._trapi_mini_response(TrapiStatusCode.NO_RESULTS, msg) - self._invalid_query_response = response, 200 - return self._valid_query, self._invalid_query_response object_qnode_key = self._query_edge['object'] object_qnode = self._find_query_node(object_qnode_key) - if object_qnode is None: - self._valid_query = False - msg = f'QNode id "{object_qnode_key}" not found in query graph' - self.log(msg, level=logging.ERROR) - response = self._trapi_mini_response(TrapiStatusCode.NO_RESULTS, msg) - self._invalid_query_response = response, 200 - return self._valid_query, self._invalid_query_response # In COHD queries, concept_id_1 must be specified by ID. Figure out which QNode to use for concept_1 node_ids = set() From 5e4df67a23e1fdd573653616f2b0c6cf79562e68 Mon Sep 17 00:00:00 2001 From: caseyta Date: Sat, 1 Jun 2024 00:16:21 +0000 Subject: [PATCH 5/6] Beginning MCQ refactor --- cohd/cohd_trapi_15.py | 294 +++++++++++++++++++++++++++--------------- 1 file changed, 193 insertions(+), 101 deletions(-) diff --git a/cohd/cohd_trapi_15.py b/cohd/cohd_trapi_15.py index 4f98038..a5cfdb7 100644 --- a/cohd/cohd_trapi_15.py +++ b/cohd/cohd_trapi_15.py @@ -30,8 +30,9 @@ class CohdTrapi150(CohdTrapi): 'biolink:has_real_world_evidence_of_association_with'] # QNode set_interpretation values that COHD TRAPI does not support - supported_set_interpretation = ['BATCH'] + supported_set_interpretation = ['BATCH', 'MANY'] unsupported_set_interpretation = list(set(['BATCH', 'ALL', 'MANY']) - set(supported_set_interpretation)) + DEFAULT_SET_INTERPRETATION = 'BATCH' # Biolink predicates that request positive associations only edge_types_positive = ['biolink:positively_correlated_with'] @@ -62,6 +63,8 @@ def __init__(self, request): self._concept_2_ancestor_dict = None # Boolean indicating if concept_1 (from API context) is the subject node (True) or object node (False) self._concept_1_is_subject_qnode = True + self._concept_1_set_interpretation = None + self._concept_2_set_interpretation = None self._query_options = None self._method = None self._concept_1_omop_ids = None @@ -219,7 +222,7 @@ def _check_query_input(self): if len(workflow) > 1 or workflow[0]['id'] != CohdTrapi.supported_operation: self._valid_query = False msg = f'Unsupported workflow. Only a single "{CohdTrapi.supported_operation}" operation is supported' - self.log(msg, level=logging.WARNING) + self.log(msg, level=logging.ERROR) response = self._trapi_mini_response(TrapiStatusCode.NO_RESULTS, msg) self._invalid_query_response = response, 200 return self._valid_query, self._invalid_query_response @@ -458,10 +461,11 @@ def _interpret_query(self): concept_1_qnode = subject_qnode self._concept_2_qnode_key = object_qnode_key concept_2_qnode = object_qnode + self._concept_1_set_interpretation = concept_1_qnode.get('set_interpretation', CohdTrapi150.DEFAULT_SET_INTERPRETATION) - # Check the length of the IDs list is below the batch size limit - ids = subject_qnode['ids'] - if len(ids) > CohdTrapi.batch_size_limit: + # Check the length of the IDs list is below the batch size limit + ids = subject_qnode['ids'] + if self._concept_1_set_interpretation == 'BATCH' and len(ids) > CohdTrapi.batch_size_limit: # Warn the client and truncate the ids list description = f"More IDs ({len(ids)}) in QNode '{subject_qnode_key}' than batch_size_limit allows "\ f"({CohdTrapi.batch_size_limit}). IDs list will be truncated." @@ -470,6 +474,7 @@ def _interpret_query(self): subject_qnode['ids'] = ids node_ids = node_ids.union(ids) if 'ids' in object_qnode: + object_qnode_set_interpretation = object_qnode.get('set_interpretation', CohdTrapi150.DEFAULT_SET_INTERPRETATION) if 'ids' not in subject_qnode: # Swap the subj/obj mapping to concept1/2 if only the obj node has IDs self._concept_1_is_subject_qnode = False @@ -477,10 +482,22 @@ def _interpret_query(self): concept_1_qnode = object_qnode self._concept_2_qnode_key = subject_qnode_key concept_2_qnode = subject_qnode + self._concept_1_set_interpretation = object_qnode_set_interpretation + else: + # Both QNodes have IDs specified + self._concept_2_set_interpretation = object_qnode_set_interpretation + # COHD only supports set_interpretation MANY when IDs given on 1 qnode + if self._concept_1_set_interpretation == 'MANY' or self._concept_2_set_interpretation == 'MANY': + self._valid_query = False + msg = f'For COHD MCQ, only a single QNode is allowed to have IDs' + self.log(msg, level=logging.ERROR) + response = self._trapi_mini_response(TrapiStatusCode.NO_RESULTS, msg) + self._invalid_query_response = response, 200 + return self._valid_query, self._invalid_query_response # Check the length of the IDs list is below the batch size limit ids = object_qnode['ids'] - if len(ids) > CohdTrapi.batch_size_limit: + if object_qnode_set_interpretation == 'BATCH' and len(ids) > CohdTrapi.batch_size_limit: # Warn the client and truncate the ids list description = f"More IDs ({len(ids)}) in QNode '{object_qnode_key}' than batch_size_limit allows " \ f"({CohdTrapi.batch_size_limit}). IDs list will be truncated." @@ -493,7 +510,7 @@ def _interpret_query(self): # COHD queries require at least 1 node with a specified ID if len(node_ids) == 0: self._valid_query = False - msg = '{CohdTrapi._SERVICE_NAME} TRAPI requires at least one node to have an ID' + msg = f'{CohdTrapi._SERVICE_NAME} TRAPI requires at least one node to have an ID' self.log(msg, level=logging.ERROR) response = self._trapi_mini_response(TrapiStatusCode.NO_RESULTS, msg) self._invalid_query_response = response, 200 @@ -584,37 +601,38 @@ def _interpret_query(self): # Get subclasses for all CURIEs using Automat-Ubergraph descendant_ids = list() ancestor_dict = dict() + # Don't get sublcasses for MCQ because it could make the query very complex + if self._concept_1_set_interpretation != 'MANY': + descendant_results = Ubergraph.get_descendants(ids, self._concept_1_qnode_categories) + if descendant_results is not None: + # Add new descendant CURIEs to the end of IDs list + descendants, ancestor_dict = descendant_results + descendant_ids = list(set(descendants.keys()) - set(ids)) + if len(descendant_ids) > 0: + if (len(ids) + len(descendant_ids)) > CohdTrapi.batch_size_limit: + # Only add up to the batch_size_limit + n_to_add = CohdTrapi.batch_size_limit - len(ids) + descendant_ids_ignored = descendant_ids[n_to_add:] + descendant_ids = descendant_ids[:n_to_add] + description = f"More descendants from Automat-Ubergraph KP for QNode '{self._concept_1_qnode_key}'"\ + f"than batch_size_limit allows. Ignored: {descendant_ids_ignored}." + self.log(description, level=logging.WARNING) - descendant_results = Ubergraph.get_descendants(ids, self._concept_1_qnode_categories) - if descendant_results is not None: - # Add new descendant CURIEs to the end of IDs list - descendants, ancestor_dict = descendant_results - descendant_ids = list(set(descendants.keys()) - set(ids)) - if len(descendant_ids) > 0: - if (len(ids) + len(descendant_ids)) > CohdTrapi.batch_size_limit: - # Only add up to the batch_size_limit - n_to_add = CohdTrapi.batch_size_limit - len(ids) - descendant_ids_ignored = descendant_ids[n_to_add:] - descendant_ids = descendant_ids[:n_to_add] - description = f"More descendants from Automat-Ubergraph KP for QNode '{self._concept_1_qnode_key}'"\ - f"than batch_size_limit allows. Ignored: {descendant_ids_ignored}." - self.log(description, level=logging.WARNING) - - ids.extend(descendant_ids) - ids_deduped = SriNodeNormalizer.remove_equivalents(ids) - if ids_deduped is not None: - ids = ids_deduped + ids.extend(descendant_ids) + ids_deduped = SriNodeNormalizer.remove_equivalents(ids) + if ids_deduped is not None: + ids = ids_deduped + else: + self.log(f'Issue encountered with SRI Node Norm when removing equivalents', level=logging.WARNING) + self.log(f"Adding descendants from Automat-Ubergraph to QNode '{self._concept_1_qnode_key}': {descendant_ids}.", + level=logging.INFO) else: - self.log(f'Issue encountered with SRI Node Norm when removing equivalents', level=logging.WARNING) - self.log(f"Adding descendants from Automat-Ubergraph to QNode '{self._concept_1_qnode_key}': {descendant_ids}.", - level=logging.INFO) + self.log(f"No descendants found from Automat-Ubergraph for QNode '{self._concept_1_qnode_key}'.", + level=logging.INFO) else: - self.log(f"No descendants found from Automat-Ubergraph for QNode '{self._concept_1_qnode_key}'.", - level=logging.INFO) - else: - # Add a warning that we didn't get descendants from Automat-Ubergraph - self.log(f"Issue with retrieving descendants from Automat-Ubergraph for QNode '{self._concept_1_qnode_key}'", - level=logging.WARNING) + # Add a warning that we didn't get descendants from Automat-Ubergraph + self.log(f"Issue with retrieving descendants from Automat-Ubergraph for QNode '{self._concept_1_qnode_key}'", + level=logging.WARNING) # Update the ancestor dictionary for concept 1 self._concept_1_ancestor_dict = ancestor_dict @@ -878,84 +896,158 @@ def _interpret_query(self): else: return self._valid_query, self._invalid_query_response - def operate(self): - """ Performs the COHD query and reasoning. - - Returns - ------- - Response message with JSON data in Translator Reasoner API Standard - """ - # Check if the query is valid - if self._valid_query: - self._cohd_results = [] - self._initialize_trapi_response() - - for i, concept_1_omop_id in enumerate(self._concept_1_omop_ids): - # Limit the amount of time the TRAPI query runs for - ellapsed_time = (datetime.now() - self._start_time).total_seconds() - if ellapsed_time > self._time_limit: - skipped_curies = [self._kg_omop_curie_map[x] for x in self._concept_1_omop_ids[i:]] - description = f'Maximum time limit {self._time_limit} sec reached before all input IDs processed. '\ - f'Skipped IDs: {skipped_curies}' - self.log(description, level=logging.WARNING) - break - - new_cohd_results = list() - if self._concept_2_omop_ids is None: - # Node 2's IDs were not specified - if self._domain_class_pairs: - # Node 2's category was specified. Query associations between Node 1 and the requested - # categories (domains) - for domain_id, concept_class_id in self._domain_class_pairs: - json_results = query_cohd_mysql.query_trapi(concept_id_1=concept_1_omop_id, - concept_id_2=None, - dataset_id=self._dataset_id, - domain_id=domain_id, - concept_class_id=concept_class_id, - ln_ratio_sign=self._association_direction, - confidence=self._confidence_interval, - bypass=self._bypass_cache) - if json_results: - new_cohd_results.extend(json_results['results']) - else: - # No category (domain) was specified for Node 2. Query the associations between Node 1 and all - # domains - json_results = query_cohd_mysql.query_trapi(concept_id_1=concept_1_omop_id, concept_id_2=None, - dataset_id=self._dataset_id, domain_id=None, + def operate_batch(self): + for i, concept_1_omop_id in enumerate(self._concept_1_omop_ids): + # Limit the amount of time the TRAPI query runs for + ellapsed_time = (datetime.now() - self._start_time).total_seconds() + if ellapsed_time > self._time_limit: + skipped_curies = [self._kg_omop_curie_map[x] for x in self._concept_1_omop_ids[i:]] + description = f'Maximum time limit {self._time_limit} sec reached before all input IDs processed. '\ + f'Skipped IDs: {skipped_curies}' + self.log(description, level=logging.WARNING) + break + + new_cohd_results = list() + if self._concept_2_omop_ids is None: + # Node 2's IDs were not specified + if self._domain_class_pairs: + # Node 2's category was specified. Query associations between Node 1 and the requested + # categories (domains) + for domain_id, concept_class_id in self._domain_class_pairs: + json_results = query_cohd_mysql.query_trapi(concept_id_1=concept_1_omop_id, + concept_id_2=None, + dataset_id=self._dataset_id, + domain_id=domain_id, + concept_class_id=concept_class_id, ln_ratio_sign=self._association_direction, confidence=self._confidence_interval, bypass=self._bypass_cache) if json_results: new_cohd_results.extend(json_results['results']) - else: - # Concept 2's IDs were specified. Query Concept 1 against all IDs for Concept 2 - for concept_2_id in self._concept_2_omop_ids: + # No category (domain) was specified for Node 2. Query the associations between Node 1 and all + # domains + json_results = query_cohd_mysql.query_trapi(concept_id_1=concept_1_omop_id, concept_id_2=None, + dataset_id=self._dataset_id, domain_id=None, + ln_ratio_sign=self._association_direction, + confidence=self._confidence_interval, + bypass=self._bypass_cache) + if json_results: + new_cohd_results.extend(json_results['results']) + + else: + # Concept 2's IDs were specified. Query Concept 1 against all IDs for Concept 2 + for concept_2_id in self._concept_2_omop_ids: + json_results = query_cohd_mysql.query_trapi(concept_id_1=concept_1_omop_id, + concept_id_2=concept_2_id, + dataset_id=self._dataset_id, domain_id=None, + confidence=self._confidence_interval, + bypass=self._bypass_cache) + if json_results: + new_cohd_results.extend(json_results['results']) + + # Results within each query call should be sorted, but still need to be sorted across query calls + new_cohd_results = sort_cohd_results(new_cohd_results) + + # Convert results from COHD format to Translator Reasoner standard + results_limit_reached = self._add_results_to_trapi(new_cohd_results) + + # Log warnings and stop when results limits reached + if results_limit_reached: + curie = self._kg_omop_curie_map[concept_1_omop_id] + self.log(f'Results limit ({self._max_results_per_input}) reached for {curie}. ' + 'There may be additional associations.', level=logging.WARNING) + if len(self._results) >= self._max_results: + if i < len(self._concept_1_omop_ids) - 1: + skipped_ids = [self._kg_omop_curie_map[x] for x in self._concept_1_omop_ids[i+1:]] + self.log(f'Total results limit ({self._max_results}) reached. Skipped {skipped_ids}', + level=logging.WARNING) + break + + def operate_mcq(self): + for i, concept_1_omop_id in enumerate(self._concept_1_omop_ids): + # Limit the amount of time the TRAPI query runs for + ellapsed_time = (datetime.now() - self._start_time).total_seconds() + if ellapsed_time > self._time_limit: + skipped_curies = [self._kg_omop_curie_map[x] for x in self._concept_1_omop_ids[i:]] + description = f'Maximum time limit {self._time_limit} sec reached before all input IDs processed. '\ + f'Skipped IDs: {skipped_curies}' + self.log(description, level=logging.WARNING) + break + + new_cohd_results = list() + if self._concept_2_omop_ids is None: + # Node 2's IDs were not specified + if self._domain_class_pairs: + # Node 2's category was specified. Query associations between Node 1 and the requested + # categories (domains) + for domain_id, concept_class_id in self._domain_class_pairs: json_results = query_cohd_mysql.query_trapi(concept_id_1=concept_1_omop_id, - concept_id_2=concept_2_id, - dataset_id=self._dataset_id, domain_id=None, + concept_id_2=None, + dataset_id=self._dataset_id, + domain_id=domain_id, + concept_class_id=concept_class_id, + ln_ratio_sign=self._association_direction, confidence=self._confidence_interval, bypass=self._bypass_cache) if json_results: new_cohd_results.extend(json_results['results']) + else: + # No category (domain) was specified for Node 2. Query the associations between Node 1 and all + # domains + json_results = query_cohd_mysql.query_trapi(concept_id_1=concept_1_omop_id, concept_id_2=None, + dataset_id=self._dataset_id, domain_id=None, + ln_ratio_sign=self._association_direction, + confidence=self._confidence_interval, + bypass=self._bypass_cache) + if json_results: + new_cohd_results.extend(json_results['results']) - # Results within each query call should be sorted, but still need to be sorted across query calls - new_cohd_results = sort_cohd_results(new_cohd_results) - - # Convert results from COHD format to Translator Reasoner standard - results_limit_reached = self._add_results_to_trapi(new_cohd_results) - - # Log warnings and stop when results limits reached - if results_limit_reached: - curie = self._kg_omop_curie_map[concept_1_omop_id] - self.log(f'Results limit ({self._max_results_per_input}) reached for {curie}. ' - 'There may be additional associations.', level=logging.WARNING) - if len(self._results) >= self._max_results: - if i < len(self._concept_1_omop_ids) - 1: - skipped_ids = [self._kg_omop_curie_map[x] for x in self._concept_1_omop_ids[i+1:]] - self.log(f'Total results limit ({self._max_results}) reached. Skipped {skipped_ids}', - level=logging.WARNING) - break + else: + # Concept 2's IDs were specified. Query Concept 1 against all IDs for Concept 2 + for concept_2_id in self._concept_2_omop_ids: + json_results = query_cohd_mysql.query_trapi(concept_id_1=concept_1_omop_id, + concept_id_2=concept_2_id, + dataset_id=self._dataset_id, domain_id=None, + confidence=self._confidence_interval, + bypass=self._bypass_cache) + if json_results: + new_cohd_results.extend(json_results['results']) + + # Results within each query call should be sorted, but still need to be sorted across query calls + new_cohd_results = sort_cohd_results(new_cohd_results) + + # Convert results from COHD format to Translator Reasoner standard + results_limit_reached = self._add_results_to_trapi(new_cohd_results) + + # Log warnings and stop when results limits reached + if results_limit_reached: + curie = self._kg_omop_curie_map[concept_1_omop_id] + self.log(f'Results limit ({self._max_results_per_input}) reached for {curie}. ' + 'There may be additional associations.', level=logging.WARNING) + if len(self._results) >= self._max_results: + if i < len(self._concept_1_omop_ids) - 1: + skipped_ids = [self._kg_omop_curie_map[x] for x in self._concept_1_omop_ids[i+1:]] + self.log(f'Total results limit ({self._max_results}) reached. Skipped {skipped_ids}', + level=logging.WARNING) + break + + def operate(self): + """ Performs the COHD query and reasoning. + + Returns + ------- + Response message with JSON data in Translator Reasoner API Standard + """ + # Check if the query is valid + if self._valid_query: + self._cohd_results = [] + self._initialize_trapi_response() + + if self._concept_1_set_interpretation == 'BATCH': + self.operate_batch() + elif self._concept_1_set_interpretation == 'MANY': + self.operate_mcq() return self._finalize_trapi_response() else: From 4267ea61e6dbe162ab7a9b046649674216d02112 Mon Sep 17 00:00:00 2001 From: caseyta Date: Mon, 3 Jun 2024 10:58:16 +0000 Subject: [PATCH 6/6] MCQ TRAPI implementation --- cohd/cohd_trapi.py | 25 ++- cohd/cohd_trapi_15.py | 382 ++++++++++++++++++++++++++++++--------- cohd/query_cohd_mysql.py | 137 ++++---------- 3 files changed, 359 insertions(+), 185 deletions(-) diff --git a/cohd/cohd_trapi.py b/cohd/cohd_trapi.py index 092391c..e38bf2a 100644 --- a/cohd/cohd_trapi.py +++ b/cohd/cohd_trapi.py @@ -27,6 +27,7 @@ class TrapiStatusCode(Enum): UNSUPPORTED_ATTR_CONSTRAINT = 'UnsupportedAttributeConstraint' UNSUPPORTED_QUAL_CONSTRAINT = 'UnsupportedQualifierConstraint' UNSUPPORTED_SET_INTERPRETATION = 'UnsupportedSetInterpretation' + MISSING_MEMBER_IDS = 'MissingMemberIDs' class CohdTrapi(ABC): @@ -210,6 +211,28 @@ def criteria_confidence(cohd_result, confidence, threshold=CohdTrapi.default_ln_ return True +def criteria_mcq_score(cohd_result, threshold=CohdTrapi.default_ln_ratio_ci_thresohld): + """ Checks the confidence interval of the result for significance using alpha. Only applies to observed-expected + frequency ratio. Returns True for all other types of results. + + Parameters + ---------- + cohd_result + confidence + threshold + + Returns + ------- + True if significant + """ + if 'ln_ratio_score' in cohd_result: + # obsExpFreq + return abs(cohd_result['ln_ratio_score']) >= threshold + else: + # Missing the score to filter on + return False + + mappings_domain_ontology = { '_DEFAULT': ['ICD9CM', 'RxNorm', 'UMLS', 'DOID', 'MONDO'] } @@ -318,7 +341,7 @@ def sort_cohd_results(cohd_results, sort_field='ln_ratio_ci', ascending=False): if cohd_results is None or len(cohd_results) == 0: return cohd_results - if sort_field in ['p-value', 'ln_ratio', 'relative_frequency']: + if sort_field in ['p-value', 'ln_ratio', 'relative_frequency', 'ln_ratio_score']: sort_values = [x[sort_field] for x in cohd_results] elif sort_field == 'ln_ratio_ci': sort_values = [score_cohd_result(x) for x in cohd_results] diff --git a/cohd/cohd_trapi_15.py b/cohd/cohd_trapi_15.py index a5cfdb7..2cb93cd 100644 --- a/cohd/cohd_trapi_15.py +++ b/cohd/cohd_trapi_15.py @@ -83,6 +83,7 @@ def __init__(self, request): 'nodes': {}, 'edges': {} } + self._auxiliary_graphs = {} # Track in the KG which CURIEs are being used by which OMOP IDs (may be more than 1 OMOP ID) self._kg_curie_omop_use = defaultdict(list) # Track mappings from OMOP to Biolink used for this KG @@ -201,7 +202,7 @@ def _check_query_input(self): return self._valid_query, self._invalid_query_response # Check to see if cohd doesn't recognize any properties - qnode_properties = {'ids','categories', 'set_interpretation', 'constraints'} + qnode_properties = {'ids','categories', 'set_interpretation', 'constraints', 'member_ids'} unrec_properties = (set(nodes[0].keys()) | (set(nodes[1].keys()))) - qnode_properties if unrec_properties: description = f'{CohdTrapi._SERVICE_NAME} does not recognize the following node properties: ' \ @@ -596,7 +597,28 @@ def _interpret_query(self): # Get concept_id_1. QNode IDs is a list. self._concept_1_omop_ids = list() found = False - ids = list(set(concept_1_qnode['ids'])) # remove duplicate CURIEs + if self._concept_1_set_interpretation == 'BATCH': + ids = list(set(concept_1_qnode['ids'])) # remove duplicate CURIEs + elif self._concept_1_set_interpretation == 'MANY': + member_ids = concept_1_qnode.get('member_ids') + if not member_ids: + # Missing required member_ids for MCQ + self._valid_query = False + description = 'set_interpretation: MANY but no member_ids' + response = self._trapi_mini_response(TrapiStatusCode.MISSING_MEMBER_IDS, description) + self._invalid_query_response = response, 200 + return self._valid_query, self._invalid_query_response + ids = list(set(concept_1_qnode['member_ids'])) # remove duplicate CURIEs + + # Get the MCQ set ID + self._mcq_set_id = concept_1_qnode['ids'][0] + + # Copy over the knowledge graph from the input message + self._knowledge_graph = self._json_data['message']['knowledge_graph'] + + # Find the member of edges + self._member_of_edges = {edge['subject']:edge_id for edge_id, edge in self._knowledge_graph['edges'].items() if + edge['predicate'] == 'biolink:member_of' and edge['object'] == self._mcq_set_id} # Get subclasses for all CURIEs using Automat-Ubergraph descendant_ids = list() @@ -854,21 +876,26 @@ def _interpret_query(self): # Criteria for returning results self._criteria = [] - # Add a criterion for minimum co-occurrence - if self._min_cooccurrence > 0: - self._criteria.append(ResultCriteria(function=criteria_min_cooccurrence, - kargs={'cooccurrence': self._min_cooccurrence})) - - # Get query_option for threshold. Don't use filter if not specified (i.e., no default option for threshold) - self._threshold = self._query_options.get('threshold') - if self._threshold is not None and isinstance(self._threshold, Number): - self._criteria.append(ResultCriteria(function=criteria_threshold, + if self._concept_1_set_interpretation == 'MANY': + self._threshold = self._query_options.get('threshold', CohdTrapi.default_ln_ratio_ci_thresohld) + self._criteria.append(ResultCriteria(function=criteria_mcq_score, kargs={'threshold': self._threshold})) - - # If the method is obsExpRatio, add a criteria for confidence interval - if self._method.lower() == 'obsexpratio' and self._confidence_interval > 0: - self._criteria.append(ResultCriteria(function=criteria_confidence, - kargs={'confidence': self._confidence_interval})) + else: + # Add a criterion for minimum co-occurrence + if self._min_cooccurrence > 0: + self._criteria.append(ResultCriteria(function=criteria_min_cooccurrence, + kargs={'cooccurrence': self._min_cooccurrence})) + + # Get query_option for threshold. Don't use filter if not specified (i.e., no default option for threshold) + self._threshold = self._query_options.get('threshold') + if self._threshold is not None and isinstance(self._threshold, Number): + self._criteria.append(ResultCriteria(function=criteria_threshold, + kargs={'threshold': self._threshold})) + + # If the method is obsExpRatio, add a criteria for confidence interval + if self._method.lower() == 'obsexpratio' and self._confidence_interval > 0: + self._criteria.append(ResultCriteria(function=criteria_confidence, + kargs={'confidence': self._confidence_interval})) if self._dataset_auto: # Automatically select the dataset based on which data types being queried @@ -965,72 +992,41 @@ def operate_batch(self): break def operate_mcq(self): - for i, concept_1_omop_id in enumerate(self._concept_1_omop_ids): - # Limit the amount of time the TRAPI query runs for - ellapsed_time = (datetime.now() - self._start_time).total_seconds() - if ellapsed_time > self._time_limit: - skipped_curies = [self._kg_omop_curie_map[x] for x in self._concept_1_omop_ids[i:]] - description = f'Maximum time limit {self._time_limit} sec reached before all input IDs processed. '\ - f'Skipped IDs: {skipped_curies}' - self.log(description, level=logging.WARNING) - break - - new_cohd_results = list() - if self._concept_2_omop_ids is None: - # Node 2's IDs were not specified - if self._domain_class_pairs: - # Node 2's category was specified. Query associations between Node 1 and the requested - # categories (domains) - for domain_id, concept_class_id in self._domain_class_pairs: - json_results = query_cohd_mysql.query_trapi(concept_id_1=concept_1_omop_id, - concept_id_2=None, - dataset_id=self._dataset_id, - domain_id=domain_id, - concept_class_id=concept_class_id, - ln_ratio_sign=self._association_direction, - confidence=self._confidence_interval, - bypass=self._bypass_cache) - if json_results: - new_cohd_results.extend(json_results['results']) - else: - # No category (domain) was specified for Node 2. Query the associations between Node 1 and all - # domains - json_results = query_cohd_mysql.query_trapi(concept_id_1=concept_1_omop_id, concept_id_2=None, - dataset_id=self._dataset_id, domain_id=None, + set_results = list() + single_results = dict() + if self._domain_class_pairs: + # Node 2's category was specified. Query associations between Node 1 and the requested + # categories (domains) + for domain_id, concept_class_id in self._domain_class_pairs: + new_results = query_cohd_mysql.query_trapi_mcq(concept_ids=self._concept_1_omop_ids, + dataset_id=self._dataset_id, + domain_id=domain_id, + concept_class_id=concept_class_id, ln_ratio_sign=self._association_direction, confidence=self._confidence_interval, bypass=self._bypass_cache) - if json_results: - new_cohd_results.extend(json_results['results']) - - else: - # Concept 2's IDs were specified. Query Concept 1 against all IDs for Concept 2 - for concept_2_id in self._concept_2_omop_ids: - json_results = query_cohd_mysql.query_trapi(concept_id_1=concept_1_omop_id, - concept_id_2=concept_2_id, - dataset_id=self._dataset_id, domain_id=None, - confidence=self._confidence_interval, - bypass=self._bypass_cache) - if json_results: - new_cohd_results.extend(json_results['results']) - - # Results within each query call should be sorted, but still need to be sorted across query calls - new_cohd_results = sort_cohd_results(new_cohd_results) - - # Convert results from COHD format to Translator Reasoner standard - results_limit_reached = self._add_results_to_trapi(new_cohd_results) - - # Log warnings and stop when results limits reached - if results_limit_reached: - curie = self._kg_omop_curie_map[concept_1_omop_id] - self.log(f'Results limit ({self._max_results_per_input}) reached for {curie}. ' - 'There may be additional associations.', level=logging.WARNING) - if len(self._results) >= self._max_results: - if i < len(self._concept_1_omop_ids) - 1: - skipped_ids = [self._kg_omop_curie_map[x] for x in self._concept_1_omop_ids[i+1:]] - self.log(f'Total results limit ({self._max_results}) reached. Skipped {skipped_ids}', - level=logging.WARNING) - break + new_set_results, new_single_results = new_results + if new_set_results: + set_results.extend(new_set_results) + single_results.update(new_single_results) + else: + # No category (domain) was specified for Node 2. Query the associations between Node 1 and all + # domains + new_results = query_cohd_mysql.query_trapi_mcq(concept_id_1=self._concept_1_omop_ids, + dataset_id=self._dataset_id, domain_id=None, + ln_ratio_sign=self._association_direction, + confidence=self._confidence_interval, + bypass=self._bypass_cache) + new_set_results, new_single_results = new_results + if new_set_results: + set_results.extend(new_set_results) + single_results.update(new_single_results) + + # Results within each query call should be sorted, but still need to be sorted across query calls + new_set_results = sort_cohd_results(new_set_results, sort_field='ln_ratio_score') + + # Convert results from COHD format to Translator Reasoner standard + self._add_mcq_results_to_trapi(set_results, single_results) def operate(self): """ Performs the COHD query and reasoning. @@ -1119,6 +1115,87 @@ def _add_cohd_result(self, cohd_result, criteria): score = score_cohd_result(cohd_result) self._add_result(node_1['primary_curie'], node_2['primary_curie'], kg_edge_id, score) + def _add_aux_graph(self, edges, attributes=None): + if attributes is None: + attributes = list() + + ag_id = f'ag{(len(self._auxiliary_graphs)+1):06d}' + self._auxiliary_graphs[ag_id] = { + 'edges': edges, + 'attributes': attributes + } + return ag_id + + def _add_mcq_result(self, set_result, single_results, criteria): + """ Adds a COHD result. The COHD result is always added to the knowledge graph. If the COHD result passes all + criteria, it is also added to the results. + + Parameters + ---------- + cohd_result + criteria: List - [ResultCriteria] + """ + assert set_result is not None and 'concept_id_2' in set_result, \ + 'Translator::KnowledgeGraph::_add_mcq_result() - Bad set_result' + + assert single_results, 'Translator::KnowledgeGraph::_add_mcq_result() - Bad set_result' + + # Check if result passes all filters before adding + if criteria is not None: + if not all([c.check(set_result) for c in criteria]): + return + + # Get node for concept 2 + concept_2_id = set_result['concept_id_2'] + concept_2_name = set_result.get('concept_2_name') + concept_2_domain = set_result.get('concept_2_domain') + concept_2_class_id = set_result.get('concept_2_class_id') + node_2 = self._get_kg_node(concept_2_id, concept_2_name, concept_2_domain, concept_2_class_id, + query_node_categories=self._concept_2_qnode_categories) + + if not node_2.get('query_category_compliant', False) or \ + (self._biolink_only and not node_2.get('biolink_compliant', False)): + # Only include results when node_2 maps to biolink and matches the queried category + return + + # Only allow one OMOP ID to use a CURIE. Will allow the first result using a given CURIE to go through. Since + # results are in descending order, will give priority to the OMOP ID with the strongest association + concept_2_curie = node_2['primary_curie'] + if (self._kg_curie_omop_use[concept_2_curie] and concept_2_id not in self._kg_curie_omop_use[concept_2_curie]): + return + + # Add nodes and edge to knowledge graph + is_subject = self._query_edge['subject'] != self._concept_1_qnode_key + kg_node_2, kg_set_edge, kg_set_edge_id = self._add_kg_set_edge(node_2, is_subject, set_result) + + # Add to results + score = set_result['ln_ratio_score'] + self._add_result(self._mcq_set_id, concept_2_curie, kg_set_edge_id, score) + + # Add single result edges and auxiliary graphs + support_graphs = list() + for sr in single_results: + concept_1_id = sr['concept_id_1'] + node_1 = self._get_kg_node(concept_1_id, query_node_categories=self._concept_1_qnode_categories) + if is_subject: + subject_node = node_2 + object_node = node_1 + else: + subject_node = node_1 + object_node = node_2 + kg_node_1, kg_node_2, kg_edge, kg_edge_id = self._add_kg_edge(subject_node, object_node, sr) + + member_of_edge_id = self._member_of_edges[node_1['primary_curie']] + ag_id = self._add_aux_graph([kg_edge_id, member_of_edge_id]) + support_graphs.append(ag_id) + + # Add support graphs to the set edge + kg_set_edge['attributes'].append({ + "attribute_source": CohdTrapi._INFORES_ID, + "attribute_type_id": "biolink:support_graphs", + "values": support_graphs + }) + def _add_result(self, kg_node_1_id, kg_node_2_id, kg_edge_id, score): """ Adds a knowledge graph edge to the results list @@ -1765,6 +1842,128 @@ def _add_kg_edge(self, node_1, node_2, cohd_result): self._knowledge_graph['edges'][ke_id] = kg_edge return kg_node_1, kg_node_2, kg_edge, ke_id + + def _add_kg_set_edge(self, node_2, is_subject, set_result): + """ Adds the edge to the knowledge graph + + Parameters + ---------- + node_2: Answer node + is_subject: True if the answer node should be the subject node + set_result: COHD set result - data gets added to edge + + Returns + ------- + kg_node_1, kg_node_2, kg_edge + """ + # Add nodes to knowledge graph + kg_node_2 = self._add_internal_node_to_kg(node_2) + + # Mint a new identifier + ke_id = self._get_new_kg_edge_id() + + # Determine identifiers for sub and obj + if is_subject: + curie_subj = node_2['primary_curie'] + curie_obj = self._mcq_set_id + else: + curie_obj = node_2['primary_curie'] + curie_subj = self._mcq_set_id + + # Add source retrieval + sources = [ + { + 'resource_id': 'infores:columbia-cdw-ehr-data', + 'resource_role': 'supporting_data_source', + }, + { + 'resource_id': CohdTrapi._INFORES_ID, + 'resource_role': 'primary_knowledge_source', + 'upstream_resource_ids': ['infores:columbia-cdw-ehr-data'] + }, + ] + + # Add properties from COHD results to the edge attributes + attributes = [ + # Knowledge Level + { + 'attribute_type_id': 'biolink:knowledge_level', + 'value': 'statistical_association', + 'attribute_source': CohdTrapi._INFORES_ID + }, + # Agent Type + { + 'attribute_type_id': 'biolink:agent_type', + 'value': 'computational_model', + 'attribute_source': CohdTrapi._INFORES_ID + }, + # Observed-expected frequency ratio analysis + { + "attribute_source": CohdTrapi._INFORES_ID, + "attribute_type_id": "biolink:has_supporting_study_result", + "description": "A study result describing an observed-expected frequency anaylsis on a single pair of concepts", + "value": set_result['ln_ratio_score'], + "value_type_id": "biolink:ObservedExpectedFrequencyAnalysisResult", + 'value_url': 'https://github.com/NCATSTranslator/Translator-All/wiki/COHD-KP', + "attributes": [ + { + 'attribute_type_id': 'biolink:ln_ratio', + 'original_attribute_name': 'ln_ratio', + 'value': set_result['ln_ratio_score'], + 'value_type_id': 'EDAM:data_1772', # Score + 'attribute_source': CohdTrapi._INFORES_ID, + 'description': 'Observed-expected frequency ratio.' + }, + { + 'attribute_type_id': 'biolink:supporting_data_set', # Database ID + 'original_attribute_name': 'dataset_id', + 'value': f"COHD:dataset_{set_result['dataset_id']}", + 'value_type_id': 'EDAM:data_1048', # Database ID + 'attribute_source': CohdTrapi._INFORES_ID, + 'description': f'Dataset ID within {CohdTrapi._SERVICE_NAME}' + }, + # Knowledge Level + { + 'attribute_type_id': 'biolink:knowledge_level', + 'value': 'statistical_association', + 'attribute_source': CohdTrapi._INFORES_ID + }, + # Agent Type + { + 'attribute_type_id': 'biolink:agent_type', + 'value': 'computational_model', + 'attribute_source': CohdTrapi._INFORES_ID + } + ] + }, + ] + + # Determine which predicate to use + predicate = CohdTrapi.default_predicate + if self._kg_edge_predicate is not None: + predicate = self._kg_edge_predicate + else: + ln_ratio = set_result['ln_ratio_score'] + if ln_ratio > 0: + predicate = self.default_positive_predicate + elif ln_ratio < 0: + predicate = self.default_negative_predicate + else: + predicate = self.default_predicate + + # Set the knowledge graph edge properties + kg_edge = { + 'predicate': predicate, + 'subject': curie_subj, + 'object': curie_obj, + 'attributes': attributes, + 'sources': sources + } + + # Add the new edge + self._knowledge_graph['edges'][ke_id] = kg_edge + + return kg_node_2, kg_edge, ke_id def _add_kg_edge_subclass_of(self, descendant_node_id, ancestor_node_id): """ Adds the biolink:subclass_of edge to the knowledge graph @@ -1820,7 +2019,7 @@ def _initialize_trapi_response(self): } def _add_results_to_trapi(self, new_cohd_results): - """ Creates the response message with JSON data in Reasoner Std API format + """ Add results Returns ------- @@ -1840,6 +2039,26 @@ def _add_results_to_trapi(self, new_cohd_results): self._add_cohd_result(result, self._criteria) return False + def _add_mcq_results_to_trapi(self, new_set_results, new_single_results): + """ Add set results and their corresponding single results + + Returns + ------- + boolean: True if results limit reached, otherwise False + """ + n_prior_results = len(self._results) + for _, set_result in enumerate(new_set_results): + # Don't add more than the maximum number of results per input ID + if len(self._results) - n_prior_results >= self._max_results_per_input: + return True + # Don't add more than the maximum total number of results + if len(self._results) >= self._max_results: + return True + + cid2 = set_result['concept_id_2'] + self._add_mcq_result(set_result, new_single_results[cid2], self._criteria) + return False + def _finalize_trapi_response(self, status: TrapiStatusCode = TrapiStatusCode.SUCCESS): """ Finalizes the TRAPI response @@ -1857,7 +2076,8 @@ def _finalize_trapi_response(self, status: TrapiStatusCode = TrapiStatusCode.SUC self._response['message'] = { 'results': self._results, 'query_graph': self._query_graph, - 'knowledge_graph': self._knowledge_graph + 'knowledge_graph': self._knowledge_graph, + 'auxiliary_graphs': self._auxiliary_graphs } if self._logs is not None and self._logs: diff --git a/cohd/query_cohd_mysql.py b/cohd/query_cohd_mysql.py index b1184a1..1ae263d 100644 --- a/cohd/query_cohd_mysql.py +++ b/cohd/query_cohd_mysql.py @@ -1132,7 +1132,11 @@ def query_db(service, method, args): elif type(concept_ids) is not list: concept_ids = [concept_ids] - json_return = query_trapi_mcq(concept_ids, dataset_id, domain_id, bypass=True)['results'] + set_results, single_results = query_trapi_mcq(concept_ids, dataset_id, domain_id, bypass=True) + json_return = { + 'set_results': set_results, + 'single_results': single_results + } logging.debug(cur._executed) logging.debug(json_return) @@ -1712,10 +1716,6 @@ def query_trapi(concept_id_1, concept_id_2=None, dataset_id=None, domain_id=None # Perform calculations for results for row in json_return: - cpc = row['concept_pair_count'] - c1 = row['concept_1_count'] - c2 = row['concept_2_count'] - # Confidence interval for obsExpRatio # The CI bounds may hit Inf, which causes issues with JSON serialization. Limit it to 999 row['ln_ratio_ci'] = (clip(row['ln_ratio_ci_lo'], JSON_INFINITY_REPLACEMENT), @@ -1833,6 +1833,7 @@ def _get_weighted_statistics(cur=None,dataset_id=None,domain_id = None,concept_i Input 2: concept_list in the query for weight calculation (currently only support jaccard index based weight calculation.) return weighted json_key. e.g. ws_jaccard_index. ''' + pair_count_df = pair_count_df.copy() concept_list_1_w_df= pd.DataFrame({'concept_id_1':concept_id_1}) concept_list_1_w_df['w'] = 1 @@ -1856,23 +1857,22 @@ def _get_weighted_statistics(cur=None,dataset_id=None,domain_id = None,concept_i # Group by concept_id_2. Sum the scores and combine concept_id_1 into a list gb = pair_count_df.groupby('concept_id_2') - weighted_stats = gb['concept_id_1'].agg(list).reset_index() - weighted_stats = weighted_stats.merge(gb[json_key].agg('sum'), on='concept_id_2') - return weighted_stats + weighted_stats = gb[json_key].agg('sum') + return weighted_stats.reset_index() -def _get_ci_scores(r, low, high): - if r[low] > 0: - return r[low] - elif r[high] < 0: - return r[high] +def _get_ci_scores(r, score_col): + if r[score_col][0] > 0: + return r[score_col][0] + elif r[score_col][1] < 0: + return r[score_col][1] else: return 0 @cache.memoize(timeout=86400, unless=_bypass_cache) def query_trapi_mcq(concept_ids, dataset_id=None, domain_id=None, concept_class_id=None, - ln_ratio_sign=0, bypass=False): + ln_ratio_sign=0, confidence=DEFAULT_CONFIDENCE, bypass=False): """ Query for TRAPI Multicurie Query. Calculates weighted scores using methods similar to linkage disequilibrium to downweight contributions from input concepts that are similar to each other @@ -1898,111 +1898,42 @@ def query_trapi_mcq(concept_ids, dataset_id=None, domain_id=None, concept_class_ conn = sql_connection() cur = conn.cursor() - # Filter ln ratio - if ln_ratio_sign == 0: - ln_ratio_filter = '' - elif ln_ratio_sign > 0: - ln_ratio_filter = 'AND log(cp.concept_count * pc.count / (c1.concept_count * c2.concept_count + 0E0)) > 0' - elif ln_ratio_sign < 0: - ln_ratio_filter = 'AND log(cp.concept_count * pc.count / (c1.concept_count * c2.concept_count + 0E0)) < 0' - - sql = '''SELECT * - FROM - ((SELECT - cp.dataset_id, - cp.concept_id_1, - cp.concept_id_2, - ln_ratio, - ln_ratio_ci_lo, - ln_ratio_ci_hi, - log_odds, - log_odds_ci_lo, - log_odds_ci_hi, - c.concept_name AS concept_2_name, - c.domain_id AS concept_2_domain, - c.concept_class_id AS concept_2_class_id - FROM cohd.concept_pair_counts cp - JOIN cohd.concept c ON cp.concept_id_2 = c.concept_id - WHERE cp.dataset_id = %(dataset_id)s - AND cp.concept_id_1 = %(concept_id_1)s - {domain_filter} - {concept_class_filter} - {ln_ratio_filter}) - UNION - (SELECT - cp.dataset_id, - cp.concept_id_2 AS concept_id_1, - cp.concept_id_1 AS concept_id_2, - ln_ratio, - ln_ratio_ci_lo, - ln_ratio_ci_hi, - log_odds, - log_odds_ci_lo, - log_odds_ci_hi, - c.concept_name AS concept_2_name, - c.domain_id AS concept_2_domain, - c.concept_class_id AS concept_2_class_id - FROM cohd.concept_pair_counts cp - JOIN cohd.concept c ON cp.concept_id_1 = c.concept_id - WHERE cp.dataset_id = %(dataset_id)s - AND cp.concept_id_2 = %(concept_id_1)s - {domain_filter} - {concept_class_filter} - {ln_ratio_filter})) x - ORDER BY ABS(ln_ratio) DESC;''' - params = { - 'dataset_id': dataset_id, - } - - if domain_id is not None and not domain_id == ['']: - # Restrict the associated concept by domain - domain_filter = 'AND c.domain_id = %(domain_id)s' - params['domain_id'] = domain_id - else: - # Unrestricted domain - domain_filter = '' - - # Filter concepts by concept_class - if concept_class_id is None or not concept_class_id or concept_class_id == [''] or \ - concept_class_id.isspace(): - concept_class_filter = '' - else: - concept_class_filter = 'AND concept_class_id = %(concept_class_id)s' - params['concept_class_id'] = concept_class_id - # Get the associations for each of the concepts in the list - pair_counts = list() + associations = list() for concept_id_1 in concept_ids: - params['concept_id_1'] = concept_id_1 - sqlp = sql.format(domain_filter=domain_filter, concept_class_filter=concept_class_filter, - ln_ratio_filter=ln_ratio_filter) - - cur.execute(sqlp, params) - pair_counts.extend(cur.fetchall()) - pair_count = pd.DataFrame(pair_counts) + a = query_trapi(concept_id_1=concept_id_1, concept_id_2=None, dataset_id=dataset_id, domain_id=domain_id, + concept_class_id=concept_class_id, ln_ratio_sign=ln_ratio_sign, confidence=confidence, bypass=bypass) + associations.extend(a['results']) + associations = pd.DataFrame(associations) # Scorify ln_ratio and log_odds - pair_count['ln_ratio_score'] = pair_count.apply(_get_ci_scores, axis=1, low='ln_ratio_ci_lo', high='ln_ratio_ci_hi') - # pair_count['log_odds_score'] = pair_count.apply(_get_ci_scores, axis=1, low='log_odds_ci_lo', high='log_odds_ci_hi') + associations['ln_ratio_score'] = associations.apply(_get_ci_scores, axis=1, score_col='ln_ratio_ci') + # associations['log_odds_score'] = associations.apply(_get_ci_scores, axis=1, low='log_odds_ci_lo', high='log_odds_ci_hi') # Adjust the scores by weights - concept_list_1 = list(set(pair_count['concept_id_1'].tolist())) + concept_list_1 = list(set(associations['concept_id_1'].tolist())) weighted_ln_ratio = _get_weighted_statistics(cur=cur, dataset_id=dataset_id, domain_id=domain_id, - concept_id_1=concept_list_1, pair_count_df=pair_count, + concept_id_1=concept_list_1, pair_count_df=associations, json_key = 'ln_ratio_score') # weighted_log_odds = _get_weighted_statistics(cur=cur, dataset_id=dataset_id, domain_id=domain_id, - # concept_id_1=concept_list_1, pair_count_df=pair_count, + # concept_id_1=concept_list_1, pair_count_df=associations, # json_key = 'log_odds_score') + # Add list of single associations + single_associations = dict() + for i, row in weighted_ln_ratio.iterrows(): + cid2 = int(row['concept_id_2']) + single_associations[cid2] = associations.loc[associations.concept_id_2 == cid2, :].to_dict('records') + # Extract concept 2 definitions columns_c2 = ['dataset_id', 'concept_id_2', 'concept_2_name', 'concept_2_domain', 'concept_2_class_id'] - concept_2_defs = pair_count[columns_c2].groupby('concept_id_2').agg(lambda x: x.iloc[0]) + concept_2_defs = associations[columns_c2].groupby('concept_id_2').agg(lambda x: x.iloc[0]) # Merge and sort results, and convert to dict for JSON results - results = concept_2_defs.merge(weighted_ln_ratio, on='concept_id_2') + set_associations = concept_2_defs.merge(weighted_ln_ratio, on='concept_id_2') # results = results.merge(weighted_log_odds[['concept_id_2', 'log_odds_score']], on='concept_id_2') - results = results.sort_values('ln_ratio_score', ascending=False) - json_return = {'results': results.to_dict('records')} + set_associations = set_associations.sort_values('ln_ratio_score', ascending=False) + json_return = set_associations.to_dict('records'), single_associations cur.close() conn.close()