In [None]:
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import fashion_mnist
import wandb

# Load the dataset
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# Normalize images
train_images, test_images = train_images / 255.0, test_images / 255.0
# Define class names
class_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
#intialize wanddb
wandb.init(project="deep_learning",name='mnist_dataset')
# Select one image per class
samples_per_class = {}
for img, label in zip(train_images, train_labels):
    if label not in samples_per_class:
        samples_per_class[label] = img
    if len(samples_per_class) == len(class_names):
        break
wandb.log({"Fashion-MNIST Sample Images": [wandb.Image(img, caption=class_names[label]) for label, img in samples_per_class.items()]})

# Plot images
fig, axes = plt.subplots(2, 5, figsize=(10, 5))
fig.suptitle("Fashion-MNIST Sample Images", fontsize=14)

for ax, (label, image) in zip(axes.flat, samples_per_class.items()):
    ax.imshow(image, cmap='gray')
    ax.set_title(class_names[label])
    ax.axis('off')

plt.tight_layout()
plt.show()
wandb.finish()