#### Imports

In [1]:
import numpy as np
from tensorflow.keras.layers import Conv1D
from tensorflow.keras.models import Sequential
from tensorflow.random import set_seed
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.animation as animation
from IPython.display import HTML, Image
plt.style.use('ggplot')
params = {'legend.fontsize': '18',
          'axes.labelsize': '20',
          'axes.labelweight': 'bold',
          'axes.titlesize':'20',
          'xtick.labelsize':'18',
          'ytick.labelsize':'18'}
plt.rcParams.update(params)

#### Functions

In [2]:
def mk_fig():
    """
    Convenience function to plot figure canvas

    Returns
    -------
    fig, axes
        Figure and axes objects
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 4))
    axes[0].set_xlim(-2, 22)
    axes[0].set_ylim(-1, 1)
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('y')
    axes[0].set_title('Input sequence')
    axes[1].set_xlim(-2, 22)
    axes[1].set_ylim(-1, 1)
    axes[1].set_xlabel('x')
    axes[1].set_ylabel('y')
    axes[1].set_title('Conv1D layer output')
    
    return fig, axes

#### Data

In [3]:
np.random.seed(100)
n = 21
x = (np.sin(np.linspace(0, 3, n)) + np.random.randn(n)*0.1) * 0.8

#### Model

In [4]:
set_seed(100)
model = Sequential(name="conv1d_model")
filters = 3
kernel_size = 3
model.add(Conv1D(filters, kernel_size=kernel_size, input_shape=(n, 1), padding='same'))
x_out = model.predict(x[None,:,None])[0]
model.summary()

Model: "conv1d_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv1d (Conv1D)              (None, 21, 3)             12        
Total params: 12
Trainable params: 12
Non-trainable params: 0
_________________________________________________________________


#### Create and save animations

In [11]:
fig, axes = mk_fig()
ec = [(0.89, 0.29, 0.2, 1), (0.2, 0.54, 0.74, 1), (0.60, 0.56, 0.84, 1)]
fc = [(0.89, 0.29, 0.2, 0.2), (0.2, 0.54, 0.74, 0.2), (0.60, 0.56, 0.84, 0.2)]

def init():
    axes[0].plot(x, '-k', marker='.', ms=13)
    axes[0].plot([-1, 0], [0, x[0]], 'k', marker='.', markerfacecolor='w', ms=13, zorder=1, label='zero padding')
    axes[0].plot([20, 21], [ x[-1], 0], 'k', marker='.', markerfacecolor='w', ms=13, zorder=1)

def animate(i):
    axes[0].set_title(f'Input sequence. Passing filter {i//21 + 1}.')
    [p.remove() for p in reversed(axes[0].patches)];
    p = []
    p.append(patches.Rectangle((i%n-1.5, x[i%n]-0.4), 1, 0.8, linewidth=1, edgecolor=ec[i//21], facecolor=fc[i//21]))
    p.append(patches.Rectangle((i%n-0.5, x[i%n]-0.4), 1, 0.8, linewidth=1, edgecolor=ec[i//21], facecolor=fc[i//21]))
    p.append(patches.Rectangle((i%n+0.5, x[i%n]-0.4), 1, 0.8, linewidth=1, edgecolor=ec[i//21], facecolor=fc[i//21]))
    for _ in p:
        axes[0].add_patch(_)
    axes[1].plot(x_out[:i%n+1,i//21], color=ec[i//21], marker='.', ms=13);
        
plt.close(fig)
ani = animation.FuncAnimation(fig,
                              animate,
                              init_func=init,
                              frames=63)
ani.save('../gif/cnn/cnn_1d.gif', writer='imagemagick', fps=3, dpi=75)
# HTML(ani.to_jshtml())

#### View animations

In [9]:
Image(url='../gif/cnn/cnn_1d.gif')