In [None]:
#https://github.com/Daniil-Osokin/lightweight-human-pose-estimation.pytorch

In [1]:
import torch
from torch import nn

In [2]:
def conv(in_channels, 
         out_channels, 
         kernel_size=3, 
         padding=1, 
         bn=True, 
         dilation=1, 
         stride=1, 
         relu=True, 
         bias=True):
    
    modules = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)]

    #배치 노말라이제이션을 수행하는 경우
    if bn:
        modules.append(nn.BatchNorm2d(out_channels))

    #relu를 적용하는 경우
    if relu:
        modules.append(nn.ReLU(inplace=True))

    #*modules의 의미하는 바 : 
    return nn.Sequential(*modules)

In [3]:
def conv_dw(in_channels, out_channels, kernel_size=3, padding=1, stride=1, dilation=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation=dilation, groups=in_channels, bias=False),
        nn.BatchNorm2d(in_channels),
        nn.ReLU(inplace=True),

        nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

In [4]:
def conv_dw_no_bn(in_channels, out_channels, kernel_size=3, padding=1, stride=1, dilation=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation=dilation, groups=in_channels, bias=False),
        nn.ELU(inplace=True),

        nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False),
        nn.ELU(inplace=True),
    )

In [7]:
class Cpm(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.align = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
        self.trunk = nn.Sequential(
            conv_dw_no_bn(out_channels, out_channels),
            conv_dw_no_bn(out_channels, out_channels),
            conv_dw_no_bn(out_channels, out_channels)
        )
        self.conv = conv(out_channels, out_channels, bn=False)

    def forward(self, x):
        x = self.align(x)
        x = self.conv(x + self.trunk(x))
        return x

In [6]:
class InitialStage(nn.Module):
    def __init__(self, num_channels, num_heatmaps, num_pafs):
        super().__init__()
        self.trunk = nn.Sequential(
            conv(num_channels, num_channels, bn=False),
            conv(num_channels, num_channels, bn=False),
            conv(num_channels, num_channels, bn=False)
        )
        self.heatmaps = nn.Sequential(
            conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
            conv(512, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
        )
        self.pafs = nn.Sequential(
            conv(num_channels, 512, kernel_size=1, padding=0, bn=False),
            conv(512, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
        )

    def forward(self, x):
        trunk_features = self.trunk(x)
        heatmaps = self.heatmaps(trunk_features)
        pafs = self.pafs(trunk_features)
        return [heatmaps, pafs]

In [11]:
class RefinementStageBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.initial = conv(in_channels, out_channels, kernel_size=1, padding=0, bn=False)
        self.trunk = nn.Sequential(
            conv(out_channels, out_channels),
            conv(out_channels, out_channels, dilation=2, padding=2)
        )

    def forward(self, x):
        initial_features = self.initial(x)
        trunk_features = self.trunk(initial_features)
        return initial_features + trunk_features

In [14]:
class RefinementStage(nn.Module):
    def __init__(self, in_channels, out_channels, num_heatmaps, num_pafs):
        super().__init__()
        self.trunk = nn.Sequential(
            RefinementStageBlock(in_channels, out_channels),
            RefinementStageBlock(out_channels, out_channels),
            RefinementStageBlock(out_channels, out_channels),
            RefinementStageBlock(out_channels, out_channels),
            RefinementStageBlock(out_channels, out_channels)
        )
        self.heatmaps = nn.Sequential(
            conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
            conv(out_channels, num_heatmaps, kernel_size=1, padding=0, bn=False, relu=False)
        )
        self.pafs = nn.Sequential(
            conv(out_channels, out_channels, kernel_size=1, padding=0, bn=False),
            conv(out_channels, num_pafs, kernel_size=1, padding=0, bn=False, relu=False)
        )

    def forward(self, x):
        trunk_features = self.trunk(x)
        heatmaps = self.heatmaps(trunk_features)
        pafs = self.pafs(trunk_features)
        return [heatmaps, pafs]

In [15]:
class PoseEstimationWithMobileNet(nn.Module):
    def __init__(self, num_refinement_stages=1, num_channels=128, num_heatmaps=19, num_pafs=38):
        super().__init__()
        self.model = nn.Sequential(
            conv(     3,  32, stride=2, bias=False),
            conv_dw( 32,  64),
            conv_dw( 64, 128, stride=2),
            conv_dw(128, 128),
            conv_dw(128, 256, stride=2),
            conv_dw(256, 256),
            conv_dw(256, 512),  # conv4_2
            conv_dw(512, 512, dilation=2, padding=2),
            conv_dw(512, 512),
            conv_dw(512, 512),
            conv_dw(512, 512),
            conv_dw(512, 512)   # conv5_5
        )
        self.cpm = Cpm(512, num_channels)

        self.initial_stage = InitialStage(num_channels, num_heatmaps, num_pafs)
        self.refinement_stages = nn.ModuleList()
        for idx in range(num_refinement_stages):
            self.refinement_stages.append(RefinementStage(num_channels + num_heatmaps + num_pafs, num_channels,
                                                          num_heatmaps, num_pafs))

    def forward(self, x):
        backbone_features = self.model(x)
        backbone_features = self.cpm(backbone_features)

        stages_output = self.initial_stage(backbone_features)
        for refinement_stage in self.refinement_stages:
            stages_output.extend(
                refinement_stage(torch.cat([backbone_features, stages_output[-2], stages_output[-1]], dim=1)))

        return stages_output

In [16]:
model = PoseEstimationWithMobileNet()

In [17]:
from torchsummary import summary

summary(model)

Layer (type:depth-idx)                             Param #
├─Sequential: 1-1                                  --
|    └─Sequential: 2-1                             --
|    |    └─Conv2d: 3-1                            864
|    |    └─BatchNorm2d: 3-2                       64
|    |    └─ReLU: 3-3                              --
|    └─Sequential: 2-2                             --
|    |    └─Conv2d: 3-4                            288
|    |    └─BatchNorm2d: 3-5                       64
|    |    └─ReLU: 3-6                              --
|    |    └─Conv2d: 3-7                            2,048
|    |    └─BatchNorm2d: 3-8                       128
|    |    └─ReLU: 3-9                              --
|    └─Sequential: 2-3                             --
|    |    └─Conv2d: 3-10                           576
|    |    └─BatchNorm2d: 3-11                      128
|    |    └─ReLU: 3-12                             --
|    |    └─Conv2d: 3-13                           8,192
|    |    └─

Layer (type:depth-idx)                             Param #
├─Sequential: 1-1                                  --
|    └─Sequential: 2-1                             --
|    |    └─Conv2d: 3-1                            864
|    |    └─BatchNorm2d: 3-2                       64
|    |    └─ReLU: 3-3                              --
|    └─Sequential: 2-2                             --
|    |    └─Conv2d: 3-4                            288
|    |    └─BatchNorm2d: 3-5                       64
|    |    └─ReLU: 3-6                              --
|    |    └─Conv2d: 3-7                            2,048
|    |    └─BatchNorm2d: 3-8                       128
|    |    └─ReLU: 3-9                              --
|    └─Sequential: 2-3                             --
|    |    └─Conv2d: 3-10                           576
|    |    └─BatchNorm2d: 3-11                      128
|    |    └─ReLU: 3-12                             --
|    |    └─Conv2d: 3-13                           8,192
|    |    └─

## 훈련

In [None]:
dataset = CustomData(image_list, transform) #이미지에 대한 경로, 일괄적으로 적용할 전처리
dataloader = DataLoader(dataset = dataset,
                        batch_size = 25,
                        shuffle = True,
                        drop_last = False)

dataiter = iter(dataloader)
batch = next(dataiter)

images, labels = batch

In [None]:
#학습에 사용할 설정
LEARNING_RATE = 0.01
ARCHITECTURE = 'CNN'
DATASET = 'YOLO'
EPOCHS = epochs

In [None]:
# 5. 모델 학습
epochs = 1000  # 학습 epoch 수
losses = []  # 손실 기록

#메타데이터 만들기
wandb.init(
  # Set the project where this run will be logged
  project="basic-intro",
  # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
  name=f"experiment_{model}",
  # Track hyperparameters and run metadata
  config={
  "learning_rate": LEARNING_RATE,
  "architecture": ARCHITECTURE,
  "dataset": DATASET,
  "epochs": EPOCHS,
  })

for epoch in range(epochs):
    # 1) 순전파
    y_pred = model(X_train)

    # 2) 손실 계산
    loss = criterion(y_pred, y_train)

    # 3) 기울기 초기화
    optimizer.zero_grad()

    # 4) 역전파
    loss.backward()
    # 5) 가중치 업데이트
    optimizer.step()

    # 손실 기록
    losses.append(loss.item())
    # 2️. Log metrics from your script to W&B
    wandb.log({"loss": loss})

    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

# Mark the run as finished
wandb.finish()

##  검증

- classification이기 때문에 검증을 위한 metric으로 accuracy를 사용할 수 있음


In [None]:
model = MyModel()  # 모델 아키텍처를 다시 정의
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()  # 추론 모드로 전환

for epoch in range(epochs):
    # 1) 순전파
    y_pred = model(X_train)

    # 2) 손실 계산
    loss = criterion(y_pred, y_train)

    # 3) 기울기 초기화
    optimizer.zero_grad()

    # 4) 역전파
    loss.backward()
    # 5) 가중치 업데이트
    optimizer.step()

    # 손실 기록
    losses.append(loss.item())
    # 2️. Log metrics from your script to W&B
    wandb.log({"loss": loss})

    if (epoch + 1) % 100 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")