Skip to content

Commit

Permalink
Ensure that TripletEmbeddingTrainer does not introduce NaNs when norm…
Browse files Browse the repository at this point in the history
…alizing input to the unit sphere
  • Loading branch information
JohnVinyard committed Aug 30, 2018
1 parent ab80e8c commit 9741550
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 2 deletions.
4 changes: 2 additions & 2 deletions zounds/learn/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn
from torch.optim import Adam
import torch
from util import trainable_parameters
from util import trainable_parameters, batchwise_unit_norm


class TripletEmbeddingTrainer(Trainer):
Expand Down Expand Up @@ -75,7 +75,7 @@ def _apply_network_and_normalize(self, x):
by section 4.2 of https://arxiv.org/pdf/1711.02209.pdf
"""
x = self.network(x)
return x / torch.norm(x, dim=1).view(-1, 1)
return batchwise_unit_norm(x)

def _select_batch(self, training_set):
indices = np.random.randint(0, len(training_set), self.batch_size)
Expand Down
21 changes: 21 additions & 0 deletions zounds/learn/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
import unittest2
from embedding import TripletEmbeddingTrainer
from torch import nn
import numpy as np


class TripletEmbeddingTrainerTests(unittest2.TestCase):
def test_normalization_does_not_cause_nans(self):
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__()

def forward(self, x):
return x

network = Network()
trainer = TripletEmbeddingTrainer(network, 100, 32, slice(None))
x = torch.zeros(8, 3)
result = trainer._apply_network_and_normalize(x).data.numpy()
self.assertFalse(np.any(np.isnan(result)))
27 changes: 27 additions & 0 deletions zounds/learn/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import unittest2
import torch
from util import batchwise_unit_norm
import numpy as np


class BatchwiseUnitNormTests(unittest2.TestCase):
def test_all_elements_have_unit_norm(self):
t = torch.FloatTensor(100, 5).normal_(0, 1)
t = batchwise_unit_norm(t).data.numpy()
norms = np.linalg.norm(t, axis=1)
np.testing.assert_allclose(norms, 1, rtol=1e-6)

def test_maintains_correct_shape_2d(self):
t = torch.FloatTensor(100, 5).normal_(0, 1)
t = batchwise_unit_norm(t).data.numpy()
self.assertEqual((100, 5), t.shape)

def test_maintains_correct_shape_3d(self):
t = torch.FloatTensor(100, 5, 3).normal_(0, 1)
t = batchwise_unit_norm(t).data.numpy()
self.assertEqual((100, 5, 3), t.shape)

def test_does_not_introduce_nans(self):
t = torch.FloatTensor(100, 5, 3).zero_()
t = batchwise_unit_norm(t).data.numpy()
self.assertFalse(np.any(np.isnan(t)))
9 changes: 9 additions & 0 deletions zounds/learn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ def sample_norm(x):
return original / x.view(-1, 1, x.shape[-1])


def batchwise_unit_norm(x, epsilon=1e-8):
batch_size = x.shape[0]
flattened = x.view(batch_size, -1)
norm = torch.norm(flattened, dim=1, keepdim=True)
expanded = norm.view(batch_size, *((1,) * (x.dim() - 1)))
normed = x / (expanded + epsilon)
return normed


def feature_map_size(inp, kernel, stride=1, padding=0):
return ((inp - kernel + (2 * padding)) / stride) + 1

Expand Down

0 comments on commit 9741550

Please sign in to comment.