# Full Text Search + Extraction, pt ii

In this notebook I will reconnect FTS with the GPT extraction function, and evaluate the whole process across single and multi group studies.

To improve selection of the correct chunk, we'll use the following logic:

- Attempt to use heuristic to narrow down chunk, and search across entire document as fall back
- Feed one section to GPT, and if nothing is return, try next nearest section
- Evaluate using various output "cleaning rules"
- Try using GPT4 as well to see if there's improvements

In [1]:
import pandas as pd
import numpy as np
from labelrepo.projects.participant_demographics import get_participant_demographics

# Load annotations
subgroups = get_participant_demographics(include_locations=True)

# Load multi group as well
jerome_pd = subgroups[(subgroups.project_name == 'participant_demographics') & \
                      (subgroups.annotator_name == 'Jerome_Dockes')]

# Subset annotation df to only include studies with body annotations
subset_cols = ['count', 'diagnosis', 'group_name', 'subgroup_name', 'male count',
       'female count', 'age mean', 'age minimum', 'age maximum',
       'age median', 'pmcid']
jerome_pd_subset = jerome_pd[subset_cols].sort_values('pmcid')

In [3]:
jerome_pd_subset.pmcid.unique().shape

(264,)

## Search and extract across all documents

In [2]:
import openai
import pickle

openai.api_key = open('/home/zorro/.keys/open_ai.key').read().strip()
all_embeddings = pickle.load(open('data/all_embeddings.pkl', 'rb'))

In [3]:
from extract import search_extract
from templates import ZERO_SHOT_MULTI_GROUP
from evaluate import evaluate_predictions, clean_predictions

query = 'How many participants or subjects were recruited for this study?' 

### No heuristic

In [4]:
# predictions_full_search = search_extract(all_embeddings, query, **ZERO_SHOT_MULTI_GROUP, num_workers=3)
# predictions_full_search.to_csv('data/predictions_full_search.csv', index=False)
predictions_full_search = pd.read_csv('data/predictions_full_search.csv')
# Clean predictions
clean_preds = clean_predictions(predictions_full_search)

In [13]:
predictions_full_search

Unnamed: 0,count,diagnosis,group_name,subgroup_name,male count,female count,age mean,age range,age minimum,age maximum,age median,rank,start_char,end_char,pmcid
0,23.0,healthy,healthy,,,5.0,222.0,,,,,0,10707,11145,4330553
1,16.0,healthy,healthy,,,,,,,,,0,28053,28467,3183226
2,15.0,healthy,undergraduate students,,,5.0,,20-22 years,,,,0,9041,9577,6989437
3,10.0,,,,4.0,,20.8,,,,,0,11679,12501,6528067
4,20.0,healthy,healthy,,12.0,8.0,24.4,19-35,19.0,35.0,,0,7853,8437,3147157
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
252,6.0,PCH,patients,,,,,,,,,0,8676,10310,9461104
253,102.0,,,,36.0,66.0,21.5,20 to 23,20.0,23.0,,1,8012,10968,3775427
254,228.0,,,,,,,,,,,0,27161,31007,6492297
255,1.0,autism,patients,children,1.0,0.0,8.0,8,8.0,8.0,8.0,0,31900,35502,9407088


In [None]:
import pprint
from publang.evaluate import compare_by_pmcid, score_columns

def evaluate_predictions(predictions, annotations):
    # Subset to only pmcids in predictions
    diff = set(set(annotations.pmcid.unique()) - set(predictions.pmcid.unique()))
    annotations = annotations[annotations.pmcid.isin(predictions.pmcid.unique())]

    pred_n_groups = predictions.groupby('pmcid').size()
    n_groups = annotations.groupby('pmcid').size()
    correct_n_groups = (n_groups == pred_n_groups)
    more_groups_pred = (n_groups < pred_n_groups)
    less_groups_pred = (n_groups > pred_n_groups)
    
    ix_corr_n_groups = correct_n_groups[correct_n_groups == True].index
    ix_more_groups =  more_groups_pred[more_groups_pred == True].index
    ix_less_groups = less_groups_pred[less_groups_pred == True].index

    print(f"Exact match # of groups: {correct_n_groups.mean():.2f}\n",
    f"More groups predicted: {more_groups_pred.mean():.2f}\n",
    f"Less groups predicted: {less_groups_pred.mean():.2f}\n",
    f"Missing pmcids: {diff}\n"
    )

    # Compare by columns (matched accuracy)
    print("Column wise comparison of predictions and annotations (error):\n")
    pprint.pprint(compare_by_pmcid(annotations, predictions))


    res_mean, res_sum, counts = score_columns(annotations, predictions, scoring='mpe')

    # Compare by columns (count of pmcids with overlap)
    print("\nPercentage response given by pmcid:\n")
    pprint.pprint(counts)
    

    # Compare by columns (summed mean percentage error on sum by pmcid)
    print("\nSummed Mean percentage error:\n")
    pprint.pprint(res_mean)

    # Compare by columns (summed mean percentage error on mean by pmcid)
    print("\nAveraged Mean percentage error:\n")
    pprint.pprint(res_sum)

    return ix_corr_n_groups, ix_more_groups, ix_less_groups

In [5]:
ix_corr_n_groups, ix_more_groups, ix_less_groups = evaluate_predictions(clean_preds, jerome_pd_subset)

Exact match # of groups: 0.84
 More groups predicted: 0.09
 Less groups predicted: 0.07
 Missing pmcids: set()

Column wise comparison of predictions and annotations (error):

{'age maximum': 0.31,
 'age mean': 0.32,
 'age median': 0.21,
 'age minimum': 0.32,
 'count': 0.31,
 'diagnosis': 0.45,
 'female count': 0.37,
 'group_name': 0.21,
 'male count': 0.33,
 'subgroup_name': 0.98}

Percentage response given by pmcid:

{'age maximum': 0.7,
 'age mean': 0.75,
 'age median': 0.04,
 'age minimum': 0.7,
 'count': 1.0,
 'female count': 0.73,
 'male count': 0.78}

Summed Mean percentage error:

{'age maximum': 0.0,
 'age mean': 0.14,
 'age median': 0.0,
 'age minimum': 0.0,
 'count': 0.13,
 'female count': 0.06,
 'male count': 0.1}

Averaged Mean percentage error:

{'age maximum': 0.06,
 'age mean': 0.18,
 'age median': 0.0,
 'age minimum': 0.08,
 'count': 0.17,
 'female count': 0.09,
 'male count': 0.14}


### Heuristic - methods
uses heuristic to find methods section hen full text search using query

In [7]:
predictions_methods_fts = search_extract(all_embeddings, query, heuristic_strategy='methods', **ZERO_SHOT_MULTI_GROUP, num_workers=3)
predictions_methods_fts.to_csv('data/predictions_methods_fts.csv', index=False)
# predictions_methods_fts = pd.read_csv('data/predictions_methods_fts.csv')
# Clean predictions
predictions_methods_fts_clean = clean_predictions(predictions_methods_fts)

100%|█████████████████████████████████████████| 153/153 [02:51<00:00,  1.12s/it]


In [8]:
ix_corr_n_groups, ix_more_groups, ix_less_groups = evaluate_predictions(predictions_methods_fts_clean, jerome_pd_subset)

Exact match # of groups: 0.82
 More groups predicted: 0.11
 Less groups predicted: 0.07
 Missing pmcids: {5416685, 4352055}

Column wise comparison of predictions and annotations (error):

{'age maximum': 0.29,
 'age mean': 0.31,
 'age median': 0.17,
 'age minimum': 0.29,
 'count': 0.32,
 'diagnosis': 0.45,
 'female count': 0.37,
 'group_name': 0.22,
 'male count': 0.33,
 'subgroup_name': 0.99}

Percentage response given by pmcid:

{'age maximum': 0.71,
 'age mean': 0.77,
 'age median': 0.04,
 'age minimum': 0.72,
 'count': 1.0,
 'female count': 0.73,
 'male count': 0.77}

Summed Mean percentage error:

{'age maximum': 0.0,
 'age mean': 0.14,
 'age median': 0.0,
 'age minimum': 0.0,
 'count': 0.15,
 'female count': 0.06,
 'male count': 0.1}

Averaged Mean percentage error:

{'age maximum': 0.07,
 'age mean': 0.18,
 'age median': 0.0,
 'age minimum': 0.08,
 'count': 0.17,
 'female count': 0.09,
 'male count': 0.13}


### Heurisic - demographics section

In [9]:
predictions_demographics_fts = search_extract(all_embeddings, query, heuristic_strategy='demographics', **ZERO_SHOT_MULTI_GROUP, num_workers=3)
predictions_demographics_fts.to_csv('data/predictions_demographics_fts.csv', index=False)
# predictions_demographics_fts = pd.read_csv('data/predictions_demographics_fts.csv')
# Clean predictions
predictions_demographics_fts_clean = clean_predictions(predictions_demographics_fts)

100%|█████████████████████████████████████████| 153/153 [02:53<00:00,  1.13s/it]


In [10]:
ix_corr_n_groups, ix_more_groups, ix_less_groups = evaluate_predictions(predictions_demographics_fts_clean, jerome_pd_subset)

Exact match # of groups: 0.81
 More groups predicted: 0.12
 Less groups predicted: 0.07
 Missing pmcids: {5460048, 5416685, 4352055}

Column wise comparison of predictions and annotations (error):

{'age maximum': 0.3,
 'age mean': 0.32,
 'age median': 0.2,
 'age minimum': 0.31,
 'count': 0.32,
 'diagnosis': 0.48,
 'female count': 0.38,
 'group_name': 0.23,
 'male count': 0.34,
 'subgroup_name': 0.99}

Percentage response given by pmcid:

{'age maximum': 0.68,
 'age mean': 0.75,
 'age median': 0.04,
 'age minimum': 0.69,
 'count': 1.0,
 'female count': 0.72,
 'male count': 0.78}

Summed Mean percentage error:

{'age maximum': 0.0,
 'age mean': 0.14,
 'age median': 0.0,
 'age minimum': 0.01,
 'count': 0.15,
 'female count': 0.07,
 'male count': 0.1}

Averaged Mean percentage error:

{'age maximum': 0.04,
 'age mean': 0.17,
 'age median': 0.0,
 'age minimum': 0.06,
 'count': 0.23,
 'female count': 0.08,
 'male count': 0.13}


# GPT4 - No Heuristic

In [4]:
# predictions_gpt4 = search_extract(all_embeddings, query, **ZERO_SHOT_MULTI_GROUP, num_workers=1, model_name='gpt-4')
predictions_gpt4 = pd.read_csv('data/predictions_gpt4.csv', index=False)

# Clean predictions
predictions_gpt4 = clean_predictions(predictions_gpt4)

100%|█████████████████████████████████████████| 153/153 [34:50<00:00, 13.66s/it]


In [6]:
ix_corr_n_groups, ix_more_groups, ix_less_groups = evaluate_predictions(predictions_gpt4, jerome_pd_subset)

Exact match # of groups: 0.72
 More groups predicted: 0.23
 Less groups predicted: 0.05
 Missing pmcids: {6492297, 8978988, 8785614, 4352055}

Column wise comparison of predictions and annotations (error):

{'age maximum': 0.18,
 'age mean': 0.17,
 'age median': 0.04,
 'age minimum': 0.19,
 'count': 0.21,
 'diagnosis': 0.32,
 'female count': 0.2,
 'group_name': 0.09,
 'male count': 0.2,
 'subgroup_name': 0.95}

Percentage response given by pmcid:

{'age maximum': 0.75,
 'age mean': 0.92,
 'age median': 1.0,
 'age minimum': 0.76,
 'count': 1.0,
 'female count': 0.78,
 'male count': 0.79}

Summed Mean percentage error:

{'age maximum': 0.0,
 'age mean': 0.13,
 'age median': 0.0,
 'age minimum': 0.0,
 'count': 0.12,
 'female count': 0.06,
 'male count': 0.09}

Averaged Mean percentage error:

{'age maximum': 0.16,
 'age mean': 0.29,
 'age median': 0.0,
 'age minimum': 0.17,
 'count': 0.28,
 'female count': 0.25,
 'male count': 0.23}


In [None]:
pd.read_csv('data/pr

## Conclusion

Heuristic doesn't seem to improve things much, but does increase the chance that we don't find anything at all (when info is in Results section).

Could modify heuristic to fall back onto entire document if extraction comes  back null, but it doesn't seem to improve prediction much otherwise, so I'm not sure it's worthwhile. 

- GPT4 does marginally better but misses some docs entirely

### TODOs
- Refactor code into package
- 

# Manual results revision