In [2]:
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
import numpy as np


In [3]:
# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [4]:
# Normalize pixel values to be between 0 and 1
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0


In [5]:
# One-hot encode the labels
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

In [6]:
#We'll use the ResNet architecture for this task
# build and train the model
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam

In [7]:
# Load the ResNet50 model pre-trained on ImageNet, without the top layer
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(32, 32, 3))

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5


In [8]:
# Add a global spatial average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)

# Add a fully-connected layer and a softmax layer
predictions = Dense(10, activation='softmax')(x)

In [11]:
# This is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)

# Compile the model
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])

In [15]:
# Train the model
history = model.fit(x_train, y_train, epochs=20 , batch_size=64, validation_data=(x_test, y_test))

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [16]:
# Evaluate the model on the test set
test_loss, test_accuracy = model.evaluate(x_test, y_test)

print(f'Test accuracy: {test_accuracy}')
print(f'Test loss: {test_loss}')

Test accuracy: 0.7804999947547913
Test loss: 1.0339032411575317


In [17]:
model.evaluate(x_test,y_test)



[1.0339032411575317, 0.7804999947547913]

In [18]:
y_train[:5]

array([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)

In [24]:
y_train = y_train.reshape(-1,)

y_train[:5]

array([0., 0., 0., 0., 0.], dtype=float32)

In [20]:
y_test = y_test.reshape(-1,)

In [21]:
classes = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]

In [29]:
y_pred = model.predict(x_test)
y_pred[:5]



array([[4.73682121e-06, 5.23449780e-05, 2.46185381e-02, 9.54552591e-01,
        1.35608390e-03, 1.39905338e-03, 1.54227270e-02, 9.55979573e-04,
        5.24717092e-04, 1.11323385e-03],
       [1.06381485e-05, 3.64212246e-05, 3.15405401e-07, 9.19440572e-06,
        2.47988936e-07, 1.68540009e-06, 8.45505426e-08, 3.58925725e-08,
        9.99932289e-01, 9.04613989e-06],
       [2.66886549e-03, 1.33300647e-01, 1.02840897e-06, 1.02039725e-02,
        1.62281084e-03, 2.06360346e-04, 4.71297317e-05, 2.52550944e-06,
        8.48907471e-01, 3.03919334e-03],
       [6.28707230e-01, 2.39610436e-05, 8.09461984e-04, 1.17494375e-04,
        7.39131906e-07, 2.91784750e-06, 6.83939015e-06, 4.49099971e-05,
        3.70272815e-01, 1.36431436e-05],
       [1.01420716e-09, 6.74288447e-09, 4.85661349e-06, 7.68199904e-09,
        9.29408241e-03, 2.21462781e-09, 9.90700603e-01, 9.73754126e-08,
        3.12345975e-07, 5.16548082e-10]], dtype=float32)

In [32]:
y_classes = [np.argmax(element) for element in y_pred]
y_classes[:5]

[3, 8, 8, 0, 6]

In [33]:
y_test[:5]

array([0., 0., 0., 1., 0.], dtype=float32)

In [35]:
classes[y_classes[3]]

'airplane'

In [36]:
classes[y_classes[3]]

'airplane'

In [37]:
classes[y_classes[6]]

'cat'