In [9]:
%load_ext lab_black

In [1]:
import argparse
import json
import math
import multiprocessing
import random
import sys
import os
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from pprint import pformat
from typing import List, Tuple

import albumentations as A
import cv2
import imageio
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import torch.optim as optim
import yaml
from albumentations.pytorch import ToTensorV2
from easydict import EasyDict
from PIL import Image
from sklearn.model_selection import KFold, StratifiedKFold
from torch import nn, optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm import tqdm

In [2]:
def seed_everything(seed, deterministic=False):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = deterministic
        torch.backends.cudnn.benchmark = not deterministic

In [3]:
error_list = [317, 869, 873, 877, 911, 1559, 1560, 1562, 1566, 1575]
error_list += [1577, 1578, 1582, 1606, 1607, 1622, 1623, 1624, 1625]
error_list += [1629, 3968, 4115, 4116, 4117, 4118, 4119, 4120, 4121]
error_list += [4122, 4123, 4124, 4125, 4126, 4127, 4128, 4129, 4130]
error_list += [4131, 4132, 4133, 4134, 4135, 4136, 4137, 4138, 4139]
error_list += [4140, 4141, 4142, 4143, 4144, 4145, 4146, 4147, 4148]
error_list += [4149, 4150, 4151, 4152, 4153, 4154, 4155, 4156, 4157]
error_list += [4158, 4159, 4160, 4161, 4162, 4163, 4164, 4165, 4166]
error_list += [4167, 4168, 4169, 4170, 4171, 4172, 4173, 4174, 4175]
error_list += [4176, 4177, 4178, 4179, 4180, 4181, 4182, 4183, 4184]
error_list += [4185, 4186, 4187, 4188, 4189, 4190, 4191, 4192, 4193, 4194]

# 20210323 추가
error_list += [1516, 1597, 2221, 2808, 2821, 3081, 3084, 3085, 3090, 3093, 3283, 3284]

## HRNet

In [4]:
import torch
import torch.nn as nn


BN_MOMENTUM = 0.1


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class HighResolutionModule(nn.Module):
    def __init__(self, num_branches, blocks, num_blocks, num_inchannels, num_channels, multi_scale_output=True):
        super(HighResolutionModule, self).__init__()
        self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels)

        self.num_inchannels = num_inchannels
        self.num_branches = num_branches

        self.multi_scale_output = multi_scale_output

        self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels)
        self.fuse_layers = self._make_fuse_layers()
        self.relu = nn.ReLU(True)

    def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels):
        if num_branches != len(num_blocks):
            error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(num_branches, len(num_blocks))
            raise ValueError(error_msg)

        if num_branches != len(num_channels):
            error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(num_branches, len(num_channels))
            raise ValueError(error_msg)

        if num_branches != len(num_inchannels):
            error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(num_branches, len(num_inchannels))
            raise ValueError(error_msg)

    def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1):
        downsample = None
        if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.num_inchannels[branch_index],
                    num_channels[branch_index] * block.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=BN_MOMENTUM),
            )

        layers = []
        layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample))
        self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion
        for i in range(1, num_blocks[branch_index]):
            layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index]))

        return nn.Sequential(*layers)

    def _make_branches(self, num_branches, block, num_blocks, num_channels):
        branches = []

        for i in range(num_branches):
            branches.append(self._make_one_branch(i, block, num_blocks, num_channels))

        return nn.ModuleList(branches)

    def _make_fuse_layers(self):
        if self.num_branches == 1:
            return None

        num_branches = self.num_branches
        num_inchannels = self.num_inchannels
        fuse_layers = []
        for i in range(num_branches if self.multi_scale_output else 1):
            fuse_layer = []
            for j in range(num_branches):
                if j > i:
                    fuse_layer.append(
                        nn.Sequential(
                            nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False),
                            nn.BatchNorm2d(num_inchannels[i]),
                            nn.Upsample(scale_factor=2 ** (j - i), mode="nearest"),
                        )
                    )
                elif j == i:
                    fuse_layer.append(None)
                else:
                    conv3x3s = []
                    for k in range(i - j):
                        if k == i - j - 1:
                            num_outchannels_conv3x3 = num_inchannels[i]
                            conv3x3s.append(
                                nn.Sequential(
                                    nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
                                    nn.BatchNorm2d(num_outchannels_conv3x3),
                                )
                            )
                        else:
                            num_outchannels_conv3x3 = num_inchannels[j]
                            conv3x3s.append(
                                nn.Sequential(
                                    nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False),
                                    nn.BatchNorm2d(num_outchannels_conv3x3),
                                    nn.ReLU(True),
                                )
                            )
                    fuse_layer.append(nn.Sequential(*conv3x3s))
            fuse_layers.append(nn.ModuleList(fuse_layer))

        return nn.ModuleList(fuse_layers)

    def get_num_inchannels(self):
        return self.num_inchannels

    def forward(self, x):
        if self.num_branches == 1:
            return [self.branches[0](x[0])]

        for i in range(self.num_branches):
            x[i] = self.branches[i](x[i])

        x_fuse = []

        for i in range(len(self.fuse_layers)):
            y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
            for j in range(1, self.num_branches):
                if i == j:
                    y = y + x[j]
                else:
                    y = y + self.fuse_layers[i][j](x[j])
            x_fuse.append(self.relu(y))

        return x_fuse


class PoseHighResolutionNet(nn.Module):
    def __init__(self, width=32, num_keypoints=17):
        assert width in [32, 48], f"PoseHighResolutionNet width must be in [32, 48] not {width}"
        self.width = width

        block = BasicBlock
        num_modules = [1, 4, 3]
        num_branches = [2, 3, 4]
        num_inchannels = [
            [2 ** i * width * block.expansion for i in range(2)],
            [2 ** i * width * block.expansion for i in range(3)],
            [2 ** i * width * block.expansion for i in range(4)],
        ]
        self.pre_stage_channels = [256]

        self.inplanes = 64
        super(PoseHighResolutionNet, self).__init__()

        # stem net
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(Bottleneck, 64, 4)

        self.transition1 = self._make_transition_layer(num_inchannels[0])
        self.stage2 = self._make_stage(block, num_modules[0], num_branches[0], num_inchannels[0])
        self.transition2 = self._make_transition_layer(num_inchannels[1])
        self.stage3 = self._make_stage(block, num_modules[1], num_branches[1], num_inchannels[1])
        self.transition3 = self._make_transition_layer(num_inchannels[2])
        self.stage4 = self._make_stage(block, num_modules[2], num_branches[2], num_inchannels[2], multi_scale_output=False)

        self.final_layer = nn.Conv2d(self.pre_stage_channels[0], num_keypoints, 1)

        self.init_weights()
        self.num_branches = num_branches

        self.finetune_step = 3

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.layer1(x)

        x_list = []
        for i in range(self.num_branches[0]):
            if self.transition1[i] is not None:
                x_list.append(self.transition1[i](x))
            else:
                x_list.append(x)
        y_list = self.stage2(x_list)
        x = y_list[-1]

        x_list = []
        for i in range(self.num_branches[1]):
            if self.transition2[i] is not None:
                x_list.append(self.transition2[i](x))
            else:
                x_list.append(y_list[i])
        y_list = self.stage3(x_list)
        x = y_list[-1]

        x_list = []
        for i in range(self.num_branches[2]):
            if self.transition3[i] is not None:
                x_list.append(self.transition3[i](x))
            else:
                x_list.append(y_list[i])
        y_list = self.stage4(x_list)
        x = y_list[0]

        x = self.final_layer(x)

        return x

    def _make_transition_layer(self, num_channels_cur_layer):
        num_channels_pre_layer = self.pre_stage_channels
        num_branches_pre = len(num_channels_pre_layer)
        num_branches_cur = len(num_channels_cur_layer)

        transition_layers = []
        for i in range(num_branches_cur):
            if i < num_branches_pre:
                if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
                    transition_layers.append(
                        nn.Sequential(
                            nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False),
                            nn.BatchNorm2d(num_channels_cur_layer[i]),
                            nn.ReLU(inplace=True),
                        )
                    )
                else:
                    transition_layers.append(None)
            else:
                conv3x3s = []
                for j in range(i + 1 - num_branches_pre):
                    inchannels = num_channels_pre_layer[-1]
                    outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels
                    conv3x3s.append(
                        nn.Sequential(
                            nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False),
                            nn.BatchNorm2d(outchannels),
                            nn.ReLU(inplace=True),
                        )
                    )
                transition_layers.append(nn.Sequential(*conv3x3s))

        return nn.ModuleList(transition_layers)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def _make_stage(self, block, num_module, num_branch, num_inchannels, multi_scale_output=True):
        modules = []
        for i in range(num_module):
            # multi_scale_output is only used last module
            if not multi_scale_output and i == num_module - 1:
                reset_multi_scale_output = False
            else:
                reset_multi_scale_output = True

            modules.append(
                HighResolutionModule(
                    num_branch,
                    block,
                    [4 for _ in range(num_branch)],
                    num_inchannels,
                    [2 ** i * self.width for i in range(num_branch)],
                    reset_multi_scale_output,
                )
            )
            num_inchannels = modules[-1].get_num_inchannels()

        self.pre_stage_channels = num_inchannels
        return nn.Sequential(*modules)

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                nn.init.normal_(m.weight, std=0.001)
                for name, _ in m.named_parameters():
                    if name in ["bias"]:
                        nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.ConvTranspose2d):
                nn.init.normal_(m.weight, std=0.001)
                for name, _ in m.named_parameters():
                    if name in ["bias"]:
                        nn.init.constant_(m.bias, 0)

    def freeze_step1(self):
        for p in self.parameters():
            p.requires_grad_(False)
        self.final_layer.requires_grad_(True)
        self.finetune_step = 1

    def freeze_step2(self):
        for p in self.parameters():
            p.requires_grad_(True)
        self.final_layer.requires_grad_(False)
        self.finetune_step = 2

    def freeze_step3(self):
        for p in self.parameters():
            p.requires_grad_(True)
        self.finetune_step = 3

## 데이터셋

In [5]:
class HorizontalFlipEx(A.HorizontalFlip):
    swap_columns = [(1, 2), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16), (18, 19), (22, 23)]

    def apply_to_keypoints(self, keypoints, **params):
        keypoints = super().apply_to_keypoints(keypoints, **params)

        # left/right 키포인트들은 서로 swap해주기
        for a, b in self.swap_columns:
            temp1 = deepcopy(keypoints[a])
            temp2 = deepcopy(keypoints[b])
            keypoints[a] = temp2
            keypoints[b] = temp1

        return keypoints


class VerticalFlipEx(A.VerticalFlip):
    swap_columns = [(1, 2), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16), (18, 19), (22, 23)]

    def apply_to_keypoints(self, keypoints, **params):
        keypoints = super().apply_to_keypoints(keypoints, **params)

        # left/right 키포인트들은 서로 swap해주기
        for a, b in self.swap_columns:
            temp1 = deepcopy(keypoints[a])
            temp2 = deepcopy(keypoints[b])
            keypoints[a] = temp2
            keypoints[b] = temp1

        return keypoints

In [6]:
def reratio_box(roi: Tuple[float, float, float, float], ratio_limit=2.0):
    # 극단적인 비율의 이미지는 조정해줄 필요가 있음
    w, h = roi[2] - roi[0], roi[3] - roi[1]
    dhl, dhr, dwl, dwr = 0, 0, 0, 0
    if w / h > ratio_limit:
        # w/(h+x) = l, x = w/l - h
        dh = w / ratio_limit - h
        dhl, dhr = math.floor(dh / 2), math.ceil(dh / 2)
    elif h / w > ratio_limit:
        # h/(w+x) = l, x = h/l - w
        dw = h / ratio_limit - w
        dwl, dwr = math.floor(dw / 2), math.ceil(dw / 2)
    return dhl, dhr, dwl, dwr

In [7]:
def keypoint2box(keypoint, padding=0):
    return np.array(
        [
            keypoint[:, 0].min() - padding,
            keypoint[:, 1].min() - padding,
            keypoint[:, 0].max() + padding,
            keypoint[:, 1].max() + padding,
        ]
    )

In [8]:
@torch.no_grad()
def keypoints2heatmaps(
    k: torch.Tensor,
    h=768 // 4,
    w=576 // 4,
    smooth=False,
    smooth_size=3,
    smooth_values=[0.1, 0.4, 0.8],
):
    k = k.type(torch.int64)
    c = torch.zeros(k.size(0), h, w, dtype=torch.float32)
    for i, (x, y) in enumerate(k):
        if smooth:
            for d, s in zip(range(smooth_size, 0, -1), smooth_values):
                c[i, max(y - d, 0) : min(y + d, h), max(x - d, 0) : min(x + d, w)] = s
        c[i, y, x] = 1.0
    return c

In [10]:
class KeypointDataset(Dataset):
    def __init__(self, config, files, keypoints, augmentation):
        super().__init__()
        self.C = config
        self.files = files
        self.keypoints = keypoints

        T = []
        # T.append(A.Crop(0, 28, 1920, 1080 - 28))  # 1920x1080 --> 1920x1024
        # T.append(A.Resize(512, 1024))
        if augmentation:
            # 중간에 기구로 잘리는 경우를 가장
            T_ = []
            T_.append(A.Cutout(num_holes=16, max_h_size=100, max_w_size=100, fill_value=0, p=1))
            T_.append(A.Cutout(num_holes=16, max_h_size=100, max_w_size=100, fill_value=255, p=1))
            T_.append(A.Cutout(num_holes=16, max_h_size=100, max_w_size=100, fill_value=128, p=1))
            T_.append(A.Cutout(num_holes=16, max_h_size=100, max_w_size=100, fill_value=192, p=1))
            T_.append(A.Cutout(num_holes=16, max_h_size=100, max_w_size=100, fill_value=64, p=1))
            T_.append(A.Cutout(num_holes=5, max_h_size=1920, max_w_size=50, fill_value=0, p=1))
            T_.append(A.Cutout(num_holes=5, max_h_size=1920, max_w_size=50, fill_value=255, p=1))
            T_.append(A.Cutout(num_holes=5, max_h_size=1920, max_w_size=50, fill_value=128, p=1))
            T_.append(A.Cutout(num_holes=5, max_h_size=1920, max_w_size=50, fill_value=192, p=1))
            T_.append(A.Cutout(num_holes=5, max_h_size=1920, max_w_size=50, fill_value=64, p=1))
            T_.append(A.Cutout(num_holes=5, max_h_size=30, max_w_size=1080, fill_value=0, p=1))
            T_.append(A.Cutout(num_holes=5, max_h_size=30, max_w_size=1080, fill_value=255, p=1))
            T_.append(A.Cutout(num_holes=5, max_h_size=30, max_w_size=1080, fill_value=128, p=1))
            T_.append(A.Cutout(num_holes=5, max_h_size=30, max_w_size=1080, fill_value=192, p=1))
            T_.append(A.Cutout(num_holes=5, max_h_size=30, max_w_size=1080, fill_value=64, p=1))
            T.append(A.OneOf(T_))

            # geomatric augmentations
            # T.append(A.ShiftScaleRotate(border_mode=cv2.BORDER_CONSTANT))
            T.append(A.ShiftScaleRotate())
            T.append(HorizontalFlipEx())
            T.append(VerticalFlipEx())
            T.append(A.RandomRotate90())

            T_ = []
            T_.append(A.RandomBrightnessContrast(p=1))
            T_.append(A.RandomGamma(p=1))
            T_.append(A.RandomBrightness(p=1))
            T_.append(A.RandomContrast(p=1))
            T.append(A.OneOf(T_))

            T_ = []
            T_.append(A.MotionBlur(p=1))
            T_.append(A.GaussNoise(p=1))
            T.append(A.OneOf(T_))
        if self.C.dataset.normalize:
            if self.C.dataset.mean is not None and self.C.dataset.std is not None:
                T.append(A.Normalize(self.C.dataset.mean, self.C.dataset.std))
            else:
                T.append(A.Normalize())
        else:
            T.append(A.Normalize((0, 0, 0), (1, 1, 1)))
        T.append(ToTensorV2())

        self.transform = A.Compose(
            transforms=T,
            bbox_params=A.BboxParams(format="pascal_voc", label_fields=["labels"]),
            keypoint_params=A.KeypointParams(format="xy", remove_invisible=False),
        )

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        # rotation/scale 등으로 keypoint가 화면 밖으로 나가면 exception 발생.
        # 그럼 데이터 다시 만들어줌
        while True:
            try:
                file = str(self.files[idx])
                image = imageio.imread(file)

                keypoint = self.keypoints[idx]
                box = keypoint2box(keypoint, self.C.dataset.padding)
                box = np.expand_dims(box, 0)
                labels = np.array([0], dtype=np.int64)
                a = self.transform(image=image, labels=labels, bboxes=box, keypoints=keypoint)

                image = a["image"]
                # if not self.C.dataset.normalize:
                #     image = image.type(torch.float) / 255.0
                bbox = list(map(int, a["bboxes"][0]))
                keypoint = torch.tensor(a["keypoints"], dtype=torch.float32)
                image, keypoint, heatmap, ratio, offset = self._resize_image(image, bbox, keypoint)

                return file, image, heatmap, ratio, offset
            except IndexError:
                pass

    def _resize_image(self, image, bbox, keypoint):
        """
        bbox크기만큼 이미지를 자르고, keypoint에 offset/ratio를 준다.
        """
        dhl, dhr, dwl, dwr = reratio_box(bbox, ratio_limit=self.C.dataset.ratio_limit)
        h, w = image.shape[1:3]
        bbox[0] = max(bbox[0] - dwl, 0)
        bbox[1] = max(bbox[1] - dhl, 0)
        bbox[2] = min(bbox[2] + dwr, w)
        bbox[3] = min(bbox[3] + dhr, h)

        image = image[:, bbox[1] : bbox[3], bbox[0] : bbox[2]]
        CD = self.C.dataset

        # HRNet의 입력 이미지 크기로 resize
        ratio = (CD.input_width / image.shape[2], CD.input_height / image.shape[1])
        ratio = torch.tensor(ratio, dtype=torch.float32)
        image = F.interpolate(image.unsqueeze(0), (CD.input_height, CD.input_width))[0]

        # bbox만큼 빼줌
        keypoint[:, 0] -= bbox[0]
        keypoint[:, 1] -= bbox[1]

        # 이미지를 resize해준 비율만큼 곱해줌
        keypoint[:, 0] *= ratio[0]
        keypoint[:, 1] *= ratio[1]
        # TODO: 잘못된 keypoint가 있으면 고쳐줌

        # HRNet은 1/4로 resize된 출력이 나오므로 4로 나눠줌
        keypoint /= 4

        # keypoint를 heatmap으로 변환
        heatmap = keypoints2heatmaps(
            keypoint,
            CD.input_height // 4,
            CD.input_width // 4,
            smooth=self.C.dataset.smooth_heatmap.do,
            smooth_size=self.C.dataset.smooth_heatmap.size,
            smooth_values=self.C.dataset.smooth_heatmap.values,
        )

        offset = torch.tensor([bbox[0], bbox[1]], dtype=torch.float)

        return image, keypoint, heatmap, ratio, offset

In [11]:
def get_pose_datasets(C, fold):
    total_imgs = np.array(sorted(list(C.dataset.train_dir.glob("*.jpg"))))
    df = pd.read_csv(C.dataset.target_file)
    total_keypoints = df.to_numpy()[:, 1:].astype(np.float32)
    total_keypoints = np.stack([total_keypoints[:, 0::2], total_keypoints[:, 1::2]], axis=2)

    # 오류가 있는 데이터는 학습에서 제외
    total_imgs_, total_keypoints_ = [], []
    for i in range(len(total_imgs)):
        if i not in error_list:
            total_imgs_.append(total_imgs[i])
            total_keypoints_.append(total_keypoints[i])
    total_imgs = np.array(total_imgs_)
    total_keypoints = np.array(total_keypoints_)

    # KFold
    if C.dataset.group_kfold:
        skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=C.seed)
        # 파일 이름 앞 17자리를 group으로 이미지를 분류 (파일이 너무 잘 섞여도 안됨)
        groups = []
        last_group = 0
        last_stem = total_imgs[0].name[:17]
        for f in total_imgs:
            stem = f.name[:17]
            if stem == last_stem:
                groups.append(last_group)
            else:
                last_group += 1
                last_stem = stem
                groups.append(last_group)
        indices = list(skf.split(total_imgs, groups))
    else:
        kf = KFold(n_splits=5, shuffle=True, random_state=C.seed)
        indices = list(kf.split(total_imgs))
    train_idx, valid_idx = indices[fold - 1]

    # 데이터셋 생성
    ds_train = KeypointDataset(
        C,
        total_imgs[train_idx],
        total_keypoints[train_idx],
        augmentation=True,
    )
    ds_valid = KeypointDataset(
        C,
        total_imgs[valid_idx],
        total_keypoints[valid_idx],
        augmentation=False,
    )
    dl_train = DataLoader(
        ds_train,
        batch_size=C.dataset.batch_size,
        num_workers=C.dataset.num_cpus,
        shuffle=True,
        pin_memory=True,
    )
    dl_valid = DataLoader(
        ds_valid,
        batch_size=C.dataset.batch_size,
        num_workers=C.dataset.num_cpus,
        shuffle=False,
        pin_memory=True,
    )

    return dl_train, dl_valid

## 학습 설정

In [12]:
class AverageMeter(object):
    """
    AverageMeter, referenced to https://dacon.io/competitions/official/235626/codeshare/1684
    """

    def __init__(self):
        self.sum = 0
        self.cnt = 0
        self.avg = 0

    def update(self, val, n=1):
        if n > 0:
            self.sum += val * n
            self.cnt += n
            self.avg = self.sum / self.cnt

    def get(self):
        return self.avg

    def __call__(self):
        return self.avg

In [13]:
class KeypointLoss(nn.Module):
    def forward(self, x, y):
        x = x.flatten(2).flatten(0, 1)
        y = y.flatten(2).flatten(0, 1).argmax(1)
        loss1 = F.cross_entropy(x, y)
        return loss1


class KeypointRMSE(nn.Module):
    @torch.no_grad()
    def forward(self, pred_heatmaps: torch.Tensor, real_heatmaps: torch.Tensor, ratios: torch.Tensor):
        W = pred_heatmaps.size(3)
        pred_positions = pred_heatmaps.flatten(2).argmax(2)
        real_positions = real_heatmaps.flatten(2).argmax(2)
        pred_positions = torch.stack((pred_positions // W, pred_positions % W), 2).type(torch.float32)
        real_positions = torch.stack((real_positions // W, real_positions % W), 2).type(torch.float32)
        # print(pred_positions.shape, real_positions.shape, ratios.shape)
        pred_positions *= 4 / ratios.unsqueeze(1)  # position: (B, 24, 2), ratio: (B, 2)
        real_positions *= 4 / ratios.unsqueeze(1)
        loss = (pred_positions - real_positions).square().mean().sqrt()

        return loss

In [14]:
class TrainOutput:
    def __init__(self):
        self.loss = AverageMeter()
        self.rmse = AverageMeter()

    def freeze(self):
        self.loss = self.loss()
        self.rmse = self.rmse()
        return self

In [22]:
class PoseTrainer:
    _tqdm_ = dict(ncols=100, leave=False, file=sys.stdout)

    def __init__(self, config, fold, checkpoint=None):
        self.C = config
        self.fold = fold

        # Create Network
        if self.C.pose_model == "HRNet-W32":
            width = 32
        elif self.C.pose_model == "HRNet-W48":
            width = 48
        else:
            raise NotImplementedError()

        self.pose_model = PoseHighResolutionNet(width)
        self.pose_model.load_state_dict(torch.load(f"networks/models/pose_hrnet_w{width}_384x288.pth"))

        final_layer = nn.Conv2d(width, 24, 1)
        with torch.no_grad():
            final_layer.weight[:17] = self.pose_model.final_layer.weight
            final_layer.bias[:17] = self.pose_model.final_layer.bias

            if self.C.model_additional_weight:
                # neck(17)은 left/right sholder(5, 6)과 nose(0)의 평균
                # left/right palm(18, 19)는 left/right wrist(9, 10)을 복사
                # spine2(20)은 left/right sholder(5, 6)과 left/right hip(11, 12)의 중앙
                # spine1(21)은 left/right hip(11, 12)을 각각 1/3 + left/right sholder(5, 6)을 각각 1/6
                # instep(22, 23)은 angle(15, 16)를 복사
                final_layer.weight[17] = self.pose_model.final_layer.weight[[0, 5, 6]].clone().mean(0)
                final_layer.bias[17] = self.pose_model.final_layer.bias[[0, 5, 6]].clone().mean(0)
                final_layer.weight[18] = self.pose_model.final_layer.weight[9].clone()
                final_layer.bias[18] = self.pose_model.final_layer.bias[9].clone()
                final_layer.weight[19] = self.pose_model.final_layer.weight[10].clone()
                final_layer.bias[19] = self.pose_model.final_layer.bias[10].clone()
                final_layer.weight[20] = self.pose_model.final_layer.weight[[5, 6, 11, 12]].clone().mean(0)
                final_layer.bias[20] = self.pose_model.final_layer.bias[[5, 6, 11, 12]].clone().mean(0)
                final_layer.weight[21] = torch.cat(
                    (
                        self.pose_model.final_layer.weight[[11, 12]].clone() * 1 / 3,
                        self.pose_model.final_layer.weight[[5, 6]].clone() * 6 / 1,
                    )
                ).mean(0)
                final_layer.bias[21] = torch.cat(
                    (
                        self.pose_model.final_layer.bias[[11, 12]].clone() * 1 / 3,
                        self.pose_model.final_layer.bias[[5, 6]].clone() * 6 / 1,
                    )
                ).mean(0)
                final_layer.weight[22] = self.pose_model.final_layer.weight[15].clone()
                final_layer.bias[22] = self.pose_model.final_layer.bias[15].clone()
                final_layer.weight[23] = self.pose_model.final_layer.weight[16].clone()
                final_layer.bias[23] = self.pose_model.final_layer.bias[16].clone()

            self.pose_model.final_layer = final_layer
        self.pose_model.cuda()

        # Criterion
        self.criterion = KeypointLoss().cuda()
        self.criterion_rmse = KeypointRMSE().cuda()

        # Optimizer
        self.optimizer = optim.AdamW(self.pose_model.parameters(), lr=self.C.train.lr)

        self.epoch = 1
        self.best_loss = math.inf
        self.best_rmse = math.inf
        self.earlystop_cnt = 0

        # Dataset
        self.dl_train, self.dl_valid = get_pose_datasets(self.C, self.fold)

        # Load Checkpoint
        if checkpoint is not None and Path(checkpoint).exists():
            self.load(checkpoint)

        # Scheduler
        self.scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, **self.C.train.scheduler.params)

    def save(self, path):
        torch.save(
            {
                "model": self.pose_model.state_dict(),
                "optimizer": self.optimizer.state_dict(),
                "epoch": self.epoch,
                "best_loss": self.best_loss,
                "best_rmse": self.best_rmse,
                "earlystop_cnt": self.earlystop_cnt,
            },
            path,
        )

    def load(self, path):
        print("Load pretrained", path)
        ckpt = torch.load(path)
        self.pose_model.load_state_dict(ckpt["model"])
        self.optimizer.load_state_dict(ckpt["optimizer"])
        self.epoch = ckpt["epoch"] + 1
        self.best_loss = ckpt["best_loss"]
        self.best_rmse = ckpt["best_rmse"]
        self.earlystop_cnt = ckpt["earlystop_cnt"]

    def train_loop(self):
        self.pose_model.train()

        O = TrainOutput()
        with tqdm(total=len(self.dl_train.dataset), desc=f"Train {self.epoch:03d}", **self._tqdm_) as t:
            for files, imgs, target_heatmaps, ratios, offsets in self.dl_train:
                imgs_, target_heatmaps_ = imgs.cuda(non_blocking=True), target_heatmaps.cuda(non_blocking=True)

                # augmentation
                if self.C.train.plus_augment.do:
                    with torch.no_grad():
                        c = self.C.train.plus_augment
                        if c.downsample.do and random.random() <= c.downsample.p:
                            h, w = imgs_.shape[2:]
                            ratios[:, 0] = c.downsample.width / w * ratios[:, 0]
                            ratios[:, 1] = c.downsample.height / h * ratios[:, 1]
                            imgs_ = F.interpolate(imgs_, (c.downsample.height, c.downsample.width))
                            target_heatmaps_ = F.interpolate(
                                target_heatmaps_, (c.downsample.height // 4, c.downsample.width // 4)
                            )

                        if c.rotate.do and random.random() <= c.rotate.p:
                            k = 3 if random.random() < 0.5 else 1
                            ratios[:, 0], ratios[:, 1] = ratios[:, 1], ratios[:, 0]
                            imgs_ = torch.rot90(imgs_, k, dims=(2, 3))
                            target_heatmaps_ = torch.rot90(target_heatmaps_, k, dims=(2, 3))

                pred_heatmaps_ = self.pose_model(imgs_)
                loss = self.criterion(pred_heatmaps_, target_heatmaps_)
                rmse = self.criterion_rmse(pred_heatmaps_, target_heatmaps_, ratios.cuda(non_blocking=True))

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                O.loss.update(loss.item(), len(files))
                O.rmse.update(rmse.item(), len(files))
                t.set_postfix_str(f"loss: {loss.item():.6f}, rmse: {rmse.item():.6f}", refresh=False)
                t.update(len(imgs))

        return O.freeze()

    @torch.no_grad()
    def valid_loop(self):
        self.pose_model.eval()

        O = TrainOutput()
        with tqdm(total=len(self.dl_valid.dataset), desc=f"Valid {self.epoch:03d}", **self._tqdm_) as t:
            for files, imgs, target_heatmaps, ratios, offsets in self.dl_valid:
                imgs_, target_heatmaps_ = imgs.cuda(non_blocking=True), target_heatmaps.cuda(non_blocking=True)
                pred_heatmaps_ = self.pose_model(imgs_)
                loss = self.criterion(pred_heatmaps_, target_heatmaps_)
                rmse = self.criterion_rmse(pred_heatmaps_, target_heatmaps_, ratios.cuda(non_blocking=True))

                O.loss.update(loss.item(), len(files))
                O.rmse.update(rmse.item(), len(files))
                t.set_postfix_str(f"loss: {loss.item():.6f}, rmse: {rmse.item():.6f}", refresh=False)
                t.update(len(imgs))

        return O.freeze()

    @torch.no_grad()
    def callback(self, to: TrainOutput, vo: TrainOutput):
        print(
            f"Epoch: {self.epoch:03d}/{self.C.train.max_epochs},",
            f"loss: {to.loss:.6f};{vo.loss:.6f},",
            f"rmse {to.rmse:.6f};{vo.rmse:.6f}",
        )

        self.scheduler.step(vo.loss)

        if self.best_loss > vo.loss or self.best_rmse > vo.rmse:
            if self.best_loss > vo.loss:
                self.best_loss = vo.loss
            else:
                self.best_rmse = vo.rmse

            self.earlystop_cnt = 0
            self.save(self.C.result_dir / f"fold{self.fold}.pth")
        else:
            self.earlystop_cnt += 1

    def fit(self):
        for self.epoch in range(self.epoch, self.C.train.max_epochs + 1):
            if self.C.train.finetune.do:
                if self.epoch <= self.C.train.finetune.step1_epochs:
                    if self.pose_model.finetune_step != 1:
                        print("Finetune step 1")
                        self.pose_model.freeze_step1()
                elif self.epoch <= self.C.train.finetune.step2_epochs:
                    if self.pose_model.finetune_step != 2:
                        print("Finetune step 2")
                        self.pose_model.freeze_step2()
                else:
                    if self.pose_model.finetune_step != 3:
                        print("Finetune step 3")
                        self.pose_model.freeze_step3()

            to = self.train_loop()
            vo = self.valid_loop()
            self.callback(to, vo)

            if self.earlystop_cnt > 10:
                print(f"Stop training at epoch", self.epoch)
                break

In [26]:
__hrnet_train_config__ = """
pose_model: HRNet-W48
model_additional_weight: true
comment: null
result_dir: results/submit/hrnet
data_dir: data/ori
debug: false
seed: 20210309

train:
  max_epochs: 200
  SAM: true
  folds: 
    - 1
    - 2
    - 3
    - 4
    - 5
  checkpoints: 
    - null
    - null
    - null
    - null
    - null
  loss_type: ce # ce, bce, mse, mae, awing, sigmae, kldiv
  
  finetune:
    do: true
    step1_epochs: 3
    step2_epochs: 6
    
  plus_augment:
    do: true
    downsample:
      do: true
      p: 0.2
      width: 256
      height: 256
    rotate:
      do: true
      p: 0.4
      left: true
      right: true
  
  lr: 0.0001
  scheduler:
    type: ReduceLROnPlateau
    params:
      factor: 0.5
      patience: 3
      verbose: true
  
dataset:
  train_dir: data/ori/train_imgs
  target_file: data/ori/train_df.csv
  test_dir: results/effdet-train/example/efficientdet-d7_multi5_flip_median_test_imgs
  
  scale_invariance: false
  normalize: true
  mean: [0.411, 0.420, 0.416]
  std: [0.307, 0.303, 0.292]
  smooth_heatmap: 
    do: true
    size: 3
    values: [0.1, 0.2, 0.5]
  input_width: 512
  input_height: 512
  ratio_limit: 2.0
  
  batch_size: 15
  num_cpus: 6
  padding: 20
  
  group_kfold: false
"""

In [27]:
def main():
    C = EasyDict(yaml.load(__hrnet_train_config__, yaml.FullLoader))

    for fold, checkpoint in zip(C.train.folds, C.train.checkpoints):
        C = EasyDict(yaml.load(__hrnet_train_config__, yaml.FullLoader))
        Path(C.result_dir).mkdir(parents=True, exist_ok=True)

        if C.dataset.num_cpus < 0:
            C.dataset.num_cpus = multiprocessing.cpu_count()

        C.result_dir = Path(C.result_dir)
        C.dataset.train_dir = Path(C.dataset.train_dir)
        seed_everything(C.seed, deterministic=False)

        trainer = PoseTrainer(C, fold, checkpoint)
        trainer.fit()

In [28]:
main()

Finetune step 1
Epoch: 001/5, loss: 9.158868;9.307481, rmse 81.590326;46.756194                                     
Epoch: 002/5, loss: 8.995666;9.287218, rmse 84.094788;50.721443                                     
Epoch: 003/5, loss: 9.009095;9.200552, rmse 85.877548;48.972834                                     
Finetune step 2
Epoch: 004/5, loss: 7.072980;5.931625, rmse 60.359309;40.384485                                     
Epoch: 005/5, loss: 5.325669;4.976274, rmse 44.425532;34.272797                                     


5epoch만 돌려봤습니다.

실제로 돌렸을 때 나온 로그는 아래와 같습니다.

```
[2021-04-03 01:50:42  INFO] Fold 1 , checkpoint None
[2021-04-03 01:50:45  INFO] Finetune step 1
[2021-04-03 01:52:06  INFO] Epoch: 001/200, loss: 9.111576;9.281687, rmse 81.787804;47.151906
[2021-04-03 01:53:28  INFO] Epoch: 002/200, loss: 9.004133;9.202540, rmse 83.841997;48.828235
[2021-04-03 01:54:49  INFO] Epoch: 003/200, loss: 9.009309;9.121378, rmse 84.239054;51.023133
[2021-04-03 01:54:49  INFO] Finetune step 2
[2021-04-03 01:57:26  INFO] Epoch: 004/200, loss: 6.473514;5.378668, rmse 57.224715;36.844983
[2021-04-03 01:59:59  INFO] Epoch: 005/200, loss: 4.943404;4.682473, rmse 41.109905;28.716766
[2021-04-03 02:02:34  INFO] Epoch: 006/200, loss: 4.474174;4.335109, rmse 32.775994;22.685376
[2021-04-03 02:02:35  INFO] Finetune step 3
[2021-04-03 02:05:12  INFO] Epoch: 007/200, loss: 4.064342;3.849286, rmse 27.056086;18.153155
[2021-04-03 02:07:47  INFO] Epoch: 008/200, loss: 3.817954;3.699427, rmse 24.673332;15.462396
[2021-04-03 02:10:19  INFO] Epoch: 009/200, loss: 3.656661;3.626052, rmse 22.458779;14.047837
[2021-04-03 02:12:54  INFO] Epoch: 010/200, loss: 3.597842;3.542171, rmse 20.562469;12.546319
[2021-04-03 02:15:30  INFO] Epoch: 011/200, loss: 3.534556;3.490196, rmse 19.450384;11.922026
[2021-04-03 02:18:05  INFO] Epoch: 012/200, loss: 3.474406;3.451678, rmse 18.150288;11.867161
[2021-04-03 02:20:39  INFO] Epoch: 013/200, loss: 3.417534;3.413972, rmse 17.463345;10.929338
[2021-04-03 02:23:15  INFO] Epoch: 014/200, loss: 3.407891;3.377012, rmse 16.144659;10.885777
[2021-04-03 02:25:51  INFO] Epoch: 015/200, loss: 3.368363;3.362036, rmse 15.968075;10.498726
[2021-04-03 02:28:24  INFO] Epoch: 016/200, loss: 3.315080;3.338231, rmse 15.394482;10.415142
[2021-04-03 02:30:58  INFO] Epoch: 017/200, loss: 3.294093;3.311528, rmse 14.878963;10.197726
[2021-04-03 02:33:33  INFO] Epoch: 018/200, loss: 3.275022;3.283736, rmse 14.254614;9.979561
[2021-04-03 02:36:05  INFO] Epoch: 019/200, loss: 3.243217;3.267931, rmse 14.406654;10.213960
[2021-04-03 02:38:41  INFO] Epoch: 020/200, loss: 3.243431;3.249061, rmse 13.852117;9.934496
[2021-04-03 02:41:16  INFO] Epoch: 021/200, loss: 3.221962;3.235899, rmse 13.751989;9.875229
[2021-04-03 02:43:53  INFO] Epoch: 022/200, loss: 3.215765;3.223300, rmse 13.173394;9.770934
[2021-04-03 02:46:30  INFO] Epoch: 023/200, loss: 3.196588;3.191802, rmse 13.132091;9.603020
[2021-04-03 02:49:02  INFO] Epoch: 024/200, loss: 3.147130;3.198061, rmse 13.433540;9.733900
[2021-04-03 02:51:38  INFO] Epoch: 025/200, loss: 3.143757;3.177124, rmse 13.000288;9.355708
[2021-04-03 02:54:14  INFO] Epoch: 026/200, loss: 3.143037;3.175199, rmse 12.615048;9.502426
[2021-04-03 02:56:47  INFO] Epoch: 027/200, loss: 3.126487;3.167524, rmse 12.853057;9.241065
[2021-04-03 02:59:22  INFO] Epoch: 028/200, loss: 3.084297;3.151067, rmse 12.150601;9.099870
[2021-04-03 03:01:56  INFO] Epoch: 029/200, loss: 3.097590;3.147912, rmse 12.353869;9.074561
[2021-04-03 03:04:30  INFO] Epoch: 030/200, loss: 3.066993;3.131535, rmse 12.083944;9.089086
[2021-04-03 03:07:05  INFO] Epoch: 031/200, loss: 3.068503;3.122327, rmse 12.273081;9.092817
[2021-04-03 03:09:35  INFO] Epoch: 032/200, loss: 3.033859;3.116148, rmse 12.088444;8.912008
[2021-04-03 03:12:10  INFO] Epoch: 033/200, loss: 3.053121;3.129031, rmse 11.630105;8.997702
[2021-04-03 03:14:43  INFO] Epoch: 034/200, loss: 3.030079;3.107317, rmse 11.397275;9.422136
[2021-04-03 03:17:22  INFO] Epoch: 035/200, loss: 3.066026;3.091151, rmse 11.552298;8.612168
[2021-04-03 03:19:54  INFO] Epoch: 036/200, loss: 2.995404;3.072383, rmse 12.019486;9.076405
[2021-04-03 03:22:29  INFO] Epoch: 037/200, loss: 3.013175;3.068880, rmse 11.397027;9.057779
[2021-04-03 03:25:05  INFO] Epoch: 038/200, loss: 3.007169;3.072455, rmse 10.686960;8.785975
[2021-04-03 03:27:40  INFO] Epoch: 039/200, loss: 2.997308;3.050728, rmse 11.668658;8.779771
[2021-04-03 03:30:13  INFO] Epoch: 040/200, loss: 2.968577;3.074040, rmse 11.048011;8.814108
[2021-04-03 03:32:47  INFO] Epoch: 041/200, loss: 2.980202;3.035634, rmse 11.401218;8.910084
[2021-04-03 03:35:19  INFO] Epoch: 042/200, loss: 2.949863;3.032272, rmse 11.401427;8.884144
[2021-04-03 03:37:52  INFO] Epoch: 043/200, loss: 2.957391;3.027880, rmse 11.089871;8.268735
[2021-04-03 03:40:25  INFO] Epoch: 044/200, loss: 2.938866;3.029270, rmse 10.891933;8.695430
[2021-04-03 03:43:00  INFO] Epoch: 045/200, loss: 2.932901;3.029656, rmse 10.767609;8.818323
[2021-04-03 03:45:35  INFO] Epoch: 046/200, loss: 2.941626;2.985382, rmse 10.630344;8.607774
[2021-04-03 03:48:13  INFO] Epoch: 047/200, loss: 2.928358;2.993131, rmse 10.385914;8.409465
[2021-04-03 03:50:47  INFO] Epoch: 048/200, loss: 2.910414;2.989746, rmse 10.296080;8.713046
[2021-04-03 03:53:22  INFO] Epoch: 049/200, loss: 2.918711;2.990455, rmse 10.196509;8.435738
[2021-04-03 03:55:57  INFO] Epoch: 050/200, loss: 2.906180;2.997988, rmse 10.475929;8.856266
[2021-04-03 03:58:30  INFO] Epoch: 051/200, loss: 2.884906;2.980430, rmse 10.497673;8.537921
[2021-04-03 04:01:06  INFO] Epoch: 052/200, loss: 2.881751;2.960218, rmse 10.286178;8.477791
[2021-04-03 04:03:38  INFO] Epoch: 053/200, loss: 2.849795;2.967345, rmse 10.253557;8.388203
[2021-04-03 04:06:13  INFO] Epoch: 054/200, loss: 2.862772;3.007725, rmse 10.060478;8.661884
[2021-04-03 04:08:47  INFO] Epoch: 055/200, loss: 2.853514;2.926684, rmse 10.045300;8.067089
[2021-04-03 04:11:22  INFO] Epoch: 056/200, loss: 2.862386;2.951853, rmse 9.921272;8.132290
[2021-04-03 04:13:58  INFO] Epoch: 057/200, loss: 2.845991;2.932593, rmse 9.744282;8.449003
[2021-04-03 04:16:28  INFO] Epoch: 058/200, loss: 2.821307;2.940892, rmse 10.138666;8.318795
[2021-04-03 04:19:00  INFO] Epoch: 059/200, loss: 2.819373;2.939199, rmse 9.835361;8.282974
[2021-04-03 04:21:35  INFO] Epoch: 060/200, loss: 2.831455;2.909624, rmse 9.819294;8.265219
[2021-04-03 04:24:08  INFO] Epoch: 061/200, loss: 2.808254;2.926813, rmse 9.614713;8.049121
[2021-04-03 10:13:57  INFO] Epoch: 062/200, loss: 2.839149;2.925009, rmse 9.590933;8.190913
[2021-04-03 10:16:36  INFO] Epoch: 063/200, loss: 2.820567;2.907754, rmse 9.898347;8.177191
[2021-04-03 10:19:20  INFO] Epoch: 064/200, loss: 2.850035;2.893030, rmse 9.180195;7.893622
[2021-04-03 10:22:02  INFO] Epoch: 065/200, loss: 2.813749;2.896620, rmse 9.565798;7.961446
[2021-04-03 10:24:41  INFO] Epoch: 066/200, loss: 2.772795;2.921744, rmse 9.346566;8.290775
[2021-04-03 10:27:20  INFO] Epoch: 067/200, loss: 2.779566;2.909194, rmse 9.299446;7.869686
[2021-04-03 10:30:03  INFO] Epoch: 068/200, loss: 2.771226;2.887583, rmse 8.838923;7.770881
[2021-04-03 10:32:45  INFO] Epoch: 069/200, loss: 2.753618;2.888313, rmse 9.089010;7.948496
[2021-04-03 10:35:22  INFO] Epoch: 070/200, loss: 2.706607;2.888361, rmse 8.714147;7.950290
[2021-04-03 10:38:01  INFO] Epoch: 071/200, loss: 2.721139;2.907770, rmse 8.587005;8.030171
[2021-04-03 10:40:42  INFO] Epoch: 072/200, loss: 2.725513;2.888864, rmse 8.689077;8.068762
[2021-04-03 10:43:21  INFO] Epoch: 073/200, loss: 2.713907;2.890966, rmse 8.590899;8.160121
[2021-04-03 10:46:00  INFO] Epoch: 074/200, loss: 2.694574;2.888750, rmse 8.479927;8.128879
[2021-04-03 10:48:42  INFO] Epoch: 075/200, loss: 2.721445;2.884164, rmse 8.446556;7.932857
[2021-04-03 10:51:23  INFO] Epoch: 076/200, loss: 2.707748;2.882547, rmse 8.364045;7.859095
[2021-04-03 10:54:02  INFO] Epoch: 077/200, loss: 2.677123;2.890484, rmse 8.232603;7.877231
[2021-04-03 10:56:41  INFO] Epoch: 078/200, loss: 2.679331;2.880562, rmse 8.361204;7.804435
[2021-04-03 10:59:21  INFO] Epoch: 079/200, loss: 2.679801;2.885669, rmse 8.131620;7.765136
[2021-04-03 11:01:59  INFO] Epoch: 080/200, loss: 2.660931;2.881426, rmse 8.267631;7.762618
[2021-04-03 11:04:41  INFO] Epoch: 081/200, loss: 2.677028;2.882562, rmse 8.227345;8.021825
[2021-04-03 11:07:20  INFO] Epoch: 082/200, loss: 2.666284;2.877246, rmse 8.059798;7.742466
[2021-04-03 11:10:04  INFO] Epoch: 083/200, loss: 2.670393;2.883696, rmse 8.084868;7.772282
[2021-04-03 11:12:44  INFO] Epoch: 084/200, loss: 2.666446;2.871469, rmse 8.117327;7.752082
[2021-04-03 11:15:22  INFO] Epoch: 085/200, loss: 2.638374;2.890071, rmse 8.350603;7.738865
[2021-04-03 11:18:03  INFO] Epoch: 086/200, loss: 2.643916;2.887219, rmse 8.085594;7.742009
[2021-04-03 11:20:43  INFO] Epoch: 087/200, loss: 2.647046;2.885289, rmse 8.043022;7.688209
[2021-04-03 11:23:22  INFO] Epoch: 088/200, loss: 2.637412;2.884242, rmse 8.196427;7.693994
[2021-04-03 11:26:01  INFO] Epoch: 089/200, loss: 2.611991;2.871550, rmse 7.818042;7.714978
[2021-04-03 11:28:40  INFO] Epoch: 090/200, loss: 2.633182;2.891691, rmse 8.161644;7.736809
[2021-04-03 11:31:18  INFO] Epoch: 091/200, loss: 2.606546;2.880741, rmse 7.874489;7.649662
[2021-04-03 11:33:58  INFO] Epoch: 092/200, loss: 2.616876;2.867935, rmse 7.981494;7.721557
[2021-04-03 11:36:35  INFO] Epoch: 093/200, loss: 2.583245;2.872302, rmse 8.077883;7.671826
[2021-04-03 11:39:13  INFO] Epoch: 094/200, loss: 2.611010;2.906274, rmse 8.023121;7.635853
[2021-04-03 11:41:52  INFO] Epoch: 095/200, loss: 2.598272;2.883841, rmse 7.902870;7.704837
[2021-04-03 11:44:35  INFO] Epoch: 096/200, loss: 2.626818;2.871694, rmse 7.715516;7.743892
[2021-04-03 11:47:11  INFO] Epoch: 097/200, loss: 2.572718;2.874291, rmse 8.090728;7.732247
[2021-04-03 11:49:51  INFO] Epoch: 098/200, loss: 2.588042;2.876649, rmse 7.723771;7.708593
[2021-04-03 11:52:30  INFO] Epoch: 099/200, loss: 2.597060;2.891991, rmse 7.452216;7.669222
[2021-04-03 11:55:09  INFO] Epoch: 100/200, loss: 2.584952;2.888988, rmse 7.783365;7.791896
[2021-04-03 11:57:47  INFO] Epoch: 101/200, loss: 2.564278;2.911695, rmse 7.660972;7.747816
[2021-04-03 12:00:25  INFO] Epoch: 102/200, loss: 2.574284;2.876146, rmse 7.829700;7.727427
[2021-04-03 12:03:01  INFO] Epoch: 103/200, loss: 2.553566;2.880476, rmse 7.702065;7.794234
[2021-04-03 12:05:38  INFO] Epoch: 104/200, loss: 2.563121;2.875806, rmse 7.629126;7.695067
[2021-04-03 12:08:17  INFO] Epoch: 105/200, loss: 2.550955;2.906066, rmse 7.737550;7.976671
[2021-04-03 12:08:17  INFO] Stop training at epoch 105
```