In [4]:
#imports for automation
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys

#import for web scrapping
from bs4 import BeautifulSoup as bs
import requests
from os.path import basename
import base64

#import for image visualization
import matplotlib.pyplot as plt
import matplotlib.image as img

#import for image augmentation
import tensorflow.image as tfimg

#import for model generation
import tensorflow as tf
from keras.models import Sequential
import keras.callbacks
import keras.metrics as metrics


#extra imports
import traceback
import time
import os
import numpy as np

# Data Collection

This section scrapes data from Google Images and stores them in different folders/

In [None]:
class data_accumulator:
    
    def __init__(self, topic, folder):
        self.topic = topic
        self.folder = folder
        self.driver = webdriver.Chrome()
        self.driver.get('https://www.google.com')

    def topic_image(self):
        try:
            search = self.driver.find_element(By.CLASS_NAME, 'gLFyf')
            search.send_keys(self.topic)
            search.send_keys(Keys.RETURN)

            image_page = self.driver.find_element(By.LINK_TEXT, 'Images')
            image_page.click()

            images = self.driver.find_element(By.CLASS_NAME, 'islrc')
            return images

        except:
            traceback.print_exc()


    def image_download(self):
        current_dir = os.getcwd()
        if(not(os.path.exists(os.path.join(os.getcwd(), self.folder)))):
            os.mkdir(os.path.join(os.getcwd(), self.folder))
            os.chdir(os.path.join(os.getcwd(), self.folder))
        else:
            os.chdir(os.path.join(os.getcwd(), self.folder))
        
        page = self.topic_image()
        i = 0
        elements = page.get_attribute('outerHTML') #gives exact HTML content of the element
        soup = bs(elements,'html.parser')
        img = soup.findAll('img',{"src":True, "height":True, "width":True})
        print(f"Total Images found: {len(img)}")
        
        for link in img:
            png = link["src"]
            if r"data:image" in png:
                data = png.split(',')[1]
                i = i+1
                with open(f"data{i}.png", "wb") as f:
                    f.write(base64.b64decode(data))
            else:
                i = i+1
                with open(f"file{i}.png", "wb") as f:
                    f.write(requests.get(png).content)
        os.chdir(current_dir)
        self.driver.close()

# Dataset Noise Addition
Adding noisy images to the dataset to get more robust model and also increase the number of dataset values to train and test

In [None]:
# doge = img.imread(r"E:\Github\ImageClassifier\dogs\data1.png", format='PNG')
# print(fig.shape)
# plt.imshow(fig[:, :,:])

In [None]:
def visualize(img1, img2):
    fig = plt.figure()
    plt.subplot(1,2,1)
    plt.title('Original image')
    plt.imshow(img1)

    plt.subplot(1,2,2)
    plt.title('Augmented image')
    plt.imshow(img2)

flipped = tf.image.flip_left_right(doge)
visualize(doge, flipped)

In [None]:
def plot_images(dataset, n_images, samples_per_image):
    output = np.zeros((32 * n_images, 32 * samples_per_image, 3))

    row = 0
    for images in dataset.repeat(samples_per_image).batch(n_images):
        output[:, row*32:(row+1)*32] = np.vstack(images.numpy())
        row += 1

    plt.figure()
    plt.imshow(output)
    plt.show()

def flip(x: tf.Tensor) -> tf.Tensor:
    """Flip augmentation

    Args:
        x: Image to flip

    Returns:
        Augmented image
    """
    x = tf.image.random_flip_left_right(x)
    x = tf.image.random_flip_up_down(x)
    x = tfimg.r
    return x

def color(x: tf.Tensor) -> tf.Tensor:
    """Color augmentation

    Args:
        x: Image

    Returns:
        Augmented image
    """
    x = tf.image.random_hue(x, 0.08)
    x = tf.image.random_saturation(x, 0.6, 1.6)
    x = tf.image.random_brightness(x, 0.05)
    x = tf.image.random_contrast(x, 0.7, 1.3)
    return x

def rotate(x: tf.Tensor) -> tf.Tensor:
    """Rotation augmentation

    Args:
        x: Image

    Returns:
        Augmented image
    """

    return tf.image.rot90(x, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32))

def zoom(x: tf.Tensor) -> tf.Tensor:
    """Zoom augmentation

    Args:
        x: Image

    Returns:
        Augmented image
    """

    # Generate 20 crop settings, ranging from a 1% to 20% crop.
    scales = list(np.arange(0.8, 1.0, 0.01))
    boxes = np.zeros((len(scales), 4))

    for i, scale in enumerate(scales):
        x1 = y1 = 0.5 - (0.5 * scale)
        x2 = y2 = 0.5 + (0.5 * scale)
        boxes[i] = [x1, y1, x2, y2]

    def random_crop(img):
        # Create different crops for an image
        crops = tf.image.crop_and_resize([img], boxes=boxes, box_ind=np.zeros(len(scales)), crop_size=(32, 32))
        # Return a random crop
        return crops[tf.random.uniform(shape=[], minval=0, maxval=len(scales), dtype=tf.int32)]


    choice = tf.random.uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)

    # Only apply cropping 50% of the time
    return tf.cond(choice < 0.5, lambda: x, lambda: random_crop(x))

# Add augmentations
augmentations = [flip, color, zoom, rotate]

for f in augmentations:
    dataset = dataset.map(lambda x: tf.cond(tf.random.uniform([], 0, 1) > 0.75, lambda: f(x), lambda: x), num_parallel_calls=4)
dataset = dataset.map(lambda x: tf.clip_by_value(x, 0, 1))

plot_images(dataset, n_images=8, samples_per_image=10)


# Image Classifier
This section would consist of the Image Classifier model set.

## Dataset Creation
Creating a pipeline to fetch the data

In [None]:
#importing libraries for the modelling
import tensorflow as tf
from tensorflow.data import Dataset
from keras.models import Sequential

In [None]:
from tensorflow.python.client import device_lib 
print(device_lib.list_local_devices())

# Driver Code 

In [None]:
if __name__ == "__main__":
    data_A = data_accumulator("dogs", "dogs")
    data_A.image_download()
    
    data_B = data_accumulator("cats", "cats")
    data_B.image_download()