In [4]:
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_datasets as tfds

# Load the data
data = tfds.load("tf_flowers", split="train[:90%]", as_supervised=True)
test_data = tfds.load("tf_flowers", split="train[90%:]", as_supervised=True)
train_data = data.take(2090)
valid_data = data.skip(2090)

# Define the model
model = tf.keras.Sequential([
    layers.Conv2D(16, 3, padding='same', activation='relu', input_shape=(224, 224, 3)),
    layers.MaxPooling2D(),
    layers.Dropout(0.2),
    layers.Conv2D(32, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Dropout(0.2),
    layers.Conv2D(64, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Dropout(0.2),
    layers.Flatten(),
    layers.Dense(512, activation='relu'),
    layers.Dense(5)
])

# Compile the model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# Normalize the images and resize them
def preprocess(image, label):
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize_with_pad(image, 224, 224)
    return image, label

train_data = train_data.map(preprocess).batch(32)
valid_data = valid_data.map(preprocess).batch(32)
test_data = test_data.map(preprocess).batch(32)

# Train the model
history = model.fit(train_data, epochs=10, validation_data=valid_data)

# Evaluate the model on the test data
test_loss, test_acc = model.evaluate(test_data)
print('Test accuracy:', test_acc)

# Make predictions on the test data
for image, label in test_data.take(5):
  predictions = model.predict(image)
  print('Predictions:', predictions)
  print('True values:', label)


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Test accuracy: 0.5722070932388306
Predictions: [[-3.7097232  -6.342048    3.4361622   6.9431033   0.22206457]
 [ 6.417727    0.0830256  -1.942686   -1.6537341  -2.5386472 ]
 [ 4.034203   -3.230047   -1.5410528   4.816922   -1.9962591 ]
 [-0.959799    4.394601    3.6008377  -4.71567     0.02834343]
 [ 3.2604725   0.503977   -1.7204429   0.02881222 -1.3553493 ]
 [-1.7400517  -1.9387016   0.91821086 -2.7365224   4.7176075 ]
 [-0.9418631  -2.989067    3.3423307   1.8173085  -1.8696035 ]
 [ 1.2516993   0.08514363 -0.7050368   1.5012536   0.12511988]
 [ 5.326471   -6.848391   -0.4410803   8.024228   -3.193048  ]
 [ 4.3793464  -0.02115581 -0.78990537 -2.4819446  -0.5503192 ]
 [ 5.8100824   1.3933065  -0.82445747 -0.9866694  -5.401105  ]
 [ 6.020147   -5.9283333  -1.0690539   5.489772   -1.423341  ]
 [-4.8559637  -0.26184937  1.7658005   5.005494    0.12258168]
 [ 2.6601403  -6.243624