## Evaluating multi-group promps

In [1]:
import pandas as pd
import numpy as np
from labelrepo.projects.participant_demographics import get_participant_demographics
from labelrepo.database import get_database_connection

subgroups = get_participant_demographics(include_locations=True)
docs_info = pd.read_sql(
    "select pmcid, publication_year, title, text from document",
    get_database_connection(),

)

In [2]:
# Load multi group as well
jerome_pd = subgroups[(subgroups.project_name == 'participant_demographics') & \
                      (subgroups.annotator_name == 'Jerome_Dockes')]

counts = jerome_pd.groupby('pmcid').count().reset_index()
multi_group_pmcids = counts[counts['count'] > 1].pmcid
multi_group = jerome_pd[jerome_pd.pmcid.isin(multi_group_pmcids)]

# Get multi group docs
multi_group_docs = docs_info[docs_info.pmcid.isin(multi_group.pmcid)]

## Embed documents

In [3]:
import openai
openai.api_key = open('/home/zorro/.keys/open_ai.key').read().strip()

In [3]:
from embed import embed_pmc_articles, query_embeddings
texts = multi_group_docs[['pmcid', 'text']].to_dict(orient='records')
# embeddings = embed_pmc_articles(texts)

In [4]:
import pickle
# pickle.dump(embeddings, open('data/multi_group_embeddings.pkl', 'wb'))
embeddings = pickle.load(open('data/multi_group_embeddings.pkl', 'rb'))

In [5]:
embeddings = pd.DataFrame(embeddings)

## Try various prompts on example

In [6]:
from templates import ZERO_SHOT_MULTI_GROUP
from extract import extract_from_text

In [7]:
example_id = 8883821
target = multi_group[multi_group.pmcid == example_id][['group_name', 'diagnosis', 'count', 'male count', 'age mean', 'female count', 'age minimum', 'age maximum']]
target

Unnamed: 0,group_name,diagnosis,count,male count,age mean,female count,age minimum,age maximum
250,patients,schizophrenia,28,20.0,,8.0,21.0,54.0
251,patients,autism spectrum disorder,20,16.0,,4.0,19.0,43.0
252,healthy,,30,22.0,,8.0,19.0,54.0


In [8]:
section = embeddings[embeddings.pmcid == example_id].iloc[6].content
section

'\n## Materials and Methods \n  \n### Participants \n  \nSeventy-eight participants took part, including 28 participants (age range 21–54 years) diagnosed with schizophrenia (Sz) using the Structured Clinical Interview for DSM-IV ( ), 20 adults with autism spectrum disorder (ASD) (age range 19–43 years), confirmed by the Autism Diagnostic Observation Schedule, Second Edition, and 30 neurotypical controls (age range 19–54 years) ( ). All Sz participants were on a stable dose of antipsychotic medication. All participants had at least 20/22 corrected visual acuity on a Logarithmic Visual Acuity Chart. On average, Sz participants were older [  F  (1, 56) = 7.24,   p   = 0.009] and had lower IQ scores [  F  (1, 56) = 6.54,   p   = 0.013] than controls. All ASD participants and a subset of 19 Sz and 17 controls participated in our previous EEG/fMRI study of visual sensory dysfunction as reported in  , which did not include data from the present paradigm. Participants were recruited from the 

In [9]:
#extract_from_text(section, **ZERO_SHOT_MULTI_GROUP)

## Extract group demographics across all docs 

In [10]:
from extract import extract_from_multiple
def extract_on_match(embeddings_df, annotations_df, messages, parameters, model_name="gpt-3.5-turbo", num_workers=1):
    body_df = embeddings_df[embeddings_df.section_name == 'Body']

    # Find first chunk that contains a true annotation
    # Doing this to focus on zero-shot learning
    sections = []
    for pmcid, df in body_df.groupby('pmcid'):
        annot  = annotations_df[annotations_df.pmcid == pmcid].iloc[0]

        for start, end in  zip(annot['start_char'], annot['end_char']):
            m = df[(df.start_char <= start) & (df.end_char >= end)]
            if not m.empty:
                sections.append(m)
                break
    sections = pd.concat(sections)

    res = extract_from_multiple(sections.content.to_list(), messages, parameters, 
                          model_name=model_name, num_workers=num_workers, return_type='list')

    # Combine results into single df and add pmcid
    pred_groups_df = []
    for ix, res in enumerate(predictions):
        rows = res['groups']
        pmcid = sections.iloc[ix]['pmcid']
        for row in rows:
            row['pmcid'] = pmcid
            pred_groups_df.append(row)
    pred_groups_df = pd.DataFrame(pred_groups_df)

    return sections, pred_groups_df

In [11]:
# sections, predictions = extract_on_match(embeddings, multi_group, **ZERO_SHOT_MULTI_GROUP, num_workers=3)

In [12]:
# pickle.dump(predictions, open('data/full_text_gpt4_multi_group.pkl', 'wb'))
predictions = pickle.load(open('data/full_text_gpt4_multi_group.pkl', 'rb'))

In [15]:
# Clean up predictions
predictions = predictions.fillna(value=np.nan)
predictions['group_name'] = predictions['group_name'].fillna('healthy')
predictions = predictions.replace(0.0, np.nan)

#### Evaluation

In [16]:
# Subset annotation df to only include studies with body annotations
subset_cols = ['count', 'diagnosis', 'group_name', 'subgroup_name', 'male count',
       'female count', 'age mean', 'age minimum', 'age maximum',
       'age median', 'pmcid']
sub_multi_group = multi_group[multi_group.pmcid.isin(predictions.pmcid)][subset_cols].sort_values('pmcid')

In [17]:
# How accurately does it predict the # of groups for each study?
pred_n_groups = predictions.groupby('pmcid').size()
n_groups = sub_multi_group.groupby('pmcid').size()
correct_n_groups = (n_groups == pred_n_groups)
ix_corr_n_groups = correct_n_groups[correct_n_groups == True].index
correct_n_groups.mean()

0.8695652173913043

#### Within the studies w/ the correct # of groups

In [18]:
subset_pred_corrgroups = predictions[predictions.pmcid.isin(ix_corr_n_groups)]

In [29]:
subset_pred_corrgroups.shape

(132, 12)

In [19]:
subset_annot_corrgroups = sub_multi_group[sub_multi_group.pmcid.isin(ix_corr_n_groups)]

### Overlap between group values within each PMCID

In [20]:
from collections import defaultdict

def compare(a,b):
    return (pd.isna(a) & pd.isna(b)) or a == b

def isin(a, li):
    """ Nan safe is in that return second value """
    for b in li:
        if compare(a, b):
            return b

    return False

def compare_by_pmcid(df1, df2):
    """ Compute # of matches of values for each pmcid, for each column """
    res = defaultdict(float)
    for pmcid, df in df1.groupby('pmcid'):
        for col in df:
            if col != 'pmcid':
                match_df2 = df2[df2.pmcid == pmcid][col].to_list()
                score = 0
                for v in df[col]:
                    rem_val = isin(v, match_df2)
                    if rem_val:
                        score += 1
                        match_df2.remove(rem_val)
                res[col] += score

    res = {k: np.round(v / len(df1), 2) for k,v in res.items()}
    
    return res

In [21]:
len(sub_multi_group)

153

In [22]:
# Percentage overlap by PMCID, for those with exact same # of groups
compare_by_pmcid(subset_annot_corrgroups, subset_pred_corrgroups)

{'count': 0.86,
 'diagnosis': 0.27,
 'group_name': 0.76,
 'subgroup_name': 0.03,
 'male count': 0.73,
 'female count': 0.82,
 'age mean': 0.8,
 'age minimum': 0.72,
 'age maximum': 0.73,
 'age median': 0.86}

In [23]:
# Percentage overlap by PMCID, OVERALL
compare_by_pmcid(sub_multi_group, predictions)

{'count': 0.78,
 'diagnosis': 0.27,
 'group_name': 0.69,
 'subgroup_name': 0.03,
 'male count': 0.66,
 'female count': 0.75,
 'age mean': 0.75,
 'age minimum': 0.68,
 'age maximum': 0.69,
 'age median': 0.8}

### Compare labels predicted vs given, on subset w/ matching group numbers

In [40]:
annot_diag = subset_annot_corrgroups[(subset_annot_corrgroups.group_name == 'patients')][['count', 'pmcid', 'diagnosis']]
pred_diag = subset_pred_corrgroups[subset_pred_corrgroups.group_name == 'patients'][['count', 'pmcid', 'diagnosis']]
pd.merge(annot_diag, pred_diag, on=['pmcid', 'count'])

Unnamed: 0,count,pmcid,diagnosis_x,diagnosis_y
0,15,2648877,Autism Spectrum Disorder,Autism Spectrum Disorder
1,10,3742334,chronic marijuana use,chronic marijuana (MJ) users
2,17,3877773,Tourette syndrome,TS
3,16,3984441,dementia with Lewy bodies,DLB
4,30,4190683,Attention Deficit Hyperactivity Disorder,ADHD
5,18,4265725,Autism spectrum conditions,ASC
6,32,4473263,chronic left-hemisphere stroke,chronic left-hemisphere stroke
7,11,4589842,Depression,depressed patients
8,16,5279905,autism spectrum disorder,ASD
9,42,5394595,bronchial asthma without acute attacks,bronchial asthma


In [42]:
annot_diag = subset_annot_corrgroups[(subset_annot_corrgroups.group_name == 'healthy')][['count', 'pmcid', 'diagnosis']]
pred_diag = subset_pred_corrgroups[subset_pred_corrgroups.group_name == 'healthy'][['count', 'pmcid', 'diagnosis']]
pd.merge(annot_diag, pred_diag, on=['pmcid', 'count'])

Unnamed: 0,count,pmcid,diagnosis_x,diagnosis_y
0,19,2561002,,
1,15,2561002,,
2,18,2648877,,non-autistic control
3,24,3502502,,healthy
4,44,3511796,,healthy
5,18,3742334,,nonusing (NU) comparison subjects
6,15,3877773,,HC
7,17,3984441,,controls
8,11,4589842,,healthy volunteers
9,16,5279905,,TD


In [34]:
sub_multi_group[sub_multi_group.group_name == 'patients'].shape

(58, 11)

In [37]:
sub_multi_group[sub_multi_group.group_name != 'patients']

array(['healthy'], dtype=object)