# Preprocessing in CPU environment with tfdata

In [1]:
import tensorflow as tf
import numpy as np
import os
import time
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, RandomFlip, RandomRotation, RandomCrop
from tensorflow.keras.models import Model

2024-01-05 10:07:47.626565: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
data_dir = "data/train_small_npy"

labels_dict = np.load(os.path.join(data_dir, "labels.npy"), allow_pickle = True).item()
local_paths = list(labels_dict.keys())
global_paths = [os.path.join(data_dir, local_path) for local_path in local_paths]
labels = list(labels_dict.values())

In [3]:
dataset = tf.data.Dataset.from_tensor_slices((global_paths, labels))

In [4]:
def load_data(image_path, label):
    def load_np_file(path):
        return np.load(path.decode("utf-8"))

    image = image = tf.numpy_function(load_np_file, [image_path], tf.uint8)
    return image, label

dataset = dataset.map(load_data, num_parallel_calls=8)

In [5]:
num_epochs = 10
batch_size = 48

input_layer = Input(shape=(1200, 2000, 3), dtype=tf.float32)
x = RandomFlip("horizontal_and_vertical")(input_layer)
x = RandomRotation(factor=0.25, fill_mode="constant", fill_value = 0)(x)
x = RandomCrop(500, 500)(x)

data_augmentation_model = Model(inputs=input_layer, outputs=x)

def custom_transform_data(image, label):
    image = tf.transpose(image, perm=[1, 2, 0])
    image = data_augmentation_model(image)
    image = tf.cast(image, dtype=tf.float32) / 255.0
    image = tf.transpose(image, perm=[2, 0, 1])
    return image, label

transformed_dataset = dataset.map(custom_transform_data)

start_time = time.time()
for epoch in range(num_epochs):
    epoch_start = time.time()
    for images, labels in transformed_dataset.batch(batch_size).as_numpy_iterator():
        pass
    epoch_end = time.time()
    epoch_time = epoch_end - epoch_start
    print(f"Epoch {epoch+1} done in {epoch_time} seconds.")
end_time = time.time()

total_time = end_time - start_time

print(f"Total time taken: {total_time} seconds")
print(f"Total time per epoch: {total_time/num_epochs} seconds")

Epoch 1 done in 323.1505460739136 seconds.
Epoch 2 done in 318.28870964050293 seconds.
Epoch 3 done in 327.91578364372253 seconds.
Epoch 4 done in 318.7570502758026 seconds.
Epoch 5 done in 308.46475768089294 seconds.
Epoch 6 done in 320.7336037158966 seconds.
Epoch 7 done in 302.6726689338684 seconds.
Epoch 8 done in 298.4840362071991 seconds.
Epoch 9 done in 307.5554802417755 seconds.
Epoch 10 done in 295.64889574050903 seconds.
Total time taken: 3121.67312002182 seconds
Total time per epoch: 312.16731200218203 seconds
