In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, datasets, utils, Input
import numpy as np
import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_test) = datasets.cifar10.load_data()

x_train = tf.image.resize(x_train, (224, 224)).numpy()
x_test = tf.image.resize(x_test, (224, 224)).numpy()

x_train, x_test = x_train / 255.0, x_test / 255.0

y_train = utils.to_categorical(y_train, 10)
y_test = utils.to_categorical(y_test, 10)

def AlexNet(input_shape=(224, 224, 3), num_classes=10):
    inputs = Input(shape=input_shape)

    x = layers.Conv2D(96, (11, 11), strides=4, activation='relu')(inputs)
    x = layers.MaxPooling2D(pool_size=(3, 3), strides=2)(x)

    x = layers.Conv2D(256, (5, 5), padding='same', activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(3, 3), strides=2)(x)

    x = layers.Conv2D(384, (3, 3), padding='same', activation='relu')(x)
    x = layers.Conv2D(256, (3, 3), padding='same', activation='relu')(x)
    x = layers.Conv2D(384, (3, 3), padding='same', activation='relu')(x)
    x = layers.MaxPooling2D(pool_size=(3, 3), strides=2)(x)

    x = layers.Flatten()(x)
    x = layers.Dense(4096, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(4096, activation='relu')(x)
    x = layers.Dropout(0.5)(x)

    outputs = layers.Dense(num_classes, activation='softmax')(x)
    model = models.Model(inputs=inputs, outputs=outputs, name="AlexNet")
    return model

alexnet = AlexNet()

alexnet.compile(optimizer='adam',
                loss='categorical_crossentropy',
                metrics=['accuracy'])

history = alexnet.fit(x_train, y_train,
                      epochs=5, batch_size=128,
                      validation_split=0.1)

test_loss, test_acc = alexnet.evaluate(x_test, y_test)
print("Test Accuracy:", test_acc)

plt.plot(history.history['accuracy'], label='Train Acc')
plt.plot(history.history['val_accuracy'], label='Val Acc')
plt.legend()
plt.show()

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

idx = np.random.choice(len(x_test), 5, replace=False)
for i in idx:
    img = x_test[i]
    true_label = np.argmax(y_test[i])
    pred = np.argmax(alexnet.predict(img.reshape(1, 224, 224, 3)))
    plt.imshow(img.astype("uint8"))
    plt.title(f"True: {class_names[true_label]}, Pred: {class_names[pred]}")
    plt.axis("off")
    plt.show()