In [None]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument(
    "--n_epochs", type=int, default=1, help="number of epochs of training"
)
parser.add_argument("--batch_size", type=int, default=1024, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.002, help="adam's learning rate")
parser.add_argument(
    "--b1",
    type=float,
    default=0.5,
    help="adam: decay of first-order momentum of gradient",
)
parser.add_argument(
    "--b2",
    type=float,
    default=0.99,
    help="adam: decay of second-order momentum of gradient",
)
parser.add_argument(
    "--in_len",
    type=int,
    default=2**10,
    help="length of the input fed to the neural net",
)
parser.add_argument(
    "--in_channels", type=int, default=32, help="number of signal channels"
)
parser.add_argument("--out_channels", type=int, default=6, help="number of classes")
parser.add_argument("--chunk", type=int, default=1000, help="length of split chunks")

opt, unknown = parser.parse_known_args()

In [1]:
import torch.nn as nn


class NNet(nn.Module):
    def __init__(self, in_channels=opt.in_channels, out_channels=opt.out_channels):
        super(NNet, self).__init__()
        self.hidden = 32
        self.net = nn.Sequential(
            nn.Conv1d(in_channels, in_channels, 5, padding=2),
            nn.Conv1d(self.hidden, self.hidden, 16, stride=16),
            nn.LeakyReLU(0.1),
            nn.Conv1d(self.hidden, self.hidden, 7, padding=3),
        )
        for i in range(6):
            self.net.add_module(
                "conv{}".format(i), self.__block(self.hidden, self.hidden)
            )
        self.net.add_module(
            "final",
            nn.Sequential(nn.Conv1d(self.hidden, out_channels, 1), nn.Sigmoid()),
        )

    def __block(self, inchannels, outchannels):
        return nn.Sequential(
            nn.MaxPool1d(2, 2),
            nn.Dropout(p=0.1, inplace=True),
            nn.Conv1d(inchannels, outchannels, 5, padding=2),
            nn.LeakyReLU(0.1),
            nn.Conv1d(outchannels, outchannels, 5, padding=2),
            nn.LeakyReLU(0.1),
        )

    def forward(self, x):
        return self.net(x)


KeyboardInterrupt: 