In [1]:
%load_ext lab_black

In [2]:
import math
import os
import random
from collections import defaultdict
from datetime import datetime
from io import TextIOWrapper
from pathlib import Path
from typing import Callable, List, Sequence, Tuple

import albumentations as A
import cv2
import imageio
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import KFold
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import Subset
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import mobilenet_v2
from torchvision.models.detection import KeypointRCNN, keypointrcnn_resnet50_fpn
from torchvision.ops import MultiScaleRoIAlign
from tqdm import tqdm

import utils

BASELINE = True
MODEL = "keypointrcnn_resnet50_fpn_finetune_step1"
DATA_DIR = Path("data/ori")
FOLD = 1
START_EPOCH = 1
NUM_EPOCHS = 200
N_TTA_TEST = 10
N_TTA_VALID = 1
SAM = False
LR = 1e-4

In [3]:
class KeypointDataset(Dataset):
    def __init__(
        self,
        image_dir: os.PathLike,
        label_path: os.PathLike,
        transforms: Sequence[Callable] = None,
    ) -> None:
        self.image_dir = Path(image_dir)
        self.df = pd.read_csv(label_path).to_numpy()
        self.transforms = transforms

    def __len__(self) -> int:
        return self.df.shape[0]

    def __getitem__(self, index: int):
        image_id = self.df[index, 0]
        labels = np.array([1])
        # int64가 아니면 안되는건가? 소숫점이 손실될텐데?
        keypoints = self.df[index, 1:].reshape(-1, 2).astype(np.int64)

        x1, y1 = min(keypoints[:, 0]), min(keypoints[:, 1])
        x2, y2 = max(keypoints[:, 0]), max(keypoints[:, 1])
        boxes = np.array([[x1, y1, x2, y2]], dtype=np.int64)

        image = cv2.imread(str(self.image_dir / image_id), cv2.COLOR_BGR2RGB)

        targets = {
            "image": image,
            "bboxes": boxes,
            "labels": labels,
            "keypoints": keypoints,
        }

        if self.transforms is not None:
            targets = self.transforms(**targets)

        image = targets["image"]
        image = image / 255.0

        targets = {
            "labels": torch.as_tensor(targets["labels"], dtype=torch.int64),
            "boxes": torch.as_tensor(targets["bboxes"], dtype=torch.float32),
            "keypoints": torch.as_tensor(
                np.concatenate([targets["keypoints"], np.ones((24, 1))], axis=1)[np.newaxis],
                dtype=torch.float32,
            ),
        }

        return image, targets

In [4]:
def collate_fn(batch: torch.Tensor) -> Tuple:
    return tuple(zip(*batch))


def load_dataset(fold):
    transform = A.Compose(
        [
            A.Resize(800, 1333),
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ],
        bbox_params=A.BboxParams(format="pascal_voc", label_fields=["labels"]),
        keypoint_params=A.KeypointParams(format="xy"),
    )

    ds = KeypointDataset(DATA_DIR / "train_imgs", DATA_DIR / "train_df.csv", transform)
    kf = KFold(n_splits=5, shuffle=True, random_state=1351235)
    for i, (tidx, vidx) in enumerate(kf.split(ds), 1):
        if i == fold:
            if BASELINE:
                tidx = tidx[: len(tidx) // 10]
                vidx = vidx[: len(vidx) // 10]

            tds, vds = Subset(ds, tidx), Subset(ds, vidx)
            tdl = DataLoader(tds, batch_size=24, shuffle=True, num_workers=8, pin_memory=True, collate_fn=collate_fn)
            vdl = DataLoader(vds, batch_size=24, shuffle=False, num_workers=8, pin_memory=True, collate_fn=collate_fn)
            return tdl, vdl

    raise NotImplementedError("out of folds")

In [5]:
def get_model() -> nn.Module:
    if MODEL == "우주대마왕":
        backbone = mobilenet_v2(pretrained=True).features
        backbone.out_channels = 1280
        roi_pooler = MultiScaleRoIAlign(featmap_names=["0"], output_size=7, sampling_ratio=2)

        keypoint_roi_pooler = MultiScaleRoIAlign(featmap_names=["0"], output_size=14, sampling_ratio=2)

        model = KeypointRCNN(
            backbone,
            num_classes=2,
            num_keypoints=24,
            box_roi_pool=roi_pooler,
            keypoint_roi_pool=keypoint_roi_pooler,
        )
    elif MODEL == "keypointrcnn_resnet50_fpn_finetune_step1":
        model = keypointrcnn_resnet50_fpn(pretrained=True, progress=False)
        for p in model.parameters():
            p.requires_grad = False

        m = nn.ConvTranspose2d(512, 24, 4, 2, 1)
        with torch.no_grad():
            m.weight[:, :17] = model.roi_heads.keypoint_predictor.kps_score_lowres.weight
            m.bias[:17] = model.roi_heads.keypoint_predictor.kps_score_lowres.bias
            # m.weight = m.weight.contiguous()
            # m.bias = m.bias.contiguous()
        model.roi_heads.keypoint_predictor.kps_score_lowres = m
    else:
        raise NotImplementedError()

    return model.cuda()

In [6]:
model = get_model()

In [7]:
tdl, vdl = load_dataset(1)

In [13]:
xs, ys = next(tdl.__iter__())

In [14]:
xs_ = [x.cuda() for x in xs]

In [15]:
ys_ = [{k: v.cuda() for k, v in y.items()} for y in ys]

In [18]:
losses = model(xs_, ys_)

In [19]:
losses

{'loss_classifier': tensor(0.1179, device='cuda:0'),
 'loss_box_reg': tensor(0.0427, device='cuda:0'),
 'loss_keypoint': tensor(8.0383, device='cuda:0', grad_fn=<NllLossBackward>),
 'loss_objectness': tensor(0.2116, device='cuda:0'),
 'loss_rpn_box_reg': tensor(0.0155, device='cuda:0')}

In [20]:
losses.values()

dict_values([tensor(0.1179, device='cuda:0'), tensor(0.0427, device='cuda:0'), tensor(8.0383, device='cuda:0', grad_fn=<NllLossBackward>), tensor(0.2116, device='cuda:0'), tensor(0.0155, device='cuda:0')])

values의 sum 말고 중요한 요소만 집는 방법도 있을 것 같은데

In [21]:
model.eval()
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f80fd190a50>

In [22]:
losses = model(xs_, ys_)

In [24]:
losses[0]

{'boxes': tensor([], device='cuda:0', size=(0, 4)),
 'labels': tensor([], device='cuda:0', dtype=torch.int64),
 'scores': tensor([], device='cuda:0'),
 'keypoints': tensor([], device='cuda:0', size=(0, 24, 3)),
 'keypoints_scores': tensor([], device='cuda:0', size=(0, 24))}