# End-to-end training in CPU environment with tfdata

In [1]:
import tensorflow as tf
import numpy as np
import os
import time
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, RandomFlip, RandomRotation, RandomCrop
from tensorflow.keras.models import Model

2024-01-05 12:35:46.830866: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
data_dir = "data/train_small_npy"

labels_dict = np.load(os.path.join(data_dir, "labels.npy"), allow_pickle = True).item()
local_paths = list(labels_dict.keys())
global_paths = [os.path.join(data_dir, local_path) for local_path in local_paths]
labels = list(labels_dict.values())

dataset = tf.data.Dataset.from_tensor_slices((global_paths, labels))

def load_data(image_path, label):
    def load_np_file(path):
        return np.load(path.decode("utf-8"))

    image = image = tf.numpy_function(load_np_file, [image_path], tf.uint8)
    return image, label

dataset = dataset.map(load_data, num_parallel_calls=8)

In [4]:
class MLModel(tf.keras.Model):
    def __init__(self):
        super(MLModel, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(16, (3, 3), strides=(1, 1), padding='same', activation='relu', input_shape=(3, 500, 500))
        self.pool1 = tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2))
        self.conv2 = tf.keras.layers.Conv2D(32, (3, 3), strides=(1, 1), padding='same', activation='relu')
        self.pool2 = tf.keras.layers.MaxPooling2D((2, 2), strides=(2, 2))
        self.flatten = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(200, activation='relu')
        self.fc2 = tf.keras.layers.Dense(1)

    def call(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

mlmodel = MLModel()

criterion = tf.keras.losses.MeanSquaredError()
optimizer = tf.optimizers.Adam(learning_rate=0.001)

In [5]:
num_epochs = 10
batch_size = 48

input_layer = Input(shape=(1200, 2000, 3), dtype=tf.float32)
x = RandomFlip("horizontal_and_vertical")(input_layer)
x = RandomRotation(factor=0.25, fill_mode="constant", fill_value = 0)(x)
x = RandomCrop(500, 500)(x)

data_augmentation_model = Model(inputs=input_layer, outputs=x)

def custom_transform_data(image, label):
    image = tf.transpose(image, perm=[1, 2, 0])
    image = data_augmentation_model(image)
    image = tf.cast(image, dtype=tf.float32) / 255.0
    image = tf.transpose(image, perm=[2, 0, 1])
    return image, label

transformed_dataset = dataset.map(custom_transform_data)

start_time = time.time()
for epoch in range(num_epochs):
    epoch_start = time.time()
    for images, labels in transformed_dataset.batch(batch_size).as_numpy_iterator():
        with tf.GradientTape() as tape:
            outputs = mlmodel(images)
            labels = tf.convert_to_tensor(labels.reshape(-1, 1), dtype=tf.float32)
            loss = criterion(labels, outputs)
        
        gradients = tape.gradient(loss, mlmodel.trainable_variables)
        optimizer.apply_gradients(zip(gradients, mlmodel.trainable_variables))
    epoch_end = time.time()
    epoch_time = epoch_end - epoch_start
    print(f"Epoch {epoch+1} done in {epoch_time} seconds.")
end_time = time.time()

total_time = end_time - start_time

print(f"Total time taken: {total_time} seconds")
print(f"Total time per epoch: {total_time/num_epochs} seconds")

Epoch 1 done in 469.6996729373932 seconds.
Epoch 2 done in 455.69856905937195 seconds.
Epoch 3 done in 455.5163300037384 seconds.
Epoch 4 done in 484.2993869781494 seconds.
Epoch 5 done in 462.66717076301575 seconds.
Epoch 6 done in 572.1970403194427 seconds.
Epoch 7 done in 501.33690428733826 seconds.
Epoch 8 done in 507.70059871673584 seconds.
Epoch 9 done in 509.2733688354492 seconds.
Epoch 10 done in 524.6082739830017 seconds.
Total time taken: 4942.998977899551 seconds
Total time per epoch: 494.29989778995514 seconds
