# Importing Required Libraries

In [None]:
import os

# Preprocessing and Plotting Libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

from sklearn.model_selection import train_test_split
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.vgg16 import preprocess_input

# Tools for training model
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.applications import vgg16
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.callbacks import Callback, EarlyStopping

#Tools for evaluation of model
from sklearn.metrics import classification_report, confusion_matrix

# Setting Environment For Distributed Training On Kaggle TPU

In [None]:
os.environ

In [None]:
# try:
#     tpu_address = os.environ['TPU_NAME']
#     tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_address)
#     tf.config.experimental_connect_to_cluster(tpu)
#     tf.tpu.experimental.initialize_tpu_system(tpu)

#     strategy = tf.distribute.experimental.TPUStrategy(tpu)

#     print('Running on TPU: ', tpu.cluster_spec().as_dict()['worker'])
#     print('Number of accelerators: ', strategy.num_replicas_in_sync)

# except ValueError:
#     print('TPU failed to initialize.')    

# Loading Data as Tensorflow Dataset

In [None]:
train_path = '../input/a-large-scale-fish-dataset/Fish_Dataset/Fish_Dataset/*/*/*'
test_path = '../input/a-large-scale-fish-dataset/NA_Fish_Dataset'

In [None]:
# Define constants
IMG_SIZE = [224, 224]
BATCH_SIZE = 128
VAL_SPLIT = 0.2

In [None]:

def get_class(file_path):
    return tf.strings.split(file_path, os.path.sep)[-2]


def process_image(file_path):
    label = get_class(file_path)
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img)
    img = tf.image.resize(img, IMG_SIZE)
    img = tf.cast(img, tf.float32) / 255.0
    return img, label


def load_dataset(file_path):
    image_dataset = tf.data.Dataset.list_files(file_path, shuffle = True)
    image_dataset = image_dataset.filter(lambda x: tf.strings.split(get_class(x), ' ')[-1] != 'GT')
    image_dataset = image_dataset.map(process_image)
    return image_dataset


def get_batched_dataset(file_path):
    dataset = load_dataset(file_path)
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder = False)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    dataset = dataset.cache()
    
    return dataset

In [None]:
# img_ds = load_dataset(file_path)
dataset = get_batched_dataset(file_path)