Skip to content

Commit

Permalink
Adapt to new version of pytorch: unsqueeze after reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
ducha-aiki committed Aug 6, 2017
1 parent 53d4120 commit bbcd9df
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 39 deletions.
53 changes: 27 additions & 26 deletions HardNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,55 +255,56 @@ def __len__(self):
else:
return self.matches.size(0)

class TNet(nn.Module):
"""TFeat model definition
class HardNet(nn.Module):
"""HardNet model definition
"""
def __init__(self):
super(TNet, self).__init__()
super(HardNet, self).__init__()

self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.Conv2d(1, 32, kernel_size=3, padding=1, bias = False),
nn.BatchNorm2d(32, affine=False),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.Conv2d(32, 32, kernel_size=3, padding=1, bias = False),
nn.BatchNorm2d(32, affine=False),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias = False),
nn.BatchNorm2d(64, affine=False),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias = False),
nn.BatchNorm2d(64, affine=False),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2,padding=1),
nn.Conv2d(64, 128, kernel_size=3, stride=2,padding=1, bias = False),
nn.BatchNorm2d(128, affine=False),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.Conv2d(128, 128, kernel_size=3, padding=1, bias = False),
nn.BatchNorm2d(128, affine=False),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv2d(128, 128, kernel_size=8),
nn.Conv2d(128, 128, kernel_size=8, bias = False),
nn.BatchNorm2d(128, affine=False),

)
self.features.apply(weights_init)

def forward(self, input):
flat = input.view(input.size(0), -1)
return
def input_norm(self,x):
flat = x.view(x.size(0), -1)
mp = torch.sum(flat, dim=1) / (32. * 32.)
sp = torch.std(flat, dim=1) + 1e-7
x_features = self.features(
(input - mp.unsqueeze(-1).unsqueeze(-1).expand_as(input)) / sp.unsqueeze(-1).unsqueeze(1).expand_as(input))
return (x - mp.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand_as(x)) / sp.unsqueeze(-1).unsqueeze(-1).unsqueeze(1).expand_as(x)

def forward(self, input):
x_features = self.features(self.input_norm(input))
x = x_features.view(x_features.size(0), -1)
return L2Norm()(x)

def weights_init(m):
if isinstance(m, nn.Conv2d):
nn.init.orthogonal(m.weight.data, gain=0.7)
nn.init.constant(m.bias.data, 0.01)
if isinstance(m, nn.Linear):
nn.init.orthogonal(m.weight.data, gain=0.01)
nn.init.constant(m.bias.data, 0.)

try:
nn.init.constant(m.bias.data, 0.01)
except:
pass
return
def create_loaders(load_random_triplets = False):

test_dataset_names = copy.copy(dataset_names)
Expand Down Expand Up @@ -366,10 +367,10 @@ def train(train_loader, model, optimizer, epoch, logger, load_triplets = False)
out_a, out_p = model(data_a), model(data_p)
#hardnet loss
if args.batch_reduce == 'L2Net':
loss = loss_L2Net(out_a, out_p, anchor_swap =args.anchorswap, margin = args.margin, loss_type = args.loss)
loss = loss_L2Net(out_a, out_p, column_row_swap = True, anchor_swap =args.anchorswap, margin = args.margin, loss_type = args.loss)
else:
loss = 3 * loss_HardNet(out_a, out_p,
margin=args.margin,
loss = loss_HardNet(out_a, out_p,
margin=args.margin, column_row_swap = True,
anchor_swap=args.anchorswap,
anchor_ave=args.anchorave,
batch_reduce = args.batch_reduce,
Expand Down Expand Up @@ -410,7 +411,7 @@ def test(test_loader, model, epoch, logger, logger_test_name):

out_a, out_p = model(data_a), model(data_p)
dists = torch.sqrt(torch.sum((out_a - out_p) ** 2, 1)) # euclidean distance
distances.append(dists.data.cpu().numpy())
distances.append(dists.data.cpu().numpy().reshape(-1,1))
ll = label.data.cpu().numpy().reshape(-1, 1)
labels.append(ll)

Expand Down Expand Up @@ -528,7 +529,7 @@ def main(train_loader, test_loaders, model, logger, file_logger):
if not os.path.isdir(DESCS_DIR):
os.makedirs(DESCS_DIR)
logger, file_logger = None, None
model = TNet()
model = HardNet()
if(args.enable_logging):
from Loggers import Logger, FileLogger
logger = Logger(LOG_DIR)
Expand Down
35 changes: 24 additions & 11 deletions Losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
def distance_matrix_vector(anchor, positive):
"""Given batch of anchor descriptors and positive descriptors calculate distance matrix"""

d1_sq = torch.sum(anchor * anchor, dim=1)
d2_sq = torch.sum(positive * positive, dim=1)
d1_sq = torch.sum(anchor * anchor, dim=1).unsqueeze(-1)
d2_sq = torch.sum(positive * positive, dim=1).unsqueeze(-1)

eps = 1e-6
return torch.sqrt((d1_sq.repeat(1, anchor.size(0)) + torch.t(d2_sq.repeat(1, positive.size(0)))
Expand Down Expand Up @@ -52,7 +52,7 @@ def loss_random_sampling(anchor, positive, negative, anchor_swap = False, margin
loss = torch.mean(loss)
return loss

def loss_L2Net(anchor, positive, anchor_swap = False, margin = 1.0, loss_type = "triplet_margin"):
def loss_L2Net(anchor, positive, column_row_swap = False,anchor_swap = False, margin = 1.0, loss_type = "triplet_margin"):
"""L2Net losses: using whole batch as negatives, not only hardest.
"""

Expand All @@ -73,22 +73,22 @@ def loss_L2Net(anchor, positive, anchor_swap = False, margin = 1.0, loss_type =
exp_pos = torch.exp(2.0 - pos1);
exp_den = torch.sum(torch.exp(2.0 - dist_matrix),1) + eps;
loss = -torch.log( exp_pos / exp_den )
if anchor_swap:
if column_row_swap:
exp_den1 = torch.sum(torch.exp(2.0 - dist_matrix),0) + eps;
loss += -torch.log( exp_pos / exp_den1 )
else:
print ('Only softmax loss works with L2Net sampling')
sys.exit(1)
loss = torch.mean(loss)
return loss
def loss_HardNet(anchor, positive, anchor_swap = False, anchor_ave = False, margin = 1.0, batch_reduce = 'min', loss_type = "triplet_margin"):
def loss_HardNet(anchor, positive, column_row_swap = False, anchor_swap = False, anchor_ave = False, margin = 1.0, batch_reduce = 'min', loss_type = "triplet_margin"):
"""HardNet margin loss - calculates loss based on distance matrix based on positive distance and closest negative distance.
"""

assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal."
assert anchor.dim() == 2, "Inputd must be a 2D matrix."
eps = 1e-8
dist_matrix = distance_matrix_vector(anchor, positive)
dist_matrix = distance_matrix_vector(anchor, positive) +eps
eye = torch.autograd.Variable(torch.eye(dist_matrix.size(1))).cuda()

# steps to filter out same patches that occur in distance matrix as negatives
Expand All @@ -99,22 +99,35 @@ def loss_HardNet(anchor, positive, anchor_swap = False, anchor_ave = False, marg
dist_without_min_on_diag = dist_without_min_on_diag+mask
if batch_reduce == 'min':
min_neg = torch.min(dist_without_min_on_diag,1)[0]
if anchor_swap:
min_neg2 = torch.t(torch.min(dist_without_min_on_diag,0)[0])
if column_row_swap:
min_neg2 = torch.min(dist_without_min_on_diag,0)[0]
min_neg = torch.min(min_neg,min_neg2)
min_neg = torch.t(min_neg).squeeze(0)
if False:
dist_matrix_a = distance_matrix_vector(anchor, anchor)+ eps
dist_matrix_p = distance_matrix_vector(positive,positive)+eps
dist_without_min_on_diag_a = dist_matrix_a+eye*10
dist_without_min_on_diag_p = dist_matrix_p+eye*10
min_neg_a = torch.min(dist_without_min_on_diag_a,1)[0]
min_neg_p = torch.t(torch.min(dist_without_min_on_diag_p,0)[0])
min_neg_3 = torch.min(min_neg_p,min_neg_a)
min_neg = torch.min(min_neg,min_neg_3)
print (min_neg_a)
print (min_neg_p)
print (min_neg_3)
print (min_neg)
min_neg = min_neg
pos = pos1
elif batch_reduce == 'average':
pos = pos1.repeat(anchor.size(0)).view(-1,1).squeeze(0)
min_neg = dist_without_min_on_diag.view(-1,1)
if anchor_swap:
if column_row_swap:
min_neg2 = torch.t(dist_without_min_on_diag).contiguous().view(-1,1)
min_neg = torch.min(min_neg,min_neg2)
min_neg = min_neg.squeeze(0)
elif batch_reduce == 'random':
idxs = torch.autograd.Variable(torch.randperm(anchor.size()[0]).long()).cuda()
min_neg = dist_without_min_on_diag.gather(1,idxs.view(-1,1))
if anchor_swap:
if column_row_swap:
min_neg2 = torch.t(dist_without_min_on_diag).gather(1,idxs.view(-1,1))
min_neg = torch.min(min_neg,min_neg2)
min_neg = torch.t(min_neg).squeeze(0)
Expand Down
4 changes: 2 additions & 2 deletions Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self):
self.eps = 1e-10
def forward(self, x):
norm = torch.sqrt(torch.sum(x * x, dim = 1) + self.eps)
x= x / norm.expand_as(x)
x= x / norm.unsqueeze(-1).expand_as(x)
return x

class L1Norm(nn.Module):
Expand All @@ -33,4 +33,4 @@ def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
return False

0 comments on commit bbcd9df

Please sign in to comment.