In [2]:
!pip install requests tqdm trimesh thingi10k numpy-stl numpy



In [None]:
import os
import requests
import zipfile
import shutil
import trimesh
import thingi10k
import numpy as np
from stl import mesh
from tqdm import tqdm
from urllib.parse import urlencode
import warnings

!pip install scipy



In [None]:
# datasets_manager_fixed.py
import os
import sys
import zipfile
import shutil
import requests
import tempfile
import logging
import time
import platform
import urllib.request
from urllib.parse import urlencode
from concurrent.futures import (
    ThreadPoolExecutor,
    ProcessPoolExecutor,
    as_completed,
    wait,
    FIRST_COMPLETED,
)
from tqdm.notebook import tqdm
import multiprocessing

# mesh libs
import trimesh
import numpy as np
from stl import mesh
import thingi10k
import scipy.io

logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")


# -----------------------
# Worker functions (module-level for pickling)
# -----------------------
def _download_stream(url, filename, timeout=30, max_retries=3, chunk_size=8192):
    """
    Download via requests for http/https; fallback to urllib for ftp.
    Returns filename on success or raises.
    """
    if url.startswith("ftp://"):
        # use urllib.request.urlretrieve for ftp
        try:
            urllib.request.urlretrieve(url, filename)
            return filename
        except Exception as e:
            if os.path.exists(filename):
                try:
                    os.remove(filename)
                except Exception:
                    pass
            raise RuntimeError(f"FTP download failed: {url} -> {e}")
    # else HTTP(S)
    session = requests.Session()
    adapter = requests.adapters.HTTPAdapter(max_retries=max_retries)
    session.mount("https://", adapter)
    session.mount("http://", adapter)
    try:
        with session.get(url, stream=True, timeout=timeout) as r:
            r.raise_for_status()
            total = int(r.headers.get("content-length", 0) or 0)
            with (
                open(filename, "wb") as f,
                tqdm(
                    total=total,
                    unit="B",
                    unit_scale=True,
                    leave=False,
                    desc=os.path.basename(filename),
                ) as bar,
            ):
                for chunk in r.iter_content(chunk_size=chunk_size):
                    if chunk:
                        f.write(chunk)
                        bar.update(len(chunk))
        return filename
    except Exception as e:
        if os.path.exists(filename):
            try:
                os.remove(filename)
            except Exception:
                pass
        raise RuntimeError(f"Download failed: {url} -> {e}")


def _convert_to_stl_worker(src_path, dst_path):
    """
    Convert a single mesh file to STL using trimesh.
    Returns (dst_path, None) on success or (dst_path, error_str) on failure.
    """
    try:
        if os.path.exists(dst_path):
            return (dst_path, None)
        # trimesh can sometimes infer format; try load_mesh, and if fails try load
        mesh_obj = None
        try:
            mesh_obj = trimesh.load_mesh(src_path, force="mesh")
        except Exception:
            try:
                mesh_obj = trimesh.load(src_path, force="mesh")
            except Exception as e:
                return (dst_path, f"trimesh load failed: {e}")

        if mesh_obj is None or mesh_obj.is_empty:
            return (dst_path, "Empty or invalid mesh")
        os.makedirs(os.path.dirname(dst_path), exist_ok=True)
        mesh_obj.export(dst_path)
        return (dst_path, None)
    except Exception as e:
        return (dst_path, str(e))


def _thingi10k_np_to_stl_worker(npz_path, _id, dst_path):
    """
    Convert Thingi10k stored numpy arrays to STL.
    """
    try:
        if os.path.exists(dst_path):
            return (dst_path, None)
        with np.load(npz_path) as data:
            vertices = np.asarray(data["vertices"], dtype=np.float64)
            facets = np.asarray(data["facets"], dtype=np.int32)
        if vertices.shape[0] < 3 or facets.shape[0] == 0:
            return (dst_path, "Insufficient geometry")
        mesh_data = vertices[facets]
        stl_mesh = mesh.Mesh(np.zeros(mesh_data.shape[0], dtype=mesh.Mesh.dtype))
        stl_mesh.vectors = mesh_data
        os.makedirs(os.path.dirname(dst_path), exist_ok=True)
        stl_mesh.save(dst_path)
        return (dst_path, None)
    except Exception as e:
        return (dst_path, str(e))


# -----------------------
# Manager
# -----------------------
class DatasetsManager:
    def __init__(self, root_dir="data", max_workers_cpu=None, max_workers_io=None):
        self.root_dir = root_dir
        # Note: user requested final structure modelnet40, shapenet, etc.
        self.paths = {
            "thingi10k": os.path.join(root_dir, "thingi10k"),
            "modelnet": os.path.join(
                root_dir, "modelnet40"
            ),  # << specifically to match your request
            "abc": os.path.join(root_dir, "abc_dataset"),
            "objectnet": os.path.join(root_dir, "objectnet3d"),
            "shapenet": os.path.join(root_dir, "shapenet"),
            "custom": os.path.join(root_dir, "custom_dataset"),
        }
        os.makedirs(root_dir, exist_ok=True)

        cpu_count = max(1, (os.cpu_count() or 2) - 1)
        self.max_workers_cpu = max_workers_cpu or min(max(1, cpu_count), 12)
        self.max_workers_io = max_workers_io or min(16, (os.cpu_count() or 4) * 2)

        print(f"DatasetsManager(root_dir='{self.root_dir}')")
        print(f"CPU workers: {self.max_workers_cpu}, IO workers: {self.max_workers_io}")

    # ---- Process context helpers ----
    def _get_process_context(self):
        try:
            if platform.system() != "Windows":
                return multiprocessing.get_context("fork")
        except Exception:
            pass
        return multiprocessing.get_context()

    def _use_process_pool(self):
        ctx = self._get_process_context()
        try:
            with ProcessPoolExecutor(max_workers=1, mp_context=ctx) as ex:
                pass
            return ctx
        except Exception as e:
            logging.warning(f"Process pool unavailable: {e}. Will fallback to threads.")
            return None

    # ---- parallel convert / download ----
    def _parallel_mesh_convert(self, jobs, desc="Converting meshes"):
        results = {}
        if not jobs:
            return results

        ctx = self._use_process_pool()
        if ctx is not None:
            with ProcessPoolExecutor(
                max_workers=self.max_workers_cpu, mp_context=ctx
            ) as exe:
                futures = {}
                for src, dst, jtype in jobs:
                    if jtype == "thingi10k":
                        fut = exe.submit(
                            _thingi10k_np_to_stl_worker, src, os.path.basename(src), dst
                        )
                    else:
                        fut = exe.submit(_convert_to_stl_worker, src, dst)
                    futures[fut] = dst
                for fut in tqdm(as_completed(futures), total=len(futures), desc=desc):
                    dst = futures[fut]
                    try:
                        dst_path, err = fut.result()
                        results[dst_path] = err
                    except Exception as e:
                        results[dst] = str(e)
                        print(str(e))
            return results

        # fallback to threads (less efficient for CPU-bound)
        logging.warning("Falling back to ThreadPoolExecutor for conversions.")
        with ThreadPoolExecutor(
            max_workers=max(2, min(self.max_workers_cpu, 8))
        ) as exe:
            futures = {}
            for src, dst, jtype in jobs:
                if jtype == "thingi10k":
                    fut = exe.submit(
                        _thingi10k_np_to_stl_worker, src, os.path.basename(src), dst
                    )
                else:
                    fut = exe.submit(_convert_to_stl_worker, src, dst)
                futures[fut] = dst
            for fut in tqdm(as_completed(futures), total=len(futures), desc=desc):
                dst = futures[fut]
                try:
                    dst_path, err = fut.result()
                    results[dst_path] = err
                except Exception as e:
                    results[dst] = str(e)
        return results

    def _parallel_downloads(self, download_tasks, desc="Downloading"):
        results = []
        if not download_tasks:
            return results
        with ThreadPoolExecutor(max_workers=self.max_workers_io) as exe:
            futures = {
                exe.submit(_download_stream, url, out): (url, out)
                for (url, out) in download_tasks
            }
            for fut in tqdm(as_completed(futures), total=len(futures), desc=desc):
                url, out = futures[fut]
                try:
                    res = fut.result()
                    results.append((out, None))
                except Exception as e:
                    results.append((out, str(e)))
        return results

    # ---- orchestrator ----
    def prepare_all_datasets(self):
        order = [
            self.prepare_thingi10k,
            self.prepare_modelnet40,
            self.prepare_abc_dataset,
            self.prepare_objectnet3d,
            self.prepare_shapenet,
            self.prepare_custom_dataset,
        ]
        for i, fn in enumerate(order, start=1):
            print(f"\n--- [{i}/{len(order)}] {fn.__name__} ---")
            fn()
        print("\nAll dataset preparations finished.")

    # ---------- Thingi10k ----------
    def prepare_thingi10k(self):
        out_models = os.path.join(self.paths["thingi10k"], "models")
        os.makedirs(out_models, exist_ok=True)
        try:
            thingi10k.init()
        except Exception as e:
            print(f"Thingi10k init failed: {e}")
            return
        jobs = []
        for entry in tqdm(thingi10k.dataset(), desc="Collect Thingi10k entries"):
            fid = entry.get("file_id")
            npz_path = entry.get("file_path")
            if not npz_path or not fid:
                continue
            dst = os.path.join(out_models, f"{fid}.stl")
            if os.path.exists(dst):
                continue
            jobs.append((npz_path, dst, "thingi10k"))
        if not jobs:
            print("No new Thingi10k jobs.")
            return
        print(
            f"Converting {len(jobs)} Thingi10k entries using up to {self.max_workers_cpu} workers..."
        )
        res = self._parallel_mesh_convert(jobs, desc="Thingi10k -> STL")
        failed = [p for p, e in res.items() if e]
        if failed:
            print(f"Thingi10k: {len(failed)} conversions failed.")
        print("Thingi10k done.")

    # ---------- ObjectNet3D ----------
    def prepare_objectnet3d(self):
        out_root = self.paths["objectnet"]
        os.makedirs(out_root, exist_ok=True)

        # If directory has contents assume already processed (keeps idempotency)
        if any(os.scandir(out_root)):
            print("ObjectNet3D directory not empty; skipping.")
            return

        temp_dir = os.path.join(self.root_dir, "ObjectNet3D_temp")
        os.makedirs(temp_dir, exist_ok=True)
        try:
            urls = {
                "annotations": "ftp://cs.stanford.edu/cs/cvgl/ObjectNet3D/ObjectNet3D_annotations.zip",
                "cads": "ftp://cs.stanford.edu/cs/cvgl/ObjectNet3D/ObjectNet3D_cads.zip",
                "images": "ftp://cs.stanford.edu/cs/cvgl/ObjectNet3D/ObjectNet3D_images.zip",
            }
            dl_tasks = []
            for name, url in urls.items():
                target_zip = os.path.join(temp_dir, os.path.basename(url))
                if not os.path.exists(target_zip):
                    dl_tasks.append((url, target_zip))
            if dl_tasks:
                print("Downloading ObjectNet3D archives (parallel; FTP supported)...")
                dl_res = self._parallel_downloads(
                    dl_tasks, desc="ObjectNet3D downloads"
                )
                for out, err in dl_res:
                    if err:
                        print(f"Download failed for {out}: {err}")

            # Extract any zips found
            for file in os.scandir(temp_dir):
                if file.name.endswith(".zip") and os.path.getsize(file.path) > 0:
                    try:
                        with zipfile.ZipFile(file.path, "r") as z:
                            z.extractall(temp_dir)
                    except Exception as e:
                        print(f"Extraction failed {file.path}: {e}")

            # Convert CAD .off -> .stl
            cad_off_dir = os.path.join(temp_dir, "ObjectNet3D", "CAD", "off")
            jobs = []
            if os.path.exists(cad_off_dir):
                for cat in os.scandir(cad_off_dir):
                    if not cat.is_dir():
                        continue
                    dst_cat_models = os.path.join(out_root, cat.name, "models")
                    os.makedirs(dst_cat_models, exist_ok=True)
                    for off_file in os.scandir(cat.path):
                        # earlier code required 6-char base name; maintain that check
                        if (
                            off_file.name.endswith(".off")
                            and len(off_file.name.split(".")[0]) == 2
                        ):
                            dst = os.path.join(
                                dst_cat_models, off_file.name.replace(".off", ".stl")
                            )
                            if os.path.exists(dst):
                                continue
                            jobs.append((off_file.path, dst, "general"))

            if jobs:
                print(f"Converting {len(jobs)} ObjectNet3D CAD files...")
                res = self._parallel_mesh_convert(jobs, desc="ObjectNet3D CAD -> STL")
                failed = sum(1 for e in res.values() if e)
                print(f"ObjectNet3D CAD conversions done. Failed: {failed}")
            else:
                print("No CAD conversion jobs found for ObjectNet3D.")

            # Link/copy images using annotations
            ann_dir = os.path.join(temp_dir, "ObjectNet3D", "Annotations")
            img_dir = os.path.join(temp_dir, "ObjectNet3D", "Images")
            if os.path.exists(ann_dir) and os.path.exists(img_dir):
                for mat_file in tqdm(
                    os.scandir(ann_dir), desc="ObjectNet3D annotations"
                ):
                    if not (
                        mat_file.name.endswith(".mat")
                        and len(mat_file.name.split(".")[0]) == 6
                    ):
                        continue
                    try:
                        mat = scipy.io.loadmat(mat_file.path)
                        record = mat["record"][0, 0]
                        img_filename = str(record["filename"][0])
                        category_name = str(record["objects"][0, 0]["class"][0])
                        src_img = os.path.join(img_dir, img_filename)
                        if not os.path.exists(src_img):
                            continue
                        dest_img_dir = os.path.join(out_root, category_name, "images")
                        os.makedirs(dest_img_dir, exist_ok=True)
                        shutil.copy(
                            src_img,
                            os.path.join(dest_img_dir, os.path.basename(src_img)),
                        )
                    except Exception as e:
                        print(f"Annotation processing error {mat_file.name}: {e}")
            else:
                print(
                    "ObjectNet3D annotations or images folder missing after extraction."
                )
        finally:
            # cleanup
            try:
                if os.path.exists(temp_dir):
                    pass
                    shutil.rmtree(temp_dir)
            except Exception:
                pass
        print("ObjectNet3D done.")

    # ---------- ModelNet40 ----------
    def prepare_modelnet40(self):
        url = "http://modelnet.cs.princeton.edu/ModelNet40.zip"
        zip_path = os.path.join(self.root_dir, "ModelNet40.zip")
        out_root = self.paths["modelnet"]
        os.makedirs(self.root_dir, exist_ok=True)

        # skip if already processed (detect any category/models/*.stl)
        if os.path.exists(out_root):
            found = False
            for cat in os.scandir(out_root):
                if not cat.is_dir():
                    continue
                models_dir = os.path.join(cat.path, "models")
                if os.path.exists(models_dir) and any(
                    f.name.endswith(".stl") for f in os.scandir(models_dir)
                ):
                    found = True
                    break
            if found:
                print("ModelNet40 seems processed; skipping.")
                return

        # # download and extract if not present
        if not os.path.exists(out_root):
            print("Downloading ModelNet40 zip...")
            dl = self._parallel_downloads(
                [(url, zip_path)], desc="Downloading ModelNet40"
            )
            if not os.path.exists(zip_path):
                print("ModelNet40 zip missing after download.")
                return
            print("Extracting ModelNet40...")
            with zipfile.ZipFile(zip_path, "r") as z:
                z.extractall(self.root_dir)
            os.remove(zip_path)

        jobs = []
        # The on-disk structure after extraction is ModelNet40/[category]/train/*.off and test/*.off
        top_src = os.path.join(self.root_dir, "ModelNet40")
        if not os.path.exists(top_src):
            print(
                f"Expected extracted folder {top_src} not found; aborting ModelNet40."
            )
            return

        for category_d in tqdm(
            os.scandir(top_src), desc="Collecting ModelNet categories"
        ):
            if not category_d.is_dir() or category_d.name.startswith("__"):
                continue
            category_name = category_d.name
            dest_models_dir = os.path.join(out_root, category_name, "models")
            os.makedirs(dest_models_dir, exist_ok=True)

            for split in ["train", "test"]:
                split_path = os.path.join(category_d.path, split)
                if not os.path.exists(split_path):
                    continue
                for off_file in os.scandir(split_path):
                    if off_file.name.endswith(".off"):
                        dst = os.path.join(
                            dest_models_dir, off_file.name.replace(".off", ".stl")
                        )
                        if os.path.exists(dst):
                            continue
                        jobs.append((off_file.path, dst, "general"))
                # Remove split directory to save space AFTER collecting jobs
                try:
                    shutil.rmtree(split_path)
                except Exception:
                    pass

        if jobs:
            print(f"Converting {len(jobs)} ModelNet files (into {out_root}) ...")
            # split into batches to avoid huge queue spikes
            batch_size = max(500, self.max_workers_cpu * 50)
            failed_total = 0
            for i in range(0, len(jobs), batch_size):
                batch = jobs[i : i + batch_size]
                res = self._parallel_mesh_convert(
                    batch, desc=f"ModelNet40 conversions batch {i // batch_size + 1}"
                )
                failed = sum(1 for e in res.values() if e)
                failed_total += failed
                logging.info(f"Batch {i // batch_size + 1}: {failed} failed")
            print(f"ModelNet40 conversions finished. Failed: {failed_total}")
        else:
            print("No ModelNet conversions needed.")
        try:
            shutil.rmtree(top_src)
        except Exception:
            pass
        print("ModelNet40 done.")

    # ---------- ABC Dataset ----------
    def prepare_abc_dataset(self):
        url = "https://archive.nyu.edu/retrieve/120666/abc_full_100k_v00.zip"
        zip_path = os.path.join(self.root_dir, "abc_full_100k_v00.zip")
        extracted_path = os.path.join(self.root_dir, "100k")
        models_out = os.path.join(self.paths["abc"], "models")
        os.makedirs(models_out, exist_ok=True)

        if os.path.exists(models_out) and len(os.listdir(models_out)) > 1000:
            print("ABC dataset seems processed; skipping.")
            return

        if not os.path.exists(extracted_path):
            print("Downloading ABC dataset (large)...")
            self._parallel_downloads([(url, zip_path)], desc="Downloading ABC")
            if not os.path.exists(zip_path):
                print("ABC zip not present after download.")
                return
            print("Extracting ABC (this will take a while)...")
            with zipfile.ZipFile(zip_path, "r") as z:
                z.extractall(self.root_dir)
            os.remove(zip_path)

        jobs = []
        for split in ["train/512", "test/512"]:
            src_dir = os.path.join(extracted_path, split)
            if not os.path.exists(src_dir):
                continue
            for file in os.scandir(src_dir):
                if file.name.endswith(".obj"):
                    dst = os.path.join(models_out, file.name.replace(".obj", ".stl"))
                    if os.path.exists(dst):
                        continue
                    jobs.append((file.path, dst, "general"))

        if jobs:
            print(f"Converting {len(jobs)} ABC models...")
            res = self._parallel_mesh_convert(jobs, desc="ABC conversions")
            failed = sum(1 for e in res.values() if e)
            print(f"ABC conversions complete. Failed: {failed}")
        else:
            print("No ABC conversions needed.")

        if os.path.exists(extracted_path):
            try:
                shutil.rmtree(extracted_path)
            except Exception as e:
                print(f"Could not remove ABC extracted folder: {e}")

    # ---------- ShapeNet (incremental processing to avoid temp bloat) ----------
    def prepare_shapenet(self):
        out_root = self.paths["shapenet"]
        os.makedirs(out_root, exist_ok=True)
        hf_token = os.environ.get("HF_TOKEN")
        if not hf_token:
            print("HF_TOKEN not set; skipping ShapeNet.")
            return
        try:
            from huggingface_hub import login, snapshot_download

            login(token=hf_token)
            repo_id = "ShapeNet/ShapeNetCore"
            cache_dir = snapshot_download(repo_id=repo_id, repo_type="dataset")
        except Exception as e:
            print(f"Failed to fetch ShapeNet from HF Hub: {e}")
            return

        ctx = self._use_process_pool()
        use_processes = ctx is not None

        # iterate zips one-by-one -> process their models incrementally
        for zip_filename in tqdm(
            sorted(os.listdir(cache_dir)), desc="Enumerate ShapeNet zips"
        ):
            if not zip_filename.endswith(".zip"):
                continue
            cat_id = zip_filename[:-4]
            models_dir = os.path.join(out_root, cat_id, "models")
            images_dir = os.path.join(out_root, cat_id, "images")
            os.makedirs(models_dir, exist_ok=True)
            os.makedirs(images_dir, exist_ok=True)
            zip_filepath = os.path.join(cache_dir, zip_filename)

            try:
                with zipfile.ZipFile(zip_filepath, "r") as z:
                    all_files = z.namelist()
                    model_ids = sorted(
                        list({f.split("/")[1] for f in all_files if f.count("/") >= 2})
                    )

                    # We'll process models in a controlled concurrency pattern:
                    # submit up to max_workers_cpu conversions, wait for at least one to free up, delete tmpdir immediately.
                    if use_processes:
                        pool = ProcessPoolExecutor(
                            max_workers=self.max_workers_cpu, mp_context=ctx
                        )
                    else:
                        pool = ThreadPoolExecutor(
                            max_workers=max(1, min(self.max_workers_cpu, 4))
                        )

                    futures = {}  # future -> tmp_dir for cleanup
                    try:
                        for mid in tqdm(
                            model_ids, desc=f"Models in {cat_id}", leave=False
                        ):
                            # locate the candidate model .obj path inside zip (prefer model_normalized.obj)
                            model_norm = f"{cat_id}/{mid}/models/model_normalized.obj"
                            candidate = (
                                model_norm
                                if model_norm in all_files
                                else next(
                                    (
                                        f
                                        for f in all_files
                                        if f.startswith(f"{cat_id}/{mid}/models/")
                                        and f.endswith(".obj")
                                    ),
                                    None,
                                )
                            )
                            if not candidate:
                                continue
                            dst = os.path.join(models_dir, f"{mid}.stl")
                            if os.path.exists(dst):
                                continue

                            # extract only the model file to a fresh tmp dir
                            tmp_dir = tempfile.mkdtemp(prefix="shapenet_")
                            extracted_ok = False
                            try:
                                z.extract(candidate, path=tmp_dir)
                                # ensure we can find the file (zip may or may not preserve dirs)
                                src_path = os.path.join(tmp_dir, candidate)
                                if not os.path.exists(src_path):
                                    # fallback search
                                    found_obj = None
                                    for root, _, files in os.walk(tmp_dir):
                                        for f in files:
                                            if f.endswith(".obj"):
                                                found_obj = os.path.join(root, f)
                                                break
                                        if found_obj:
                                            break
                                    if found_obj:
                                        src_path = found_obj
                                    else:
                                        raise RuntimeError(
                                            "Could not locate extracted OBJ after extract"
                                        )
                                extracted_ok = True
                            except Exception as ex:
                                shutil.rmtree(tmp_dir, ignore_errors=True)
                                print(
                                    f"Error extracting {candidate} from {zip_filename}: {ex}"
                                )
                                continue

                            # extract screenshots for this model from zip (only screenshots/)
                            # they are small; copy to images_dir
                            screenshot_prefix = f"{cat_id}/{mid}/screenshots/"
                            for f in all_files:
                                if f.startswith(screenshot_prefix) and (
                                    f.endswith(".png")
                                    or f.endswith(".jpg")
                                    or f.endswith(".jpeg")
                                ):
                                    outname = f"{mid}_{os.path.basename(f)}"
                                    outpath = os.path.join(images_dir, outname)
                                    if not os.path.exists(outpath):
                                        try:
                                            with (
                                                z.open(f) as src,
                                                open(outpath, "wb") as dstf,
                                            ):
                                                shutil.copyfileobj(src, dstf)
                                        except Exception:
                                            pass

                            # submit conversion job
                            if use_processes:
                                future = pool.submit(
                                    _convert_to_stl_worker, src_path, dst
                                )
                            else:
                                future = pool.submit(
                                    _convert_to_stl_worker, src_path, dst
                                )
                            futures[future] = tmp_dir

                            # if we've reached pool capacity, wait for at least one to complete, then cleanup that tmpdir
                            while len(futures) >= max(1, self.max_workers_cpu):
                                done, not_done = wait(
                                    futures.keys(), return_when=FIRST_COMPLETED
                                )
                                for fut in done:
                                    tmp = futures.pop(fut, None)
                                    try:
                                        dst_path, err = fut.result()
                                        if err:
                                            logging.debug(
                                                f"ShapeNet conversion error for {dst_path}: {err}"
                                            )
                                    except Exception as e:
                                        logging.debug(
                                            f"ShapeNet conversion raised: {e}"
                                        )
                                    # cleanup tmp dir
                                    if tmp:
                                        try:
                                            shutil.rmtree(tmp)
                                        except Exception:
                                            pass

                        # after submitting all models, wait for remaining futures and cleanup
                        for fut in as_completed(futures.keys()):
                            tmp = futures.pop(fut, None)
                            try:
                                dst_path, err = fut.result()
                                if err:
                                    logging.debug(
                                        f"ShapeNet conversion error for {dst_path}: {err}"
                                    )
                            except Exception as e:
                                logging.debug(f"ShapeNet conversion raised: {e}")
                            if tmp:
                                try:
                                    shutil.rmtree(tmp)
                                except Exception:
                                    pass

                    finally:
                        pool.shutdown(wait=True)

            except Exception as e:
                print(f"Error processing ShapeNet zip {zip_filename}: {e}")

        print("ShapeNet done.")

    # ---------- Custom dataset (Yandex.Disk) ----------
    def prepare_custom_dataset(self):
        out_root = self.paths["custom"]
        os.makedirs(out_root, exist_ok=True)
        if any(os.scandir(out_root)):
            print("Custom dataset folder not empty; skipping.")
            return

        files_to_download = {
            "train_data": "https://disk.yandex.ru/d/RRXJu9ZtEmSXzQ",
            "test_data": "https://disk.yandex.ru/d/TmbB7BsGzg1dQQ",
        }
        tasks = []
        for key, public_url in files_to_download.items():
            api = "https://cloud-api.yandex.net/v1/disk/public/resources/download?"
            api_url = api + urlencode(dict(public_key=public_url))
            try:
                r = requests.get(api_url, timeout=10)
                r.raise_for_status()
                href = r.json().get("href")
                if href:
                    out_zip = os.path.join(self.root_dir, f"yandex_{key}.zip")
                    tasks.append((href, out_zip))
                else:
                    print(f"Could not resolve Yandex link {public_url}")
            except Exception as e:
                print(f"Yandex API error for {public_url}: {e}")

        if tasks:
            print("Downloading custom dataset zips from Yandex.Disk (parallel)...")
            dl_res = self._parallel_downloads(tasks, desc="Yandex downloads")
            for zipfile_path, err in dl_res:
                if err:
                    print(f"Download failed for {zipfile_path}: {err}")
                    continue
                try:
                    with zipfile.ZipFile(zipfile_path, "r") as z:
                        z.extractall(out_root)
                    os.remove(zipfile_path)
                except Exception as e:
                    print(f"Failed to extract {zipfile_path}: {e}")
        print("Custom dataset done.")


# -----------------------
# If executed as a script
# -----------------------
if __name__ == "__main__":
    # Recommended: run this script from a shell for maximum stability.
    manager = DatasetsManager(root_dir="data")
    # manager.prepare_custom_dataset()
    # manager.prepare_all_datasets()
    # manager.prepare_modelnet40()
    manager.prepare_shapenet()
    # manager.prepare_objectnet3d()


DatasetsManager(root_dir='data')
CPU workers: 12, IO workers: 16


Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


Fetching 58 files:   0%|          | 0/58 [00:00<?, ?it/s]

Enumerate ShapeNet zips:   0%|          | 0/58 [00:00<?, ?it/s]

Models in 02691156:   0%|          | 0/4045 [00:00<?, ?it/s]

Models in 02747177:   0%|          | 0/343 [00:00<?, ?it/s]

Models in 02773838:   0%|          | 0/83 [00:00<?, ?it/s]

Models in 02801938:   0%|          | 0/113 [00:00<?, ?it/s]

Models in 02808440:   0%|          | 0/856 [00:00<?, ?it/s]

Models in 02818832:   0%|          | 0/233 [00:00<?, ?it/s]

Models in 02828884:   0%|          | 0/1813 [00:00<?, ?it/s]

Models in 02843684:   0%|          | 0/73 [00:00<?, ?it/s]

Models in 02871439:   0%|          | 0/452 [00:00<?, ?it/s]

Models in 02876657:   0%|          | 0/498 [00:00<?, ?it/s]

Models in 02880940:   0%|          | 0/186 [00:00<?, ?it/s]

Models in 02924116:   0%|          | 0/939 [00:00<?, ?it/s]

Models in 02933112:   0%|          | 0/1571 [00:00<?, ?it/s]

Models in 02942699:   0%|          | 0/113 [00:00<?, ?it/s]

Models in 02946921:   0%|          | 0/108 [00:00<?, ?it/s]

Models in 02954340:   0%|          | 0/56 [00:00<?, ?it/s]

Models in 02958343:   0%|          | 0/3533 [00:00<?, ?it/s]

Models in 02992529:   0%|          | 0/831 [00:00<?, ?it/s]

Models in 03001627:   0%|          | 0/6778 [00:00<?, ?it/s]

Models in 03046257:   0%|          | 0/651 [00:00<?, ?it/s]

Models in 03085013:   0%|          | 0/65 [00:00<?, ?it/s]

Models in 03207941:   0%|          | 0/93 [00:00<?, ?it/s]

Models in 03211117:   0%|          | 0/1093 [00:00<?, ?it/s]

Models in 03261776:   0%|          | 0/73 [00:00<?, ?it/s]

Models in 03325088:   0%|          | 0/744 [00:00<?, ?it/s]

Models in 03337140:   0%|          | 0/298 [00:00<?, ?it/s]

Models in 03467517:   0%|          | 0/797 [00:00<?, ?it/s]

Models in 03513137:   0%|          | 0/162 [00:00<?, ?it/s]

Models in 03593526:   0%|          | 0/596 [00:00<?, ?it/s]

Models in 03624134:   0%|          | 0/424 [00:00<?, ?it/s]

Models in 03636649:   0%|          | 0/2318 [00:00<?, ?it/s]

Models in 03642806:   0%|          | 0/460 [00:00<?, ?it/s]

Models in 03691459:   0%|          | 0/1597 [00:00<?, ?it/s]

Models in 03710193:   0%|          | 0/94 [00:00<?, ?it/s]

Models in 03759954:   0%|          | 0/67 [00:00<?, ?it/s]

Models in 03761084:   0%|          | 0/152 [00:00<?, ?it/s]

Models in 03790512:   0%|          | 0/337 [00:00<?, ?it/s]

Models in 03797390:   0%|          | 0/214 [00:00<?, ?it/s]

Models in 03928116:   0%|          | 0/239 [00:00<?, ?it/s]

Models in 03938244:   0%|          | 0/96 [00:00<?, ?it/s]

Models in 03948459:   0%|          | 0/307 [00:00<?, ?it/s]

Models in 03991062:   0%|          | 0/602 [00:00<?, ?it/s]

Models in 04004475:   0%|          | 0/166 [00:00<?, ?it/s]

Models in 04074963:   0%|          | 0/66 [00:00<?, ?it/s]

Models in 04090263:   0%|          | 0/2373 [00:00<?, ?it/s]