In [1]:
import sys
from pathlib import Path
from tqdm import tqdm
import random
from shutil import copy
import pandas as pd

sys.path.append(str(Path("..").resolve()))
from src import *

# Prepare dataset for ControlNet training

##### ℹ️ This notebook requires the triplets `uv`, `caption` and `diffuse` for each dataset considered

In this notebook we move all the generated triplets to a single folder, compliant with `ImageFolder` dataset. See more at [Create an image dataset](https://huggingface.co/docs/datasets/image_dataset#imagefolder)

In [9]:
MAX_DATASET_SIZE = 10
TEST_SET_RATIO = 0.1
OUTPUT_PATH = Path("../dataset")

In [3]:
datasets: list[Dataset3D] = [ObjaverseDataset3D()]

In [10]:
train_path = OUTPUT_PATH / "train"
test_path = OUTPUT_PATH / "test"
for path in [train_path, test_path]:
    for folder in ["diffuse", "uv"]:
        (path / folder).mkdir(parents=True, exist_ok=True)

In [11]:
metadata = pd.DataFrame(columns=["uv", "diffuse", "caption"])
for dataset in datasets:
    valid_uids = dataset.statistics["valid"].index
    avail_uids = dataset.triplets
    uids = list(avail_uids.intersection(valid_uids))
    if MAX_DATASET_SIZE:
        uids = uids[:MAX_DATASET_SIZE]
    train_uids = random.choices(list(uids), k=int(len(uids) * 0.9))
    cprint(f"yellow:{dataset.__class__.__name__}", "has", len(avail_uids), "uids,", len(uids), "of them are valid.")
    uv_paths = {x.stem: x for x in (dataset.DATASET_PATH / "uv").glob("*") if x.suffix in dataset.IMG_EXT}
    diffuse_paths = {x.stem: x for x in (dataset.DATASET_PATH / "diffuse").glob("*") if x.suffix in dataset.IMG_EXT}
    captions = dataset.captions

    for uid in tqdm(uids):
        path = train_path if uid in train_uids else test_path
        copy(uv_paths[uid], path / "uv")
        copy(diffuse_paths[uid], path / "diffuse")
        metadata.loc[-1] = [
            uv_paths[uid].name,
            diffuse_paths[uid].name,
            captions[uid],
        ]
        metadata.index += 1
metadata.to_csv(OUTPUT_PATH / "metadata.csv", index=False)

[1m[33mObjaverseDataset3D[0m has [1m[34m13,700[0m uids, [1m[34m10[0m of them are valid.


100%|██████████| 10/10 [00:00<00:00, 180.02it/s]


In [29]:
size = sum(f.stat().st_size for f in train_path.glob("**/*") if f.is_file()) / (1 - TEST_SET_RATIO)
cprint("Dataset size:", size // 1024**2, "green:MiB")

Dataset size: [1m[34m22.0[0m [1m[32mMiB[0m
