In [1]:
import objaverse
import multiprocessing


In [2]:
import glob
import gzip
import json
import multiprocessing
import os
import urllib.request
import warnings
from typing import Any, Dict, List, Optional, Tuple

from tqdm import tqdm

In [3]:
range_one, range_two = 5000, 10000
uids = objaverse.load_uids()

annotations = objaverse.load_annotations(uids[range_one:range_two])


In [4]:
processes = multiprocessing.cpu_count()

In [24]:
processes

20

In [35]:
set([annotations[id]["license"] for id in uids[range_one:range_two]])

{'by', 'by-nc', 'by-nc-sa', 'by-sa', 'cc0'}

In [14]:
# annotation["licensee"]

# [uid for uid, annotation in annotations.items() if annotation["license"] == "by"]

In [5]:
filtered_uids = [uid for uid, annotation in annotations.items() if annotation["license"] in ["by","by-sa","cc0"]]

In [37]:
len(filtered_uids)

4512

In [29]:
objects
# '/home/jiwi/.objaverse/hf-objaverse-v1/glbs/000-023/8476c4170df24cf5bbe6967222d1a42d.glb'

{'8476c4170df24cf5bbe6967222d1a42d': '/home/jiwi/.objaverse/hf-objaverse-v1/glbs/000-023/8476c4170df24cf5bbe6967222d1a42d.glb',
 '8ff7f1f2465347cd8b80c9b206c2781e': '/home/jiwi/.objaverse/hf-objaverse-v1/glbs/000-023/8ff7f1f2465347cd8b80c9b206c2781e.glb',
 'c786b97d08b94d02a1fa3b87d2e86cf1': '/home/jiwi/.objaverse/hf-objaverse-v1/glbs/000-023/c786b97d08b94d02a1fa3b87d2e86cf1.glb',
 'be2c02614d774f9da672dfdc44015219': '/home/jiwi/.objaverse/hf-objaverse-v1/glbs/000-023/be2c02614d774f9da672dfdc44015219.glb'}

In [6]:
BASE_PATH = os.path.join("/home/jiwi/Documents", "objaverse")

__version__ = "0.1.7"
_VERSIONED_PATH = os.path.join(BASE_PATH, "hf-objaverse-v1")

In [7]:
def _load_object_paths() -> Dict[str, str]:
    """Load the object paths from the dataset.

    The object paths specify the location of where the object is located
    in the Hugging Face repo.

    Returns:
        A dictionary mapping the uid to the object path.
    """
    object_paths_file = "object-paths.json.gz"
    local_path = os.path.join(_VERSIONED_PATH, object_paths_file)
    if not os.path.exists(local_path):
        hf_url = f"https://huggingface.co/datasets/allenai/objaverse/resolve/main/{object_paths_file}"
        # wget the file and put it in local_path
        os.makedirs(os.path.dirname(local_path), exist_ok=True)
        urllib.request.urlretrieve(hf_url, local_path)
    with gzip.open(local_path, "rb") as f:
        object_paths = json.load(f)
    return object_paths


def load_uids() -> List[str]:
    """Load the uids from the dataset.

    Returns:
        A list of uids.
    """
    return list(_load_object_paths().keys())


def _download_object(
    uid: str,
    object_path: str,
    total_downloads: float,
    start_file_count: int,
) -> Tuple[str, str]:
    """Download the object for the given uid.

    Args:
        uid: The uid of the object to load.
        object_path: The path to the object in the Hugging Face repo.

    Returns:
        The local path of where the object was downloaded.
    """
    # print(f"downloading {uid}")
    local_path = os.path.join(_VERSIONED_PATH, object_path)
    tmp_local_path = os.path.join(_VERSIONED_PATH, object_path + ".tmp")
    hf_url = (
        f"https://huggingface.co/datasets/allenai/objaverse/resolve/main/{object_path}"
    )
    # wget the file and put it in local_path
    os.makedirs(os.path.dirname(tmp_local_path), exist_ok=True)
    urllib.request.urlretrieve(hf_url, tmp_local_path)

    os.rename(tmp_local_path, local_path)

    files = glob.glob(os.path.join(_VERSIONED_PATH, "glbs", "*", "*.glb"))
    print(
        "Downloaded",
        len(files) - start_file_count,
        "/",
        total_downloads,
        "objects",
    )

    return uid, local_path


def load_objects(uids: List[str], download_processes: int = 1) -> Dict[str, str]:
    """Return the path to the object files for the given uids.

    If the object is not already downloaded, it will be downloaded.

    Args:
        uids: A list of uids.
        download_processes: The number of processes to use to download the objects.

    Returns:
        A dictionary mapping the object uid to the local path of where the object
        downloaded.
    """
    object_paths = _load_object_paths()
    out = {}
    if download_processes == 1:
        uids_to_download = []
        for uid in uids:
            if uid.endswith(".glb"):
                uid = uid[:-4]
            if uid not in object_paths:
                warnings.warn(f"Could not find object with uid {uid}. Skipping it.")
                continue
            object_path = object_paths[uid]
            local_path = os.path.join(_VERSIONED_PATH, object_path)
            if os.path.exists(local_path):
                out[uid] = local_path
                continue
            uids_to_download.append((uid, object_path))
        if len(uids_to_download) == 0:
            return out
        start_file_count = len(
            glob.glob(os.path.join(_VERSIONED_PATH, "glbs", "*", "*.glb"))
        )
        for uid, object_path in uids_to_download:
            uid, local_path = _download_object(
                uid, object_path, len(uids_to_download), start_file_count
            )
            out[uid] = local_path
    else:
        args = []
        for uid in uids:
            if uid.endswith(".glb"):
                uid = uid[:-4]
            if uid not in object_paths:
                warnings.warn(f"Could not find object with uid {uid}. Skipping it.")
                continue
            object_path = object_paths[uid]
            local_path = os.path.join(_VERSIONED_PATH, object_path)
            if not os.path.exists(local_path):
                args.append((uid, object_paths[uid]))
            else:
                out[uid] = local_path
        if len(args) == 0:
            return out
        print(
            f"starting download of {len(args)} objects with {download_processes} processes"
        )
        start_file_count = len(
            glob.glob(os.path.join(_VERSIONED_PATH, "glbs", "*", "*.glb"))
        )
        args_list = [(*arg, len(args), start_file_count) for arg in args]
        with multiprocessing.Pool(download_processes) as pool:
            r = pool.starmap(_download_object, args_list)
            for uid, local_path in r:
                out[uid] = local_path
    return out

In [8]:
objects = load_objects(
    uids=filtered_uids,
    download_processes=processes
)

Downloaded 263 / 4539 objects
Downloaded 264 / 4539 objects
Downloaded 265 / 4539 objects
Downloaded 266 / 4539 objects
Downloaded 267 / 4539 objects
Downloaded 268Downloaded /  2694539  /objects 
4539 objects
Downloaded 270 / 4539 objects
Downloaded 271 / 4539 objects
Downloaded 272 / 4539 objects
Downloaded 273 / 4539 objects
Downloaded 274 / 4539 objects
Downloaded 275 / 4539 objects
Downloaded 276 / 4539 objects
Downloaded 277 / 4539 objects
Downloaded 278 / 4539 objects
Downloaded 279 / 4539 objects
Downloaded 280 / 4539 objects
Downloaded 281 / 4539 objects
Downloaded 282 / 4539 objects
Downloaded 283 / 4539 objects
Downloaded 284 / 4539 objects
Downloaded 285 / 4539 objects
Downloaded 286 / 4539 objects
Downloaded 287 / 4539 objects
Downloaded 288 / 4539 objects
Downloaded 289 / 4539 objects
Downloaded 290 / 4539 objects
Downloaded 291 / 4539 objects
Downloaded 292 / 4539 objects
Downloaded 293 / 4539 objects
Downloaded 294 / Downloaded4539  295objects 
/ 4539 objects
Downloaded