In [1]:
'''
The prediction files are stored in a nested folder structure based on the dataset, 
GPT model version, and the shot setting. The example input structure is:

data/
├── NCBI_Disease/
│   ├── GPT3.5/
│   │   └── NCBI_Disease_gpt3.5_5s.json
│   └── GPT4/
│       └── NCBI_Disease_gpt4_5s.json
└── BC5CDR_Chemical/
    ├── GPT3.5/
    │   └── BC5CDR_Chemical_gpt3.5_5s.json
    └── GPT4/
        └── BC5CDR_Chemical_gpt4_5s.json
        
The example output structure is
output/
├── NCBI_Disease_gpt3.5_5s_gold_span.txt
├── NCBI_Disease_gpt3.5_5s_pre_span.txt
└── ...

'''

'\nThe prediction files are stored in a nested folder structure based on the dataset, \nGPT model version, and the shot setting. The expected directory structure is:\n\ndata/\n├── NCBI_Disease/\n│   ├── GPT3.5/\n│   │   └── NCBI_Disease_gpt3.5_5s.json\n│   └── GPT4/\n│       └── NCBI_Disease_gpt4_5s.json\n└── BC5CDR_Chemical/\n    ├── GPT3.5/\n    │   └── BC5CDR_Chemical_gpt3.5_5s.json\n    └── GPT4/\n        └── BC5CDR_Chemical_gpt4_5s.json\n        \nThe expected output is\noutput/\n├── NCBI_Disease_gpt3.5_5s_gold_span.txt\n├── NCBI_Disease_gpt3.5_5s_pre_span.txt\n└── ...\n\n'

In [2]:
import json
import os

In [3]:
# Ensure the base directories exist
base_converted_dir = './converted_bio_output'
base_output_dir = './output'
if not os.path.exists(base_converted_dir):
    os.makedirs(base_converted_dir)
if not os.path.exists(base_output_dir):
    os.makedirs(base_output_dir)

In [4]:
def evaluate(label_file, eval_score_file='', sep_tag_type='tab', eval_type='ner', exact=False, labels=None):    
    sep_tag = ' ' if sep_tag_type=='space' else '\t'    
    try:
        if type(label_file) is str:
            gold, predict = load_combined_bio( label_file, sep_tag)
        elif len(label_file) == 2:
            gold_label_file = label_file[0]
            pred_label_file = label_file[1]
            gold = load_bio(gold_label_file, sep_tag)
            predict = load_bio(pred_label_file, sep_tag)
        else:
            return
    except Exception as e:
        return
    gold_spans = load_spans( gold, eval_type )
    predict_spans = load_spans( predict, eval_type )

    if not labels:
        labels = list({item[2:] for item in set(gold).union(set(predict)) if item[2:] and item != 'O'})
        labels = [item for item in labels if item != 'O']
        labels.sort()
    else:
        labels = [item.strip() for item in labels.split(',') if item.strip()]
    eval_results = []
    if exact:
        score_cols = 'P\tR\tF1\tright\tpredict\tgold\tSemantic'
        eval_results.append(score_cols)
    else:
        score_cols = 'P(exact)\tR(exact)\tF1(exact)\tP(relax)\tR(relax)\tF1(relax)\tright\tright_predict\tright_gold\tpredict\tgold\tSemantic'
        eval_results.append(score_cols)
    gold_span = []
    predict_span = []
    for k in sorted(gold_spans.keys(),key=lambda x:labels.index(x)):
        gold_span = gold_spans[k]

        if k in predict_spans:
            predict_span = predict_spans[k]
        else:
            predict_span = []
    return gold_span, predict_span

def load_bio( result_file, sep_tag='\t' ):
    result = []
    with open( result_file ) as infile:
        for line in infile:
            if line.strip() == '':
                result.append( 'O' )
                continue
            # ignore disjoint entities for now
            #line = line.replace( 'B-DDisease_Disorder', 'B-Disease_Disorder' )
            #line = line.replace( 'B-HDisease_Disorder', 'B-Disease_Disorder' )
            #line = line.replace( 'I-DDisease_Disorder', 'I-Disease_Disorder' )
            #line = line.replace( 'I-HDisease_Disorder', 'I-Disease_Disorder' )
            result.append( line.split(sep_tag)[-1].strip('\n') )

    return result

# load spans: support eval_type as 'ner' or 'classification'
def load_spans( labels, eval_type='ner' ):
    spans = {}
    for i in range( len( labels ) ):
        if eval_type == 'classification':
            label = labels[i]
            spans.setdefault( label, [] )
            spans[label].append((i, i+1))
            continue
        else:
            label = labels[i]
            if type(label) == list:
                print('Warning: label is list {}'.format(label))
            if label.startswith( 'B-' ):
                s = i
                e = i + 1
                sem = label[ 2: ]
                # found an entity
                for j in range( i + 1, len( labels ) ):
                    e = j
                    if labels[j] != 'I-' + sem:
                        break
                spans.setdefault( sem, [] )
                spans[ sem ].append( ( s, e ) )
            if label.startswith( 'DB-' ):
                s = i
                e = i + 1
                sem = label[ 3: ]
                # found an entity
                for j in range( i + 1, len( labels ) ):
                    e = j
                    if labels[j] != 'DI-' + sem:
                        break
                spans.setdefault( sem, [] )
                spans[ sem ].append( ( s, e ) )

            if label.startswith( 'HB-' ):
                s = i
                e = i + 1
                sem = label[ 3: ]
                # found an entity
                for j in range( i + 1, len( labels ) ):
                    e = j
                    if labels[j] != 'HI-' + sem:
                        break
                spans.setdefault( sem, [] )
                spans[ sem ].append( ( s, e ) )

    return spans

In [5]:
for dataset in ['NCBI_Disease','BC5CDR_Chemical']:
    for shot in ['5s']:
        for gpt in ['3.5','4']:
            prediction_dict = json.load(open(f'./data/{dataset}/GPT{gpt}/{dataset}_gpt{gpt}_{shot}.json'))
            !mkdir ./converted_bio_output/{dataset}_gpt{gpt}_{shot}/
            
            gold_spans=[]
            pre_spans = []

            for index,instance in enumerate(prediction_dict):
                f1 = f'./converted_bio_output/{dataset}_gpt{gpt}_{shot}/{index}_pre.bio'
                f2 = f'./converted_bio_output/{dataset}_gpt{gpt}_{shot}/{index}_gold.bio'
                f_pre = open(f1,'w')
                f_gold = open(f2,'w')
                tokens = instance['sentence']
                golds = instance['gold']
                preds = instance['pred']
            
                gold_bio = []
                for token,gold in zip(tokens,golds):
                    if dataset == 'NCBI_Disease':
                        if gold == 'B':gold='B-disease'
                        if gold == 'I':gold='I-disease'
                    if dataset == 'BC5CDR_Chemical':
                        if gold == 'B':gold='B-chemical'
                        if gold == 'I':gold='I-chemical'
                    gold_bio.append(token+'\t'+gold+'\n')
                pred_bio = []
                for token,pre in zip(tokens,preds):
                    if dataset == 'NCBI_Disease':
                        
                        if pre == 'B':pre='B-disease'
                        if pre == 'I':pre='I-disease'
                    if dataset == 'BC5CDR_Chemical':
                        if pre == 'B':pre='B-chemical'
                        if pre == 'I':pre='I-chemical'
                    pred_bio.append(token+'\t'+pre+'\n')
                    
                f_pre.writelines(pred_bio)
                f_pre.write('\n')
                f_gold.writelines(gold_bio)
                f_gold.write('\n')    
                    
                f_pre.close()
                f_gold.close()
                    
            
                label_file = (f2,f1)
                gold_span,pre_span = evaluate(label_file)
                gold_spans.append(gold_span)
                pre_spans.append(pre_span)
            with open(f'./output/{dataset}_gpt{gpt}_{shot}_gold_span.txt','w') as gold_span_output:
                gold_span_output.write(str(gold_spans))
            with open(f'./output/{dataset}_gpt{gpt}_{shot}_pre_span.txt','w') as gold_span_output:
                gold_span_output.write(str(pre_spans))