This notebook is inspired from the book [Deep Generative Modeling - Jakub M. Tomczak](https://link.springer.com/book/10.1007/978-3-031-64087-2).

### **Causal Convolutions - a replacement to RNNs for long-range memory** 

- **Causal Convolutions:** To harness parallel computation, one can use convolutional layers that are "masked" or designed so that the output at position $d$ depends only on inputs from positions less than (or, in some layers, up to) $d$.  
  - **Option A:** In the very first layer, the convolution kernel is masked so that it does not see the current $x_d$.  
  - **Option B:** In later layers, the network may use a kernel that can include the current value.
- **Dilation:** By using dilated convolutions (i.e., skipping over certain positions), the receptive field - the range of input positions that affect a given output—can be increased without a proportional increase in the number of layers.
- **Advantages:**  
  - **Parameter Sharing:** Convolution kernels are reused across positions, making the model efficient.
  - **Parallel Computation:** Unlike RNNs, convolutions can be computed in parallel.
- **Downside:**  
  - **Sampling Speed:** When generating new data (ie, sampling), the autoregressive nature forces one to iterate one position at a time. Even though a forward pass is parallel when evaluating likelihoods, sampling requires a loop over positions, making it slower.

&nbsp;

Let's see how to implement a 1D causal convolutional layer in PyTorch.

In [1]:
import torch
import torch.nn as nn

In [2]:
class CausalConv1d(nn.Module):
    """
    A causal 1D convolution.
    """

    def __init__(self, in_channels, out_channels, kernel_size, dilation, A=False, **kwargs):
        super(CausalConv1d, self).__init__()
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.A = A
        self.padding = (kernel_size - 1) * dilation + A * 1
        self.conv1d = torch.nn.Conv1d(in_channels, out_channels,
                                      kernel_size, stride=1,
                                      padding=0,
                                      dilation=dilation,
                                      **kwargs)

    def forward(self, x):
        x = torch.nn.functional.pad(x, (self.padding, 0))
        conv1d_out = self.conv1d(x)
        if self.A:
            return conv1d_out[:, :, : -1]
        else:
            return conv1d_out
