### Download full dataset
https://www.pinlandata.com/rp2k_dataset

In [None]:
# Download
!aria2c --file-allocation=none -c -x 10 -s 10 \
    -d "downloads" \
    https://blob-nips2020-rp2k-dataset.obs.cn-east-3.myhuaweicloud.com/rp2k_dataset.zip \
    --check-certificate=false

# Extract
!cd downloads && unzip -q rp2k_dataset.zip
!mv downloads/all/test downloads/all/val

### Make a subset of dataset for balanced and faster training

In [54]:
output_dir = "./downloads/subset"
os.makedirs(output_dir, exist_ok=True)

for i in glob.glob("./downloads/all/train/*"):
    images = glob.glob(f"{i}/*")[:10] # only select 10 images per class
    for image_path in images:
        os.makedirs(os.path.dirname(image_path.replace("/all/", "/subset/")), exist_ok=True)
        shutil.copy(image_path, image_path.replace("/all/", "/subset/"))

for i in glob.glob("./downloads/all/val/*"):
    images = glob.glob(f"{i}/*")[:10] # only select 10 images per class
    for image_path in images:
        os.makedirs(os.path.dirname(image_path.replace("/all/", "/subset/")), exist_ok=True)
        shutil.copy(image_path, image_path.replace("/all/", "/subset/"))

### Dataset preparation

In [60]:
import os
import glob
import shutil
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
import hashlib
import torch
from dreamsim import dreamsim
from PIL import Image
import torch.nn.functional as F
from shared import LabelEncoder, resize_and_pad_image_cv2

In [67]:
device = torch.device("cuda")
model, dreamsim_preprocess = dreamsim(cache_dir="./models", dreamsim_type="dino_vitb16", device=device)

Downloading checkpoint


100%|██████████| 699M/699M [00:24<00:00, 29.4MB/s] 


Unzipping...


Using cache found in ./models/facebookresearch_dino_main
  WeightNorm.apply(module, name, dim)


In [62]:
images = glob.glob("./downloads/subset/*/*/*")
le = LabelEncoder(os.listdir("./downloads/subset/train"))

In [None]:
output_dir = "./dataset/rp2k_nights_224"
shutil.rmtree(output_dir, ignore_errors=True)

rows = []
for ref_image_path in tqdm(images):
    try:
        ref_class_label = os.path.basename(os.path.dirname(ref_image_path))
        ref_class_id = le.class2id[ref_class_label]
        split = os.path.basename(os.path.dirname(os.path.dirname(ref_image_path)))

        ref_image = resize_and_pad_image_cv2(np.array(Image.open(ref_image_path).convert("RGBA").convert("RGB")))
        ref_path = f"{output_dir}/ref/{ref_class_id}/{ref_class_id}_{hashlib.md5(ref_image).hexdigest()}.jpg"
        with torch.no_grad():
            img = dreamsim_preprocess(Image.fromarray(ref_image)).to(device)
            ref_image_embedding = model.embed(img)
            del img

        same_class_images = [{"path": p} for p in images if ref_class_label in p]
        for img_dict in same_class_images:
            with torch.no_grad():
                img_dict["class_id"] = le.class2id[os.path.basename(os.path.dirname(img_dict["path"]))]
                img_dict["image"] = resize_and_pad_image_cv2(np.array(Image.open(img_dict["path"]).convert("RGBA").convert("RGB")))
                img = dreamsim_preprocess(Image.fromarray(img_dict["image"])).to(device)
                img_dict["similarity"] = F.cosine_similarity(
                    ref_image_embedding,
                    model.embed(img)
                ).item()
                del img

        same_class_images = sorted([x for x in same_class_images if x['similarity'] != 1.0], key=lambda x:x["similarity"], reverse=True)

        if len(same_class_images) <= 1:
            continue

        right_image = same_class_images[0]["image"]
        right_path = f"{output_dir}/distort/{same_class_images[0]['class_id']}/{same_class_images[0]['class_id']}_{hashlib.md5(right_image).hexdigest()}.jpg"

        left_image = same_class_images[-1]["image"]
        left_path = f"{output_dir}/distort/{same_class_images[-1]['class_id']}/{same_class_images[-1]['class_id']}_{hashlib.md5(left_image).hexdigest()}.jpg"

        try:
            ref_image.shape
            left_image.shape
            right_image.shape
        except Exception as e:
            print(e)
            continue

        os.makedirs(os.path.dirname(ref_path), exist_ok=True)
        cv2.imwrite(ref_path, cv2.cvtColor(ref_image, cv2.COLOR_RGB2BGR))

        os.makedirs(os.path.dirname(left_path), exist_ok=True)
        cv2.imwrite(left_path, cv2.cvtColor(left_image, cv2.COLOR_RGB2BGR))
        
        os.makedirs(os.path.dirname(right_path), exist_ok=True)
        cv2.imwrite(right_path, cv2.cvtColor(right_image, cv2.COLOR_RGB2BGR))

        rows.append({
            "id": ref_class_id,
            "left_vote": 0,
            "right_vote": 1,
            "votes": 8,
            "ref_path": "ref/"+ref_path.split("/ref/")[1],
            "left_path": "distort/"+left_path.split("/distort/")[1],
            "right_path": "distort/"+right_path.split("/distort/")[1],
            "split": split,
            "is_imagenet": "FALSE",
            "prompt": "product"
        })
    except Exception as e:
        print(e)

In [None]:
df = pd.DataFrame(rows)
df.to_csv("./dataset/rp2k_nights_224/data.csv", index=False)
df