Skip to content

Commit

Permalink
Improve Tanimoto
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrienC21 committed Sep 10, 2023
1 parent c061ca1 commit 3460cc8
Show file tree
Hide file tree
Showing 4 changed files with 272 additions and 13 deletions.
50 changes: 50 additions & 0 deletions scripts/create_tanimoto_smiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""create_tanimot_smiles.py: Code to create the molecule used in GDSS for Tanimoto on QM9.
"""
import matplotlib.pyplot as plt
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.rdchem import RWMol

if __name__ == "__main__":
# Create molecule
C = rdkit.Chem.rdchem.Atom("C")
O = rdkit.Chem.rdchem.Atom("O")
N = rdkit.Chem.rdchem.Atom("N")

mol = RWMol()
mol.AddAtom(C)
mol.AddAtom(C)
mol.AddAtom(C)
mol.AddAtom(C)
mol.AddAtom(C)
mol.AddBond(0, 1, rdkit.Chem.rdchem.BondType.DOUBLE)
mol.AddBond(1, 2, rdkit.Chem.rdchem.BondType.SINGLE)
mol.AddBond(2, 3, rdkit.Chem.rdchem.BondType.SINGLE)
mol.AddBond(3, 4, rdkit.Chem.rdchem.BondType.SINGLE)
mol.AddBond(4, 0, rdkit.Chem.rdchem.BondType.SINGLE)
mol.AddAtom(N)
mol.AddAtom(C)
mol.AddBond(3, 5, rdkit.Chem.rdchem.BondType.SINGLE)
mol.AddBond(3, 6, rdkit.Chem.rdchem.BondType.SINGLE)
mol.AddBond(5, 6, rdkit.Chem.rdchem.BondType.SINGLE)
mol.AddAtom(C)
mol.AddAtom(O)
mol.AddBond(6, 7, rdkit.Chem.rdchem.BondType.SINGLE)
mol.AddBond(7, 8, rdkit.Chem.rdchem.BondType.DOUBLE)

# Sanitize molecule
rdkit.Chem.SanitizeMol(mol)

# Convert to SMILES
smiles = Chem.MolToSmiles(mol)
print(smiles)

# Plot molecule
mol_img = Draw.MolToImage(mol, size=(300, 300))
plt.imshow(mol_img)
plt.suptitle(f"SMILES: {smiles}")
plt.show()
111 changes: 111 additions & 0 deletions scripts/run_frobenius_complexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""run_frobenius_complexity.py: Code to assess the complexity of learning partial score functions using the Frobenius norm of the Jacobian of our models.
"""

import os
import sys

sys.path.insert(0, os.getcwd())

import torch
import torch.autograd as autograd

from ccsd.src.parsers.config import get_config
from ccsd.src.utils.loader import load_model_optimizer, load_model_params
from ccsd.src.utils.models_utils import get_nb_parameters


def frobenius_norm_jacobian(model: torch.nn.Module, t: torch.Tensor) -> float:
"""Calculate the Frobenius norm of the Jacobian matrix of a PyTorch model.
Args:
model (torch.nn.Module): The PyTorch model for which to compute the Jacobian.
t (torch.Tensor): Input tensor for which to compute the Jacobian.
Returns:
float: The Frobenius norm of the Jacobian matrix of the model for the input tensor.
"""
# Evaluation mode and clear gradients
model.eval()
model.zero_grad()
# Calculate the Jacobian matrix
jac = autograd.functional.jacobian(model, t)
# Compute the Frobenius norm
frob_norm = torch.norm(jac, "fro").item()
return frob_norm


if __name__ == "__main__":
for dataset in [
"QM9",
"ENZYMES_small",
"community_small",
"ego_small",
"grid_small",
]:
t = None # TO PROVIDE

print("\n----------------------")
print(f"{dataset}")
print("-----")

print("\nGraph")
cfg = f"{dataset.lower()}"
config = get_config(cfg, 42)
params_x, params_adj = load_model_params(config, is_cc=False)
try:
model_x, optimizer_x, scheduler_x = load_model_optimizer(
params_x, config.train, "cpu"
)
model_adj, optimizer_adj, scheduler_adj = load_model_optimizer(
params_adj, config.train, "cpu"
)

print(f"Complexity x: {frobenius_norm_jacobian(model_x, t)}")
print(f"Complexity adj: {frobenius_norm_jacobian(model_adj, t)}")
except Exception as e:
print("NaN")

print("\nCC")
cfg = f"{dataset.lower()}_CC"
config = get_config(cfg, 42)
params_x, params_adj, params_rank2 = load_model_params(config, is_cc=True)
try:
model_x, optimizer_x, scheduler_x = load_model_optimizer(
params_x, config.train, "cpu"
)
model_adj, optimizer_adj, scheduler_adj = load_model_optimizer(
params_adj, config.train, "cpu"
)
model_rank2, optimizer_rank2, scheduler_rank2 = load_model_optimizer(
params_rank2, config.train, "cpu"
)

print(f"Complexity x: {frobenius_norm_jacobian(model_x, t)}")
print(f"Complexity adj: {frobenius_norm_jacobian(model_adj, t)}")
print(f"Complexity rank2: {frobenius_norm_jacobian(model_rank2, t)}")
except Exception as e:
print("NaN")

print("\nCC Base Ablation study")
cfg = f"{dataset.lower()}_Base_CC"
config = get_config(cfg, 42)
params_x, params_adj, params_rank2 = load_model_params(config, is_cc=True)
try:
model_x, optimizer_x, scheduler_x = load_model_optimizer(
params_x, config.train, "cpu"
)
model_adj, optimizer_adj, scheduler_adj = load_model_optimizer(
params_adj, config.train, "cpu"
)
model_rank2, optimizer_rank2, scheduler_rank2 = load_model_optimizer(
params_rank2, config.train, "cpu"
)

print(f"Complexity x: {frobenius_norm_jacobian(model_x, t)}")
print(f"Complexity adj: {frobenius_norm_jacobian(model_adj, t)}")
print(f"Complexity rank2: {frobenius_norm_jacobian(model_rank2, t)}")
except Exception as e:
print("NaN")
54 changes: 41 additions & 13 deletions scripts/run_tanimoto.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@

def find_max_similarity_molecules_tanimoto(
generated_molecules: List[Chem.Mol],
training_molecules: List[Chem.Mol],
training_molecules: Union[List[Chem.Mol], str],
plot_result: bool = True,
folder: str = "./",
max_num: int = 16,
dataset: str = "QM9",
method: str = "CCSD",
) -> Tuple[List[Chem.Mol], float]:
"""Find the most similar molecules in a training set to a set of generated molecules using Tanimoto Similarity.
Args:
generated_molecules (List[Chem.Mol]): list of generated molecules
training_molecules (List[Chem.Mol]): list of training molecules
training_molecules (Union[List[Chem.Mol], str]): list of training molecules or single molecule SMILES string
plot_result (bool, optional): whether to plot the most similar molecules. Defaults to True.
folder (str, optional): directory where to create a analysis folder to save the results. Defaults to "./".
max_num (int, optional): maximum number of molecules to plot, if we plot. Defaults to 16.
Expand All @@ -45,7 +46,7 @@ def find_max_similarity_molecules_tanimoto(
Tuple[List[Chem.Mol], float]: list of most similar molecules in the training set and the maximum tanimoto similarity score
"""
# Calculate Morgan fingerprints for training molecules
if isinstance(training_molecules):
if isinstance(training_molecules, str):
training_molecules = [Chem.MolFromSmiles(training_molecules)]
training_fps = [
AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)
Expand All @@ -69,7 +70,7 @@ def find_max_similarity_molecules_tanimoto(
if max_sim > max_similarity:
max_similarity = max_sim
max_similar_molecules = [
training_molecules[i]
[training_molecules[i], gen_mol]
for i, sim in enumerate(similarities)
if sim == max_sim
]
Expand All @@ -79,26 +80,41 @@ def find_max_similarity_molecules_tanimoto(
if not (os.path.exists(os.path.join(folder, "analysis"))):
os.makedirs(os.path.join(folder, "analysis"))
max_num = min(len(max_similar_molecules), max_num)
img_c = int(math.ceil(np.sqrt(max_num)))
img_c = int(math.ceil(np.sqrt(2 * max_num)))
figure = plt.figure()

for i in range(max_num):
mol = max_similar_molecules[i]
mol_train, mol_gen = max_similar_molecules[i]

assert isinstance(
mol, Chem.Mol
mol_train, Chem.Mol
), "elements should be molecules" # check if we have a molecule
assert isinstance(
mol_gen, Chem.Mol
), "elements should be molecules" # check if we have a molecule

ax = plt.subplot(img_c, img_c, 2 * i + 1)
mol_img = Draw.MolToImage(mol_train, size=(300, 300))
ax.imshow(mol_img)
title_str = f"Train: {Chem.MolToSmiles(mol_train)}"
ax.title.set_text(title_str)
ax.set_axis_off()

ax = plt.subplot(img_c, img_c, i + 1)
mol_img = Draw.MolToImage(mol, size=(300, 300))
ax = plt.subplot(img_c, img_c, 2 * i + 2)
mol_img = Draw.MolToImage(mol_gen, size=(300, 300))
ax.imshow(mol_img)
title_str = f"{Chem.MolToSmiles(mol)}"
title_str = f"Gen: {Chem.MolToSmiles(mol_gen)}"
ax.title.set_text(title_str)
ax.set_axis_off()
figure.suptitle(f"Dataset: {dataset}. Tanimoto Similarity: {max_similarity}")

figure.suptitle(
f"{method}. Dataset: {dataset}. Tanimoto Similarity: {round(max_similarity, 3)}"
)
plt.savefig(
os.path.join(
folder, "analysis", f"{dataset}_most_similar_molecules_tanimoto.png"
folder,
"analysis",
f"{dataset}_most_similar_molecules_tanimoto_{method}.png",
)
)

Expand Down Expand Up @@ -133,6 +149,12 @@ def find_max_similarity_molecules_tanimoto(
default="./",
help="Directory to save the results in an `analysis` folder",
)
parser.add_argument(
"--method",
type=str,
default="CCSD",
help="Name of the approach used to generate the molecules",
)
args = parser.parse_known_args()[0]

if args.single_mol is not None:
Expand All @@ -148,5 +170,11 @@ def find_max_similarity_molecules_tanimoto(
with open(os.path.join(args.folder, args.gen_mol_file), "rb") as f:
gen_molecules = pickle.load(f)
find_max_similarity_molecules_tanimoto(
gen_molecules, train_molecules, PLOT_RESULT, args.folder, MAX_NUM, args.dataset
gen_molecules,
train_molecules,
PLOT_RESULT,
args.folder,
MAX_NUM,
args.dataset,
args.method,
)
70 changes: 70 additions & 0 deletions scripts/run_tanimoto_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""run_tanimoto_benchmark.py: Code to compare the average Tanimoto similarity between samples drawn from GDSS/CCSD and the training set.
"""

import os
import pickle
import sys

sys.path.insert(0, os.getcwd())

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.DataStructs.cDataStructs import BulkTanimotoSimilarity
from tqdm import tqdm

from ccsd.src.utils.mol_utils import canonicalize_smiles, load_smiles

if __name__ == "__main__":
# Calculate Morgan fingerprints for QM9 training molecules
print("Loading train data...")
train_smiles, _ = load_smiles("QM9")
train_smiles = canonicalize_smiles(train_smiles)
train_molecules = [Chem.MolFromSmiles(m) for m in train_smiles]
training_fps = [
AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024)
for mol in train_molecules
]

# CCSD
# Load generated molecules
gen_mol_file = (
"samples/pkl/QM9/test/sample_qm9_CC_ccsd_qm9_CC-sample_Aug27-10-44-56_mols.pkl"
)
with open(os.path.join("./", gen_mol_file), "rb") as f:
gen_molecules = pickle.load(f)

# Calculate average Tanimoto similarity
avg_sim = 0
for mol_idx in tqdm(range(len(gen_molecules))):
gen_mol = gen_molecules[mol_idx]
# Calculate Morgan fingerprint for the generated molecule
gen_fp = AllChem.GetMorganFingerprintAsBitVect(gen_mol, 2, nBits=1024)

# Calculate Tanimoto similarity with all training molecules
similarities = BulkTanimotoSimilarity(gen_fp, training_fps)
avg_sim += max(similarities)
avg_sim /= len(gen_molecules)
print(f"CCSD: {round(avg_sim, 3)}")

# GDSS
# Load generated molecules
gen_mol_file = "samples/pkl/QM9/test/sample_qm9_retrained_gdss_qm9_retrained-sample_Aug27-14-01-35_mols.pkl"
with open(os.path.join("./", gen_mol_file), "rb") as f:
gen_molecules = pickle.load(f)

# Calculate average Tanimoto similarity
avg_sim2 = 0
for mol_idx in tqdm(range(len(gen_mol_file))):
gen_mol = gen_mol_file[mol_idx]
# Calculate Morgan fingerprint for the generated molecule
gen_fp = AllChem.GetMorganFingerprintAsBitVect(gen_mol, 2, nBits=1024)

# Calculate Tanimoto similarity with all training molecules
similarities = BulkTanimotoSimilarity(gen_fp, training_fps)
avg_sim2 += max(similarities)
avg_sim2 /= len(gen_mol_file)

print(f"GDSS: {round(avg_sim2, 3)}")

0 comments on commit 3460cc8

Please sign in to comment.