Skip to content
This repository has been archived by the owner on Feb 25, 2022. It is now read-only.

Fixing create_tfrecords.py when using custom tokenizer #144

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 12 additions & 3 deletions data/create_tfrecords.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import ftfy
import tensorflow as tf
from lm_dataformat import Reader
from tokenizers import Tokenizer
from tokenizers import Tokenizer, Encoding
from transformers import GPT2TokenizerFast
from tqdm import tqdm
import logging
Expand Down Expand Up @@ -60,7 +60,10 @@ def get_tokenizer(args):
if args.encoder_path is None:
return GPT2TokenizerFast.from_pretrained('gpt2')
else:
return Tokenizer.from_file(args.encoder_path)
try:
return Tokenizer.from_file(args.encoder_path)
except Exception as e:
raise ValueError("Unable to open tokenizer file from encoder_path argument")

def split_list(l, n):
# splits list/string into n size chunks
Expand All @@ -73,7 +76,10 @@ def archive_to_tokens(f, encoder, args):
for doc in reader.stream_data(threaded=False):
if args.ftfy: # fix text with ftfy if specified
doc = ftfy.fix_text(doc, normalization='NFKC')
doc = encoder.encode(doc) + args.separator # read document from lmd and append separator token
ids = encoder.encode(doc)
if type(ids) is Encoding:
ids = ids.ids
doc = ids + args.separator # read document from lmd and append separator token
yield split_list(doc, args.chunk_size) # split into n_ctx + 1 size chunks

def write_files(files, files_per, output_dir, out_name, start_no, write_remainder=False, process_no=None):
Expand All @@ -82,6 +88,9 @@ def write_files(files, files_per, output_dir, out_name, start_no, write_remainde
return
chunks = split_list(files, files_per)

if len(chunks) == 0:
raise ValueError("No chunks found, maybe you did set chunk_size argument too high or you have too small dataset?")

if len(chunks[-1]) != files_per and not write_remainder: # pop the last file if it's length != files per
remainder = chunks.pop(-1)
else:
Expand Down