Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

train errors in Synthon Completion part of Retrosynthesis tutorial #62

Closed
Drlittlelab opened this issue Jan 18, 2022 · 1 comment
Closed
Labels
bug Something isn't working

Comments

@Drlittlelab
Copy link

Drlittlelab commented Jan 18, 2022

When I repeat the Synthon Completion part of Retrosynthesis tutorial in https://torchdrug.ai/docs/tutorials/retrosynthesis.html, run code on Windows operating system, it shows errors as follows:

Traceback (most recent call last):
  File "F:/workdir/pycharm/Retrosynthesis/main.py", line 109, in <module>
    synthon_solver.train(num_epoch=10)
  File "D:\soft\Anaconda3\envs\py37\lib\site-packages\torchdrug-0.1.0-py3.7.egg\torchdrug\core\engine.py", line 143, in train
    loss, metric = model(batch)
  File "D:\soft\Anaconda3\envs\py37\lib\site-packages\torch\nn\modules\module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "D:\soft\Anaconda3\envs\py37\lib\site-packages\torchdrug-0.1.0-py3.7.egg\torchdrug\tasks\retrosynthesis.py", line 596, in forward
    pred, target = self.predict_and_target(batch, all_loss, metric)
  File "D:\soft\Anaconda3\envs\py37\lib\site-packages\torchdrug-0.1.0-py3.7.egg\torchdrug\tasks\retrosynthesis.py", line 992, in predict_and_target
    graph1, node_in_target1, node_out_target1, bond_target1, stop_target1 = self.all_edge(reactant, synthon)
  File "D:\soft\Anaconda3\envs\py37\lib\site-packages\torch\autograd\grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "D:\soft\Anaconda3\envs\py37\lib\site-packages\torchdrug-0.1.0-py3.7.egg\torchdrug\tasks\retrosynthesis.py", line 561, in all_edge
    graph, feature_valid = self._update_molecule_feature(graph)
  File "D:\soft\Anaconda3\envs\py37\lib\site-packages\torchdrug-0.1.0-py3.7.egg\torchdrug\tasks\retrosynthesis.py", line 385, in _update_molecule_feature
    mols = graphs.to_molecule(ignore_error=True)
  File "D:\soft\Anaconda3\envs\py37\lib\site-packages\torchdrug-0.1.0-py3.7.egg\torchdrug\data\molecule.py", line 782, in to_molecule
    bond.SetStereoAtoms(*stereo_atoms[j])
RuntimeError: Pre-condition Violation
	bgnIdx not connected to begin atom of bond
	Violation occurred on line 247 in file Code\GraphMol\Bond.cpp
	Failed Expression: getOwningMol().getBondBetweenAtoms(getBeginAtomIdx(), bgnIdx) != nullptr
	RDKIT: 2021.03.5
	BOOST: 1_74

And run same code on Linux operating system, it shows errors as follows:

Traceback (most recent call last):
  File "synthon.py", line 49, in <module>
    synthon_solver.train(num_epoch=10)
  File "/cluster/projects/nn2855k/qifeng/soft/anaconda3/envs/torchdrug/lib/python3.7/site-packages/torchdrug/core/engine.py", line 143, in train
    loss, metric = model(batch)
  File "/cluster/projects/nn2855k/qifeng/soft/anaconda3/envs/torchdrug/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/cluster/projects/nn2855k/qifeng/soft/anaconda3/envs/torchdrug/lib/python3.7/site-packages/torchdrug/tasks/retrosynthesis.py", line 593, in forward
    pred, target = self.predict_and_target(batch, all_loss, metric)
  File "/cluster/projects/nn2855k/qifeng/soft/anaconda3/envs/torchdrug/lib/python3.7/site-packages/torchdrug/tasks/retrosynthesis.py", line 1000, in predict_and_target
    node_feature = graph.node_feature.float() + self.input_linear(synthon_feature)
RuntimeError: The size of tensor a (39) must match the size of tensor b (43) at non-singleton dimension 1

I paste running code as below:

from torchdrug import data, datasets, utils

reaction_dataset = datasets.USPTO50k("/home/work/torchdrug/data/",
                                     node_feature="center_identification",
                                     kekulize=True)
synthon_dataset = datasets.USPTO50k("/home/work/torchdrug/data/", as_synthon=True,
                                    node_feature="center_identification",
                                    kekulize=True)

from torchdrug.utils import plot

import torch

torch.manual_seed(1)
reaction_train, reaction_valid, reaction_test = reaction_dataset.split()
torch.manual_seed(1)
synthon_train, synthon_valid, synthon_test = synthon_dataset.split()


from torchdrug import core, models, tasks

reaction_model = models.RGCN(input_dim=reaction_dataset.node_feature_dim,
                    hidden_dims=[256, 256, 256, 256, 256, 256],
                    num_relation=reaction_dataset.num_bond_type,
                    concat_hidden=True)
reaction_task = tasks.CenterIdentification(reaction_model,
                                           feature=("graph", "atom", "bond"))
reaction_optimizer = torch.optim.Adam(reaction_task.parameters(), lr=1e-3)
reaction_solver = core.Engine(reaction_task, reaction_train, reaction_valid,
                              reaction_test, reaction_optimizer,
                              gpus=[0], batch_size=128)
# reaction_solver.train(num_epoch=50)
# reaction_solver.evaluate("valid")
# reaction_solver.save("g2gs_reaction_model.pth")
reaction_solver.load("g2gs_reaction_model.pth")

synthon_model = models.RGCN(input_dim=synthon_dataset.node_feature_dim,
                            hidden_dims=[256, 256, 256, 256, 256, 256],
                            num_relation=synthon_dataset.num_bond_type,
                            concat_hidden=True)
synthon_task = tasks.SynthonCompletion(synthon_model, feature=("graph",))

synthon_optimizer = torch.optim.Adam(synthon_task.parameters(), lr=1e-3)
synthon_solver = core.Engine(synthon_task, synthon_train, synthon_valid,
                             synthon_test, synthon_optimizer,
                             gpus=[0], batch_size=128)
synthon_solver.train(num_epoch=10)
synthon_solver.evaluate("valid")
synthon_solver.save("g2gs_synthon_model.pth")

What is the problem? Could anybody help me to solve it?
Thanks

@KiddoZhu KiddoZhu added the bug Something isn't working label Jan 19, 2022
@KiddoZhu
Copy link
Contributor

KiddoZhu commented Jan 30, 2022

Hi! For the Windows problem, this is a bug in the old version of TorchDrug. You may try the latest version (0.1.2.post1) of TorchDrug and it should be solved.

For the Linux problem, I found that the SynthonCompletion class is hard-coded for node_feature="synthon_completion". I just fixed it in a93356e and now it should work for arbitrary node features.

Note you use node_feature="center_identification" for both datasets. Ideally, you may use node_feature="synthon_completion" for the synthon dataset to obtain optimal performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants