Skip to content

Commit

Permalink
fix synthon for USPTO50k
Browse files Browse the repository at this point in the history
  • Loading branch information
KiddoZhu committed Oct 14, 2021
1 parent 489677c commit e30e2b0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 206 deletions.
202 changes: 0 additions & 202 deletions diff.txt

This file was deleted.

7 changes: 3 additions & 4 deletions torchdrug/datasets/uspto50k.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,9 @@ def _get_synthon(self, reactant, product):

if len(edge_added) > 0:
if len(edge_added) == 1: # add a single edge
edge = edge_added[0]
reverse_edge = edge.flip(0)
reverse_edge = edge_added.flip(1)
any = -torch.ones(2, 1, dtype=torch.long)
pattern = torch.cat([edge, reverse_edge])
pattern = torch.cat([edge_added, reverse_edge])
pattern = torch.cat([pattern, any], dim=-1)
index, num_match = product.match(pattern)
edge_mask = torch.ones(product.num_edge, dtype=torch.bool)
Expand All @@ -186,7 +185,7 @@ def _get_synthon(self, reactant, product):
_synthons = product.connected_components()[0]
assert len(_synthons) >= len(_reactants) # because a few samples contain multiple products

h, t = edge
h, t = edge_added[0]
reaction_center = torch.tensor([product.atom_map[h], product.atom_map[t]])
with _reactants.graph():
_reactants.reaction_center = reaction_center.expand(len(_reactants), -1)
Expand Down

0 comments on commit e30e2b0

Please sign in to comment.