From 2b4eecf696ee34b6521a0634c5f1d6ae90667fc8 Mon Sep 17 00:00:00 2001 From: Andrew Hsieh Date: Sun, 12 Jul 2020 21:48:19 +0800 Subject: [PATCH] fix base_pytorch_model coding style --- .../submarine/ml/pytorch/model/base_pytorch_model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py index a9cfe0050c..f54c30b6dd 100644 --- a/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py +++ b/submarine-sdk/pysubmarine/submarine/ml/pytorch/model/base_pytorch_model.py @@ -68,8 +68,8 @@ def __del__(self): distributed.destroy_process_group() def train(self, train_loader): - self.model.train() - with torch.enable_grad(): + self.model.train() + with torch.enable_grad(): for _, batch in enumerate(train_loader): feature_idx, feature_value, label = batch output = self.model(feature_idx, feature_value).squeeze() @@ -85,7 +85,7 @@ def evaluate(self): valid_loader = get_from_registry(self.input_type, input_fn_registry)( filepath=self.params['input']['valid_data'], **self.params['training'])() - self.model.eval() + self.model.eval() with torch.no_grad(): for _, batch in enumerate(valid_loader): feature_idx, feature_value, label = batch @@ -104,7 +104,7 @@ def predict(self): test_loader = get_from_registry(self.input_type, input_fn_registry)( filepath=self.params['input']['test_data'], **self.params['training'])() - self.model.eval() + self.model.eval() with torch.no_grad(): for _, batch in enumerate(test_loader): feature_idx, feature_value, _ = batch @@ -143,8 +143,8 @@ def save_checkpoint(self): 'optimizer': self.optimizer.state_dict() }, buffer) write_file(buffer, - uri=os.path.join( - self.params['output']['save_model_dir'], 'ckpt.pkl')) + uri=os.path.join(self.params['output']['save_model_dir'], + 'ckpt.pkl')) def model_fn(self, params): seed = params["training"]["seed"]