In [None]:
%%capture
import operator

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import importlib as imp

from collections import namedtuple
from random import sample, shuffle
from functools import reduce
from itertools import accumulate
from math import floor, ceil, sqrt, log, pi
from matplotlib import pyplot as plt
from tensorflow.keras import layers, utils, losses, models as mds, optimizers

if imp.util.find_spec('aggdraw'): import aggdraw
if imp.util.find_spec('tensorflow_addons'): from tensorflow_addons import layers as tfa_layers
if imp.util.find_spec('tensorflow_models'): from official.vision.beta.ops import augment as visaugment
if imp.util.find_spec('tensorflow_probability'): from tensorflow_probability import distributions as tfd
if imp.util.find_spec('keras_tuner'): import keras_tuner as kt

In [None]:
# Dataset image size
IMG_SIZE = 264
N_CLASSES = 102

def preprocess(image, *args):
    image = tf.image.resize_with_pad(image, IMG_SIZE, IMG_SIZE)
    image /= 255
    return (image, *args)

train_ds, val_ds = tfds.load(
    'oxford_flowers102',
    split=['train', 'validation'],
    as_supervised=True,
    read_config=tfds.ReadConfig(try_autocache=False)
)

train_ds = train_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)

## N-Ways

In [None]:
BATCH_SIZE = 5
N_CLASSES = 102

def split_label(ways=2):
    def split_fn(x, y):
        zero_mask = tf.zeros(tf.shape(y), dtype=y.dtype)

        def label_fn(slot):
            slot_size = tf.constant(N_CLASSES//ways, dtype=y.dtype)
            start, end = slot*slot_size, (slot+1)*slot_size
            start_cond = tf.math.greater_equal(y, start)
            end_cond = tf.math.less(y, end)
            slot_y = tf.where(tf.logical_and(start_cond, end_cond), y-start+1, zero_mask)

            return slot_y
        
        y = tf.map_fn(label_fn, tf.range(ways, dtype=y.dtype), dtype=y.dtype)
        y = tf.unstack(y, axis=0)
        
        return (x, tuple(y))
    
    return split_fn

tds = train_ds.batch(BATCH_SIZE).map(split_label(3))
itr = iter(tds)
next(itr)[1]