## Graph Hybridization

In [None]:
import os

import torch
from tqdm import tqdm

### Configurations

In [None]:
model_name_or_path = "DaizeDong/GraphsGPT-8W"
smiles_file = "../../data/examples/zinc_example.txt"

batch_size = 1024
num_reference_moles = 1000
hybridization_save_dir = "./hybridization_results"

device = "cuda" if torch.cuda.is_available() else "cpu"

### Load SMILES

In [None]:
with open(smiles_file, "r", encoding="utf-8") as f:
    smiles_list = f.readlines()
smiles_list = [smiles.removesuffix("\n") for smiles in smiles_list]

print(f"Total SMILES loaded: {len(smiles_list)}")
for i in range(10):
    print(f"Example SMILES {i}: {smiles_list[i]}")

### Load Model & Tokenizer

In [None]:
from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM
from data.tokenizer import GraphsGPTTokenizer

model = GraphsGPTForCausalLM.from_pretrained(model_name_or_path)
tokenizer = GraphsGPTTokenizer.from_pretrained(model_name_or_path)

print(model.state_dict().keys())
print(f"Total paramerters: {sum(x.numel() for x in model.parameters())}")

### Generate Original Molecules

Results will be saved to the "original" folder. (Empty image means generation failed)

You need to select a target molecule from them for hybridization.

In [None]:
from utils.operations.operation_list import split_list_with_yield
from utils.operations.operation_tensor import move_tensors_to_device
from utils.operations.operation_dict import reverse_dict
from utils.io import delete_file_or_dir, save_empty_png, save_mol_png

# read bond dict
bond_dict = tokenizer.bond_dict
inverse_bond_dict = reverse_dict(bond_dict)

# generate fingerprint tokens & get original results
now_sample_num = 0
all_fingerprint_tokens = []
all_generated_results = []

model.to(device)
model.eval()
print(f"Generating original molecules...")
with torch.no_grad():
    for batched_smiles in split_list_with_yield(smiles_list, batch_size):
        inputs = tokenizer.batch_encode(batched_smiles, return_tensors="pt")
        move_tensors_to_device(inputs, device)

        fingerprint_tokens = model.encode_to_fingerprints(**inputs)  # (batch_size, num_fingerprints, hidden_dim)
        generated_results: list = model.generate_from_fingerprints(
            fingerprint_tokens=fingerprint_tokens,
            bond_dict=bond_dict,
            strict_generation=True,
            similarity_threshold=0.5,
            check_first_node=True,
            check_atom_valence=False,
            fix_aromatic_bond=True,
            use_cache=False,
            save_failed=False,
            show_progress=True,
            verbose=True,
        )

        # limit the number of samples
        this_sample_num = fingerprint_tokens.shape[0]
        append_sample_num = min(this_sample_num, num_reference_moles - now_sample_num)
        if append_sample_num > 0:
            now_sample_num += append_sample_num
            all_fingerprint_tokens.append(fingerprint_tokens[:append_sample_num])
            all_generated_results.extend(generated_results[:append_sample_num])
        if append_sample_num < this_sample_num:
            print("Max sample num reached, stopping forwarding.")
            break

all_fingerprint_tokens = torch.cat(all_fingerprint_tokens, dim=0)
num_fingerprint_tokens = fingerprint_tokens.shape[1]
print(f"Number of samples is {all_fingerprint_tokens.shape[0]}")
print(f"Number of fingerprints for each sample is {num_fingerprint_tokens}")

# save generated original results to disk as reference molecules
success_cnt = 0
invalid_cnt = 0
fail_cnt = 0
save_smiles_list = []

orignial_moles_save_dir = os.path.join(hybridization_save_dir, "original")
delete_file_or_dir(orignial_moles_save_dir)
os.makedirs(orignial_moles_save_dir)

for i, result in tqdm(enumerate(all_generated_results), desc="Saving molecule images"):
    save_img_path = os.path.join(orignial_moles_save_dir, f"{i}.png")

    if result is not None:
        input_ids = result["input_ids"]
        graph_position_ids_1 = result["graph_position_ids_1"]
        graph_position_ids_2 = result["graph_position_ids_2"]
        identifier_ids = result["identifier_ids"]

        mol = tokenizer._convert_token_tensors_to_molecule(input_ids, graph_position_ids_1, graph_position_ids_2, identifier_ids, inverse_bond_dict)

        if mol is None:
            save_smiles_list.append(None)
            save_empty_png(save_img_path)
            invalid_cnt += 1
        else:
            smiles = tokenizer._convert_molecule_to_standard_smiles(mol)
            save_smiles_list.append(smiles)
            save_mol_png(mol, save_img_path)
            success_cnt += 1
    else:
        save_smiles_list.append(None)
        save_empty_png(save_img_path)
        fail_cnt += 1

# save statistics
with open(os.path.join(orignial_moles_save_dir, "count.txt"), 'a') as f:
    f.write(f"Success count: {success_cnt}\n")
    f.write(f"Invalid count: {invalid_cnt}\n")
    f.write(f"Fail count: {fail_cnt}\n")

with open(os.path.join(orignial_moles_save_dir, "smiles.txt"), 'a') as f:
    for smiles in save_smiles_list:
        f.write(f"{smiles}\n")

print(f"Results saved to {orignial_moles_save_dir}")

### Select Target Molecule to Perform Hybridization

Select a target molecules from the original molecules for hybridization!

It is better that selected molecules are generated successfully (corresponding images are not empty).

You also need to set the indices of fingerprint tokens to hybrid. You can check the clustering results for reference. A simple method is to hybrid the tokens where the target molecule's cluster possessed obvious features (e.g., all molecules in the cluster share the same functional group).


In [None]:
# You can change the index according to the results of molecules
target_mole_index = 344

# To hybrid multiple tokens at a time, separate the indices with ","
hybrid_token_index_list = "0,1,2"

### Generate Hybrid Molecules

Results will be saved to the "hybrid" folder. (Empty image means generation failed)

In [None]:
# get target fingerprint tokens
target_fingerprint_tokens = all_fingerprint_tokens[target_mole_index].unsqueeze(0)  # (1, num_fingerprints, embed_size)

# hybrid fingerprint tokens
hybridization_ids = [int(id) for id in hybrid_token_index_list.split(",")]
assert all(id < num_fingerprint_tokens for id in hybridization_ids)
hybridization_ids = torch.tensor(hybridization_ids, device=device, dtype=torch.int64)

all_hybrid_fingerprint_tokens = fingerprint_tokens.clone()
all_hybrid_fingerprint_tokens[:, hybridization_ids, :] = target_fingerprint_tokens[:, hybridization_ids, :]

# generate hybrid molecules
print(f"Generating for fingerprint token hybridization...")
all_hybrid_generated_results: list = model.generate_from_fingerprints(
    fingerprint_tokens=all_hybrid_fingerprint_tokens,
    bond_dict=bond_dict,
    strict_generation=True,
    similarity_threshold=0.5,
    check_first_node=True,
    check_atom_valence=False,
    fix_aromatic_bond=True,
    use_cache=False,
    save_failed=False,
    show_progress=True,
    verbose=True,
)
all_hybrid_generated_results = [move_tensors_to_device(result, "cpu") for result in all_hybrid_generated_results]

# save generated hybrid results to disk
success_cnt = 0
invalid_cnt = 0
fail_cnt = 0
save_smiles_list = []

hybrid_moles_save_dir = os.path.join(hybridization_save_dir, f"hybrid{hybrid_token_index_list.replace(',', '_')}")
delete_file_or_dir(hybrid_moles_save_dir)
os.makedirs(hybrid_moles_save_dir)

for i, result in tqdm(enumerate(all_hybrid_generated_results), desc="Saving molecule images"):
    save_img_path = os.path.join(hybrid_moles_save_dir, f"{i}.png")

    if result is not None:
        input_ids = result["input_ids"]
        graph_position_ids_1 = result["graph_position_ids_1"]
        graph_position_ids_2 = result["graph_position_ids_2"]
        identifier_ids = result["identifier_ids"]

        mol = tokenizer._convert_token_tensors_to_molecule(input_ids, graph_position_ids_1, graph_position_ids_2, identifier_ids, inverse_bond_dict)

        if mol is None:
            save_smiles_list.append(None)
            save_empty_png(save_img_path)
            invalid_cnt += 1
        else:
            smiles = tokenizer._convert_molecule_to_standard_smiles(mol)
            save_smiles_list.append(smiles)
            save_mol_png(mol, save_img_path)
            success_cnt += 1
    else:
        save_smiles_list.append(None)
        save_empty_png(save_img_path)
        fail_cnt += 1

# save statistics
with open(os.path.join(hybrid_moles_save_dir, "count.txt"), 'a') as f:
    f.write(f"Success count: {success_cnt}\n")
    f.write(f"Invalid count: {invalid_cnt}\n")
    f.write(f"Fail count: {fail_cnt}\n")

with open(os.path.join(hybrid_moles_save_dir, "smiles.txt"), 'a') as f:
    for smiles in save_smiles_list:
        f.write(f"{smiles}\n")

print(f"Results saved to {hybrid_moles_save_dir}")

All done.
You can check the saved files for further analysis.