In [12]:
from dataset import load_svhn, random_split_train_val
from layers import *
from trainer import *
from optim import *
from loss import CrossEntropy
from collections import OrderedDict
from layers import *
from operators import *
from model import NumberSortModule
import numpy as np

In [28]:
def one_hot(x: (np.ndarray, int), vocab_size: int) -> np.ndarray:
    assert isinstance(vocab_size, int), "vocab_size is an integer value"
    assert isinstance(x, (int, np.ndarray)), "unsupported type for one-hot encoding"
    if isinstance(x, int):
        assert x < vocab_size, "out of vocabulary"
        y = np.zeros(vocab_size)
        y[x] = 1
        return y
    
    assert x.dtype == np.int32, "unsupported x.dtype for one-hot encoding"
    
    y = np.eye(vocab_size)[x.ravel()]
    y = y.reshape((*x.shape, vocab_size))
    
    return y


def split_test_train(X, y, test_size=0.25):
    train_size = int(len(X) * (1 - test_size))

    train_X = X[:train_size, :]
    train_y = y[:train_size, :]

    test_X = X[train_size:, :]
    test_y = y[train_size:, :]

    return train_X, train_y, test_X, test_y


def generate_batch(batch_size: int = 32, seq_len: int = 10, max_num: int = 100) -> np.ndarray:
    while True:
        X = np.empty((batch_size, seq_len), dtype=np.int32)
        y = np.empty((batch_size, seq_len), dtype=np.int32)

        for batch_num in range(batch_size):
            sample, label = next(generate_sample_pointer(seq_len, max_num))
            X[batch_num] = sample
            y[batch_num] = label

        yield one_hot(X, vocab_size=max_num), y


def generate_sample_pointer(seq_len: int = 10, max_num: int = 100) -> np.ndarray:
    while True:
        X = np.random.randint(max_num, size=(seq_len))
#         y = np.empty((2, seq_len), dtype=np.int32)
#         y[0] = X
#         y[1] = np.arange(seq_len)
#         y = y[:, y[0].argsort()]
        y = np.sort(X, axis=-1)

        yield X, y # [1]


def create_dataset(num_samples: int, seq_len, max_num) -> Dataset:
    data, labels = next(generate_batch(num_samples, seq_len, max_num))
    return Dataset(*split_test_train(data, labels))


In [29]:
x, y = next(generate_batch(1, 5, 10))

In [30]:
x

array([[[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]]])

In [31]:
x.argmax(axis=-1)

array([[7, 9, 1, 0, 8]])

In [32]:
y

array([[0, 1, 7, 8, 9]], dtype=int32)

In [33]:
np.eye(10)[y]

array([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]])

In [34]:
vocab_size = 10

In [35]:
nsort = NumberSortModule(vocab_size, 10, 10, 1)

In [38]:
dataset = create_dataset(1000, 10, vocab_size)

trainer = Trainer(
    nsort,
    CrossEntropy(),
    dataset,
    Adam(nsort, learning_rate=0.001),
#     MomentumSGD(nsort, momentum=0.85, learning_rate=0.1),
    num_epochs=100,
#     learning_rate_decay=0.95
)

In [39]:
loss_history, train_history, val_history = trainer.fit()

Loss: 23.700835, Train accuracy: 0.098667, val accuracy: 0.104400
unsorted:  [0 1 7 0 2 3 2 0 7 8]
predicted: [0 0 0 0 0 0 0 0 0 0]
Loss: 23.705324, Train accuracy: 0.098667, val accuracy: 0.104400
unsorted:  [1 0 1 3 2 0 5 8 3 9]
predicted: [0 0 0 0 0 0 0 0 0 0]
Loss: 23.709898, Train accuracy: 0.098667, val accuracy: 0.104400
unsorted:  [3 2 2 7 8 6 0 0 6 8]
predicted: [0 0 0 0 0 0 0 0 0 0]
Loss: 23.714537, Train accuracy: 0.098667, val accuracy: 0.104400
unsorted:  [6 2 1 9 8 6 5 5 8 3]
predicted: [0 0 0 0 0 0 0 0 0 0]
Loss: 23.719242, Train accuracy: 0.098667, val accuracy: 0.104400
unsorted:  [9 8 6 8 4 0 5 7 6 9]
predicted: [0 0 0 0 0 0 0 0 0 0]
Loss: 23.724015, Train accuracy: 0.098667, val accuracy: 0.104400
unsorted:  [2 1 8 7 3 8 0 4 1 9]
predicted: [0 0 0 0 0 0 0 0 0 0]
Loss: 23.728855, Train accuracy: 0.098667, val accuracy: 0.104400
unsorted:  [6 2 1 9 8 6 5 5 8 3]
predicted: [0 0 0 0 0 0 0 0 0 0]


KeyboardInterrupt: 