# **SENet**
此份程式碼會介紹如何使用 PyTorch 的方式建構 SENet 的模型架構。

![image](https://hackmd.io/_uploads/SyC2RkB_6.png)

- [source paper](https://arxiv.org/abs/1709.01507)

## 匯入套件

In [None]:
# PyTorch 相關套件
import torch
import torch.nn as nn

## SENet Arhietecture

![image](https://hackmd.io/_uploads/BkI60kr_6.png)

- [source paper](https://arxiv.org/abs/1709.01507)

In [None]:
class SEBlock(nn.Module):
    def __init__(self, in_channels, out_channels, reduction_ratio=16):
        super(SEBlock, self).__init__()
        self.basic_module = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same'),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same'),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        self.squeeze = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
        )
        self.excitation = nn.Sequential(
            nn.Linear(out_channels, out_channels // reduction_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(out_channels // reduction_ratio, out_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.basic_module(x)
        skip = x
        x = self.squeeze(x)
        x = self.excitation(x)
        x = x.unsqueeze(2).unsqueeze(3)
        return skip * x


In [None]:
inputs = torch.randn(1, 32, 224, 224)
outputs = SEBlock(in_channels=32, out_channels=32)(inputs)
print(outputs.size())