In [73]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Swish活性化関数をスクラッチ実装

In [74]:
class Swish(nn.Module):
    def __init__(self, beta=1.0):      
        super().__init__()
        
        self.beta = beta    
        
    def forward(self, z):
        return z * self.beta* torch.sigmoid(z)    

In [75]:
### テスト
input_tensor = torch.randn(1, 1, 3, 3)
swish = Swish()
out = swish(input_tensor)
out

tensor([[[[ 0.1858, -0.1523,  0.8542],
          [ 0.2065, -0.1659, -0.2387],
          [-0.0512, -0.2001, -0.0118]]]])

### PytorchのSwish活性化関数

In [76]:
nn.SiLU()(input_tensor)

tensor([[[[ 0.1858, -0.1523,  0.8542],
          [ 0.2065, -0.1659, -0.2387],
          [-0.0512, -0.2001, -0.0118]]]])

### ResidualBlock

In [77]:
class ResidualBlock(nn.Module):

    def __init__(self, in_ch, out_ch, stride=1, activation='relu'):
        super().__init__()
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'swish':
            self.activation = nn.SiLU()
        else:
            raise ValueError('not support your activation. Choose from ["relu", "swish"]')
        self.main_conv = nn.Sequential(        
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            self.activation,
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_ch)
        )
        
        self.shortcut = nn.Sequential()
        
        if in_ch != out_ch or stride !=1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, padding=0, bias=False),
                nn.BatchNorm2d(out_ch)
            )
           
    def forward(self, x):
        out = self.main_conv(x) 
        out += self.shortcut(x) 
        return out

In [78]:
# テスト
input_tensor = torch.randn(8, 3, 28, 28)
residual = ResidualBlock(3, 64, stride=1, activation='swish')
out = residual(input_tensor) 
out.size() # [3, 64, 28, 28]

torch.Size([8, 64, 28, 28])

### pre-activation residual

In [79]:
class PreActivationResidualBlock(nn.Module):

    def __init__(self, in_ch, out_ch, stride=1, activation='relu'):
        super().__init__()
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'swish':
            self.activation = nn.SiLU()
        else:
            raise ValueError('not support your activation. Choose from ["relu", "swish"]')
        self.main_conv = nn.Sequential(        
            nn.BatchNorm2d(in_ch),
            self.activation,
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            self.activation,
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False),
        )
        
        self.shortcut = nn.Sequential()
        
        if in_ch != out_ch or stride !=1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride, padding=0, bias=False),
                nn.BatchNorm2d(out_ch)
            )
           
    def forward(self, x):
        out = self.main_conv(x) 
        out += self.shortcut(x) 
        return out

In [81]:
# テスト
input_tensor = torch.randn(8, 3, 28, 28)
preresidual = PreActivationResidualBlock(3, 64, stride=1, activation='relu')
out = preresidual(input_tensor) 
out.size() # [3, 64, 28, 28]

torch.Size([8, 64, 28, 28])

### Bottleneck構造

In [84]:
class BottleneckStracture(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1, activation='relu'):
        super().__init__()
        
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'swish':
            self.activation = nn.SiLU()
        else:
            raise ValueError('not support your activation. Choose from ["relu", "swish"]')
    
        self.main_conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_ch),
            self.activation,
            nn.Conv2d(out_ch, out_ch*4, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_ch*4),          
        )
        self.shortcut = nn.Sequential()
        
        if in_ch != out_ch*4 or stride !=1:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_ch, out_ch*4, kernel_size=1, stride=stride, padding=0, bias=False),
                nn.BatchNorm2d(out_ch*4)
            )
    def forward(self, x):
        out = self.main_conv(x) 
        return out

In [87]:
# テスト
input_tensor = torch.randn(3, 256, 28, 28)
bottleneck = BottleneckStracture(256, 64, stride=1, activation='relu')
out = bottleneck(input_tensor)
out.shape

torch.Size([3, 256, 28, 28])