In [None]:
#!/usr/bin/env python3
"""
Clean checkpoint folders for EMA-Predict experiments.

Keep only:
- model_iter{ITER}_best.pth
- model_iter{ITER}_last.pth

Remove other .pth files in:
folder_output/noise_*/alpha_*/iteration_*/checkpoints/

Usage:
- Set ls_folder_output in __main__ and run the script.
"""

from __future__ import annotations

import logging
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Tuple


ITER_DIR_PATTERN = re.compile(r"^iteration_(\d+)$")
KEEP_PATTERN_TEMPLATE = "model_iter{iter}_{suffix}.pth"


@dataclass(frozen=True)
class CleanConfig:
    """Configuration for cleaning checkpoint files."""
    folder_outputs: List[Path]
    dry_run: bool = True
    verbose: bool = True


def setup_logging(verbose: bool) -> None:
    """Configure logging."""
    level = logging.INFO if verbose else logging.WARNING
    logging.basicConfig(
        level=level,
        format="%(levelname)s - %(message)s",
    )


def find_checkpoints_dirs(folder_output: Path) -> List[Path]:
    """
    Find all checkpoints directories under folder_output.

    Expected structure:
      folder_output/noise_*/alpha_*/iteration_*/checkpoints/
    """
    if not folder_output.exists():
        logging.warning("Folder does not exist: %s", folder_output)
        return []

    checkpoints_dirs = list(folder_output.glob("noise_*/alpha_*/iteration_*/checkpoints"))
    # Also support Windows-style backslashes if data was created that way (Path handles it).
    return [p for p in checkpoints_dirs if p.is_dir()]


def parse_iteration_from_parent(checkpoints_dir: Path) -> int | None:
    """
    Parse iteration index from checkpoints_dir's parent folder name: iteration_{k}.
    """
    parent = checkpoints_dir.parent
    m = ITER_DIR_PATTERN.match(parent.name)
    if not m:
        return None
    return int(m.group(1))


def compute_keep_filenames(iter_idx: int) -> Tuple[str, str]:
    """Return the filenames to keep for a given iteration index."""
    best_name = KEEP_PATTERN_TEMPLATE.format(iter=iter_idx, suffix="best")
    last_name = KEEP_PATTERN_TEMPLATE.format(iter=iter_idx, suffix="last")
    return best_name, last_name


def list_pth_files(checkpoints_dir: Path) -> List[Path]:
    """List all .pth files in a checkpoints directory (non-recursive)."""
    return [p for p in checkpoints_dir.glob("*.pth") if p.is_file()]


def clean_checkpoints_dir(checkpoints_dir: Path, dry_run: bool) -> Tuple[int, int, List[str]]:
    """
    Clean a single checkpoints directory.

    Returns:
      (kept_count, removed_count, warnings)
    """
    warnings: List[str] = []
    iter_idx = parse_iteration_from_parent(checkpoints_dir)

    if iter_idx is None:
        warnings.append(f"Cannot parse iteration from: {checkpoints_dir}")
        return 0, 0, warnings

    keep_best, keep_last = compute_keep_filenames(iter_idx)
    keep_set = {keep_best, keep_last}

    pth_files = list_pth_files(checkpoints_dir)
    kept = 0
    removed = 0

    # Warn if expected keep files do not exist (optional but useful)
    for expected in keep_set:
        if not (checkpoints_dir / expected).exists():
            warnings.append(f"Missing expected file: {checkpoints_dir / expected}")

    for fpath in pth_files:
        if fpath.name in keep_set:
            kept += 1
            continue

        # Remove other .pth files
        removed += 1
        if dry_run:
            logging.info("[DRY-RUN] Remove: %s", fpath)
        else:
            try:
                fpath.unlink()
                logging.info("Removed: %s", fpath)
            except OSError as exc:
                warnings.append(f"Failed to remove {fpath}: {exc}")

    return kept, removed, warnings


def clean_folder_output(folder_output: Path, dry_run: bool) -> None:
    """Clean all checkpoints directories under a folder_output."""
    checkpoints_dirs = find_checkpoints_dirs(folder_output)
    if not checkpoints_dirs:
        logging.warning("No checkpoints dirs found under: %s", folder_output)
        return

    total_kept = 0
    total_removed = 0
    warn_msgs: List[str] = []

    logging.info("Scanning: %s (found %d checkpoints dirs)", folder_output, len(checkpoints_dirs))

    for ckpt_dir in checkpoints_dirs:
        kept, removed, warns = clean_checkpoints_dir(ckpt_dir, dry_run=dry_run)
        total_kept += kept
        total_removed += removed
        warn_msgs.extend(warns)

    logging.info("Done: %s | kept=%d | removed=%d", folder_output, total_kept, total_removed)

    if warn_msgs:
        logging.warning("Warnings (%d):", len(warn_msgs))
        for msg in warn_msgs:
            logging.warning(" - %s", msg)


def clean_many_folder_outputs(folder_outputs: Iterable[Path], dry_run: bool) -> None:
    """Clean multiple folder_output roots."""
    for folder_output in folder_outputs:
        clean_folder_output(folder_output, dry_run=dry_run)


if __name__ == "__main__":
    # -------------------------
    # INPUT SETTINGS (EDIT HERE)
    # -------------------------
    ls_folder_output = [
        # Example:
        "/mnt/c/Users/truon/learning/ptit/research/trung/M_10_01_2025/code_v2/project/notebooks/cifar10_iter_ema_noise_validation",
        "/mnt/d/code_v2/project/notebooks/cifar10_iter_ema_noise_validation",
        "/mnt/d/Cifar_backup_report",
        # r"/mnt/c/Users/truon/learning/ptit/research/trung/M_10_01_2025/code_v2/project/notebooks",
    ]

    config = CleanConfig(
        folder_outputs=[Path(p).expanduser().resolve() for p in ls_folder_output],
        dry_run=False,     # Set False to actually delete files
        verbose=True,
    )

    setup_logging(config.verbose)

    if not config.folder_outputs:
        logging.error("ls_folder_output is empty. Please set it in __main__.")
    else:
        clean_many_folder_outputs(config.folder_outputs, dry_run=config.dry_run)

        if config.dry_run:
            logging.info("Dry-run finished. Set dry_run=False to apply deletions.")


INFO - Scanning: /mnt/c/Users/truon/learning/ptit/research/trung/M_10_01_2025/code_v2/project/notebooks/cifar10_iter_ema_noise_validation (found 168 checkpoints dirs)
INFO - Removed: /mnt/c/Users/truon/learning/ptit/research/trung/M_10_01_2025/code_v2/project/notebooks/cifar10_iter_ema_noise_validation/noise_0.6/alpha_0.2/iteration_0/checkpoints/model_iter0_epoch10.pth
INFO - Removed: /mnt/c/Users/truon/learning/ptit/research/trung/M_10_01_2025/code_v2/project/notebooks/cifar10_iter_ema_noise_validation/noise_0.6/alpha_0.2/iteration_0/checkpoints/model_iter0_epoch15.pth
INFO - Removed: /mnt/c/Users/truon/learning/ptit/research/trung/M_10_01_2025/code_v2/project/notebooks/cifar10_iter_ema_noise_validation/noise_0.6/alpha_0.2/iteration_0/checkpoints/model_iter0_epoch20.pth
INFO - Removed: /mnt/c/Users/truon/learning/ptit/research/trung/M_10_01_2025/code_v2/project/notebooks/cifar10_iter_ema_noise_validation/noise_0.6/alpha_0.2/iteration_0/checkpoints/model_iter0_epoch25.pth
INFO - Remove