Skip to content

Commit

Permalink
fix resume bug
Browse files Browse the repository at this point in the history
  • Loading branch information
andi611 committed Oct 27, 2020
1 parent 9d83f75 commit d2aafd8
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions downstream/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
##########
class Runner():
''' Handler for complete training and evaluation progress of downstream models '''
def __init__(self, args, runner_config, dataloader, upstream, downstream, expdir):
def __init__(self, args, config, dataloader, upstream, downstream, expdir):

self.device = torch.device('cuda') if (args.gpu and torch.cuda.is_available()) else torch.device('cpu')
if torch.cuda.is_available(): print('[Runner] - CUDA is available!')
Expand All @@ -34,7 +34,8 @@ def __init__(self, args, runner_config, dataloader, upstream, downstream, expdir
self.log = SummaryWriter(expdir)

self.args = args
self.config = runner_config
self.all_config = config
self.config = self.all_config['runner']
self.dataloader = dataloader
self.upstream_model = upstream.to(self.device)
self.downstream_model = downstream.to(self.device)
Expand Down Expand Up @@ -78,7 +79,7 @@ def save_model(self, name='states', save_best=None):
'Optimizer': self.optimizer.state_dict(),
'Global_step': self.global_step,
'Settings': {
'Config': self.config,
'Config': self.all_config,
'Paras': self.args,
},
}
Expand Down
2 changes: 1 addition & 1 deletion run_downstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def main():

# train
runner = Runner(args=args,
runner_config=config['runner'],
config=config,
dataloader= {'train':train_loader, 'dev':dev_loader, 'test':test_loader},
upstream=upstream_model,
downstream=downstream_model,
Expand Down

0 comments on commit d2aafd8

Please sign in to comment.