In [18]:
import argparse
import glob
import logging
import os
import sys
from typing import Any, ClassVar, Dict, List
import torch
from PIL import Image

from detectron2.config import CfgNode, get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.engine.defaults import DefaultPredictor
from detectron2.structures.instances import Instances
from detectron2.utils.logger import setup_logger

from densepose import add_densepose_config
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer
from densepose.structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
from densepose.utils.logger import verbosity_to_level
from densepose.vis.base import CompoundVisualizer
from densepose.vis.bounding_box import ScoredBoundingBoxVisualizer
from densepose.vis.densepose_outputs_vertex import (
    DensePoseOutputsTextureVisualizer,
    DensePoseOutputsVertexVisualizer,
    get_texture_atlases,
)
from densepose.vis.densepose_results import (
    DensePoseResultsContourVisualizer,
    DensePoseResultsFineSegmentationVisualizer,
    DensePoseResultsUVisualizer,
    DensePoseResultsVVisualizer,
)
from densepose.vis.densepose_results_textures import (
    DensePoseResultsVisualizerWithTexture,
    get_texture_atlas,
)
from densepose.vis.extractor import (
    CompoundExtractor,
    DensePoseOutputsExtractor,
    DensePoseResultExtractor,
    create_extractor,
)

In [19]:
import cv2
import numpy as np

from densepose import add_densepose_config
from densepose.vis.densepose_results import (
    DensePoseResultsFineSegmentationVisualizer as Visualizer,
)

from detectron2.config import get_cfg

def setup_config(
        config_fpath: str, model_fpath: str, args, opts: List[str]
    ):
    cfg = get_cfg()
    add_densepose_config(cfg)
    cfg.merge_from_file(config_fpath)
    cfg.merge_from_list(args.opts)
    if opts:
        cfg.merge_from_list(opts)
    cfg.MODEL.WEIGHTS = model_fpath
    cfg.freeze()
    return cfg

def _get_input_file_list(input_spec: str):
        if os.path.isdir(input_spec):
            file_list = [
                os.path.join(input_spec, fname)
                for fname in os.listdir(input_spec)
                if os.path.isfile(os.path.join(input_spec, fname))
            ]
        elif os.path.isfile(input_spec):
            file_list = [input_spec]
        else:
            file_list = glob.glob(input_spec)
        return file_list

def _get_out_fname(entry_idx: int, fname_base: str):
    base, ext = os.path.splitext(fname_base)
    return base + ".{0:04d}".format(entry_idx) + ext

class ARGS(object):
    def __init__(self) -> None:
        self.model = 'densepose_rcnn_R_101_FPN_DL_WC1M_s1x.pkl'
        self.cfg = 'densepose_rcnn_R_101_FPN_DL_WC1M_s1x.yaml'
        self.input = 'ref.png'
        self.opts = []
    
args = ARGS()
# opts = []
# cfg = setup_config(args.cfg, args.model, args, opts)
cfg = get_cfg()
add_densepose_config(cfg)
cfg.merge_from_file("densepose_rcnn_R_101_FPN_DL_WC1M_s1x.yaml")
cfg.MODEL.WEIGHTS = 'densepose_rcnn_R_101_FPN_DL_WC1M_s1x.pkl'
cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

file_list = _get_input_file_list(args.input)

In [20]:
print(cfg.INPUT.MIN_SIZE_TEST)
print(cfg.INPUT.MAX_SIZE_TEST)

0
1333


In [21]:
import detectron2.data.transforms as T
class DefaultPredictor:
    """
    Create a simple end-to-end predictor with the given config that runs on
    single device for a single input image.

    Compared to using the model directly, this class does the following additions:

    1. Load checkpoint from `cfg.MODEL.WEIGHTS`.
    2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`.
    3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`.
    4. Take one input image and produce a single output, instead of a batch.

    This is meant for simple demo purposes, so it does the above steps automatically.
    This is not meant for benchmarks or running complicated inference logic.
    If you'd like to do anything more complicated, please refer to its source code as
    examples to build and use the model manually.

    Attributes:
        metadata (Metadata): the metadata of the underlying dataset, obtained from
            cfg.DATASETS.TEST.

    Examples:
    ::
        pred = DefaultPredictor(cfg)
        inputs = cv2.imread("input.jpg")
        outputs = pred(inputs)
    """

    def __init__(self, cfg):
        self.cfg = cfg.clone()  # cfg can be modified by model
        self.model = build_model(self.cfg)
        self.model.eval()

        checkpointer = DetectionCheckpointer(self.model)
        checkpointer.load(cfg.MODEL.WEIGHTS)

        self.aug = T.ResizeShortestEdge(
            [0, 0], 1333
        )

        self.input_format = cfg.INPUT.FORMAT
        assert self.input_format in ["RGB", "BGR"], self.input_format

    def __call__(self, original_image):
        """
        Args:
            original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).

        Returns:
            predictions (dict):
                the output of the model for one image only.
                See :doc:`/tutorials/models` for details about the format.
        """
        with torch.no_grad():  # https://github.com/sphinx-doc/sphinx/issues/4258
            # Apply pre-processing to image.
            if self.input_format == "RGB":
                # whether the model expects BGR inputs or RGB
                original_image = original_image[:, :, ::-1]
            height, width = original_image.shape[:2]
            image = self.aug.get_transform(original_image).apply_image(original_image)
            # image = original_image
            image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))
            image.to(self.cfg.MODEL.DEVICE)

            # inputs = {"image": image, "height": height, "width": width}
            inputs = {"image": image}
            
            predictions = self.model([inputs])
            return predictions

predictor = DefaultPredictor(cfg)

The checkpoint state_dict contains keys that are not used by the model:
  [35mpixel_mean[0m
  [35mpixel_std[0m


In [22]:
from torchvision import transforms
from densepose.vis.extractor import DensePoseResultExtractor

vis_I = DensePoseResultsFineSegmentationVisualizer()
# ext_I = create_extractor(vis_I)
ext_I = DensePoseResultExtractor()

vis_U = DensePoseResultsUVisualizer()
ext_U = create_extractor(vis_U)

vis_V = DensePoseResultsVVisualizer()
ext_V = create_extractor(vis_V)

vis = [vis_I, vis_U, vis_V]
ext = [ext_I, ext_U, ext_V]

dataset_dir = '/data1/lihaochen/TikTok_finetuning/TiktokDance'
# split = 'train_images'
split = 'new10val_images'

def tsv_reader(tsv_file, sep='\t'):
    with open(tsv_file, 'r') as fp:
        for i, line in enumerate(fp):
            yield [x.strip() for x in line.split(sep)]

tsv_fname_img = dataset_dir + f'/{split}.tsv'
tsv_imgs = tsv_reader(tsv_fname_img)

# for file_name in file_list:
#     img = read_image(file_name, format="BGR")  # predictor expects BGR image.

for i, img_row in enumerate(tsv_imgs):
    import base64
    image_key = img_row[0]
    if image_key != 'TiktokDance_201_005_1x1_00073.jpg':
        continue
    file_name = image_key
    image = cv2.imdecode(np.frombuffer(base64.b64decode(img_row[1]), np.uint8),cv2.IMREAD_COLOR)
    img = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

    # add_image = Image.open(file_name)
    # img_tensor = transforms.ToTensor()
    # add_image = img_tensor(add_image)
    # add_image = (add_image - 0.5) * 2
    # noise = torch.randn([224, 224]).unsqueeze(0).repeat(3, 1, 1).clamp(-1.0, 1.0)
    # noise_img = 0.7*add_image + 0.3*noise
    # noise_img = (noise_img / 2) - 0.5
    # noise_img = noise_img.cpu().permute(1, 2, 0).float().numpy()
    # noise_img = (noise_img * 255).round().astype("uint8")
    # noise_img = Image.fromarray(noise_img)
    # noise_img.save('noisy.png')
    with torch.no_grad():
        outputs = predictor(img)
        for i in range(1):
            output = outputs[i]["instances"]
            datas = []
            idx = 0
            for e in ext:
                datas.append(e(output))

            for data, v in zip(datas, vis):
                image_vis = v.visualize(np.zeros_like(img), data)
                entry_idx_ = idx + 1
                out_fname = _get_out_fname(entry_idx_, file_name)
                out_dir = os.path.dirname(out_fname)
                if len(out_dir) > 0 and not os.path.exists(out_dir):
                    os.makedirs(out_dir)
                cv2.imwrite(out_fname, image_vis)
                idx += 1