In [None]:
%load_ext autoreload
%autoreload

#### Prepare Dataset

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

import deeply.datasets as dd
from deeply.datasets.util import split as split_dataset
from deeply.generators import DatasetGenerator

In [None]:
cifar10, info = dd.load("cifar10", with_info = True, shuffle_files = True, as_supervised = True)

In [None]:
info

In [None]:
tfds.show_examples(cifar10["train"].take(3), ds_info = info)

In [None]:
# input_shape = (224, 224, 3) # 
input_shape = info.features["image"].shape
n_classes   = info.features["label"].num_classes
batch_size  = 1
epochs      = 10

In [None]:
def mapper(image, label):
#     image = tf.image.resize(image, (input_shape[0], input_shape[1]))
    image = image / 255
    
    return image, label

In [None]:
gen_kwargs = dict(batch_size = batch_size, mapper = mapper)
train, val = list(map(lambda x: DatasetGenerator(x, **gen_kwargs),
                                                 split_dataset(cifar10["train"], splits = (0.7, 0.3))))
test = DatasetGenerator(cifar10["test"], **gen_kwargs)

#### Build Model

In [None]:
from tensorflow.keras.losses  import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import SparseCategoricalAccuracy

from deeply.model.dam import DAM

In [None]:
dam = DAM(
    input_shape = input_shape,
    n_classes   = n_classes,
    batch_norm  = False
)

In [None]:
dam.compile(
    optimizer = "adam",
    loss      = SparseCategoricalCrossentropy(),
    metrics   = [SparseCategoricalAccuracy()]
)

In [None]:
steps_per_epoch  = len(train) // batch_size
validation_steps = len(val)   // batch_size

In [None]:
norm = lambda x: x / (n_classes - 1)

def meta_mapper(X, y):
    X = norm(tf.argmax(X, axis = 1))
    y = norm(y)
    
    return X, y

In [None]:
dam.fit(train, validation_data = val, verbose = 2, epochs = epochs,
        steps_per_epoch  = steps_per_epoch,
        validation_steps = validation_steps,
        meta_mapper      = meta_mapper,
        meta_epochs      = 50
)