In [12]:
%%capture
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import importlib as imp

from collections import namedtuple
from random import sample, shuffle
from functools import reduce
from itertools import accumulate
from math import floor, ceil, sqrt, log, pi
from tensorflow.keras import layers, utils, losses, models as mds, optimizers
from visualkeras import layered_view

if imp.util.find_spec('aggdraw'): import aggdraw
if imp.util.find_spec('tensorflow_addons'): from tensorflow_addons import layers as tfa_layers
if imp.util.find_spec('tensorflow_models'): from official.vision.beta.ops import augment as visaugment

In [11]:
# Dataset image size
IMG_SIZE = 264
N_CLASSES = 102

def preprocess(image, *args):
    image = tf.image.resize_with_pad(image, IMG_SIZE, IMG_SIZE)
    image /= 255
    return (image, *args)

train_ds, val_ds = tfds.load(
    'oxford_flowers102',
    split=['train', 'validation'],
    as_supervised=True,
    read_config=tfds.ReadConfig(try_autocache=False)
)

train_ds = train_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)

## Image Augmentation

In [7]:
augmenter = visaugment.RandAugment()

def randaug_pp(image, label):
    image = tf.cast(image*255, tf.uint8)
    image = augmenter.distort(image)
    image = tf.cast(image, tf.float32)
    image /= 255
    
    return image, label

tds = train_ds.map(randaug_pp, num_parallel_calls=tf.data.AUTOTUNE)

ModuleNotFoundError: No module named 'official'

In [13]:
augs = tf.keras.Sequential([
  layers.RandomFlip("horizontal"),
  layers.RandomRotation(0.1),
])

BATCH_SIZE = 2

tds = train_ds.batch(BATCH_SIZE).map(lambda x,y: (augs(x, training=True), y))
itr = iter(tds)
next(itr)

(<tf.Tensor: shape=(2, 264, 264, 3), dtype=float32, numpy=
 array([[[[0.00784314, 0.01176471, 0.        ],
          [0.00784314, 0.01176471, 0.        ],
          [0.00784314, 0.01176471, 0.        ],
          ...,
          [0.00540311, 0.01091428, 0.        ],
          [0.0041903 , 0.01176471, 0.        ],
          [0.00400661, 0.01176471, 0.        ]],
 
         [[0.00784314, 0.01176471, 0.        ],
          [0.00784314, 0.01176471, 0.        ],
          [0.00784314, 0.01176471, 0.        ],
          ...,
          [0.00453919, 0.01189561, 0.        ],
          [0.00382943, 0.01165428, 0.        ],
          [0.00551471, 0.0172912 , 0.00199538]],
 
         [[0.00784314, 0.01176471, 0.        ],
          [0.00784314, 0.01176471, 0.        ],
          [0.00784314, 0.01176471, 0.        ],
          ...,
          [0.01036594, 0.01712209, 0.00904413],
          [0.02969323, 0.04504927, 0.02606002],
          [0.05706767, 0.07749655, 0.05362134]],
 
         ...,
 
       