# Plotting Related Images

In [None]:
def visualize_item(ax, item):
    ax.imshow(item)
    ax.set_axis_off()

def show_related_images(*relatives, batch_to_rows=True, title='', size=1.5):
    num_relatives, relative_dims = len(relatives), len(relatives[0].shape)

    if relative_dims == 3:
        relatives = list(map(lambda x: tf.expand_dims(x, axis=0), relatives))
    
    batch_size = relatives[0].shape[0]
    items = tf.range(batch_size*num_relatives)

    if batch_to_rows:
        fig, axes = plt.subplots(batch_size, num_relatives, figsize=(num_relatives*size, batch_size*size))
        rows, cols = tf.unravel_index(indices=items, dims=[batch_size, num_relatives])

        fig.supylabel('Batch')
        fig.supxlabel('Relatives')
    else:
        fig, axes = plt.subplots(num_relatives, batch_size, figsize=(batch_size*size, num_relatives*size))
        rows, cols = tf.unravel_index(indices=items, dims=[num_relatives, batch_size])

        fig.supxlabel('Batch')
        fig.supylabel('Relatives')

    axes = axes.ravel()

    for item_id in range(batch_size*num_relatives):
        row, col = rows[item_id], cols[item_id]
        ax = axes[item_id]

        item = relatives[col][row] if batch_to_rows else relatives[row][col]
        visualize_item(ax, item)
    
    fig.suptitle(title)
    fig.tight_layout()
    

batch_size, img_size, classes = 3, 8, 3
image = tf.random.normal((batch_size, img_size, img_size, 3))
mask = tf.random.uniform((batch_size, img_size, img_size, 1), maxval=classes, dtype=tf.int32)

show_related_images(image, mask, batch_to_rows=False)

# Plotting Images and Bounding Boxes

* It supports both the true and the predicted bounding boxes.

## Boxes in YXHW format

In [None]:
def plot_boxes(ax, boxes, size, color='green', alpha=1.0, fill=False, no_edge=False, topLabel=True):
    """
        ax: Plot Axes.
        boxes: A tensor of YXHW values for the boxes. Shape: (N_BOXES, 4)
        size: Size of the image. It is used to translate the YXHW values to image size.
        color: A string to indicate the color in matplotlib format for the bounding boxes.
        alpha: A FP value for the box fill.
        fill: If true, fill the boxes with color.
        no_edge: If true, we do not draw the box edges.
        topLabel: If true, add the box id at the top-left corner of the box.
            If false, it goes to the bottom-right corner.
    """
    boxes = boxes*size
    # tf.print(boxes)

    for box_id, box in enumerate(boxes):
        [y_min, x_min, height, width] = box
        rect = patches.Rectangle([x_min, y_min], width, height, ec=None if no_edge else color, fill=fill, fc=color, alpha=alpha)
        ax.add_patch(rect)
        text_bbox = dict(ec=color, fc=color, mutation_aspect=.05, boxstyle=patches.BoxStyle.Square(pad=.4))

        if topLabel:
            x, y = rect.get_xy()
            x += 4
            ha, va = 'left', 'bottom'
        else:
            x, y = x_min + width - 4, y_min + height 
            ha, va = 'right', 'bottom'
        
        ax.text(x, y, str(box_id + 1), ha=ha, va=va, size='xx-small', color='white', bbox=text_bbox)

def plot_examples(images, y_true, y_pred, cols=3, match=False):
    """
        It plots true and predicted bounding boxes over the images.

        Arguments:
            images: A tensor of images. Shape: (BATCH_SIZE, IMG_SIZE, IMG_SIZE, 3)
            y_true: A grid of box HW values for each feature point. Shape: (BATCH_SIZE, SIZE, SIZE, 2)
            y_pred: A tensor of YXHW values for predicted boxes. Shape: (N_BOXES, 4)
            cols: The number of columns in the plot.
            match: If true, we truncate the predicted boxes to the number of true boxes.
    """
    batch_size = y_pred.shape[0] if y_pred is not None else y_true.shape[0]

    # Plot configuration
    rows = (batch_size + cols - 1)//cols
    fig, axes = plt.subplots(rows, cols, figsize=(8, rows*3))
    color_1, color_2, color_i, color_n = ['red', 'green', 'blue', 'yellow']
    axes = axes.ravel()

    def slice_ys(item_id):
        item_y_true = None if y_true is None else hw_grid_to_yxhw(y_true[item_id])

        num_y_pred_boxes = tf.shape(item_y_true)[0] if match and item_y_true is not None else MAX_BOXES
        item_y_pred = None if y_pred is None else y_pred[item_id, :num_y_pred_boxes]

        return item_y_true, item_y_pred

    for item_id in range(batch_size):
        # Plot image
        axes[item_id].imshow(images[item_id])

        item_y_true, item_y_pred = slice_ys(item_id)

        # Plot predicted boxes
        plot_boxes(axes[item_id], item_y_pred if item_y_pred is not None else [], IMG_SIZE, color_2, topLabel=False)
        plot_boxes(axes[item_id], item_y_true if item_y_true is not None else [], IMG_SIZE, color_1)

        # Add box counts as plot title.
        title = ' | '.join(
            map(
                lambda x: str(int(tf.shape(x)[0])),
                filter(
                    lambda x: x is not None,
                    [item_y_pred, item_y_true]
                )
            )
        )
        axes[item_id].set_title(title)

    plt.tight_layout()

# itr = iter(train_prep_ds.batch(2))
# images, y_true = next(itr)

# plot_examples(images, y_true, None)

# r = patches.Rectangle([0, 0], 5, 5)

## Boxes in CYCXHW format

In [None]:
def plot_box_centers(ax, boxes, size, color='green', alpha=1.0, marker_size=None):
    """
        ax: Plot Axes.
        boxes: A tensor of CYCXHW values for the box centers. Shape: (N_BOXES, 4)
        size: Size of the image. It is used to translate the CYCX values to image size.
        color: A string to indicate the color in matplotlib format for the bounding boxes.
        alpha: A FP value for the box fill.
    """
    if boxes is None: return

    boxes = boxes*size
    y, x = boxes[:, 0], boxes[:, 1]

    ax.scatter(x, y, color=color, s=marker_size, alpha=alpha)

    for box_id, yi in enumerate(y):
        xi = x[box_id]
        ax.text(xi, yi, str(box_id + 1), ha='center', va='center', size='xx-small', color='white')

def plot_examples(images, y_true, y_pred, cols=3, match=False):
    """
        It plots true and predicted bounding boxes over the images.

        Arguments:
            images: A tensor of images. Shape: (BATCH_SIZE, IMG_SIZE, IMG_SIZE, 3)
            y_true: A grid of box HW values for each feature point. Shape: (BATCH_SIZE, SIZE, SIZE, 2)
            y_pred: A tensor of YXHW values for predicted boxes. Shape: (N_BOXES, 4)
            cols: The number of columns in the plot.
            match: If true, we truncate the predicted boxes to the number of true boxes.
    """
    batch_size = y_pred.shape[0] if y_pred is not None else y_true.shape[0]

    # Plot configuration
    rows = (batch_size + cols - 1)//cols
    fig, axes = plt.subplots(rows, cols, figsize=(8, rows*3))
    color_1, color_2, color_i, color_n = ['red', 'green', 'blue', 'yellow']
    axes = axes.ravel()

    def slice_ys(item_id):
        item_y_true = None if y_true is None else hw_grid_to_cycxhw(y_true[item_id])

        # tf.print('item_y_true: ', item_y_true, item_y_true.shape)

        num_y_pred_boxes = tf.shape(item_y_true)[0] if match and item_y_true is not None else MAX_BOXES
        item_y_pred = None if y_pred is None else y_pred[item_id, :num_y_pred_boxes]

        return item_y_true, item_y_pred

    for item_id in range(batch_size):
        # Plot image
        axes[item_id].imshow(images[item_id])

        item_y_true, item_y_pred = slice_ys(item_id)

        # Plot predicted boxes
        plot_box_centers(axes[item_id], item_y_pred, IMG_SIZE, color_2, marker_size=20)
        plot_box_centers(axes[item_id], item_y_true, IMG_SIZE, color_1, alpha=0.5)

        # Add box counts as plot title.
        title = ' | '.join(
            map(
                lambda x: str(int(tf.shape(x)[0])),
                filter(
                    lambda x: x is not None,
                    [item_y_pred, item_y_true]
                )
            )
        )
        axes[item_id].set_title(title)

    plt.tight_layout()

# itr = iter(train_prep_ds.batch(2))
# images, y_true = next(itr)

# y_pred = tf.random.uniform((2, 20, 4))
# plot_examples(images, y_true, y_pred)
# plot_examples(images, y_true, None)

# r = patches.Rectangle([0, 0], 5, 5)

# History Plots

In [None]:
def mark_max_min(ax, column):
    xs = [column.idxmin(), column.idxmax()]
    ys = [column.min(), column.max()]
    
    ax.scatter(xs, ys, zorder=2)

def plot_metrics(ax, df, name, title=None):
    # Plot training metrics
    sns.lineplot(ax=ax, x='Epochs', y=name, data=df, label=name)
    mark_max_min(ax, df[name])

    # Plot validation metrics
    val_name = 'Validation {}'.format(name)
    sns.lineplot(ax=ax, x='Epochs', y=val_name, data=df, label=val_name)
    mark_max_min(ax, df[val_name])

    # Plot learning rates on the right Y-axis
    lr_ax = ax.twinx()
    sns.lineplot(ax=lr_ax, x='Epochs', y='Learning Rate', data=df, linestyle='--')

    for index, lr_value in enumerate(df['Learning Rate'][:-1]):
        ax.axvspan(index, index + 1, fc='#77d8c0', alpha=(1 - (1/(len(df) - 1))*(index)))

    # Label the plots
    ax.legend()
    ax.set_title(title if title else name)

def load_history_as_df(h):
    epochs = len(h['loss'])
    data = list(zip(
        h['loss'],
        h['val_loss'],
        h['learning_rate'][1:]
    ))

    df = pd.DataFrame(data, columns=[
        'Loss',
        'Validation Loss',
        'Learning Rate'
    ], index=range(1, epochs + 1))

    df.at[0, 'Learning Rate'] = h['learning_rate'][0]
    df = df.sort_index()
    df['Epochs'] = df.index

    return df

def plot_history(h):
    fig, axes = plt.subplots(1, 2, figsize=(8, 4), facecolor='w', edgecolor='k')
    axes = axes.ravel()

    loss, _ = axes

    df = load_history_as_df(h)

    # Plot loss metrics
    plot_metrics(loss, df, 'Loss', title='Losses')

    plt.tight_layout()

plot_history(hist.history)
hist.history

# Conv Weights

In [None]:
def rescale(x, targetMin, targetMax):
    dataMin = tf.math.reduce_min(x)

    if dataMin < 0:
        x -= dataMin
        dataMin = 0
    
    dataMax = tf.math.reduce_max(x)
    rescaled = (targetMax - targetMin)*(x - dataMin)/(dataMax - dataMin) + targetMin

    return rescaled

def rescale_and_plot_channels(ax, x, cols=4):
    channels = x.shape[-1]

    for channel_id in range(channels):
        rescaled = rescale(x[:, :, channel_id:channel_id+1], 0, 1)
        ax[channel_id + 1].imshow(rescaled, cmap='gray')

def plot_weights_2d(images, weights, cols=4):
    num_items= weights.shape[0]
    rows = (weights.shape[-1] + cols)//cols

    for item_id in range(num_items):
        _, axes = plt.subplots(rows, cols, figsize=(12, 2.5*rows))

        item_weights = weights[item_id]
        axes = axes.ravel()
        # rescaled = rescale(weight, 0, 1)

        #
        axes[0].imshow(images[item_id])
        rescale_and_plot_channels(axes, item_weights)

# plot_weights_2d(images, outputs)

# Grid Plot

In [None]:
from matplotlib import pyplot as plt

def make_grid_plot(num_items, cols=3, size=3):
    rows = (num_items + cols - 1)//cols
    cols = min(cols, num_items)
    
    fig, axes = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
    axes = axes.flatten() if cols > 1 else [axes]
    
    fig.tight_layout()
    
    return fig, axes