Skip to content

Commit

Permalink
make code more clean
Browse files Browse the repository at this point in the history
  • Loading branch information
hli2020 committed Jun 18, 2019
1 parent 5048686 commit a3e79f8
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 58 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ file for details.

The current version contains some
legacy variable names in early trial experiments;
we would fix them later.
we would remove them later and make the repo cleaner.

### Citation
Please cite in the following manner if you find it useful in your research:
Expand Down
25 changes: 3 additions & 22 deletions core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,6 @@ def __init__(self, opts):
self._make_layer(Bottleneck, 128, 4, stride=1),
self._make_layer(Bottleneck, 64, 3, stride=1)
)
# ot_out = self.critic_que(repnet_out)
# self.ot_pool_size = ot_out.size(2)

_embedding = repnet_out

if self.baseline_manner == 'sample_wise_similar':
Expand All @@ -299,7 +296,7 @@ def __init__(self, opts):
self.reshaper = self._make_layer(Bottleneck, out_size, 4, stride=1)
_out_downsample = self.reshaper(_embedding)

# DEDUCTOR AND PROJECTOR
# CONCENTRATOR AND PROJECTOR
if self.dnet:
if self.mp_mean:
self.inplanes = _embedding.size(1)
Expand Down Expand Up @@ -332,6 +329,7 @@ def __init__(self, opts):
else:
self.projection = self._make_layer(Bottleneck, out_size, 4, stride=1)

# deprecated; kept for legacy
if self.use_discri_loss:
# 40 x 19 x 19 = 14440
input_c = _out_downsample.size(1)*_out_downsample.size(2)*_out_downsample.size(2)
Expand Down Expand Up @@ -477,7 +475,7 @@ def _make_layer(self, block, planes, blocks, stride=1):

return nn.Sequential(*layers)

# decprecated in CTM
# decprecated in CTM; kept here for ablation study
def forward(self, support_x, support_y, query_x, query_y,
train=True, n_way=-1, curr_shot=-1):

Expand Down Expand Up @@ -553,23 +551,6 @@ def forward(self, support_x, support_y, query_x, query_y,
label = torch.eq(support_y_expand, query_y_expand).float()
loss[i] = F.mse_loss(score[i], label)

# support_y_neat = support_ys[i][:, ::curr_shot] # b, n_way
# target = torch.stack([
# torch.nonzero(torch.eq(support_y_neat[b], query_ys[i][b, j]))
# for b, query in enumerate(query_ys[i]) for j, _, in enumerate(query)
# ])
# target = target.view(-1, 1) # shape: N
# one_hot_labels = \
# torch.zeros(target.size(0), self.opts.fsl.n_way[0]).to(self.opts.ctrl.device).scatter_(
# 1, target, 1)
# if not self.opts.model.sum_supp_sample:
# one_hot_labels = one_hot_labels.unsqueeze(2).expand(-1, -1, 5).contiguous().view(25, -1)
# loss[i] = F.mse_loss(score[i], one_hot_labels.unsqueeze(0))

# loss[i] = torch.pow(label - score[i], 2).sum() / batch_sz # in this case loss decrease (~100)
# TODO (high): if changed to label.numel(), loss won't decrease (~.16)
# loss = torch.pow(label - score, 2).sum() / label.numel()

loss = (loss.sum() / len(support_xfs)).unsqueeze(0)

return loss.unsqueeze(0) # output size: 1 x 1 (or the number of losses)
Expand Down
38 changes: 3 additions & 35 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
import torch
import argparse
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR, StepLR
Expand Down Expand Up @@ -101,14 +100,12 @@ def main():
show_str = '[TRAIN FROM SCRATCH] LOG' if not opts.io.resume else '[RESUME] LOG'
opts.logger('{}\n'.format(show_str))

# old_lr = optimizer.param_groups[0]['initial_lr']
total_ep = opts.train.nep
if opts.ctrl.start_epoch > 0 or opts.ctrl.start_iter > 0:
assert opts.io.resume
RESUME = True
else:
RESUME = False
VERY_FIRST_TIME = True

for epoch in range(opts.ctrl.start_epoch, total_ep):

Expand Down Expand Up @@ -148,19 +145,7 @@ def main():
break

support_x, support_y, query_x, query_y = process_input(batch, opts, mode='train')

# shape: gpu_num x loss_num
if opts.fsl.ctm:
# New pipeline
loss, disc_weights = net.forward_CTM(support_x, support_y, query_x, query_y, True)
else:
if opts.model.structure == 'original':
support_x, support_y, query_x, query_y = \
support_x.squeeze(0), support_y.squeeze(0), query_x.squeeze(0), query_y.squeeze(0)
loss = net(support_x, support_y, query_x, query_y)
else:
loss = net(support_x, support_y, query_x, query_y,
n_way=opts.fsl.n_way[which_ind], curr_shot=curr_shot)
loss, _ = net.forward_CTM(support_x, support_y, query_x, query_y, True)
loss = loss.mean(0)
vis_loss = loss.data.cpu().numpy()

Expand All @@ -177,34 +162,17 @@ def main():
if opts.train.clip_grad:
# doesn't affect that much
torch.nn.utils.clip_grad_norm_(net.parameters(), 0.5)
# grad =
optimizer.step()

iter_time = (time.time() - step_t)
left_time = compute_left_time(iter_time, epoch, total_ep, step, total_iter)
info = {
'curr_ep': epoch,
'curr_iter': step,
'total_ep': total_ep,
'total_iter': total_iter,
'loss': vis_loss,
'left_time': left_time,
'lr': new_lr,
'iter_time': iter_time
}

# SHOW TRAIN LOSS
if step % opts.io.iter_vis_loss == 0 or step == total_iter - 1 or VERY_FIRST_TIME:
VERY_FIRST_TIME = False
# loss
if step % opts.io.iter_vis_loss == 0 or step == total_iter - 1:
opts.logger(opts.io.loss_vis_str.format(epoch, total_ep, step, total_iter, total_loss.item()))
# time
if step % 1000*opts.io.iter_vis_loss == 0 or step == total_iter - 1:
opts.logger(opts.io.time_vis_str.format(left_time[0], left_time[1], left_time[2]))
# # visdom
# if opts.misc.vis.use and opts.misc.vis.method == 'visdom':
# # tb.add_scalar('loss', loss.item())
# vis.plot_loss(**info)
# vis.show_dynamic_info(**info)

# VALIDATION and SAVE BEST MODEL
if epoch > opts.test.do_after_ep and \
Expand Down

0 comments on commit a3e79f8

Please sign in to comment.