<!-- <center><h1>🐋Happywhale EDA and Insights🐋</h1></center> -->
<div style="width:85%; margin:0 auto; position:relative;">
    <img src="https://drive.google.com/uc?id=17CBLV9IC75dAWtpYmQtdHk_KA2cRLi9Z" alt="my_img"/>
    <b>Brief summary</b><span>🐋</span>
    <div style="background-color:#caf0f8; border:1px solid #00b4d8; border-radius:5px; padding:5px 10px;">
        <span>
        We use fingerprints and facial recognition to identify people, but can we use similar approaches with animals? In this competition, you’ll develop a model to match individual whales and dolphins by unique—but often subtle—characteristics of their natural markings. You'll pay particular attention to dorsal fins and lateral body views in image sets from a multi-species dataset built by 28 research institutions. The best submissions will suggest photo-ID solutions that are fast and accurate.
        </span>
    </div>
    <br/>
    <b>Competition goal:</b><br/>
    <span>Identify whales and dolphins by unique characteristics</span>
    <br/><br/>
    <b>Evaluation:</b><br/>
    <span>Submissions are evaluated according to the Mean Average Precision @ 5 (MAP@5):</span>
    <div>$MAP@5 = \frac{1}{U}\displaystyle \sum_{u=1}^{U}\sum_{k=1}^{min(n, 5)}P(k)\times rel(k)$</div>
    <div>
        where <b><i>U</i></b> is the number of images, <b><i>P(k)</i></b> is the precision at cutoff <b><i>k</i></b>, <b><i>n</i></b> is the number of predictions per image, and <b><i>rel(k)</i></b> is an indicator function equaling 1 if the item at rank <b><i>k</i></b> is a relevant (correct) label, zero otherwise.
    </div>
    <span style="background-color:#E99BDB; border:1px solid #C92CAC; border-radius:5px; padding:5px 10px;">
     Let's understand this better below<span style="font-size:25px;">👇</span> (Huge thanks to <a href="https://www.kaggle.com/pestipeti/explanation-of-map5-scoring-metric">this notebook</a> for understanding this evaluation metric)
    </span>
</div>

In [None]:
import numpy as np 

def map_per_image(label, predictions):
    """Computes the precision score of one image.
    Parameters
    ----------
    label : string
            The true label of the image
    predictions : list
            A list of predicted elements (order does matter, 5 predictions allowed per image)
    Returns
    -------
    score : double
    """    
    try:
        return 1 / (predictions[:5].index(label) + 1)
    except ValueError:
        return 0.0
def map_per_set(labels, predictions):
    """Computes the average over multiple images.
    Parameters
    ----------
    labels : list
             A list of the true labels. (Only one true label per images allowed!)
    predictions : list of list
             A list of predicted elements (order does matter, 5 predictions allowed per image)
    Returns
    -------
    score : double
    """
    return np.mean([map_per_image(l, p) for l,p in zip(labels, predictions)])

assert map_per_image('A',['A', 'B', 'C', 'D', 'E']) == 1.0
assert map_per_image('A', ['A', 'A', 'A', 'A', 'A']) == 1.0
assert map_per_image('A', ['A', 'B', 'A', 'C', 'A']) == 1.0

<div style="width:85%; margin:0 auto;">
<div 
     style="background-color:#B2EC98; border:1px solid #55C123; border-radius:5px; padding:5px 10px; display:inline-block; font-size:25px;">
    1. Loading and understanding Data
</div>
    <div>
    <ul>
    <li>📌Import libraries</li>
    <li>📌Setting seed</li>
    <li>📌Defining global variables</li>
    <li>📌Loading and understanding data</li>
    </ul>
    </div>
</div>

In [None]:
import os 
import cv2
import numpy as np
import glob
import pandas as pd
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import seaborn as sns

def set_seed(SEED):
    np.random.seed(SEED)
    os.environ["PYTHONHASHSEED"] = str(SEED)
    tf.random.set_seed(SEED)
    tf.keras.backend.clear_session()
    print(f'SEED {SEED} SET!')


class Config:
    seed = 42
    batch_size = 4
    epochs = 5
    img_size = 448
    learning_rate = 0.001
    n_fold:int = 5
    gpus:list = tf.config.list_physical_devices('GPU')

set_seed(Config.seed)

root = '../input/happy-whale-and-dolphin'
train_images = f'{root}/train_images'
test_images = f'{root}/test_images'
tabular_path = f'{root}/train.csv'

In [None]:
tabular_data = pd.read_csv(tabular_path)
tabular_data.head()

__From the sample above of the first 5 rows of the data__🐬:
> * We have <code>image name and extension</code>, <code>species</code> to which the dolphin or whale belongs to and the <code>individual ids</code> of the whales and dolphins.
> *The individual ids will be used as the labels for the similarity search to identify whales and dolphins based on their unique yet subtle markings.
<div style="background-color:#FFD6E5; border:1px solid #FFAFCC; border-radius:5px; padding:5px 10px; display:inline-block;">
Let's explore more<span style="font-size:20px;">😄</span>
</div>

In [None]:
from IPython.display import display, Markdown, HTML
coolors = [
    ("#A2D2FF", "#EBF5FF"),
    ("#FFAFCC", "#FFD6E5"),
    ("#55C123", "#B2EC98"),
    ("#00b4d8", "#caf0f8"),
    ("#C92CAC", "#E99BDB"),
    ("#FFB35C", "#FFE2C2"),
    ("#D4A373","#F1E0D0")
]

display(HTML(" \
    <div style='background-color:#EBF5FF; border:1px solid #A2D2FF; \
        border-radius:5px; padding:5px 10px; display:inline-block;'> \
    We're back from the Orient Express and the Nile, and this is what we found \
             <span style='font-size:20px;'>🕵️</span> \
             </div>"))

no_rows = tabular_data.shape[0]
display(Markdown("* There are __{}__ rows in the dataset".format(no_rows)))
display(Markdown("In the dataset, there are {} species of whales and dolphins"
      .format(tabular_data.species.nunique())))
display(Markdown("#### 🐋The classes are:🐬"))

all_classes = "<div style='display:flex; flex-wrap:wrap;'>"
for i in tabular_data.species.unique():
    border, bg = coolors[np.random.randint(len(coolors))]
    whale_or_dolphin_class = i#(' '.join(i.split('_'))).capitalize()
    whale_or_dolphin_class = f"🐋{whale_or_dolphin_class}" if 'whale' in whale_or_dolphin_class.lower() else whale_or_dolphin_class
    whale_or_dolphin_class = f"🐬{whale_or_dolphin_class}" if 'dolphin' in whale_or_dolphin_class.lower() else whale_or_dolphin_class
    all_classes += "<div style='background-color:{}; border:1px solid {}; \
        border-radius:5px; padding:5px 10px; margin:1px; display:inline-block;'> \
                 {} </div>".format(bg, border, whale_or_dolphin_class)
all_classes += "</div>"
display(HTML(all_classes))
display(Markdown("🖋️The awesome colors can be found on [Coolors](https://coolors.co/palettes/trending)"))

>* There is a typo on the species bottlenose_dolpin which needs to be corrected to bottlenose_dolphin
>* There is a typo on the species kiler_whale  which needs to be corrected to killer_whale
>* This reduces the total species to 28
>* Also looking at the unknown species, we find out that:
>
>
>>📓 <code>Belugas</code> are <code>toothed whales</code>, and are not part of the oceanic dolphin family. They are classified under the Monodontidae family, which only consists of two species: <code>belugas</code> and <code>narwhals</code>.
>
>
>>📓 <code>Globis</code> seems to have been confused and is supposed to be <code>globus</code>, a latin for whales named "pilot whales" because their birth pods were believed to be "piloted" by a leader. Let us go further into looking at the images of the species to confirm this.

In [None]:
tabular_data.species.replace(
    ["bottlenose_dolpin", "kiler_whale", "beluga"],
    ["bottlenose_dolphin","killer_whale","beluga_whale"], inplace=True)

* Let's add another column for file paths to be able to reference file paths for the ~globis~ __*or rather globus*__ species 

In [None]:
tabular_data['path'] = tabular_data.image.apply(lambda x: f'{train_images}/{x}')
globis = tabular_data.loc[tabular_data.species == 'globis']
print("There are {} rows for the globis species".format(globis.shape[0]))

<div style="background-color:#FFE2C2; border:1px solid #FFB35C; border-radius:5px; padding:5px 10px; display:inline-block;">
Now onto the images
</div>

In [None]:
def read_img(image):
    image = cv2.imread(image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return cv2.resize(image, (Config.img_size, Config.img_size))
    
vertical_stack = []
for i in range(5):
    horizontal_stack = []
    for j in range(5):
        img = read_img(globis.iloc[i *  5 + j].path)
        horizontal_stack.append(img)
    horizontal_stack = np.asarray(horizontal_stack)
    horizontal_stack = np.hstack(horizontal_stack)
    vertical_stack.append(horizontal_stack)
vertical_stack = np.asarray(vertical_stack)
vertical_stack = np.vstack(vertical_stack)

plt.figure(figsize=(16, 16))
plt.imshow(vertical_stack)
plt.show()

> Cross referencing with sample images from pilot whales


In [None]:
pilot_whale = tabular_data.loc[tabular_data.species == 'pilot_whale']

horizontal_stack = []
for j in range(5):
    img = read_img(pilot_whale.iloc[j].path)
    horizontal_stack.append(img)
horizontal_stack = np.asarray(horizontal_stack)
horizontal_stack = np.hstack(horizontal_stack)
plt.figure(figsize=(15, 40))
plt.imshow(horizontal_stack)
plt.show()

📌 And also, after looking through other amazing notebooks such as [Andrada Olteanu's](https://www.kaggle.com/andradaolteanu/whales-dolphins-effnet-embedding-cos-distance), it has been decided to rename <code>~globis~ globus</code>(now confirmed to be a typo👍) and <code>pilot_whale</code> to <code>short_finned_pilot_whale</code>
> Yeap! The detective was right!

In [None]:
tabular_data.species.replace(
    ["globis", "pilot_whale"],
    ["short_finned_pilot_whale","short_finned_pilot_whale"], inplace=True)

In [None]:
# Looking now at the percentage of the classes of whales and dolphins

# First, a new column is created indicating whether the sample is a whale or a dolphin
tabular_data['class'] = tabular_data.species.apply(lambda x: 'whale' if 'whale' in x else 'dolphin')
total_counts = tabular_data.shape[0] # Get all the samples of the data
whale_counts = tabular_data['class'].value_counts()['whale'] # Get all whale counts in the data
dolphin_counts = tabular_data['class'].value_counts()['dolphin'] # Get all dolphin counts in the data
display(Markdown("The species are now {} with <code>whales</code> being \
                 __{} ({:.2f}%)__ and <code>dolphins</code> being __{} ({:.2f}%)__" 
                 .format(tabular_data.species.nunique(), 
                         whale_counts, whale_counts/total_counts * 100, 
                         dolphin_counts, dolphin_counts/total_counts * 100)))

In [None]:
g = sns.catplot(x="species", kind="count", data=tabular_data,  height=9)
g.set_xticklabels(rotation=90)
species_and_values = [[i[0], i[1]/total_counts*100] for i in tabular_data.species.value_counts().items()]
for p in g.ax.patches:
    height = p.get_height()
    x = p.get_x()
    g.ax.text(x - 0.25, height + 3, '{:.2f} %'.format(height/total_counts*100), size=10)
plt.title("Count of species of Dolphins and Whales as a percentage")
plt.show()

<div style="width:85%; margin:0 auto;">
<div 
     style="background-color:#F1E0D0; border:1px solid #D4A373; border-radius:5px; padding:5px 10px; display:inline-block; font-size:25px;">
    2. Label encoding and Data Loading pipeline
</div>
    <div>
    <ul>
    <li>📌Label encoding</li>
    <li>📌K-fold Cross validation</li>
    <li>📌Data loading</li>
    <li>📌Data augmentation</li>
    </ul>
    </div>
</div>

In [None]:
# Sklearn Imports
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold

In [None]:
encoder = LabelEncoder()

# Let us map the individual ids to integers using integer encoding
tabular_data['encoding'] = encoder.fit_transform(tabular_data.individual_id)

In [None]:
tabular_data.head()

<div style="width:85%; margin:0 auto">
<span style="font-size:25px">📝K-fold cross validation</span>
<div style="background-color:#FFE2C2; border:1px solid #FFB35C; border-radius:5px; padding:5px 10px; display:inline-block;">
When we are using K-fold cross validation, it is often beneficial to have folds containing roughly the same percentage of observations from each of the different target classes (called <code>stratified k-fold</code>). For example, if our target vector contained gender and 80% of the observations were male, then each fold would contain 80% male and 20% female observation
</div>
</div>

In [None]:
stratified_k_fold = StratifiedKFold(
                    n_splits=Config.n_fold
                    )
stratified_k_fold_splits = stratified_k_fold.split(
    X = tabular_data.drop(columns = 'encoding'),
    y = tabular_data['encoding']
)

for fold, (train_index, valid_index) in enumerate(
                                stratified_k_fold_splits
                                ):
    tabular_data.loc[valid_index, "kfold"] = np.int32(fold)
    
tabular_data["kfold"] = tabular_data["kfold"].astype(int)

tabular_data.head()

In [None]:
import albumentations as A
from albumentations import transforms
AUTOTUNE = tf.data.AUTOTUNE

Normalize = transforms.Normalize
CoarseDropout = transforms.CoarseDropout
Flip = transforms.Flip

def read_img(path, label):
    img = tf.io.decode_image(
        tf.io.read_file(path), 
        channels=3, 
        dtype=tf.dtypes.float32,
        expand_animations=False
    )
    img = tf.image.resize(img, (Config.img_size, Config.img_size))
    return img, label

def get_transform():
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    return A.Compose([
        Normalize(mean=mean, std=std, always_apply=True, p=1.0),
        CoarseDropout(max_holes=30, max_height=10, max_width=10, fill_value=64),
        Flip(),
        ])

def augment_image(image,label, transforms):
    image = image.numpy()
    return transforms(image=image)["image"], label

In [None]:
# Read the footer below 
# def load_dataset(data, fold):
#     dataset = tabular_data.loc[
#       tabular_data.kfold==fold, ['path', 'encoding', 'kfold']
#     ].values
#     path = dataset[:, 0]
#     labels = dataset[:, 1].astype(np.float32)
#     dataset = tf.data.Dataset.from_tensor_slices((path, labels))
#     transform = get_transform()
#     dataset = dataset.map(
#                lambda x, y: read_img(x, y),
#                     num_parallel_calls = AUTOTUNE).map(
#                lambda x, y: tf.py_function(
#                     func = augment_image(x, y, transform),
#                     inp=[x, y],
#                     Tout=[tf.float32, tf.float32]
#                     ),
#                 num_parallel_calls = AUTOTUNE
#                ).batch(Config.batch_size)
#     return dataset

<div style="margin:0 auto; width:85%; position:relative; background-color:#FFE2C2; border:1px solid #FFB35C; border-radius:5px; padding:5px 10px; display:inline-block;">
    <img style="position:absolute; margin-top:80px; right:-10px;" src="https://drive.google.com/uc?id=1Wj9P2p9xq_h_Lk-PNw6W8uUFT2kXPydW" alt="my_img"/>
    <div>
        <span style="font-size:30px;">🧐</span>
        I don't know why, but for some reason, <code>strings</code> and <code>TensorFlow dataset map</code> don't see eye to eye. Could someone investigate this code more. 'Cause only with the invitation of Holmes and Watson could we find the real culprit!
    </div>
    <img src="https://i.imgur.com/wH8Ouo4.png" alt="my_img"/>
    <div style="font-size:20px;"><span style="font-size:25px;">✈️</span>Hence, I went on to use generators!</div>
</div>

In [None]:
from sklearn.model_selection import train_test_split

def load_dataset(dataset):    
    transforms = get_transform()
    def generator():
            for path, label in dataset:
                image, _ = read_img(path, label)
                image = image.numpy()
                image = transforms(image=image)["image"]
                yield image, label
    data = tf.data.Dataset.from_generator(
    generator,
      output_signature=(tf.TensorSpec(shape=(Config.img_size,Config.img_size,3),dtype=tf.float32),
              tf.TensorSpec(shape=(),dtype=tf.float32))
    )
    return data.batch(Config.batch_size)

def train_and_validation_data(data, fold=0, test_size=0.33):
    data = data.loc[data.kfold==fold, ['path', 'encoding']].values
    train, valid = train_test_split(data, test_size=test_size, random_state=Config.seed)
    train = load_dataset(train)
    valid = load_dataset(valid)
    return train, valid

### 📌 A bit of the visualization

In [None]:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
    
data, _ = train_and_validation_data(tabular_data, 1)

plt.figure(figsize=(10,10))
for images, labels in data.take(1):
    for i, (image, label) in enumerate(zip(images, labels)):
        image, label = image.numpy(), label.numpy()
        plt.subplot(4, 4, i + 1)
        image = ((np.asarray(image) * np.array(std) * 255) +(np.array(mean) * 255))
        plt.imshow(image)
        plt.title(f"{label}")
        plt.axis("off")
    plt.show()

<div style="width:85%; margin:0 auto;">
<div 
     style="background-color:#B2EC98; border:1px solid #55C123; border-radius:5px; padding:5px 10px; display:inline-block; font-size:25px;">
    3. Model implementation and Embeddings
</div>
    <div>
    <ul>
    <li>📌GeM pooling</li>
    <li>📌ArcFace Loss function</li>
    <li>📌Efficientnetb7 model</li>
    <li>📌Stitching everything together</li>
    </ul>
    </div>
</div>

### 📌GeM (Generalized-Mean) pooling
Generalized Mean Pooling (GeM) computes the generalized mean of each channel in a tensor. Formally:
$$e = \left[(\frac{1}{|\Omega|}\sum_{u\in\Omega}x_{cu}^p)\right]_{c=1,...,C}$$
where $p > 0$ is a parameter. Setting this exponent as $p > 0$ increases the contrast of the pooled feature map and focuses on the salient features of the image. GeM is a generalization of the average pooling commonly used in classification networks ($p = 1$) and of spatial max-pooling layer ($p = \infty$).
<div style="width:50%; margin:0 auto;"><img src="https://production-media.paperswithcode.com/methods/Screen_Shot_2020-06-09_at_3.10.49_PM_FPwQGTY.png"/></div>

In [None]:
class GeMPoolingLayer(tf.keras.layers.Layer):
    def __init__(self, p=3., train_p=False):
        super(GeMPoolingLayer, self).__init__()
        if train_p:
            self.p = tf.Variable(p, dtype=tf.float32)
        else:
            self.p = p
        self.eps = 1e-6
        
    #def build(self, x_shape):
        #self.avg_pool_2d = tf.keras.layers.AveragePooling2D((x_shape[0], x_shape[1]))

    def call(self, inputs: tf.Tensor, **kwargs):
        inputs = tf.clip_by_value(inputs, clip_value_min=self.eps, clip_value_max=tf.reduce_max(inputs))
        inputs = tf.pow(inputs, self.p)
        inputs = tf.reduce_mean(inputs, axis=[1, 2], keepdims=False)
        #inputs = self.avg_pool_2d(inputs)
        inputs = tf.pow(inputs, 1./self.p)
        return inputs

### 📌ArcFace (Additive Angular Margin) Loss Function
> ArcFace is a machine learning model that takes two face images as input and outputs the distance between them to see how likely they are to be the same person. It can be used for face recognition and face search. __and in this case, similarity search in subtle features of whales and dolphins__


> ArcFace uses a similarity learning mechanism that allows distance metric learning to be solved in the classification task by introducing Angular Margin Loss to replace Softmax Loss.
The distance between faces is calculated using cosine distance, which is a method used by search engines and can be calculated by the inner product of two normalized vectors. If the two vectors are the same, <b>$\theta$ will be 0</b> and $cos(\theta) = 1$. If they are orthogonal, $\theta$ <b>will be</b> $\pi/2$ and $cos\theta$<b>=0</b>. Therefore, it can be used as a similarity measure.
<img src="https://learnopencv.com/wp-content/uploads/2020/07/arcface-1024x252.jpg"/>

In [None]:
import math 
class ArcMarginProduct(tf.keras.layers.Layer):
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):
        super(ArcMarginProduct, self).__init__(**kwargs)
        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'n_classes': self.n_classes,
            's': self.s,
            'm': self.m,
            'ls_eps': self.ls_eps,
            'easy_margin': self.easy_margin,
        })
        return config
    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])
        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)
    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

### 📌EfficientNetB7 model implementation

In [None]:
EfficientNetB7 = tf.keras.applications.EfficientNetB7

baseModel = EfficientNetB7(
    include_top = False
)

for layer in baseModel.layers:
    if not isinstance(layer, tf.keras.layers.BatchNormalization):
        layer.trainable = True
    else:
        layer.trainable = False

In [None]:
margin = ArcMarginProduct(
    n_classes = tabular_data.encoding.nunique(), 
    s = 30, 
    m = 0.3, 
    name=f'head/arcface', 
    dtype='float32'
    )

In [None]:
input_1 = tf.keras.layers.Input(shape = (Config.img_size, Config.img_size, 3))
input_2 = tf.keras.layers.Input(shape = ())
x = baseModel(input_1)
x = GeMPoolingLayer()(x)
embed = tf.keras.layers.Dropout(0.2)(x)
embed = tf.keras.layers.Dense(512)(embed)
x = margin([embed, input_2])
output = tf.keras.layers.Softmax(dtype='float32')(x)

model = tf.keras.models.Model(inputs = [input_1, input_2], outputs = [output])

In [None]:
dummy = tf.random.normal((1, Config.img_size, Config.img_size, 3))
model([dummy, 1])

In [None]:
tf.keras.utils.plot_model(model)

### 📌Model training

In [None]:
from tqdm.notebook import tqdm
train_dataset, val_dataset = train_and_validation_data(tabular_data, 0)

optimizer = tf.keras.optimizers.Adam(learning_rate = Config.learning_rate)
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
loss = tf.keras.losses.SparseCategoricalCrossentropy()
sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
sparse_top_k_accuracy = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5)

def compute_loss(labels, predictions):
    per_example_loss = loss(labels, predictions)
    return tf.nn.compute_average_loss([per_example_loss], global_batch_size=Config.batch_size)

def train_step(inputs):
    images, labels = inputs
    with tf.GradientTape() as tape:
        predictions = model([images,labels], training=True)
        loss = compute_loss(labels, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    sparse_categorical_accuracy.update_state(labels, predictions)
    sparse_top_k_accuracy.update_state(labels, predictions)
    return loss 

for epoch in range(Config.epochs):
    # TRAIN LOOP
    for batch in tqdm(train_dataset, total=957):
        total_loss = 0.0
        num_batches = 0
        total_loss += train_step(batch)
        num_batches += 1
    train_loss = total_loss / num_batches
    if epoch % 2 == 0:
        checkpoint.save("efficientnetb7.h5")
    template = ("Epoch {}, Loss: {}, Sparse Accuracy: {}, "
              " Sparse Top K Accuracy: {}")
    print (template.format(epoch+1, train_loss,
                         sparse_categorical_accuracy.result()*100, 
                         sparse_top_k_accuracy.result()*100))

In [None]:
!nvidia-smi