In [1]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.preprocessing import OneHotEncoder

In [2]:
# Create random values
raw_input = np.random.randint(0, 4, size=(1000, 1))

# One-hot
encoder = OneHotEncoder(categories='auto', sparse_output=False)
f_input = encoder.fit_transform(raw_input)

# Convert np matrix to torch matrix
f_input = torch.tensor(f_input, dtype=torch.float32).unsqueeze(0).permute(0, 2, 1)
print(f_input.shape)
print(f_input)

torch.Size([1, 4, 1000])
tensor([[[0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 1., 0., 1.],
         [1., 0., 1.,  ..., 0., 0., 0.]]])


In [3]:
class SignalCNN(nn.Module):
    def __init__(self):
        super(SignalCNN, self).__init__()
        
        # layers
        self.conv1 = nn.Conv1d(in_channels=4, out_channels=16, kernel_size=25, stride=1, padding=12)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv1d(in_channels=32, out_channels=1, kernel_size=1)
        
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [5]:
model = SignalCNN()
output = model(f_input)
print(output.shape)
print(output)

torch.Size([1, 1, 1000])
tensor([[[ 0.0412,  0.0807,  0.0815,  0.1027,  0.0515,  0.0474,  0.0667,
           0.1317,  0.1079,  0.0732,  0.0473,  0.0942,  0.0858,  0.0939,
           0.0726,  0.0695,  0.0825,  0.1781,  0.0838,  0.0932,  0.0803,
           0.0990,  0.0826,  0.0504,  0.0917,  0.0815,  0.0224,  0.0864,
           0.0545,  0.1332,  0.0541,  0.0765,  0.0259,  0.0672,  0.0855,
           0.0760,  0.1018,  0.1137,  0.1454,  0.1110,  0.0431,  0.1031,
           0.0872,  0.1136,  0.0283,  0.1016,  0.0927,  0.1122,  0.0873,
           0.0659,  0.0777,  0.0472,  0.1070,  0.0730,  0.0771,  0.1055,
           0.0543,  0.1024,  0.0766,  0.1052,  0.1001,  0.1615,  0.0326,
           0.1165,  0.1052,  0.0359,  0.1047,  0.1132,  0.0980,  0.0739,
           0.0388,  0.1138,  0.1112,  0.0975,  0.0872,  0.1443,  0.0407,
           0.1625,  0.0517,  0.1316,  0.0732,  0.1220,  0.0606,  0.1285,
           0.0672,  0.0697,  0.0894,  0.0684,  0.0864,  0.1051,  0.1142,
           0.0944,  0.0950