In [2]:
import torch
import torch.nn as nn

from madigan.fleet.net.utils import calc_pad_to_conserve

In [None]:
class DilConvLayer(nn.Module):
    """
    1D Causal Convolution using small dilated kernels
    Contains residual connection as well as skip connection to output
    """
    def __init__(self, channels_in, channels_out, kernel, stride, dilation,
                 dropout):
        super().__init__()
        self.c_in = channels_in
        self.c_out = channels_out
        self.kernel = kernel
        self.dilation = dilation
        self.stride = stride
        self.dropout = dropout
#         self.output_shape = output_shape
        self.act_fn = nn.GELU()
        self.conv_project = nn.Conv1d(channels_in, channels_out, kernel, 
                                       dilation=dilation,
                                      stride=stride, bias=False)
        padding = (dilation*(kernel - 1), 0)
        self.causal_padding_layer = nn.ReplicationPad1d(padding)
        self.conv_embed = nn.Sequential(self.conv_project, self.causal_padding_layer,
                                       self.act_fn)
        self.conv_compress = nn.Conv1d(channels_out, 1, kernel_size=1, 
                                       bias=False)
        self.skip_conv = nn.Conv1d(channels_out, 1, kernel_size=1,
                                   bias=False)
        self.dropout = nn.Dropout(dropout)
            
    def forward(self, x):
        latent = self.conv_embed(x)
#         import ipdb; ipdb.set_trace()
        res = x + self.conv_compress(latent)
        skip_connection = self.dropout(self.skip_conv(latent))
        return res, skip_connection
    
class SeriesNet(nn.Module):
    def __init__(self, input_length, input_dim, output_length, channel_dims, kernels,
                 dilations, strides, dropouts):
        super().__init__()
        assert(len(channel_dims) == len(kernels) == len(dilations)
              == len(strides) == len(dropouts)), \
               "layer hyperparams must be same length - n_layers"
        self.output_act_fn = nn.ReLU()
        self.layers = []
        for i in range(len(kernels)):
            self.layers.append(DilConvLayer(1,
                                       channel_dims[i], kernels[i],
                                       strides[i], dilations[i], 
                                       dropouts[i]))
        self.output_compression = nn.Conv1d(channel_dims[-1], out_channels=1,
                                            kernel_size=1, bias=False)
        self.output_layer = nn.Linear(input_length, output_length)
        self.apply(self.init)
        
    def init(self, m):
        if isinstance(m, nn.Conv1d):
            nn.init.normal_(m.weight, mean=0., std=0.05)
    
    def forward(self, x):
        x, out = self.layers[0](x)
        for layer in self.layers[1:]:
            x, skip_out = layer(x)
            out += skip_out
        out = self.output_act_fn(out)
        out = self.output_layer(out)
        return out.view(out.shape[0], -1)

In [55]:
input_len = 108
input_dim = 1
output_len = 1
x = torch.randn(1, input_dim, input_len)

model = SeriesNet(input_len, input_dim, output_len, 
                 [32, 32, 32, 32, 32, 32, 32],
                 [2, 2, 2, 2, 2, 2, 2],
                 [1, 2, 4, 8, 16, 32, 64],
                 [1]*7,
                 [0., 0., 0., 0., 0., 0.8, 0.8])
model(x).shape
# model.layers[0]

torch.Size([1, 1])

2