In [1]:
import tensorflow as tf
import numpy as np
import cv2

In [8]:
# Path to the dataset
data_dir = './Data/Data'

# Creating a dataset from the images in the data directory
image_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    labels='inferred',
    label_mode='int',
    color_mode='grayscale',
    batch_size=32,
    image_size=(256, 256),
    shuffle=True,
    seed=None,
    validation_split=None,
    subset=None,
    interpolation='bilinear',
    follow_links=False
)

Found 1000 files belonging to 3 classes.


In [9]:
# Printing the class names
class_names = image_ds.class_names
print(class_names)

['test', 'train', 'valid']


In [32]:
# Method to detect the blurred images
def detect_blur(image):
  image = tf.cast(image * 255, tf.uint8)  
  return tf.py_function(func=_detect_blur, inp=[image], Tout=tf.bool)

def _detect_blur(image):
  image = image.numpy().copy()
  image = np.squeeze(image)
  laplacian = cv2.Laplacian(image, cv2.CV_64F)  
  laplacian_mean, laplacian_std = cv2.meanStdDev(laplacian)
  threshold = laplacian_mean + 3 * laplacian_std  
  print(threshold)
  print(np.max(laplacian))
  return np.max(laplacian) < threshold

@tf.autograph.experimental.do_not_convert
def preprocess_image(image, label):
    print(tf.cond(detect_blur(image)
                   , true_fn=lambda: (image, label)
                   , false_fn=lambda: (tf.zeros_like(image), tf.zeros_like(label))))
    return tf.cond(detect_blur(image)
                   , true_fn=lambda: (image, label)
                   , false_fn=lambda: (tf.zeros_like(image), tf.zeros_like(label)))

In [33]:
# Detecting and removing the blurred images
processed_ds = image_ds.map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
filtered_ds = processed_ds.filter(lambda image, label: tf.math.reduce_all(tf.math.not_equal(image, tf.zeros_like(image))))

(<tf.Tensor 'cond/Identity:0' shape=(None, 256, 256, 1) dtype=float32>, <tf.Tensor 'cond/Identity_1:0' shape=(None,) dtype=int32>)


In [15]:
import imageio as iio
 
# read an image
img = iio.imread("images.jpg") 
iio.imwrite("images.jpg", img)