In [1]:
import numpy as np
import os
import regex as re
# import requests

In [211]:
def one_hot_encode(x, vocab_size):
    """
    Taking x of shape (B, S) where B is the batch size and S is the sequence length
    and return a one-hot encoded version of x of shape (B, S, Vocab_size/features)
    """
    batch_size = np.array(x).shape[0]
    seq_len = np.array(x).shape[1]
    big_list = np.zeros((batch_size, seq_len, vocab_size))
    
    for batch in range(batch_size):
        big_list[batch, np.arange(seq_len), x[batch]] = 1
    return big_list

In [2]:
# download the tiny shakespeare dataset

# if not os.path.exists(input_file_path):
#     data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
#     with open(input_file_path, 'w', encoding='utf-8') as f:
#         f.write(requests.get(data_url).text)

input_file_path = os.path.join(os.path.dirname("../data/"), 'input.txt')

with open(input_file_path, 'r', encoding='utf-8') as f:
    data = f.read()
n = len(data)

In [213]:
# Split data into train, val folds
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

In [4]:
def cross_entropy_loss(y, t):
    return -np.sum(np.log())

In [212]:
class DataLoader(object):
    def __init__(self, data, sequence_length, batch_size):
        """
        Initializes the DataLoader with the given data, sequence length, and batch size.
        
        Args:
            data (str): The input text data.
            sequence_length (int): The length of each sequence.
            batch_size (int): The size of each batch.
        """
        self.text = data
        
        self.sequence_length = sequence_length
        self.batch_size = batch_size
        self.counter = 0
        
        # Getting list of words, from preprocessed text
        self.word_list = self._text_preprocessing()
        
        # Unique words
        self.unique_words = set(self.word_list)
        
        # Adding unknown token for exceptions
        self.unique_words.add('<UNK>')
        
        # Create a vocabulary with '<UNK>' first
        self.word_to_ix = {'<UNK>': 0}
        self.word_to_ix.update({tok: idx + 1 for idx, tok in enumerate(self.unique_words - {'<UNK>'})})
        
        # Create the inverse mapping
        self.ix_to_word = {idx: tok for tok, idx in self.word_to_ix.items()}
        
        self.vocab_size = len(self.word_to_ix)
        
    def _text_preprocessing(self):
        """
        Preprocesses the text by converting to lowercase, removing numbers, punctuation, and special tokens.
        
        Returns:
            str: The preprocessed text.
        """
        self.text = self.text.lower()
        self.text = re.sub(r'\d+', '', self.text)  # Remove numbers
        self.text = re.sub(r"[^\w\s']+", '', self.text)  # Remove punctuation ( except apostrophes ) 
        self.text = re.sub(r'\W+', ' ', self.text)  # Remove special tokens
        self.text = self.text.split(' ')
        return self.text
    
    def encode(self, words):
        """
        Encodes the given words into their corresponding indices using the vocabulary.
        
        Args:
            words (str): The words to encode.
        
        Returns:
            list: A list of indices representing the encoded words.
        """
        # For strings
        # return [self.word_to_ix.get(word, self.word_to_ix['<UNK>']) for word in words.split(' ')]
        # For list of strings
        return [self.word_to_ix.get(word, self.word_to_ix['<UNK>']) for word in words]
    
    def decode(self, indexes):
        """
        Decodes the given indices into their corresponding words using the vocabulary.
        
        Args:
            indexes (list): The indices to decode.
        
        Returns:
            list: A list of words corresponding to the indices.
        """
        return [self.ix_to_word.get(idx, '<UNK>') for idx in indexes]
    
    def next_batch(self):
        """
        Iterate to the next batch in text.
        
        Returns:
            batches (list): List of (input_sequence, target_sequence) pairs for training.
            shape: (batch_size, sequence_length, 2)
            None: If there are no more batches left.
        """
        target_offset = 1 # Offset inside the function
        num_batches = int(len(self.word_list) / self.batch_size)
        batch_sequence = []
        if self.counter <= (num_batches * self.batch_size):
            for i in range(0, self.batch_size):
                # Ensure we don't exceed the list length by taking the minimum of the desired end index and the list length.
                # For example, if the desired end is 1002 but text length is 1000, we take 1000, even if the batch isn't full.
                input_sequence = self.word_list[self.counter + i : min(len(self.word_list), self.counter + i + self.sequence_length)]
                target_sequence = self.word_list[(self.counter + i) + target_offset : min(len(self.word_list), (self.counter + i) + self.sequence_length + target_offset)]
                
                # Encode the input and target sequences into their corresponding numerical representations.
                input_sequence = self.encode(input_sequence)
                target_sequence = self.encode(target_sequence)
                
                batch_sequence.append([input_sequence, target_sequence])
                
            # One-hot encoding
            batch_sequence = np.array(batch_sequence)
            input  = one_hot_encode(batch_sequence[:, 0, :], self.vocab_size)
            target = one_hot_encode(batch_sequence[:, 1, :], self.vocab_size)
            # Increasing the counter by the batch size
            self.counter += self.batch_size
            return input, target
        else:
            return None
    
    def drop_counter(self):
        """
        Drops the counter to zero.
        """
        self.counter = 0

In [214]:
# Using "beautiful" numbers for hyperparameters (x^2)
dl = DataLoader(data=train_data, sequence_length=32, batch_size=8)

In [215]:
b, c = dl.next_batch()
b.shape, c.shape, b.dtype, c.dtype

((8, 32, 11411), (8, 32, 11411), dtype('float64'), dtype('float64'))