# Naive Multi-Domain Embedding Creation

Knowledge distillation serves a purpose in model compression, but it can be applied for more general reasons. For example, suppose we want a network to mimic not a single network, but the output of several speciailist networks (each with above satisfactory results for their individual tasks). Under normal circumstances, the naive way of attempting to approach this task would be to create a combined dataset for all the $N$ tasks we are attempting to combine into a single model, 

$$\mathcal{D}_\text{comb} = ⋃\limits_{k=1}^{N}\mathcal{D}_k$$

And proceed with training on this fused dataset. But unfortunately, it doesn't seem as though squeezing performance from this type of dataset is quite difficult in practice due to inter-dataset properties, such as dataset imbalance, overlapping domains, and varying levels of task difficulty.

Let us consider specifically the task of __creating embeddings__. In other words, we seek to create a set of vectors that gives us information about how close two images are. When considering this type of task on a single consistent dataset, the performance is generally acceptable. However, because each task can be fairly easy or fairly difficult to discriminate individually, optimizing them all at the same time has been shown to lead to partial overfitting/underfitting of component datasets. This is, of course, generally undesirable so there must be another way of approaching this problem...

## Recall@k Metric

Recall is a fairly common method to evaluate the ability of a network to return results that are the most relevant to it's task. Because this network is an embedding network, finding the $k$ nearest neighbors and ensuring that these neighbors are only of the same class does exactly that. This falls in line with recommender systems and can basically be summarized as following:

$$ R@k = \frac{1}{Qk}\sum_{i=1}^{N_Q}\sum_{j=1}^{N_P}\delta\{\text{rank}(d_{ij})\le k\}\delta\{y_i=y_j\} $$

where 

$$
\delta\{\cdot\} = \begin{cases}
 1                   & \text{if } \cdot, \\
 0, & \text{otherwise.}
\end{cases}
$$

$$d_{ij} = ||e_{q_i}-e_{p_j}||_2$$

and

$$\text{rank}(d_{ij})$$ is the rank of an item-distance value in vector $d_i$ after sorting the $N_P$ candidate distance scores with ties broken randomly.

It looks somewhat complex, but in practice all this is saying is that we want to ensure that the $k$ closest embeddings to each of our queries, $q_i$, belong to the same class. This is a fairly decent way of measuring the average performance of a recommender system across all of its recommendations. Naturally our values are bounded between $[0,1]$. $0$ means the recommender system does not produce any relevant results, while $1$ means the recommender system produces only relevant results. [In the context of embeddings, this is more about producing a clear seperation between classes that is generalizable.]

In [None]:
%%capture
!pip install tensorflow-addons

In [None]:
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
import tensorflow_addons as tfa

In [None]:
# There is no Recall@k for the purpose of determining the best suggestions based
# off of some distance metric, so we can implement that here. Because our inputs
# can possibly be too large to accodomate into memory, this does it in mini-batches
# so it'll be quite slow.
# This class expects a dataset to be passed in that returns (image, label) pairs.
# It converts this into (embedding, label) pairs which are then used to calculate
# distances and rank the returns for the query.
class RecallAtK:
    def __init__(self, relModel, k, dataset):
        self.embedding = relModel
        self.toEmbed = dataset
        self.embedVecs = None
        self.kVal = k

        # perform embedding transformation
        self._calculate_embeddings()

    def _calculate_embeddings(self, batchSize = 1024):
        self.embedVecs = self.toEmbed.batch(batchSize).map(
                            lambda im, lab:(self.embedding(im), lab)).unbatch().cache()

        print("Pre-computing and caching embeddings... ", end='')
        for _ in self.embedVecs:
            pass
        print("Done.")

    '''
        rankQueries:
            Performs a query ranking over the original input dataset and calculates
            the recall appropriately.

        Input:
            queries : A TFDS containing only the list of queries to input.
            labels : The labels for the corresponding queries. If None, defaults
                    to what is present in the queries dataset.
            batchSize : Number of elements to calculate the L2 loss for
                        simultaneously
            repeatedDS : If True, the queries are a subset of the input DS, and
                         the algorithm must be modofied to prevent the trivial
                         solution from appearing (since it is guaranteed to be
                         skewed towards 1 if the example is in both subsets)
    '''
    def rankQueries(self, queries, labels = None, batchSize = 4096, subsetDS = False):
        if labels is None:
            labels = iter(queries.map(lambda im, lab: lab))

            # try extracting them just in case
            try:
                for _ in labels:
                    break # only need to make sure it exists
            except:
                raise ValueError("Labels must be passed in with queries in some capacity.")
            finally:
                labels = iter(queries.map(lambda im, lab: lab)) # remake the iter
                queries = queries.map(lambda im, lab: im) # readjust the queries now

        # Adjust k if repeated
        if subsetDS:
            k = self.kVal + 1
        else:
            k = self.kVal

        # Proceed with the calculation now given the inputs
        from tqdm import tqdm
        import heapq
        recallQList = list()

        # prebatch queries to process one-by-one
        queries = queries.batch(1)
        for query in tqdm(queries, total = tf.data.experimental.cardinality(queries).numpy()):
            qEmbed = self.embedding(query)
            qLabel = next(labels).numpy()
            minHeap = list()
            
            for embedCandidate in self.embedVecs.batch(batchSize):
                eCVecs, eCLabs = embedCandidate
                dists = tf.math.top_k(-1 * tf.reduce_sum((eCVecs-qEmbed)**2, axis = 1), k=k)

                for distInd, distVal in zip(list(dists.indices.numpy()), list(dists.values.numpy())):
                    # print(distInd, distVal)
                    if len(minHeap) < k:
                        heapq.heappush(minHeap, (distVal, eCLabs.numpy()[distInd]))
                    else:
                        if distVal > minHeap[0][0]:
                            heapq.heapreplace(minHeap, (distVal, eCLabs.numpy()[distInd]))

            # Calculate recall given the final heap representing the closest k
            # embedding distances (ignoring the first if it's a subset)
            if subsetDS:
                minHeap = [(-1 * dist, lab) for (dist, lab) in minHeap]
                heapq.heapify(minHeap)
                minHeap = minHeap[1:]
            recallQList.append(float(sum(map(lambda tup: tup[1] == qLabel, minHeap))) / self.kVal)

        # Then return the averaged recall per query
        return sum(recallQList) / len(recallQList)

## Image Embeddings & Knowledge Distillation

Knowledge distillation was already covered in another notebook, but recall that the purpose of knowledge distillation in classification networks was to ensure that one model try to mimic the output of another model by copying it's softmax output given some temperature parameter, $T$. In practice, this amounted to enforcing there to be some type of loss ($L_2$, $L_1$, Huber) between the softmax outputs of the network.

Something similar exists in the embedding side of things. The technique of enforcing the explicit embeddings to be close together is simply knowledge distillation. Unlike in the classification scenario, there is no inherent loss of significance for any of the output neurons of the teacher, and thus no concept of temperature is necessary for optimization. The implementation is the same as the previous knowledege distllation notebook but without the softmax so it will not be looked at here.

The first embedding intrinsic method of distillation is known as relational knowledge distillation. During training, instead of optimizing on single sets of images, there are pairs of distinct images chosen and the distance between the images is minimized (hence, relational). In other words, we would like to preserve the distances between pairs of images and not the explicit embedding locations themselves. Naturally, this method of optimization is no longer dependent on the model's embedding dimension nor will it severely fit unlike in the other method.

The second, more relevant, form of knowledege distillation for embeddings is known as [stochastic knowledge distillation](https://arxiv.org/pdf/2003.03701.pdf). Unlike before, we instead use the SNE objective to optimize on the __Gaussian-proportional__ separation between the teacher and student networks. Compared to RKD, it should be obvious that this method is less dependent on the embedding space size as it only cares about the relative separation amongst points within the same space. In practice, this results in several disarable factors compared to using RKD that will be covered in the next couple sections...

### Specialist Training : Triplet Learning

Before we begin with the creation of our overall embedding space, it would help to know what type of network we are looking for. In general, we would like to have a latent space where the classes are all distant from each other while trying to compact the embedding vectors corresponding to each class. Now, while training a network for classification can often result in a partitioning of the space, one of the biggest problems with the normal way of training is that it functions off of the use of the softmax function, a 1D pseudo-regularizer. This means that while the network will try to modify the logits to be as distant from one another, the overall effect is mitigated due to the regularization used.

To address this issue, there are a collection of networks that are designed to create these embedding spaces, with or without the class labels. They can perform the task in either a supervised or semi-supervised fashion, but we will only be looking at the former. Let us denote the problem as follows:



### Relational Knowledge Distillation

We begin with a simple implementation for RKD. Suppose we have our teacher model, $t$, and an untrained student model, $s$. Our model $t$ has been trained on some dataset $\mathcal{D}_t = \{x_1, x_2, ... , x_N\}$ that is accessible to our student model. We seek to minimize the distance between the separation of the output emebddings in $t$ and $s$ output space by performing the following optimization on a set of (ideally) $\frac{N(N-1)}{2}$ distinct pairs of images,  $D_{\text{train}}=\{(x_i, x_j)|(x_i\ne x_j) \wedge (i < j)\}$.

$$ W^* = \arg\min_W \mathcal{L}\left(d_t, d_s\right) $$

where

$$\mathcal{L} = L_1(d_t, d_s) = \left|d_t-d_s\right| $$ 

or

$$ \mathcal{L} = L_\delta(d_t, d_s) = \begin{cases}
 \frac{1}{2}(d_t - d_s)^2            & \text{for } |d_t - d_s| \le \delta, \\
 \frac{1}{2}\delta^2\ + \delta \left(|d_t - d_s| - \delta\right), & \text{otherwise.}
\end{cases} $$

and 

$$ d_t = \frac{1}{\mu_t}||t(x_i)-t(x_j)||_2 $$
$$ d_s = \frac{1}{\mu_s}||s(x_i)-s(x_j)||_2 $$

$\mu_s$ and $\mu_t$ represent the average difference in the batch so as to match update magnitudes.

In [None]:
# Load in our bird data for this first application
dsTrain, dsTest = tfds.load("caltech_birds2011", split = ['train', 'test'], as_supervised=True)

# Decides input image size for the network
INPUT_SHAPE = (256, 256, 3)
CROP_SHAPE = (224, 224, 3)

# Normalize image values
dsTrain = dsTrain.map(lambda im, lab:(tf.image.resize(tf.cast(im, dtype=tf.float32)/255., INPUT_SHAPE[:-1]), lab))
dsTest = dsTest.map(lambda im, lab:(tf.image.resize(tf.cast(im, dtype=tf.float32)/255., INPUT_SHAPE[:-1]), lab))

In [None]:
# Creates the image preprocessing layers
rcLayer = tf.keras.layers.RandomCrop(height = CROP_SHAPE[0], width = CROP_SHAPE[1])
flipLayer = tf.keras.layers.RandomFlip(mode = "horizontal")
preprocFunctor = lambda inLayer:flipLayer(rcLayer(inLayer))

# In order to implement this, we need a large teacher and a smaller student that
# has been pretrained on the same dataset to perform this calculation. For 
# convenience, we choose ResNet50 as our teacher and a small custom ResNet model
# as our student
tf.keras.backend.clear_session() # erases old models in memory

def teacherNetworkFunc():
    initResNet =  tf.keras.applications.resnet50.ResNet50(input_shape = CROP_SHAPE, include_top = False, weights = 'imagenet')
    tInLayer = tf.keras.Input(INPUT_SHAPE)
    procIms = preprocFunctor(tInLayer)
    projLayer = tf.keras.layers.GlobalAveragePooling2D()(initResNet(procIms))
    projLayer = tf.keras.layers.Dense(128, activation = 'linear')(projLayer)
    projLayer = tf.keras.layers.Lambda(lambda inLayer : tf.linalg.normalize(inLayer, ord = 2, axis = 1)[0])(projLayer)
    teacherModel = tf.keras.Model(inputs = tInLayer, outputs = projLayer)

    return teacherModel 

teacherModel = teacherNetworkFunc()
teacherModel.summary()

# and the smaller student...
def studentNetworkFunc():
    initMobileNet = tf.keras.applications.MobileNetV3Large(input_shape = CROP_SHAPE, include_top = False, weights = 'imagenet', include_preprocessing = False)
    sInLayer = tf.keras.Input(INPUT_SHAPE)
    procIms = preprocFunctor(sInLayer)
    projLayer = tf.keras.layers.GlobalAveragePooling2D()(initMobileNet(procIms))
    projLayer = tf.keras.layers.Dense(128, activation = 'linear')(projLayer)
    projLayer = tf.keras.layers.Lambda(lambda inLayer : tf.linalg.normalize(inLayer, ord = 2, axis = 1)[0])(projLayer)
    studentModel = tf.keras.Model(inputs = sInLayer, outputs = projLayer)

    return studentModel

studentModel = studentNetworkFunc()
studentModel.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 256, 256, 3)]     0         
                                                                 
 random_crop (RandomCrop)    (None, 224, 224, 3)       0         
                                                                 
 random_flip (RandomFlip)    (None, 224, 224, 3)       0         
                                                                 
 resnet50 (Functional)       (None, 7, 7, 2048)        23587712  
                                                                 
 global_average_pooling2d (G  (None, 2048)             0         
 lobalAveragePooling2D)                                          
                                                                 
 dense (Dense)               (None, 128)               262272    
                                                             

In [None]:
# Sets up our pseudo-callback for adjusting the loss function as described above
class TripletLossSwappingCallback(tf.keras.callbacks.EarlyStopping):
    '''
        TripletLossSwappingCallback
            By subclassing the EarlyStopping callback, we can choose to further
            change a couple values associated with training before continuing
            with training.

        Input:
            lossFunc : The loss to swap to following the EarlyStopping flag
                        raising.
            
            [INHERITED KWARGS]
            monitor : The value to monitor from the model
            min_delta : The minimum value to count as an improvement
            patience : The number of steps to wait before raising the flag
    '''
    def __init__(self, lossFunc, **kwargs):
        super(TripletLossSwappingCallback, self).__init__(**kwargs)
        self.toSwap = lossFunc
        self.swapped = False

    def on_epoch_end(self, epoch, logs = None):
        # Process original callback values
        super(TripletLossSwappingCallback, self).on_epoch_end(epoch, logs)

        # Modify flag used to stop training to instead manipulate model loss
        # but only on the first flag
        if self.model.stop_training and not self.swapped:
            print("\n\nSwapping loss function and restarting training...\n\n")
            self.model.loss = self.toSwap
            self.swapped = True             # prevent further blocking
            self.model.stop_training = False # continue with new round of training
            super(TripletLossSwappingCallback, self).on_train_begin() # reset callback

In [None]:
# Set up model hyperparams
MARGIN_VAL = .2
LR_VAL = 1e-5
BATCH_SIZE = 128

lossFunc = tfa.losses.TripletSemiHardLoss(margin = MARGIN_VAL)
harderLossFunc = tfa.losses.TripletHardLoss(margin = MARGIN_VAL)
opt = tf.keras.optimizers.Adam(learning_rate = LR_VAL)
lossSwapper = TripletLossSwappingCallback(harderLossFunc, monitor = 'loss', min_delta = 1e-5, patience = 5)
teacherModel.compile(loss = lossFunc, optimizer = opt)

--------------------- OPTIONAL GDRIVE LOADING ---------------------------------

In [None]:
# You can optionally load in weights from gdrive if necessary
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# And use our drive path to load weights!
!cp ./drive/MyDrive/ResNet50NetworkWeights/CUB-HARD/* ./

In [None]:
teacherModel.load_weights("./teacherWeightsHARD")

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f2982e4c2d0>

-------------------------------------------------------

In [None]:
# Fine tune the larger teacher model using the input dataset via triplet loss
NUM_EPOCHS_PRETRAIN = 100

teacherModel.fit(dsTrain.shuffle(10000).batch(BATCH_SIZE), epochs = NUM_EPOCHS_PRETRAIN, callbacks = [lossSwapper])

Epoch 1/100
 6/47 [==>...........................] - ETA: 50s - loss: 1.3755e-04



Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100

Swapping loss function and restarting training...


Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100


<keras.callbacks.History at 0x7f3080ce9890>

------------------------- OPTIONAL GDRIVE SAVING -------------------------

In [None]:
# We also go ahead and save the model as it does take time to train it to a sufficiently
# acceptable level.
teacherModel.save_weights("./teacherWeightsHARD")

-------------------------------------------------------------

In [None]:
# Perform Recall@1 test-evaluation now over entire testing birds dataset...
rAK1 = RecallAtK(teacherModel, 1, dsTest)
res = rAK1.rankQueries(dsTest)
print("Average R@1 for teacher:{:.3f}".format(res))

# And on the student to measure initial performance
rAK1Stud = RecallAtK(studentModel, 1, dsTest)
resStud = rAK1Stud.rankQueries(dsTest)
print("Average R@1 for student:{:.3f}".format(resStud))

Average R@1 for teacher:0.689
Average R@1 for student:0.278


So, the recall for our specialist is fairly decent. It returns a relevant result nearly 69% of the time while the ImageNet initiated MobileNet only returns if 28% of the time. We do keep in mind that these networks have a different set of parameters, but in the end the goal is still to try and achieve a similar-performing embedding space. The specialist could be improved with further training, but this is sufficient to generate results for the purposes of implementing RKD.

[Note that while overfitting with triplet mining is generally not too large of an issue due to how the margin is user specified, should there be any values that are "much harder" than normal to classify, that could lead to an embedding space that may not properly generalize to a test set.]

#### TPU Model Declaration

(_At this point you may also run the appendix section to save the values for TPU usage. After doing so, restart the notebook so that the following cells can properly initialize the TPU/GPU strategy for training. This is necessary to increase the limited batch size we can get because of the constraints the GPU apply on us such as the ~20gb memory limit including the model weights._)

In [None]:
# Imports all of our necessary tools for this portion
import tensorflow as tf
import tensorflow_datasets as tfds

In [None]:
# initialize TPU cluster for use
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
    print("Device:", tpu.master())
    strategy = tf.distribute.TPUStrategy(tpu)

    # functor in case we need to reset TPU training progress
    def resetTrainingMemory():
        hw_accelerator_handle = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.tpu.experimental.initialize_tpu_system(hw_accelerator_handle)

except ValueError:
    print("Not connected to a TPU runtime. Using default (CPU/GPU) strategy")
    strategy = tf.distribute.get_strategy()

    def resetTrainingMemory():
        tf.keras.backend.clear_session()

print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

Device: grpc://10.74.122.50:8470
Number of devices: 8


In [None]:
# Redeclare student down here for use without preproc layer
def studentNetworkFunc(input_shape):
    sInLayer = tf.keras.Input(input_shape)
    # initMobileNet = tf.keras.applications.MobileNetV3Large(input_shape = input_shape, include_top = False, weights = 'imagenet', include_preprocessing = False)
    initMobileNet = tf.keras.applications.resnet50.ResNet50(input_shape = input_shape, include_top = False, weights = 'imagenet')
    projLayer = tf.keras.layers.GlobalAveragePooling2D()(initMobileNet(sInLayer))
    projLayer = tf.keras.layers.Dense(128, activation = 'linear')(projLayer)
    projLayer = tf.keras.layers.Lambda(lambda inLayer : tf.linalg.normalize(inLayer, ord = 2, axis = 1)[0])(projLayer)
    studentModel = tf.keras.Model(inputs = sInLayer, outputs = projLayer)

    return studentModel

In [None]:
# Now apply RKD to the student model in order to evalute its performance.
class RKDModel(tf.keras.Model):
    '''
        RKDModel is a class used to perform a relational knowledge distillation
        on a dataset. Given a teacher and a student, enforce the student to 
        minimize the L2 or Huber distance between points.

        Inputs:
            studentModel : The initializer for the model to train (generally smaller)
            delta : On the off chance that training must be performed with the Huber
                    loss, this allows you to specify the input for that.
    '''
    def __init__(self, studentModel, delta = tf.constant(1, dtype=tf.int64)):
        super(RKDModel, self).__init__(inputs = studentModel.inputs, outputs = studentModel.outputs)
        self.delta = delta

        # Set up metrics
        self.lossTracker = tf.keras.metrics.Mean(name='loss')
        self.valLossTracker = tf.keras.metrics.Mean(name='val_loss')

    '''
        compile
            Overloads the typical model compilation function. Takes in a loss
            function (or "Huber" / "L1") in order to compute the distance metrics.
            This step also sets up the model for usage.

        Inputs:
            loss : A tf.keras.Loss object or the above two specific relevant losses.
            optimizer : Any tf.keras.Optimizer object to train the student model.
                        If left undefined, it will default to 'adam'
            delta : A hyperparameter for the Huber loss. Only used when 'Huber'
                    is passed in as a loss parameter.
    '''
    def compile(self, loss, optimizer, **kwargs):
        if isinstance(loss, tf.keras.losses.Loss):
            self.loss = loss
        elif loss == "Huber":
            self.loss = tf.keras.losses.Huber(delta = self.delta, reduction = tf.keras.losses.Reduction.SUM)
        elif loss == "L1":
            self.loss = tf.keras.losses.MeanAbsoluteError(reduction = tf.keras.losses.Reduction.SUM)
        else: # We are unsure of the loss in this case, so throw an error
            raise ValueError("Loss is ill-defined.")
            
        return super(RKDModel, self).compile(loss = self.loss, optimizer = optimizer, **kwargs)

    '''
        This defines model metrics for automatic resetting.
    '''
    @property
    def metrics(self):
        return [self.lossTracker, self.valLossTracker]

    '''
        train_step
            Proceeds with a single pass of the algorithm and performs the
            backpropagation and gradient update for the student model.

            Note that the input is expected to be in the format of
                       ((image, logits), ((image, logits))
            tuples. 
    '''
    def train_step(self, data):
        (imageL, tEmbedL), (imageR, tEmbedR) = data

        # Calculate inter-class diff vectors dt, ds
        with tf.GradientTape() as tape:
            sEmbedL = self(imageL, training = True)
            sEmbedR = self(imageR, training = True)
            sDiffs = tf.linalg.norm(sEmbedL - sEmbedR, ord = 2, axis = 1)
            sDiffsNormed = sDiffs / tf.reduce_mean(sDiffs)

            tDiffs = tf.linalg.norm(tEmbedL - tEmbedR, ord = 2, axis = 1)
            tDiffsNormed = tDiffs / tf.reduce_mean(tDiffs)
            totalLoss = self.loss(tDiffs, sDiffs)

        # Proceed with backpropagation
        grads = tape.gradient(totalLoss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        # update metrics
        self.lossTracker.update_state(totalLoss)

        return {"loss":self.lossTracker.result()}

    '''
        test_step
            Performs a forward pass on a test dataset. Does not perform any
            backpropagation.

            Note that the input is expected to be in the format of
                        ((image, logits), ((image, logits))
            tuples. This only gets called for a validation or testing step.
    '''
    def test_step(self, data):
        (imageL, tEmbedL), (imageR, tEmbedR) = data

        # Perform forward pass with no tape
        sEmbedL = self(imageL)
        sEmbedR = self(imageR)
        sDiffs = tf.linalg.norm(sEmbedL - sEmbedR, ord = 2, axis = 1)
        sDiffsNormed = sDiffs / tf.reduce_mean(sDiffs)

        tDiffs = tf.linalg.norm(tEmbedL - tEmbedR, ord = 2, axis = 1)
        tDiffsNormed = tDiffs / tf.reduce_mean(tDiffs)
        totalLoss = self.loss(tDiffs, sDiffs)

        # update val metrics
        self.valLossTracker.update_state(totalLoss)

        return {"val_loss" : self.valLossTracker.result()}

    '''
        prepare_dataset
            Because the dataset is required to be in a particular format due to
            the smaller size of the GPU being used, we go ahead and prepare that
            beforehand. This takes in a dataset of just the input images and
            produces two datasets to use as the (x,y) value pairs. To be exact,
            the output would be two datasets with the following output:
                            (images, teacher_embed)
        Inputs
            dataIn : The dataset to process into the desired format
            teacherModel : The model to prepare the examples from
            batchSize : The number of elements to batch the results by
            zipped : Determines whether the returned dataset is fully prepared
                     with the zipping process or if the single processed dataset
                     should be returned. Useful for saving the dataset to a GCP
                     cluster using less space.

        Outputs
            Returns two datasets with the above specifications for use in the
            neural network. This does cache values as well so it can take some
            time for the cache to take effect.
    '''
    @staticmethod
    def prepare_dataset(dataIn, teacherModel, batchSize = 1024, zipped = True):
        # form base cached dataset
        baseDS = dataIn.batch(batchSize).map(lambda im, lab:(im, teacherModel(im))).unbatch().cache()

        # Executing both models will cause a significant performance degradation
        # so we process it early...
        print("Pre-caching teacher logits...")
        for _ in baseDS:
            pass
        print("Finished")

        # With elements cached, we can return our new datasets now
        if zipped:
            return RKDModel.repeat_and_zip_dataset(baseDS, batchSize)
        else:
            return baseDS


    '''
        repeat_and_zip_dataset
            Takes a dataset and turns it into an infinitely long dataset that
            returns values picked from (D^2) [aka all combinations]

        Inputs
            baseDS : The dataset to process into the desired format
            batchSize : The number of elements to batch the results by

        Outputs
            Returns the dataset holding an infinitely long selection of
            combinations found in the cartesian product of the dataset with itself.
    '''
    @staticmethod
    def repeat_and_zip_dataset(baseDS, batchSize):
        ds1 = baseDS.repeat().shuffle(10000)
        ds2 = baseDS.repeat().shuffle(10000)
        return tf.data.Dataset.zip((ds1, ds2)).filter(
                                                lambda l,r : tf.reduce_any(l[0] != r[0])
                                             ).batch(batchSize)

#### Training / Execution
With all that function prep out of the way, we can begin to declare hyperparameters and deal with the actual training. Go to the (Addendum - Loading Dataset for TPU Usage) section and load the dataset fully and come back here for the training operations.

In [None]:
# Cardinality actually doesn't work here so we need to iterate through the dataset
numTrainSamples = 0
numTestSamples = 0
for _ in dsTrain:
    numTrainSamples += 1
for _ in dsTest:
    numTestSamples += 1
print("Loaded {} training samples and {} test samples.".format(numTrainSamples, numTestSamples))

Exception ignored in: <function Executor.__del__ at 0x7fab98acd560>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/executor.py", line 46, in __del__
    self.wait()
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/executor.py", line 65, in wait
    pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence


Loaded 5994 training samples and 5794 test samples.


Exception ignored in: <function Executor.__del__ at 0x7fab98acd560>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/executor.py", line 46, in __del__
    self.wait()
  File "/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/executor.py", line 65, in wait
    pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence


In [None]:
# Hyperparms for next few sections. Feel free to modify them
BATCH_SIZE_PER_REPLICA = 64
NUM_EPOCHS = 50
STEPS_PER_EXEC = 20 # number of steps to perform before updating metric
LEARNING_RATE = 1e-5
LOSS_TYPE = "Huber"  # this can either be huber or L1 loss

# Don't touch these hyperparams unless necessary, however
GLOBAL_BATCH_SIZE = int(BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync)
STEPS_PER_EPOCH = int(numTrainSamples + GLOBAL_BATCH_SIZE - 1) // GLOBAL_BATCH_SIZE
VAL_STEPS = int(numTestSamples // GLOBAL_BATCH_SIZE)
STEPS_PER_EXEC = min(STEPS_PER_EPOCH, STEPS_PER_EXEC) # steps / execution can't be higher than max step count
IMAGE_SIZE = (224, 224, 3)

# Prepare dataset using class function
preppedDS = RKDModel.repeat_and_zip_dataset(dsTrain, GLOBAL_BATCH_SIZE)
preppedValDS = RKDModel.repeat_and_zip_dataset(dsTest, GLOBAL_BATCH_SIZE)

# And distribute them
distTrainDS = strategy.experimental_distribute_dataset(preppedDS)
distValDS = strategy.experimental_distribute_dataset(preppedValDS)

# Prepare the student for learning
resetTrainingMemory()
with strategy.scope():
    studentModel = studentNetworkFunc(IMAGE_SIZE)
    rkdStudent = RKDModel(studentModel)
    rkdOpt = tf.keras.optimizers.Adam(learning_rate = LEARNING_RATE)
    rkdStudent.compile(loss = LOSS_TYPE, optimizer = rkdOpt, steps_per_execution = STEPS_PER_EXEC)
rkdStudent.summary()



Model: "rkd_model_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_7 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 resnet50 (Functional)       (None, 7, 7, 2048)        23587712  
                                                                 
 global_average_pooling2d_3   (None, 2048)             0         
 (GlobalAveragePooling2D)                                        
                                                                 
 dense_3 (Dense)             (None, 128)               262272    
                                                                 
 lambda_3 (Lambda)           (None, 128)               0         
                                                                 
Total params: 23,849,988
Trainable params: 23,796,864
Non-trainable params: 53,124
______________________________________

In [None]:
# And finally train the network!
rkdHis = rkdStudent.fit(distTrainDS, epochs = NUM_EPOCHS, steps_per_epoch = STEPS_PER_EPOCH)

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


In [None]:
# Set up our weight saving via checkpointing
checkpoint = tf.train.Checkpoint(model=rkdStudent)
localOpts = tf.train.CheckpointOptions(experimental_io_device="/job:localhost")
checkpoint.write("studentWeights_rkd50", options=localOpts)

'studentWeights_rkd50'

#### Evaluation

We are now done with the TPU. You may reset the notebook to re-initialize a GPU instance so that the R@1 value can be processed. Keep in mind that it may take some time before the weights above are downloaded, so make sure you don't kill the instance before then.

In [None]:
# reimports all of our typical datasets
import tensorflow as tf
import tensorflow_datasets as tfds
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# reload our original dataset
dsTrain, dsTest = tfds.load("caltech_birds2011", split = ['train', 'test'], as_supervised=True)

# Decides input image size for the network
INPUT_SHAPE = (256, 256, 3)
CROP_SHAPE = (224, 224, 3)

# Normalize image values
dsTrain = dsTrain.map(lambda im, lab:(tf.image.resize(tf.cast(im, dtype=tf.float32)/255., INPUT_SHAPE[:-1]), lab))
dsTest = dsTest.map(lambda im, lab:(tf.image.resize(tf.cast(im, dtype=tf.float32)/255., INPUT_SHAPE[:-1]), lab))

Downloading and preparing dataset 1.11 GiB (download: 1.11 GiB, generated: 1.11 GiB, total: 2.22 GiB) to ~/tensorflow_datasets/caltech_birds2011/0.1.1...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/5994 [00:00<?, ? examples/s]

Shuffling ~/tensorflow_datasets/caltech_birds2011/0.1.1.incompleteOMD7B9/caltech_birds2011-train.tfrecord*...:…

Generating test examples...:   0%|          | 0/5794 [00:00<?, ? examples/s]

Shuffling ~/tensorflow_datasets/caltech_birds2011/0.1.1.incompleteOMD7B9/caltech_birds2011-test.tfrecord*...: …

Dataset caltech_birds2011 downloaded and prepared to ~/tensorflow_datasets/caltech_birds2011/0.1.1. Subsequent calls will reuse this data.


In [None]:
#@title Re-inits student model and RKD metric class
class RecallAtK:
    def __init__(self, relModel, k, dataset):
        self.embedding = relModel
        self.toEmbed = dataset
        self.embedVecs = None
        self.kVal = k

        # perform embedding transformation
        self._calculate_embeddings()

    def _calculate_embeddings(self, batchSize = 1024):
        self.embedVecs = self.toEmbed.batch(batchSize).map(
                            lambda im, lab:(self.embedding(im), lab)).unbatch().cache()

        print("Pre-computing and caching embeddings... ", end='')
        for _ in self.embedVecs:
            pass
        print("Done.")

    '''
        rankQueries:
            Performs a query ranking over the original input dataset and calculates
            the recall appropriately.

        Input:
            queries : A TFDS containing only the list of queries to input.
            labels : The labels for the corresponding queries. If None, defaults
                    to what is present in the queries dataset.
            batchSize : Number of elements to calculate the L2 loss for
                        simultaneously
            repeatedDS : If True, the queries are a subset of the input DS, and
                         the algorithm must be modofied to prevent the trivial
                         solution from appearing (since it is guaranteed to be
                         skewed towards 1 if the example is in both subsets)
    '''
    def rankQueries(self, queries, labels = None, batchSize = 4096, subsetDS = False):
        if labels is None:
            labels = iter(queries.map(lambda im, lab: lab))

            # try extracting them just in case
            try:
                for _ in labels:
                    break # only need to make sure it exists
            except:
                raise ValueError("Labels must be passed in with queries in some capacity.")
            finally:
                labels = iter(queries.map(lambda im, lab: lab)) # remake the iter
                queries = queries.map(lambda im, lab: im) # readjust the queries now

        # Adjust k if repeated
        if subsetDS:
            k = self.kVal + 1
        else:
            k = self.kVal

        # Proceed with the calculation now given the inputs
        from tqdm import tqdm
        import heapq
        recallQList = list()

        # prebatch queries to process one-by-one
        queries = queries.batch(1)
        for query in tqdm(queries, total = tf.data.experimental.cardinality(queries).numpy()):
            qEmbed = self.embedding(query)
            qLabel = next(labels).numpy()
            minHeap = list()
            
            for embedCandidate in self.embedVecs.batch(batchSize):
                eCVecs, eCLabs = embedCandidate
                dists = tf.math.top_k(-1 * tf.reduce_sum((eCVecs-qEmbed)**2, axis = 1), k=k)

                for distInd, distVal in zip(list(dists.indices.numpy()), list(dists.values.numpy())):
                    # print(distInd, distVal)
                    if len(minHeap) < k:
                        heapq.heappush(minHeap, (distVal, eCLabs.numpy()[distInd]))
                    else:
                        if distVal > minHeap[0][0]:
                            heapq.heapreplace(minHeap, (distVal, eCLabs.numpy()[distInd]))

            # Calculate recall given the final heap representing the closest k
            # embedding distances (ignoring the first if it's a subset)
            if subsetDS:
                minHeap = [(-1 * dist, lab) for (dist, lab) in minHeap]
                heapq.heapify(minHeap)
                minHeap = minHeap[1:]
            recallQList.append(float(sum(map(lambda tup: tup[1] == qLabel, minHeap))) / self.kVal)

        # Then return the averaged recall per query
        return sum(recallQList) / len(recallQList)

# Creates the image preprocessing layers
rcLayer = tf.keras.layers.RandomCrop(height = CROP_SHAPE[0], width = CROP_SHAPE[1])
flipLayer = tf.keras.layers.RandomFlip(mode = "horizontal")
preprocFunctor = lambda inLayer:flipLayer(rcLayer(inLayer))

def studentNetworkFunc(input_shape):
    initResNet =  tf.keras.applications.resnet50.ResNet50(input_shape = CROP_SHAPE, include_top = False, weights = 'imagenet')
    sInLayer = tf.keras.Input(INPUT_SHAPE)
    procIms = preprocFunctor(sInLayer)
    projLayer = tf.keras.layers.GlobalAveragePooling2D()(initResNet(procIms))
    projLayer = tf.keras.layers.Dense(128, activation = 'linear')(projLayer)
    projLayer = tf.keras.layers.Lambda(lambda inLayer : tf.linalg.normalize(inLayer, ord = 2, axis = 1)[0])(projLayer)
    studentModel = tf.keras.Model(inputs = sInLayer, outputs = projLayer)

    return studentModel

In [None]:
# import our gdrive checkpoint into the workspace
!cp ./drive/MyDrive/ResNet50NetworkWeights/XFER-Student-50-TPU/* ./

In [None]:
# Also time to re-initialize the student network
rkdStudent = studentNetworkFunc(input_shape = INPUT_SHAPE)

# And then load in our checkpoint via gdrive
checkpoint = tf.train.Checkpoint(rkdStudent)
checkpoint.restore('./studentWeights_rkd50')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fbeae4aaa10>

In [None]:
# And we can evaluate the recall performance after training the student...
rAK1Stud = RecallAtK(rkdStudent, 1, dsTest)
resStud = rAK1Stud.rankQueries(dsTest, subsetDS = True)
print("\nAverage R@1 for student:{:.3f}".format(resStud))

Pre-computing and caching embeddings... Done.


100%|██████████| 5794/5794 [19:31<00:00,  4.95it/s]


Average R@1 for student:0.024





In [None]:
# We are now done for now
from google.colab import runtime
runtime.unassign()

### Stochastic Knowledge Distillation

Stochastic knowledge distillation functions a bit differently. Much like in the T-SNE notebook, we instead are more interested in treating our points in our lower-dimensional spaces as distributions. Assuming the teacher's distributon of points in its latent space sufficiently describe the input data, we would like the student to learn a similar distribution.

One thing to keep in mind is that this process requires quite a significant amount of memory. It must calculate the $O(N^2)$ distance arrays for BOTH the high and low dimensional data. In general, that not only means that our batch sizes will be low, but it also greatly limits the amount of information that the network will be receiving through each batch... Because of this, a TPU solution will be considered for this application. This solution will be heavily based off of the previous parametric T-SNE TPU solution in the other notebook.

# FAQ



*   Unknown model compilation error when executing `fit()` using the model but not when using a single batch on `fit_on_batch()`.
    * This took far more effort then was necessary to fix... One of the biggest things to keep in mind is function return values. Not all functions are guaranteed to report errors as throwing exceptions for simple functions is often an easy form of bloat. In this case, using the `tf.data.experimental.cardinality()` function actually states that it IS possible for the function to return a value of `-1` or `-2` if the dataset is either infinite or the number of elements cannot be determined prior to run-time. As a result of the value returning a negative number, a negative step count was passed to `fit()` and it bugged out in a way that reported the issue coming from the model itself as opposed to the dataset processing.
    * In our case, reading the TFRecords does not actually tell us the number of elements due to the way it parses the values on run-time (_even when cached!_), so the elements must be counted individually. Note that the TPU node is unfriendly towards non-infinite datasets so it will throw a non-critical exception once it parses the entire dataset.

* RKD directive is decreasing but the model collapses (R@1 values are deteriorating / not changing)
    * One negative aspect of knowledge distillation is that without some handholding (ie. inner layer feature-wise loss), the model can often improperly update itself in a way that would be rather hard to notice in a dataset that should have $N^2-N$ elements (many of which are normalized to be about $\sqrt{2}$ apart due to the spherical constraint). If this happens to be the case, it is far more serviceable to use a feature-loss or simply append the original problem as part of the total loss function to give some overall directive.
    * [This](https://arxiv.org/pdf/1706.07567.pdf) paper seems to give a good overview of how metric learning should be performed. Upon reading this, it does make it quite clear that sampling should be of utmost performance for proper network performance here. This won't be covered here, however, and, at best, we'll simply stratify the classes sampled.

$ $

# Addendum

These are things that had to be performed but are shifted out of view as they are less relevant to the central algorithm...

## Creating dataset for TPU usage

One big part about TPU training is that the dataset has to be accessible from either a public GCP bucket or a bucket configured to allow the TPU server access. The simplest way to work with a GCP bucket would be to allow it to be visible to all, but it's generally considered an unsafe practice for sensitive data. Instead, by initiating our training and using the server identity from the error to grant access in the GCP access management UI, we are able to allow only the TPU cluster to use the files.

Because the images should be consistent between batches sent to both networks, we can no longer allow randomness to affect the respective images each network receives. Under a TPU environment it is possible to perform all the calculations on the spot, but for simplicity the dataset was pre-processed in such a way that the randomness is introduced from the dataset generation.

Below was the process used to save the preprocessed dataset to a GCP bucket for TPU training.

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
# Import our old protobuf funcs here now
def parseTensorToBytes(inTensor):
    bytesArr = tf.io.serialize_tensor(inTensor)
    return tf.train.BytesList(value=[bytesArr.numpy()])

def parseImageToBytes(inImage):
    bytesArr = tf.io.encode_jpeg(inImage)
    return tf.train.BytesList(value=[bytesArr.numpy()])

# creates the protobuf
def createExampleBuff(image, logits):
    featureDict = {"image": tf.train.Feature(bytes_list = parseImageToBytes(image)),
                "logits": tf.train.Feature(bytes_list = parseTensorToBytes(logits))}

    example = tf.train.Example(features = tf.train.Features(feature=featureDict))
    return example.SerializeToString()

In [None]:
# Adjust the input images to our desired size
dsTrain = dsTrain.map(lambda im, lab:(tf.image.stateless_random_crop(im, CROP_SHAPE, (1,2)), lab))
dsTest = dsTest.map(lambda im, lab:(tf.image.stateless_random_crop(im, CROP_SHAPE, (1,2)), lab))

# We modify our teacher to no longer perform random computations
teacherModelNoPreproc = tf.keras.Sequential([tf.keras.Input(CROP_SHAPE),
                                             *teacherModel.layers[3:]])

# Process the datasets using the teacher
preppedDS = RKDModel.prepare_dataset(dsTrain, teacherModelNoPreproc, batchSize = BATCH_SIZE, zipped = False)
preppedDSTest = RKDModel.prepare_dataset(dsTest, teacherModelNoPreproc, batchSize = BATCH_SIZE, zipped = False)

# Then return the images back into their integer format to save GCP space
intImgDS = preppedDS.map(lambda img, logits : (tf.image.convert_image_dtype(img, tf.uint8, saturate = True), logits))
intImgDSTest = preppedDSTest.map(lambda img, logits : (tf.image.convert_image_dtype(img, tf.uint8, saturate = True), logits))

Pre-caching teacher logits...


In [None]:
# some pre-fixed params
BUCKET_PATH = 'gs://cub-preprocessed-dataset/'
INITIAL_DIR = "cub_cropped_normalized"
FILENAME_PREFIX = "cub_256x256"
DATASET_TYPE = "train"
NUM_TPUS = 8
NUM_SHARDS = 10 * NUM_TPUS # you can set this if you feel the need
NUM_SAMPLES = tf.data.experimental.cardinality(dsTrain)

# Create our initial path
dsOutDirectory = BUCKET_PATH + INITIAL_DIR + "/" + DATASET_TYPE
tf.io.gfile.makedirs(dsOutDirectory)

# Begin sharding dataset into .tfrecord files...
# The format will be prefix_type_{%d}-of-{%d}.tfrecord
ims_per_shard = (NUM_SAMPLES + (NUM_SHARDS-1)) // NUM_SHARDS
batchedTrain = intImgDS.batch(ims_per_shard)

for shardNum, curShard in enumerate(batchedTrain):
    curFilename = dsOutDirectory + "/" + FILENAME_PREFIX + "_" + DATASET_TYPE + \
                    "_" + "{:02d}-of-{}.tfrecord".format(shardNum, NUM_SHARDS-1)
    with tf.io.TFRecordWriter(curFilename) as writer:
        for index, entry in enumerate(tf.data.Dataset.from_tensors(curShard).unbatch()):
            image, logits = entry
            proto = createExampleBuff(image, logits)
            writer.write(proto)

In [None]:
NUM_SAMPLES = tf.data.experimental.cardinality(dsTest)
DATASET_TYPE = "test"
dsOutDirectory = BUCKET_PATH + INITIAL_DIR + "/" + DATASET_TYPE

# Begin sharding dataset into .tfrecord files...
# The format will be prefix_type_{%d}-of-{%d}.tfrecord
ims_per_shard = (NUM_SAMPLES + (NUM_SHARDS-1)) // NUM_SHARDS
batchedTest = intImgDSTest.batch(ims_per_shard)

for shardNum, curShard in enumerate(batchedTest):
    curFilename = dsOutDirectory + "/" + FILENAME_PREFIX + "_" + DATASET_TYPE + \
                    "_" + "{:02d}-of-{}.tfrecord".format(shardNum, NUM_SHARDS-1)
    with tf.io.TFRecordWriter(curFilename) as writer:
        for index, entry in enumerate(tf.data.Dataset.from_tensors(curShard).unbatch()):
            image, logits = entry
            proto = createExampleBuff(image, logits)
            writer.write(proto)

## Loading Dataset for TPU Usage

If loading from GCP, the exact opposite procedure from above can be followed. We first glob all protobufs together in a TFRecordDataset and then proceed to parse each protobuffer while converting to the necessary data types. Because our dataset must take on all possible combinations of inputs, we also go ahead and proceed to finish off the zipping using the built in helper function of our class before training...

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
# Imports our saved files for use (this time from our GCS bucket)
BUCKET_PATH = 'gs://cub-preprocessed-dataset/cub_cropped_normalized'
TRAIN_SAVE_PATH = 'train/*'
TEST_SAVE_PATH = 'test/*'

trainFilenames = tf.io.gfile.glob(BUCKET_PATH + "/" + TRAIN_SAVE_PATH)
testFilenames = tf.io.gfile.glob(BUCKET_PATH + "/" + TEST_SAVE_PATH)

In [None]:
# Read back the dataset
dsTrain = tf.data.TFRecordDataset(trainFilenames)
dsTest = tf.data.TFRecordDataset(testFilenames)

# And parse it via the schema
exampleSchema = {
    'image' : tf.io.FixedLenFeature(shape=[], dtype=tf.string),
    'logits' : tf.io.FixedLenFeature(shape=[], dtype=tf.string)
}

# Parsing scheme for the examples
def parseExampleToFeatures(proto):
    parsed = tf.io.parse_example(proto, features = exampleSchema)
    image = tf.io.decode_image(parsed['image'], channels = 3)
    logits = tf.io.parse_tensor(parsed['logits'], out_type = tf.dtypes.float32)
    
    return image, logits

# Finally we want to convert the examples from int back to float for 
# TPU compatability
def normalizeImage(image, logits):
    image = tf.cast(image, tf.dtypes.float32)
    image /= 255.0

    # This needs to be here so TPU can recognize the shapes.
    image.set_shape(IMAGE_SIZE)
    logits.set_shape([128])
    return image, logits

# prepare our datasets with some simple parsing to get us back to the original
# dataset returned
dsTrain = dsTrain.map(parseExampleToFeatures).map(normalizeImage).cache()
dsTest = dsTest.map(parseExampleToFeatures).map(normalizeImage).cache()