In [3]:
import os
import re
import glob
from tqdm import tqdm
import codecs
from chardet import detect

# get file encoding type
def get_encoding_type(file):
    with open(file, 'rb') as f:
        rawdata = f.read()
    return detect(rawdata)['encoding']


def correctTxtEncoding(filename, encoding_to='UTF-8'):
    from_codec = get_encoding_type(filename)
    temp_filename = filename[:-4]+"temp.txt"
    try:
        with open(filename, 'r', encoding=from_codec) as fr:
            with open(temp_filename, 'w', encoding=encoding_to) as fw:
                for line in fr:
                    fw.write(line[:-1]+'\r\n')
        os.remove(filename) # remove old encoding file
        os.rename(temp_filename, filename) # rename new encoding
    except UnicodeDecodeError:
        print('Decode Error')
    except UnicodeEncodeError:
        print('Encode Error')

def clean_text(string):
    pattern = '(page|PAGE|Page)(\s+\|\s+)([0-9]+)(.*)$'
    output_cleaned = re.sub('\s$', '', string, flags=re.MULTILINE)
    p=re.compile(pattern,re.MULTILINE)
    output_cleaned = p.sub(" ",output_cleaned)
    return output_cleaned

def merge_texts(texts):
    merged_text = ''
    for text in tqdm(texts):
        correctTxtEncoding(text)
        with open(text,'r',encoding="utf8") as f:
            for line in f:
                merged_text += line[:-1]+'\n'
    merged_text_cleaned = clean_text(merged_text)
    return merged_text_cleaned

def get_all_txt_files():
    path = os.path.abspath(os.path.join("../", os.pardir))+'\Data\**\*.txt';
    files = glob.glob(path, recursive=True)
    return files

def get_files_in_data_folder(folder):
    path = os.path.abspath(os.path.join("../", os.pardir))+ '\Data\\'+ folder + '\*.txt';
    files = glob.glob(path, recursive=True)
    return files

def get_files_by_author(author):
    path = os.path.abspath(os.path.join("../", os.pardir))+ '\Data\**\*_' + author + '.txt';
    files = glob.glob(path, recursive=True)
    return files

def get_train_test_validation(txt_file, train=0.70, test=0.20, val=0.10):
    train_doc =[]
    test_doc = []
    val_doc =[]
    with open(txt_file,'r', encoding='UTF-8') as f:
        file_input=f.readlines()

    count = 0
    for cnt, line in enumerate(file_input):
            if cnt <= len(file_input)*train:
                train_doc.append(line)
            elif (cnt > len(file_input)*train and cnt < len(file_input)*(train+test)):
                test_doc.append(line)
            else:
                val_doc.append(line)

    ## Write to file
    f = open(txt_file[:-4]+'_train.txt', "w+", encoding='UTF-8')
    count = 0
    for line in train_doc:
        count=count+1
        f.write(str(line))
        f.write("\n")  
    f.close()
    print("Training lines:\t",count)
    
    ## Write to file
    f = open(txt_file[:-4]+'_test.txt', "w+", encoding='UTF-8')
    count = 0
    for line in train_doc:
        count=count+1
        f.write(str(line))
        f.write("\n")  
    f.close()
    print("Testing lines:\t",count)

    ## Write to file
    f = open(txt_file[:-4]+'_val.txt',"w+", encoding='UTF-8')
    count = 0
    for line in val_doc:
        count=count+1
        f.write(str(line))
        f.write("\n")

    f.close()
    print("Validation lines:\t",count)
    
def save_output_file(path_from_output, file_name, data):
    path = os.path.abspath(os.path.join("../", os.pardir))+ '\Data\Outputs\\'+path_from_output + file_name;
    with open(path, "w", encoding='UTF-8') as file:
        file.write(data)
    print("File saved at:\t", path)
    return path



# all_text_files = get_all_txt_files()
# harry_potter_texts = merge_texts(get_files_in_data_folder("Harry_Potter"))
# print("Cleaned Harry Potter Text:\n\n\n" + harry_potter_texts[:2000]+"....\n\n")
# stephen_king_texts = merge_texts(get_files_by_author("Stephen_King"))
# print("Cleaned Stephen Text:\n\n\n" + stephen_king_texts[:2000]+"....\n\n")
# horror_movie_transcripts = merge_texts(get_files_in_data_folder("Horror_Movie_Transcripts"))
# print("Cleaned Horror Movie Transcripts:\n\n\n" + horror_movie_transcripts[:2000]+"....\n\n")
# public_domain_texts = merge_texts(get_files_in_data_folder("Public_Domain_Horror_Novels"))
# print("Cleaned Public Domain Horror Novels:\n\n\n" + public_domain_texts[:2000]+"....\n\n")

In [4]:
get_train_test_validation(save_output_file('Stephen_King_Playground\\', 
                 'merged_Stephen_King.txt', 
                 merge_texts(get_files_by_author("Stephen_King"))))

100%|████████████████████████████████████████████████████████████████████████████████████| 9/9 [11:19<00:00, 75.50s/it]


File saved at:	 C:\Users\Carson\OneDrive\Desktop\Programming\Projects\Epic\Data\Outputs\Stephen_King_Playground\merged_Stephen_King.txt
Training lines:	 80009
Testing lines:	 80009
Validation lines:	 11429


In [None]:
from transformers import TextDataset,DataCollatorForLanguageModeling

def load_dataset(train_path,test_path,tokenizer):
    train_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path=train_path,
          block_size=128)

    test_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path=test_path,
          block_size=128)

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False,
    )
    return train_dataset,test_dataset,data_collator

train_dataset,test_dataset,data_collator = load_dataset(train_path,test_path,tokenizer)