In [None]:
import json
import pickle
from rdkit import Chem
from tqdm import tqdm
import os
from RelAgent.relagent.MolSubOnto import MolSubOntoInstance, search_index, get_all_instances_and_relationships, find_relationship
from RelAgent.relagent.template import EE_TEMPLATE, REL_TEMPLATE
from RelAgent.relagent.util import extract_json_from_text

relation_test = json.load(open("RelAgent/data/molground/molground_test.json"))
all_mols = pickle.load(open('RelAgent/data/molgenie/molgenie_all_mols.pkl', 'rb'))
all_nodes_dict = pickle.load(open('RelAgent/data/molgenie/molgenie_all_nodes_dict.pkl', 'rb'))

ee_outputs = json.load(open("llm_outputs/BoN/output_ee_test_MolAgent_ee_Llama-3.1-8B-Instruct_sft_n16.json"))
prompts = []

for test, ee_output in tqdm(zip(relation_test, ee_outputs), total=len(ee_outputs)):
    id = ee_output['id']
    outputs = ee_output['outputs']
    smiles = test['smiles']
    caption = test['caption']
    
    # Load EE Outputs
    ee_candidate = []
    for output in outputs:
        ee_candidate.append(extract_json_from_text(output))
    
    # ================================================
    # Verifier 1: RDKit Localization
    # ================================================
    entities_candidate = []
    for ee in ee_candidate:
        entities = []
        if ee is None:
            entities_candidate.append(None)
            continue
        for e in ee:
            entities.append({
                'name': e['name'],
                'smiles': e['smiles'],
                'indices': search_index(smiles, e['smiles'])
            })
        entities_candidate.append(entities)
    
    # ================================================
    # Verifier 2: MolOnto Relationship
    # ================================================
    if os.path.exists(f'data/molonto/{id}.pickle'):
        molonto = MolSubOntoInstance(smiles, all_mols, all_nodes_dict, load_graph=True, graph_path=f'data/molonto/{id}.pickle', load_ontology=True, ontology_path=f'data/molonto/{id}.ttl')
    else:
        molonto = MolSubOntoInstance(smiles, all_mols, all_nodes_dict)
        molonto.save_graph(f'data/molonto/{id}.pickle')
        molonto.save_ontology(f'data/molonto/{id}.ttl')

    if os.path.exists(f'data/molonto/{id}_instances.pickle') and os.path.exists(f'data/molonto/{id}_relationships.pickle'):
        instances_dict = pickle.load(open(f'data/molonto/{id}_instances.pickle', 'rb'))
        relationships_dict = pickle.load(open(f'data/molonto/{id}_relationships.pickle', 'rb'))
    else:
        instances_dict, relationships_dict = get_all_instances_and_relationships(molonto.ontology)
        pickle.dump(instances_dict, open(f'data/molonto/{id}_instances.pickle', 'wb'))
        pickle.dump(relationships_dict, open(f'data/molonto/{id}_relationships.pickle', 'wb'))
    
    # Prepare a mapping from entity indices to their names
    relationships_candidate = []
    for entities in entities_candidate:
        # for each completion, prepare the molonto relationship contexts
        relationship_context = ""
        entity_indices2name_map = {}
        if entities is None:
            relationships_candidate.append(None)
            continue
        for entity in entities:
            if entity is None:
                continue
            if entity['indices'] is None:
                continue
            for index in entity['indices']:
                index_str = ','.join(map(str, sorted(index)))
                entity_indices2name_map[index_str] = entity['name']
        mentioned_indices = set()
        for index_a, name_a in entity_indices2name_map.items():
            for index_b, name_b in entity_indices2name_map.items():
                if index_a != index_b and (index_b, index_a) not in mentioned_indices and (index_a, index_b) not in mentioned_indices:
                    mentioned_indices.add((index_a, index_b))
                    mentioned_indices.add((index_b, index_a))
                    rel_label = find_relationship(relationships_dict, index_a, index_b)
                    if rel_label:
                        relationship_context += f"Relationship between {name_a} ({index_a}) and {name_b} ({index_b}): {rel_label}\n"

        relationship_context = relationship_context.strip()
        relationships_candidate.append(relationship_context)

    # ================================================
    # Generate Prompt
    # ================================================
    for n in range(len(outputs)):
        _entities_str = json.dumps(entities_candidate[n], indent=4)
        _relationships_str = relationships_candidate[n]
        _rel_temp = rel_template.format(smiles=smiles, caption=caption, entities=_entities_str, relationships=_relationships_str)
        prompts.append({
            "id": f"{id}_N{n+1}",
            "entities": entities_candidate[n],
            "relationships": relationships_candidate[n],
            "prompt": _rel_temp
        })

    with open('llm_prompts/prompt_rel_test_eeFromLora_BoN16.json', 'w') as f:
        json.dump(prompts, f, indent=4)

In [None]:
outputs_no_duplicate = json.load(open('llm_outputs/BoN/output_rel_test_eeFromLora_BoN16__no_duplicate_MolAgent_ee_rel_Llama-3.1-8B-Instruct_sft.json', 'r'))
prompts_7k = json.load(open('llm_prompts/BoN/prompt_rel_test_eeFromLora_BoN16.json', 'r'))
prompts_no_duplicate = json.load(open('llm_prompts/BoN/prompt_rel_test_eeFromLora_BoN16_no_duplicate.json', 'r'))
prompt2output = {}
for prompt_str, output in zip(prompts_no_duplicate, outputs_no_duplicate):
    prompt2output[prompt_str['prompt']] = output['outputs'][0]

real_outputs = []
for t, test in enumerate(relation_test):
    id = test['id']
    n = 16
    prompts_range = prompts_7k[t*n:(t+1)*n]
    entities = []
    relationships = []
    outputs = []
    for i in range(n):
        ee = prompts_range[i]['entities']
        rel = prompts_range[i]['relationships']
        index = None
        for p in prompts_no_duplicate:
            if p['entities'] == ee and p['relationships'] == rel:
                index = prompts_no_duplicate.index(p)
                break
        if index is not None:
            output = prompt2output[prompts_no_duplicate[index]['prompt']]
        else:
            output = None
        entities.append(ee)
        relationships.append(rel)
        outputs.append(output)
    real_outputs.append({
        'id': id,
        'entities': entities,
        'relationships': relationships,
        'outputs': outputs})

with open('llm_outputs/BoN/output_rel_test_eeFromLora_BoN16_MolAgent_ee_rel_Llama-3.1-8B-Instruct_sft.json', 'w') as f:
    json.dump(real_outputs, f, indent=4)
print(len(real_outputs))

443


In [None]:
# i = 100 # Perfect Answer

i = 0   # EE loss
# i = 8   # EE Chain Loss & Single Atom Localization Loss
# i = 400 # Single Atom Localization Loss

# i = 256 # Complex Loss

p = prompts_7k[i*16]['prompt']
molonto = real_outputs[i]['outputs'][0]
bon_outputs = json.load(open('llm_outputs/BoN/policy/output_rel_test_eeFromLora_BoN16_majority-voting_MolAgent_ee_rel_Llama-3.1-8B-Instruct_sft.json', 'r'))
bon_o = bon_outputs[i]['output']
sft_outputs = json.load(open('llm_outputs/output_ee-from-llama8b_rel_test_MolAgent_ee_Llama-3.1-8B-Instruct_sft.json', 'r'))
sft_o = sft_outputs[i]['output']
gt_o = relation_test[i]['ground_truth']

# Create a table comparing the prompt, outputs and ground truth
from tabulate import tabulate

comparison_table = [
    ["Type", "Prompt", "MolOnto Context", "BoN Majority Voting Output", "SFT Output", "Ground Truth"],
    ["Content", p, molonto, bon_o, sft_o, json.dumps(gt_o, indent=4)]
]

# Display the table
print(tabulate(comparison_table, headers="firstrow", tablefmt="grid"))

+---------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------+----------------------------------------------+---------------------------------------------------+----------------------------------------------+
| Type    | Prompt                                                                                                                                                                                                                                                                           | MolOnto Context                              | BoN Majority Voting Output                   | SFT Output                                        | Ground Truth                                 |
| Content | You are a chemistry expert. 

In [None]:
# select outputs based on defferent policy
import json
import random

# first output only
new_outputs = []
for output in real_outputs:
    new_outputs.append({
        'id': output['id'],
        'output': output['outputs'][0]
    })
with open('llm_outputs/BoN/policy/output_rel_test_eeFromLora_BoN16_first-output-only_MolAgent_ee_rel_Llama-3.1-8B-Instruct_sft.json', 'w') as f:
    json.dump(new_outputs, f, indent=4)

# random selection
random.seed(42)
new_outputs = []
for output in real_outputs:
    new_outputs.append({
        'id': output['id'],
        'output': random.choice(output['outputs'])
    })
with open('llm_outputs/BoN/policy/output_rel_test_eeFromLora_BoN16_random-selection_MolAgent_ee_rel_Llama-3.1-8B-Instruct_sft.json', 'w') as f:
    json.dump(new_outputs, f, indent=4)

# most relation retrieved
new_outputs = []
for output in real_outputs:
    max_rel_num = 0
    max_rel_output = None
    for rel in output['relationships']:
        if rel is None:
            continue
        rel_num = len(rel.split('\n'))
        if rel_num > max_rel_num:
            max_rel_num = rel_num
            max_rel_output = output['outputs'][0]
    new_outputs.append({
        'id': output['id'],
        'output': max_rel_output if max_rel_output is not None else output['outputs'][0]
    })
with open('llm_outputs/BoN/policy/output_rel_test_eeFromLora_BoN16_most-relation-retrieved_MolAgent_ee_rel_Llama-3.1-8B-Instruct_sft.json', 'w') as f:
    json.dump(new_outputs, f, indent=4)


# majority voting
new_outputs = []
for output in real_outputs:
    output_count = {}
    for o in output['outputs']:
        if o not in output_count:
            output_count[o] = 0
        output_count[o] += 1
    max_output = max(output_count, key=output_count.get)
    new_outputs.append({
        'id': output['id'],
        'output': max_output})
with open('llm_outputs/BoN/policy/output_rel_test_eeFromLora_BoN16_majority-voting_MolAgent_ee_rel_Llama-3.1-8B-Instruct_sft.json', 'w') as f:
    json.dump(new_outputs, f, indent=4)

# always return the first non-empty output
new_outputs = []
for output in real_outputs:
    for o in output['outputs']:
        new_outputs.append({
            'id': output['id'],
            'output': o})
        break
with open('llm_outputs/BoN/policy/output_rel_test_eeFromLora_BoN16_always-return-first-non-empty_MolAgent_ee_rel_Llama-3.1-8B-Instruct_sft.json', 'w') as f:
    json.dump(new_outputs, f, indent=4)

# least non-empty entity indices
new_outputs = []
for output in real_outputs:
    min_entity_count = 100
    min_entity_output = None
    for entities, o in zip(output['entities'], output['outputs']):
        if o is None:
            continue
        count = 0
        if entities is None:
            continue
        for _e in entities:
            if _e['indices'] is None:
                count += 1
            else:
                count += len(_e['indices'])
        if count < min_entity_count:
            min_entity_count = count
            min_entity_output = o
    new_outputs.append({
        'id': output['id'],
        'output': min_entity_output if min_entity_output is not None else output['outputs'][0]
    })
with open('llm_outputs/BoN/policy/output_rel_test_eeFromLora_BoN16_least-non-empty-entity-indices_MolAgent_ee_rel_Llama-3.1-8B-Instruct_sft.json', 'w') as f:
    json.dump(new_outputs, f, indent=4)

# longest relationship pairs
new_outputs = []
for output in real_outputs:
    max_rel_num = 0
    max_rel_output = None
    for o in output['outputs']:
        try:
            o_json = json.loads(o.split('```json')[1].split('```')[0])
        except:
            continue
        rel_num = len(o_json['relationships'])
        if rel_num > max_rel_num:
            max_rel_num = rel_num
            max_rel_output = o
    new_outputs.append({
        'id': output['id'],
        'output': max_rel_output if max_rel_output is not None else output['outputs'][0]
    })
with open('llm_outputs/BoN/policy/output_rel_test_eeFromLora_BoN16_longest-relationship-pairs_MolAgent_ee_rel_Llama-3.1-8B-Instruct_sft.json', 'w') as f:
    json.dump(new_outputs, f, indent=4)

In [None]:
# select output based on dynamic score
import json
import pickle
import os
import numpy as np
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('F:\\backup\model\\all-MiniLM-L6-v2')
real_outputs = json.load(open('llm_outputs/BoN/output_rel_test_eeFromLora_BoN16_MolAgent_ee_rel_Llama-3.1-8B-Instruct_sft.json', 'r'))
relation_train = json.load(open('data/relation_train.json', 'r'))

mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=2,fpSize=1024)
# calculate the morgan fingerprint of the training set's smiles
fingerprints_train = []
for data in relation_train:
    smiles = data['smiles']
    mol = Chem.MolFromSmiles(smiles)
    fingerprint = mfpgen.GetFingerprint(mol)
    fingerprints_train.append(fingerprint)
fingerprints_train = np.array(fingerprints_train)
print(f"Training set fingerprints shape: {fingerprints_train.shape}")

fingerprints_test = []
for data in relation_test:
    smiles = data['smiles']
    mol = Chem.MolFromSmiles(smiles)
    fingerprint = mfpgen.GetFingerprint(mol)
    fingerprints_test.append(fingerprint)
fingerprints_test = np.array(fingerprints_test)
print(f"Test set fingerprints shape: {fingerprints_test.shape}")

# calculate the embedding of the training set's caption
captions = [data['caption'] for data in relation_train]
caption_embeddings_train = model.encode(captions)
print(f"Training set caption embeddings shape: {caption_embeddings_train.shape}")
captions = [data['caption'] for data in relation_test]
caption_embeddings_test = model.encode(captions)
print(f"Test set caption embeddings shape: {caption_embeddings_test.shape}")

# calculate the embedding of test ground-truth
gt_texts = [json.dumps(data['ground_truth'], indent=4) for data in relation_test]
gt_embeddings = model.encode(gt_texts)
print(f"Test ground-truth embeddings shape: {gt_embeddings.shape}")

Training set fingerprints shape: (3487, 1024)
Test set fingerprints shape: (443, 1024)
Training set caption embeddings shape: (3487, 384)
Test set caption embeddings shape: (443, 384)
Test ground-truth embeddings shape: (443, 384)


In [None]:
# select real outputs based on dynamic score
from sklearn.metrics.pairwise import cosine_similarity

# find the closest ground-truth for each test set
sim_fingerprints_test = cosine_similarity(fingerprints_test, fingerprints_train)
k_mol = 5
sim_fingerprints_test_topk = np.argsort(sim_fingerprints_test, axis=1)[:, :k_mol]
filtered_captions_embeddings_test = []
for i in range(len(sim_fingerprints_test_topk)):
    filtered_captions_embeddings_test.append(caption_embeddings_train[sim_fingerprints_test_topk[i]])
filtered_captions_embeddings_test = np.array(filtered_captions_embeddings_test)

k_cap = 1
sim_captions_test = []
for i in range(len(caption_embeddings_test)):
    sim_captions_test.append(cosine_similarity(caption_embeddings_test[i].reshape(1, -1), filtered_captions_embeddings_test[i]))
sim_captions_test = np.array(sim_captions_test)
sim_captions_test_topk = np.argsort(sim_captions_test, axis=2)[:, :, :k_cap]
sim_captions_test_topk = sim_captions_test_topk.reshape(sim_captions_test_topk.shape[0], k_cap)

filtered_contexts = []
for i in range(len(sim_captions_test_topk)):
    filtered_contexts.append(real_outputs[i]['outputs'][sim_captions_test_topk[i][0]])

# save filtered contexts
filtered_outputs = []
for i in range(len(filtered_contexts)):
    filtered_outputs.append({
        'id': real_outputs[i]['id'],
        'output': filtered_contexts[i]
    })

with open('llm_outputs/BoN/policy/output_rel_test_eeFromLora_BoN16_similarity-based_MolAgent_ee_rel_Llama-3.1-8B-Instruct_sft.json', 'w') as f:
    json.dump(filtered_outputs, f, indent=4)