In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Block(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, stride=2, groups=8):
        super().__init__()
        self.conv0 = nn.Conv1d(in_channels, hidden_channels, kernel_size=1, stride=1)
        self.bn0 = nn.BatchNorm1d(hidden_channels)

        self.conv1 = nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, stride=stride, padding=1, groups=groups) 
        self.bn1 = nn.BatchNorm1d(hidden_channels)

        self.conv2 = nn.Conv1d(hidden_channels, out_channels, kernel_size=1, stride=1)
        self.bn2 = nn.BatchNorm1d(out_channels)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride)

    def forward(self, x):
        residual = x
        x = self.conv0(x)
        x = F.relu(self.bn0(x), inplace=True)
        x = self.conv1(x)
        x =  F.relu(self.bn1(x), inplace=True)
        x = self.conv2(x)
        x = self.bn2(x)
        x += self.shortcut(residual)
        x = F.relu(x, inplace=True)
        return x

In [3]:
X = torch.randn((32,32,200))
block = Block(32,32,128)
X.shape

torch.Size([32, 32, 200])

In [4]:
block(X).shape

torch.Size([32, 128, 100])

In [5]:
class Stage(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.block0 = Block(in_channels, hidden_channels, out_channels, stride=2)
        self.block1 = Block(out_channels, hidden_channels, out_channels, stride=1)

    def forward(self, x):
        x = self.block0(x)
        x = self.block1(x)
        return x


In [6]:
stage0 = Stage(32,32, 128)

In [7]:
stage0(X).shape

torch.Size([32, 128, 100])

In [8]:
class Body(nn.Module):
    def __init__(self):
        super().__init__()
        self.stage0 = Stage(32, 32, 128)
        self.stage1 = Stage(128, 64, 256)
        self.stage2 = Stage(256, 128, 512)
        self.stage3 = Stage(512, 256, 1024)

    def forward(self, x):
        x = self.stage0(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        return x


In [9]:
body = Body()

In [10]:
res = body(X)
res.shape

torch.Size([32, 1024, 13])

In [11]:
class Stem(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv0 = nn.Conv1d(in_channels, 32, kernel_size=3, stride=2, padding=1)
        self.bn0 = nn.BatchNorm1d(32)

    def forward(self, x):
        x = self.conv0(x)
        x = F.relu(self.bn0(x), inplace=True)
        return x


In [12]:
stem = Stem(6)

In [13]:
X_pre = torch.randn((32,6,400))


In [14]:
stem(X_pre).shape

torch.Size([32, 32, 200])

In [None]:
class ClassificationHead(nn.Module):
    def __init__(self, in_channels, dropout=0.0):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool1d(1)
        self.l0 = nn.Linear(in_channels, 256)
        self.dropout = nn.Dropout(p=dropout)
        self.out = nn.Linear(256, 1)

    def forward(self, x):
        x = (self.gap(x)).squeeze()
        x = F.relu(self.l0(x), inplace=True)
        x = self.dropout(x)
        x = self.out(x)
        return x


In [16]:
class_head = ClassificationHead(1024)


In [17]:
res.shape

torch.Size([32, 1024, 13])

In [18]:
logits = class_head(res)
logits.shape

torch.Size([32, 1024])


torch.Size([32, 1])

In [20]:
class ResNet(nn.Module):
    def __init__(self, in_channels = 6, dropout=0.0):
        super().__init__()
        self.stem = Stem(in_channels)
        self.body = Body()
        self.class_head = ClassificationHead(1024, dropout)
    
    def forward(self, X):
        x = self.stem(X)
        x = self.body(x)
        x = self.class_head(x)
        return x

In [21]:
res_net = ResNet()

In [23]:
res_net(X_pre).shape

torch.Size([32, 1024])


torch.Size([32, 1])