# 오픈포즈 네트워크 구성 및 구현

In [1]:
import torch
import torch.nn as nn
from torch.nn import init
import torchvision

## OpenPoseNet 구현

In [3]:
class OpenPoseNet(nn.Module):
    def __init__(self):
        super(OpenPoseNet, self).__init__()

        self.model0 = OpenPose_Feature()

        # PAFs 
        self.model1_1 = make_OpenPose_block("block1_1")
        self.model2_1 = make_OpenPose_block("block2_1")
        self.model3_1 = make_OpenPose_block("block3_1")
        self.model4_1 = make_OpenPose_block("block4_1")
        self.model5_1 = make_OpenPose_block("block5_1")
        self.model6_1 = make_OpenPose_block("block6_1")

        # confidence heatmap
        self.model1_2 = make_OpenPose_block("block1_2")
        self.model2_2 = make_OpenPose_block("block2_2")
        self.model3_2 = make_OpenPose_block("block3_2")
        self.model4_2 = make_OpenPose_block("block4_2")
        self.model5_2 = make_OpenPose_block("block5_2")
        self.model6_2 = make_OpenPose_block("block6_2")

    def forward(self, x):
        out1 = self.model0(x)

        out1_1 = self.model1_1(out1)
        out1_2 = self.model1_2(out1)

        out2 = torch.cat([out1_1, out1_2, out1], 1)
        out2_1 = self.model2_1(out2)
        out2_2 = self.model2_2(out2)

        out3 = torch.cat([out2_1, out2_2, out1], 1)
        out3_1 = self.model3_1(out3)
        out3_2 = self.model3_2(out3)

        out4 = torch.cat([out3_1, out3_2, out1], 1)
        out4_1 = self.model4_1(out4)
        out4_2 = self.model4_2(out4)

        out5 = torch.cat([out4_1, out4_2, out1], 1)
        out5_1 = self.model5_1(out5)
        out5_2 = self.model5_2(out5)

        out6 = torch.cat([out5_1, out5_2, out1], 1)
        out6_1 = self.model6_1(out6)
        out6_2 = self.model6_2(out6)


        # 손실 계산을 위한 스테이지 결과 저장
        save_for_loss = []
        save_for_loss.append(out1_1)
        save_for_loss.append(out1_2)
        save_for_loss.append(out2_1)
        save_for_loss.append(out2_2)
        save_for_loss.append(out3_1)
        save_for_loss.append(out3_2)
        save_for_loss.append(out4_1)
        save_for_loss.append(out4_2)
        save_for_loss.append(out5_1)
        save_for_loss.append(out5_2)
        save_for_loss.append(out6_1)
        save_for_loss.append(out6_2)

        return (out6_1, out6_2), save_for_loss

# Feature 및 Stage 모듈 설명 및 구현

In [4]:
class OpenPose_Feature(nn.Module):
    def __init__(self):
        super(OpenPose_Feature, self).__init__()

        vgg19 = torchvision.models.vgg19(pretrained=True)
        model = {}
        model["block0"] = vgg19.features[0:23]

        model["block0"].add_module("23", torch.nn.Conv2d(512,256,kernel_size=3, stride=1, padding=1))
        model["block0"].add_module("24", torch.nn.ReLU(inplace=True))
        model["block0"].add_module("25", torch.nn.Conv2d(256,128, kernel_size=3, stride=1, padding=1))
        model["block0"].add_module("26", torch.nn.ReLU(inplace=True))

        self.model = model["block0"]
    
    def forward(self, x):
        outputs = self.model(x)
        return outputs

In [7]:
def make_OpenPose_block(block_name):
    blocks = {}

    # 스테이지 1
    blocks["block1_1"] = [{"conv5_1_CPM_L1" : [128, 128, 3, 1, 1]},
                          {"conv5_2_CPM_L1" : [128, 128, 3, 1, 1]},
                          {"conv5_3_CPM_L1" : [128, 128, 3, 1, 1]},
                          {"conv5_4_CPM_L1" : [128, 512, 1, 1, 0]},
                          {"conv5_5_CPM_L1" : [512, 38, 1, 1, 0]}]

    blocks["block1_2"] = [{"conv5_1_CPM_L2" : [128, 128, 3, 1, 1]},
                          {"conv5_2_CPM_L2" : [128, 128, 3, 1, 1]},
                          {"conv5_3_CPM_L2" : [128, 128, 3, 1, 1]},
                          {"conv5_4_CPM_L2" : [128, 512, 1, 1, 0]},
                          {"conv5_5_CPM_L2" : [512, 19, 1, 1, 0]}]

    # 스테이지 2~6
    for i in range(2, 7):
        blocks["block%d_1" % i] = [
            {"Mconv1_stage%d_L1" % i : [185,128,7,1,3]},
            {"Mconv2_stage%d_L1" % i : [128,128,7,1,3]},
            {"Mconv3_stage%d_L1" % i : [128,128,7,1,3]},
            {"Mconv4_stage%d_L1" % i : [128,128,7,1,3]},
            {"Mconv5_stage%d_L1" % i : [128,128,7,1,3]},
            {"Mconv6_stage%d_L1" % i : [128,128,1,1,0]},
            {"Mconv7_stage%d_L1" % i : [128,38,1,1,0]},
        ]

        blocks["block%d_2" % i] = [
            {"Mconv1_stage%d_L2" % i : [185,128,7,1,3]},
            {"Mconv2_stage%d_L2" % i : [128,128,7,1,3]},
            {"Mconv3_stage%d_L2" % i : [128,128,7,1,3]},
            {"Mconv4_stage%d_L2" % i : [128,128,7,1,3]},
            {"Mconv5_stage%d_L2" % i : [128,128,7,1,3]},
            {"Mconv6_stage%d_L2" % i : [128,128,1,1,0]},
            {"Mconv7_stage%d_L2" % i : [128,19,1,1,0]},
        ]

    cfg_dict = blocks[block_name]

    layers = []

    for i in range(len(cfg_dict)):
        for k,v in cfg_dict[i].items():
            if "pool" in k:
                layers += [nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2])]
            else:
                conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], kernel_size=v[2], stride=v[3], padding=v[4])
                layers += [conv2d, nn.ReLU(inplace=True)]

    net = nn.Sequential(*layers[:-1])

    def _initialize_weights_norm(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    init.constant_(m.bias, 0.0)

    net.apply(_initialize_weights_norm)
    return net


In [8]:
# 동작확인
net = OpenPoseNet()
net.train()

batch_size = 2
dummy_img = torch.rand(batch_size, 3, 368,368)
outputs = net(dummy_img)
print(outputs)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


((tensor([[[[-1.4041e-05, -1.3203e-05, -1.4262e-07,  ..., -8.5245e-06,
           -1.9878e-05, -1.2839e-05],
          [-4.1812e-05, -4.1857e-05, -2.5971e-05,  ...,  8.9378e-06,
            6.6356e-06, -1.0518e-05],
          [-2.9651e-05, -3.9320e-05, -2.4211e-05,  ..., -5.2097e-06,
           -9.8132e-06,  7.2437e-06],
          ...,
          [-1.3525e-05,  8.0578e-06,  2.0922e-06,  ...,  2.5315e-05,
            3.2976e-05,  5.5913e-05],
          [-2.4721e-05, -7.9711e-06, -1.8720e-06,  ...,  2.5735e-05,
            4.9219e-05,  4.0261e-05],
          [-2.3380e-05,  1.4878e-06,  3.8355e-06,  ...,  4.9978e-06,
            2.9812e-05,  1.9463e-05]],

         [[ 7.1048e-06,  1.3395e-05,  3.0785e-05,  ..., -1.4079e-05,
            1.1854e-05,  1.2350e-07],
          [ 3.5255e-05,  2.0126e-05,  7.3669e-06,  ..., -7.3807e-06,
            1.5082e-05,  7.7132e-06],
          [ 5.4746e-05,  1.7761e-05,  1.7947e-05,  ..., -5.3311e-05,
           -2.0095e-05,  1.4274e-05],
          ...,
   