In [None]:
import pandas as pd 
import numpy as np  
import matplotlib.pyplot as plt
import text2term

import ast
import csv

from collections import Counter, deque

In [None]:
dataset = pd.read_csv('datasets/staging_test_set.csv').rename(columns={'EDAM Topics': 'Old EDAM Topics'})

In [None]:
dataset.head()

In [None]:
def convert_list(string):
    try:
        return ast.literal_eval(string)
    except:
        return None

dataset['MeSH Terms'] = dataset['MeSH Terms'].apply(convert_list)

In [None]:
# Drop empty MeSH
dataset = dataset[~dataset['MeSH Terms'].isna()]

In [None]:
def has_slash_or_comma(lst):
    return any('/' in string or ',' in string for string in lst)

def has_forward_slash_or_comma(df, column_name):
    return df[column_name].apply(has_slash_or_comma)

has_forward_slash_or_comma(dataset, 'MeSH Terms').sum()

In [None]:
def split_strings(lst):
    new_list = []
    for string in lst:
        filtered_str = string.replace('*', '')
        if '/' in filtered_str:
            new_list.extend([str.strip() for str in filtered_str.split('/')])
        elif ',' in filtered_str:
            # new_list.extend([str.strip() for str in filtered_str.split(',')])
            new_list.extend(map(str.strip, next(csv.reader([string])), skipinitialspace=True))
        else:
            new_list.append(filtered_str)
    return np.unique(new_list).tolist()

dataset['Filtered MeSH Terms'] = dataset['MeSH Terms'].apply(split_strings)

In [None]:
dataset['Filtered MeSH Terms']

In [None]:
def flatten_lists(df, column_name):
    return [item for sublist in df[column_name] for item in sublist]

all_mesh_terms = flatten_lists(dataset, 'Filtered MeSH Terms')

In [None]:
mesh_term_freqs = Counter(all_mesh_terms)
unique_mesh_terms = set(all_mesh_terms)

In [None]:
min_frequency = min(mesh_term_freqs.values())
strings_with_min_frequency = [string for string, frequency in mesh_term_freqs.items() if frequency == min_frequency]

print('Minimum frequency:', min_frequency)
print('Terms with min frequency:', len(strings_with_min_frequency), '/', len(unique_mesh_terms))

In [None]:
## text2term scores

edam_ontology = text2term.cache_ontology("https://data.bioontology.org/ontologies/EDAM/submissions/44/download?apikey=8b5b7825-538d-40e0-9e9e-5ab9274a9aeb", "EDAM")

In [None]:
mapped_terms = text2term.map_terms(list(unique_mesh_terms), "EDAM", use_cache=True)

In [None]:
mapped_terms = mapped_terms[mapped_terms['Mapped Term IRI'].str.contains('topic')]
mapped_terms

In [None]:
mapped_terms[mapped_terms['Mapped Term Label'] == 'Animal study']

In [None]:
plt.hist(mapped_terms['Mapping Score'].values, bins='auto', edgecolor='black', alpha=0.7)

plt.xlabel('Mapping Score')
plt.ylabel('Frequency')

plt.axvline(np.mean(mapped_terms['Mapping Score'].values), color='red')
plt.axvline(np.median(mapped_terms['Mapping Score'].values), color='orange')

In [None]:
# threshold = np.mean(mapped_terms['Mapping Score'].values)
threshold = 0.7

def map_mesh_to_edam(mesh_terms):
    # mapping = text2term.map_terms(mesh_terms, "EDAM", use_cache=True)
    try:
         filtered_mapping = text2term.map_terms(mesh_terms, "EDAM", use_cache=True)
        # filtered_mapping = text2term.map_terms([term for term in mesh_terms if mesh_term_freqs[term] > 1], "EDAM", use_cache=True)
    except:
        return None

    # Only consider mappings over threshold
    filtered_terms = filtered_mapping[filtered_mapping['Mapping Score'] > threshold]

    return filtered_terms['Mapped Term Label'].unique().tolist()

dataset['New EDAM Topics'] = dataset['Filtered MeSH Terms'].apply(map_mesh_to_edam)

In [None]:
# Drop rows with no mapped terms
dataset = dataset[~dataset['New EDAM Topics'].isna()]

In [None]:
# Remove any EDAM not in the list
with open(input("Enter EDAM topics file:"), 'r') as edam_file:
    full_edam_topics = edam_file.readlines()

full_edam_topics = [topic.strip() for topic in full_edam_topics]

dataset['New EDAM Topics'] = dataset['New EDAM Topics'].apply(lambda x: [item for item in x if item in full_edam_topics])

In [None]:
# Compare old and new (dataset vs outputs.csv)

gpt_output = pd.read_csv('outputs.csv')

In [None]:
def get_new_edam(abstract):
    # topics = dataset.loc[dataset['Abstract'] == abstract, 'New EDAM Topics'].values
    # return topics[0] if len(topics[0]) > 0 else None

    matching_rows = dataset.loc[dataset['Abstract'] == abstract, 'New EDAM Topics']
    
    if not matching_rows.empty:
        return matching_rows.iloc[0]
    else:
        return None

print(get_new_edam(gpt_output.iloc[0]['Abstract']))
print(gpt_output['Ground Truth'].iloc[0])

In [None]:
gpt_output.rename(columns={'Ground Truth': 'Old Ground Truth'}, inplace=True)

In [None]:
gpt_output['New Ground Truth'] = gpt_output['Abstract'].apply(get_new_edam)

In [None]:
# Drop any rows without mappings
gpt_output = gpt_output[~gpt_output['New Ground Truth'].isna()]

In [None]:
# Get MeSH Terms

def get_mesh_terms(abstract):
    matching_rows = dataset.loc[dataset['Abstract'] == abstract, 'MeSH Terms']
    
    if not matching_rows.empty:
        return matching_rows.iloc[0]
    else:
        return None

In [None]:
gpt_output['MeSH Terms'] = gpt_output['Abstract'].apply(get_mesh_terms)
# Drop any rows without mappings
gpt_output = gpt_output[~gpt_output['MeSH Terms'].isna()]

In [None]:
gpt_output = gpt_output[['Model', 'Abstract', 'MeSH Terms', 'Old Ground Truth', 'New Ground Truth', 'Predictions', 'Prioritized Predictions']]
gpt_output.to_csv(input("Enter file name: "), index=False)

In [None]:
# Compare old terms with new terms
from IPython.display import HTML

for idx, row in gpt_output.sample(n=5).iterrows():
    display('Abstract:', HTML(f"<p style='overflow-x: auto'>{row['Abstract']}</p>"))
    # print('Abstract:', row['Abstract'].replace('.', '.\n'))
    print('Old:', row['Old Ground Truth'])
    print('New:', ', '.join(row['New Ground Truth']))
    print('GPT:', row['Predictions'], '\n')

In [None]:
# Get PMIDs

def get_pmids(abstract):
    matching_rows = dataset.loc[dataset['Abstract'] == abstract, 'PMID']
    
    if not matching_rows.empty:
        return matching_rows.iloc[0]
    else:
        return None

In [None]:
pmids = gpt_output['Abstract'].apply(get_pmids)
# Drop any rows without mappings
pmids = pmids[~pmids.isna()]

In [None]:
pmids.to_csv('pmids.csv', index=False)

## Testing for discrepancy in MeSH terms

There seems to be a disrepancy between the returned xml data and text data

In [None]:
from Bio import Entrez, Medline

Entrez.email = "zqazi@scripps.edu"

handle = Entrez.efetch(db="pubmed", id=21406103, retmode="xml")

In [None]:
article_data = Entrez.read(handle)
handle.close()


In [None]:
mesh_terms = []
if "PubmedArticle" in article_data:
    for article in article_data["PubmedArticle"]:
        if "MeshHeadingList" in article["MedlineCitation"]:
            mesh_headings = article["MedlineCitation"]["MeshHeadingList"]
            for heading in mesh_headings:
                descriptor_name = heading["DescriptorName"]
                mesh_terms.append(str(descriptor_name))

mesh_terms

In [None]:
handle = Entrez.efetch(db="pubmed", id=21406103, rettype='medline', retmode="text")
article_data = Medline.parse(handle)

In [None]:
for record in article_data:
    mesh = record.get('MH', '?')
    abstract = record.get('AB', '?')

print(mesh)
print(abstract)

In [None]:
# Get fixed mesh terms

def get_fixed_mesh_terms(pmid):
    handle = Entrez.efetch(db="pubmed", id=pmid, rettype='medline', retmode="text")
    article_data = Medline.parse(handle)

    for record in article_data:
        mesh_terms = record.get('MH', None)

    return mesh_terms


In [None]:
output = dataset['PMID'].apply(get_fixed_mesh_terms)

In [None]:
dataset['MeSH Terms'] = output

In [None]:
other_data = pd.read_csv('datasets/staging_test_set.csv')

In [None]:
other_data['MeSH Terms'] = output