Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'dev-1.2' into dev-1.2-merge
- Loading branch information
Showing
99 changed files
with
3,993 additions
and
3,258 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,5 @@ MANIFEST | |
*.c | ||
.asv/ | ||
autograd.egg-info/ | ||
asv.conf.json | ||
benchmarks/asv.conf.js |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,8 @@ | ||
from __future__ import absolute_import | ||
from .core import primitive, make_vjp, getval | ||
from . import container_types | ||
from .container_types import make_tuple, make_list, make_dict | ||
from .convenience_wrappers import (grad, multigrad, multigrad_dict, elementwise_grad, | ||
value_and_grad, grad_and_aux, hessian_tensor_product, | ||
hessian_vector_product, hessian, jacobian, | ||
tensor_jacobian_product, vector_jacobian_product, | ||
grad_named, checkpoint, make_hvp, value_and_multigrad, | ||
make_jvp, make_ggnvp) | ||
from .differential_operators import ( | ||
make_vjp, grad, multigrad_dict, elementwise_grad, value_and_grad, | ||
grad_and_aux, hessian_tensor_product, hessian_vector_product, hessian, | ||
jacobian, tensor_jacobian_product, vector_jacobian_product, grad_named, | ||
checkpoint, make_hvp, make_jvp, make_ggnvp, make_jvp, | ||
deriv, holomorphic_grad) | ||
from .builtins import isinstance, type, tuple, list, dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import itertools | ||
from future.utils import with_metaclass | ||
from .util import subvals | ||
from .extend import (Box, primitive, notrace_primitive, VSpace, vspace, | ||
SparseObject, defvjp, defvjp_argnum, defjvp, defjvp_argnum) | ||
|
||
isinstance_ = isinstance | ||
isinstance = notrace_primitive(isinstance) | ||
|
||
type_ = type | ||
type = notrace_primitive(type) | ||
|
||
tuple_, list_, dict_ = tuple, list, dict | ||
|
||
@primitive | ||
def container_take(A, idx): | ||
return A[idx] | ||
def grad_container_take(ans, A, idx): | ||
return lambda g: container_untake(g, idx, vspace(A)) | ||
defvjp(container_take, grad_container_take) | ||
defjvp(container_take, 'same') | ||
|
||
class SequenceBox(Box): | ||
__slots__ = [] | ||
__getitem__ = container_take | ||
def __len__(self): return len(self._value) | ||
def __add__(self, other): return sequence_extend_right(self, *other) | ||
def __radd__(self, other): return sequence_extend_left(self, *other) | ||
def __contains__(self, elt): return elt in self._value | ||
def index(self, elt): return self._value.index(elt) | ||
SequenceBox.register(tuple_) | ||
SequenceBox.register(list_) | ||
|
||
class DictBox(Box): | ||
__slots__ = [] | ||
__getitem__= container_take | ||
def __len__(self): return len(self._value) | ||
def __iter__(self): return self._value.__iter__() | ||
def __contains__(self, elt): return elt in self._value | ||
def items(self): return list(self.iteritems()) | ||
def keys(self): return list(self.iterkeys()) | ||
def values(self): return list(self.itervalues()) | ||
def iteritems(self): return ((k, self[k]) for k in self) | ||
def iterkeys(self): return iter(self) | ||
def itervalues(self): return (self[k] for k in self) | ||
def get(self, k, d=None): return self[k] if k in self else d | ||
DictBox.register(dict_) | ||
|
||
@primitive | ||
def container_untake(x, idx, vs): | ||
if isinstance(idx, slice): | ||
accum = lambda result: [elt_vs._mut_add(a, b) | ||
for elt_vs, a, b in zip(vs.shape[idx], result, x)] | ||
else: | ||
accum = lambda result: vs.shape[idx]._mut_add(result, x) | ||
def mut_add(A): | ||
return vs._subval(A, idx, accum(A[idx])) | ||
return SparseObject(vs, mut_add) | ||
defvjp(container_untake, lambda ans, x, idx, _: | ||
lambda g: container_take(g, idx)) | ||
defjvp(container_untake, 'same') | ||
|
||
@primitive | ||
def sequence_extend_right(seq, *elts): | ||
return seq + type(seq)(elts) | ||
def grad_sequence_extend_right(argnum, ans, args, kwargs): | ||
seq, elts = args[0], args[1:] | ||
return lambda g: g[:len(seq)] if argnum == 0 else g[len(seq) + argnum - 1] | ||
defvjp_argnum(sequence_extend_right, grad_sequence_extend_right) | ||
|
||
@primitive | ||
def sequence_extend_left(seq, *elts): | ||
return type(seq)(elts) + seq | ||
def grad_sequence_extend_left(argnum, ans, args, kwargs): | ||
seq, elts = args[0], args[1:] | ||
return lambda g: g[len(elts):] if argnum == 0 else g[argnum - 1] | ||
defvjp_argnum(sequence_extend_left, grad_sequence_extend_left) | ||
|
||
@primitive | ||
def make_sequence(seq_type, *args): | ||
return seq_type(args) | ||
defvjp_argnum(make_sequence, lambda argnum, *args: lambda g: g[argnum - 1]) | ||
|
||
def fwd_grad_make_sequence(argnum, g, ans, seq_type, *args, **kwargs): | ||
return container_untake(g, argnum-1, vspace(ans)) | ||
|
||
defjvp_argnum(make_sequence, fwd_grad_make_sequence) | ||
|
||
|
||
class TupleMeta(type_): | ||
def __instancecheck__(self, instance): | ||
return isinstance(instance, tuple_) | ||
class tuple(with_metaclass(TupleMeta, tuple_)): | ||
def __new__(cls, xs): | ||
return make_sequence(tuple_, *xs) | ||
|
||
class ListMeta(type_): | ||
def __instancecheck__(self, instance): | ||
return isinstance(instance, list_) | ||
class list(with_metaclass(ListMeta, list_)): | ||
def __new__(cls, xs): | ||
return make_sequence(list_, *xs) | ||
|
||
class DictMeta(type_): | ||
def __instancecheck__(self, instance): | ||
return isinstance(instance, dict_) | ||
class dict(with_metaclass(DictMeta, dict_)): | ||
def __new__(cls, args, **kwargs): | ||
keys, vals = zip(*itertools.chain(args, kwargs.items())) | ||
return _make_dict(keys, list(vals)) | ||
|
||
@primitive | ||
def _make_dict(keys, vals): | ||
return dict_(zip(keys, vals)) | ||
defvjp(_make_dict, | ||
lambda ans, keys, vals: lambda g: | ||
list(g[key] for key in keys), argnums=(1,)) | ||
|
||
class ContainerVSpace(VSpace): | ||
def __init__(self, value): | ||
self.shape = value | ||
self.shape = self._map(vspace) | ||
|
||
@property | ||
def size(self): return sum(self._values(self._map(lambda vs: vs.size))) | ||
def zeros(self): return self._map(lambda vs: vs.zeros()) | ||
def ones(self): return self._map(lambda vs: vs.ones()) | ||
def randn(self): return self._map(lambda vs: vs.randn()) | ||
def standard_basis(self): | ||
zero = self.zeros() | ||
for i, vs in self._kv_pairs(self.shape): | ||
for x in vs.standard_basis(): | ||
yield self._subval(zero, i, x) | ||
def _add(self, xs, ys): | ||
return self._map(lambda vs, x, y: vs._add(x, y), xs, ys) | ||
def _mut_add(self, xs, ys): | ||
return self._map(lambda vs, x, y: vs._mut_add(x, y), xs, ys) | ||
def _scalar_mul(self, xs, a): | ||
return self._map(lambda vs, x: vs._scalar_mul(x, a), xs) | ||
def _inner_prod(self, xs, ys): | ||
return sum(self._values(self._map(lambda vs, x, y: vs._inner_prod(x, y), xs, ys))) | ||
def _covector(self, xs): | ||
return self._map(lambda vs, x: vs._covector(x), xs) | ||
|
||
class SequenceVSpace(ContainerVSpace): | ||
def _values(self, x): return x | ||
def _kv_pairs(self, x): return enumerate(x) | ||
def _map(self, f, *args): | ||
return self.seq_type(map(f, self.shape, *args)) | ||
def _subval(self, xs, idx, x): | ||
return self.seq_type(subvals(xs, [(idx, x)])) | ||
|
||
class ListVSpace(SequenceVSpace): seq_type = list_ | ||
class TupleVSpace(SequenceVSpace): seq_type = tuple_ | ||
class DictVSpace(ContainerVSpace): | ||
def _values(self, x): return x.values() | ||
def _kv_pairs(self, x): return x.items() | ||
def _map(self, f, *args):return {k: f(vs, *[x[k] for x in args]) | ||
for k, vs in self.shape.items()} | ||
def _subval(self, xs, idx, x): | ||
d = dict(xs.items()) | ||
d[idx] = x | ||
return d | ||
|
||
ListVSpace.register(list_) | ||
TupleVSpace.register(tuple_) | ||
DictVSpace.register(dict_) |
Oops, something went wrong.