# Importing Required Libraries

In [1]:
import os

# Preprocessing and Plotting Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

# Tools for training model
import tensorflow as tf
from tensorflow.keras.applications import resnet50

#Tools for evaluation of model
from sklearn.metrics import classification_report, confusion_matrix

# Setting Environment For Distributed Training On Kaggle TPU

In [2]:
os.environ

In [3]:
try:
    tpu_address = os.environ['TPU_NAME']
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_address)
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)

    strategy = tf.distribute.experimental.TPUStrategy(tpu)

    print('Running on TPU: ', tpu.cluster_spec().as_dict()['worker'])
    print('Number of accelerators: ', strategy.num_replicas_in_sync)

except ValueError:
    print('TPU failed to initialize.')

# Loading Data as Tensorflow Dataset

In [4]:
train_path = '../input/a-large-scale-fish-dataset/Fish_Dataset/Fish_Dataset/*/*/*'
test_path = '../input/a-large-scale-fish-dataset/NA_Fish_Dataset/*/*'

In [5]:
# Define constants
IMG_SIZE = [224, 224]
BATCH_SIZE = 128
VAL_SPLIT = 0.2
CLASSES = ['Black Sea Sprat', 'Gilt-Head Bream', 'Hourse Mackerel', 'Red Mullet', 'Red Sea Bream', 'Sea Bass', 'Shrimp', 'Striped Red Mullet', 'Trout']

In [6]:
# Training And Validation Data
def get_class(file_path):
    return tf.strings.split(file_path, os.path.sep)[-2]


def process_image(file_path):
    label = get_class(file_path)
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img)
    img = tf.keras.applications.resnet50.preprocess_input(img)
    img = tf.image.resize(img, IMG_SIZE)
    img = tf.cast(img, tf.float32) / 255.0
    return img, label


def load_dataset(file_path):
    image_dataset = tf.data.Dataset.list_files(file_path, shuffle = True)
    split = int(len(image_dataset) * VAL_SPLIT)

    valid_dataset = image_dataset.take(split)
    train_dataset = image_dataset.skip(split)

    train_dataset = train_dataset.filter(lambda x: tf.strings.split(get_class(x), ' ')[-1] != 'GT')
    train_dataset = train_dataset.map(process_image)

    valid_dataset = valid_dataset.filter(lambda x: tf.strings.split(get_class(x), ' ')[-1] != 'GT')
    valid_dataset = valid_dataset.map(process_image)
    return train_dataset, valid_dataset


def get_batched_dataset(dataset):
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder = False)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    dataset = dataset.cache()    
    return dataset


def get_train_val_data(file_path):
    dataset = load_dataset(file_path)
    train = get_batched_dataset(dataset[0])
    val = get_batched_dataset(dataset[1])
    return train, val

In [7]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()

        self.base_block = resnet50.ResNet50()

        self.base_block = tf.keras.models.Model(
            inputs = self.base_block.input,
            outputs = self.base_block.layers[-2].output            
        )

        for layer in self.base_block.layers[:140]:
            layer.trainable = False

        self._dense1 = tf.keras.layers.Dense(128, activation='relu')
        self._dense2 = tf.keras.layers.Dense(64, activation='relu')
        self._dense3 = tf.keras.layers.Dense(32, activation='relu')
        self._dense4 = tf.keras.layers.Dense(16, activation='relu')
        self._classifier = tf.keras.layers.Dense(9, activation='relu')

    def call(self, inputs):

        base_output = self.base_block(inputs)

        x = self._dense1(base_output)
        x = self._dense2(x)
        x = self._dense3(x)
        x = self._dense4(x)

        x = self._classifier(x)

        return x

In [8]:
train_data, valid_data = get_train_val_data(train_path)

In [None]:
img = train_data.take(1)

In [None]:
images = []
for image in train_data.take(1):
    images.append(image)

In [None]:
plt.imshow((images[0][0])[0, :, :, :])

In [None]:
(images[0])[1][0]