In [1]:
import numpy as np
import torch as th
import glob, tqdm, os
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import pandas as pd

def face_segment(segment_part, img):
    
    if isinstance(img, Image.Image):
        face_segment_anno = np.array(img)
    else:
        face_segment_anno = img
        
    bg = (face_segment_anno == 0)
    skin = (face_segment_anno == 1)
    l_brow = (face_segment_anno == 2)
    r_brow = (face_segment_anno == 3)
    l_eye = (face_segment_anno == 4)
    r_eye = (face_segment_anno == 5)
    eye_g = (face_segment_anno == 6)
    l_ear = (face_segment_anno == 7)
    r_ear = (face_segment_anno == 8)
    ear_r = (face_segment_anno == 9)
    nose = (face_segment_anno == 10)
    mouth = (face_segment_anno == 11)
    u_lip = (face_segment_anno == 12)
    l_lip = (face_segment_anno == 13)
    neck = (face_segment_anno == 14)
    neck_l = (face_segment_anno == 15)
    cloth = (face_segment_anno == 16)
    hair = (face_segment_anno == 17)
    hat = (face_segment_anno == 18)
    face = np.logical_or.reduce((skin, l_brow, r_brow, l_eye, r_eye, eye_g, l_ear, r_ear, ear_r, nose, mouth, u_lip, l_lip))

    if segment_part == 'faceseg_face':
        seg_m = face
    elif segment_part == 'faceseg_head':
        seg_m = (face | neck | hair)
    elif segment_part == 'faceseg_nohead':
        seg_m = ~(face | neck | hair)
    elif segment_part == 'faceseg_face&hair':
        seg_m = ~bg
    elif segment_part == 'faceseg_bg_noface&nohair':
        seg_m = (bg | hat | neck | neck_l | cloth) 
    elif segment_part == 'faceseg_bg&ears_noface&nohair':
        seg_m = (bg | hat | neck | neck_l | cloth) | (l_ear | r_ear | ear_r)
    elif segment_part == 'faceseg_bg':
        seg_m = bg
    elif segment_part == 'faceseg_bg&noface':
        seg_m = (bg | hair | hat | neck | neck_l | cloth)
    elif segment_part == 'faceseg_hair':
        seg_m = hair
    elif segment_part == 'faceseg_faceskin':
        seg_m = skin
    elif segment_part == 'faceseg_faceskin&nose':
        seg_m = (skin | nose)
    elif segment_part == 'faceseg_eyes&glasses&mouth&eyebrows':
        seg_m = (l_eye | r_eye | eye_g | l_brow | r_brow | mouth)
    elif segment_part == 'faceseg_faceskin&nose&mouth&eyebrows':
        seg_m = (skin | nose | mouth | u_lip | l_lip | l_brow | r_brow | l_eye | r_eye)
    elif segment_part == 'faceseg_faceskin&nose&mouth&eyebrows&eyes&glasses':
        seg_m = (skin | nose | mouth | u_lip | l_lip | l_brow | r_brow | l_eye | r_eye | eye_g)
    elif segment_part == 'faceseg_face_noglasses':
        seg_m = (~eye_g & face)
    elif segment_part == 'faceseg_face_noglasses_noeyes':
        seg_m = (~(l_eye | r_eye) & ~eye_g & face)
    elif segment_part == 'faceseg_eyes&glasses':
        seg_m = (l_eye | r_eye | eye_g)
    elif segment_part == 'glasses':
        seg_m = eye_g
    elif segment_part == 'faceseg_eyes':
        seg_m = (l_eye | r_eye)
    # elif (segment_part == 'sobel_bg_mask') or (segment_part == 'laplacian_bg_mask') or (segment_part == 'sobel_bin_bg_mask'):
    elif segment_part in ['sobel_bg_mask', 'laplacian_bg_mask', 'sobel_bin_bg_mask']:
        seg_m = ~(face | neck | hair)
    elif segment_part in ['canny_edge_bg_mask']:
        seg_m = ~(face | neck | hair) | (l_ear | r_ear)
    else: raise NotImplementedError(f"Segment part: {segment_part} is not found!")
    
    out = seg_m
    return out


set_ = 'valid'
ray_mask_path = f"/data/mint/DPM_Dataset/ffhq_256_with_anno/shadow_masks/{set_}/"
img_path = f"/data/mint/DPM_Dataset/ffhq_256_with_anno/ffhq_256/{set_}/"
face_segment_path = f'/data/mint/DPM_Dataset/ffhq_256_with_anno/face_segment/{set_}/anno/'

out_path = f"/data/mint/DPM_Dataset/ffhq_256_with_anno/ray_masks/images/{set_}/"
out_ovl_path = f"/data/mint/DPM_Dataset/ffhq_256_with_anno/ray_masks/overlays/{set_}/"
os.makedirs(out_path, exist_ok=True)
os.makedirs(out_ovl_path, exist_ok=True)

max_c = 8.481700287326827 # 7.383497233314015
min_c = -4.989461058405101 # -4.985533880236826
c_p = f'/data/mint/DPM_Dataset/ffhq_256_with_anno/params/{set_}/ffhq-{set_}-shadow-anno.txt'
c = pd.read_csv(c_p, sep=' ', header=None, names=['image_name', 'c_val'])

for i in tqdm.tqdm(sorted(glob.glob(img_path + "*.jpg"))):
    # print(i)
    img_name = i.split("/")[-1]
    face = Image.open(i)
    face = np.array(face)
    faceseg = Image.open(face_segment_path + 'anno_' + i.split("/")[-1].replace(".jpg", ".png"))
    faceseg = face_segment('faceseg_face_noglasses_noeyes', faceseg)
    faceseg = faceseg[..., None]
    
    c_val = c[c['image_name'] == img_name]['c_val'].values[0]
    c_val_norm = (c_val - min_c)/(max_c - min_c)
    
    rmask = Image.open(ray_mask_path + i.split("/")[-1].replace(".jpg", ".png"))
    rmask = np.array(rmask)
    # plt.imshow(rmask)
    # plt.show()
    
    # overlay = cv2.applyColorMap((shadow_area_with_c * 255).astype(np.uint8), cv2.COLORMAP_WINTER) * shadow_area
    # plt.imshow(faceseg)
    # plt.show()
    # assert False
    shadow_area = (rmask < 128) * (rmask > 0) * faceseg
    shadow_area_with_c = shadow_area.copy() * c_val_norm
    overlay = cv2.applyColorMap((shadow_area * 255).astype(np.uint8), cv2.COLORMAP_WINTER) * shadow_area
    overlay = cv2.addWeighted(np.array(face), 1, overlay, 0.5, 0)
    # plt.imshow(shadow_area * 1.0, cmap='gray')
    # plt.show()
    # plt.imshow(overlay)
    # plt.show()
    out_ovl = np.concatenate([face, (shadow_area_with_c * 255).astype(np.uint8), overlay], axis=1)
    Image.fromarray(out_ovl).save(out_ovl_path + i.split("/")[-1].replace(".jpg", ".png"))
    Image.fromarray((shadow_area_with_c * 255).astype(np.uint8)).save(out_path + i.split("/")[-1].replace(".jpg", ".png"))
    # assert False
    
    # out_hl = [np.repeat((t[..., None]*255).astype(np.uint8), 3, axis=2) for t in out.copy()]
    # out_hl = [cv2.applyColorMap(t, cv2.COLORMAP_WINTER) * (t==255) for t in out_hl]
    # overlay_img = [cv2.addWeighted(np.array(face), 1, t, 0.5, 0) for t in out_hl]
    

100%|██████████| 10000/10000 [11:58<00:00, 13.93it/s]
