In [1]:
import codecs
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
from data_loading_utils import read_lines_from_file_as_data_chunks
import time  # Import the time module
import threading
from concurrent.futures import ThreadPoolExecutor





# Dataset creation

In [None]:
class WPDataset(Dataset):
    """
    A class loading clean text from txt files to be used as an input 
    to PyTorch DataLoader.

    Datapoints are sequences of words (tokenized) + label (next token). If the 
    words have not been seen before (i.e, they are not found in the
    'word_to_id' dict), they will be mapped to the unknown word '<UNK>'.
    chunk_size: how much we read from the file at the time - we could play around with it. 
    """
    def __init__(self, filenames, tokenizer, samples_length=5, chunk_size=1000000, artificial_padding=True):
        self.sequences = [] # X
        self.labels = [] # Y 
        self.tokenizer = tokenizer
        self.samples_length = samples_length
        self.artificial_padding = artificial_padding
        self.pad_token_id = tokenizer.pad_token_id  # Get the PAD token ID = 0 
        
        with ThreadPoolExecutor(max_workers=4) as executor:
            futures = [executor.submit(self.read_file, filename, chunk_size) for filename in filenames]
            for future in futures:
                future.result()  # Ensure all files are processed

    def read_file(self, filename, chunk_size):
        print("Read in ", filename)
        start_time = time.time()
        try:
            read_lines_from_file_as_data_chunks(filename, chunk_size, self.process_lines)
        except FileNotFoundError:
            print(f"File not found: {filename}")
        except Exception as e:
            print(f"An error occurred: {e}")
        end_time = time.time()  # End the timer
        print(f"Time taken to read {filename}: {end_time - start_time:.2f} seconds")

    def process_lines(self, data, eof, file_name):
        """
        eof: end of file 
        Callback function to process lines read from file.
        """
        if not eof:
            text = data.strip()  # Remove leading/trailing whitespace
            # split sentence into sub-sentences so that it can be passed to tokenizer, which has a max capacity of 512 
            line_chunks = self.split_into_chunks(text) 
            for chunk in line_chunks:
                line_tokens = self.tokenizer.tokenize(chunk) # data is already lower case 
                line_tokens_ids = self.tokenizer.convert_tokens_to_ids(line_tokens)
                self.create_sequences(line_tokens_ids)
        else:
            print(f"Finished reading file: {file_name}")

    def split_into_chunks(self, line, max_length=512):
        """Splits a long line into chunks of max_length tokens."""
        return [line[i:i + max_length] for i in range(0, len(line), max_length)]

    def create_sequences(self, token_ids):
        """
        Create sequences and labels from tokenized text.
        """
        n = self.samples_length
        if self.artificial_padding:
            k = 0 
            while k < len(token_ids) - n:
                for i in range(1, n + 1):
                    seq = token_ids[k:i+k] + [self.pad_token_id] * (n - i)
                    label = token_ids[i + k]
                    self.sequences.append(seq)
                    self.labels.append(label)
                k += n
            remaining_tokens = len(token_ids) - k
            if remaining_tokens > 1:
                for i in range(1, remaining_tokens):
                    seq = token_ids[k:i+k] + [self.pad_token_id] * (n - i)
                    label = token_ids[i + k]
                    self.sequences.append(seq)
                    self.labels.append(label)     
        else: 
            # Ensure all sequences are of length samples_length
            for i in range(self.samples_length, len(token_ids)): # sliding window 
                seq = token_ids[i-self.samples_length:i]
                label = token_ids[i]
                self.sequences.append(seq)
                self.labels.append(label)

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return torch.tensor(self.sequences[idx]), torch.tensor(self.labels[idx])

# Example usage

In [None]:
filenames = ['data/clean_data/news_summarization.txt', 'data/clean_data/twitter.txt']

# Define the tokenizer (using BERT tokenizer as an example)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
dataset = WPDataset(filenames, tokenizer)

In [None]:
len(dataset.sequences)

In [13]:
dataset.sequences[1]

[2821, 2129, 0, 0, 0]

In [57]:
dataset.labels[1]

'##ed'

# Training loop 

In [None]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)

In [None]:
# Iterate through the DataLoader
i = 0 
for batch in dataloader:
    sequences, labels = batch
    print(sequences.shape, labels.shape)
    print(sequences)
    print(labels)
    print('')     
    # Your training loop here

# Testing loading configs/experiments:
Total data: 5.862,7 MB

news_summarization.txt: 264MB 
twitter.txt: 551,9 MB

USING news_summarization.txt ONLY 

1. chunk_size=1000000, artifical_padding = False   
   time = 528.20 seconds, memory = 6.44 GB
2. chunk_size=1000000, artifical_padding = True
   time = 571.33 seconds, memory = 6.93 GB 

3. chunk_size=2000000, artifical_padding = True
   time = 561.67 seconds, memory = 6.95 GB
4. chunk_size=500000, artifical_padding = True
   time = 562.89 seconds, memory = 6.95 GB 
   
5. Thread, chunk_size=1000000, artifical_padding = True
   time = 546.93 seconds, memory = 6.86 GB

USING news_summarization.txt AND twitter.txt  

6. Thread, chunk_size=1000000, artifical_padding = True
   time = 1102.74 seconds, memory =  13.61 GB

7. No tread, chunk_size=1000000, artifical_padding = True
    time = 515.84 + 1158.22 , memory = 20.01 GB 