In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Dilated ResNet 1D Encoder

In [23]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, dropout=0.0):
        super(ResidualBlock, self).__init__()
        
        self.bn1 = nn.BatchNorm1d(in_channels)
        self.relu1 = nn.ReLU()
        self.conv1 = nn.Conv1d(
            in_channels, out_channels, kernel_size, 
            stride=stride, 
            padding=dilation*(kernel_size//2), 
            dilation=dilation, 
            bias=False)
        self.drop1 = nn.Dropout(p=dropout)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.relu2 = nn.ReLU()
        self.drop2 = nn.Dropout(p=dropout)
        self.conv2 = nn.Conv1d(
            out_channels, out_channels, kernel_size, 
            stride=stride, 
            padding=dilation*(kernel_size//2), 
            dilation=dilation, 
            bias=False)
        
        self.bn3 = nn.BatchNorm1d(out_channels)
        self.relu3 = nn.ReLU()
        self.downsample = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)

    def forward(self, x):
        identity = x
        identity = self.downsample(identity)

        out = self.bn1(x)
        out = self.relu1(out)
        out = self.drop1(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu2(out)
        out = self.drop2(out)
        out = self.conv2(out)

        out = self.downsample(out)

        out += identity
        out = self.bn3(out)
        out = self.relu3(out)

        return out

class DeilatedResidualNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, n_layers, expansion_factor=4):
        super(DeilatedResidualNet, self).__init__()

        self.in_channels = in_channels
        self.kernel_size = kernel_size
        self.h_dim = out_channels // n_layers

        dilation_rates = [expansion_factor**i for i in range(n_layers)]

        self.blocks = nn.ModuleList([
            ResidualBlock(self.in_channels, self.h_dim, self.kernel_size, dilation=dilation)
            for dilation in dilation_rates
        ])

    def forward(self, x):
        outputs = [ block(x) for block in self.blocks ]
        output = torch.cat(outputs, dim=1)
        
        return output

class DilatedResidualEncoder(nn.Module):
    def __init__(self, kernel_sizes=[3, 5, 7, 9], in_channels=8, planes=24, dilate_layers=[6,3,1], expansion_factor=4):
        super(DilatedResidualEncoder, self).__init__()

        self.in_channels = in_channels
        self.planes = planes
        self.kernel_sizes = kernel_sizes
        self.dilate_layers = dilate_layers
        self.expansion_factor = expansion_factor

        out_channels = self.planes * self.in_channels
        fix_kernel_size = 5
        self.conv1 = nn.Conv1d(
            self.in_channels, out_channels, kernel_size=fix_kernel_size, stride=1, padding=fix_kernel_size//2
            )
        
        self.blocks = nn.ModuleList([
            nn.Sequential(*[
                ResidualBlock(
                    out_channels, out_channels, kernel_size, dilation=dilation
                ) for dilation in self.dilate_layers
            ])
            for kernel_size in self.kernel_sizes
        ])

        self.pooling = nn.AdaptiveAvgPool1d(1)
    
    def forward(self, x):

        x = self.conv1(x)
        outputs = [ block(x) for block in self.blocks ]
        output = torch.cat(outputs, dim=1)
        output = self.pooling(output).squeeze(-1)
        
        return output

In [24]:

# Model instantiation example
input_channels = 8  # For 8-channel sequence input
model = DilatedResidualEncoder()

# Example input tensor
example_input = torch.rand(8, input_channels, 2000)  # [batch_size, channels, seq_len]
example_output = model(example_input)

print("Output shape:", example_output.shape)  # Expected shape: [batch, 768]


Output shape: torch.Size([8, 768])
