In [1]:
import os
from pathlib import Path
import hashlib
import requests
from typing import Dict, List, Optional

import zipfile
import shutil

import pandas as pd
from tqdm.auto import tqdm

In [2]:
# Root directory where all data will be stored.
DATA_ROOT = Path("./sts_tooth_data").resolve()
DATA_ROOT.mkdir(parents=True, exist_ok=True)

# Subdirectories:
# - downloads: raw .zip.001, .zip.002, ... files
# - raw: extracted content from SD-Tooth.zip
# - processed_2d: cleaned & reorganized 2D images
DOWNLOAD_DIR = DATA_ROOT / "downloads"
RAW_DIR = DATA_ROOT / "raw"
PROCESSED_2D_DIR = DATA_ROOT / "processed_2d"

DOWNLOAD_DIR.mkdir(exist_ok=True)
RAW_DIR.mkdir(exist_ok=True)
PROCESSED_2D_DIR.mkdir(exist_ok=True)

print("DATA_ROOT:", DATA_ROOT)

In [4]:
# Download all 15 split zip parts

# %% [markdown]
# ## Step 1 (faster): Download all 15 split zip parts in parallel
# 
# This is an accelerated version of the download step.
# 
# Main changes compared to the original version:
# - We use a ThreadPoolExecutor to download several parts in parallel.
# - We use a single `requests.Session()` object (connection reuse).
# - Already existing files are still skipped.
# 
# Note:
# - If your network is unstable, you can reduce `MAX_WORKERS` (e.g. 2).
# - If your bandwidth is good, you can try increasing it to 6â€“8.

# %%
from concurrent.futures import ThreadPoolExecutor
import itertools

# Zenodo record ID for STS-Tooth
ZENODO_RECORD_ID = "10597292"

# Expected MD5 checksums from the Zenodo page.
EXPECTED_MD5: Dict[str, str] = {
    "SD-Tooth.zip.001": "9fff54c469d2f5706332d01fd362178e",
    "SD-Tooth.zip.002": "eb5be6735f8a20b469bef46a26784f3c",
    "SD-Tooth.zip.003": "1ee5ce98fbfadeb48b2c264ac425ff20",
    "SD-Tooth.zip.004": "12b7b965b1d7330c10850a40324368fd",
    "SD-Tooth.zip.005": "1348eaf36bed8da9aabd5748a70a3d51",
    "SD-Tooth.zip.006": "f0a2158d0cc915507ccefc53067b6a74",
    "SD-Tooth.zip.007": "d728cec1d10e05845d33d2a3480c1cf9",
    "SD-Tooth.zip.008": "af917049b72b512c4eae0d2801811b16",
    "SD-Tooth.zip.009": "699713928665b2c15537d665689c7bdd",
    "SD-Tooth.zip.010": "6ba5b3100fd516f031b68bc13868e0bb",
    "SD-Tooth.zip.011": "82e5df4d8dda37a45ea1725b6b21942b",
    "SD-Tooth.zip.012": "d7896f01462a4a1589f9f006e3c97190",
    "SD-Tooth.zip.013": "763728b43b93cf081cc28ca15d4490a1",
    "SD-Tooth.zip.014": "9cba0607e1d8066cd5aaac5c8091ee6a",
    "SD-Tooth.zip.015": "0be8577f58e8bc1937f8e162887c4654",
}


def build_part_filename(idx: int) -> str:
    """
    Build a local filename for the given index.

    Example:
        idx = 1 -> "SD-Tooth.zip.001"
        idx = 2 -> "SD-Tooth.zip.002"
    """
    return f"SD-Tooth.zip.{idx:03d}"


def build_part_url(idx: int) -> str:
    """
    Build the remote download URL for a given part.

    This follows the Zenodo file download pattern:
        https://zenodo.org/records/<record_id>/files/<filename>?download=1
    """
    part_name = build_part_filename(idx)
    return f"https://zenodo.org/records/{ZENODO_RECORD_ID}/files/{part_name}?download=1"


def md5sum(path: Path, chunk_size: int = 1024 * 1024) -> str:
    """
    Compute the MD5 checksum of a file in a memory-efficient way.
    """
    h = hashlib.md5()
    with path.open("rb") as f:
        for chunk in iter(lambda: f.read(chunk_size), b""):
            h.update(chunk)
    return h.hexdigest()


# Create a session to reuse HTTP connections across downloads
_session = requests.Session()


def download_file(url: str, dst: Path, chunk_size: int = 4 * 1024 * 1024) -> None:
    """
    Download a file from `url` to `dst`, showing a progress bar.

    If the destination file already exists, the download is skipped.

    Args:
        url: Remote URL.
        dst: Local file path.
        chunk_size: Number of bytes read per iteration (4 MB by default).
    """
    if dst.exists():
        print(f"[skip] file already exists: {dst.name}")
        return

    print(f"[download] {dst.name} <- {url}")
    with _session.get(url, stream=True) as r:
        r.raise_for_status()
        total = int(r.headers.get("Content-Length", 0)) or None
        with dst.open("wb") as f, tqdm(
            total=total, unit="B", unit_scale=True, desc=dst.name
        ) as pbar:
            for chunk in r.iter_content(chunk_size=chunk_size):
                if chunk:
                    f.write(chunk)
                    pbar.update(len(chunk))


def download_one_part(idx: int):
    """
    Helper wrapper for ThreadPoolExecutor.

    It builds the filename and URL for a given part index and then
    calls `download_file`.
    """
    filename = build_part_filename(idx)
    url = build_part_url(idx)
    dst = DOWNLOAD_DIR / filename
    download_file(url, dst)
    return filename


# Number of parallel workers (you can tune this)
MAX_WORKERS = 6

print(f"Starting parallel download with {MAX_WORKERS} workers...")

# We wrap the executor.map call with tqdm so we see progress over the 15 parts.
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
    list(
        tqdm(
            executor.map(download_one_part, range(1, 16)),
            total=15,
            desc="Downloading parts",
        )
    )

print("All parts have finished the download step (existing files were skipped).")


In [5]:
# Optional MD5 verification
def verify_md5_all() -> bool:
    """
    Verify the MD5 checksum of all downloaded parts.

    Returns:
        True if all files are present and all MD5 checksums match,
        False otherwise.
    """
    ok = True
    for name, expected in EXPECTED_MD5.items():
        path = DOWNLOAD_DIR / name
        if not path.exists():
            print(f"[missing] {name}")
            ok = False
            continue

        real = md5sum(path)
        if real.lower() == expected.lower():
            print(f"[OK] {name} md5 matches")
        else:
            print(f"[ERR] {name} md5 mismatch: got={real}, expected={expected}")
            ok = False
    return ok


# You can comment out the next two lines if you don't need verification.
all_good = verify_md5_all()
print("MD5 verification result:", all_good)

In [6]:
# Merge parts and unzip

# Path to the merged zip file
MERGED_ZIP = DOWNLOAD_DIR / "SD-Tooth.zip"


def merge_parts_to_zip():
    """
    Merge all split zip parts into a single `SD-Tooth.zip` file.

    Each split part is appended in order:
        SD-Tooth.zip.001, SD-Tooth.zip.002, ..., SD-Tooth.zip.015

    If the merged zip already exists, this step is skipped.
    """
    if MERGED_ZIP.exists():
        print(f"[skip] merged zip already exists: {MERGED_ZIP}")
        return

    print("[merge] merging parts into SD-Tooth.zip ...")
    with MERGED_ZIP.open("wb") as out_f:
        for i in range(1, 16):
            part_path = DOWNLOAD_DIR / build_part_filename(i)
            if not part_path.exists():
                raise FileNotFoundError(f"Missing split part: {part_path}")
            print(f"  -> merging {part_path.name}")
            with part_path.open("rb") as in_f:
                shutil.copyfileobj(in_f, out_f)
    print("[done] merge finished.")


merge_parts_to_zip()


def unzip_merged_zip():
    """
    Extract the merged `SD-Tooth.zip` into RAW_DIR.

    If RAW_DIR already contains some files or directories,
    we assume extraction has been done before and skip it.
    """
    if any(RAW_DIR.iterdir()):
        print("[skip] RAW_DIR is not empty, assuming data already extracted:", RAW_DIR)
        return

    print(f"[unzip] {MERGED_ZIP} -> {RAW_DIR}")
    with zipfile.ZipFile(MERGED_ZIP, "r") as zf:
        zf.extractall(RAW_DIR)
    print("[done] unzip finished.")


unzip_merged_zip()

# %% [markdown]
# ## Step 3: Locate the `STS-2D-Tooth` directory and scan all PNGs
# 
# The extracted folder may contain multiple top-level directories.  
# We need to:
# 
# 1. Search for the directory whose name contains `"STS-2D-Tooth"`.
# 2. Recursively scan all `.png` files under this directory.
# 3. For each PNG file, infer:
#    - `age_group`: `"adult"`, `"children"`, or `"unknown"` based on the path (`A-PXI` / `C-PXI`).
#    - `label_status`: `"labeled"`, `"unlabeled"`, or `"unknown"` based on path (`Labeled` / `Unlabeled`).
#    - `is_mask`: Boolean, `True` if the filename suggests a mask (`mask`/`label`).
#    - `pair_id`: A normalized ID used to match an image and its mask (e.g., `A001` vs `A001_mask`).
# 4. Save all this information into a `pandas.DataFrame` and write it to `sts2d_index.csv`.

In [12]:
#  Find STS-2D-Tooth root & scan PNGs
# %%  (fixed Cell 5) Find STS-2D-Tooth root & scan PNGs

def find_sts2d_root(raw_root: Path) -> Path:
    candidates: List[Path] = []
    for p in raw_root.rglob("*"):
        if p.is_dir() and "STS-2D-Tooth" in p.name:
            candidates.append(p)

    if not candidates:
        raise RuntimeError(
            "Could not find a directory with name containing 'STS-2D-Tooth' under RAW_DIR."
        )

    candidates.sort(key=lambda x: len(str(x)))
    sts2d_root = candidates[0]
    print("[found STS-2D-Tooth directory]:", sts2d_root)
    return sts2d_root


STS2D_ROOT = find_sts2d_root(RAW_DIR)


def infer_age_group(path: Path) -> str:
    s = str(path)
    if "A-PXI" in s:
        return "adult"
    if "C-PXI" in s:
        return "children"
    return "unknown"


def infer_label_status(path: Path) -> str:
    s = str(path)
    if "Labeled" in s:
        return "labeled"
    if "Unlabeled" in s:
        return "unlabeled"
    return "unknown"


def infer_is_mask(path: Path) -> bool:
    """
    Heuristically decide whether this PNG is a mask file.

    We now check BOTH:
    - file name patterns (e.g. *_mask.png, *-label.png)
    - directory names like 'Mask', 'Masks', 'Label', 'Labels'

    This is to correctly catch masks that are placed under a 'Mask' folder
    but whose file names do not contain 'mask' or 'label'.
    """
    stem = path.stem.lower()
    # 1) file name patterns
    for suf in ["_mask", "-mask", "_label", "-label"]:
        if stem.endswith(suf):
            return True

    # 2) directory names
    parts = [p.lower() for p in path.parts]
    if any(p in {"mask", "masks", "label", "labels"} for p in parts):
        return True

    return False


def make_pair_id(path: Path) -> str:
    stem = path.stem.lower()
    for suf in ["_mask", "-mask", "_label", "-label"]:
        if stem.endswith(suf):
            stem = stem[: -len(suf)]
            break
    return stem


records = []

print("[scan] scanning all PNG files under STS-2D-Tooth ...")
for png_path in tqdm(list(STS2D_ROOT.rglob("*.png"))):
    rel_path = png_path.relative_to(STS2D_ROOT)
    age_group = infer_age_group(png_path)
    label_status = infer_label_status(png_path)
    is_mask = infer_is_mask(png_path)
    pair_id = make_pair_id(png_path)

    records.append(
        {
            "rel_path": str(rel_path).replace("\\", "/"),
            "age_group": age_group,
            "label_status": label_status,
            "is_mask": is_mask,
            "pair_id": pair_id,
        }
    )

df = pd.DataFrame(records)
print("Total number of PNG files found:", len(df))

INDEX_CSV = DATA_ROOT / "sts2d_index.csv"
df.to_csv(INDEX_CSV, index=False)
print("Index file saved to:", INDEX_CSV)


# %% [markdown]
# ## Step 3.1: Quick sanity checks and statistics
# 
# - Check how many files belong to each `age_group`.
# - Check how many files are labeled vs unlabeled.
# - Check how many files are masks vs normal images.
# - Show the first few rows of `pair_id` statistics to see if images and masks are paired correctly.

In [13]:
# Basic statistics for sanity check
print("\n=== age_group distribution ===")
print(df["age_group"].value_counts())

print("\n=== label_status distribution ===")
print(df["label_status"].value_counts())

print("\n=== is_mask distribution ===")
print(df["is_mask"].value_counts())

# Group by pair_id and is_mask to see how many images/masks each pair_id has
pair_stats = (
    df.groupby(["pair_id", "is_mask"])
    .size()
    .unstack(fill_value=0)
    .rename(columns={False: "num_image_like", True: "num_mask_like"})
)

print("\nExample of pairing statistics (first 10 rows):")
print(pair_stats.head(10))

# %% [markdown]
# ## Step 4: Organize 2D images into a clean directory structure
# 
# For easier training and data loading, we copy (or link) the images into:
# 
# ```text
# processed_2d/
#   adult/
#     labeled/
#       images/
#       masks/
#     unlabeled/
#       images/
#   children/
#     labeled/
#       images/
#       masks/
#     unlabeled/
#       images/
#   unknown/
#     ...
# ```
# 
# Notes:
# - `age_group` values other than `"adult"`/`"children"` are mapped to `"unknown"`.
# - `label_status` values other than `"labeled"`/`"unlabeled"` are mapped to `"unknown"`.
# - You can disable the copying by setting `ENABLE_COPY = False`.

In [14]:
# Copy files into processed_2d directory

# Whether to actually copy files.
# If the dataset is too large and you don't want duplication,
# you can set this to False and only use the CSV index.
ENABLE_COPY = True


def safe_copy(src: Path, dst: Path):
    """
    Copy file from src to dst, creating parent directories if needed.

    If dst already exists, the copy is skipped.
    """
    dst.parent.mkdir(parents=True, exist_ok=True)
    if dst.exists():
        return
    shutil.copy2(src, dst)


if ENABLE_COPY:
    print("\n[organize] copying 2D images into processed_2d by age_group / label_status / is_mask ...")

    for _, row in tqdm(df.iterrows(), total=len(df)):
        rel_path = Path(row["rel_path"])
        age_group = row["age_group"]
        label_status = row["label_status"]
        is_mask = bool(row["is_mask"])

        src = STS2D_ROOT / rel_path

        # Handle unknown values: group them into 'unknown'
        if age_group not in ("adult", "children"):
            age_group = "unknown"
        if label_status not in ("labeled", "unlabeled"):
            label_status = "unknown"

        # Decide which subfolder to use: 'images' or 'masks'
        sub = "masks" if is_mask else "images"

        # Target path:
        # processed_2d / {age_group} / {label_status} / images|masks / <filename>
        dst = PROCESSED_2D_DIR / age_group / label_status / sub / rel_path.name
        safe_copy(src, dst)

    print("[done] 2D images have been organized under:", PROCESSED_2D_DIR)
else:
    print("\n[skip] ENABLE_COPY is False; no files were copied.")

# %% [markdown]
# ## Step 5: Summary
# 
# - `INDEX_CSV` contains a full index of all 2D PNG files, with:
#   * `rel_path` (relative to `STS2D_ROOT`)
#   * `age_group` (`adult` / `children` / `unknown`)
#   * `label_status` (`labeled` / `unlabeled` / `unknown`)
#   * `is_mask` (True / False)
#   * `pair_id` (used to link images with masks)
# - `STS2D_ROOT` is the root of the original `STS-2D-Tooth` folder.
# - `PROCESSED_2D_DIR` contains the reorganized dataset, ready for training.


In [15]:
# Final summary printout
print("\n=== Summary ===")
print("Index CSV file:", INDEX_CSV)
print("Original 2D root directory:", STS2D_ROOT)
print("Processed 2D directory:", PROCESSED_2D_DIR)
print("You can now load the CSV with pandas and join `STS2D_ROOT` + `rel_path` to get absolute paths for training.")