In [1]:
import tensorflow as tf
from tensorflow.keras.applications import InceptionV3
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.image import resize
import numpy as np

# Load CIFAR-10 data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Resize CIFAR-10 images to 75x75
x_train_resized = resize(x_train, [75, 75])
x_test_resized = resize(x_test, [75, 75])

# Normalize pixel values
x_train_resized = x_train_resized / 255.0
x_test_resized = x_test_resized / 255.0

# Convert labels to one-hot encoding
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

# Load pre-trained InceptionV3 model without top classification layer
base_model = InceptionV3(weights='imagenet', include_top=False, input_shape=(75, 75, 3))

# Freeze convolutional layers
for layer in base_model.layers:
    layer.trainable = False

# Add new classification layers
x = GlobalAveragePooling2D()(base_model.output)
x = Dense(256, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)

# Create the model
model = Model(inputs=base_model.input, outputs=predictions)

# Compile the model
model.compile(optimizer=Adam(lr=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
model.fit(x_train_resized, y_train, batch_size=64, epochs=10, validation_data=(x_test_resized, y_test))

# Sample prediction
sample_index = 10  # Choose any index from the test set
sample_image = np.expand_dims(x_test_resized[sample_index], axis=0)
prediction = model.predict(sample_image)
predicted_class = np.argmax(prediction)
print("Predicted class:", predicted_class)


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5




Epoch 1/10

KeyboardInterrupt: 