In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors

def plot_tensor(tensor_, labels):
    assert 2 <= tensor_.ndim <= 4
    assert len(labels) == tensor_.ndim

    tensor_ = tensor_.detach()

    if tensor_.ndim == 2:
        plt.imshow(tensor_.detach())
        plt.xticks(range(tensor_.shape[1]), labels[1])
        plt.yticks(range(tensor_.shape[0]), labels[0])
        plt.colorbar()
        plt.show()

    else:
        if tensor_.ndim == 3:
            tensor_ = np.expand_dims(tensor_, axis=0)

        rows, cols = tensor_.shape[:2]
        fig, axs = plt.subplots(rows, cols, squeeze=False)

        plt.setp(axs, xticks=range(tensor_.shape[-1]), xticklabels=labels[-1],
                 yticks=range(tensor_.shape[-2]), yticklabels=labels[-2])

        for ax in axs.flat:
            plt.setp(ax.get_xticklabels(), rotation=90)

        images = []
        for i in range(rows):
            for j in range(cols):
                im = axs[i, j].imshow(tensor_[i, j])
                images.append(im)
                axs[i, j].label_outer()

        # Normalize all images to same range
        vmin = min(image.get_array().min() for image in images)
        vmax = max(image.get_array().max() for image in images)
        norm = colors.Normalize(vmin=vmin, vmax=vmax)
        for im in images:
            im.set_norm(norm)

        # Colorbar on right
        fig.subplots_adjust(right=0.8)
        cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])
        fig.colorbar(images[0], cax=cbar_ax)

        plt.show()

def log_grad(tensor, title, labels):
    print(title)
    print(tensor)
    plot_tensor(tensor, labels)

def log_tensor(tensor, title, labels):
    assert len(labels) == tensor.ndim

    print(tensor)
    plot_tensor(tensor, labels)

    if tensor.requires_grad:
         tensor.register_hook(lambda grad: log_tensor(grad, title, labels))