In [104]:
import os, sys, io
import pandas as pd
from typing import List, Dict, Tuple, Union, Any, Optional
from datasets import Dataset, load_dataset, Image

DALLE3_PROMPT_BLIP2_FLAN = '/share/imagereward_work/prompt_reconstruction/data/blip2_flan.csv'
from config import *

In [26]:
ds = load_dataset("/share/img_datasets/pickapic_v2")

Resolving data files:   0%|          | 0/645 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/623 [00:00<?, ?it/s]

In [31]:
print(ds.keys())
for key1 in ds.keys():
    print(key1)
    print(ds[key1].column_names)
    print(type(ds[key1]))
    for data in ds[key1]:
        for column in ds[key1].column_names:
            print(column, data[column])
        break
    break

dict_keys(['train', 'validation', 'test', 'test_unique', 'validation_unique'])
train
['are_different', 'best_image_uid', 'caption', 'created_at', 'has_label', 'image_0_uid', 'image_0_url', 'image_1_uid', 'image_1_url', 'jpg_0', 'jpg_1', 'label_0', 'label_1', 'model_0', 'model_1', 'ranking_id', 'user_id', 'num_example_per_prompt', '__index_level_0__']
<class 'datasets.arrow_dataset.Dataset'>
are_different True
best_image_uid 751f5aba-c6da-4381-ac7e-cb2b51004581
caption stunningly beautiful space zombie, insanely detailed, photorealistic, 8k, created with midjourney
created_at 2023-03-25 18:37:24.454412
has_label True
image_0_uid 751f5aba-c6da-4381-ac7e-cb2b51004581
image_0_url https://text-to-image-human-preferences.s3.us-east-2.amazonaws.com/images/751f5aba-c6da-4381-ac7e-cb2b51004581.png
image_1_uid 6b915748-8bf9-4991-884d-9876fe3e2bd6
image_1_url https://text-to-image-human-preferences.s3.us-east-2.amazonaws.com/images/6b915748-8bf9-4991-884d-9876fe3e2bd6.png
jpg_0 b'\xff\xd8\xff\xe0

In [7]:
original_images_dir = "/share/home/wusiyuan/imagereward_work/prompt_reconstruction/data/tmp"
original_images = os.listdir(original_images_dir)
original_images = [os.path.join(original_images_dir, image) for image in original_images]

modified_images_dir = "/share/home/wusiyuan/imagereward_work/prompt_reconstruction/data/tmp"
modified_images = os.listdir(modified_images_dir)
modified_images = [os.path.join(modified_images_dir, image) for image in modified_images]

print(f"Total number of images: {len(original_images)}")
print(f"First 5 images: {original_images[:5]}")

# load the dataset
dataset = Dataset.from_dict({"jpg_0": original_images, "jpg_1": modified_images, "label_0": [1]*len(original_images), "caption": [""]*len(original_images)})
dataset = dataset.cast_column("jpg_0", Image()).cast_column("jpg_1", Image())

print(f"Total number of images in the dataset: {len(dataset)}")
print(f"First 5 images in the dataset: {dataset[:5]}")

print("first image in the dataset:")
print(dataset[0]["jpg_0"])
print(dataset[0]["jpg_1"])
print(dataset[0]["label_0"])
print(dataset[0]["caption"])

Total number of images: 9
First 5 images: ['/share/home/wusiyuan/imagereward_work/prompt_reconstruction/data/tmp/test-0001.png', '/share/home/wusiyuan/imagereward_work/prompt_reconstruction/data/tmp/test-0000.png', '/share/home/wusiyuan/imagereward_work/prompt_reconstruction/data/tmp/train-0030.png', '/share/home/wusiyuan/imagereward_work/prompt_reconstruction/data/tmp/test-0002.png', '/share/home/wusiyuan/imagereward_work/prompt_reconstruction/data/tmp/train-0032.png']
Total number of images in the dataset: 9
First 5 images in the dataset: {'jpg_0': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=1024x1024 at 0x7F9514774A00>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=1024x1024 at 0x7F95147AD480>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=1024x1024 at 0x7F95147AEA10>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=1024x1024 at 0x7F9514774FA0>, <PIL.PngImagePlugin.PngImageFile image mode=RGB size=1024x1024 at 0x7F957C5C78E0>], 'jpg_1': [<PIL.PngImagePlu

In [112]:
def align_file(file_list_a: List[str], file_list_b: List[str]) -> Tuple[List[str], List[str]]:
    """
    Align file list a and file list b.
    """
    root_a = os.path.dirname(file_list_a[0])
    root_b = os.path.dirname(file_list_b[0])

    file_list_a = sorted(file_list_a)
    file_list_b = sorted(file_list_b)

    file_list_a = [os.path.basename(file) for file in file_list_a]
    file_list_b = [os.path.basename(file) for file in file_list_b]

    file_list_a = [file for file in file_list_a if file in file_list_b]
    file_list_b = [file for file in file_list_b if file in file_list_a]

    file_list_a = [os.path.join(root_a, file) for file in file_list_a]
    file_list_b = [os.path.join(root_b, file) for file in file_list_b]

    return file_list_a, file_list_b

def load(
    original_images_dir: str = ORIGINAL_IMAGES_DIR,
    modified_images_root_dir: str = GENERATED_IMAGES_DIR,
    caption_csv_file: str = DALLE3_PROMPT_BLIP2_FLAN,
    modified_images_subdir: List[str] = [0,1,2,3,4,5,6],
    splits: List[str] = ["train", "validation", "test"],
):
    orig_imgs = os.listdir(original_images_dir)
    orig_imgs = [os.path.join(original_images_dir, img) for img in orig_imgs]
    orig_imgs = [img for img in orig_imgs if img.endswith(".jpg")]

    mod_imgs_dict = {}
    for subdir in modified_images_subdir:
        mod_imgs = os.listdir(os.path.join(modified_images_root_dir, str(subdir)))
        mod_imgs = [os.path.join(modified_images_root_dir, str(subdir), img) for img in mod_imgs]
        mod_imgs_dict[subdir] = mod_imgs
        mod_imgs_dict[subdir] = [img for img in mod_imgs_dict[subdir] if img.endswith(".jpg")]

    # read the caption csv file, 'image' col as key, 'prompt' col as value
    prompts = pd.read_csv(caption_csv_file)
    caption_dict = dict(zip(prompts["image"], prompts["prompt"]))
    caption_dict = {key.replace(".png", ".jpg"): value for key, value in caption_dict.items()}

    orig_imgs_train = []
    mod_imgs_train = []
    prompt_train = []
    orig_imgs_val = []
    mod_imgs_val = []
    prompt_val = []
    orig_imgs_test = []
    mod_imgs_test = []
    prompt_test = []

    for subdir in mod_imgs_dict:
        orig_imgs_now, mod_imgs_dict[subdir] = align_file(orig_imgs, mod_imgs_dict[subdir])
        assert len(orig_imgs_now) == len(mod_imgs_dict[subdir])

        # separate the dataset into train and val and test using the file name
        orig_imgs_now.sort()
        mod_imgs_dict[subdir].sort()
        for orig_img, mod_img in zip(orig_imgs_now, mod_imgs_dict[subdir]):
            assert os.path.basename(orig_img) == os.path.basename(mod_img)
            assert os.path.basename(orig_img) in caption_dict
            if "val" in orig_img:
                orig_imgs_val.append(orig_img)
                mod_imgs_val.append(mod_img)
                prompt_val.append(caption_dict[os.path.basename(orig_img)])
            elif "test" in orig_img:
                orig_imgs_test.append(orig_img)
                mod_imgs_test.append(mod_img)
                prompt_test.append(caption_dict[os.path.basename(orig_img)])
            else:
                orig_imgs_train.append(orig_img)
                mod_imgs_train.append(mod_img)
                prompt_train.append(caption_dict[os.path.basename(orig_img)])
        
    # load dataset from the file list
    # the dataset has keys ['train', 'validation', 'test']
    # and for each key, it has columns ['jpg_0', 'jpg_1', 'label_0', 'caption']
    dataset_train = None
    dataset_val = None
    dataset_test = None
    
    if 'train' in splits:
        assert len(orig_imgs_train) == len(mod_imgs_train) == len(prompt_train)
        dataset_train = Dataset.from_dict({
            "jpg_0": orig_imgs_train,
            "jpg_1": mod_imgs_train,
            "label_0": [1.0]*len(orig_imgs_train),
            "caption": prompt_train,
        })
        dataset_train = dataset_train.cast_column("jpg_0", Image()).cast_column("jpg_1", Image())
    
    if 'validation' in splits:
        assert len(orig_imgs_val) == len(mod_imgs_val) == len(prompt_val)
        dataset_val = Dataset.from_dict({
            "jpg_0": orig_imgs_val,
            "jpg_1": mod_imgs_val,
            "label_0": [1.0]*len(orig_imgs_val),
            "caption": prompt_val,
        })
        dataset_val = dataset_val.cast_column("jpg_0", Image()).cast_column("jpg_1", Image())

    if 'test' in splits:
        assert len(orig_imgs_test) == len(mod_imgs_test) == len(prompt_test)
        dataset_test = Dataset.from_dict({
            "jpg_0": orig_imgs_test,
            "jpg_1": mod_imgs_test,
            "label_0": [1.0]*len(orig_imgs_test),
            "caption": prompt_test,
        })
        dataset_test = dataset_test.cast_column("jpg_0", Image()).cast_column("jpg_1", Image())
        
    dataset = {"train": dataset_train, "validation": dataset_val, "test": dataset_test}
    return dataset
    

In [113]:
# dataset = load(modified_images_subdir=[0], splits=["test"])
dataset = load()

until subdir 0 train 4784 val 590 test 591
until subdir 1 train 9553 val 1177 test 1180
until subdir 2 train 14331 val 1767 test 1771
until subdir 3 train 19113 val 2356 test 2362
until subdir 4 train 23883 val 2942 test 2950
until subdir 5 train 28636 val 3527 test 3538
until subdir 6 train 33360 val 4112 test 4127


In [116]:
print(dataset.keys())
for key1 in dataset.keys():
    print(key1)
    print(dataset[key1].column_names)
    for data in dataset[key1]:
        for column in dataset[key1].column_names:
            print(column, data[column])
        break
    print()

dict_keys(['train', 'validation', 'test'])
train
['jpg_0', 'jpg_1', 'label_0', 'caption']
jpg_0 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x1024 at 0x7F942E715060>
jpg_1 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x1024 at 0x7F94362E7250>
label_0 1.0
caption ambulances driving on a city street at night, high quality fantasy stock photo, anamorphic illustration, unsplash, streaming on twitch, high octane cybernetics, high speed chase, rendered in lumion, inspired by Eugene Tertychnyi, 2030s, art for the game

validation
['jpg_0', 'jpg_1', 'label_0', 'caption']
jpg_0 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x1024 at 0x7F94362E6C50>
jpg_1 <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x1024 at 0x7F94362E6FB0>
label_0 1.0
caption a book, ink pen, and candle on a table, rendered in cryengine, stunning 3d render, medieval photograph, pexels, flat matte painting, prelude to the esoteric, writing a letter, toon render keyshot, of a o

In [117]:
train_dataset = dataset["train"]
print(type(train_dataset))
for data in train_dataset:
    print(data["jpg_0"])
    print(data["jpg_1"])
    print(data["label_0"])
    print(data["caption"])
    break

<class 'datasets.arrow_dataset.Dataset'>
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x1024 at 0x7F94365C6B00>
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x1024 at 0x7F94363415A0>
1.0
ambulances driving on a city street at night, high quality fantasy stock photo, anamorphic illustration, unsplash, streaming on twitch, high octane cybernetics, high speed chase, rendered in lumion, inspired by Eugene Tertychnyi, 2030s, art for the game


In [118]:
print(train_dataset)

Dataset({
    features: ['jpg_0', 'jpg_1', 'label_0', 'caption'],
    num_rows: 33360
})
