-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Ensure that TripletEmbeddingTrainer does not introduce NaNs when norm…
…alizing input to the unit sphere
- Loading branch information
1 parent
ab80e8c
commit 9741550
Showing
4 changed files
with
59 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
import unittest2 | ||
from embedding import TripletEmbeddingTrainer | ||
from torch import nn | ||
import numpy as np | ||
|
||
|
||
class TripletEmbeddingTrainerTests(unittest2.TestCase): | ||
def test_normalization_does_not_cause_nans(self): | ||
class Network(nn.Module): | ||
def __init__(self): | ||
super(Network, self).__init__() | ||
|
||
def forward(self, x): | ||
return x | ||
|
||
network = Network() | ||
trainer = TripletEmbeddingTrainer(network, 100, 32, slice(None)) | ||
x = torch.zeros(8, 3) | ||
result = trainer._apply_network_and_normalize(x).data.numpy() | ||
self.assertFalse(np.any(np.isnan(result))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import unittest2 | ||
import torch | ||
from util import batchwise_unit_norm | ||
import numpy as np | ||
|
||
|
||
class BatchwiseUnitNormTests(unittest2.TestCase): | ||
def test_all_elements_have_unit_norm(self): | ||
t = torch.FloatTensor(100, 5).normal_(0, 1) | ||
t = batchwise_unit_norm(t).data.numpy() | ||
norms = np.linalg.norm(t, axis=1) | ||
np.testing.assert_allclose(norms, 1, rtol=1e-6) | ||
|
||
def test_maintains_correct_shape_2d(self): | ||
t = torch.FloatTensor(100, 5).normal_(0, 1) | ||
t = batchwise_unit_norm(t).data.numpy() | ||
self.assertEqual((100, 5), t.shape) | ||
|
||
def test_maintains_correct_shape_3d(self): | ||
t = torch.FloatTensor(100, 5, 3).normal_(0, 1) | ||
t = batchwise_unit_norm(t).data.numpy() | ||
self.assertEqual((100, 5, 3), t.shape) | ||
|
||
def test_does_not_introduce_nans(self): | ||
t = torch.FloatTensor(100, 5, 3).zero_() | ||
t = batchwise_unit_norm(t).data.numpy() | ||
self.assertFalse(np.any(np.isnan(t))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters