Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 15 additions & 18 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -134,24 +134,21 @@ jobs:
name: Deploy docs
command: |
set -e;
if [ "${CIRCLE_BRANCH}" == "master" ]; then
git config --global user.email "circle@PythonOT.com";
git config --global user.name "Circle CI";
cd ~/PythonOT.github.io;
git checkout master
git remote -v
git fetch origin
git reset --hard origin/master
git clean -xdf
echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
cp -a /tmp/build/html/* .;
touch .nojekyll;
git add -A;
git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
git push origin master;
else
echo "No deployment (build: ${CIRCLE_BRANCH}).";
fi
git config --global user.email "circle@PythonOT.com";
git config --global user.name "Circle CI";
cd ~/PythonOT.github.io;
git checkout master
git remote -v
git fetch origin
git reset --hard origin/master
git clean -xdf
echo "Deploying dev docs for ${CIRCLE_BRANCH}.";
cp -a /tmp/build/html/* .;
touch .nojekyll;
git add -A;
git commit -m "CircleCI update of dev docs (${CIRCLE_BUILD_NUM}).";
git push origin master;



workflows:
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_Intro_OT.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@
time_sinkhorn_reg[k] = time.time() - start

if k % 4 == 0 and k > 0: # we only plot a few
ax = pl.subplot(1, 5, k / 4)
ax = pl.subplot(1, 5, k // 4)
im = pl.imshow(ot_sinkhorn, vmin=0, vmax=max_ot)
pl.title('reg={0:.2g}'.format(reg_parameter[k]))
pl.xlabel('Cafés')
Expand Down
2 changes: 1 addition & 1 deletion ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,7 +1203,7 @@ def forward(ctx, val, grads, *inputs):
@staticmethod
def backward(ctx, grad_output):
# the gradients are grad
return (None, None) + ctx.grads
return (None, None) + tuple(g * grad_output for g in ctx.grads)

self.ValFunction = ValFunction

Expand Down
16 changes: 16 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,22 @@ def test_emd2_gradients():
assert b1.shape == b1.grad.shape
assert M1.shape == M1.grad.shape

# Testing for bug #309, checking for scaling of gradient
a2 = torch.tensor(a, requires_grad=True)
b2 = torch.tensor(a, requires_grad=True)
M2 = torch.tensor(M, requires_grad=True)

val = 10.0 * ot.emd2(a2, b2, M2)

val.backward()

assert np.allclose(10.0 * a1.grad.cpu().detach().numpy(),
a2.grad.cpu().detach().numpy())
assert np.allclose(10.0 * b1.grad.cpu().detach().numpy(),
b2.grad.cpu().detach().numpy())
assert np.allclose(10.0 * M1.grad.cpu().detach().numpy(),
M2.grad.cpu().detach().numpy())


def test_emd_emd2():
# test emd and emd2 for simple identity
Expand Down