In [None]:
def encode_text(char_to_number, text):
    encoded = [char_to_number[c] for c in text]
    return encoded


def decode_text(number_to_char, encoded):
    text = [number_to_char[c] for c in encoded]
    text = reduce(lambda s1, s2: s1 + s2, text)
    return text


class RandomCrop():
    
    def __init__(self, crop_len):
        self.crop_len = crop_len
        
    def __call__(self, sample):
        text = sample['text']
        encoded = sample['encoded']
        # Randomly choose an index
        crop = np.random.randint(0, self.crop_len)
        text = text[:-crop]
        encoded = encoded[:-crop]
        return {**sample,
                'text': text,
                'encoded': encoded}
        

def create_one_hot_matrix(encoded, alphabet_len):
    # Create one hot matrix
    encoded_onehot = np.zeros([len(encoded), alphabet_len])
    tot_chars = len(encoded)
    encoded_onehot[np.arange(tot_chars), encoded] = 1
    return encoded_onehot


class OneHotEncoder():
    
    def __init__(self, alphabet_len):
        self.alphabet_len = alphabet_len
        
    def __call__(self, sample):
        # Load encoded text with numbers
        encoded = np.array(sample['encoded'])
        # Create one hot matrix
        encoded_onehot = create_one_hot_matrix(encoded[:-1], self.alphabet_len)
        return {**sample,
                'encoded_rnn': encoded.reshape(1,-1),
                'x_rnn': encoded_onehot,
                'y_rnn': encoded[-1]
               }
        
                
class ToTensor():
    
    def __call__(self, sample):
        # Convert one hot encoded text to pytorch tensor
        encoded_onehot = torch.tensor(sample['x_rnn']).float()
        encoded_conv = torch.LongTensor(sample['x_conv'])
        return {**sample,
                'x_rnn': encoded_onehot,
                'x_conv': encoded_conv
               }
        
    
class training_data_conv():
    
    def __call__(self, sample):
        x = sample['encoded'][:-1]
        y = sample['encoded'][-1]
        return {**sample,
                'x_conv': x,
                'y_conv': y
               }
    
    
def find_speech(text):
    
    in_speech = np.zeros(len(text))
    speech = 0
    for i in range(len(text)):
        if text[i]=='‘':
            speech = 1
        elif text[i]=='’':
            speech = 0
        in_speech[i] = speech
    return in_speech

In [None]:
class LewisCarrollDataset(Dataset):
    
    def __init__(self, filepath, transform=None, n_char=50):
        
        ### Load data
        text = open(filepath, 'r').read()
        
        alphabet = list(set(text))
        alphabet.sort()
        print('Found letters:', alphabet)
        
        text = re.sub('CHAPTER\s[A-Z]*\.\s.*\n\s', '', text)
        text = re.sub('\n', ' ', text)
        text = re.sub('\s\s[\s]+', '\n', text)
        text = re.sub('\*\n', '', text)
        text = re.sub('\*', '', text)
        text = text.lower()
        
        p = re.compile(r'([a-z])’([a-z])')
        text = p.sub(r"\1'\2", text)
        p = re.compile(r'([a-z])--([a-z])')
        text = p.sub(r'\1, \2', text)
        #text = re.sub('[‘’]', '"', text)
        text = re.sub('\n', ' ', text)
        text = re.sub('\*', '', text)
        text = re.sub('[[]', '(', text)
        text = re.sub('[]]', ')', text)
        text = re.sub('-', '', text)
        text = re.sub('[_:;“”]', '', text)
        text = re.sub('[0-9]', '', text)
        text = re.sub('[()]', ',', text)
        
        in_speech = find_speech(text)
        
        ### Char to number
        alphabet = list(set(text))
        alphabet.sort()
        print('Found letters:', alphabet)
        char_to_number = {char: number for number, char in enumerate(alphabet)}
        number_to_char = {number: char for number, char in enumerate(alphabet)}
        
        ### Store data
        self.text = text
        self.transform = transform
        self.alphabet = alphabet
        self.char_to_number = char_to_number
        self.number_to_char = number_to_char
        self.n_char = n_char
        self.speech = in_speech
        
    def __len__(self):
        return len(self.text)-1-self.n_char
        
    def __getitem__(self, idx):
        text = self.text[idx:idx+self.n_char]
        speech = self.speech[idx+self.n_char-2]
        # Encode with numbers
        encoded = encode_text(self.char_to_number, text)
        # Create sample
        sample = {'text': text, 'encoded': encoded, 'speech': speech}
        if self.transform:
            sample = self.transform(sample)
        return sample
        