In [8]:
import tensorflow as tf
from tensorflow import keras
from keras.models import Sequential
from keras.layers import Conv2D
from keras import layers
from keras.preprocessing.image import ImageDataGenerator
import numpy as np

In [2]:
# Define the CNN
model = Sequential()
model.add(Conv2D(100, (3, 3), strides=(2, 2), padding='same', input_shape=(720, 1280, 3)))
model.add(Conv2D(200, (3, 3), strides=(2, 2), padding='same'))
model.add(Conv2D(400, (3, 3), strides=(2, 2), padding='same'))

# Print the number of parameters in the CNN
print("Total number of parameters:", model.count_params())

Total number of parameters: 903400


In [3]:
# Create a random input image
input_image = np.random.rand(1, 720, 1280, 3).astype(np.float32)

# Make a prediction with the CNN and measure the memory usage
model.predict(input_image)
memory_usage = model.output_shape[1] * model.output_shape[2] * model.output_shape[3] * 4 + \
               model.input_shape[1] * model.input_shape[2] * model.input_shape[3] * 4 + \
               model.count_params() * 4
print("Minimum total RAM needed:", memory_usage, "bytes")

Minimum total RAM needed: 37712800 bytes


In [4]:
# Make a prediction with the CNN and measure the memory usage with 8-bit floats
model.predict(input_image.astype(np.float32) / 256)
memory_usage = model.output_shape[1] * model.output_shape[2] * model.output_shape[3] + \
               model.input_shape[1] * model.input_shape[2] * model.input_shape[3] + \
               model.count_params()
print("Minimum total RAM needed with 8-bit floats:", memory_usage, "bytes")

Minimum total RAM needed with 8-bit floats: 9428200 bytes


In [5]:
# Create 20 random input images
input_images = np.random.rand(20, 720, 1280, 3).astype(np.float32)

# Make a prediction with the CNN and measure the memory usage with 20 input images
model.predict(input_images)
memory_usage = 20 * (model.output_shape[1] * model.output_shape[2] * model.output_shape[3] * 4 + \
                     model.input_shape[1] * model.input_shape[2] * model.input_shape[3] * 4) + \
               model.count_params() * 4
print("Minimum total RAM needed with 20 input images:", memory_usage, "bytes")

Minimum total RAM needed with 20 input images: 685597600 bytes


In [9]:
# Load the Fashion MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()

# Normalize pixel values to be between 0 and 1
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

# Reshape input data to be 4D (batch_size, height, width, channels)
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)

In [10]:
# Define image augmentation
datagen = ImageDataGenerator(
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
    vertical_flip=False,
    fill_mode='nearest')

# Define the model
model = keras.Sequential(
    [
        layers.Conv2D(64, kernel_size=3, activation="relu", padding="same", input_shape=(28, 28, 1)),
        layers.BatchNormalization(),
        layers.Conv2D(64, kernel_size=3, activation="relu", padding="same"),
        layers.BatchNormalization(),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Dropout(0.25),
        layers.Flatten(),
        layers.Dense(128, activation="relu"),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        layers.Dense(10, activation="softmax"),
    ]
)

In [None]:
# Define the learning rate scheduler
def scheduler(epoch, lr):
    if epoch < 10:
        return lr
    else:
        return lr * tf.math.exp(-0.1)

# Compile the model
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

# Train the model with batch size of 64 for 20 epochs using the image generator
history = model.fit(datagen.flow(x_train, y_train, batch_size=64), epochs=20, validation_data=(x_test, y_test), callbacks=[keras.callbacks.LearningRateScheduler(scheduler)])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
139/938 [===>..........................] - ETA: 6:40 - loss: 0.3938 - accuracy: 0.8540