## 1. Initialization

#### 1.1 Importing libraries

In [None]:
!pip install visual-attention-tf

In [None]:
import numpy as np
from numpy import cov
from numpy import trace
from numpy import iscomplexobj
from numpy import asarray
from numpy.random import randint
from scipy.linalg import sqrtm
import tensorflow_addons as tfa
import cv2
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import PIL
import skimage
from skimage.exposure import is_low_contrast
from skimage.transform import rescale, SimilarityTransform, AffineTransform, rotate
# comment for kaggle/colab
##########################
# import mediapipe as mp
# import Augmentor
# from sympy import im
##########################
import tensorflow as tf
# import tensorflow_addons as tfa
# import tensorflow_gan as tfg
from tensorflow import keras
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input
from keras import layers, losses
from keras.layers import Input, InputLayer,  Dense, Embedding, Conv2D, Dropout, Flatten, RandomTranslation, LeakyReLU, Reshape, Conv2DTranspose, GlobalMaxPooling2D
from keras.models import Sequential, Model
from keras import activations
from tensorflow.keras.applications import DenseNet121, resnet50, MobileNetV2, VGG19, InceptionV3
import tensorflow_hub as hub
import tensorflow_probability as tfp
from PIL import Image, ImageFont, ImageDraw, ImageColor
import os
import pathlib
import tarfile
import pathlib as pb
import pandas as pd
from collections import defaultdict
import enum
import random
import seaborn as sns
from tqdm import tqdm
import datetime
import shutil
# comment for kaggle/colab
# from rembg import remove
from sklearn.model_selection import train_test_split
from sklearn import preprocessing, dummy
import warnings
import time
from IPython.display import FileLinks, FileLink
from visual_attention import PixelAttention2D , ChannelAttention2D,EfficientChannelAttention2D


sns.set_style('dark')
np.random.seed(42)
warnings.filterwarnings('ignore')


#### 1.2 Libraries version check

In [None]:
tf.__version__, tf.executing_eagerly(), np.__version__, pd.__version__


In [None]:
!nvidia-smi


#### 1.3 Declaring constants

In [None]:
kaggle = True

DATASET_DIRECTORY = '../Dataset/tea sickness dataset/'
if kaggle:
    DATASET_DIRECTORY = '../input/tea-dataset/tea sickness dataset - kaggle'


INCEPTION_TFHUB = 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1'


OBJECT_DETECTION_MODEL_DICT = {'ssd': {'dir': './object_detction_models/mobilenet_ssd',
                                       'name': 'openimages_v4_ssd_mobilenet_v2_1.tar.gz'},

                               'rcnn': {'dir': './object_detction_models/inception_resnet',
                                        'name': 'faster_rcnn_openimages_v4_inception_resnet_v2_1.tar.gz'}
                               }
MODEL_CHECKPOINT_PATHS = {
    'vgg': './checkpoints/vgg/',
    'resnet': './checkpoints/resnet_50/',
    'mobilenet': './checkpoints/mobilenet/',
    'inception': './checkpoints/inception/',
    'densenet': './checkpoints/densenet/',
    'wgan': './checkpoints/wgan/',
    'wgan_res': './checkpoints/wgan_res/',
}


class model_type(enum.Enum):
    mobilenet_ssd = 'ssd'
    inception_resnet_rcnn = 'rcnn'


BATCH_SIZE_DATASET = 16
BATCH_SIZE_TRAIN = 16
BATCH_SIZE_GAN_TRAIN = 4
BUFFER_SIZE = 1000
IMG_SIZE_TRAIN = (180, 180)


## 2. Reading file paths and initial analysis of images

#### 2.1 Display all the tea dataset sickness folders present

In [None]:
sickness_folder_lst = os.listdir(os.path.abspath(DATASET_DIRECTORY))
sickness_folder_lst


#### 2.2 Display the number of images in each sickness folder

In [None]:
sickness_folder_dict = defaultdict(lambda: 'NA')

for sickenss in sickness_folder_lst:
    folder_path = os.path.join(DATASET_DIRECTORY, sickenss)
    img_path_lst = list(pb.Path(folder_path).glob('*.jpg'))
    img_path_lst = [os.path.abspath(path) for path in img_path_lst]
    sickness_folder_dict.update({sickenss: img_path_lst})

image_count_df = pd.DataFrame(index=sickness_folder_dict.keys(), data=[len(
    v) for v in sickness_folder_dict.values()], columns=['image_count'])
image_count_df  # .to_clipboard()


#### 2.3 Data visualization

In [None]:
image_count_df.plot.bar(color='y')
plt.show()


In [None]:
# def axis_plot(ax, _img_path, desc, fontsize=12):
#     """_summary_

#     Args:
#         ax (_type_): _description_
#         _img_path (_type_): _description_
#         desc (_type_): _description_
#         fontsize (int, optional): _description_. Defaults to 12.
#     """
#     img_array = cv2.imread(str(_img_path))

#     ax.imshow(img_array)
#     ax.tick_params(left=False, right=False, labelleft=False,
#                    labelbottom=False, bottom=False)

#     ax.locator_params(nbins=5)

#     ax.set_title(desc, fontsize=fontsize)


In [None]:
# fig, ax = plt.subplots(nrows=8, ncols=5, figsize=(20, 20))

# for idx, key_vals in enumerate(sickness_folder_dict.items()):
#     key, vals = key_vals
#     rand_idxs = np.random.choice(range(len(vals)), size=5)
#     vals = [vals[i] for i in rand_idxs]
#     for i_dx, val in enumerate(vals):

#         axis_plot(ax[idx][i_dx], val, desc=key)

# plt.tight_layout(pad=0.1, w_pad=0.1, h_pad=0.25)


# 3. Image pre-processing pipeline

In [None]:
# class HelperFunctions():

#     def __init__(self) -> None:
#         pass

#     def read_single_image(self,  _img_path='', sickness_name='', randomly=False):
#         if randomly:
#             _img_path = os.path.abspath(np.random.choice(
#                 sickness_folder_dict[sickness_name]))

#         else:
#             _img_path = os.path.abspath(_img_path)

#         return _img_path, cv2.imread(str(_img_path))

#     def resize_image(self, img_array,  new_dsize):
#         img_array = cv2.resize(img_array, dsize=new_dsize)
#         return img_array

#     def display_image(self, img_array, desc='', fontsize=14, figsize=(7, 7)):
#         """
#         Displays an image inside the notebook.
#         This is used by download_and_resize_image()
#         """
#         fig = plt.figure(figsize=figsize)
#         plt.grid(False)
#         plt.tick_params(left=False, right=False, labelleft=False,
#                         labelbottom=False, bottom=False)

#         plt.title(desc, fontsize=fontsize)
#         plt.imshow(img_array)

#     def add_denoising_nlmeans(self, img_array):
#         return cv2.fastNlMeansDenoisingColored(img_array, None, 10, 10, 7, 15)

#     def add_gaussian_blur(self, img_array):

#         return cv2.GaussianBlur(img_array, (5, 5), 0)

#     # def add_edge_detection(img_array):
#     #     return cv2.Canny(img_array, 30, 150)

#     def create_white_bg(self, shape, bg_fp='./white_bg.jpg'):
#         white_bg = np.full(shape, 255, dtype=np.uint8)
#         plt.imsave(bg_fp, white_bg)

#     def superimpose_white_bg(self, img, bg_fp='./white_bg.jpg', binary_mask_th=0.8):
#         change_background_mp = mp.solutions.selfie_segmentation

#         change_bg_segment = change_background_mp.SelfieSegmentation()
#         result = change_bg_segment.process(image=img)
#         binary_mask = result.segmentation_mask > binary_mask_th
#         binary_mask_3 = np.dstack((binary_mask, binary_mask, binary_mask))

#         output_image = np.where(binary_mask_3, img, 255)
#         bg_img = cv2.imread(bg_fp)
#         output_image = np.where(binary_mask_3, img, bg_img)
#         return output_image

#     def remove_bg(self, image):
#         # Fill the black background with white color
#         # rgb to hsv color space
#         hsv_img = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)

#         s_ch = hsv_img[:, :, 1]  # Get the saturation channel

#         # Apply threshold - pixels above 5 are going to be 255, other are zeros.
#         thesh = cv2.threshold(s_ch, 5, 255, cv2.THRESH_BINARY)[1]
#         # Apply opening morphological operation for removing artifacts.
#         thesh = cv2.morphologyEx(
#             thesh, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)))

#         # Fill the background in thesh with the value 128 (pixel in the foreground stays 0.
#         cv2.floodFill(thesh, None, seedPoint=(0, 0),
#                       newVal=255, loDiff=1, upDiff=1)

#         # Set all the pixels where thesh=128 to red.
#         image[thesh == 128] = (255, 255, 255)

#         return image

#     def remove_bg2(self, image):
#         return remove(image)

#     def draw_bounding_box_on_image(self,
#                                    image,
#                                    ymin,
#                                    xmin,
#                                    ymax,
#                                    xmax,
#                                    font,
#                                    color='#FFFFFF',
#                                    thickness=1,
#                                    display_str_list=()):
#         """
#         Adds a bounding box to an image.

#         Args:
#             image -- the image object
#             ymin -- bounding box coordinate
#             xmin -- bounding box coordinate
#             ymax -- bounding box coordinate
#             xmax -- bounding box coordinate
#             color -- color for the bounding box edges
#             font -- font for class label
#             thickness -- edge thickness of the bounding box
#             display_str_list -- class labels for each object detected


#         Returns:
#             No return.  The function modifies the `image` argument 
#                         that gets passed into this function

#         """
#         draw = ImageDraw.Draw(image)
#         im_width, im_height = image.size

#         # scale the bounding box coordinates to the height and width of the image
#         (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
#                                       ymin * im_height, ymax * im_height)

#         # define the four edges of the detection box
#         draw.line([(left, top), (left, bottom), (right, bottom), (right, top),
#                    (left, top)],
#                   width=thickness,
#                   fill=color)

#     def draw_boxes(self, image, boxes, class_names, scores):
#         """
#         Overlay labeled boxes on an image with formatted scores and label names.

#         Args:
#             image -- the image as a numpy array
#             boxes -- list of detection boxes
#             class_names -- list of classes for each detected object
#             scores -- numbers showing the model's confidence in detecting that object
#             max_boxes -- maximum detection boxes to overlay on the image (default is 10)
#             min_score -- minimum score required to display a bounding box

#         Returns:
#             image -- the image after detection boxes and classes are overlaid on the original image.
#         """
#         # colors = list(ImageColor.colormap.values())

#         try:
#             font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf",
#                                       25)
#         except IOError:
#             #print("Font not found, using default font.")
#             font = ImageFont.load_default()

#             # only display detection boxes that have the minimum score or higher
#         idx = np.argmax(scores)

#         ymin, xmin, ymax, xmax = tuple(boxes[idx])

#         # "{}: {}%".format(class_names[i].decode("ascii"),int(100 * scores[i]))
#         display_str = f'Predicted class - {class_names[idx]}'

#         # color = colors[hash(class_names[idx]) % len(colors)]
#         image_pil = Image.fromarray(np.uint8(image)).convert("RGB")

#         # draw one bounding box and overlay the class labels onto the image
#         self.draw_bounding_box_on_image(image_pil,
#                                         ymin,
#                                         xmin,
#                                         ymax,
#                                         xmax,
#                                         # color,
#                                         font,
#                                         display_str_list=[display_str])
#         np.copyto(image, np.array(image_pil))

#         return image

#     def make_border(self, img_array, border_percent=0.03, display_shape=False):
#         border_type = cv2.BORDER_CONSTANT
#         img_shape = img_array.shape[:2]
#         # Initialize arguments for the filter
#         top = int(border_percent * img_array.shape[0])  # shape[0] = rows
#         bottom = top
#         left = int(border_percent * img_array.shape[1])  # shape[1] = cols
#         right = left

#         img_array = cv2.copyMakeBorder(
#             img_array, top, bottom, left, right, border_type, None, value=(255, 255, 255))
#         if display_shape:
#             print(img_shape)
#         img_array = cv2.resize(src=img_array, dsize=img_shape)

#         return img_array

#     def sharpen_image(self, img_array, sharpening_kernel):
#         # sharpening_kernel = np.array([[0, -1, 0],
#         #                       [-1, 5, -1],
#         #                       [0, -1, 0]])

#         img_array = cv2.filter2D(
#             src=img_array, ddepth=-1, kernel=sharpening_kernel)
#         return img_array

#     def increase_brightness_contrast(self, img_array):
#         alpha = 1.1  # Simple contrast control
#         beta = 1.2    # Simple brightness control
#         final_image = np.zeros(img_array.shape, img_array.dtype)
#         for y in range(img_array.shape[0]):
#             for x in range(img_array.shape[1]):
#                 for ch in range(img_array.shape[2]):
#                     final_image[y, x, ch] = np.clip(
#                         alpha*img_array[y, x, ch] + beta, 0, 255)

#         return final_image


In [None]:
# class LocalizeObjects(HelperFunctions):
#     def __init__(self, img_dsize=IMG_SIZE_TRAIN) -> None:

#         self._img_dsize = img_dsize
#         self.__model_ref = None
#         self._buffer_img_array = None
#         super().__init__()

#     def extract_model(self, file_path, dir):
#         model_file = tarfile.open(file_path)
#         model_file.extractall(dir)
#         model_file.close()

#     def load_model(self, model_type):
#         if isinstance(model_type, enum.Enum):
#             model_type = model_type.value

#         if model_type in OBJECT_DETECTION_MODEL_DICT:

#             model_dir = OBJECT_DETECTION_MODEL_DICT[model_type]['dir']
#             model_name = OBJECT_DETECTION_MODEL_DICT[model_type]['name']
#             file_path = os.path.join(model_dir, model_name)

#             if not pathlib.Path(f'{model_dir}/saved_model.pb').exists():
#                 self.extract_model(file_path, model_dir)

#             self._model_ref = hub.load(model_dir)
#             self._model_ref = self._model_ref.signatures['default']
#             print(f'Model loaded from -> {model_dir}')
#         else:
#             raise Exception(
#                 f'{model_type} is not mapped to any object detection model')

#     def read_resize_image(self, image_path, display=False):
#         try:
#             # image_path = str(os.path.abspath(image_path))
#             _, _img_array = self.read_single_image(_img_path=image_path)
#             _size_1 = _img_array.shape
#             _img_array = self.resize_image(
#                 _img_array, new_dsize=self._img_dsize)
#             self._buffer_img_array = _img_array

#             if display:
#                 print(f'Old size - {_size_1}, New size - {_img_array.shape}')
#                 self.display_image(img_array=_img_array, desc=f'Resized image')

#         except Exception as exp:
#             print(exp)

#     def predict_bbox(self):
#         _converted_img = tf.image.convert_image_dtype(
#             self._buffer_img_array, tf.float32)[tf.newaxis, ...]
#         result = self._model_ref(_converted_img)
#         result = {key: value.numpy() for key, value in result.items()}
#         #image_with_boxes = self.draw_boxes(img_array.numpy(), result["detection_boxes"],result["detection_class_entities"], result["detection_scores"])
#         return result

#     def draw_predicted_bbox(self, result_dict, is_diplay=False):
#         img_box = self.draw_boxes(self._buffer_img_array,
#                                   result_dict["detection_boxes"],
#                                   result_dict["detection_class_entities"],
#                                   result_dict["detection_scores"])
#         if is_diplay:
#             self.display_image(img_box)

#         return img_box

#     def get_object_dimensions(self, img_size, result_dict):
#         boxes = result_dict["detection_boxes"]
#         scores = result_dict["detection_scores"]
#         box = boxes[np.argmax(scores)]

#         ymin, xmin, ymax, xmax = tuple(box)
#         # idx = np.argmax(scores)
#         im_width, im_height = self._img_dsize

#         # scale the bounding box coordinates to the height and width of the image
#         (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
#                                       ymin * im_height, ymax * im_height)

#         return (left, right, top, bottom)


In [None]:
# ## Create class objects for image pre-processing ##

# lo = LocalizeObjects()
# hf = HelperFunctions()


In [None]:

# lo.load_model(model_type=model_type.inception_resnet_rcnn)


### Experiments with pre-processing pipeline

In [None]:
# exp_path, orig_image = hf.read_single_image(
#     sickness_name='algal leaf', randomly=True)

# hf.display_image(orig_image, desc='original_image')


In [None]:
# print(exp_path)
# lo.read_resize_image(str(exp_path), display=True)


In [None]:
# res_dict = lo.predict_bbox()


In [None]:
# res_dict['detection_class_entities'][np.argmax(res_dict['detection_scores'])]


In [None]:
# %time
# pred_bbox_img = lo.draw_predicted_bbox(res_dict, is_diplay=True)


In [None]:
# left, right, top, bottom = tuple(round(v)
#                                  for v in lo.get_object_dimensions(0, res_dict))
# left, right, top, bottom


In [None]:
# pred_bbox_img_cropped = pred_bbox_img[top:bottom, left:right]
# pred_bbox_img_cropped = hf.resize_image(
#     pred_bbox_img_cropped, new_dsize=IMG_SIZE_TRAIN)
# hf.display_image(pred_bbox_img_cropped, desc='Cropped image')


In [None]:
# pred_bbox_img_cropped_border = hf.make_border(
#     pred_bbox_img_cropped, border_percent=0.01)
# hf.display_image(pred_bbox_img_cropped_border, desc='image with border')


In [None]:
# pred_bbox_img_cropped_border_rembg_1 = hf.remove_bg(
#     pred_bbox_img_cropped_border)
# hf.display_image(pred_bbox_img_cropped_border_rembg_1,
#                  desc='BG removed method-1')


In [None]:
# %time

# pred_bbox_img_cropped_border_rembg_2 = hf.remove_bg2(
#     pred_bbox_img_cropped_border)
# hf.display_image(pred_bbox_img_cropped_border_rembg_2,
#                  desc='BG removed method-2')


In [None]:
# sharpening_kernel = np.array([[0, -1, 0],
#                               [-1, 5, -1],
#                               [0, -1, 0]])

# image_sharp = hf.sharpen_image(
#     pred_bbox_img_cropped_border_rembg_2, sharpening_kernel)
# hf.display_image(image_sharp, desc='Shapened image')


In [None]:
# final_image = hf.increase_brightness_contrast(image_sharp)
# hf.display_image(final_image)


In [None]:
# def plot_image(ax, img_array, desc, fontsize=14):
#     ax.imshow(img_array)
#     ax.tick_params(left=False, right=False, labelleft=False,
#                    labelbottom=False, bottom=False)
#     ax.set_title(desc, fontsize=fontsize)
#     # ax.locator_params(nbins=4)


In [None]:
# fig, ax = plt.subplots(nrows=2, ncols=4, figsize=[10, 10])

# plot_image(ax[0][0], orig_image, desc='original_image')
# plot_image(ax[0][1], pred_bbox_img, desc='predicted_bbox')
# plot_image(ax[0][2], pred_bbox_img_cropped, desc='cropped_image')
# plot_image(ax[0][3], pred_bbox_img_cropped_border,
#            desc='cropped_image_with_border')
# plot_image(ax[1][0], pred_bbox_img_cropped_border_rembg_1,
#            desc='bg_removed_1')
# plot_image(ax[1][1], pred_bbox_img_cropped_border_rembg_2,
#            desc='bg_removed_2')


# plot_image(ax[1][2], image_sharp, desc='sharpened_image')
# plot_image(ax[1][3], final_image, desc='bright_contrast_image')
# plt.tight_layout(pad=0.1, w_pad=0.1, h_pad=0.25)


### Pre-processing pipeline run automated

In [None]:
# %time
# hf = HelperFunctions()
# lo = LocalizeObjects()
# lo.load_model(model_type=model_type.inception_resnet_rcnn)
# sharpening_kernel = np.array([[0, -1, 0],
#                               [-1, 5, -1],
#                               [0, -1, 0]])


In [None]:
######################
##### Legacy code ####
######################


# for disease, path_lst in sickness_folder_dict.items():

#     tmp_path = path_lst[0]
#     tmp_dir = os.path.abspath(os.path.dirname(tmp_path))
#     processed_dir = os.path.join(tmp_dir, 'processed')
#     if not os.path.exists(processed_dir):
#         os.mkdir(processed_dir)
#     else:
#         shutil.rmtree(processed_dir)
#         os.mkdir(processed_dir)

#     dt_lst = []
#     for path in tqdm(path_lst, desc=f'Pre-processing images of {disease}...'):

#         t1 = datetime.datetime.now()
#         file_name = os.path.basename(path)

#         lo.read_resize_image(path)
#         res_dict = lo.predict_bbox()
#         pred_bbox_img = lo.draw_predicted_bbox(res_dict, is_diplay=False)

#         _top, _bottom, _left, _right = tuple(
#             round(v) for v in lo.get_object_dimensions(0, res_dict))
#         # _top, _bottom, _left, _right = _top-1, _bottom+1, _left-1, _right+1

#         pred_bbox_img_cropped = pred_bbox_img[_left:_right, _top:_bottom]

#         pred_bbox_img_cropped = hf.resize_image(
#             pred_bbox_img_cropped, new_dsize=IMG_SIZE_TRAIN)

#         pred_bbox_img_cropped_border = hf.make_border(
#             pred_bbox_img_cropped, border_percent=0.01)

#         pred_bbox_img_cropped_border_rembg_2 = hf.remove_bg2(
#             pred_bbox_img_cropped_border)

#         image_sharp = hf.sharpen_image(
#             pred_bbox_img_cropped_border_rembg_2, sharpening_kernel)

#         final_image = hf.increase_brightness_contrast(image_sharp)

#         file_path = os.path.join(processed_dir, file_name)
#         cv2.imwrite(file_path, final_image)

#         dt = datetime.datetime.now() - t1
#         dt_lst.append(dt.seconds+(dt.microseconds*10e-6))

#     print(
#         f'INFO: Average processing time for each image of {disease} is {sum(dt_lst)/len(dt_lst)} seconds')


In [None]:

# def run_image_preprocessing():
#     hf = HelperFunctions()
#     lo = LocalizeObjects()
#     lo.load_model(model_type=model_type.inception_resnet_rcnn)
#     sharpening_kernel = np.array([[0, -1, 0],
#                                   [-1, 5, -1],
#                                   [0, -1, 0]])

#     for disease, path_lst in sickness_folder_dict.items():

#         tmp_path = path_lst[0]
#         tmp_dir = os.path.abspath(os.path.dirname(tmp_path))
#         processed_dir = os.path.join(tmp_dir, 'processed')
#         if not os.path.exists(processed_dir):
#             os.mkdir(processed_dir)
#         else:
#             shutil.rmtree(processed_dir)
#             os.mkdir(processed_dir)

#         dt_lst = []
#         for path in tqdm(path_lst, desc=f'Pre-processing images of {disease}...'):

#             t1 = datetime.datetime.now()
#             file_name = os.path.basename(path)

#             lo.read_resize_image(path)
#             res_dict = lo.predict_bbox()
#             pred_bbox_img = lo.draw_predicted_bbox(res_dict, is_diplay=False)

#             _top, _bottom, _left, _right = tuple(
#                 round(v) for v in lo.get_object_dimensions(0, res_dict))
#             # _top, _bottom, _left, _right = _top-1, _bottom+1, _left-1, _right+1

#             pred_bbox_img_cropped = pred_bbox_img[_left:_right, _top:_bottom]

#             pred_bbox_img_cropped = hf.resize_image(
#                 pred_bbox_img_cropped, new_dsize=IMG_SIZE_TRAIN)

#             pred_bbox_img_cropped_border = hf.make_border(
#                 pred_bbox_img_cropped, border_percent=0.01)

#             pred_bbox_img_cropped_border_rembg_2 = hf.remove_bg2(
#                 pred_bbox_img_cropped_border)

#             image_sharp = hf.sharpen_image(
#                 pred_bbox_img_cropped_border_rembg_2, sharpening_kernel)

#             final_image = hf.increase_brightness_contrast(image_sharp)

#             file_path = os.path.join(processed_dir, file_name)
#             cv2.imwrite(file_path, final_image)

#             dt = datetime.datetime.now() - t1
#             dt_lst.append(dt.seconds+(dt.microseconds*10e-6))

#         print(
#             f'INFO: Average processing time for each image of {disease} is {sum(dt_lst)/len(dt_lst)} seconds')


In [None]:
# %time
# run_image_preprocessing()


# 4. Preprocessed Dataset loading pipeline

### 4.1 Load the preprocessed data

In [None]:
def get_preprocessed_fp_dict(old_sickness_folder_dict):
    sickness_folder_processed_dict = defaultdict()
    for key, values in old_sickness_folder_dict.items():
        dir_name = os.path.dirname(values[0])
        new_path_lst = []
        for val in values:

            new_path = os.path.join(
                dir_name, 'processed', os.path.basename(val))
            new_path_lst.append(pb.Path(new_path))

        sickness_folder_processed_dict.update({key: new_path_lst})

    return sickness_folder_processed_dict


In [None]:
sickness_folder_processed_dict = get_preprocessed_fp_dict(sickness_folder_dict)


### 4.2 Split the image data paths into training and testing 

In [None]:

def get_train_test_fp_list(processed_sickness_folder_dict, train_split_ratio=0.8, test_split_ratio=0.2):
    train_img_path_lst = []
    test_img_path_lst = []

    for _, values in sickness_folder_processed_dict.items():
        number_of_images = len(values)
        values_lst = list(values)
        number_of_train_images = int(
            np.ceil(number_of_images*train_split_ratio))
        number_of_test_images = int(
            np.floor(number_of_images*test_split_ratio))

        train_image_paths = np.random.choice(
            values_lst, number_of_train_images)
        test_image_paths = np.random.choice(values_lst, number_of_test_images)

        train_img_path_lst.extend(train_image_paths)
        test_img_path_lst.extend(test_image_paths)

    return (train_img_path_lst, test_img_path_lst)


In [None]:
train_img_path_lst, test_img_path_lst = get_train_test_fp_list(
    sickness_folder_processed_dict)

print(
    f'Train images # -> {len(train_img_path_lst)}\nTest images # -> {len(test_img_path_lst)}')


In [None]:
# ### TEST ###
# a = np.array(range(10), dtype=np.float32).reshape(-1,1)
# b = np.array(range(10), dtype=np.uint8).reshape(-1,1)

# tf.concat([a,b], axis = 1)


In [None]:
def get_img_path_and_labels_df(img_path_lst, parents_idx=1) -> pd.DataFrame:
    """Generate a pandas dataframe having the image paths, disease and OHE vectors as columns"""
    df = pd.DataFrame(img_path_lst, columns=['paths'])
    df['paths'] = df['paths'].apply(lambda x: str(os.path.abspath(x)))

    df['disease'] = df['paths'].apply(
        lambda x: os.path.basename(list(pb.Path(x).parents)[parents_idx]))

    disease_df = pd.get_dummies(df['disease'])  # , dtype = np.float16)
    df = pd.concat([df, disease_df], axis=1)
    df.columns = df.columns.str.lower()
    df['disease'] = df['disease'].str.lower()
    return df


In [None]:
train_data_df = get_img_path_and_labels_df(train_img_path_lst, parents_idx=1)
train_data_df  # .to_clipboard()


In [None]:
train_data_df['disease'].unique()


In [None]:
actual_data_df = train_data_df.copy()



<br>
<br>
<br>
<br>

### 4.3 Geometrical Augmentation



In [None]:
def geometrical_augmentation(tmp_src_aug_dir, dst_aug_dir, number_of_image_to_augment):
    p = Augmentor.Pipeline(tmp_src_aug_dir, output_directory=dst_aug_dir)

    p.flip_random(probability=0.5)

    # Add a shear operation to the pipeline
    p.shear(probability=0.2, max_shear_left=0.3, max_shear_right=0.3)

    # Add a rotate operation to the pipeline:
    p.rotate_random_90(probability=0.5)

    p.skew_corner(probability=0.3)

    # We are adding such that total # of  samples per class is same i.e. 1000, to make sure that none
    p.sample(number_of_image_to_augment)


In [None]:
def handle_augmentation(train_data_df, augment_split_ratio=0.6):
    aug_count = 0
    for disease in tqdm(train_data_df['disease'].unique(), desc='Augmenting images...'):
        df_path_disease_lst = train_data_df[train_data_df['disease']
                                            == disease]['paths'].values.tolist()
        number_of_image_to_augment = int(
            np.floor(len(df_path_disease_lst)*augment_split_ratio))
        print(
            f'For {disease}, number of images to augment is {number_of_image_to_augment}')

        # processed dir path of a sickness folder
        dir_path = os.path.dirname(df_path_disease_lst[0])

        # temporary augmentation directory for storing buffer data
        tmp_src_aug_dir = os.path.join(dir_path, 'tmp_aug')
        if os.path.exists(tmp_src_aug_dir):
            shutil.rmtree(tmp_src_aug_dir)
        else:
            os.mkdir(tmp_src_aug_dir)

        # copying the train images of a disease into the temporary directory
        for img_path in df_path_disease_lst:
            tmp_path = os.path.join(
                tmp_src_aug_dir, os.path.basename(img_path))
            shutil.copy(img_path, tmp_path)

        # augmentation destination directory
        dst_aug_dir = os.path.join(dir_path, 'output')

        # deleting the previous augmentation data from the destibnation folder
        if os.path.exists(dst_aug_dir):
            shutil.rmtree(dst_aug_dir)

        geometrical_augmentation(
            tmp_src_aug_dir, dst_aug_dir, number_of_image_to_augment)

        # deleting the temporary augmentation source directory after usage
        if os.path.exists(tmp_src_aug_dir):
            shutil.rmtree(tmp_src_aug_dir)

        aug_count += number_of_image_to_augment

    print(f'Number of images augmented is {aug_count}')


In [None]:
# # Run the image augmentation

# handle_augmentation(train_data_df)


### 4.4 Load the baseline model dataset

#### 4.4.1 Load all the augmented image paths

In [None]:
def get_augmented_image_list(train_data_df):
    tmp_train_data_df = train_data_df.copy()
    aug_img_list_all_disease = []
    for disease in train_data_df['disease'].unique():
        tmp_df = pd.DataFrame()
        df_path_disease_lst = train_data_df[train_data_df['disease']
                                            == disease]['paths'].values.tolist()
        dir_path = os.path.dirname(df_path_disease_lst[0])
        # print(dir_path)
        aug_dir_path = os.path.join(dir_path, 'output')
        aug_img_list = list(pb.Path(aug_dir_path).glob('*.jpg'))

        if len(aug_img_list) > 0:
            aug_img_list_all_disease.extend(aug_img_list)
        else:
            raise Exception(
                f'No augmentation output folder found in {os.path.dirname(aug_dir_path)}')
    return aug_img_list_all_disease


In [None]:
aug_img_list_all_disease = get_augmented_image_list(train_data_df)
len(aug_img_list_all_disease)


In [None]:
augmented_df = get_img_path_and_labels_df(
    aug_img_list_all_disease, parents_idx=2)
augmented_df


#### 4.4.2 Concatenate the real and geometrically augmented image paths

In [None]:
print(f'Before augmentation - {train_data_df.shape}')
train_data_df = pd.concat([train_data_df, augmented_df], axis=0)
print(f'After augmentation - {train_data_df.shape}')


In [None]:
train_data_df


In [None]:
train_data_df.disease.value_counts().plot.bar(color='y')
plt.show()


#### 4.4.3 Getting X-image paths and y-label one-hot encoded vectors

In [None]:
paths_X_train = train_data_df['paths'].values.tolist()

[paths_X_train[i] for i in np.random.randint(
    low=0, high=len(paths_X_train)-1, size=5)]


In [None]:
lables_y_train = train_data_df[train_data_df['disease'].unique()].values

[lables_y_train[i] for i in np.random.randint(
    low=0, high=len(lables_y_train)-1, size=5)]


In [None]:
test_data_df = get_img_path_and_labels_df(test_img_path_lst, parents_idx=1)
test_data_df


In [None]:
paths_X_test = test_data_df['paths'].values.tolist()


[paths_X_test[i] for i in np.random.randint(
    low=0, high=len(paths_X_test)-1, size=5)]


In [None]:
lables_y_test = test_data_df[test_data_df['disease'].unique()].values

[lables_y_test[i] for i in np.random.randint(
    low=0, high=len(lables_y_test)-1, size=5)]


#### 4.4.4 Splitting the test set into test set and validation set

In [None]:
print(len(paths_X_test))


In [None]:
paths_X_test, paths_X_val, lables_y_test, lables_y_val = train_test_split(
    paths_X_test, lables_y_test, train_size=0.6, test_size=0.4, random_state=42)


In [None]:
len(paths_X_test), lables_y_val.shape


### 4.5 Image dataset streamimg

In [None]:
def map_func(image_path, lables):
    """ This function will take the image_path & caption and return it's feature & respective caption. """

    img_tensor = skimage.io.imread(image_path.decode('utf-8'))
    img_tensor = img_tensor.astype(np.float32)
    lables = lables.astype(np.float32)

    return img_tensor, lables


In [None]:
def gen_dataset(img_path, lables, buffer_size=BUFFER_SIZE, batch_size=BATCH_SIZE_DATASET):
    """ Dataset generator function enabling shuffling and parallel CPU core execution """

    # Load up the sliced image path and respective captions
    dataset = tf.data.Dataset.from_tensor_slices((img_path, lables))

    # Maps on top the sliced image and label data
    # and apply the map_func by using auto-tuning of CPU cores
    dataset = dataset.map(lambda path, label: tf.numpy_function(map_func, [path, label], [tf.float32, tf.float32]),
                          num_parallel_calls=tf.data.experimental.AUTOTUNE,)  # name='map_function')

    # Shuffle the data of size equal to batch size and with a prefetch
    # buffer memory of size 1000 data point using the auto-tuning of CPU cores.
    dataset = (dataset.shuffle(buffer_size=buffer_size, reshuffle_each_iteration=True)).batch(
        batch_size=batch_size, drop_remainder=False).prefetch(tf.data.AUTOTUNE)

    return dataset


In [None]:
# train_dataset = gen_dataset(paths_X_train, lables_y_train)


In [None]:
# val_datset = gen_dataset(paths_X_val, lables_y_val)


In [None]:
# test_dataset = gen_dataset(paths_X_test, lables_y_test)


In [None]:
# sample_img_batch, sample_label_batch = next(iter(train_dataset))
# print(sample_img_batch.shape)  # (batch_size, 8*8, 2048)
# print(sample_label_batch.shape)  # (batch_size,max_len)


In [None]:

# def get_epoch_settings(paths_X_train, paths_X_val, batch_size=BATCH_SIZE_GAN_TRAIN):
#     num_train_sequences = len(paths_X_train)
#     num_val_sequences = len(paths_X_val)

#     if (num_train_sequences % batch_size) == 0:

#         steps_per_epoch = int(num_train_sequences/batch_size)
#         print(
#             f'data size is factor of batch size({batch_size}), thus step/epoch = {steps_per_epoch}')
#     else:

#         steps_per_epoch = (num_train_sequences//batch_size) + 1
#         print(
#             f'data size is not factor of batch size({batch_size}), thus step/epoch = {steps_per_epoch}')

#     if (num_val_sequences % batch_size) == 0:
#         validation_steps = int(num_val_sequences/batch_size)
#     else:
#         validation_steps = (num_val_sequences//batch_size) + 1

#     return steps_per_epoch, validation_steps


## 5. Baseline classification model

In [None]:
# steps_per_epoch, validation_steps = get_epoch_settings(
#     paths_X_train, paths_X_val, batch_size=BATCH_SIZE_TRAIN)
# steps_per_epoch, validation_steps


#### 5.1 Baseline model building

In [None]:
# ### Building model through TF sub-classing ####

# class DenseConnection(tf.keras.Model):
#     def __init__(self, kernels, activations, name) -> None:
#         super(DenseConnection, self).__init__()
#         self.dense = tf.keras.layers.Dense(
#             kernels, activation=activations, name=name)

#     def call(self, input):
#         x = self.dense(input)
#         return x


# class DenseConnectionWithDropout(DenseConnection):
#     def __init__(self, kernels, activations, name, dropout) -> None:
#         super().__init__(kernels, activations, name)
#         self.dropout = tf.keras.layers.Dropout(dropout)

#     def call(self, input):
#         x = self.dense(input)
#         x = self.dropout(x)
#         return x


# class ModelBuilder(tf.keras.Model):
#     def __init__(self, pretrained_model_ref, num_classes) -> None:
#         super(ModelBuilder, self).__init__()
#         self.pretrained_model_ref = pretrained_model_ref
#         self.global_pool = tf.keras.layers.GlobalMaxPooling2D()
#         self.flatten = tf.keras.layers.Flatten()
#         self.dense_1 = DenseConnection(64, activations='relu', name='dense_1')
#         self.dense_2 = DenseConnectionWithDropout(
#             64, activations='relu', dropout=0.3, name='dense_2')
#         self.dense_3 = DenseConnection(32, activations='relu', name='dense_3')
#         self.dense_4 = DenseConnectionWithDropout(
#             32, activations='relu', dropout=0.2, name='dense_4')
#         self.dense_5 = DenseConnectionWithDropout(
#             16, activations='relu', dropout=0.2, name='dense_5')
#         self.class_output = tf.keras.layers.Dense(
#             num_classes, activation='softmax', name='output')

#     def call(self, input):
#         for layer in self.pretrained_model_ref.layers:
#             layer.trainable = False

#         x = self.pretrained_model_ref(input)
#         x = self.global_pool(x)
#         x = self.flatten(x)
#         x = self.dense_1(x)
#         x = self.dense_2(x)
#         x = self.dense_3(x)
#         x = self.dense_4(x)
#         x = self.dense_5(x)
#         x = self.class_output(x)
#         return x


In [None]:
# class ModelHelper():
#     def __init__(self, model, LR):
#         self.model = model
#         self.adam = tf.keras.optimizers.Adam(learning_rate=LR)
#         self.sgd = tf.keras.optimizers.SGD(learning_rate=LR)
#         self.rms_prop = tf.keras.optimizers.RMSprop(learning_rate=LR)

#     def compile_model(self, optimizer_name):
#         optimzer = self.__getattribute__(optimizer_name)
#         self.model.compile(optimizer=optimzer,
#                            loss=tf.keras.losses.categorical_crossentropy,
#                            metrics=['categorical_accuracy'])

#         return self.model

#     def decrease_lr_on_plateau_callback(self, **kwargs):
#         return tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
#                                                     factor=0.05,
#                                                     patience=10,
#                                                     min_delta=0.01,
#                                                     min_lr=1e-6,
#                                                     verbose=1)

#     def model_checkpoints_callback(self, filepath, monitor: str = "val_loss", mode: str = "auto", save_freq="epoch", save_best_only=True):
#         return tf.keras.callbacks.ModelCheckpoint(filepath,
#                                                   monitor=monitor,
#                                                   mode=mode,
#                                                   save_freq=save_freq,
#                                                   save_best_only=save_best_only)

#     def plot_model_train_info(self, model_history):
#         pd.DataFrame(model_history.history).plot(figsize=(8, 5))
#         plt.grid(True)
#         plt.title('Model train information', fontsize=14)
#         plt.xlabel('epoch #')
#         plt.show()

#     def get_model_info(self, model_history, col_info):
#         history_df = pd.DataFrame(model_history.history)
#         loss_min = history_df['loss'].min()
#         val_loss_min = history_df['val_loss'].min()
#         categorical_accuracy_max = history_df['categorical_accuracy'].max()
#         val_categorical_accuracy_max = history_df['val_categorical_accuracy'].max(
#         )

#         return pd.DataFrame([loss_min, val_loss_min, categorical_accuracy_max, val_categorical_accuracy_max],
#                             index=['loss_min', 'val_loss_min',
#                                    'categorical_accuracy_max', 'val_categorical_accuracy_max'],
#                             columns=[col_info])


In [None]:
# def dump_model(model, directory, info):
#     model.save_weights(f'{directory}/{info}.h5')


# def get_metrics(y_true, y_pred, name):
#     metrics_dict = {}

#     recall = tf.keras.metrics.Recall()
#     recall.update_state(y_true=y_true, y_pred=y_pred)
#     metrics_dict['recall'] = recall.result().numpy()

#     precision = tf.keras.metrics.Precision()
#     precision.update_state(y_true=y_true, y_pred=y_pred)
#     metrics_dict['precision'] = precision.result().numpy()

#     auc = tf.keras.metrics.AUC()
#     auc.update_state(y_true=y_true, y_pred=y_pred)
#     metrics_dict['auc'] = auc.result().numpy()

#     metrics = pd.DataFrame(data=metrics_dict.values(),
#                            index=metrics_dict.keys(), columns=[name])

#     return metrics


# def get_metrics_for_model(model, model_name):
#     t_o = time.perf_counter()
#     y_true = []
#     data = []

#     for d, v in test_dataset:
#         y_true.extend(v.numpy())
#         data.extend(d.numpy())

#     y_true = np.array(y_true)
#     data = np.array(data)
#     y_pred = model.predict(data)
#     metrics_df = get_metrics(y_true, y_pred, model_name)
#     print(time.perf_counter() - t_o)
#     return metrics_df


### 5.2 Densenet

In [None]:
# INPUT_SHAPE = (180, 180, 3)
# INPUT_SHAPE_BUILD = (None, 180, 180, 3)


In [None]:
# densenet = DenseNet121(
#     include_top=False, weights='imagenet', input_shape=INPUT_SHAPE)
# densenet_model = ModelBuilder(pretrained_model_ref=densenet, num_classes=8)


In [None]:
# model_helper = ModelHelper(model=densenet_model, LR=0.0001)
# densenet_model = model_helper.compile_model(optimizer_name='adam')
# densenet_model.build(input_shape=INPUT_SHAPE_BUILD)
# densenet_model.summary()


In [None]:
# callbacks_lst = [model_helper.decrease_lr_on_plateau_callback()]
# densenet_model_history = densenet_model.fit(train_dataset,
#                                             steps_per_epoch=steps_per_epoch,
#                                             epochs=100,
#                                             verbose=1,
#                                             callbacks=callbacks_lst,
#                                             validation_data=val_datset,
#                                             validation_steps=validation_steps,
#                                             class_weight=None,
#                                             workers=1,
#                                             initial_epoch=0)


In [None]:
# model_helper.plot_model_train_info(densenet_model_history)


In [None]:
# densenet_info = model_helper.get_model_info(densenet_model_history, 'densenet')
# densenet_info


In [None]:
# densenet_model.evaluate(test_dataset)


In [None]:
# densenet_metrics = get_metrics_for_model(
#     model=densenet_model, model_name='densenet')
# densenet_metrics


In [None]:
# #
# dump_model(densenet_model, 'G:/Learning/Degree Courses/MS AI ML/Research/Tea Sickeness Project/Code/output',
#            'densenet_model_with_aug')


### 5.3 ResNet-50

In [None]:
# resnet = resnet50.ResNet50(
#     include_top=False, weights='imagenet', input_shape=INPUT_SHAPE)
# resnet_model = ModelBuilder(pretrained_model_ref=resnet, num_classes=8)
# # densenet_model.summary()


In [None]:
# model_helper = ModelHelper(model=resnet_model, LR=0.0001)
# resnet_model = model_helper.compile_model(optimizer_name='adam')
# resnet_model.build(input_shape=INPUT_SHAPE_BUILD)
# resnet_model.summary()


In [None]:
# callbacks_lst = [model_helper.decrease_lr_on_plateau_callback()]
# resnet_model_history = resnet_model.fit(train_dataset,
#                                         steps_per_epoch=steps_per_epoch,
#                                         epochs=100,
#                                         verbose=1,
#                                         callbacks=callbacks_lst,
#                                         validation_data=val_datset,
#                                         validation_steps=validation_steps,
#                                         class_weight=None,
#                                         workers=1,
#                                         initial_epoch=0)


In [None]:
# model_helper.plot_model_train_info(resnet_model_history)


In [None]:
# resnet_info = model_helper.get_model_info(resnet_model_history, 'resnet')
# resnet_info


In [None]:
# resnet_model.evaluate(test_dataset)


In [None]:
# resnet_metrics = get_metrics_for_model(model=resnet_model, model_name='resnet')
# resnet_metrics


In [None]:
# #
# dump_model(resnet_model, 'G:/Learning/Degree Courses/MS AI ML/Research/Tea Sickeness Project/Code/output',
#            'resnet_model_with_aug')


### 5.3 Mobilenet

In [None]:
# mobilenet = MobileNetV2(
#     include_top=False, weights='imagenet', input_shape=INPUT_SHAPE)
# mobilenet_model = ModelBuilder(pretrained_model_ref=mobilenet, num_classes=8)


In [None]:
# model_helper = ModelHelper(model=mobilenet_model, LR=0.0001)
# mobilenet_model = model_helper.compile_model(optimizer_name='adam')
# mobilenet_model.build(input_shape=INPUT_SHAPE_BUILD)
# mobilenet_model.summary()


In [None]:
# callbacks_lst = [model_helper.decrease_lr_on_plateau_callback()]

# mobilenet_model_history = mobilenet_model.fit(train_dataset,
#                                               steps_per_epoch=steps_per_epoch,
#                                               epochs=100,
#                                               verbose=1,
#                                               callbacks=callbacks_lst,
#                                               validation_data=val_datset,
#                                               validation_steps=validation_steps,
#                                               class_weight=None,
#                                               workers=1,
#                                               initial_epoch=0)


In [None]:
# model_helper.plot_model_train_info(mobilenet_model_history)


In [None]:
# mobilenet_info = model_helper.get_model_info(
#     mobilenet_model_history, 'mobilenet')
# mobilenet_info


In [None]:
# mobilenet_model.evaluate(test_dataset)


In [None]:
# mobilenet_metrics = get_metrics_for_model(
#     model=mobilenet_model, model_name='mobilenet')
# mobilenet_metrics


In [None]:
# #
# dump_model(mobilenet_model, 'G:/Learning/Degree Courses/MS AI ML/Research/Tea Sickeness Project/Code/output',
#            'mobilenet_model_with_aug')


### 5.6 Comparing models trained on real + geometrically augmented data

In [None]:

# all_model_info = pd.concat(
#     [densenet_info, resnet_info, mobilenet_info], axis=1)
# all_model_metrics = pd.concat(
#     [densenet_metrics, resnet_metrics, mobilenet_metrics], axis=1)
# evaulation_info = pd.concat([all_model_info, all_model_metrics], axis=0)
# evaulation_info


In [None]:
# evaulation_info.plot.bar(figsize=[7, 5], grid=True, fontsize=14)
# plt.show()


## 6 GANs

#### 6.1.1 Constants and hyperparameters for GAN

In [None]:
# global
num_channels = 3
num_classes = 8

# for conditional gan
latent_dim = 128

# for wgan & wgan-sn
noise_dim = 256

# Set the number of epochs for trainining.
wgan_epochs = 200

train_wgan = True
train_wgan_res = True


In [None]:
generator_in_channels = latent_dim + num_classes
discriminator_in_channels = num_channels + num_classes
print(generator_in_channels, discriminator_in_channels)



#### 6.1.2 Save the model weights and outputs into zip file

In [None]:

def get_all_file_as_zip(model_history, gan_name='wgan', delete_prev=False):

    if delete_prev:
        if os.path.exists('./output'):
            shutil.rmtree('./output')

        if os.path.exists('./output.zip'):
            os.remove('./output.zip')

        if not os.path.exists('./output'):
            os.makedirs('./output')

    timestamp = (datetime.datetime.now() + datetime.timedelta(hours=5,
                 minutes=30)).strftime('%d_%m_%y_%H-%M-%S')
    pd.DataFrame(model_history.history).to_csv(
        f'./output/{gan_name}_history_{timestamp}.csv')

    shutil.move('./checkpoints', './output', )

    shutil.make_archive(f'output_{gan_name}_{timestamp}', 'zip', './output')

    os.chdir(r'/kaggle/working')
    FileLinks('./')


### 6.2 Streamlining the data loading for gan training

In [None]:
gan_train_img_path_lst, _ = get_train_test_fp_list(sickness_folder_processed_dict,
                                                   train_split_ratio=1.0,
                                                   test_split_ratio=0.0)

print(f'Train images # -> {len(gan_train_img_path_lst)}')


### 6.3 Load the GAN training dataset

In [None]:
gan_train_data_df = get_img_path_and_labels_df(gan_train_img_path_lst)


gan_train_data_df.head()


In [None]:
gan_paths_X_train = gan_train_data_df['paths'].values.tolist()

[gan_paths_X_train[i] for i in np.random.randint(
    low=0, high=len(gan_paths_X_train)-1, size=5)]


In [None]:
# .astype(np.float32)
gan_lables_y_train = gan_train_data_df[gan_train_data_df['disease'].unique(
)].values

[gan_lables_y_train[i] for i in np.random.randint(
    low=0, high=len(gan_lables_y_train)-1, size=5)]


In [None]:
gan_train_dataset = gen_dataset(
    gan_paths_X_train, gan_lables_y_train, batch_size=BATCH_SIZE_DATASET)


In [None]:
sample_gan_X_train_data, sample_gan_y_train_data = next(
    iter(gan_train_dataset))
sample_gan_X_train_data.shape, sample_gan_y_train_data.shape


### 6.5 Wasserstein GAN (WGAN) with Gradient Penalty (GP)


Info: https://arxiv.org/abs/1701.07875

#### 6.5.1 WGAN-Discriminator

In [None]:
def build_discriminator_wgan(discriminator_in_channels):
    discriminator = Sequential(
        [
            InputLayer(input_shape=(180, 180, discriminator_in_channels)),
            layers.ZeroPadding2D((2, 2)),
            Conv2D(filters=64, kernel_size=(5, 5), strides=(2, 2),
                   use_bias=True, activation=layers.LeakyReLU(0.2), padding='same'),
            Conv2D(filters=128, kernel_size=(5, 5), strides=(2, 2),
                   use_bias=True, activation=layers.LeakyReLU(0.2), padding='same'),
            layers.Dropout(0.3),
            Conv2D(filters=256, kernel_size=(5, 5), strides=(2, 2),
                   use_bias=True, activation=layers.LeakyReLU(0.2), padding='same'),
            layers.Dropout(0.3),
            Conv2D(filters=512, kernel_size=(5, 5), strides=(2, 2),
                   use_bias=True, activation=layers.LeakyReLU(0.2)),
            Flatten(),
            Dropout(0.2),
            Dense(256, activation='relu', name = "full_con_1"),
            Dropout(0.1),
            Dense(128, activation='relu', name = "full_con_2"),
            Dropout(0.1),
            Dense(1, activation=LeakyReLU(), name = "output"),
        ],
        name="discrinator_wgan"
    )

    return discriminator

In [None]:

d_model = build_discriminator_wgan(discriminator_in_channels)
d_model.summary()


In [None]:
tf.keras.utils.plot_model(d_model, show_shapes=True,
                          to_file='d_model.png', dpi=70)


#### 6.5.2 WGAN - Attention Network

In [None]:
def generate_mobilenet_feature_extrator(slice_layer_number=53, show_summary=False):
    mv = MobileNetV2(include_top=False, weights='imagenet',
                     input_shape=(180, 180, 3))
    for layers in mv.layers:
        layers.trainable = False
    _input = mv.input
    _output = mv.layers[slice_layer_number].output

    attn_model_feature_extractor = keras.Model(
        inputs=_input, outputs=_output, name='attn_model_feature_extractor')
    if show_summary:
        attn_model_feature_extractor.summary()

    return attn_model_feature_extractor


def conv_layers(kernels):
    return Conv2D(filters=kernels, kernel_size=(3, 3))


class AttentionModel(Model):
    """ Model to generate the context vector by using feature map and hidden state of decoder """

    def __init__(self, units):
        self.units = units
        super(AttentionModel, self).__init__()
        # Dense layer to take inputs of pretrained model

        self.pretrained_model = generate_mobilenet_feature_extrator(
            slice_layer_number=53)

        self.conv_layers_1 = conv_layers(self.units)
        self.bn_1 = keras.layers.BatchNormalization()

        self.conv_layers_2 = conv_layers(self.units)
        self.bn_2 = keras.layers.BatchNormalization()

        self.W1 = tf.keras.layers.Dense(self.units, name='1st_dense')

        # Dense layer having one neuron to hold the score of a context vector
        self.V = tf.keras.layers.Dense(1, name='score')

        # self.aggregate_dense = tf.keras.layers.Dense(units)

    def call(self, imgs):

        feature_maps = self.pretrained_model(imgs)
        features = self.conv_layers_1(feature_maps)
        features = self.bn_1(features)
        features = PixelAttention2D(
            features.shape[-1], name="pix-atten-1")(features)

        features = self.conv_layers_2(features)
        features = self.bn_2(features)
        features = PixelAttention2D(
            features.shape[-1], name="pix-atten-2")(features)

        # build your score funciton to shape: (batch_size, 8*8, units)
        attention_hidden_layer = tf.keras.activations.tanh(self.W1(features))

        # score shape == (batch_size, 64, 1)
        # This gives you an unnormalized score for each image feature.
        score = self.V(attention_hidden_layer)

        # extract your attention weights with shape: (batch_size, 8*8, 1)
        attention_weights = tf.keras.activations.softmax(score, axis=1)

        # shape: create the context vector with shape (batch_size, 8*8,embedding_dim)
        # context_vector = attention_weights * features
        context_vector = tf.matmul(
            attention_weights, features, transpose_a=True)

        context_vector = tf.reshape(
            context_vector, [-1, 19, 32])

        return context_vector

    def summary(self):
        x = keras.layers.Input(shape=(180, 180, 3))
        model = keras.Model(inputs=[x], outputs=self.call(x))
        return model.summary()

In [None]:
am = AttentionModel(32)
am.summary()


#### 6.5.3 WGAN - Generator

In [None]:
def upsample_block(x,
                   filters,
                   activation,
                   kernel_size=(3, 3),
                   strides=(1, 1),
                   up_size=(2, 2),
                   padding="same",
                   use_bn=False,
                   use_sn=False,
                   use_bias=True,
                   use_dropout=False,
                   use_gn = False,
                   drop_value=0.3,
                   name=''
                   ):

    x = layers.UpSampling2D(up_size, name=name + '-upsample')(x)

    # if spectral normalization is enabled
    if use_sn:
        conv2d_sn = tfa.layers.SpectralNormalization(layers.Conv2D(
            filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias, name=name+'-con2d'))
        x = conv2d_sn(x)
    else:
        x = layers.Conv2D(filters, kernel_size, strides=strides,
                          padding=padding, use_bias=use_bias, name=name+'-con2d')(x)

    if use_bn:
        x = layers.BatchNormalization()(x)
    if use_gn:
        x = layers.GroupNormalization()(x)

    if activation:
        x = activation(x)
    if use_dropout:
        x = layers.Dropout(drop_value)(x)
    return x


In [None]:
def get_generator_model(add_residiual=False):
    noise = layers.Input(shape=(noise_dim,), name='noise')
    labels = layers.Input(shape=(num_classes,), name='lables')
    images = layers.Input(shape=(180, 180, 3), name='images')

    attention_vectors = AttentionModel(units=32)(images)

    if add_residiual:
        # Residual attention feed-fwd
        residual_atten_vectors = attention_vectors
        residual_atten_vectors = tf.keras.layers.Conv1D(
            filters=60, kernel_size=3, activation='relu', name='residual_conv1')(residual_atten_vectors)
        residual_atten_vectors = tf.keras.layers.PReLU(alpha_initializer= "zeros", name = "Preule_1")(residual_atten_vectors)
        residual_atten_vectors = tf.keras.layers.UpSampling1D()(residual_atten_vectors)
        residual_atten_vectors = tf.keras.layers.Conv1D(
            filters=30, kernel_size=5, activation='relu', name='residual_conv2')(residual_atten_vectors)
        residual_atten_vectors = tf.keras.layers.PReLU(alpha_initializer= "zeros", name  = "Preule_2")(residual_atten_vectors)
        residual_atten_vectors = tf.keras.layers.BatchNormalization(
            name='residual_bn')(residual_atten_vectors)

    batch_size_train = tf.shape(images)[0]

    # shape = TensorShape([4, 8, 32])
    random_noise_vectors = tf.reshape(noise, shape=(batch_size_train, 8, 32))

    # shape = TensorShape([4, 19, 8]) -> random_noise_vectors x attention_vectors
    attention_vectors_with_noise = tf.matmul(
        random_noise_vectors, attention_vectors, transpose_b=True, name='noise_matmul_atten_vectors')

    # shape = TensorShape([4, 152])
    attention_vectors_with_noise = tf.reshape(
        attention_vectors_with_noise, shape=(batch_size_train, 152))

    # shape = TensorShape([4, 160])
    input_data = tf.concat([attention_vectors_with_noise, labels], axis=1)

    x = layers.Dense(15 * 15 * 256, use_bias=False)(input_data)  # 57600
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(0.2)(x)

    if add_residiual:
        # concatenating the residual_atten_vectors
        residual_atten_vectors = Flatten(name='residual_flatten')(residual_atten_vectors)  # 900
        x = tf.concat([x, residual_atten_vectors],
                      axis=1, name='concat_residual')

        x = layers.Reshape((15, 15, 260), name='residual_reshape')(x)  # 58050

    else:
        x = layers.Reshape((15, 15, 256))(x)  # 58050

    x = upsample_block(x, 128, layers.LeakyReLU(0.2), strides=(1, 1), up_size=(
        4, 4), use_bias=False, use_gn=True, padding="same", use_dropout=False, name='1')

    x = upsample_block(x, 64, layers.LeakyReLU(0.2), strides=(1, 1), use_bias=False,
                       use_gn=True, up_size=(3, 3), padding="same", use_dropout=False, name='2')

    x = upsample_block(x, 3, layers.Activation("tanh"), up_size=(
        1, 1), strides=(1, 1), use_bias=False, use_bn=True, name='3')

    g_model = keras.models.Model([noise, labels, images], x, name="generator")
    return g_model


In [None]:

g_model = get_generator_model(add_residiual=False)
g_model.summary()


In [None]:
tf.keras.utils.plot_model(g_model, show_shapes=True,
                          show_layer_names=True,  to_file='g_model.png', dpi=70)


#### 6.5.4 Build WGAN model 

In [None]:
class WGAN(keras.Model):
    def __init__(self,discriminator,generator,latent_dim,discriminator_extra_steps=5,gp_weight=10.0,):

        super(WGAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.d_steps = discriminator_extra_steps
        self.gp_weight = gp_weight
        self.attention_block = AttentionModel(units=32)
        self.preprocess_mobilenet = tf.keras.applications.mobilenet_v2.preprocess_input

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(WGAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def gradient_penalty(self, batch_size, real_images, fake_images):
        """Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # Get the interpolated image
        alpha = tf.random.normal(shape=[batch_size, 1, 1, 1], mean=0.0, stddev=1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator(interpolated, training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calculate the norm of the gradients.
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, data, image_size=IMG_SIZE_TRAIN[0], num_classes=8):

        # Unpack the data.
        real_images, one_hot_labels = data
        batch_size_train = tf.shape(real_images)[0]

        real_images = self.preprocess_mobilenet(real_images)

        # Add dummy dimensions to the labels so that they can be concatenated with
        # the images. This is for the discriminator.
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = tf.repeat(one_hot_labels, repeats=[image_size * image_size])

        image_one_hot_labels = tf.reshape(image_one_hot_labels, (-1, image_size, image_size, num_classes))

        ##################################
        ##### Discriminator Training #####
        ##################################

        for _ in range(self.d_steps):
            # shape = TensorShape([4, 256])
            random_latent_vectors = tf.random.normal(shape=(batch_size_train, self.latent_dim))

            # Train the discriminator.
            with tf.GradientTape() as tape:

                # Decode the noise (guided by labels) to fake images.
                # shape = (4, 180, 180, 3)
                generated_images_fake = self.generator([random_latent_vectors, one_hot_labels, real_images], training=True)

                # Combine them with real images. Note that we are concatenating the labels
                # with these images here.
                # shape = TensorShape([4, 180, 180, 11])
                fake_image_and_labels = tf.concat([generated_images_fake, image_one_hot_labels], -1)

                real_image_and_labels = tf.concat([real_images, image_one_hot_labels], -1)

                predictions_on_fake = self.discriminator(fake_image_and_labels, training=True)
                predictions_on_real = self.discriminator(real_image_and_labels, training=True)

                d_loss_disc = self.d_loss_fn(real_data=predictions_on_real, pred_data=predictions_on_fake)
                #gp = self.gradient_penalty(batch_size, real_images, generated_images)
                gp = self.gradient_penalty(batch_size_train, real_image_and_labels, fake_image_and_labels)

                # Add the gradient penalty to the original discriminator loss
                d_loss = d_loss_disc + gp * self.gp_weight

            # Get gradients wrt disc loss
            d_grads = tape.gradient(d_loss, self.discriminator.trainable_weights)

            # Update discriminator weights
            self.d_optimizer.apply_gradients(zip(d_grads, self.discriminator.trainable_weights))

        ##############################
        ##### Generator Training #####
        ##############################

        # Sample random points in the latent space.
        # shape = TensorShape([4, 256])
        random_latent_vectors = tf.random.normal(shape=(batch_size_train, self.latent_dim))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:

            generated_images = self.generator([random_latent_vectors, one_hot_labels, real_images], training=True)

            fake_image_and_labels = tf.concat([generated_images, image_one_hot_labels], -1)

            predictions = self.discriminator(fake_image_and_labels, training=True)

            g_loss = self.g_loss_fn(predictions)

        # Update the generator weights
        gen_grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(gen_grads, self.generator.trainable_weights))

        return {"g_loss": g_loss,"d_loss": d_loss}


#### 6.5.5 WGAN - Model save checkpoint

In [None]:
class SaveCheckpoint(keras.callbacks.Callback):
    def __init__(self, filepath, epoch_to_wait=5):
        self._filepath = filepath
        self._epoch_to_wait = epoch_to_wait

    def on_epoch_end(self, epoch, logs=None):
        if epoch % self._epoch_to_wait == 0:
            timestamp = (datetime.datetime.now(
            ) + datetime.timedelta(hours=5, minutes=30)).strftime('%d_%m_%y_%H-%M-%S')
            folder_path = os.path.join(
                self._filepath, f'epoch_{epoch}_{timestamp}')
            os.makedirs(folder_path)
            # saving in h5 format
            self.model.save_weights(f'{folder_path}/wgan_ckp_epoch_{epoch}.h5')
            print(f'\tCheckpoint saved for epoch {epoch} at -> {folder_path}')


#### 6.5.6 WGAN - Image save callback

In [None]:
class ImageSaveCallback(keras.callbacks.Callback):
    def __init__(self, disease_lst, num_img, latent_dim, img_dump_dir, epoch_to_wait ):
        self.disease_lst = disease_lst
        self.num_img = num_img
        self.latent_dim = latent_dim
        self.img_dump_dir = img_dump_dir
        self.gan_train_dataset = gen_dataset(
            gan_paths_X_train, gan_lables_y_train, batch_size=num_img)
        self.epoch_to_wait = epoch_to_wait

    def on_epoch_end(self, epoch, logs=None):
        if (epoch % self.epoch_to_wait == 0):
            timestamp = (datetime.datetime.now(
            ) + datetime.timedelta(hours=5, minutes=30)).strftime('%d_%m_%y_%H-%M-%S')
            sample_gan_X_train_data, sample_gan_y_train_data = next(
                iter(self.gan_train_dataset))

            sample_gan_X_train_data = tf.keras.applications.mobilenet_v2.preprocess_input(
                sample_gan_X_train_data)

            batch_size_train = tf.shape(sample_gan_X_train_data)[0]

            random_latent_vectors = tf.random.normal(
                shape=(batch_size_train, self.latent_dim))

            generated_images = self.model.generator(
                [random_latent_vectors, sample_gan_y_train_data, sample_gan_X_train_data])
            generated_images = (generated_images * 127.5) + 127.5

            for i, label in enumerate(sample_gan_y_train_data):
                idx = np.argmax(label)
                disease = self.disease_lst[idx]
                img = generated_images[i].numpy()
                img = keras.preprocessing.image.array_to_img(img)

                disease_path = f'{self.img_dump_dir}/{disease}'

                if not os.path.exists(disease_path):
                    os.makedirs(disease_path)

                img.save(
                    f"{disease_path}/{disease}_{i}_{epoch}_{timestamp}.jpg")


#### 6.5.7 Training WGAN model

##### 6.5.7.1 Training - Without residual connection

In [None]:
# (learning_rate=0.0002, beta_1=0.5 are recommended)

# generator optimizers - lr = 0.0001, beta_1 = 0.5. beta_2 = 0.9
generator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9)

# discriminator optimizers - lr = 0.0004, beta_1 = 0.5. beta_2 = 0.9
discriminator_optimizer = keras.optimizers.Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9)


def discriminator_loss(real_data, pred_data):
    real_loss = tf.reduce_mean(real_data)

    fake_loss = tf.reduce_mean(pred_data)
    return fake_loss - real_loss


# Define the loss functions for the generator.
def generator_loss(fake_img):
    return -tf.reduce_mean(fake_img)


### Callback configuration

In [None]:
# Instantiate the custom `ImageSaveCallback` Keras callback.
cbk = ImageSaveCallback(num_img=BATCH_SIZE_GAN_TRAIN,
                        latent_dim=noise_dim,
                        disease_lst=train_data_df['disease'].unique().tolist(),
                        img_dump_dir='./output/wgan',
                        epoch_to_wait=5)

# fid = FID()
# inception_score = IS()
checkpoint = SaveCheckpoint(
    filepath=MODEL_CHECKPOINT_PATHS['wgan'],
    epoch_to_wait=5)

# Get the wgan model
wgan = WGAN(discriminator=d_model,
            generator=g_model,
            latent_dim=noise_dim,
            discriminator_extra_steps=5)


In [None]:
if train_wgan:

    # # Compile the wgan model
    wgan.compile(d_optimizer=discriminator_optimizer,
                 g_optimizer=generator_optimizer,
                 g_loss_fn=generator_loss,
                 d_loss_fn=discriminator_loss)

    # Start training
    wgan_hist = wgan.fit(gan_train_dataset,
                         batch_size=BATCH_SIZE_GAN_TRAIN,
                         epochs=200,
                         callbacks=[checkpoint, cbk])


In [None]:
if train_wgan:
    pd.DataFrame(wgan_hist.history).plot(grid=True)4


In [None]:
if kaggle:
    get_all_file_as_zip(model_history=wgan_hist, gan_name='wgan')


##### 6.5.7.2 Training - With residual connection

In [None]:
g_model_with_residual = get_generator_model(add_residiual=True)
g_model_with_residual.summary()


In [None]:
tf.keras.utils.plot_model(g_model_with_residual, show_shapes=True,
                          show_layer_names=True,  to_file='g_model_with_residual.png', dpi=70)


In [None]:
# Instantiate the custom `ImageSaveCallback` custom callback.
cbk = ImageSaveCallback(num_img=BATCH_SIZE_GAN_TRAIN,
                        latent_dim=noise_dim,
                        disease_lst=train_data_df['disease'].unique().tolist(),
                        img_dump_dir='./output/wgan_res')

# Instantiate the custom `SaveCheckpoint` custom callback.
checkpoint = SaveCheckpoint(filepath=MODEL_CHECKPOINT_PATHS['wgan_res'],
                            epoch_to_wait=5)

# Get the wgan model
wgan_res = WGAN(discriminator=d_model,
                generator=g_model_with_residual,
                latent_dim=noise_dim,
                discriminator_extra_steps=5,
                gp_weight=20)


In [None]:
if train_wgan_res:
    # # Compile the wgan model
    wgan_res.compile(d_optimizer=discriminator_optimizer,
                     g_optimizer=generator_optimizer,
                     g_loss_fn=generator_loss,
                     d_loss_fn=discriminator_loss)

    # # Start training
    wgan_res_hist = wgan_res.fit(
        gan_train_dataset, batch_size=BATCH_SIZE_GAN_TRAIN, epochs=150, callbacks=[checkpoint, cbk])


In [None]:
if train_wgan_res:
    pd.DataFrame(wgan_res_hist.history).plot(grid=True)


In [None]:
if kaggle:
    get_all_file_as_zip(model_history=wgan_res_hist, gan_name='wgan_res')


## 7 Generate images for each classes

*Generate 100 images for each class which will be used by the proportional ratio mixing mechanism to train a pretrained model following:*

    • real + synthetic images
    • real + geometrical augmented + synthetic images

### 7.1 Loading the data paths for sampling and shuffling

In [None]:
actual_data_df


In [None]:
actual_data_df.disease.value_counts()


In [None]:

def get_synthetic_img_generation_paths(df, disease, sample_size=100):
    disease_df = df[df['disease'] == disease]
    disease_df_sample = disease_df.sample(
        n=sample_size, axis=0, replace=True, random_state=42)
    disease_df = pd.concat([disease_df, disease_df_sample], axis=0)
    return disease_df


In [None]:
sample_data_df = pd.DataFrame()
for disease in actual_data_df['disease'].unique():
    sample_data_df = pd.concat([sample_data_df, get_synthetic_img_generation_paths(
        actual_data_df, disease)], axis=0)


In [None]:
sample_data_df.disease.value_counts()  # .plot.bar(color = 'y')


In [None]:
# shuffle the data

sample_data_df = sample_data_df.sample(frac=1)
sample_data_df


In [None]:
paths_X_train_syn_gen = sample_data_df['paths'].values.tolist()
lables_y_train_syn_gen = sample_data_df[sample_data_df['disease'].unique(
)].values


In [None]:
paths_X_train_syn_gen[:5]


### 7.2 Generate synthetic images

In [None]:
class SyntheticImageGenerator():
    def __init__(self, model_ref, checkpoint_path, data_dump_dir):
        self._model_ref = model_ref
        self._checkpoint_path = checkpoint_path
        self._data_dump_dir = data_dump_dir
        self._image_path_dict = {}

        if self._model_ref:
            self.load_model_weights()
            print(
                f'INFO: Model loaded from from checkpoint -> {self._checkpoint_path}')

    def load_model_weights(self):
        try:
            self._model_ref.built = True
            self._model_ref.load_weights(self._checkpoint_path)
        except:
            raise Exception(
                f'Model load error!!!\nCheckpoint path: {self._checkpoint_path}')

    def generate_images(self, _noise, _real_data, _labels,):

        _fake_data = self._model_ref.generator([_noise, _labels, _real_data])

        return _fake_data

    def map_func_for_generation(self, image_path, lables):
        """ This function will take the image_path & caption and return it's feature & respective caption. """

        img_tensor = skimage.io.imread(image_path)  # .decode('utf-8'))
        img_tensor = img_tensor.astype(np.float32)
        lables = lables.astype(np.float32)

        return img_tensor, lables

    def handle_image_generation_and_dump(self, img_path, label, disease_lst):
        img_name = os.path.basename(img_path)
        X, y = self.map_func_for_generation(image_path=img_path, lables=label)

        X = np.expand_dims(X, axis=0)
        y = y.reshape(1, -1)

        noise = tf.random.normal(shape=(1, 256))
        fake_data = self.generate_images(_noise=noise, _real_data=X, _labels=y)

        for i, data in enumerate(fake_data):
            idx = np.argmax(y[i])
            disease = disease_lst[idx]
            img_data = data.numpy()

            img_data = (img_data*127.5) + 127.5

            img_data = tf.keras.preprocessing.image.array_to_img(img_data)

            disease_path = f'{self._data_dump_dir}/{disease}/generated'

            if not os.path.exists(disease_path):
                os.makedirs(disease_path)

            img_path = f"{disease_path}/{img_name}"

            # logic for handling duplicate image samples
            if os.path.exists(img_path):
                _path, _format = os.path.splitext(img_path)

                _path = _path + '_'
                img_path = f'{_path}{_format}'

                if os.path.exists(img_path):
                    while os.path.exists(img_path):
                        _path = _path + '_'
                        img_path = f'{_path}{_format}'

            # if disease in self._image_path_dict:
            #     self._image_path_dict[disease].append(img_path)
            # else:
            #     self._image_path_dict[disease] = img_path

            img_data.save(img_path)

    def handle_synthetic_image_generation(self, disease_lst, paths_X_train_syn_gen, lables_y_train_syn_gen):

        t1 = time.perf_counter()

        for img_path, label in tqdm(zip(paths_X_train_syn_gen, lables_y_train_syn_gen), desc='Generating synthetic images...'):

            self.handle_image_generation_and_dump(
                img_path=img_path, label=label, disease_lst=disease_lst)

        print(f'\nDumped images are in -> {self._data_dump_dir}')
        print(f'Elapsed time : {time.perf_counter() - t1}')


#### 7.2.1 Using wgan

In [None]:
syn_img_gen = SyntheticImageGenerator(wgan,
                                      checkpoint_path='./checkpoints/vanilaa as h5 fromat/checkpoints/wgan/epoch_98_12_11_22_21-57-02/wgan_ckp_epoch_98.h5',
                                      data_dump_dir='./tmp/dump_vanilla')


In [None]:
syn_img_gen.handle_synthetic_image_generation(disease_lst=sample_data_df['disease'].unique().tolist(),
                                              paths_X_train_syn_gen=paths_X_train_syn_gen,
                                              lables_y_train_syn_gen=lables_y_train_syn_gen)


#### 7.2.2 Using wgan with residual

In [None]:
syn_img_gen = SyntheticImageGenerator(wgan_res,
                                      checkpoint_path='./checkpoints/wgan_res/checkpoints/wgan/epoch_190_13_11_22_01-48-35/wgan_ckp_epoch_190.h5',
                                      data_dump_dir='./tmp/dump_residual')


In [None]:
syn_img_gen.handle_synthetic_image_generation(disease_lst=sample_data_df['disease'].unique().tolist(),
                                              paths_X_train_syn_gen=paths_X_train_syn_gen,
                                              lables_y_train_syn_gen=lables_y_train_syn_gen)


###  7.3 GAN Metrics 

In [None]:
import metrics_fid
import metrics_inception_score


In [None]:
def copy_real_images(real_img_path_list, dst_directory='./tmp/dump_real_image'):

    if len(list(os.walk(dst_directory))) > 0:
        print(f'Found old files in {dst_directory}, all deleted')
        shutil.rmtree(dst_directory)

    for src_path in tqdm(real_img_path_list, desc='Copying files...'):
        img_name = os.path.basename(src_path)
        disease_name = os.path.basename(list(pb.Path(src_path).parents)[1])

        _dst_directory = f'{dst_directory}/{disease_name}'

        if not os.path.exists(_dst_directory):
            os.makedirs(_dst_directory)

        dst_path = os.path.join(_dst_directory, img_name)

        if os.path.exists(dst_path):
            _path, _format = os.path.splitext(dst_path)

            _path = _path + '_'
            dst_path = f'{_path}{_format}'

            if os.path.exists(dst_path):
                while os.path.exists(dst_path):
                    _path = _path + '_'
                    dst_path = f'{_path}{_format}'

        shutil.copy(src=src_path, dst=dst_path)


In [None]:
copy_real_images(paths_X_train_syn_gen)


In [None]:
def load_img_as_np_array(directory):

    t1 = time.perf_counter()
    path_lst = list(pb.Path(directory).rglob('*.jpg'))
    buffer_arr = np.zeros([len(path_lst), 180, 180, 3], dtype='uint8')
    for idx, path in enumerate(path_lst):
        img_tensor = skimage.io.imread(path)  # .decode('utf-8'))
        img_tensor = img_tensor.astype(np.float16)
        buffer_arr[idx, :, :, :] = img_tensor

    buffer_arr = buffer_arr.transpose([0, 3, 2, 1])  # [n,h,w,c] -> [n,c,h,w]
    print(f'Shape -> {buffer_arr.shape}')
    print(time.perf_counter() - t1)
    return buffer_arr


In [None]:
# getting the synthetic data as np array of vanilla wgan
synthetic_dataset_vanilla = load_img_as_np_array('./tmp/dump_vanilla')

# getting the synthetic data as np array of wgan residual
synthetic_dataset_residual = load_img_as_np_array('./tmp/dump_residual')

# getting the real data as np array
real_dataset = load_img_as_np_array('./tmp/dump_real_image')


##### 7.3.1 FID

Reference : https://github.com/tsc2017/Frechet-Inception-Distance

In [None]:
fid_score_vanilla = metrics_fid.get_fid(
    images1=synthetic_dataset_vanilla, images2=real_dataset)


In [None]:
fid_score_residual = metrics_fid.get_fid(
    images1=synthetic_dataset_residual, images2=real_dataset)


##### 7.3.2 Inception Score

Reference : https://github.com/tsc2017/Inception-Score

In [None]:
metrics_inception_score.get_inception_score(images=synthetic_dataset_vanilla)


In [None]:
metrics_inception_score.get_inception_score(images=synthetic_dataset_residual)


## 8 Load images for final classification model training

### 8.1 Generate the synthetic images in the train directory

In [None]:
syn_img_gen_final = SyntheticImageGenerator(wgan,
                                            checkpoint_path='./checkpoints/wgan/checkpoints/wgan/epoch_98_12_11_22_21-57-02/wgan_ckp_epoch_98.h5',
                                            data_dump_dir='../Dataset/tea sickness dataset')


In [None]:
syn_img_gen_final.handle_synthetic_image_generation(disease_lst=sample_data_df['disease'].unique().tolist(),
                                                    paths_X_train_syn_gen=paths_X_train_syn_gen,
                                                    lables_y_train_syn_gen=lables_y_train_syn_gen)


In [None]:
def get_synthetic_image_paths(directory, disease_list):
    image_path_list = []
    for disease in disease_list:
        image_path_list_disease = list(
            pb.Path(f'{directory}/{disease}/generated').glob('*.jpg'))
        image_path_list.extend(image_path_list_disease)

    return image_path_list


In [None]:
synthtic_images_lst = get_synthetic_image_paths(
    '../Dataset/tea sickness dataset', actual_data_df['disease'].unique().tolist())


### 8.2 Get a dataframe having the disease paths and respective OHE labels

In [None]:
synthetic_df = get_img_path_and_labels_df(synthtic_images_lst)
synthetic_df


In [None]:
def get_proportional_random_mixing(actual_data, synthetic_data, synthetic_percent=20):

    print(
        f'Shape of actual data -> {actual_data.shape}\nShape of synthetic data -> {synthetic_data.shape}')
    synthetic_percent = synthetic_percent/100
    num_rows_synthetic = int(
        np.ceil(synthetic_data.shape[0]*synthetic_percent))
    print(f'Synthetic sample size -> {num_rows_synthetic}')

    df_synthtic_sample = synthetic_data.sample(n=num_rows_synthetic)

    merged_df = pd.concat([actual_data, df_synthtic_sample], axis=0)
    print(f'After merging shape of actual data -> {merged_df.shape}')
    return merged_df


## 9 Final classification model trainig

In [None]:

def get_proportional_mixed_data(merged_df):
    final_paths_X_train = merged_df['paths'].values.tolist()
    final_lables_y_train = merged_df[merged_df['disease'].unique()].values
    final_train_dataset = gen_dataset(
        final_paths_X_train, final_lables_y_train)
    return final_train_dataset, final_paths_X_train


In [None]:
def get_resnet():
    resnet = resnet50.ResNet50(
        include_top=False, weights='imagenet', input_shape=INPUT_SHAPE)
    resnet_model = ModelBuilder(pretrained_model_ref=resnet, num_classes=8)
    return resnet_model


In [None]:
INPUT_SHAPE = (180, 180, 3)
INPUT_SHAPE_BUILD = (None, 180, 180, 3)


### 9.1 Train model with real + 20% synthetic data

In [None]:
merged_df = get_proportional_random_mixing(
    actual_data_df, synthetic_df, synthetic_percent=20)
merged_df.head()


In [None]:
resnet_model_final = get_resnet()
model_helper = ModelHelper(model=resnet_model_final, LR=0.0002)

resnet_model_with_syn_20 = model_helper.compile_model(optimizer_name='adam')
resnet_model_with_syn_20.build(input_shape=INPUT_SHAPE_BUILD)
resnet_model_with_syn_20.summary()


In [None]:
train_dataset, paths_X_train = get_proportional_mixed_data(merged_df)


In [None]:
steps_per_epoch, validation_steps = get_epoch_settings(
    paths_X_train, paths_X_val, batch_size=BATCH_SIZE_TRAIN)
steps_per_epoch, validation_steps


In [None]:

resnet_model_with_syn_20_history = resnet_model_with_syn_20.fit(train_dataset,
                                                                steps_per_epoch=steps_per_epoch,
                                                                epochs=100,
                                                                verbose=1,
                                                                callbacks=[
                                                                    model_helper.decrease_lr_on_plateau_callback()],
                                                                validation_data=val_datset,
                                                                validation_steps=validation_steps,
                                                                class_weight=None,
                                                                workers=1,
                                                                initial_epoch=0)


In [None]:
model_helper.plot_model_train_info(resnet_model_with_syn_20_history)


In [None]:
resnet_model_with_syn_20.evaluate(test_dataset)


In [None]:
get_metrics_for_model(resnet_model_with_syn_20, 'resnet_model_with_syn_20')


In [None]:
dump_model(resnet_model_with_syn_20,
           'G:/Learning/Degree Courses/MS AI ML/Research/Tea Sickeness Project/Code/output', 'resnet_model_with_syn_20')


### 9.2 Train model with real + 40% synthetic data

In [None]:
merged_df = get_proportional_random_mixing(
    actual_data_df, synthetic_df, synthetic_percent=40)
merged_df.head()


In [None]:
train_dataset, paths_X_train = get_proportional_mixed_data(merged_df)


In [None]:
steps_per_epoch, validation_steps = get_epoch_settings(
    paths_X_train, paths_X_val, batch_size=BATCH_SIZE_TRAIN)
steps_per_epoch, validation_steps


In [None]:
resnet_model_final = get_resnet()
model_helper = ModelHelper(model=resnet_model_final, LR=0.0002)

resnet_model_with_syn_40 = model_helper.compile_model(optimizer_name='adam')
resnet_model_with_syn_40.build(input_shape=INPUT_SHAPE_BUILD)
resnet_model_with_syn_40.summary()


In [None]:
resnet_model_with_syn_40_history = resnet_model_with_syn_40.fit(train_dataset,
                                                                steps_per_epoch=steps_per_epoch,
                                                                epochs=100,
                                                                verbose=1,
                                                                callbacks=[
                                                                    model_helper.decrease_lr_on_plateau_callback()],
                                                                validation_data=val_datset,
                                                                validation_steps=validation_steps,
                                                                class_weight=None,
                                                                workers=1,
                                                                initial_epoch=0)


In [None]:
resnet_model_with_syn_40.evaluate(test_dataset)


In [None]:
model_helper.plot_model_train_info(resnet_model_with_syn_40_history)


In [None]:
resnet_model_with_syn_40.evaluate(test_dataset)


In [None]:
get_metrics_for_model(resnet_model_with_syn_40, 'resnet_model_with_syn_40')


In [None]:
dump_model(resnet_model_with_syn_40,
           'G:/Learning/Degree Courses/MS AI ML/Research/Tea Sickeness Project/Code/output', 'resnet_model_with_syn_40')


### 9.4 Train model with real + 100% synthetic data

In [None]:
merged_df = get_proportional_random_mixing(
    actual_data_df, synthetic_df, synthetic_percent=100)
merged_df.head()


In [None]:
resnet_model_final = get_resnet()
model_helper = ModelHelper(model=resnet_model_final, LR=0.0002)

resnet_model_with_syn_100 = model_helper.compile_model(optimizer_name='adam')
resnet_model_with_syn_100.build(input_shape=INPUT_SHAPE_BUILD)
resnet_model_with_syn_100.summary()


In [None]:
train_dataset, paths_X_train = get_proportional_mixed_data(merged_df)


In [None]:
steps_per_epoch, validation_steps = get_epoch_settings(
    paths_X_train, paths_X_val, batch_size=BATCH_SIZE_TRAIN)
steps_per_epoch, validation_steps


In [None]:
resnet_model_with_syn_100_history = resnet_model_with_syn_100.fit(train_dataset,
                                                                  steps_per_epoch=steps_per_epoch,
                                                                  epochs=100,
                                                                  verbose=1,
                                                                  callbacks=[
                                                                      model_helper.decrease_lr_on_plateau_callback()],
                                                                  validation_data=val_datset,
                                                                  validation_steps=validation_steps,
                                                                  class_weight=None,
                                                                  workers=1,
                                                                  initial_epoch=0)


In [None]:
resnet_model_with_syn_100.evaluate(test_dataset)


In [None]:
model_helper.plot_model_train_info(resnet_model_with_syn_100_history)


In [None]:
get_metrics_for_model(resnet_model_with_syn_100, 'resnet_model_with_syn_100')


In [None]:
dump_model(resnet_model_with_syn_100,
           'G:/Learning/Degree Courses/MS AI ML/Research/Tea Sickeness Project/Code/output', 'resnet_model_with_syn_100')
