In [1]:
import os
import math
import pickle
from nltk import word_tokenize

In [25]:
class Helper:
    
    def __init__(
        self, *, 
        #PKL ARGS
        pkl_text_path="", pkl_summary_path="", pkl_word2idx_path="", pkl_idx2word_path="",
        #BUILD PKL ARGS
        text_path="", summary_path="", max_chr_count=300, dataset_size=-1,
        #SAVE
        save_path="",
        #FLAGS
        build_from_dataset=False
    ):
        if pkl_text_path != "" and pkl_summary_path != "" and pkl_word2idx_path != "" and pkl_idx2word_path != "" and \
                os.path.isfile(pkl_text_path) and os.path.isfile(pkl_summary_path) and os.path.isfile(pkl_word2idx_path) and \
                os.path.isfile(pkl_idx2word_path) and not build_from_dataset:
            self._load_pkl(pkl_text_path, pkl_summary_path, pkl_word2idx_path, pkl_idx2word_path)
                
        elif text_path !="" and summary_path != "":
            self.force_load_from_files(text_path, summary_path, max_chr_count, dataset_size)
            self.override_files_with_current_data(save_path)
            
        else:
            raise ValueError("Wrong argument combination provided.")
            
    def override_files_with_current_data(self, save_path):
        if save_path != "":
            print("Overriding files with current data.")
            
            with open(save_path+"/dataset_text.pkl", "wb") as data_file:
                pickle.dump(self.dataset_text, data_file)
            with open(save_path+"/dataset_summary.pkl", "wb") as data_file:
                pickle.dump(self.dataset_summary, data_file)
            with open(save_path+"/word2idx.pkl", "wb") as data_file:
                pickle.dump(self.word2idx, data_file)
            with open(save_path+"/idx2word.pkl", "wb") as data_file:
                pickle.dump(self.idx2word, data_file)

            print("Data saved to files.")

    def force_load_from_files(self, text_path, summary_path, max_chr_count, dataset_size):
        print("Building data from dataset.")
        
        self.dataset_text, self.dataset_summary = self._load_dataset(
            text_path=text_path, summary_path=summary_path, 
            max_chr_count=max_chr_count, dataset_size=dataset_size
        )
        self.word2idx, self.idx2word = self._create_vocabulary()
            
    def _load_pkl(self, pkl_text_path, pkl_summary_path, pkl_word2idx_path, pkl_idx2word_path):
        with open(pkl_text_path, "rb") as data_file:
            self.dataset_text = pickle.load(data_file)
        with open(pkl_summary_path, "rb") as data_file:
            self.dataset_summary = pickle.load(data_file)
        with open(pkl_word2idx_path, "rb") as data_file:
            self.word2idx = pickle.load(data_file)
        with open(pkl_idx2word_path, "rb") as data_file:
            self.idx2word = pickle.load(data_file)

        print("PKL dataset loaded with vocab size ({}) and dataset len ({})".format(
            len(self.idx2word), len(self.dataset_text)
        ))

    def _load_dataset(self, text_path, summary_path, max_chr_count=300, dataset_size=-1):
        with open(text_path, "r") as text_file:
            text = text_file.read().lower().strip().split("\n")
        with open(summary_path, "r") as summary_file:
            summary = summary_file.read().lower().strip().split("\n")

        if len(text) != len(summary):
            raise RuntimeError("Dataset Inconsistency.")

        dataset_text = []
        dataset_summary = []

        size = min(len(text), dataset_size) if dataset_size != -1 else len(text)

        for line_idx in range(size):
            print("\r <{progress}{left}>".format(
                progress="="*int((line_idx/size)*30),
                left=" "*int(((size-line_idx)/size)*30)
            ), end="")

            token_texts = word_tokenize(text[line_idx])
            token_summaries = word_tokenize(summary[line_idx])
            
            if len(token_texts) <= max_chr_count and len(token_summaries) <= max_chr_count:
                dataset_text.append(token_texts)
                dataset_summary.append(token_summaries)
        
        print("")                
                
        return dataset_text, dataset_summary
    
    def _create_vocabulary(self):
        vocabulary = set()
        
        for idx, (words_text, words_summary) in enumerate(zip(self.dataset_text, (self.dataset_summary))):
            for word in words_text:
                vocabulary.update([word])

            for word in words_summary:
                vocabulary.update([word])
                
        vocabulary.update(["<SOS>"])
        vocabulary.update(["<EOS>"])
        vocabulary.update(["<UNK>"])
               
        vocabulary = list(vocabulary)
        vocabulary.sort()
            
        return {word: idx for idx, word in enumerate(vocabulary)}, vocabulary
    
    def next_batch(self, batch_size):
        rounds = math.ceil(len(self.dataset_text)/batch_size)

        for batch_num in range(rounds):
            sample_text = self.dataset_text[batch_num*batch_size : (batch_num+1)*batch_size]
            sample_summary = self.dataset_summary[batch_num*batch_size : (batch_num+1)*batch_size]
            
            sample_text = [[self.word2idx[word] for word in sentence] for sentence in sample_text]
            sample_summary = [[self.word2idx[word] for word in sentence] for sentence in sample_summary]
            
            sample_text_sizes = [len(sentence) for sentence in sample_text]
            sample_summary_sizes = [len(sentence) for sentence in sample_summary]
            
            sample_text_max_len = max(sample_text_sizes) + 2
            sample_summary_max_len = max(sample_summary_sizes) + 2
            
            sample_text = [[self.word2idx["<SOS>"]] + sentence + [self.word2idx["<EOS>"]]*(sample_text_max_len - 1 - len(sentence)) for sentence in sample_text]
            sample_summary = [[self.word2idx["<SOS>"]] + sentence + [self.word2idx["<EOS>"]]*(sample_summary_max_len - 1 - len(sentence)) for sentence in sample_summary]
            
            yield (sample_text, sample_summary), (sample_text_sizes, sample_summary_sizes)

In [None]:
h = Helper(
    #PKL ARGS
    pkl_text_path="./pkl_dataset/dataset_text.pkl", pkl_summary_path="./pkl_dataset/dataset_summary.pkl", 
    pkl_word2idx_path="./pkl_dataset/word2idx.pkl", pkl_idx2word_path="./pkl_dataset/idx2word.pkl",
    #BUILD PKL ARGS
    text_path="../Dataset/train/train.article.txt", summary_path="../Dataset/train/train.title.txt", max_chr_count=300, dataset_size=-1,
    #SAVE
    save_path="./pkl_dataset",
    #FLAGS
    build_from_dataset=False
)

Building data from dataset.