In [1]:
import torch as pt
import torch.nn as nn

In [3]:
class Conv1d(nn.Module):
    def __init__(self, in_channels:int, out_channels:int, kernel_size:int=3, 
                 stride:int=1, padding:int=0, bias:bool=True) -> None:
        super(Conv1d, self).__init__()
        self.stride = stride
        self.pad = padding
        self.ker_size = kernel_size        
        self.out_ch = out_channels
        self.kernel = nn.Parameter(nn.init.xavier_normal_(pt.randn(out_channels, in_channels, kernel_size)))
        self.bias = None
        if bias:
            self.bias = nn.Parameter(pt.zeros(out_channels))

    def forward(self, x):    
        # x：(batch_size, in_channel, in_len)
        out_len = (x.shape[2] + 2*self.pad - self.ker_size)//self.stride + 1
        x = nn.functional.pad(x, (self.pad, self.pad), "constant", 0)
        
        # (batch_size, in_ch, 1, in_len) -> (batch_size, in_ch*ker_size, out_len)
        x_unfold = nn.functional.unfold(x[:,:,None,:], 
                                        kernel_size=(1,self.ker_size), stride=(1,self.stride))
        out = self.kernel.view(self.out_ch, -1).matmul(x_unfold).view(x.shape[0], self.out_ch, out_len)
                 
        if self.bias is not None:
            out += self.bias[None,:,None]
 
        return out

In [4]:
a = pt.randn(2,3,256)
m = Conv1d(3,8,3,1,1)
b=m(a)
b

tensor([[[-0.3449, -0.4838,  0.8674,  ...,  1.0192,  0.7717,  0.3107],
         [ 0.2144, -0.0657, -0.2448,  ..., -1.3025, -0.5207, -1.1039],
         [ 0.1322,  0.2191,  0.7026,  ..., -0.1720,  0.2691, -1.0479],
         ...,
         [-0.4344,  0.7108, -1.1424,  ...,  0.2083, -0.6631,  0.3330],
         [ 0.7642,  0.5890, -0.7912,  ...,  0.6954, -1.8800,  0.1798],
         [ 0.0536, -0.1230,  0.7764,  ...,  0.5833, -0.8045,  0.6417]],

        [[ 0.4509,  1.3623,  0.4830,  ...,  1.0819, -0.1903,  0.7669],
         [ 0.3703,  0.4546, -1.9401,  ...,  0.3224,  0.5016,  0.8625],
         [-0.4299, -1.2423, -0.6637,  ..., -0.9293, -0.6359, -0.4731],
         ...,
         [ 0.4017, -1.1916, -0.0102,  ...,  0.8748, -0.7221, -0.1131],
         [-1.8109,  1.1097,  0.3012,  ..., -0.2892, -0.9557,  0.7772],
         [ 0.5414,  0.3040,  1.5627,  ...,  0.0882,  1.8263, -0.2107]]],
       grad_fn=<AsStridedBackward0>)