In [3]:
import numpy as np
import torch as th
import torchvision
import PIL
import cv2
import imageio
import blobfile as bf
import os
import tqdm
import json
from pytorch_lightning.utilities.seed import seed_everything

def _list_image_files_recursively(data_dir):
    results = []
    for entry in sorted(bf.listdir(data_dir)):
        full_path = bf.join(data_dir, entry)
        ext = entry.split(".")[-1]
        if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif", "npy"]:
            results.append(full_path)
        elif bf.isdir(full_path):
            results.extend(_list_image_files_recursively(full_path))
    return results


# Random samples

In [6]:
seed = 47
seed_everything(seed)
path = '/data/mint/DPM_Dataset/ffhq_256_with_anno/ffhq_256/valid/'
img_path = _list_image_files_recursively(path)
n_pairs = 20
pairs = np.random.choice(a=np.arange(0, len(img_path)), size=n_pairs*2, replace=False)

gen_pairs = {"seed":seed, "pair":{}}
for idx, i in enumerate(range(0, n_pairs*2, 2)):
    src_name = img_path[pairs[i]].split('/')[-1]
    dst_name = img_path[pairs[i+1]].split('/')[-1]
    gen_pairs["pair"][f"pair_{idx+1}"] = {"src":src_name, "dst":dst_name}

with open(f'gen_pair_{seed}.json', 'w') as fp:
    json.dump(gen_pairs, fp, indent=4)

Global seed set to 47


# Manual Sample to json

In [27]:
src_path = './light_right/'
dst_path = './light_left/'
src_img = _list_image_files_recursively(src_path)
dst_img = _list_image_files_recursively(dst_path)
assert len(src_img) == len(dst_img)
sel_pairs = {"src_path":src_path, "dst_path":dst_path, "pair":{}, "pairwise":{"src":[], "dst":[]}}
for idx, i in enumerate(range(0, len(src_img))):
    src_name = src_img[i].split('/')[-1]
    dst_name = dst_img[i].split('/')[-1]
    sel_pairs["pair"][f"pair_{idx+1}"] = {"src":src_name, "dst":dst_name}
    sel_pairs["pairwise"]["src"].append(src_name)
    sel_pairs["pairwise"]["dst"].append(dst_name)
    
cur = idx + 1
for idx, i in enumerate(range(len(src_img))):
    src_name = src_img[i].split('/')[-1]
    dst_name = dst_img[i].split('/')[-1]
    sel_pairs["pair"][f"pair_{cur+idx+1}"] = {"src":dst_name, "dst":src_name}
    
with open(f'manual_pair.json', 'w') as fp:
    json.dump(sel_pairs, fp, indent=4)