In [1]:
import os

In [2]:
os.chdir("../../")

In [3]:
from external_dependencies.face_parsing.model import BiSeNet
import torch
import os
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import cv2

In [4]:
from tqdm import tqdm

In [5]:
# ['background', 'skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 
# 'ear_r', 'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']
n_classes = 19

In [6]:
net = BiSeNet(n_classes=n_classes)

In [7]:
device="cuda:2"

In [8]:
net.to(device)

BiSeNet(
  (cp): ContextPath(
    (resnet): Resnet18(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, a

In [9]:
net.load_state_dict(torch.load("./external_dependencies/face_parsing/79999_iter.pth"))

<All keys matched successfully>

In [10]:
net.eval()

BiSeNet(
  (cp): ContextPath(
    (resnet): Resnet18(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, a

In [11]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [12]:
@torch.no_grad()
def vis_parsing_maps(im: torch.Tensor, inverse: bool = False, argmax: bool = True):
    # Colors for all 20 parts
    part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
                   [255, 0, 85], [255, 0, 170],
                   [0, 255, 0], [85, 255, 0], [170, 255, 0],
                   [0, 255, 85], [0, 255, 170],
                   [0, 0, 255], [85, 0, 255], [170, 0, 255],
                   [0, 85, 255], [0, 170, 255],
                   [255, 255, 0], [255, 255, 85], [255, 255, 170],
                   [255, 0, 255], [128, 128, 128], [255, 170, 255],
                   [0, 255, 255], [85, 255, 255], [170, 255, 255], 
                   [128, 128, 128]]
    
    if inverse == False:
        if argmax:
            im = torch.argmax(im, dim=1, keepdim=True)
        out = torch.zeros((im.size(0), 3, im.size(2), im.size(3)), device=im.device, dtype=torch.float32)

        for index in range(len(part_colors)):
            color = torch.from_numpy(np.array(part_colors[index])).to(out.device).to(out.dtype).view(1, 3, 1, 1).expand_as(out)
            out = torch.where(im == index, color, out)

        out = out # / 255.0 * 2 - 1
        return out
    else:
        out = torch.zeros((im.size(0), 1, im.size(2), im.size(3)), device=im.device, dtype=torch.int64)
        
        for index in range(len(part_colors)):
            color = torch.from_numpy(np.array(part_colors[index])).to(im.device).to(im.dtype).view(1, 3, 1, 1).expand_as(im) / 255.0 * 2 - 1
            out = torch.where(torch.all(im == color, dim=1, keepdim=True), torch.ones((im.size(0), 1, im.size(2), im.size(3)), device=out.device, dtype=torch.int64) * index, out)
        
        return out

In [13]:
seg_mapping = {
    0: 0, # Background
    1: 1, # Skin
    2: 2, # Brow (L)
    3: 2, # Brow (R)
    4: 3, # Eye (L)
    5: 3, # Eye (R)
    6: 4, # Glasses
    7: 5, # Ear (L)
    8: 5, # Ear (R)
    9: 6, # Ear-ring 
    10: 7, # Nose
    11: 8, # Mouth 
    12: 9, # Lip (U)
    13: 9, # Lip (D)
    14: 10, # Neck
    15: 11, # Neck-lace
    16: 12, # Cloth
    17: 13, # Hair
    18: 14, # Hat
}

def remap_seg(seg):
    for key, value in seg_mapping.items():
        seg[seg == key] = value
    return seg

def remap_cseg(seg):
    class_s = max(list(seg_mapping.values())) + 1
    out = np.zeros((class_s, seg.shape[1], seg.shape[2]), dtype=np.float32)
    for key, value in seg_mapping.items():
        out[value, :, :] += seg[key, :, :]
    return out

In [14]:
src_img_dir = "../../Dataset/seg_samples/blacks/img"
dst_img_dir = "../../Dataset/seg_samples/blacks/seg"
os.makedirs(dst_img_dir, exist_ok=True)
fn_s = os.listdir(src_img_dir)

In [16]:
with torch.no_grad():
    for fn in tqdm(fn_s):
        image = Image.open(os.path.join(src_img_dir, fn)).resize((512, 512), Image.BILINEAR)
        img = transform(image)
        img = img.unsqueeze(0)
        img = img.to(device)
        out = ((vis_parsing_maps(remap_seg(net(img)[0].argmax(1, keepdims=True)), argmax=False).squeeze(0).cpu().permute(1, 2, 0)).numpy()).astype(np.uint8)
        parsing = out
        Image.fromarray(parsing).save(os.path.join(dst_img_dir, fn.replace("jpg", "png")))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:06<00:00,  3.34it/s]


In [4]:
from typing import List, Optional

In [5]:
def convert_to_mask(img: np.ndarray, black_lst: Optional[List[int]] = None, white_lst: Optional[List[int]] = None) -> np.ndarray:
    assert black_lst is None or white_lst is None
    
    if black_lst is not None:
        mask = np.ones(img.shape, dtype=np.uint8) * 255
        for black_index in black_lst:
            mask[img == black_index] = 0
        return mask
    elif white_lst is not None:
        mask = np.zeros(img.shape, dtype=np.uint8)
        for white_index in white_lst:
            mask[img == white_index] = 255
        return mask
    else:
        raise NotImplementedError(f"One of `black_lst` or `white_lst` must not be `None`.")

In [7]:
for fn in tqdm(fn_s):
    image = Image.open(os.path.join(dst_img_dir, fn)).resize((512, 512), Image.BILINEAR)
    if convert_to_mask(np.array(image), white_lst=[17, 18]).sum() == 0:
        print(fn)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30000/30000 [00:52<00:00, 573.09it/s]


In [12]:
convert_to_mask(np.array(Image.open(os.path.join(dst_img_dir, fn_s[6])).resize((512, 512), Image.BILINEAR)), white_lst=[17, 18]).sum()

60180