In [1]:
from Libs import *

In [2]:
class SqueezeLayer(nn.Module):
    def __init__(self, dim=1):
        super(SqueezeLayer, self).__init__()
        self.dim = dim

    def forward(self, x):
        return torch.squeeze(x, dim=self.dim)
    
class UnsqueezeLayer(nn.Module):
    def __init__(self, dim=(1,)):
        super(UnsqueezeLayer, self).__init__()
        self.dim = dim

    def forward(self, x):
        for d in self.dim:
            x = torch.unsqueeze(x, dim=d)
        return x

In [3]:
class SEModule1d(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEModule1d, self).__init__()
        self.se = nn.Sequential(nn.AdaptiveAvgPool1d((1)),
                                SqueezeLayer(dim=(2,)),
                                nn.Linear(in_channels, in_channels//reduction),
                                nn.ReLU(),
                                nn.Linear(in_channels//reduction, in_channels),
                                nn.Sigmoid(),
                                UnsqueezeLayer(dim=(2,)))

    def forward(self, input):
        return input*self.se(input)

class SEResBlock1d(nn.Module):
    def __init__(self, in_channels, upsample=False, kernel_size=7, dropout=0.4):
        super(SEResBlock1d, self).__init__()
        if upsample:
            out_channels = in_channels*2
            self.identity = nn.Sequential(nn.Conv1d(in_channels, out_channels, kernel_size=1), 
                                          nn.BatchNorm1d(out_channels))
        else:
            out_channels = in_channels
            self.identity = nn.Identity()
        
        self.convs = nn.Sequential(nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding="same"), 
                                   nn.BatchNorm1d(out_channels), 
                                   nn.ReLU(), 
                                   nn.Dropout(dropout), 
                                   nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding="same"), 
                                   nn.BatchNorm1d(out_channels), 
                                   SEModule1d(out_channels))
        self.act_fn = nn.ReLU()

    def forward(self, input):
        return self.act_fn(self.convs(input) + self.identity(input))

class SEResNet1d(nn.Module):
    def __init__(self, base_channels=64, kernel_size=7, downsample=True):
        super(SEResNet1d, self).__init__()
        if downsample:
            self.seresnet = nn.Sequential(nn.Conv1d(1, base_channels, kernel_size=15, padding=7, stride=2), 
                                          nn.BatchNorm1d(base_channels), 
                                          nn.ReLU(), 
                                          nn.MaxPool1d(kernel_size=3, padding=1, stride=2),
                                          SEResBlock1d(base_channels, kernel_size=kernel_size), 
                                          SEResBlock1d(base_channels, kernel_size=kernel_size),
                                          SEResBlock1d(base_channels, upsample=True, kernel_size=kernel_size), 
                                          SEResBlock1d(base_channels*2, kernel_size=kernel_size),
                                          SEResBlock1d(base_channels*2, kernel_size=kernel_size),
                                          nn.AdaptiveAvgPool1d(1),
                                          SqueezeLayer(dim=2))
        else:
            self.seresnet = nn.Sequential(nn.Conv1d(1, base_channels, kernel_size=1), 
                                          SEResBlock1d(base_channels, kernel_size=kernel_size), 
                                          SEResBlock1d(base_channels, kernel_size=kernel_size),
                                          SEResBlock1d(base_channels, upsample=True, kernel_size=kernel_size), 
                                          SEResBlock1d(base_channels*2, kernel_size=kernel_size),
                                          SEResBlock1d(base_channels*2, kernel_size=kernel_size),
                                          nn.AdaptiveAvgPool1d(1),
                                          SqueezeLayer(dim=2))
  
    def forward(self, input):
        return self.seresnet(input)
    
dummy = torch.rand((32, 1, 1024))

model = SEResNet1d(downsample=True)
dummy_out = model(dummy)

dummy_out.shape

torch.Size([32, 128])

In [4]:
class SEModule2d(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(SEModule2d, self).__init__()
        self.se = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
                                SqueezeLayer(dim=(2, 3)),
                                nn.Linear(in_channels, in_channels//reduction),
                                nn.ReLU(),
                                nn.Linear(in_channels//reduction, in_channels),
                                nn.Sigmoid(),
                                UnsqueezeLayer(dim=(2, 3)))

    def forward(self, input):
        return input*self.se(input)
    
class SEResBlock2d(nn.Module):
    def __init__(self, in_channels, upsample=False, kernel_size=7, dropout=0.4):
        super(SEResBlock2d, self).__init__()
        if upsample:
            out_channels = in_channels*2
            self.identity = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1), 
                                          nn.BatchNorm2d(out_channels))
        else:
            out_channels = in_channels
            self.identity = nn.Identity()
        
        self.convs = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding="same"), 
                                   nn.BatchNorm2d(out_channels), 
                                   nn.ReLU(), 
                                   nn.Dropout(dropout), 
                                   nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding="same"), 
                                   nn.BatchNorm2d(out_channels), 
                                   SEModule2d(out_channels))
        self.act_fn = nn.ReLU()

    def forward(self, input):
        return self.act_fn(self.convs(input) + self.identity(input))

class SEResNet2d(nn.Module):
    def __init__(self, base_channels=64, kernel_size=7):
        super(SEResNet2d, self).__init__()
        self.seresnet = nn.Sequential(nn.Conv2d(1, base_channels, kernel_size=1), 
                                      SEResBlock2d(base_channels, kernel_size=kernel_size), 
                                      SEResBlock2d(base_channels, kernel_size=kernel_size),
                                      SEResBlock2d(base_channels, upsample=True, kernel_size=kernel_size), 
                                      SEResBlock2d(base_channels*2, kernel_size=kernel_size),
                                      SEResBlock2d(base_channels*2, kernel_size=kernel_size),
                                      nn.AdaptiveAvgPool2d(1),
                                      SqueezeLayer(dim=(2, 3)))

    def forward(self, input):
        return self.seresnet(input)

dummy = torch.rand((8, 1, 256, 128))

model = SEResNet2d()
dummy_out = model(dummy)

dummy_out.shape

torch.Size([8, 128])

: 