In [9]:
import os
import pickle
import random
import numpy as np

In [7]:
class DataGenerator():
    def __init__(self, batch_size, seq_length):
        self.batch_size = batch_size
        self.seq_len = seq_length
        
        dataset, labels, translation = self.load_dataset()
        
        # filter the dataset as per the seq length
        dataset_new, labels_new = [], []
        
        for i in range(len(dataset)):
            # 1 extra point for the output of the last input
            if len(dataset[i]) >= seq_length + 1:
                dataset_new.append(dataset[i])
                labels_new.append(labels[i])
        
        self.dataset = dataset_new
        self.translation = translation
        self.num_chars = len(self.translation)
        
        max_len = max(map(lambda x: len(x), labels))
        # pad each label to the maximum length
        eye = np.eye(self.num_chars, dtype = np.float32)
        
        # shape -> [shape(labels)[0, max_len + 1, num_chars]
        # max_len + 1 -> blank spaces padded with zeros
        self.labels = np.array([np.concatenate([np.eye(self.num_chars, dtype=np.float32)[l],
                                                np.zeros((max_len - len(l) + 1, self.num_chars),
                                                         dtype=np.float32)],
                                               axis=0)
                                for l in labels])
        self.max_len = max_len + 1
        # generate indices randomly without repeating the same number twice (replace = False)
        self.indices = np.random.choice(len(self.dataset), size = (batch_size,), replace = False)
        self.batches = np.zeros((batch_size, ))
    
    def next_batch(self):
        # 'batch_size' number of elements, with each element having 3 parameters (x, y, eos), with 'seq_len + 1' no. of elements
        data = np.zeros((self.batch_size, self.seq_len + 1, 3), dtype = np.float32)
        # 'batch_size' number of elements
        # each element having 'max_len' cols and 'num_chars' rows representing each character as one hot
        seq = np.zeros((self.batch_size, self.max_len, self.num_chars), dtype = np.float32)
        reset_states = np.ones((self.batch_size, 1), dtype=np.float32)
        needed = False
        for i in range(self.batch_size):
            if self.batches[i] + self.seq_len + 1 > len(self.dataset[self.indices[i]]):
                # any random int between 0 and len(dataset) - 1
                ni = random.randint(0, len(self.dataset) - 1)
                # set it to the indices[i]
                self.indices[i] = ni
                self.batches[i] = 0
                reset_states[i] = 0
                needed = True
            # now get the data at that index of length (seq_len + 1)
            data[i, :, :] = self.dataset[int(self.indices[i])][int(self.batches[i]):int(self.batches[i] + self.seq_len + 1)]
            seq[i] = self.labels[int(self.indices[i])]
            self.batches[i] += self.seq_len
        return data, seq, reset_states, needed
    
    def load_dataset(self):
        dataset_path = "data_parsed\dataset.npy"
        labels_path = "data_parsed\labels.npy"

        dataset = [np.array(d) for d in np.load(dataset_path)]
        dataset_final = []

        for d in dataset:
            # d[1:, :2] -> Get all x, y points from 2nd point till the end
            # d[:-1, :2] -> Get all x, y points from 1st point till the end excluding the last point

            offs = d[1:, :2] - d[:-1, :2]     # we will train on the difference of consecutive points

            # d[1:, 2] -> get the last column of d, indicates the end if 1, else 0
            ends = d[1:, 2]

            # np.concatenate([offs, ends[:, None]], axis=1) --> [0 0 0 ..... 1] -> [[0][0][0] ........ [1]]
            intmdt = np.concatenate([offs, ends[:, None]], axis=1)

            # adds [0 0 1] to the top of the matrix [[0 0 1][a b c][d e f]..........]
            final = np.concatenate([[[0., 0., 1.]], intmdt], axis=0)

            # add everything to the final dataset
            dataset_final.append(final)

        # get the labels
        labels = np.load(labels_path)

        # get the translation file
        with open(os.path.join('data_parsed', 'translation.pkl'), 'rb') as f:
            translation = pickle.load(f)

        return dataset_final, labels, translation