In [12]:
import pandas as pd
import json
from pathlib import Path
from tqdm import tqdm
import re
import ast
import os
from rel_verbaliser import get_rel2prompt

In [13]:
def parse_string_to_list_of_lists(s):
    extracted_text = re.sub(r'\]\s*,?\s*\[', '],[', s)
    extracted_list = extracted_text.split('],[')
    extracted_list = [s.strip('[]') for s in extracted_list]
    list_of_lists = []
    for inner_list in extracted_list:
        elements = [elem.strip().strip('"').strip("'").strip('`').strip('[').strip(']').lower() for elem in inner_list.split(',')]
        if len(elements)==3:
            list_of_lists.append(elements)
    return list_of_lists

In [14]:
def extract_mistral(text):
    str_ = text.replace('\n','')
    triplet_marker = "$TRIPLETS$ ="
    triplet_string = str_[str_.find(triplet_marker) + len(triplet_marker):len(str_)].strip().replace("</s>","")
    triple_list = parse_string_to_list_of_lists(triplet_string)
    return triple_list

In [15]:
def extract_llama(text):
    str_ = text.replace('\n','')
    triplet_marker = "$TRIPLETS$ ="
    triplet_string = str_[str_.find(triplet_marker) + len(triplet_marker):len(str_)].strip().replace("<|eot_id|>","")
    triple_list = parse_string_to_list_of_lists(triplet_string)
    return triple_list

In [16]:
def extract_openchat(text):
    str_ = text.replace('\n','')
    triplet_marker = "$TRIPLETS$ ="
    triplet_string = str_[str_.find(triplet_marker) + len(triplet_marker):len(str_)].strip().replace("<|end_of_turn|>","")
    triple_list = parse_string_to_list_of_lists(triplet_string)
    return triple_list

In [17]:
def extract_gemma(text):
    str_ = text.replace('\n','')
    triplet_marker = "$TRIPLETS$ ="
    triplet_string = str_[str_.find(triplet_marker) + len(triplet_marker):len(str_)].replace("<end_of_turn><eos>","").strip()
    triple_list = parse_string_to_list_of_lists(triplet_string)
    return triple_list

In [21]:
data_dir = './Data_JRE'
llm_result_path = './results'
out_path = './Results'
exp = '2stage'
datasets = ['crossRE', 'tacred_new', 'NYT10']
models = ["openchat/openchat_3.5", "meta-llama/Meta-Llama-3.1-8B-Instruct", "mistralai/Mistral-Nemo-Instruct-2407", "google/gemma-2-9b-it"]

In [19]:
def get_info(exp, file):
    parts = file.parts
    struct, prompt, demo, seed, k = None, None, None, None, None
    if exp=='plm':
        seed = file.parts[-1].split('.')[0][-1]
    elif exp=='structure_extract':
        prompt = parts[-1].split('_')[-1].split('-')[1]
        struct = parts[-1].split('_')[-1].split('-')[0]
        k = int(parts[-1].split('_')[-1].split('-')[-1].split('.')[0])
        seed = parts[-2].split('-')[-1]
        demo = parts[-3]
    else:
        prompt = parts[-1].split('-')[0]
        k = int(parts[-1].split('_')[-1].split('-')[-1].split('.')[0])
        seed = parts[-2].split('-')[-1]
        demo = parts[-3]
    return struct, prompt, seed, demo, k

In [20]:
def extract_gpt(text):
    triplet_patterns = [
        r".*TRIPLETS\s*(\[\[.*?\]\])",  # Standard TRIPLETS = [[...]]
        r".*TRIPLETS\s*[:=]\s*(\[[\s\S,'\w\s]*\]?)",  # Standard TRIPLETS = [[...]]
        r".*Triplets\s*[:=]\s*(\[[\s\S,'\w\s]*\]?)",  # Standard TRIPLETS = [[...]]
        r".*TRIPLETS\$\s*[:=]?\s*(\[[\s\S,'\w\s]*\]?)",  # Matches $TRIPLETS$ variations
        r".*list of triplets are\s*(\[[\s\S,'\w\s]*\]?)",  # Handles descriptive cases
        r".*list of triplets is\s*(\[[\s\S,'\w\s]*\]?)",  # Handles descriptive cases
    r".*TRIPLETS\s*[:=]\s*(\[[\s\S,'\w\s]*\]?)",  # Multi-line safe handling
]

    triplets = []

    for pattern in triplet_patterns:
        try:
            match = re.search(pattern, text, re.DOTALL)
            if match:
                triplets = parse_string_to_list_of_lists(match.group(1))
                break
        except:
            print(text)
            pass

    return triplets

In [27]:
for data in datasets:
    print(data)
    data_dict = {}
    with open(f'{data_dir}/{data}/rel2id.json', "r") as f:
        rel2id = json.load(f)
        
    rel2prompt = get_rel2prompt(data, rel2id)
    prompt2rel = {v:k for k, v in rel2prompt.items()}
    
    with open(f'{data_dir}/{data}/test.jsonl', "r") as f:
        for line in tqdm(f.read().splitlines()):
            tmp_dict = json.loads(line)
            triple_list = []
            for triple in tmp_dict['relationMentions']:
                triple_list.append([triple['em1Text'].lower(), rel2prompt[triple['label']].lower(), triple['em2Text'].lower()])
            row = {
                'triples': triple_list
            }
            data_dict[tmp_dict['sample_id']] = row
            
        
    for model in models:
        print(model)
        llm = model.split('/')[-1]
        files = list(Path(f'{llm_result_path}/JRE/{exp}/{model}/{data}'
                         ).rglob('*.jsonl'))
        
        
        for file in files:
            records = []
            count = 0
            struct, prompt, seed, demo, k = get_info(exp, file)
            
            if prompt in ['rel', 'open']:
                output_file = f'{out_path}/JRE/{exp}/{data}/{model}/{demo}/seed-{seed}'
                if not os.path.exists(f'{output_file}/{file.name}'):
                    os.makedirs(output_file, exist_ok=True)

                    examples = {}
                    with open(file, "r") as f:
                        for line in f.read().splitlines():
                            res_dict = json.loads(line)
                            examples[res_dict['id']] = res_dict

                    for key, val in examples.items():
                        label_true = data_dict[key]['triples']
                        if val['label_pred']:
                            if llm=="openchat_3.5":
                                label_pred = extract_openchat(val['label_pred'])
                            elif llm=="Meta-Llama-3.1-8B-Instruct":
                                label_pred = extract_llama(val['label_pred'])
                            elif llm=="Mistral-Nemo-Instruct-2407":
                                label_pred = extract_mistral(val['label_pred'])
                            elif llm=="gemma-2-9b-it":
                                label_pred = extract_gemma(val['label_pred'])
                            elif llm=="gpt-4o":
                                label_pred = extract_gpt(val['label_pred'])
                            if len(label_pred)==0:
                                count+=1
                        else:
                            label_pred = []
                        records.append({'id':key, 'true_label':label_true, 'pred_label':label_pred})
                    print(f'Missing triplets {count}/{len(examples)}; seed-{seed}')
                    df = pd.DataFrame(records)
                    df.to_json(f'{output_file}/{file.name}', lines=True, orient='records')

crossRE


100%|██████████| 1826/1826 [00:00<00:00, 30363.38it/s]


openchat/openchat_3.5
meta-llama/Meta-Llama-3.1-8B-Instruct
mistralai/Mistral-Nemo-Instruct-2407
google/gemma-2-9b-it
tacred_new


100%|██████████| 2307/2307 [00:00<00:00, 130827.44it/s]


openchat/openchat_3.5
meta-llama/Meta-Llama-3.1-8B-Instruct
mistralai/Mistral-Nemo-Instruct-2407
google/gemma-2-9b-it
NYT10


100%|██████████| 4006/4006 [00:00<00:00, 133504.81it/s]

openchat/openchat_3.5
meta-llama/Meta-Llama-3.1-8B-Instruct
mistralai/Mistral-Nemo-Instruct-2407
google/gemma-2-9b-it



