Skip to content

Commit

Permalink
Renamed the Training Tuple Sampler to TupleSampler to avoid confusion…
Browse files Browse the repository at this point in the history
… with torch.utils.data.Sampler()
  • Loading branch information
Confusezius committed Sep 28, 2019
1 parent 109cca4 commit dbcd5ab
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def loss_select(loss, opt, to_optim):

"""================================================================================================="""
######### MAIN SAMPLER CLASS #################################
class Sampler():
class TupleSampler():
"""
Container for all sampling methods that can be used in conjunction with the respective loss functions.
Based on batch-wise sampling, i.e. given a batch of training data, sample useful data tuples that are
Expand Down Expand Up @@ -288,11 +288,11 @@ def __init__(self, margin=1, sampling_method='random'):
Args:
margin: float, Triplet Margin - Ensures that positives aren't placed arbitrarily close to the anchor.
Similarl, negatives should not be placed arbitrarily far away.
sampling_method: Method to use for sampling training triplets. Used for the Sampler-class.
sampling_method: Method to use for sampling training triplets. Used for the TupleSampler-class.
"""
super(TripletLoss, self).__init__()
self.margin = margin
self.sampler = Sampler(method=sampling_method)
self.sampler = TupleSampler(method=sampling_method)

def triplet_distance(self, anchor, positive, negative):
"""
Expand Down Expand Up @@ -335,7 +335,7 @@ def __init__(self, l2=0.02):
Nothing!
"""
super(NPairLoss, self).__init__()
self.sampler = Sampler(method='npair')
self.sampler = TupleSampler(method='npair')
self.l2 = l2

def npair_distance(self, anchor, positive, negatives):
Expand Down Expand Up @@ -409,7 +409,7 @@ def __init__(self, margin=0.2, nu=0, beta=1.2, n_classes=100, beta_constant=Fals
self.nu = nu

self.sampling_method = sampling_method
self.sampler = Sampler(method=sampling_method)
self.sampler = TupleSampler(method=sampling_method)


def forward(self, batch, labels):
Expand Down

0 comments on commit dbcd5ab

Please sign in to comment.