In [1]:
import pickle
import numpy as np
from datetime import datetime
import tensorflow as tf
import io
import matplotlib.pyplot as plt

%load_ext tensorboard

  from ._conv import register_converters as _register_converters


In [2]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [5]:
data = unpickle("data_batch_1")
class_name_data = unpickle("batches.meta")
data_key = "data".encode()
label_key = "labels".encode()
label_names_key = "label_names".encode()
images = data[data_key]
class_index = data[label_key]
class_names = class_name_data[label_names_key]


img = np.reshape(images[:], (len(images), 32, 32, 3), order="F")

In [6]:
# Clear out prior logging data.
!rm -rf logs/plots

logdir = "logs/plots/" + datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(logdir)

def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image.
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image

def image_grid():
    """Returns a 5X5 grid of a CIFAR-10 image as a matplotlib figure."""
    # Create a figure to contain the plot.
    figure = plt.figure(figsize=(10, 10))
    for i in range(25):
        # Start next subplot.
        plt.subplot(5, 5, i + 1, title = class_names[class_index[i]].decode())
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(img[i], cmap=plt.cm.binary)
    return figure

# Prepare the plot.
figure = image_grid()
# Convert to image and log.
with file_writer.as_default():
    tf.summary.image("Training data", plot_to_image(figure), step=0)

%tensorboard --logdir logs/plots