In [None]:
import pandas as pd
pd.set_option('display.max_colwidth', None)

import re
import os
import csv

from tqdm.notebook import tqdm

In [None]:
DATA_PATH = 'outputs/dryad.csv'

data = pd.read_csv(DATA_PATH, lineterminator="\n")

print(f'Loaded dataset with {len(data)} rows')

In [None]:
with open('EDAM/edam_topics.txt', 'r') as f:
    edam_topics = [topic.strip() for topic in f.readlines()]

quoted_topics = [topic for topic in edam_topics if topic.startswith('"') and topic.endswith('"')]

# Remove quotes
edam_topics = [topic[1:-1] if topic.startswith('"') and topic.endswith('"') else topic for topic in edam_topics]

## Format Outputs

Split outputs on tab, and check for other separators that GPT may have used in error.

In [None]:
data['Predictions'] = data['Predictions'].str.replace('\\t', '\t')


In [None]:
def split_topics(topics):
    cleaned_topics = [topic.strip() for topic in topics.split('\t')]
    for i in range(len(cleaned_topics)):
        for quoted_topic in quoted_topics:
            if quoted_topic.replace('\"', '').lower() in cleaned_topics[i].lower():
                cleaned_topics[i] = cleaned_topics[i].replace(quoted_topic.replace('\"', ''), quoted_topic)
                break
            else:
                cleaned_topics[i] = cleaned_topics[i].replace('\"', '')
    return cleaned_topics

data['Predictions'] = data['Predictions'].apply(split_topics)

In [None]:
separators = ['    ', '   ', '  ', '\n', '<TAB>', 'TAB', '<tab>', '(tab)', '<Tab>', '[tab]', '▪️', '(Tab)', '\xa0\xa0\xa0\xa0', '\xa0', '\u2003', '、', '\x0b', '\x0c', ';', '.', '--', '-', '–', '_', '\\', '\\n', '/', '@', '|', '\r', '+', '<', '>']

# Join the separators with the regex OR operator |
sep_pattern = '|'.join(map(re.escape, separators))

for i in tqdm(range(len(data))):
    split_list = [re.split(sep_pattern, pred) for pred in data['Predictions'][i]]
    # Flatten the list
    data['Predictions'][i] = [item for sublist in split_list for item in sublist]

data['Predictions'] = data['Predictions'].apply(lambda x: [item.strip() for item in x])

data['Predictions'] = data['Predictions'].apply(lambda x: [re.sub(r'Category \d+:', '', pred) for pred in x])

In [None]:
is_one_dimensional = all(isinstance(pred, str) for preds in data['Predictions'] for pred in preds)
print(is_one_dimensional)

In [None]:
# ## Capture any weirdly formatted outputs (using the wrong separators)

# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('    ') if 0 < len(x) <= 1 and '    ' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('   ') if 0 < len(x) <= 1 and '   ' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('  ') if 0 < len(x) <= 1 and '  ' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('\n') if 0 < len(x) <= 1 and '\n' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('<TAB>') if 0 < len(x) <= 1 and '<TAB>' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('TAB') if 0 < len(x) <= 1 and 'TAB' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('<tab>') if 0 < len(x) <= 1 and '<tab>' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('(tab)') if 0 < len(x) <= 1 and '(tab)' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('<Tab>') if 0 < len(x) <= 1 and '<Tab>' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('[tab]') if 0 < len(x) <= 1 and '[tab]' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('▪️') if 0 < len(x) <= 1 and '▪️' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('(Tab)') if 0 < len(x) <= 1 and '<Tab>' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('\xa0\xa0\xa0\xa0') if 0 < len(x) <= 1 and '\xa0\xa0\xa0\xa0' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('\xa0') if 0 < len(x) <= 1 and '\xa0' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('\u2003') if 0 < len(x) <= 1 and '\u2003' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('、') if 0 < len(x) <= 1 and '、' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('\x0b') if 0 < len(x) <= 1 and '\x0b' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('\x0c') if 0 < len(x) <= 1 and '\x0c' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split(';') if 0 < len(x) <= 1 and ';' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('.') if 0 < len(x) <= 1 and '.' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('--') if 0 < len(x) <= 1 and '--' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('-') if 0 < len(x) <= 1 and '-' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('–') if 0 < len(x) <= 1 and '–' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('_') if 0 < len(x) <= 1 and '_' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('\\') if 0 < len(x) <= 1 and '\\' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('\\n') if 0 < len(x) <= 1 and '\\n' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('/') if 0 < len(x) <= 1 and '/' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('@') if 0 < len(x) <= 1 and '@' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('|') if 0 < len(x) <= 1 and '|' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('\r') if 0 < len(x) <= 1 and '\r' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('+') if 0 < len(x) <= 1 and '+' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('<') if 0 < len(x) <= 1 and '<' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split('>') if 0 < len(x) <= 1 and '>' in list(x)[0] else x)
# data['Predictions'] = data['Predictions'].apply(lambda x: [pred.strip() for pred in csv.reader([list(x)[0]],skipinitialspace=True, delimiter=',', quotechar='"').__next__()])
# data['Predictions'] = data['Predictions'].apply(lambda x: [re.sub(r'Category \d+:', '', pred) for pred in x])
# # data['Predictions'] = data['Predictions'].apply(lambda x: list(x)[0].split(', ') if len(x) <= 1 and ', ' in list(x)[0] else x)

In [None]:
# check if any of the quoted topics, or their equivalent without quotes, 
# are in any of the prediction sets with length less than or equal to 1. ""
# If there is, then add the quotes back in if they don't have them,
#  and then split on commas while avoiding anything inside quotes
def process_predictions(predictions):
    processed_predictions = []
    for prediction in predictions:
        formatted = False
        for topic in quoted_topics:
            formatted_topic = topic.replace('\"', '')
            if formatted_topic in prediction:
                processed_prediction = prediction.replace(formatted_topic, f'{topic}')
                processed_predictions.append(processed_prediction)
                formatted = True
                break
        if not formatted:
            processed_predictions.append(prediction)
        
    final_predictions = []
    for prediction in processed_predictions:
        if '\"' in prediction:
            parts = re.findall(r'[^"]+|"[^"]+"', prediction)
            final_predictions.extend(parts)
        else:
            final_predictions.extend([pred.strip() for pred in prediction.split(',')])
    return set(final_predictions)

data['Predictions'] = data['Predictions'].apply(process_predictions)

In [None]:
filtered_predictions = data[data['Predictions'].apply(len) <= 1]['Predictions']

# Filter out any expected predictions, so we can see only the unexpected ones
unexpected_predictions = []
for original_index, pred_set in filtered_predictions.items():
    lst = list(pred_set)
    if len(lst) == 0:
        continue
    prediction = lst[0]
    if '\"' not in prediction and ' ' in prediction and prediction not in edam_topics:
        unexpected_predictions.append((original_index, prediction))

# Print the unexpected predictions and their corresponding original indices
count = len(unexpected_predictions)
print(f"Number of unexpected predictions: {count}")
for original_index, prediction in unexpected_predictions:
    print(f"Original Index: {original_index}, Prediction: {prediction}")

## Hallucinations

Filter out topics not in the EDAM topics list. The filtered topics may be matched to a topic or synonym->topic in the next section.

In [None]:
data['Hallucinations'] = data['Predictions'].apply(lambda preds: set([pred.replace('.', '').replace('\"', '') for pred in preds if pred.replace('.', '').replace('\"', '') not in edam_topics]))

In [None]:
data['Predictions'] = data['Predictions'].apply(lambda preds: set([pred.replace('.', '').replace('\"', '') for pred in preds if pred.replace('.', '').replace('\"', '') in edam_topics]))
data['Predictions'] = data.apply(lambda row: set([topic for topic in row['Predictions'] if topic not in row['Hallucinations']]), axis=1)

## Synonym matching

Check for mispelled/misformatted topics or synonyms using levenshtein

In [None]:
edam = pd.read_csv('EDAM/EDAM.csv')

edam = edam[edam['Class ID'].str.contains('topic')].reset_index(drop=True)
# edam['Preferred Label'].apply(lambda topic: topic.replace('\"', ''))
edam = edam[edam['Preferred Label'].isin([topic.replace('\"', '') for topic in edam_topics])].reset_index(drop=True)

In [None]:
edam['Synonyms'] = edam['Synonyms'].fillna('').apply(lambda x: x.split('|') if x != '' else [])

In [None]:
missing_topics = set([topic.replace('\"', '') for topic in edam_topics]) - set(edam['Preferred Label'])
missing_topics

In [None]:
synonym_dict = {}

for index, row in edam.iterrows():
    for synonym in row['Synonyms']:
        synonym_dict[synonym] = row['Preferred Label']

In [None]:
synonym_dict

In [None]:
import Levenshtein

hallucinations = data['Hallucinations']

matched_topics = {}

for hallucination_set in tqdm(hallucinations):
    for hallucination in hallucination_set:
        if hallucination in matched_topics:
            continue
        matched = False
        # First check for a match in the topics list
        sorted_topics = sorted(edam_topics, key=lambda topic: Levenshtein.distance(hallucination, topic))
        for topic in sorted_topics:
            distance = Levenshtein.distance(hallucination, topic)
            if  0 < distance <= 2:
                matched_topics[hallucination] = topic
                matched = True
                break
        
        # If the hallucination has already been matched, skip to the next hallucination
        if matched:
            continue
        
        # If no match in the topics list, look through the available synonyms
        sorted_synonyms = sorted(synonym_dict.keys(), key=lambda topic: Levenshtein.distance(hallucination, topic))
        for topic in sorted_synonyms:
            distance = Levenshtein.distance(hallucination, topic)
            if 0 <= distance <= 1:
                matched_topics[hallucination] = synonym_dict[topic]
                matched = True
                break     

        if matched:
            continue

        for topic in sorted_topics:
            if topic.lower() in hallucination.lower().split():
                matched_topics[hallucination] = topic
                break
        # No break reached
        else:
            for topic in sorted_synonyms:
                if topic.lower() in hallucination.lower().split():
                    matched_topics[hallucination] = synonym_dict[topic]
                    break

matched_topics

In [None]:
for index, row in data.iterrows():
    if len(row['Hallucinations']) > 0:
        for hallucination in list(row['Hallucinations']):
            if hallucination in matched_topics:
                print(f"'{hallucination}' in row {index} matches topic '{matched_topics[hallucination]}'")
                data.at[index, 'Predictions'].add(matched_topics[hallucination])
                continue

In [None]:
data[data['Hallucinations'].apply(len) > 1]

In [None]:
# Add quotes back in for the predictions with commas (i.e, "Data submission, annotation, and curation")
data['Predictions'] = data['Predictions'].apply(lambda preds: [f'"{pred}"' if f'"{pred}"' in quoted_topics else pred for pred in preds])

In [None]:
data[data['Hallucinations'].apply(len) > 0][['Predictions', 'Hallucinations']]

In [None]:
data['Predictions'] = data['Predictions'].apply(lambda lst: set(lst))

In [None]:
file_name = os.path.basename(DATA_PATH).replace('.', '_processed.')

data.to_csv(f'outputs/{file_name}', index=False)