## Evaluating multi-group promps

In [1]:
import pandas as pd
import numpy as np
from labelrepo.projects.participant_demographics import get_participant_demographics
from labelrepo.database import get_database_connection

subgroups = get_participant_demographics(include_locations=True)
docs_info = pd.read_sql(
    "select pmcid, publication_year, title, text from document",
    get_database_connection(),

)

In [2]:
# Load multi group as well
jerome_pd = subgroups[(subgroups.project_name == 'participant_demographics') & \
                      (subgroups.annotator_name == 'Jerome_Dockes')]

counts = jerome_pd.groupby('pmcid').count().reset_index()
multi_group_pmcids = counts[counts['count'] > 1].pmcid
multi_group = jerome_pd[jerome_pd.pmcid.isin(multi_group_pmcids)]

# Get multi group docs
multi_group_docs = docs_info[docs_info.pmcid.isin(multi_group.pmcid)]

## Embed documents

In [105]:
import openai
openai.api_key = open('/home/zorro/.keys/open_ai.key').read().strip()

In [185]:
from embed import embed_pmc_articles, query_embeddings
texts = multi_group_docs[['pmcid', 'text']].to_dict(orient='records')
# embeddings = embed_pmc_articles(texts)
import pickle
# pickle.dump(embeddings, open('data/multi_group_embeddings.pkl', 'wb'))
embeddings = pickle.load(open('data/multi_group_embeddings.pkl', 'rb'))

In [5]:
embeddings = pd.DataFrame(embeddings)

## Extract group demographics across all docs 

In [287]:
# from extract import extract_from_match
# sections, predictions = extract_on_match(embeddings, multi_group, **ZERO_SHOT_MULTI_GROUP, num_workers=3)
# pickle.dump(predictions, open('data/full_text_gpt4_multi_group.pkl', 'wb'))
predictions = pickle.load(open('data/full_text_gpt4_multi_group.pkl', 'rb'))

In [282]:
# Clean up predictions

def _clean_predictions(predictions):
    # Clean known issues with GPT predictions

    predicitons = predictions.copy()
    
    predictions = predictions.fillna(value=np.nan)
    predictions['group_name'] = predictions['group_name'].fillna('healthy')

    # If group name is healthy, blank out diagnosis
    predictions.loc[predictions.group_name == 'healthy', 'diagnosis'] = np.nan
    predictions = predictions.replace(0.0, np.nan)

    # Drop rows where count is NA
    predictions = predictions[~pd.isna(predictions['count'])]

    # Set group_name to healthy if no diagnosis

    predictions.loc[(predictions['group_name'] != 'healthy') & (pd.isna(predictions['diagnosis'])), 'group_name'] = 'healthy'

    return predictions

In [283]:
clean_preds = _clean_predictions(predictions)

# Evaluation

In [119]:
# Subset annotation df to only include studies with body annotations
subset_cols = ['count', 'diagnosis', 'group_name', 'subgroup_name', 'male count',
       'female count', 'age mean', 'age minimum', 'age maximum',
       'age median', 'pmcid']
sub_multi_group = multi_group[multi_group.pmcid.isin(predictions.pmcid)][subset_cols].sort_values('pmcid')

In [207]:
from evaluate import evaluate_predictions

In [294]:
ix_corr_n_groups, ix_more_groups, ix_less_groups = evaluate_predictions(predictions, sub_multi_group)

Exact match # of groups: 0.86
 More groups predicted: 0.09
 Less groups predicted: 0.06

{'age maximum': 0.63,
 'age mean': 0.75,
 'age median': 0.77,
 'age minimum': 0.63,
 'count': 0.8,
 'diagnosis': 0.08,
 'female count': 0.65,
 'group_name': 0.65,
 'male count': 0.62,
 'subgroup_name': 0.03}


In [293]:
# With "cleaned up" predictions

_ = evaluate_predictions(clean_preds, sub_multi_group)

Exact match # of groups: 0.88
 More groups predicted: 0.06
 Less groups predicted: 0.06

{'age maximum': 0.71,
 'age mean': 0.74,
 'age median': 0.85,
 'age minimum': 0.71,
 'count': 0.8,
 'diagnosis': 0.52,
 'female count': 0.75,
 'group_name': 0.79,
 'male count': 0.67,
 'subgroup_name': 0.03}


### Try on single group

In [211]:
single_group_embeddings = pickle.load(open('data/single_group_embeddings.pkl', 'rb'))
single_group_embeddings = pd.DataFrame(single_group_embeddings)

single_group_pmcids = counts[counts['count'] == 1].pmcid
single_group = jerome_pd[jerome_pd.pmcid.isin(single_group_pmcids)]

In [214]:
single_sections, single_preds = extract_on_match(single_group_embeddings, single_group, **ZERO_SHOT_MULTI_GROUP, num_workers=3)

100%|███████████████████████████████████████████| 69/69 [00:47<00:00,  1.46it/s]


In [220]:
sub_single_group = single_group[single_group.pmcid.isin(single_preds.pmcid)][subset_cols].sort_values('pmcid')

In [215]:
# pickle.dump(single_preds, open('data/full_text_gpt4_single_group_2.pkl', 'wb'))

In [290]:
clean_single_preds = _clean_predictions(single_preds)

In [225]:
ix_corr_n_groups_single, ix_more_groups_single, _= evaluate_predictions(single_preds, sub_single_group)

Exact match # of groups: 0.86
 More groups predicted: 0.14
 Less groups predicted: 0.00

{'age maximum': 0.78,
 'age mean': 0.72,
 'age median': 0.81,
 'age minimum': 0.78,
 'count': 0.74,
 'diagnosis': 0.06,
 'female count': 0.59,
 'group_name': 0.51,
 'male count': 0.64,
 'subgroup_name': 0.0}


In [292]:
# With "cleaned up" predictions
_= evaluate_predictions(clean_single_preds, sub_single_group)

Exact match # of groups: 0.88
 More groups predicted: 0.12
 Less groups predicted: 0.00

{'age maximum': 0.83,
 'age mean': 0.72,
 'age median': 0.87,
 'age minimum': 0.83,
 'count': 0.74,
 'diagnosis': 0.68,
 'female count': 0.64,
 'group_name': 0.78,
 'male count': 0.67,
 'subgroup_name': 0.0}


### Error types:

- Many times multiple rows get created for a single group study, each time a new "revised" sample size is mention (i.e. after exclusion)
- 2x the human annotation had less detai (i.e. total N instead sub group chars, that were combined into 1 group)

### Compare labels predicted vs given, on subset w/ matching group numbers

In [123]:
# Subset to only look at studies w/ same # of groups
subset_pred_corrgroups = predictions[predictions.pmcid.isin(ix_corr_n_groups)]
subset_annot_corrgroups = sub_multi_group[sub_multi_group.pmcid.isin(ix_corr_n_groups)]

In [126]:
annot_diag = subset_annot_corrgroups[(subset_annot_corrgroups.group_name == 'patients')][['count', 'pmcid', 'diagnosis']]
pred_diag = subset_pred_corrgroups[subset_pred_corrgroups.group_name == 'patients'][['count', 'pmcid', 'diagnosis']]
pd.merge(annot_diag, pred_diag, on=['pmcid', 'count'])

Unnamed: 0,count,pmcid,diagnosis_x,diagnosis_y
0,15,2648877,Autism Spectrum Disorder,Autism Spectrum Disorder
1,10,3742334,chronic marijuana use,chronic marijuana (MJ) users
2,17,3877773,Tourette syndrome,TS
3,16,3984441,dementia with Lewy bodies,DLB
4,30,4190683,Attention Deficit Hyperactivity Disorder,ADHD
5,18,4265725,Autism spectrum conditions,ASC
6,32,4473263,chronic left-hemisphere stroke,chronic left-hemisphere stroke
7,11,4589842,Depression,depressed patients
8,42,5394595,bronchial asthma without acute attacks,bronchial asthma
9,34,5413198,Alcohol dependence,ADP


In [127]:
annot_diag = subset_annot_corrgroups[(subset_annot_corrgroups.group_name == 'healthy')][['count', 'pmcid', 'diagnosis']]
pred_diag = subset_pred_corrgroups[subset_pred_corrgroups.group_name == 'healthy'][['count', 'pmcid', 'diagnosis']]
pd.merge(annot_diag, pred_diag, on=['pmcid', 'count'])

Unnamed: 0,count,pmcid,diagnosis_x,diagnosis_y
0,19,2561002,,
1,15,2561002,,
2,18,2648877,,non-autistic control
3,24,3502502,,healthy
4,44,3511796,,healthy
5,18,3742334,,nonusing (NU) comparison subjects
6,15,3877773,,HC
7,17,3984441,,controls
8,11,4589842,,healthy volunteers
9,35,5292583,,


### Other groups

In [182]:
subset_pred_corrgroups[subset_pred_corrgroups.group_name.isin(['healthy', 'patients']) == False][['count', 'pmcid', 'diagnosis', 'group_name']]


Unnamed: 0,count,pmcid,diagnosis,group_name
12,23.0,8933759,healthy adults,controls
48,20.0,7485713,,controls
53,20.0,6820536,fibromyalgia,FM
54,20.0,6820536,,HC
62,24.0,3965851,,participants
63,18.0,3965851,,participants
81,260.0,3991323,,adolescent
82,28.0,3991323,,adult
83,164.0,4174863,,incarcerated
84,46.0,4174863,,non-incarcerated


### Error types
- Will often fail to set name to `healthy` for controls.
   - Can probably come up with a whitelist of terms that can be translated to `healthy`:
      - ['control', 'volunteers', 'subjects', 'participants', 'adults'] + if diagnosis == `n/a` 

## Look at example of failures

### When model predicted more groups

In [145]:
sub_multi_group[sub_multi_group.pmcid.isin(ix_more_groups)].sort_values('pmcid')

Unnamed: 0,count,diagnosis,group_name,subgroup_name,male count,female count,age mean,age minimum,age maximum,age median,pmcid
130,16,autism spectrum disorder,patients,_,16.0,,,8.0,18.0,,5279905
131,16,,healthy,_,,,,,,,5279905
136,19,unipolar major depression,patients,validation study,,,,18.0,65.0,,5416685
135,20,unipolar major depression,patients,primary study,,,,18.0,65.0,,5416685
138,19,,healthy,validation study,,,,18.0,65.0,,5416685
137,19,,healthy,primary study,,,,18.0,65.0,,5416685
90,15,,healthy,_,,,,60.0,70.0,,5665859
89,72,ischemic stroke,patients,_,,,,60.0,70.0,,5665859
20,39,,healthy,tsd,22.0,17.0,33.5,,,,7260173
21,15,,healthy,control,7.0,8.0,34.5,,,,7260173


In [146]:
predictions[predictions.pmcid.isin(ix_more_groups)].sort_values('pmcid')

Unnamed: 0,count,diagnosis,group_name,subgroup_name,male count,female count,age mean,age range,age minimum,age maximum,age median,pmcid
132,16.0,TD,healthy,,16.0,,,,,,,5279905
131,16.0,ASD,patients,,16.0,,13.0,8-18,8.0,18.0,13.0,5279905
130,30.0,healthy,healthy,children,15.0,15.0,9.0,6-13,6.0,13.0,9.0,5279905
129,20.0,ASD,patients,children,10.0,10.0,8.0,5-12,5.0,12.0,8.0,5279905
22,20.0,,healthy,validation sample,,,,,,,,5416685
18,,,,,,,,,,,,5416685
19,20.0,unipolar major depression,patients,primary study,,,,,,,,5416685
20,19.0,unipolar major depression,patients,validation study,,,,,,,,5416685
21,19.0,,healthy,primary sample,,,,,,,,5416685
74,15.0,healthy subjects,healthy,,,,,,,,,5665859


In [164]:
sections[sections.pmcid == 9230060].iloc[0].content

'\n## 2. Materials and Methods \n  \nThis prospective study was approved by the local institutional review board of Kaohsiung Veterans General Hospital (protocol number: VGHKS93-CT2-09). This study enrolled 20 right-handed male offenders who were further separated into 3 sub-groups by their records of court verdict: six offenders had committed affective violence (VA group), seven offenders had committed predatory violence (VP group), and seven had committed non-violent crime (NV group). The VA offenders were defined as the subjects who impulsively or affectively committed violent crime. The VP offenders were defined as the subjects who purposely planned to commit violent crime with detailed documentation in the court records. The NV offenders were defined as the subjects who committed non-violent crime based on the documents in court record. In addition, 20 age-matched right-handed male non-criminal healthy controls (HC group) were enrolled for comparisons. All subjects underwent psych

### Error types
- Sometimes there is repeated information (for example, in Methods and Results). Not clear which section to "trust". In addition there are confusions about total N. 
- It seems to create a group if given the introduction, but it oftne lacks count. Which means we can probably drop those observations
- Final N and exclusion criteria always seems to confuse the models. In the case of multigroup prompt, instead of choosing one, it often will invent another group -- how to handle conflicting information?
- Also: it OFTEN, invents group_names other than `patients` or `healthy`. I think we can assume that if `diagnosis == NaN` or `healthy` then group_name can be corrected to `healthy`, and if there is a diagnosis then group_name should be patient. However, there will be rare times when something meaning healthy is given to diagnosis, and group name is something like 'control'.


Ideas for finding correction section:
- If can find "Methods" in body--- use that and search within that, or sections that follow.
- Also, should save subsection prior to splitting up-- save all headers to be able to select methods section
- If not, then search over entire document.
- Iterate until some information is generated, and iterate if there's no hits


### When model predicted fewer groups

In [147]:
sub_multi_group[sub_multi_group.pmcid.isin(ix_less_groups)].sort_values('pmcid')

Unnamed: 0,count,diagnosis,group_name,subgroup_name,male count,female count,age mean,age minimum,age maximum,age median,pmcid
57,20,,healthy,dataset 1,,,,,,,4352055
58,19,,healthy,dataset 2,,,,,,,4352055
205,19,,healthy,work with children,9.0,10.0,,,,,7395771
206,19,,healthy,no work with children,9.0,10.0,,,,,7395771
101,88,attention/deficit hyperactivity disorder,patients,_,,,,,,,8782893
102,79,,healthy,_,,,,,,,8782893
163,573,,healthy,abide 1,,,,,,,9407088
164,593,,healthy,abide 2,,,,,,,9407088
161,539,autism spectrum disorder,patients,abide 1,,,,,,,9407088
162,521,autism spectrum disorder,patients,abide 2,,,,,,,9407088


In [148]:
predictions[predictions.pmcid.isin(ix_less_groups)].sort_values('pmcid')

Unnamed: 0,count,diagnosis,group_name,subgroup_name,male count,female count,age mean,age range,age minimum,age maximum,age median,pmcid
122,20.0,,participants,,,,,,,,,4352055
64,38.0,healthy,adults,childless,19.0,19.0,24.08,3.33,19.0,29.0,24.0,7395771
9,167.0,ADHD,children and adolescents,,,,,,,,,8782893
3,539.0,ASD patients,patients,ASD patients,300.0,239.0,25.0,18-40,18.0,40.0,26.0,9407088
4,573.0,normal controls,healthy,normal controls,300.0,273.0,28.0,20-45,20.0,45.0,29.0,9407088


In [173]:
sections[sections.pmcid == 9407088].iloc[0].content

'\n## 3. Experiments \n  \n### 3.1. Data \n  \nThe data used in this article come from the Autism Brain Imaging Data Exchange (ABIDE) project [ ]. This project aims to accelerate the understanding of the deep brain mechanism of autism spectrum disorder (ASD), which integrates the brain structure and functional imaging data from many laboratories around the world. The abide I and abide II datasets were used. The abide I data were collected from 17 centers, including 1112 subjects, including 539 ASD patients and 573 normal controls; the abide II dataset was collected from 19 centers, including 1114 subjects, including 521 ASD patients and 593 normal controls. \n\nTo construct a 3D custom fMRI template using fMRI time-series data, we first need to preprocess the data. The basic idea of fMRI data preprocessing is to eliminate the timing error of the interlayer scanning and the head movement error caused by the subject’s head movement in the scanning process, and then carry out the time-lay

### Error types
- Combining two groups into 1 (not necessarily wrong)-- happened 2x, usually combining control group w/ patient group
- Missed a dataset (abide)

Overall, not much we can do about this