-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
272 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
"""create_tanimot_smiles.py: Code to create the molecule used in GDSS for Tanimoto on QM9. | ||
""" | ||
import matplotlib.pyplot as plt | ||
import rdkit | ||
from rdkit import Chem | ||
from rdkit.Chem import Draw | ||
from rdkit.Chem.rdchem import RWMol | ||
|
||
if __name__ == "__main__": | ||
# Create molecule | ||
C = rdkit.Chem.rdchem.Atom("C") | ||
O = rdkit.Chem.rdchem.Atom("O") | ||
N = rdkit.Chem.rdchem.Atom("N") | ||
|
||
mol = RWMol() | ||
mol.AddAtom(C) | ||
mol.AddAtom(C) | ||
mol.AddAtom(C) | ||
mol.AddAtom(C) | ||
mol.AddAtom(C) | ||
mol.AddBond(0, 1, rdkit.Chem.rdchem.BondType.DOUBLE) | ||
mol.AddBond(1, 2, rdkit.Chem.rdchem.BondType.SINGLE) | ||
mol.AddBond(2, 3, rdkit.Chem.rdchem.BondType.SINGLE) | ||
mol.AddBond(3, 4, rdkit.Chem.rdchem.BondType.SINGLE) | ||
mol.AddBond(4, 0, rdkit.Chem.rdchem.BondType.SINGLE) | ||
mol.AddAtom(N) | ||
mol.AddAtom(C) | ||
mol.AddBond(3, 5, rdkit.Chem.rdchem.BondType.SINGLE) | ||
mol.AddBond(3, 6, rdkit.Chem.rdchem.BondType.SINGLE) | ||
mol.AddBond(5, 6, rdkit.Chem.rdchem.BondType.SINGLE) | ||
mol.AddAtom(C) | ||
mol.AddAtom(O) | ||
mol.AddBond(6, 7, rdkit.Chem.rdchem.BondType.SINGLE) | ||
mol.AddBond(7, 8, rdkit.Chem.rdchem.BondType.DOUBLE) | ||
|
||
# Sanitize molecule | ||
rdkit.Chem.SanitizeMol(mol) | ||
|
||
# Convert to SMILES | ||
smiles = Chem.MolToSmiles(mol) | ||
print(smiles) | ||
|
||
# Plot molecule | ||
mol_img = Draw.MolToImage(mol, size=(300, 300)) | ||
plt.imshow(mol_img) | ||
plt.suptitle(f"SMILES: {smiles}") | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
"""run_frobenius_complexity.py: Code to assess the complexity of learning partial score functions using the Frobenius norm of the Jacobian of our models. | ||
""" | ||
|
||
import os | ||
import sys | ||
|
||
sys.path.insert(0, os.getcwd()) | ||
|
||
import torch | ||
import torch.autograd as autograd | ||
|
||
from ccsd.src.parsers.config import get_config | ||
from ccsd.src.utils.loader import load_model_optimizer, load_model_params | ||
from ccsd.src.utils.models_utils import get_nb_parameters | ||
|
||
|
||
def frobenius_norm_jacobian(model: torch.nn.Module, t: torch.Tensor) -> float: | ||
"""Calculate the Frobenius norm of the Jacobian matrix of a PyTorch model. | ||
Args: | ||
model (torch.nn.Module): The PyTorch model for which to compute the Jacobian. | ||
t (torch.Tensor): Input tensor for which to compute the Jacobian. | ||
Returns: | ||
float: The Frobenius norm of the Jacobian matrix of the model for the input tensor. | ||
""" | ||
# Evaluation mode and clear gradients | ||
model.eval() | ||
model.zero_grad() | ||
# Calculate the Jacobian matrix | ||
jac = autograd.functional.jacobian(model, t) | ||
# Compute the Frobenius norm | ||
frob_norm = torch.norm(jac, "fro").item() | ||
return frob_norm | ||
|
||
|
||
if __name__ == "__main__": | ||
for dataset in [ | ||
"QM9", | ||
"ENZYMES_small", | ||
"community_small", | ||
"ego_small", | ||
"grid_small", | ||
]: | ||
t = None # TO PROVIDE | ||
|
||
print("\n----------------------") | ||
print(f"{dataset}") | ||
print("-----") | ||
|
||
print("\nGraph") | ||
cfg = f"{dataset.lower()}" | ||
config = get_config(cfg, 42) | ||
params_x, params_adj = load_model_params(config, is_cc=False) | ||
try: | ||
model_x, optimizer_x, scheduler_x = load_model_optimizer( | ||
params_x, config.train, "cpu" | ||
) | ||
model_adj, optimizer_adj, scheduler_adj = load_model_optimizer( | ||
params_adj, config.train, "cpu" | ||
) | ||
|
||
print(f"Complexity x: {frobenius_norm_jacobian(model_x, t)}") | ||
print(f"Complexity adj: {frobenius_norm_jacobian(model_adj, t)}") | ||
except Exception as e: | ||
print("NaN") | ||
|
||
print("\nCC") | ||
cfg = f"{dataset.lower()}_CC" | ||
config = get_config(cfg, 42) | ||
params_x, params_adj, params_rank2 = load_model_params(config, is_cc=True) | ||
try: | ||
model_x, optimizer_x, scheduler_x = load_model_optimizer( | ||
params_x, config.train, "cpu" | ||
) | ||
model_adj, optimizer_adj, scheduler_adj = load_model_optimizer( | ||
params_adj, config.train, "cpu" | ||
) | ||
model_rank2, optimizer_rank2, scheduler_rank2 = load_model_optimizer( | ||
params_rank2, config.train, "cpu" | ||
) | ||
|
||
print(f"Complexity x: {frobenius_norm_jacobian(model_x, t)}") | ||
print(f"Complexity adj: {frobenius_norm_jacobian(model_adj, t)}") | ||
print(f"Complexity rank2: {frobenius_norm_jacobian(model_rank2, t)}") | ||
except Exception as e: | ||
print("NaN") | ||
|
||
print("\nCC Base Ablation study") | ||
cfg = f"{dataset.lower()}_Base_CC" | ||
config = get_config(cfg, 42) | ||
params_x, params_adj, params_rank2 = load_model_params(config, is_cc=True) | ||
try: | ||
model_x, optimizer_x, scheduler_x = load_model_optimizer( | ||
params_x, config.train, "cpu" | ||
) | ||
model_adj, optimizer_adj, scheduler_adj = load_model_optimizer( | ||
params_adj, config.train, "cpu" | ||
) | ||
model_rank2, optimizer_rank2, scheduler_rank2 = load_model_optimizer( | ||
params_rank2, config.train, "cpu" | ||
) | ||
|
||
print(f"Complexity x: {frobenius_norm_jacobian(model_x, t)}") | ||
print(f"Complexity adj: {frobenius_norm_jacobian(model_adj, t)}") | ||
print(f"Complexity rank2: {frobenius_norm_jacobian(model_rank2, t)}") | ||
except Exception as e: | ||
print("NaN") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
"""run_tanimoto_benchmark.py: Code to compare the average Tanimoto similarity between samples drawn from GDSS/CCSD and the training set. | ||
""" | ||
|
||
import os | ||
import pickle | ||
import sys | ||
|
||
sys.path.insert(0, os.getcwd()) | ||
|
||
from rdkit import Chem | ||
from rdkit.Chem import AllChem | ||
from rdkit.DataStructs.cDataStructs import BulkTanimotoSimilarity | ||
from tqdm import tqdm | ||
|
||
from ccsd.src.utils.mol_utils import canonicalize_smiles, load_smiles | ||
|
||
if __name__ == "__main__": | ||
# Calculate Morgan fingerprints for QM9 training molecules | ||
print("Loading train data...") | ||
train_smiles, _ = load_smiles("QM9") | ||
train_smiles = canonicalize_smiles(train_smiles) | ||
train_molecules = [Chem.MolFromSmiles(m) for m in train_smiles] | ||
training_fps = [ | ||
AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=1024) | ||
for mol in train_molecules | ||
] | ||
|
||
# CCSD | ||
# Load generated molecules | ||
gen_mol_file = ( | ||
"samples/pkl/QM9/test/sample_qm9_CC_ccsd_qm9_CC-sample_Aug27-10-44-56_mols.pkl" | ||
) | ||
with open(os.path.join("./", gen_mol_file), "rb") as f: | ||
gen_molecules = pickle.load(f) | ||
|
||
# Calculate average Tanimoto similarity | ||
avg_sim = 0 | ||
for mol_idx in tqdm(range(len(gen_molecules))): | ||
gen_mol = gen_molecules[mol_idx] | ||
# Calculate Morgan fingerprint for the generated molecule | ||
gen_fp = AllChem.GetMorganFingerprintAsBitVect(gen_mol, 2, nBits=1024) | ||
|
||
# Calculate Tanimoto similarity with all training molecules | ||
similarities = BulkTanimotoSimilarity(gen_fp, training_fps) | ||
avg_sim += max(similarities) | ||
avg_sim /= len(gen_molecules) | ||
print(f"CCSD: {round(avg_sim, 3)}") | ||
|
||
# GDSS | ||
# Load generated molecules | ||
gen_mol_file = "samples/pkl/QM9/test/sample_qm9_retrained_gdss_qm9_retrained-sample_Aug27-14-01-35_mols.pkl" | ||
with open(os.path.join("./", gen_mol_file), "rb") as f: | ||
gen_molecules = pickle.load(f) | ||
|
||
# Calculate average Tanimoto similarity | ||
avg_sim2 = 0 | ||
for mol_idx in tqdm(range(len(gen_mol_file))): | ||
gen_mol = gen_mol_file[mol_idx] | ||
# Calculate Morgan fingerprint for the generated molecule | ||
gen_fp = AllChem.GetMorganFingerprintAsBitVect(gen_mol, 2, nBits=1024) | ||
|
||
# Calculate Tanimoto similarity with all training molecules | ||
similarities = BulkTanimotoSimilarity(gen_fp, training_fps) | ||
avg_sim2 += max(similarities) | ||
avg_sim2 /= len(gen_mol_file) | ||
|
||
print(f"GDSS: {round(avg_sim2, 3)}") |