diff --git a/rl4lms/envs/text_generation/logging_utils.py b/rl4lms/envs/text_generation/logging_utils.py index 0a893237..3711d629 100644 --- a/rl4lms/envs/text_generation/logging_utils.py +++ b/rl4lms/envs/text_generation/logging_utils.py @@ -12,6 +12,9 @@ from rich.logging import RichHandler +LOGGER = logging.getLogger(__name__) + + class Tracker: def __init__(self, base_path_to_store_results: str, @@ -76,9 +79,9 @@ def log_predictions(self, epoch: int, # randomly display few predictions for logging predictions_ = copy.deepcopy(predictions) random.shuffle(predictions_) - logging.info(f"Split {split_name} predictions") + LOGGER.info(f"Split {split_name} predictions") for pred in predictions_[:10]: - logging.info(pred) + LOGGER.info(pred) # for wandb logging, we create a table consisting of predictions # we can create one table per split per epoch @@ -119,10 +122,10 @@ def log_metrics(self, epoch: int, wandb.log(metric_dict_) # logger - logging.info(f"{split_name} metrics: {metrics_dict_}") + LOGGER.info(f"{split_name} metrics: {metrics_dict_}") def log_rollout_infos(self, rollout_info: Dict[str, float]): - logging.info(f"Rollout Info: {rollout_info}") + LOGGER.info(f"Rollout Info: {rollout_info}") rollout_info_file = os.path.join( self._run_path, "rollout_info.jsonl") with jsonlines.open(rollout_info_file, mode="a") as writer: @@ -133,7 +136,7 @@ def log_rollout_infos(self, rollout_info: Dict[str, float]): wandb.log(rollout_info) def log_training_infos(self, training_info: Dict[str, float]): - logging.info(f"Training Info: {training_info}") + LOGGER.info(f"Training Info: {training_info}") training_info_file = os.path.join( self._run_path, "training_info.jsonl") with jsonlines.open(training_info_file, mode="a") as writer: @@ -156,7 +159,7 @@ def checkpoint_base_path(self): return os.path.join(self._run_path, "checkpoints") def log_info(self, msg: str): - logging.info(msg) + LOGGER.info(msg) if __name__ == "__main__":