In [1]:
import numpy as np
import torch as th
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
from torchvision.io import read_image
from pathlib import Path
import cv2
import os, tqdm, sys
import re

plt.rcParams["savefig.bbox"] = 'tight'
plt.rcParams["figure.figsize"] = (20, 20)   # figsize

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        
def face_segment(segment_part, img):
    
    if isinstance(img, Image.Image):
        face_segment_anno = np.array(img)
    else:
        face_segment_anno = img
        
    bg = (face_segment_anno == 0)
    skin = (face_segment_anno == 1)
    l_brow = (face_segment_anno == 2)
    r_brow = (face_segment_anno == 3)
    l_eye = (face_segment_anno == 4)
    r_eye = (face_segment_anno == 5)
    eye_g = (face_segment_anno == 6)
    l_ear = (face_segment_anno == 7)
    r_ear = (face_segment_anno == 8)
    ear_r = (face_segment_anno == 9)
    nose = (face_segment_anno == 10)
    mouth = (face_segment_anno == 11)
    u_lip = (face_segment_anno == 12)
    l_lip = (face_segment_anno == 13)
    neck = (face_segment_anno == 14)
    neck_l = (face_segment_anno == 15)
    cloth = (face_segment_anno == 16)
    hair = (face_segment_anno == 17)
    hat = (face_segment_anno == 18)
    face = np.logical_or.reduce((skin, l_brow, r_brow, l_eye, r_eye, eye_g, l_ear, r_ear, ear_r, nose, mouth, u_lip, l_lip))

    if segment_part == 'faceseg_face':
        seg_m = face
    elif segment_part == 'faceseg_head':
        seg_m = (face | neck | hair)
    elif segment_part == 'faceseg_nohead':
        seg_m = ~(face | neck | hair)
    elif segment_part == 'faceseg_face&hair':
        seg_m = ~bg
    elif segment_part == 'faceseg_bg_noface&nohair':
        seg_m = (bg | hat | neck | neck_l | cloth) 
    elif segment_part == 'faceseg_bg&ears_noface&nohair':
        seg_m = (bg | hat | neck | neck_l | cloth) | (l_ear | r_ear | ear_r)
    elif segment_part == 'faceseg_bg':
        seg_m = bg
    elif segment_part == 'faceseg_bg&noface':
        seg_m = (bg | hair | hat | neck | neck_l | cloth)
    elif segment_part == 'faceseg_hair':
        seg_m = hair
    elif segment_part == 'faceseg_faceskin':
        seg_m = skin
    elif segment_part == 'faceseg_faceskin&nose':
        seg_m = (skin | nose)
    elif segment_part == 'faceseg_eyes&glasses&mouth&eyebrows':
        seg_m = (l_eye | r_eye | eye_g | l_brow | r_brow | mouth)
    elif segment_part == 'faceseg_faceskin&nose&mouth&eyebrows':
        seg_m = (skin | nose | mouth | u_lip | l_lip | l_brow | r_brow | l_eye | r_eye)
    elif segment_part == 'faceseg_faceskin&nose&mouth&eyebrows&eyes&glasses':
        seg_m = (skin | nose | mouth | u_lip | l_lip | l_brow | r_brow | l_eye | r_eye | eye_g)
    elif segment_part == 'faceseg_face_noglasses':
        seg_m = (~eye_g & face)
    elif segment_part == 'faceseg_face_noglasses_noeyes':
        seg_m = (~(l_eye | r_eye) & ~eye_g & face)
    elif segment_part == 'faceseg_eyes&glasses':
        seg_m = (l_eye | r_eye | eye_g)
    elif segment_part == 'glasses':
        seg_m = eye_g
    elif segment_part == 'faceseg_eyes':
        seg_m = (l_eye | r_eye)
    # elif (segment_part == 'sobel_bg_mask') or (segment_part == 'laplacian_bg_mask') or (segment_part == 'sobel_bin_bg_mask'):
    elif segment_part in ['sobel_bg_mask', 'laplacian_bg_mask', 'sobel_bin_bg_mask']:
        seg_m = ~(face | neck | hair)
    elif segment_part in ['canny_edge_bg_mask']:
        seg_m = ~(face | neck | hair) | (l_ear | r_ear)
    else: raise NotImplementedError(f"Segment part: {segment_part} is not found!")
    
    out = seg_m
    return out

def get_shadow_diff(img1, img2, c_type='L', signed=False):
    # Compute Shadow Difference
    img1 = np.array(img1.convert(c_type)) / 255.0
    img2 = np.array(img2.convert(c_type)) / 255.0
    if signed:
        shadow_diff = img2 - img1
    else:
        shadow_diff = np.abs(img2 - img1)
    return shadow_diff

def create_image_grid(images, n_rows=1):
    """
    Creates a grid of images from a list of NumPy arrays.
    
    Parameters:
    - images: List of np.array, each representing an image.
    - n_rows: Number of rows in the grid.
    
    Returns:
    - A matplotlib figure containing the image grid.
    """
    n_images = len(images)
    n_cols = (n_images + n_rows - 1) // n_rows  # Calculate number of columns needed
    
    # Get the height and width of the images (assuming all images are the same size)
    # img_height, img_width = images[0].shape[:2]

    # Add zero images if the number of images is less than needed to fill the grid
    images += [np.zeros_like(images[0]) for _ in range(n_rows * n_cols - n_images)]
    
    # Create the grid by concatenating images
    rows = []
    for i in range(n_rows):
        row_images = images[i * n_cols:(i + 1) * n_cols]
        rows.append(np.concatenate(row_images, axis=1))
    
    grid_image = np.concatenate(rows, axis=0)
    return grid_image

# Settings up the paths

In [2]:
# Dataset
data_path = '/data/mint/DPM_Dataset/ffhq_256_with_anno/'
set_ = 'valid'
image_path = f'{data_path}/ffhq_256/{set_}/'
mask_path = f'{data_path}/face_segment/{set_}/anno/'
shadows_path = f'{data_path}/shadow_masks/{set_}/'
ckpt = 'ema_085000'


def progress(mode, model, n_frames, sampling_path, mothership=False):
    total = 10000 if set_ == 'valid' else 60000
    if mothership:
        progress_path = sampling_path
    else:
        if set_ == 'train':
            progress_path = f'/data/mint/DPM_Dataset/Soften_Strengthen_Shadows/TPAMI/{mode}/{model}/{ckpt}/train_sub/'
        else:
            progress_path = f'/data/mint/DPM_Dataset/Soften_Strengthen_Shadows/TPAMI/{mode}/{model}/{ckpt}/{set_}/'
    if os.path.exists(progress_path):
        print("Available: ", sorted(os.listdir(progress_path)))
    
    img_path = []
    for p in sorted(os.listdir(progress_path)):
        if set_ == 'train':
            start = int(p.split('_')[1])
            end = int(p.split('_')[3])
            n = end - start
            tail = f'{progress_path}/{p}/shadow/reverse_sampling/'
        elif set_ == 'valid':
            assert p == 'shadow'
            n = 1
            tail = f'{progress_path}/shadow/reverse_sampling/'
            start = '60000'
        else: raise NotImplementedError(f"Set: {set_} is not found!")
        print(tail)
        done_count = 0
        empty_count = 0
        empty_name = []
        for t in sorted(os.listdir(tail)):
            tmp = f'{tail}/{t}/dst={start}.jpg/Lerp_1000/n_frames={n_frames}/'
            if not os.path.exists(tmp):
                continue
            if len(os.listdir(tmp)) == n_frames * 2 + 1:
                done_count += 1
            else: 
                empty_count += 1
                empty_name.append(t)
            img_path.append(tmp)
            
        print(f'[#] Done: {p} => {done_count}/{total} => {done_count * 100/total:.2f}%')
        print(f'[#] Empty: {empty_count}/{total} => {empty_count * 100/total:.2f}%')
        print(f'[#] Empty: {empty_name}')
    return img_path
            
        
def get_img_path(set_, mode, n_frames, mothership):
    # Sampling

    if mode == 'FFHQ_diffuse_face':
        model = 'log=Masked_Face_woclip+BgNoHead+shadow_256_cfg=Masked_Face_woclip+BgNoHead+shadow_256.yaml_tomin_steps=50'
    elif mode == 'FFHQ_shadow_face':
        model = 'log=difareli_canny=153to204bg_256_vll_cfg=difareli_canny=153to204bg_256_vll.yaml_tomax_steps=50'

    if set_ == 'train':
        sampling_path = f'/data/mint/sampling/TPAMI/{mode}/{model}/ema_085000/train_sub'
    else:
        sampling_path = f'/data/mint/sampling/TPAMI/{mode}/{model}/ema_085000/{set_}'
        
    img_path= sum([progress(mode=mode, model=model, n_frames=n_frames, sampling_path=sampling_path, mothership=ms) for ms in mothership], [])
    # regex src=([0-9]+).jpg/dst=([0-9]+).jpg from path
    sj_dict = {}
    for p in img_path:
        src = re.findall(r'src=([0-9]+).jpg', p)
        assert len(src) == 1
        dst = re.findall(r'dst=([0-9]+).jpg', p)
        assert len(dst) == 1
        
        sj_dict[src[0]] = p
        
    return img_path, sj_dict

n_frames_shadow = 5
n_frames_diffuse = 3
print("[#] Adding Shadow...")
shadow_img, shadow_dict = get_img_path(set_=set_, mode='FFHQ_shadow_face', n_frames=n_frames_shadow, mothership=[False])
print("=" * 100)
print("[#] Adding Diffuse...")
diffuse, diffuse_dict = get_img_path(set_=set_, mode='FFHQ_diffuse_face', n_frames=n_frames_diffuse, mothership=[False])
print("=" * 100)

[#] Adding Shadow...
Available:  ['shadow']
/data/mint/DPM_Dataset/Soften_Strengthen_Shadows/TPAMI/FFHQ_shadow_face/log=difareli_canny=153to204bg_256_vll_cfg=difareli_canny=153to204bg_256_vll.yaml_tomax_steps=50/ema_085000/valid//shadow/reverse_sampling/
[#] Done: shadow => 10000/10000 => 100.00%
[#] Empty: 0/10000 => 0.00%
[#] Empty: []
[#] Adding Diffuse...
Available:  ['shadow']
/data/mint/DPM_Dataset/Soften_Strengthen_Shadows/TPAMI/FFHQ_diffuse_face/log=Masked_Face_woclip+BgNoHead+shadow_256_cfg=Masked_Face_woclip+BgNoHead+shadow_256.yaml_tomin_steps=50/ema_085000/valid//shadow/reverse_sampling/
[#] Done: shadow => 10000/10000 => 100.00%
[#] Empty: 0/10000 => 0.00%
[#] Empty: []


In [3]:
out_dict = {k: {'shadow':shadow_dict[k], 'diffuse':diffuse_dict[k]} for k in shadow_dict if k in diffuse_dict}
# print(out_dict)
print(len(out_dict))

8456
