In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class W2VModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, stride, filter_size, padding):
        super(W2VModel, self).__init__()
        assert (
                len(stride) == len(filter_size) == len(padding)
        ), "Inconsistent length of strides, filter sizes and padding"

        self.model = nn.Sequential()
        for index, (stride, filter_size, padding) in enumerate(zip(stride, filter_size, padding)):
            self.model.add_module(
                "model_layer_{}".format(index),
                nn.Sequential(
                    nn.Conv1d(in_channels=input_dim, out_channels=hidden_dim,
                              kernel_size=filter_size, stride=stride, padding=padding),
                    nn.BatchNorm1d(hidden_dim),
                    nn.ReLU(),
                )
            )
            input_dim = hidden_dim

        self.model_2 = nn.Sequential()
        self.model_2.add_module(
            "fc_layer",
            nn.Sequential(
                nn.Linear(513*13*8, 500),
                nn.ReLU(),
                nn.Linear(500, 255),
                nn.ReLU()
            )
        )
    
    def forward(self, x):
        x = self.model(x)
        x = x.view(1,-1)
        x = self.model_2(x)
        x = F.log_softmax(x, dim=1)
        return x

In [14]:
import numpy as np

file_path = './data/702000.npy'
data = np.load(file_path)
print(data.shape)

(48, 1876)


In [15]:
model = W2VModel(
    input_dim = 48,
    hidden_dim=513,
    stride=[5, 4, 2, 2, 2],
    filter_size=[10, 8, 4, 2, 2],
    padding=[2, 2, 2, 2, 1])

In [16]:
sample_data = torch.randn(8, 48, 1876)
print(sample_data.shape)

output = model(sample_data)
print(output.size())

torch.Size([8, 48, 1876])
torch.Size([1, 255])


In [17]:
print(output)

tensor([[-5.6122, -5.5175, -5.5771, -5.5134, -5.6122, -5.6122, -5.6122, -5.6122,
         -5.5892, -5.4518, -5.6122, -5.4463, -5.4534, -5.4743, -5.4908, -5.6122,
         -5.6122, -5.4778, -5.6122, -5.5480, -5.2982, -5.6122, -5.6122, -5.6122,
         -5.2897, -5.6122, -5.3956, -5.6122, -5.4392, -5.3732, -5.6122, -5.6122,
         -5.3615, -5.6122, -5.6122, -5.6122, -5.3617, -5.6122, -5.4895, -5.6122,
         -5.6122, -5.2595, -5.6122, -5.6122, -5.4422, -5.4061, -5.6122, -5.6122,
         -5.5688, -5.4464, -5.5859, -5.6122, -5.4892, -5.5544, -5.6122, -5.5703,
         -5.6122, -5.6122, -5.4526, -5.6122, -5.4790, -5.6122, -5.6122, -5.6122,
         -5.6122, -5.6122, -5.6122, -5.5255, -5.6122, -5.6122, -5.3515, -5.6122,
         -5.6122, -5.6122, -5.5390, -5.6122, -5.6122, -5.3595, -5.6122, -5.6122,
         -5.6122, -5.6122, -5.6122, -5.6122, -5.5014, -5.3082, -5.3406, -5.4962,
         -5.3987, -5.5066, -5.6122, -5.4028, -5.4336, -5.5659, -5.5460, -5.5228,
         -5.6122, -5.5860, -