In [1]:
from dataset import load_svhn, random_split_train_val
from layers import *
from trainer import *
from optim import *
from loss import *
from collections import OrderedDict
from layers import *
from operators import *
from model import NumberSortModule
import numpy as np

In [2]:
def one_hot(x: (np.ndarray, int), vocab_size: int) -> np.ndarray:
    assert isinstance(vocab_size, int), "vocab_size is an integer value"
    assert isinstance(x, (int, np.ndarray)), "unsupported type for one-hot encoding"
    if isinstance(x, int):
        assert x < vocab_size, "out of vocabulary"
        y = np.zeros(vocab_size)
        y[x] = 1
        return y
    
    assert x.dtype == np.int32, "unsupported x.dtype for one-hot encoding"
    
    y = np.eye(vocab_size)[x.ravel()]
    y = y.reshape((*x.shape, vocab_size))
    
    return y


def split_test_train(X, y, test_size=0.25):
    train_size = int(len(X) * (1 - test_size))

    train_X = X[:train_size, :]
    train_y = y[:train_size, :]

    test_X = X[train_size:, :]
    test_y = y[train_size:, :]

    return train_X, train_y, test_X, test_y


def generate_batch(batch_size: int = 32, seq_len: int = 10, max_num: int = 100) -> np.ndarray:
    while True:
        X = np.empty((batch_size, seq_len), dtype=np.int32)
        y = np.empty((batch_size, seq_len), dtype=np.int32)

        for batch_num in range(batch_size):
            sample, label = next(generate_sample_pointer(seq_len, max_num))
            X[batch_num] = sample
            y[batch_num] = label

        yield one_hot(X, vocab_size=max_num), one_hot(y, vocab_size=max_num)


def generate_sample_pointer(seq_len: int = 10, max_num: int = 100) -> np.ndarray:
    while True:
        X = np.random.randint(max_num, size=(seq_len))
#         y = np.empty((2, seq_len), dtype=np.int32)
#         y[0] = X
#         y[1] = np.arange(seq_len)
#         y = y[:, y[0].argsort()]
        y = np.sort(X, axis=-1)

        yield X, y # [1]


def create_dataset(num_samples: int, seq_len, max_num) -> Dataset:
    data, labels = next(generate_batch(num_samples, seq_len, max_num))
    return Dataset(*split_test_train(data, labels))


In [3]:
x, y = next(generate_batch(1, 5, 10))

In [4]:
x

array([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]]])

In [5]:
x.argmax(axis=-1)

array([[0, 5, 4, 2, 3]])

In [6]:
y

array([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]]])

In [7]:
y.argmax(axis=-1)

array([[0, 2, 3, 4, 5]])

In [8]:
vocab_size = 10

In [9]:
nsort = NumberSortModule(vocab_size, 10, 10, 1)

In [10]:
dataset = create_dataset(1000, 10, vocab_size)

trainer = Trainer(
    nsort,
    MSELoss(),
    dataset,
    Adam(nsort, learning_rate=0.001),
#     MomentumSGD(nsort, momentum=0.85, learning_rate=0.1),
    num_epochs=100,
#     learning_rate_decay=0.95
)

In [11]:
loss_history, train_history, val_history = trainer.fit()

Loss: 1.899502, Train accuracy: 0.101867, val accuracy: 0.100400
unsorted:  [0 8 8 0 8 8 8 3 0 5]
sorted:    [0 0 0 3 5 8 8 8 8 8]
predicted: [9 9 9 9 9 9 9 9 9 9]
Loss: 1.800659, Train accuracy: 0.101867, val accuracy: 0.100400
unsorted:  [8 2 1 2 4 4 9 4 3 0]
sorted:    [0 1 2 2 3 4 4 4 8 9]
predicted: [9 9 9 9 9 9 9 9 9 9]
Loss: 1.781310, Train accuracy: 0.101867, val accuracy: 0.100400
unsorted:  [8 8 0 1 0 2 1 6 9 7]
sorted:    [0 0 1 1 2 6 7 8 8 9]
predicted: [9 9 9 9 9 9 9 9 9 9]
Loss: 1.778379, Train accuracy: 0.101867, val accuracy: 0.100400
unsorted:  [5 3 7 0 4 9 8 6 6 6]
sorted:    [0 3 4 5 6 6 6 7 8 9]
predicted: [9 9 9 9 9 9 9 9 9 9]
Loss: 1.777834, Train accuracy: 0.101867, val accuracy: 0.100400
unsorted:  [6 3 7 6 3 9 6 9 1 7]
sorted:    [1 3 3 6 6 6 7 7 9 9]
predicted: [9 9 9 9 9 9 9 9 9 9]
Loss: 1.777605, Train accuracy: 0.101867, val accuracy: 0.100400
unsorted:  [4 4 3 4 0 7 8 6 1 0]
sorted:    [0 0 1 3 4 4 4 6 7 8]
predicted: [9 9 9 9 9 9 9 9 9 9]
Loss: 1.777419, 

Loss: 1.765982, Train accuracy: 0.112800, val accuracy: 0.116800
unsorted:  [2 2 2 6 3 4 0 2 6 1]
sorted:    [0 1 2 2 2 2 3 4 6 6]
predicted: [3 3 9 9 9 9 9 9 9 9]
Loss: 1.765610, Train accuracy: 0.135867, val accuracy: 0.138000
unsorted:  [1 5 0 7 6 7 5 5 2 7]
sorted:    [0 1 2 5 5 5 6 7 7 7]
predicted: [1 3 9 9 9 9 9 9 9 9]
Loss: 1.765233, Train accuracy: 0.135867, val accuracy: 0.138000
unsorted:  [8 1 6 2 6 8 2 1 6 4]
sorted:    [1 1 2 2 4 6 6 6 8 8]
predicted: [1 3 9 9 9 9 9 9 9 9]
Loss: 1.764850, Train accuracy: 0.135867, val accuracy: 0.138000
unsorted:  [6 2 8 9 9 1 8 9 0 2]
sorted:    [0 1 2 2 6 8 8 9 9 9]
predicted: [1 3 9 9 9 9 9 9 9 9]
Loss: 1.764461, Train accuracy: 0.135867, val accuracy: 0.138000
unsorted:  [2 8 2 6 2 0 0 9 0 7]
sorted:    [0 0 0 2 2 2 6 7 8 9]
predicted: [1 3 9 9 9 9 9 9 9 9]
Loss: 1.764066, Train accuracy: 0.135867, val accuracy: 0.138000
unsorted:  [6 5 6 2 3 7 8 3 8 3]
sorted:    [2 3 3 3 5 6 6 7 8 8]
predicted: [1 3 9 9 9 9 9 9 9 9]
Loss: 1.763666, 