diff --git a/src/training/language_specific.py b/src/training/language_specific.py index 1fe98aa34a..b6ee3ac211 100644 --- a/src/training/language_specific.py +++ b/src/training/language_specific.py @@ -916,6 +916,9 @@ def set_lang_specific_parameters(ctx, lang): TEXT2IMAGE_EXTRA_ARGS = [] EXPOSURES = [] + GENERATE_WORD_BIGRAMS = None + WORD_DAWG_SIZE = None + # Latin languages. if lang == "enm": TEXT2IMAGE_EXTRA_ARGS += ["--ligatures"] # Add ligatures when supported @@ -1364,18 +1367,40 @@ def set_lang_specific_parameters(ctx, lang): LANG_IS_RTL = False NORM_MODE = 1 - for var in [v for v in locals()]: - if var.isupper(): - value = locals()[var] - lowervar = var.lower() - if hasattr(ctx, lowervar) and getattr(ctx, lowervar) != value: - log.debug(f"{lowervar} = {value} (was {getattr(ctx, lowervar)})") - setattr(ctx, lowervar, value) - elif hasattr(ctx, lowervar): - log.debug(f"{lowervar} = {value} (set on cmdline)") + vars_to_transfer = { + 'ambigs_filter_denominator': AMBIGS_FILTER_DENOMINATOR, + 'bigram_dawg_factor': BIGRAM_DAWG_FACTOR, + 'exposures': EXPOSURES, + 'filter_arguments': FILTER_ARGUMENTS, + 'fonts': FONTS, + 'fragments_disabled': FRAGMENTS_DISABLED, + 'generate_word_bigrams': GENERATE_WORD_BIGRAMS, + 'lang_is_rtl': LANG_IS_RTL, + 'leading': LEADING, + 'mean_count': MEAN_COUNT, + 'mix_lang': MIX_LANG, + 'norm_mode': NORM_MODE, + 'number_dawg_factor': NUMBER_DAWG_FACTOR, + 'punc_dawg_factor': PUNC_DAWG_FACTOR, + 'run_shape_clustering': RUN_SHAPE_CLUSTERING, + 'text2image_extra_args': TEXT2IMAGE_EXTRA_ARGS, + 'text_corpus': TEXT_CORPUS, + 'training_data_arguments': TRAINING_DATA_ARGUMENTS, + 'word_dawg_factor': WORD_DAWG_FACTOR, + 'word_dawg_size': WORD_DAWG_SIZE, + 'wordlist2dawg_arguments': WORDLIST2DAWG_ARGUMENTS, + } + + for attr, value in vars_to_transfer.items(): + if hasattr(ctx, attr): + if getattr(ctx, attr) != value: + log.debug(f"{attr} = {value} (was {getattr(ctx, attr)})") + setattr(ctx, attr, value) else: - log.debug(f"{lowervar} = {value}") - setattr(ctx, lowervar, value) + log.debug(f"{attr} = {value} (set on cmdline)") + else: + log.debug(f"{attr} = {value}") + setattr(ctx, attr, value) return ctx diff --git a/src/training/tesstrain.py b/src/training/tesstrain.py index 6a0e929067..451ed792e0 100644 --- a/src/training/tesstrain.py +++ b/src/training/tesstrain.py @@ -14,7 +14,7 @@ # Tesseract. For a detailed description of the phases, see # https://github.com/tesseract-ocr/tesseract/wiki/TrainingTesseract # -import sys, os, subprocess, logging +import sys, os, logging sys.path.insert(0, os.path.dirname(__file__)) @@ -32,7 +32,7 @@ log = logging.getLogger() -def setup_logging(logfile): +def setup_logging_console(): log.setLevel(logging.DEBUG) console = logging.StreamHandler() console.setLevel(logging.INFO) @@ -42,6 +42,8 @@ def setup_logging(logfile): console.setFormatter(console_formatter) log.addHandler(console) + +def setup_logging_logfile(logfile): logfile = logging.FileHandler(logfile) logfile.setLevel(logging.DEBUG) logfile_formatter = logging.Formatter( @@ -52,8 +54,9 @@ def setup_logging(logfile): def main(): + setup_logging_console() ctx = parse_flags() - setup_logging(ctx.log_file) + setup_logging_logfile(ctx.log_file) if not ctx.linedata: log.error("--linedata_only is required since only LSTM is supported") sys.exit(1) diff --git a/src/training/tesstrain_utils.py b/src/training/tesstrain_utils.py index 8c006e6837..3b9f1a7316 100644 --- a/src/training/tesstrain_utils.py +++ b/src/training/tesstrain_utils.py @@ -49,13 +49,10 @@ def __init__(self): self.max_pages = 0 self.save_box_tiff = False - self.output_dir = "/tmp/tesstrain/tessdata" self.overwrite = False self.linedata = False self.run_shape_clustering = False self.extract_font_properties = True - self._workspace_dir = TemporaryDirectory(prefix="tesstrain") - self.workspace_dir = self._workspace_dir.name def err_exit(msg): @@ -88,8 +85,8 @@ def run_command(cmd, *args, env=None): else: try: proclog.error(proc.stdout.decode("utf-8", errors="replace")) - except Exception: - pass + except Exception as e: + proclog.error(e) err_exit(f"Program {cmd} failed with return code {proc.returncode}. Abort.") @@ -101,10 +98,10 @@ def check_file_readable(*filenames): filenames = [filenames] for filename in filenames: try: - with Path(filename).open() as f: + with Path(filename).open(): pass except FileNotFoundError: - err_exit(f"Expected file {filename} does not exist") + err_exit(f"Required/expected file '{filename}' does not exist") except PermissionError: err_exit(f"{filename} is not readable") except IOError as e: @@ -191,7 +188,6 @@ def check_file_readable(*filenames): nargs="+", help="A list of exposure levels to use (e.g. -1,0,1).", ) -parser.add_argument("--workspace_dir") # Does simple command-line parsing and initialization. @@ -200,7 +196,6 @@ def parse_flags(argv=None): log.debug(ctx) parser.parse_args(args=argv, namespace=ctx) log.debug(ctx) - log.info("Parsing") if not ctx.lang_code: err_exit("Need to specify a language --lang") @@ -215,12 +210,15 @@ def parse_flags(argv=None): ) else: ctx.tessdata_dir = tessdata_prefix + if not ctx.output_dir: + ctx.output_dir = mkdtemp(prefix=f"trained-{ctx.lang_code}-{ctx.timestamp}") + log.info(f"Output directory set to: {ctx.output_dir}") # Location where intermediate files will be created. ctx.training_dir = mkdtemp(prefix=f"{ctx.lang_code}-{ctx.timestamp}") # Location of log file for the whole run. ctx.log_file = Path(ctx.training_dir) / "tesstrain.log" - log.info(f"Log file {ctx.log_file}") + log.info(f"Log file location: {ctx.log_file}") def show_tmpdir_location(training_dir): # On successful exit we will delete this first; on failure we want to let the user @@ -356,7 +354,7 @@ def phase_I_generate_image(ctx, par_factor): # for tesseract to recognize during training. Take only the ngrams whose # combined weight accounts for 95% of all the bigrams in the language. lines = Path(ctx.bigram_freqs_file).read_text(encoding="utf-8").split("\n") - records = (line.split(" ") for line in splittable_lines) + records = (line.split(" ") for line in lines) p = 0.99 ngram_frac = p * sum(int(rec[1]) for rec in records if len(rec) >= 2)