Skip to content
This repository has been archived by the owner on Jul 10, 2024. It is now read-only.

Commit

Permalink
fix base_pytorch_model coding style
Browse files Browse the repository at this point in the history
  • Loading branch information
ifndef012 committed Jul 18, 2020
1 parent 573a4e8 commit 2b4eecf
Showing 1 changed file with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def __del__(self):
distributed.destroy_process_group()

def train(self, train_loader):
self.model.train()
with torch.enable_grad():
self.model.train()
with torch.enable_grad():
for _, batch in enumerate(train_loader):
feature_idx, feature_value, label = batch
output = self.model(feature_idx, feature_value).squeeze()
Expand All @@ -85,7 +85,7 @@ def evaluate(self):
valid_loader = get_from_registry(self.input_type, input_fn_registry)(
filepath=self.params['input']['valid_data'],
**self.params['training'])()
self.model.eval()
self.model.eval()
with torch.no_grad():
for _, batch in enumerate(valid_loader):
feature_idx, feature_value, label = batch
Expand All @@ -104,7 +104,7 @@ def predict(self):
test_loader = get_from_registry(self.input_type, input_fn_registry)(
filepath=self.params['input']['test_data'],
**self.params['training'])()
self.model.eval()
self.model.eval()
with torch.no_grad():
for _, batch in enumerate(test_loader):
feature_idx, feature_value, _ = batch
Expand Down Expand Up @@ -143,8 +143,8 @@ def save_checkpoint(self):
'optimizer': self.optimizer.state_dict()
}, buffer)
write_file(buffer,
uri=os.path.join(
self.params['output']['save_model_dir'], 'ckpt.pkl'))
uri=os.path.join(self.params['output']['save_model_dir'],
'ckpt.pkl'))

def model_fn(self, params):
seed = params["training"]["seed"]
Expand Down

0 comments on commit 2b4eecf

Please sign in to comment.