Skip to content

Commit

Permalink
release v0.1.3
Browse files Browse the repository at this point in the history
  • Loading branch information
KiddoZhu committed Jun 4, 2022
1 parent cb117f8 commit 7392070
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 43 deletions.
8 changes: 4 additions & 4 deletions conda/torchdrug/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
package:
name: torchdrug
version: 0.1.2
version: 0.1.3

source:
path: ../..

requirements:
host:
- python >=3.7,<3.9
- python >=3.7,<3.10
- pip
run:
- python >=3.7,<3.9
- python >=3.7,<3.10
- pytorch >=1.8.0
- pytorch-scatter >=2.0.8
- decorator
- numpy >=1.11
- rdkit
- rdkit >=2020.09
- matplotlib
- tqdm
- networkx
Expand Down
34 changes: 17 additions & 17 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
long_description_content_type="text/markdown",
url="https://torchdrug.ai/",
author="TorchDrug Team",
version="0.1.2",
version="0.1.3",
license="Apache-2.0",
keywords=["deep-learning", "pytorch", "drug-discovery"],
packages=setuptools.find_packages(),
Expand All @@ -24,23 +24,23 @@
"layers/functional/extension/*.cpp",
"layers/functional/extension/*.cu",
"utils/extension/*.cpp",
"utils/template/*.html"
]},
"utils/template/*.html",
]
},
test_suite="nose.collector",
install_requires=
[
"torch>=1.8.0",
"torch-scatter>=2.0.8",
"decorator",
"numpy>=1.11",
"rdkit-pypi",
"matplotlib",
"tqdm",
"networkx",
"ninja",
"jinja2",
],
python_requires=">=3.7,<3.9",
install_requires=[
"torch>=1.8.0",
"torch-scatter>=2.0.8",
"decorator",
"numpy>=1.11",
"rdkit-pypi>=2020.9",
"matplotlib",
"tqdm",
"networkx",
"ninja",
"jinja2",
],
python_requires=">=3.7,<3.10",
classifiers=[
"Development Status :: 4 - Beta",
'Intended Audience :: Developers',
Expand Down
10 changes: 5 additions & 5 deletions test/layers/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_graph_conv(self):
adjacency /= adjacency.sum(dim=0, keepdim=True).sqrt() * adjacency.sum(dim=1, keepdim=True).sqrt()
x = adjacency.t() @ self.input
truth = conv.activation(conv.linear(x))
self.assertTrue(torch.allclose(result, truth, rtol=1e-4, atol=1e-7), "Incorrect graph convolution")
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect graph convolution")

num_head = 2
conv = layers.GraphAttentionConv(self.input_dim, self.output_dim, num_head=num_head).cuda()
Expand All @@ -55,15 +55,15 @@ def test_graph_conv(self):
outputs.append(output)
truth = torch.cat(outputs, dim=-1)
truth = conv.activation(truth)
self.assertTrue(torch.allclose(result, truth), "Incorrect graph attention convolution")
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect graph attention convolution")

eps = 1
conv = layers.GraphIsomorphismConv(self.input_dim, self.output_dim, eps=eps).cuda()
result = conv(self.graph, self.input)
adjacency = self.graph.adjacency.to_dense().sum(dim=-1)
x = (1 + eps) * self.input + adjacency.t() @ self.input
truth = conv.activation(conv.mlp(x))
self.assertTrue(torch.allclose(result, truth, atol=1e-4, rtol=1e-7), "Incorrect graph isomorphism convolution")
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-2), "Incorrect graph isomorphism convolution")

conv = layers.RelationalGraphConv(self.input_dim, self.output_dim, self.num_relation).cuda()
result = conv(self.graph, self.input)
Expand All @@ -72,7 +72,7 @@ def test_graph_conv(self):
x = torch.einsum("htr, hd -> trd", adjacency, self.input)
x = conv.linear(x.flatten(1)) + conv.self_loop(self.input)
truth = conv.activation(x)
self.assertTrue(torch.allclose(result, truth, atol=1e-4, rtol=1e-7), "Incorrect relational graph convolution")
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect relational graph convolution")

conv = layers.ChebyshevConv(self.input_dim, self.output_dim, k=2).cuda()
result = conv(self.graph, self.input)
Expand All @@ -83,7 +83,7 @@ def test_graph_conv(self):
bases = [self.input, laplacian.t() @ self.input, (2 * laplacian.t() @ laplacian.t() - identity) @ self.input]
x = conv.linear(torch.cat(bases, dim=-1))
truth = conv.activation(x)
self.assertTrue(torch.allclose(result, truth, atol=1e-4, rtol=1e-7), "Incorrect chebyshev graph convolution")
self.assertTrue(torch.allclose(result, truth, rtol=1e-2, atol=1e-3), "Incorrect chebyshev graph convolution")


if __name__ == "__main__":
Expand Down
14 changes: 7 additions & 7 deletions test/layers/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def test_pool(self):
truth_adj = torch.einsum("bna, bnm, bmc -> bac", assignment, adjacency, assignment)
index = torch.arange(self.output_node, device=truth.device)
truth_adj[:, index, index] = 0
self.assertTrue(torch.allclose(result, truth), "Incorrect diffpool node feature")
self.assertTrue(torch.allclose(result_adj, truth_adj), "Incorrect diffpool adjacency")
self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect diffpool node feature")
self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect diffpool adjacency")

graph = self.graph[0]
rng_state = torch.get_rng_state()
Expand All @@ -60,8 +60,8 @@ def test_pool(self):
truth_adj = torch.einsum("na, nm, mc -> ac", assignment, adjacency, assignment)
index = torch.arange(self.output_node, device=truth.device)
truth_adj[index, index] = 0
self.assertTrue(torch.allclose(result, truth), "Incorrect diffpool node feature")
self.assertTrue(torch.allclose(result_adj, truth_adj), "Incorrect diffpool adjacency")
self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect diffpool node feature")
self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect diffpool adjacency")

pool = layers.MinCutPool(self.input_dim, self.output_node, self.feature_layer, self.pool_layer).cuda()
all_loss = torch.tensor(0, dtype=torch.float32, device="cuda")
Expand Down Expand Up @@ -89,10 +89,10 @@ def test_pool(self):
x = x - torch.eye(self.output_node, device=x.device) / (self.output_node ** 0.5)
regularization = x.flatten(-2).norm(dim=-1).mean()
truth_metric = {"normalized cut loss": cut_loss, "orthogonal regularization": regularization}
self.assertTrue(torch.allclose(result, truth), "Incorrect min cut pool feature")
self.assertTrue(torch.allclose(result_adj, truth_adj), "Incorrect min cut pool adjcency")
self.assertTrue(torch.allclose(result, truth, rtol=1e-3, atol=1e-4), "Incorrect min cut pool feature")
self.assertTrue(torch.allclose(result_adj, truth_adj, rtol=1e-3, atol=1e-4), "Incorrect min cut pool adjcency")
for key in result_metric:
self.assertTrue(torch.allclose(result_metric[key], truth_metric[key], atol=1e-4, rtol=1e-7),
self.assertTrue(torch.allclose(result_metric[key], truth_metric[key], rtol=1e-3, atol=1e-4),
"Incorrect min cut pool metric")


Expand Down
2 changes: 1 addition & 1 deletion torchdrug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
handler.setFormatter(format)
logger.addHandler(handler)

__version__ = "0.1.2"
__version__ = "0.1.3"
24 changes: 15 additions & 9 deletions torchdrug/utils/decorator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import warnings
import functools

from decorator import decorator

Expand Down Expand Up @@ -100,14 +101,19 @@ def deprecated_alias(**alias):
Handle argument alias for a function and output deprecated warnings.
"""

def wrapper(func, *args, **kwargs):
for key, value in alias.items():
if key in kwargs:
if value in kwargs:
raise TypeError("%s() got values for both `%s` and `%s`" % (func.__name__, value, key))
warnings.warn("%s(): argument `%s` is deprecated in favor of `%s`" % (func.__name__, key, value))
kwargs[value] = kwargs.pop(key)
def decorate(func):

return func(*args, **kwargs)
@functools.wraps(func)
def wrapper(*args, **kwargs):
for key, value in alias.items():
if key in kwargs:
if value in kwargs:
raise TypeError("%s() got values for both `%s` and `%s`" % (func.__name__, value, key))
warnings.warn("%s(): argument `%s` is deprecated in favor of `%s`" % (func.__name__, key, value))
kwargs[value] = kwargs.pop(key)

return decorator(wrapper, kwsyntax=True)
return func(*args, **kwargs)

return wrapper

return decorate

0 comments on commit 7392070

Please sign in to comment.