In [1]:
import torch
from torch import nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class BasicConvBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size, stride=1, padding=0, has_relu=True
    ) -> None:
        super(BasicConvBlock, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, out_channels, kernel_size, stride, padding, bias=False
        )
        self.batchNorm = nn.BatchNorm2d(out_channels)
        self.has_relu = has_relu
        if has_relu:
            self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.batchNorm(x)
        if self.has_relu:
            x = self.relu(x)
        return x


BasicConvBlock(256, 64, 1, 1, 1)

BasicConvBlock(
  (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), padding=(1, 1), bias=False)
  (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
)

In [3]:
class BottleneckBlock(nn.Module):
    def __init__(self, in_channels, factor=4) -> None:
        super(BottleneckBlock, self).__init__()
        hidden_channels = in_channels // factor
        self.layers = nn.Sequential(
            BasicConvBlock(in_channels, hidden_channels, kernel_size=1),
            BasicConvBlock(hidden_channels, hidden_channels, kernel_size=3, padding=1),
            BasicConvBlock(hidden_channels, in_channels, kernel_size=1),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x0 = x
        x1 = self.layers(x)
        x = x0 + x1
        return self.relu(x)


print(BottleneckBlock(256))
print(BottleneckBlock(256)(torch.zeros((3, 256, 32, 32))).shape)

BottleneckBlock(
  (layers): Sequential(
    (0): BasicConvBlock(
      (conv): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): BasicConvBlock(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): BasicConvBlock(
      (conv): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (batchNorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (relu): ReLU(inplace=True)
)
torch.Size([3, 256, 32, 32])


In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(ResidualBlock, self).__init__()
        self.layers = nn.Sequential(
            BasicConvBlock(in_channels, out_channels, kernel_size=3, padding=1),
            BasicConvBlock(
                out_channels, out_channels, kernel_size=3, padding=1, has_relu=False
            ),
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x0 = x
        x1 = self.layers(x)
        x = x0 + x1
        x = self.relu(x)
        return x


print(ResidualBlock(64, 64))
print(ResidualBlock(64, 64)(torch.zeros((5, 64, 32, 32))).shape)

ResidualBlock(
  (layers): Sequential(
    (0): BasicConvBlock(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): BasicConvBlock(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (relu): ReLU(inplace=True)
)
torch.Size([5, 64, 32, 32])


In [5]:
class ConvUpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(ConvUpsampleBlock, self).__init__()
        self.layer = BasicConvBlock(in_channels, out_channels, 1)
        self.upsampling = nn.Upsample(scale_factor=2, mode="bilinear")

    def forward(self, x):
        x = self.layer(x)
        x = self.upsampling(x)
        return x


print(ConvUpsampleBlock(64, 64))

ConvUpsampleBlock(
  (layer): BasicConvBlock(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (upsampling): Upsample(scale_factor=2.0, mode=bilinear)
)


In [6]:
class HRTransformBlock(nn.Module):
    def __init__(self, in_channels, in_resolution_idx, out_resolution_idx) -> None:
        super(HRTransformBlock, self).__init__()
        self.type = in_resolution_idx - out_resolution_idx
        self.layers = (
            nn.Identity()  # if in_resolution and out_resolution are same, only identity
            if self.type == 0
            else nn.Sequential(
                *[
                    BasicConvBlock(
                        in_hid_channels,
                        out_hid_channels,
                        kernel_size=3,
                        stride=2,
                        padding=1,
                    )
                    for in_hid_channels, out_hid_channels in [
                        (in_channels * int(2**i), in_channels * int(2 ** (i + 1)))
                        for i in range(abs(self.type))
                    ]  # if downsampling, out_channels is in_channels * 2
                ]
            )  # if out_resolution is smaller(=idx is lager), downsampling by conv
            if self.type < 0
            else nn.Sequential(
                *[
                    ConvUpsampleBlock(in_hid_channels, out_hid_channels)
                    for in_hid_channels, out_hid_channels in [
                        (in_channels // int(2**i), in_channels // int(2 ** (i + 1)))
                        for i in range(abs(self.type))
                    ]  # if upsampling, out_channels is in_channels // 2
                ]
            )  # if out_resolution is larger(=idx is smaller), upsampling by conv and bilinear
        )

    def forward(self, x):
        x = self.layers(x)
        return x


for tmp_model in [
    HRTransformBlock(64, 1, 1),
    HRTransformBlock(64, 1, 3),
    HRTransformBlock(64, 3, 1),
]:
    print(tmp_model)
    print(tmp_model(torch.zeros((3, 64, 32, 32))).shape)
    print()

HRTransformBlock(
  (layers): Identity()
)
torch.Size([3, 64, 32, 32])

HRTransformBlock(
  (layers): Sequential(
    (0): BasicConvBlock(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (batchNorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): BasicConvBlock(
      (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (batchNorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
)
torch.Size([3, 256, 8, 8])

HRTransformBlock(
  (layers): Sequential(
    (0): ConvUpsampleBlock(
      (layer): BasicConvBlock(
        (conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (batchNorm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (upsampl

In [7]:
class HRModularizedBlock(nn.Module):
    def __init__(self, in_channels) -> None:
        super(HRModularizedBlock, self).__init__()
        len_res_blocks = 4
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(in_channels, in_channels) for _ in range(len_res_blocks)]
        )

    def forward(self, x):
        x = self.residual_blocks(x)
        return x


print(HRModularizedBlock(64))
print(HRModularizedBlock(64)(torch.zeros((3, 64, 32, 32))).shape)

HRModularizedBlock(
  (residual_blocks): Sequential(
    (0): ResidualBlock(
      (layers): Sequential(
        (0): BasicConvBlock(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): BasicConvBlock(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (relu): ReLU(inplace=True)
    )
    (1): ResidualBlock(
      (layers): Sequential(
        (0): BasicConvBlock(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        

In [8]:
class HRStem(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super(HRStem, self).__init__()
        self.layers = nn.Sequential(
            BasicConvBlock(
                in_channels, out_channels, kernel_size=3, stride=2, padding=1
            ),
            BasicConvBlock(
                out_channels, out_channels, kernel_size=3, stride=2, padding=1
            ),
        )

    def forward(self, x):
        x = self.layers(x)
        return x


tmp_model = HRStem(3, 64)
tmp_x = torch.zeros((7, 3, 128, 128))
print(tmp_model)
print(" in:", tmp_x.shape)
print("out:", tmp_model(tmp_x).shape)

HRStem(
  (layers): Sequential(
    (0): BasicConvBlock(
      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): BasicConvBlock(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
)
 in: torch.Size([7, 3, 128, 128])
out: torch.Size([7, 64, 32, 32])


In [9]:
class HRStage1(nn.Module):
    def __init__(self, in_channels) -> None:
        super(HRStage1, self).__init__()
        self.bottlenecks = nn.Sequential(
            *[BottleneckBlock(in_channels) for _ in range(4)]
        )
        self.stage1_from_1_to_1 = HRTransformBlock(
            in_channels, in_resolution_idx=1, out_resolution_idx=1
        )
        self.stage1_from_1_to_2 = HRTransformBlock(
            in_channels, in_resolution_idx=1, out_resolution_idx=2
        )

    def forward(self, x):
        x = self.bottlenecks(x)
        x0 = self.stage1_from_1_to_1(x)
        x1 = self.stage1_from_1_to_2(x)
        return x0, x1


tmp_model = HRStage1(64)
tmp_x = torch.zeros((3, 64, 32, 32))
print(tmp_model)
print(" in:", tmp_x.shape)
print("out:", end=" ")
for tmp_x in tmp_model(tmp_x):
    print(tmp_x.shape, end=" ")

HRStage1(
  (bottlenecks): Sequential(
    (0): BottleneckBlock(
      (layers): Sequential(
        (0): BasicConvBlock(
          (conv): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (batchNorm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): BasicConvBlock(
          (conv): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (batchNorm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): BasicConvBlock(
          (conv): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (relu): ReLU(inplace=True)
    )
    (1): BottleneckBlock(
      (layers): Sequential(
        (0): 

In [10]:
class HRSubStream(nn.Module):
    def __init__(self, in_channels, stage_idx, resolution_idx) -> None:
        super(HRSubStream, self).__init__()
        list_len_mod_blocks = [1, 4, 3]
        len_mod_blocks = list_len_mod_blocks[stage_idx - 2]
        self.modularized_blocks = nn.Sequential(
            *[HRModularizedBlock(in_channels) for _ in range(len_mod_blocks)]
        )
        self.fusion_layers = nn.ModuleDict(
            [
                (
                    f"stage{stage_idx}_from_{resolution_idx}_to_{out_resolution_idx}",
                    HRTransformBlock(in_channels, resolution_idx, out_resolution_idx),
                )
                for out_resolution_idx in range(1, stage_idx + 2)
            ]
        )

    def forward(self, x):
        x = self.modularized_blocks(x)
        list_x = [transform_block(x) for transform_block in self.fusion_layers.values()]
        return list_x


tmp_model = HRSubStream(128, 2, 2)
tmp_x = torch.zeros((3, 128, 16, 16))
print(tmp_model)
print(" in:", tmp_x.shape)
print("out:", end=" ")
for tmp_x in tmp_model(tmp_x):
    print(tmp_x.shape, end=" ")

HRSubStream(
  (modularized_blocks): Sequential(
    (0): HRModularizedBlock(
      (residual_blocks): Sequential(
        (0): ResidualBlock(
          (layers): Sequential(
            (0): BasicConvBlock(
              (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (batchNorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (relu): ReLU(inplace=True)
            )
            (1): BasicConvBlock(
              (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (batchNorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
          )
          (relu): ReLU(inplace=True)
        )
        (1): ResidualBlock(
          (layers): Sequential(
            (0): BasicConvBlock(
              (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (b

In [11]:
class HRNet(nn.Module):
    def __init__(self) -> None:
        super(HRNet, self).__init__()
        self.stem = HRStem(3, 64)
        self.stage1 = HRStage1(64)
        self.sub_streams = nn.ModuleDict(
            [
                (
                    f"stage{stage_idx}_resolution{resolution_idx}",
                    HRSubStream(
                        64 * int(2 ** (resolution_idx - 1)), stage_idx, resolution_idx
                    ),
                )
                for stage_idx in range(2, 5)
                for resolution_idx in range(1, stage_idx + 2)
                if resolution_idx < 5
            ]
        )
        self.human_pose_prediction_layer = nn.Conv2d(64, 16, kernel_size=1)

    def get_sub_stream(self, stage_idx, resolution_idx):
        return self.sub_streams[f"stage{stage_idx}_resolution{resolution_idx}"]

    def forward(self, x):
        x = self.stem(x)
        stage_in_x = list(self.stage1(x))
        for stage_idx in range(2, 5):  # stage2, stage3, stage4
            stage_out_x = []
            for in_idx, in_x in enumerate(
                stage_in_x
            ):  # in_resolution1, in_resolution2, ...
                for out_idx, out_x in enumerate(
                    self.get_sub_stream(stage_idx, in_idx + 1)(in_x)
                ):  # out_resolution1, out_resolution2, out_resoltuion3, ...
                    if out_idx < len(stage_out_x):
                        stage_out_x[out_idx] += out_x  # sum of out_resolutionXXX
                    else:
                        stage_out_x.append(out_x)  # just for initialize the list
            stage_in_x = stage_out_x
        x = self.human_pose_prediction_layer(stage_out_x[0])
        return x


model = HRNet()
tmp_x = torch.zeros((7, 3, 256, 256))
print(model)
print(" in:", tmp_x.shape)
print("out:", model(tmp_x).shape)

HRNet(
  (stem): HRStem(
    (layers): Sequential(
      (0): BasicConvBlock(
        (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
      (1): BasicConvBlock(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (batchNorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
      )
    )
  )
  (stage1): HRStage1(
    (bottlenecks): Sequential(
      (0): BottleneckBlock(
        (layers): Sequential(
          (0): BasicConvBlock(
            (conv): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (batchNorm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
          )
          (1): BasicConvBlock(

### 후기

이번 HRNet은 네트워크 구조도를 보는 것만으로는 만들 수 없었습니다.  
HRNet 관련 GitHub를 몇 개 참고했지만 config와 method가 복잡하여 한 눈에 이해하기 어려웠습니다.  
그래서 직접 논문을 읽어 GitHub 코드와 비교하면서 만들었습니다.

HRNet을 list comprehension 문법으로 구성하여 한 번에 네트워크 블록을 이해할 수 있도록 작성했습니다.  
그리고 각 블록이 어떤 것인지 파악할 수 있도록 ModuleList가 아닌 Sequntial과 ModuleDict을 사용했습니다.  

도움이 되었으면 좋겠습니다.