Skip to content

Commit

Permalink
renamed delta_mult as beautify option; updated code to always lock
Browse files Browse the repository at this point in the history
  • Loading branch information
aharley committed Sep 25, 2023
1 parent 3be8252 commit 9f901be
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions nets/pips2.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def __init__(self, stride=8):
self.delta_block = DeltaBlock(hidden_dim=self.hidden_dim, corr_levels=self.corr_levels, corr_radius=self.corr_radius)
self.norm = nn.GroupNorm(1, self.latent_dim)

def forward(self, trajs_e0, rgbs, iters=3, trajs_g=None, vis_g=None, valids=None, sw=None, feat_init=None, is_train=False, delta_mult=0.5):
def forward(self, trajs_e0, rgbs, iters=3, trajs_g=None, vis_g=None, valids=None, sw=None, feat_init=None, is_train=False, beautify=False):
total_loss = torch.tensor(0.0).cuda()

B,S,N,D = trajs_e0.shape
Expand Down Expand Up @@ -463,8 +463,8 @@ def forward(self, trajs_e0, rgbs, iters=3, trajs_g=None, vis_g=None, valids=None

coords_bak = coords.clone()

if not is_train:
coords[:,0] = coords_bak[:,0] # lock coord0 for target
# if not is_train:
coords[:,0] = coords_bak[:,0] # lock coord0 for target

coord_predictions1 = [] # for loss
coord_predictions2 = [] # for vis
Expand Down Expand Up @@ -510,18 +510,18 @@ def forward(self, trajs_e0, rgbs, iters=3, trajs_g=None, vis_g=None, valids=None

delta_coords_ = self.delta_block(fcorrs_, flows_) # B*N,S,2

if not is_train and itr >= iters*3/4:
# this beautifies the results a bit, but does not really help perf
if beautify:
# this smooths the results a bit, but does not really help perf
delta_coords_ = delta_coords_ * delta_mult

coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0,2,1,3)

if not is_train:
coords[:,0] = coords_bak[:,0] # lock coord0 for target

coord_predictions1.append(coords * self.stride)
coord_predictions2.append(coords * self.stride)

coords[:,0] = coords_bak[:,0] # lock coord0 for target


# pause at the end, to make the summs more interpretable
coord_predictions2.append(coords * self.stride)
coord_predictions2.append(coords * self.stride)
Expand All @@ -531,5 +531,6 @@ def forward(self, trajs_e0, rgbs, iters=3, trajs_g=None, vis_g=None, valids=None
else:
loss = None

coord_predictions1.append(coords * self.stride)
feats = (feats1, feats2, feats4)
return coord_predictions1, coord_predictions2, feats, loss

0 comments on commit 9f901be

Please sign in to comment.