# Auto-Download Parquet Files

This notebook demonstrates a proof-of-concept for automatically downloading parquet files from a Hugging Face dataset repository. It includes configuration, helper functions, and a main execution flow.

In [None]:
# CONFIG (edit only this block)
config = {
    "repo": "Lichess/standard-chess-games",  # Hugging Face repo id
    "year": "2025",                          # 4-digit year (string or int)
    "month": "7",                            # numeric month (e.g., "7" or "07")
    "max_parquets": 1,                    # int or None to download all available
    "output_dir": "data/raw/auto_download_parquets",
    "hf_token": None,                        # set to your HF token string if you need to access gated datasets
    "probe_max_attempts": 1000,              # for fallback probing
    "probe_patterns": [                      # tried in order if APIs gave no URLs
        # Pattern A: common "train-00000-of-00066.parquet" style
        "https://huggingface.co/datasets/{repo}/resolve/main/data/year={year}/month={month}/train-{idx:05d}-of-{total:05d}.parquet",
        # Pattern B: some datasets use plain shard names
        "https://huggingface.co/datasets/{repo}/resolve/main/data/year={year}/month={month}/train-{idx:05d}.parquet",
        # Pattern C: fall back to zero-padded 4-digit name
        "https://huggingface.co/datasets/{repo}/resolve/main/data/year={year}/month={month}/000{idx}.parquet",
    ],
}
# ---- end config ----

## Helper Functions

The following cells define helper functions for interacting with the Hugging Face API, filtering URLs, and downloading files.

In [None]:
import os, sys, time, math, json, typing, urllib.parse, re
from pathlib import Path
from typing import List, Optional
import requests

# Helper functions
def flatten_parquet_mapping(mapping: dict) -> List[str]:
    """Flatten the Hub /api/datasets/.../parquet mapping to a list of URLs."""
    urls = []
    if not isinstance(mapping, dict):
        return urls
    for subset_val in mapping.values():
        if isinstance(subset_val, dict):
            for split_val in subset_val.values():
                if isinstance(split_val, list):
                    urls.extend(split_val)
    return urls

def get_urls_from_hub_api(repo: str) -> List[str]:
    """Call https://huggingface.co/api/datasets/{repo}/parquet (Hub API)."""
    try:
        api = f"https://huggingface.co/api/datasets/{repo}/parquet"
        r = requests.get(api, headers=hf_headers, timeout=30)
        if r.status_code != 200:
            return []
        data = r.json()
        urls = flatten_parquet_mapping(data)
        return urls
    except Exception:
        return []

def get_urls_from_dataset_viewer(repo: str) -> List[str]:
    """Call dataset-viewer endpoint: https://datasets-server.huggingface.co/parquet?dataset={repo}"""
    try:
        api = "https://datasets-server.huggingface.co/parquet"
        params = {"dataset": repo}
        r = requests.get(api, headers=hf_headers, params=params, timeout=30)
        if r.status_code != 200:
            return []
        data = r.json()
        urls = [entry.get("url") for entry in data.get("parquet_files", []) if entry.get("url")]
        return urls
    except Exception:
        return []

def filter_urls_for_month(urls: List[str], year: str, month_padded: str) -> List[str]:
    """Return only URLs that contain the month/year partition (decoded)."""
    out = []
    for u in urls:
        decoded = urllib.parse.unquote(u)
        if f"year={year}/month={month_padded}" in decoded or f"year={year}/month={int(month_padded)}" in decoded:
            out.append(u)
    def shard_index(u):
        m = re.search(r"(\d{1,5})\.parquet$", urllib.parse.unquote(u))
        if m:
            return int(m.group(1))
        m2 = re.search(r"train-(\d{1,5})-of-(\d{1,5})\.parquet", urllib.parse.unquote(u))
        if m2:
            return int(m2.group(1))
        return 10**9
    return sorted(out, key=shard_index)

def probe_fallback_urls(repo: str, year: str, month: str, max_attempts: int, patterns: List[str]) -> List[str]:
    """If APIs fail, try probing plausible URL patterns until 404."""
    found = []
    for pattern in patterns:
        if "{total" in pattern:
            for total_guess in range(1, 201):
                consecutive_not_found = 0
                for idx in range(max_attempts):
                    url = pattern.format(repo=repo, year=year, month=month, idx=idx, total=total_guess)
                    if try_head(url):
                        found.append(url)
                        consecutive_not_found = 0
                    else:
                        consecutive_not_found += 1
                        break
                if found:
                    return found
        else:
            for idx in range(max_attempts):
                url = pattern.format(repo=repo, year=year, month=month, idx=idx)
                if try_head(url):
                    found.append(url)
                else:
                    break
            if found:
                return found
    return found

def try_head(url: str, timeout: int = 20) -> bool:
    """Quick HEAD-ish check (GET with stream and immediate close) to see if URL exists."""
    try:
        r = requests.get(url, headers=hf_headers, stream=True, timeout=timeout)
        if r.status_code == 200:
            r.raw.read(1)
            r.close()
            return True
        r.close()
        return False
    except Exception:
        return False

def download_file(url: str, dest: Path, chunk_size: int = 1024*32) -> bool:
    """Download url -> dest. Return True on success, False on 404 or error."""
    try:
        r = requests.get(url, headers=hf_headers, stream=True, timeout=60)
        if r.status_code == 404:
            return False
        r.raise_for_status()
        with open(dest, "wb") as f:
            for chunk in r.iter_content(chunk_size=chunk_size):
                if chunk:
                    f.write(chunk)
        return True
    except Exception as e:
        try:
            if r is not None:
                r.close()
        except Exception:
            pass
        print(f"  download error: {e}")
        return False

## Main Execution Flow

The following cell contains the main logic for querying parquet URLs, filtering them, and downloading the files.

In [None]:
# Main flow
print("1) Querying Hub API for parquet URLs...")
urls = get_urls_from_hub_api(repo)
if urls:
    print(f"  Hub API returned {len(urls)} total parquet URLs (unfiltered).")
else:
    print("  Hub API returned nothing (or failed).")

filtered = filter_urls_for_month(urls, year, month_padded)
if filtered:
    print(f"  Found {len(filtered)} parquet URLs for {year}/{month_padded} via Hub API.")
else:
    print("2) Trying dataset-viewer endpoint...")
    urls2 = get_urls_from_dataset_viewer(repo)
    if urls2:
        print(f"  dataset-viewer returned {len(urls2)} total parquet entries.")
        filtered = filter_urls_for_month(urls2, year, month_padded)
        if filtered:
            print(f"  Found {len(filtered)} parquet URLs for {year}/{month_padded} via dataset-viewer.")
if not filtered:
    print("3) No parquet URLs found via API; falling back to incremental probing (may be slower).")
    patterns = config.get("probe_patterns", [])
    found = probe_fallback_urls(repo, year, month_padded, config["probe_max_attempts"], patterns)
    filtered = found

if not filtered:
    print("ERROR: no parquet URLs discovered for that month/year by API or fallback probing. Aborting.")
    sys.exit(1)

if max_parquets is not None:
    filtered = filtered[:int(max_parquets)]

print(f"\nWill download {len(filtered)} file(s) into {out_dir.resolve()}\n")

success_count = 0
for i, url in enumerate(filtered):
    decoded = urllib.parse.unquote(url)
    filename = Path(decoded).name
    dest = out_dir / filename
    if dest.exists():
        print(f"[{i+1}/{len(filtered)}] Skipping (already exists): {filename}")
        success_count += 1
        continue
    print(f"[{i+1}/{len(filtered)}] Downloading: {filename}")
    ok = download_file(url, dest)
    if not ok:
        print(f"  Failed to download (skipping): {url}")
        if urls == []:
            print("  Probe-based download hit missing file — stopping probe downloads.")
            break
        else:
            continue
    success_count += 1
    time.sleep(0.5)

print(f"\nDone. {success_count} file(s) downloaded to: {out_dir.resolve()}")