In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Residual_block
Stage1 에서는 ResNet50 에서 사용하는 Residual block을 동일하게 사용하였습니다.


### Basic Residual_block
Stage2 이상부터는 단순히 3 * 3 conv를 2개 이어붙힌 Basic Residual block을 사용하였습니다.

In [2]:
class Residual_block(nn.Module):
  def __init__(self, in_channels=256, bottleneck_channels=64):
    super().__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=bottleneck_channels, kernel_size=1, stride=1),
        nn.BatchNorm2d(bottleneck_channels)
    )

    self.layer2 = nn.Sequential(
        nn.Conv2d(in_channels=bottleneck_channels, out_channels=bottleneck_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(bottleneck_channels)
    )

    self.layer3 = nn.Sequential(
        nn.Conv2d(in_channels=bottleneck_channels, out_channels=in_channels, kernel_size=1, stride=1),
        nn.BatchNorm2d(in_channels)
    )

  def forward(self, x):
    skip_connection = x 
    x = self.layer1(x)
    x = F.relu(x)


    x = self.layer2(x)
    x = F.relu(x)


    x = self.layer3(x)
    x += skip_connection
    x = F.relu(x)


    return x


class Basic_Residual_block(nn.Module):
  def __init__(self, in_channels=256):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1)
    self.conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1)
    self.batchNorm1 = nn.BatchNorm2d(in_channels)
    self.batchNorm2 = nn.BatchNorm2d(in_channels)

  def forward(self, x):
    skip = x 
    x = self.conv1(x)
    x = self.batchNorm1(x)
    x = F.relu(x)

    x = self.conv2(x)
    x = self.batchNorm2(x)
    x += skip
    x = F.relu(x)

    return x




In [3]:
data = torch.randn(1, 256, 512, 512)
print(data.shape)

residual = Residual_block()
result = residual(data)
print(result.shape)

torch.Size([1, 256, 512, 512])
torch.Size([1, 256, 512, 512])


## 다양한 모듈을 입력이 다른 경우 동적으로 관리하기 위해서 nn.ModuleList()를 사용예시

In [4]:
'''Practice
nn.ModuleList() 사용방법 연습
'''

class Model(nn.Module):
  def __init__(self, channels_info):
    super().__init__()
    self.layers = nn.ModuleList()

    for in_c, out_c in zip(channels_info[ : -1], channels_info[1 : ]):
      self.layers.append(nn.Linear(in_c, out_c))

  def forward(self, x):
    for layer in self.layers :
      x = layer(x)
    return x


channels_info = [10, 20, 30, 40, 50]
model = Model(channels_info)
x = torch.randn(4, channels_info[0])
out = model(x)
print(f'stage1 Out shape : {out.shape}')



channels_info = [30, 50, 80, 100, 200]
model = Model(channels_info)
x = torch.randn(4, channels_info[0])
out = model(x)
print(f'stage2 Out shape : {out.shape}')


stage1 Out shape : torch.Size([4, 50])
stage2 Out shape : torch.Size([4, 200])


### 각 Stage별로 Exchange block으로 구성되기에 Exchange block을 선언하여 재사용할 수 있도록 한다.

In [5]:
class Exchange_block(nn.Module):
  def __init__(self, input_info):
    super().__init__()
    self.num_branch = len(input_info)
    self.branchs = nn.ModuleList()

    # stage1 
    self.branchs.append(Residual_block(in_channels=input_info[0]))

    # stage2 이상 
    if self.num_branch > 1 :
      for channel in input_info[1 : ] :
        self.branchs.append(Basic_Residual_block(in_channels=channel))
    
    # exchange unit
    self.fues_layer = nn.ModuleList()
    for j in range(self.num_branch):
      fs = nn.ModuleList()
      for i in range(self.num_branch):
        # i feature map을 j feature map과 동일한 크기로 맞춰주는 과정
        if input_info[i] == input_info[j] :
          fs.append(nn.Identity())
        else :
          fs.append(nn.Sequential(
              nn.Conv2d(in_channels=input_info[i], out_channels=input_info[j], kernel_size=1),
              nn.BatchNorm2d(input_info[j])
          ))
      self.fues_layer.append(fs)



  def forward(self, inputs):
    features = [branch(x) for branch, x in zip(self.branchs, inputs)] # 각 입력을 각 모듈에 매핑하여, 각각 독립적으로 철힌다.

    outputs = []
    for j in range(self.num_branch):
      fused = 0
      for i in range(self.num_branch):
        x = features[i] # i번째 스태이지의 특징맵
        x = self.fues_layer[j][i](x) # i번쨰 스테이지의 특징맵을 j 번째 특징맵과 동일 차원으로 변경

        if i < j : # 저해상도로 바꿔야되는 경우
          factor = 2 ** (j - i)
          x = F.avg_pool2d(x, kernel_size=factor, stride=factor)
        elif i > j : # 고해상도로 바꿔야하는 경우
          factor = 2 ** (i - j)
          x = F.interpolate(x, scale_factor=factor, mode='nearest')

        fused = fused + x
      outputs.append(F.relu(fused))
    return outputs

In [6]:
input_info = [256, 512, 1024, 2048]
block = Exchange_block(input_info)

# 이미지의 해상도는 1/2배씩, 차원은 2배씩 증가시킨다.
b0 = torch.randn(2, 256, 56, 56)
b1 = torch.randn(2, 512, 28, 28)
b2 = torch.randn(2, 1024, 14, 14)
b3 = torch.randn(2, 2048, 7, 7)

out = block([b0, b1, b2, b3])
for i, o in enumerate(out):
  print(f"Branch {i} fused shape: {o.shape}")

Branch 0 fused shape: torch.Size([2, 256, 56, 56])
Branch 1 fused shape: torch.Size([2, 512, 28, 28])
Branch 2 fused shape: torch.Size([2, 1024, 14, 14])
Branch 3 fused shape: torch.Size([2, 2048, 7, 7])


### 최종적으로 HRNet을구현

In [7]:
class HRNet(nn.Module):
  def __init__(self, in_channels=3, base_channels=64):
    super().__init__()

    self.stem = nn.Sequential(
      nn.Conv2d(in_channels,  base_channels, kernel_size=3, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(base_channels),
      nn.ReLU(inplace=True),

      nn.Conv2d(base_channels,    base_channels, kernel_size=3, stride=2, padding=1, bias=False),
      nn.BatchNorm2d(base_channels),
      nn.ReLU(inplace=True),
    )

    c1 = base_channels                
    c1_expanded = 4 * c1            

    self.stage1 = nn.Sequential(
        # 32 → 128 (projection)
        nn.Conv2d(c1, c1_expanded, kernel_size=1, bias=False),
        nn.BatchNorm2d(c1_expanded),
        nn.ReLU(inplace=True),

        # Bottleneck ×4   (128 → 32 → 32 → 128)
        *[Residual_block(in_channels=c1_expanded,
                         bottleneck_channels=c1)         
          for _ in range(4)],
        nn.Conv2d(c1_expanded, c1, kernel_size=3, padding=1, bias=False),
        nn.BatchNorm2d(c1),
        nn.ReLU(inplace=True),
    )

    c1 = base_channels   # 예: 32
    c2, c3, c4 = 2*c1, 4*c1, 8*c1

    self.trans1 = nn.ModuleList([
        nn.Sequential(
            nn.Conv2d(c1, c1, kernel_size=1, bias=False),
            nn.BatchNorm2d(c1),
            nn.ReLU(inplace=True),
        ),
        nn.Sequential(
            nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(inplace=True),
        ),
        nn.Sequential(
            nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(inplace=True),

            nn.Conv2d(c2, c3, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c3),
            nn.ReLU(inplace=True),
        ),
        nn.Sequential(
            nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(inplace=True),
            nn.Conv2d(c2, c3, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c3),
            nn.ReLU(inplace=True),
            nn.Conv2d(c3, c4, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c4),
            nn.ReLU(inplace=True),
        ),
    ])


    # Stage2: 1 block
    self.stage2 = nn.ModuleList([ Exchange_block([c1, c2, c3, c4]) ])
    # Stage3: 4 blocks
    self.stage3 = nn.ModuleList([ Exchange_block([c1, c2, c3, c4]) for _ in range(4) ])
    # Stage4: 3 blocks
    self.stage4 = nn.ModuleList([ Exchange_block([c1, c2, c3, c4]) for _ in range(3) ])



  def forward(self, x):
    x = self.stem(x)            

    # Stage1
    x = self.stage1(x)           
    branches = [trans(x) for trans in self.trans1]

    # Stage2
    feature_maps = branches
    for block in self.stage2:
        feature_maps = block(feature_maps)      

    # Stage3
    for block in self.stage3:
        feature_maps = block(feature_maps)

    # Stage4
    for block in self.stage4:
        feature_maps = block(feature_maps)

    f0, f1, f2, f3 = feature_maps

    return f0


In [8]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = HRNet(in_channels=3, base_channels=32).to(device)
model.eval()

x = torch.randn(1, 3, 224, 224).to(device)

with torch.no_grad():
    out = model(x)

print("Final output shape:", out.shape)
with torch.no_grad():
    x1 = model.stem(x)
    x1 = model.stage1(x1)
    branches = [t(x1) for t in model.trans1]
for i, b in enumerate(branches):
    print(f" Branch {i} shape:", b.shape)

Final output shape: torch.Size([1, 32, 56, 56])
 Branch 0 shape: torch.Size([1, 32, 56, 56])
 Branch 1 shape: torch.Size([1, 64, 28, 28])
 Branch 2 shape: torch.Size([1, 128, 14, 14])
 Branch 3 shape: torch.Size([1, 256, 7, 7])
