In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models, losses, optimizers
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import plotly.graph_objects as go

# Load and preprocess the fashion_mnist dataset
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0

# Convert grayscale images to RGB
train_images_rgb = tf.repeat(train_images[..., tf.newaxis], 3, axis=-1)
test_images_rgb = tf.repeat(test_images[..., tf.newaxis], 3, axis=-1)

# One-hot encode the labels
num_classes = 10
train_labels = to_categorical(train_labels, num_classes)
test_labels = to_categorical(test_labels, num_classes)

# Data augmentation
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# Load pre-trained ResNet50 model (excluding the top layers)
resnet_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze the layers of ResNet
for layer in resnet_model.layers:
    layer.trainable = False

# Build the neural network for style classification
model = models.Sequential([
    layers.experimental.preprocessing.Rescaling(scale=1./255, input_shape=(28, 28, 3)),
    layers.experimental.preprocessing.Resizing(224, 224),
    resnet_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(512, activation='relu'),
    layers.Dense(256, activation='relu'),
    layers.Dense(num_classes, activation='softmax')
])

# Compile the model
model.compile(optimizer=optimizers.Adam(),
              loss=losses.CategoricalCrossentropy(),
              metrics=['accuracy'])

# Train your model with data augmentation
epochs = 10
history = model.fit(
    datagen.flow(train_images_rgb, train_labels, batch_size=32),
    epochs=epochs,
    validation_data=(test_images_rgb, test_labels)
)

# Evaluate your model on the test set
eval_result = model.evaluate(test_images_rgb, test_labels)
print(f"Test Accuracy: {eval_result[1] * 100:.2f}%")

# Plot training history using Plotly
fig = go.Figure()

fig.add_trace(go.Scatter(x=list(range(1, epochs + 1)), y=history.history['accuracy'], mode='lines+markers', name='Training Accuracy'))
fig.add_trace(go.Scatter(x=list(range(1, epochs + 1)), y=history.history['val_accuracy'], mode='lines+markers', name='Validation Accuracy'))

fig.update_layout(title='Training and Validation Accuracy',
                  xaxis_title='Epoch',
                  yaxis_title='Accuracy',
                  template='plotly_dark')
fig.show()

# Plot sample predictions using Plotly
sample_indices = [0, 1, 2, 3, 4]
sample_images = test_images_rgb[sample_indices]
sample_labels = test_labels[sample_indices]

predictions = model.predict(sample_images)
predicted_labels = tf.argmax(predictions, axis=1)

class_names = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

fig = go.Figure()

for i in range(len(sample_indices)):
    true_label = class_names[tf.argmax(sample_labels[i])]
    pred_label = class_names[predicted_labels[i]]

    fig.add_trace(go.Image(z=sample_images[i], name=f'True: {true_label}<br>Pred: {pred_label}'))

fig.update_layout(title='Sample Predictions',
                  template='plotly_dark')
fig.show()


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Test Accuracy: 10.00%


InvalidArgumentError: ignored