Skip to content

Commit

Permalink
code clean
Browse files Browse the repository at this point in the history
  • Loading branch information
hytseng0509 committed Jul 25, 2018
1 parent 938095c commit 148ced9
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
30 changes: 16 additions & 14 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,26 +385,28 @@ def _compute_kl(self, mu):
encoding_loss = torch.mean(mu_2)
return encoding_loss

def resume(self, model_dir):
def resume(self, model_dir, train=True):
checkpoint = torch.load(model_dir)
# weight
self.disA.load_state_dict(checkpoint['disA'])
self.disA2.load_state_dict(checkpoint['disA2'])
self.disB.load_state_dict(checkpoint['disB'])
self.disB2.load_state_dict(checkpoint['disB2'])
self.disContent.load_state_dict(checkpoint['disContent'])
if train:
self.disA.load_state_dict(checkpoint['disA'])
self.disA2.load_state_dict(checkpoint['disA2'])
self.disB.load_state_dict(checkpoint['disB'])
self.disB2.load_state_dict(checkpoint['disB2'])
self.disContent.load_state_dict(checkpoint['disContent'])
self.enc_c.load_state_dict(checkpoint['enc_c'])
self.enc_a.load_state_dict(checkpoint['enc_a'])
self.gen.load_state_dict(checkpoint['gen'])
# optimizer
self.disA_opt.load_state_dict(checkpoint['disA_opt'])
self.disA2_opt.load_state_dict(checkpoint['disA2_opt'])
self.disB_opt.load_state_dict(checkpoint['disB_opt'])
self.disB2_opt.load_state_dict(checkpoint['disB2_opt'])
self.disContent_opt.load_state_dict(checkpoint['disContent_opt'])
self.enc_c_opt.load_state_dict(checkpoint['enc_c_opt'])
self.enc_a_opt.load_state_dict(checkpoint['enc_a_opt'])
self.gen_opt.load_state_dict(checkpoint['gen_opt'])
if train:
self.disA_opt.load_state_dict(checkpoint['disA_opt'])
self.disA2_opt.load_state_dict(checkpoint['disA2_opt'])
self.disB_opt.load_state_dict(checkpoint['disB_opt'])
self.disB2_opt.load_state_dict(checkpoint['disB2_opt'])
self.disContent_opt.load_state_dict(checkpoint['disContent_opt'])
self.enc_c_opt.load_state_dict(checkpoint['enc_c_opt'])
self.enc_a_opt.load_state_dict(checkpoint['enc_a_opt'])
self.gen_opt.load_state_dict(checkpoint['gen_opt'])
return checkpoint['ep'], checkpoint['total_it']

def save(self, filename, ep, total_it):
Expand Down
6 changes: 3 additions & 3 deletions src/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main():
print('\n--- load model ---')
model = DRIT(opts)
model.setgpu(opts.gpu)
model.resume(opts.resume)
model.resume(opts.resume, train=False)
model.eval()

# directory
Expand All @@ -46,9 +46,9 @@ def main():
img2 = img2.cuda()
with torch.no_grad():
if opts.a2b:
img = model.test_forward(img1, img2, opts.random_z, a2b=True, idx=idx2)
img = model.test_forward(img1, img2, opts.random_z, a2b=True)
else:
img = model.test_forward(img2, img1, opts.random_z, a2b=False, idx=idx2)
img = model.test_forward(img2, img1, opts.random_z, a2b=False)
imgs.append(img)
names.append('output_{}'.format(idx2))
save_imgs(imgs, names, os.path.join(result_dir, '{}'.format(idx1)))
Expand Down

0 comments on commit 148ced9

Please sign in to comment.