From 13f64106dd67ce682ba51654ae3255f8c6f34779 Mon Sep 17 00:00:00 2001 From: galloj Date: Wed, 17 Mar 2021 20:36:22 +0100 Subject: [PATCH] Fixing create_tfrecords.py when using custom tokenizer --- data/create_tfrecords.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/data/create_tfrecords.py b/data/create_tfrecords.py index ea3be2c0..7d404592 100644 --- a/data/create_tfrecords.py +++ b/data/create_tfrecords.py @@ -5,7 +5,7 @@ import ftfy import tensorflow as tf from lm_dataformat import Reader -from tokenizers import Tokenizer +from tokenizers import Tokenizer, Encoding from transformers import GPT2TokenizerFast from tqdm import tqdm import logging @@ -60,7 +60,10 @@ def get_tokenizer(args): if args.encoder_path is None: return GPT2TokenizerFast.from_pretrained('gpt2') else: - return Tokenizer.from_file(args.encoder_path) + try: + return Tokenizer.from_file(args.encoder_path) + except Exception as e: + raise ValueError("Unable to open tokenizer file from encoder_path argument") def split_list(l, n): # splits list/string into n size chunks @@ -73,7 +76,10 @@ def archive_to_tokens(f, encoder, args): for doc in reader.stream_data(threaded=False): if args.ftfy: # fix text with ftfy if specified doc = ftfy.fix_text(doc, normalization='NFKC') - doc = encoder.encode(doc) + args.separator # read document from lmd and append separator token + ids = encoder.encode(doc) + if type(ids) is Encoding: + ids = ids.ids + doc = ids + args.separator # read document from lmd and append separator token yield split_list(doc, args.chunk_size) # split into n_ctx + 1 size chunks def write_files(files, files_per, output_dir, out_name, start_no, write_remainder=False, process_no=None): @@ -82,6 +88,9 @@ def write_files(files, files_per, output_dir, out_name, start_no, write_remainde return chunks = split_list(files, files_per) + if len(chunks) == 0: + raise ValueError("No chunks found, maybe you did set chunk_size argument too high or you have too small dataset?") + if len(chunks[-1]) != files_per and not write_remainder: # pop the last file if it's length != files per remainder = chunks.pop(-1) else: