In [1]:
import torch
from torch.utils import data
import os

In [22]:
filename = "../pa4Data/train.txt"
with open(filename) as f:
    text_blob = ""
    for line in f:
        if line.startswith("<start>"):
            line = "%\n"
        elif line.startswith("<end>"):
            line = "`\n"
        text_blob += line
    chars = set(text_blob)
    char_index = {char: i for i, char in enumerate(sorted(chars))}
    # Sort so that it's interpretable

In [23]:
class PA4Dataset(data.DataLoader):
    
    def __init__(self, filename, character_index, chunk_size=100):        
        self.character_index = character_index
        self.filename = filename
        
        self.text_chunks = PA4Dataset.__load_text_chunks(filename, chunk_size)
        self.num_chunks = len(self.text_chunks)
        
    def __len__(self):
        return self.num_chunks
    
    def __getitem__(self, index):            
        i_chunk = index
        cur_chunk = self.text_chunks[i_chunk]
        input_tensors = torch.zeros(len(cur_chunk), len(self.character_index))
        label_tensors = torch.zeros(len(cur_chunk), len(self.character_index))
        
        for i in range(len(cur_chunk)-1):
            char = cur_chunk[i]
            next_char = cur_chunk[i+1]
            
            input_tensors[i, self.character_index[char]] = 1
            label_tensors[i, self.character_index[next_char]] = 1
                        
        # How to handle the last chunk? Currently just all zeros.
        if i_chunk == self.num_chunks - 1:
            pass
        else:
            # Last char will have first char of next chunk
            last_i = len(cur_chunk) - 1
            
            cur_chunk_last_char = cur_chunk[last_i]
            next_chunk_first_char = self.text_chunks[i_chunk+1][0]
            
            input_tensors[last_i, self.character_index[cur_chunk[last_i]]] = 1
            label_tensors[last_i, self.character_index[next_chunk_first_char]] = 1
                
        return (input_tensors, label_tensors)
        
    @staticmethod
    def __load_text_chunks(filename, chunk_size):
        """
        Returns a list of strings each with length chunk_size
        """
        text_chunks = []
        reset_flags = []
        with open(filename) as f:
            text_blob = ""
            for line in f:
                if line.startswith("<start>"):
                    line = "%\n"
                elif line.startswith("<end>"):
                    line = "`\n"
                text_blob += line
            i = 0
            while i < len(text_blob):
                line = text_blob[i: i+chunk_size]
                if "`" in line:
                    print(line)
                text_chunks.append(text_blob[i:i+chunk_size])
                i += chunk_size
            return text_chunks
        raise RuntimeError("Can't read file")

In [41]:
"""
Returns a list of strings each with length chunk_size
"""
chunk_size = 100
text_chunks = []
reset_flags = []
with open(filename) as f:
    text_blob = ""
    for line in f:
        if line.startswith("<start>"):
            line = "%\n"
        elif line.startswith("<end>"):
            line = "`\n"
        text_blob += line
    i = 0
    while i < len(text_blob):
        line = text_blob[i: i+chunk_size]
        # chunk ends early with "`\n" as a signal for end of the character
        if "`" in line and line.index('`') < chunk_size - 1:
            text_chunks.append(text_blob[i: i + line.index('`') + 2])
            # signal the reset of the hidden layer at the end of the chunk training
            # reset flags also signal the dataloader to pad 0 at the end of the sequence
            reset_flags.append(True)
            i = i + line.index('`') + 2
            continue
        text_chunks.append(text_blob[i:i+chunk_size])
        if "`" == text_blob[i + chunk_size - 2]:
            reset_flags.append(True)
        else:
            reset_flags.append(False)
        i += chunk_size

In [60]:
a = torch.ones(5, 10)

In [61]:
torch.cat([a, torch.zeros(5, 10)])

tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])