# Pipeline of GraphsGPT with Hugging Face Transformers

### Configurations

In [None]:
import torch

model_name_or_path = "DaizeDong/GraphsGPT-8W"
smiles_file = "../data/examples/zinc_example.txt"

batch_size = 1024
max_batches = 4
device = "cuda" if torch.cuda.is_available() else "cpu"

### Load SMILES

In [None]:
with open(smiles_file, "r", encoding="utf-8") as f:
    smiles_list = f.readlines()
smiles_list = [smiles.removesuffix("\n") for smiles in smiles_list]

print(f"Total SMILES loaded: {len(smiles_list)}")
for i in range(10):
    print(f"Example SMILES {i}: {smiles_list[i]}")

### Load Model & Tokenizer

In [None]:
from models.graphsgpt.modeling_graphsgpt import GraphsGPTForCausalLM
from data.tokenizer import GraphsGPTTokenizer

model = GraphsGPTForCausalLM.from_pretrained(model_name_or_path)
tokenizer = GraphsGPTTokenizer.from_pretrained(model_name_or_path)

print(model.state_dict().keys())
print(f"Total paramerters: {sum(x.numel() for x in model.parameters())}")

### Encode SMILES into Fingerprint Embeddings (Graph Words)

In [None]:
from utils.operations.operation_tensor import move_tensors_to_device
from utils.operations.operation_list import split_list_with_yield

batch_count = 0
fingerprints_lists = []

model.to(device)
model.eval()
with torch.no_grad():
    for batched_smiles in split_list_with_yield(smiles_list, batch_size):
        inputs = tokenizer.batch_encode(batched_smiles, return_tensors="pt")
        move_tensors_to_device(inputs, device)

        fingerprint_tokens = model.encode_to_fingerprints(**inputs)  # (batch_size, num_fingerprints, hidden_dim)
        fingerprints_lists.append(fingerprint_tokens)

        batch_count += 1
        if batch_count >= max_batches:
            break

print(f"Encoded total {batch_count * batch_size} molecules")

### Recover Molecule Sequences through Generation

In [None]:
all_results = []

for fingerprints in fingerprints_lists:
    generation_result = model.generate_from_fingerprints(
        fingerprint_tokens=fingerprints,
        bond_dict=tokenizer.bond_dict,
        strict_generation=True,
        max_atoms=None,
        similarity_threshold=0.5,
        check_first_node=True,
        check_atom_valence=False,
        fix_aromatic_bond=False,
        use_cache=False,
        save_failed=True,  # save the generated partial result even the full generation failed
        show_progress=True,
        verbose=True,
    )
    all_results.extend(generation_result)

print("Done.")
print(f"#### Generated {len(all_results)} molecules")

### Decode Sequences back to SMILES

In [None]:
from rdkit.Chem import Draw


def show_mol_png(mol, size=(512, 512)):
    img = Draw.MolToImage(mol, size=size)
    img.show()
    img.close()


decoded_mols = []
decoded_smiles = []

for result in all_results:
    if result is not None:
        mol, smiles = tokenizer.decode(result)
        decoded_mols.append(mol)
        decoded_smiles.append(smiles)
    else:
        decoded_mols.append(None)
        decoded_smiles.append(None)

# visualize the first 10 results
for i in range(10):
    print(f"Original SMILES {i}: {smiles_list[i]}")
    print(f"Decoded SMILES {i}: {decoded_smiles[i]}")
    show_mol_png(decoded_mols[i])