diff --git a/utils/batch_viewer.py b/utils/batch_viewer.py index 04f9666c..04885251 100644 --- a/utils/batch_viewer.py +++ b/utils/batch_viewer.py @@ -18,6 +18,7 @@ def view_data( args, neox_args, batch_fn: callable = None, + save_path: str = None, ): # fake MPU setup (needed to init dataloader without actual GPUs or parallelism) mpu.mock_model_parallel() @@ -37,12 +38,14 @@ def view_data( if args.mode == "save": # save full batches for each step in the range (WARNING: this may consume lots of storage!) - np.save(f"./dump_data/batch{i}_bs{neox_args.train_micro_batch_size_per_gpu}", batch) + filename = f"batch{i}_bs{neox_args.train_micro_batch_size_per_gpu}" + np.save(os.path.join(save_path, filename), batch) elif args.mode == "custom": # dump user_defined statistic to a jsonl file (save_fn must return a dict) log = batch_fn(batch, i) - with open("./dump_data/stats.jsonl", "w+") as f: + filename = "stats.jsonl" + with open(os.path.join(save_path, filename), "w+") as f: f.write(json.dumps(log) + "\n") else: raise ValueError(f'mode={mode} not acceptable--please pass either "save" or "custom" !') @@ -74,6 +77,12 @@ def view_data( choices=["save", "custom"], help="Choose mode: 'save' to log all batches, and 'custom' to use user-defined statistic" ) + parser.add_argument( + "--save_path", + type=str, + default=0, + help="Save path for files" + ) args = parser.parse_known_args()[0] # init neox args @@ -86,10 +95,11 @@ def save_fn(batch: np.array, iteration: int): # define your own logic here return {"iteration": iteration, "text": None} + os.makedirs(args.save_path, exist_ok=True) view_data( args, neox_args, batch_fn=save_fn, + save_path=args.save_path, ) -