Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/anucvml/ddn
Browse files Browse the repository at this point in the history
  • Loading branch information
dylan-campbell committed Oct 10, 2019
2 parents 15012d8 + 65d8e7b commit 0e4d793
Show file tree
Hide file tree
Showing 7 changed files with 322 additions and 37 deletions.
11 changes: 9 additions & 2 deletions README.md
Expand Up @@ -28,8 +28,15 @@ export PYTHONPATH=${PYTHONPATH}:ddn
python tests/testBasicDeclNodes.py
```

Example applications for image and point cloud classification can be found under the `apps` directory. See
the `README` files therein for instructions on installation and how to run.
Tutorials should be opened in Jupyter notebook:

```
cd tutorials
jupyter notebook
```

Reference (PyTorch) applications for image and point cloud classification can be found under the `apps`
directory. See the `README` files therein for instructions on installation and how to run.

## License

Expand Down
1 change: 1 addition & 0 deletions ddn/basic/composition.py
Expand Up @@ -11,6 +11,7 @@ class ComposedNode(AbstractNode):
as such can be further composed with other nodes to form a chain."""

def __init__(self, nodeA, nodeB):
assert (nodeA.dim_y == nodeB.dim_x)
super().__init__(nodeA.dim_x, nodeB.dim_y)
self.nodeA = nodeA
self.nodeB = nodeB
Expand Down
21 changes: 12 additions & 9 deletions ddn/basic/node.py
Expand Up @@ -6,6 +6,7 @@
#

import autograd.numpy as np
import scipy as sci
from autograd import grad, jacobian
import warnings

Expand Down Expand Up @@ -90,8 +91,7 @@ def gradient(self, x, y=None, ctx=None):
y, ctx = self.solve(x)
assert self._check_optimality_cond(x, y)

# TODO: replace with symmetric matrix solver
return -1.0 * np.linalg.solve(self.fYY(x, y), self.fXY(x, y))
return -1.0 * sci.linalg.solve(self.fYY(x, y), self.fXY(x, y), assume_a='pos')

def _check_optimality_cond(self, x, y, ctx=None):
"""Checks that the problem's first-order optimality condition is satisfied."""
Expand Down Expand Up @@ -159,8 +159,7 @@ def gradient(self, x, y=None, ctx=None):
B = self.fXY(x, y) - nu * self.hXY(x, y)
C = self.hX(x, y)
try:
# TODO: replace with symmetric solver
v = np.linalg.solve(H, np.concatenate((a.reshape((self.dim_y, 1)), B), axis=1))
v = sci.linalg.solve(H, np.concatenate((a.reshape((self.dim_y, 1)), B), axis=1), assume_a='pos')
except:
return np.full((self.dim_y, self.dim_x), np.nan).squeeze()
return (np.outer(v[:, 0], (v[:, 0].dot(B) - C) / v[:, 0].dot(a)) - v[:, 1:self.dim_x + 1]).squeeze()
Expand Down Expand Up @@ -241,11 +240,15 @@ def gradient(self, x, y=None, ctx=None):
assert self._check_constraints(x, y)
assert self._check_optimality_cond(x, y, ctx)

# TODO: replace with symmetric matrix solver and avoid explicit inverse matrix computations
invH = np.linalg.inv(self.fYY(x, y))
invHAT = np.dot(invH, self.A.T)
w = np.dot(np.dot(invHAT, np.linalg.inv(np.dot(self.A, invHAT))), invHAT.T) - invH
return np.dot(w, self.fXY(x, y))
# TODO: write test case for LinEqConstDeclarativeNode
# use cholesky to solve H^{-1}A^T and H^{-1}B
C, L = sci.linalg.cho_factor(self.fYY(x, y))
invHAT = sci.linalg.cho_solve((C, L), self.A.T)
invHB = sci.linalg.cho_solve((C, L), self.fXY(x, y))
# compute W = H^{-1}A^T (A H^{-1} A^T)^{-1} A
W = np.dot(invHAT, sci.linalg.solve(np.dot(self.A, invHAT), self.A))
# return H^{-1}A^T (A H^{-1} A^T)^{-1} A H^{-1} B - H^{-1} B
return np.dot(W, invHB) - invHB

def _check_constraints(self, x, y):
"""Check that the problem's constraints are satisfied."""
Expand Down
2 changes: 1 addition & 1 deletion ddn/basic/robust_nodes.py
Expand Up @@ -15,7 +15,7 @@ class RobustAverage(NonUniqueDeclarativeNode):
minimize f(x, y) = \sum_{i=1}^{n} phi(y - x_i; alpha)
where phi(z; alpha) is one of the following robust penalties,
'quadratic': 1/2 z^2
'pseudo-huber': alpha^2 (\sqrt(1 + (z/alpha)^2 - 1)
'pseudo-huber': alpha^2 (\sqrt(1 + (z/alpha)^2) - 1)
'huber': 1/2 z^2 for |z| <= alpha and alpha |z| - 1/2 alpha^2 otherwise
'welsch': 1 - exp(-z^2 / 2 alpha^2)
'trunc-quad': 1/2 z^2 for |z| <= alpha and 1/2 alpha^2 otherwise
Expand Down
4 changes: 2 additions & 2 deletions ddn/pytorch/robustpool.py
Expand Up @@ -266,7 +266,7 @@ def backward(ctx, grad_output):
grad_input = method.Dy(z, alpha) * grad_output.unsqueeze(-1).unsqueeze(-1)
# Unflatten:
grad_input = grad_input.reshape(input_size)
return grad_input, None, None, None, None
return grad_input, None, None

class RobustGlobalPool2d(torch.nn.Module):
def __init__(self, method, alpha=1.0):
Expand Down Expand Up @@ -304,4 +304,4 @@ def extra_repr(self):
input = (torch.randn(2, 3, 7, 7, dtype=torch.double, requires_grad=True), method, alpha_tensor)
test = gradcheck(robustPool, input, eps=1e-6, atol=1e-4, rtol=1e-3, raise_exception=True)
print("{}: {}".format(method.__name__, test))
"""
"""
315 changes: 292 additions & 23 deletions tutorials/01_simple_worked_example.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions tutorials/TODO.md
@@ -0,0 +1,5 @@
- [ ] Diagram immediately after each optimisation problem
- [ ] Starting with a 1D example like the contour plot
- [ ] Linking the last two simple worked examples to Euclidean projection onto balls/spheres
- [ ] Adding a PyTorch tutorial
- [ ] Feedback from Itzik

0 comments on commit 0e4d793

Please sign in to comment.