## Graph Interpolation

In [None]:
import os
import torch
from tqdm import tqdm

### Configurations

In [None]:
model_name_or_path = "DaizeDong/GraphsGPT-8W"
smiles_file = "../../data/examples/zinc_example.txt"

batch_size = 1024
num_reference_moles = 1000
num_interpolated_features = 1000
interpolation_save_dir = "./interpolation_results"

device = "cuda" if torch.cuda.is_available() else "cpu"

### Load SMILES

In [None]:
with open(smiles_file, "r", encoding="utf-8") as f:
    smiles_list = f.readlines()
smiles_list = [smiles.removesuffix("\n") for smiles in smiles_list]

print(f"Total SMILES loaded: {len(smiles_list)}")
for i in range(10):
    print(f"Example SMILES {i}: {smiles_list[i]}")

### Load Model & Tokenizer

In [None]:
from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM
from data.tokenizer import GraphsGPTTokenizer

model = GraphsGPTForCausalLM.from_pretrained(model_name_or_path)
tokenizer = GraphsGPTTokenizer.from_pretrained(model_name_or_path)

print(model.state_dict().keys())
print(f"Total paramerters: {sum(x.numel() for x in model.parameters())}")

### Generate Reference Molecules

Results will be saved to the "reference" folder. (Empty image means generation failed)

You need to select two molecules from them for the interpolation.

In [None]:
from utils.operations.operation_list import split_list_with_yield
from utils.operations.operation_tensor import move_tensors_to_device
from utils.operations.operation_dict import reverse_dict
from utils.io import delete_file_or_dir, save_empty_png, save_mol_png

# read bond dict
bond_dict = tokenizer.bond_dict
inverse_bond_dict = reverse_dict(bond_dict)

# generate fingerprint tokens
now_sample_num = 0
all_fingerprint_tokens = []
all_generated_results = []

model.to(device)
model.eval()
print(f"Generating for reference...")
with torch.no_grad():
    for batched_smiles in split_list_with_yield(smiles_list, batch_size):
        inputs = tokenizer.batch_encode(batched_smiles, return_tensors="pt")
        move_tensors_to_device(inputs, device)

        fingerprint_tokens = model.encode_to_fingerprints(**inputs)  # (batch_size, num_fingerprints, hidden_dim)
        generated_results: list = model.generate_from_fingerprints(
            fingerprint_tokens=fingerprint_tokens,
            bond_dict=bond_dict,
            strict_generation=True,
            max_atoms=None,
            similarity_threshold=0.5,
            check_first_node=True,
            check_atom_valence=False,
            fix_aromatic_bond=False,
            use_cache=False,
            save_failed=False,
            show_progress=True,
            verbose=True,
        )

        # limit the number of samples
        this_sample_num = fingerprint_tokens.shape[0]
        append_sample_num = min(this_sample_num, num_reference_moles - now_sample_num)
        if append_sample_num > 0:
            now_sample_num += append_sample_num
            all_fingerprint_tokens.append(fingerprint_tokens[:append_sample_num])
            all_generated_results.extend(generated_results[:append_sample_num])
        if append_sample_num < this_sample_num:
            print("Max sample num reached, stopping forwarding.")
            break

all_fingerprint_tokens = torch.cat(all_fingerprint_tokens, dim=0)
num_fingerprint_tokens = fingerprint_tokens.shape[1]
print(f"Number of samples is {all_fingerprint_tokens.shape[0]}")
print(f"Number of fingerprints for each sample is {num_fingerprint_tokens}")

# save generated results to disk as reference molecules
success_cnt = 0
invalid_cnt = 0
fail_cnt = 0
save_smiles_list = []

delete_file_or_dir(os.path.join(interpolation_save_dir, "references"))
os.makedirs(os.path.join(interpolation_save_dir, "references"))

for i, result in tqdm(enumerate(all_generated_results), desc="Saving molecule images"):
    save_img_path = os.path.join(interpolation_save_dir, "references", f"{i}.png")

    if result is not None:
        input_ids = result["input_ids"]
        graph_position_ids_1 = result["graph_position_ids_1"]
        graph_position_ids_2 = result["graph_position_ids_2"]
        identifier_ids = result["identifier_ids"]

        mol = tokenizer._convert_token_tensors_to_molecule(input_ids, graph_position_ids_1, graph_position_ids_2, identifier_ids, inverse_bond_dict)

        if mol is None:
            save_smiles_list.append(None)
            save_empty_png(save_img_path)
            invalid_cnt += 1
        else:
            smiles = tokenizer._convert_molecule_to_standard_smiles(mol)
            save_smiles_list.append(smiles)
            save_mol_png(mol, save_img_path)
            success_cnt += 1
    else:
        save_smiles_list.append(None)
        save_empty_png(save_img_path)
        fail_cnt += 1

# save statistics
with open(os.path.join(interpolation_save_dir, "references", "count.txt"), 'a') as f:
    f.write(f"Success count: {success_cnt}\n")
    f.write(f"Invalid count: {invalid_cnt}\n")
    f.write(f"Fail count: {fail_cnt}\n")

with open(os.path.join(interpolation_save_dir, "references", "smiles.txt"), 'a') as f:
    for smiles in save_smiles_list:
        f.write(f"{smiles}\n")

print(f"Results saved to {os.path.join(interpolation_save_dir, 'references')}")

### Select Molecules to Perform Interpolation

Select the indices of molecules for interpolation!

It is better that selected molecules are generated successfully (corresponding images are not empty).

In [None]:
# You can change the index according to the results of molecules
mole_index_1 = 81
mole_index_2 = 89

### Generate Interpolation Molecules

Results will be saved to the folder named with the selected two molecule indices. 

We only save molecule images with different SMILES to save memory.

In [None]:
# get reference fingerprint tokens
fingerprint_tokens_1 = all_fingerprint_tokens[mole_index_1].unsqueeze(0)  # (1, num_fingerprints, embed_size)
fingerprint_tokens_2 = all_fingerprint_tokens[mole_index_2].unsqueeze(0)  # (1, num_fingerprints, embed_size)

# linear interpolation
interpolation_fingerprint_tokens = []

interpolation_fingerprint_tokens.append(fingerprint_tokens_1)
for i in range(num_interpolated_features):
    alpha = (i + 1) / (num_interpolated_features + 1)
    interpolated_tensors = torch.lerp(fingerprint_tokens_1, fingerprint_tokens_2, alpha)
    interpolation_fingerprint_tokens.append(interpolated_tensors)
interpolation_fingerprint_tokens.append(fingerprint_tokens_2)

interpolation_fingerprint_tokens = torch.cat(interpolation_fingerprint_tokens, dim=0)

# generate
print(f"Generating for interpolation...")
interpolation_generated_results: list = model.generate_from_fingerprints(
    fingerprint_tokens=interpolation_fingerprint_tokens,
    bond_dict=bond_dict,
    strict_generation=True,
    max_atoms=None,
    similarity_threshold=0.5,
    check_first_node=True,
    check_atom_valence=False,
    fix_aromatic_bond=False,
    use_cache=False,
    save_failed=False,
    show_progress=True,
    verbose=True,
)

# save results
success_cnt = 0
invalid_cnt = 0
fail_cnt = 0
save_smiles_list = []
last_smiles = None

delete_file_or_dir(os.path.join(interpolation_save_dir, f"{mole_index_1}--{mole_index_2}"))
os.makedirs(os.path.join(interpolation_save_dir, f"{mole_index_1}--{mole_index_2}"))

for i, result in tqdm(enumerate(interpolation_generated_results), desc="Saving interpolation molecule images"):
    save_img_path = os.path.join(interpolation_save_dir, f"{mole_index_1}--{mole_index_2}", f"{i}.png")

    if result is not None:
        input_ids = result["input_ids"]
        graph_position_ids_1 = result["graph_position_ids_1"]
        graph_position_ids_2 = result["graph_position_ids_2"]
        identifier_ids = result["identifier_ids"]

        mol, smiles = tokenizer.decode(result)

        if mol is None:
            save_smiles_list.append(None)
            invalid_cnt += 1
        else:
            save_smiles_list.append(smiles)
            success_cnt += 1

    else:
        save_smiles_list.append(None)
        fail_cnt += 1

# save statistics
with open(os.path.join(interpolation_save_dir, f"{mole_index_1}--{mole_index_2}", "count.txt"), 'a') as f:
    f.write(f"Success count: {success_cnt}\n")
    f.write(f"Invalid count: {invalid_cnt}\n")
    f.write(f"Fail count: {fail_cnt}\n")

with open(os.path.join(interpolation_save_dir, f"{mole_index_1}--{mole_index_2}", "smiles.txt"), 'a') as f:
    for smiles in save_smiles_list:
        f.write(f"{smiles}\n")

print(f"Results saved to {os.path.join(interpolation_save_dir, f'{mole_index_1}--{mole_index_2}')}")

All done.
You can check the saved files for further analysis.