In [2]:
from pathlib import Path
import json
from datetime import datetime

from typing import List, Optional

import pandas as pd
import ipywidgets as widgets
from IPython.display import Markdown, display

import sys
sys.path.append('..')

from scripts.validate_sharded_dataset import validate as run_validator

In [3]:
PROJECT_ROOT = Path.cwd().resolve()
DATASET_DIR = PROJECT_ROOT / "data/processed/hest_v1_multitech_smoke"
INTERMEDIATE_DIR = None  # set to Path(...) to override the default
OUTPUT_DIR = PROJECT_ROOT / "outputs/validation_notebooks"
MAX_SPOTS_PER_SAMPLE = 0  # 0 = evaluate all spots per sample
COORD_TOL = 1.5
DEFAULT_SEED = 17

print(f"Project root: {PROJECT_ROOT}")
print(f"Dataset dir: {DATASET_DIR}")

Project root: /cpfs01/projects-HDD/cfff-afe2df89e32e_HDD/jjh_19301050235/git_repo/Spatial-Clip/notebooks
Dataset dir: /cpfs01/projects-HDD/cfff-afe2df89e32e_HDD/jjh_19301050235/git_repo/Spatial-Clip/notebooks/data/processed/hest_v1_multitech_smoke


In [5]:
def load_sample_ids(dataset_dir: Path) -> List[str]:
    dataset_dir = Path(dataset_dir)
    manifest_path = dataset_dir / "manifest.json"
    if manifest_path.exists():
        manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
        sample_dirs = manifest.get("outputs", {}).get("sample_dirs") or []
        if not sample_dirs:
            sample_dirs = [p.name for p in dataset_dir.iterdir() if p.is_dir()]
    else:
        sample_dirs = [p.name for p in dataset_dir.iterdir() if p.is_dir()]
    return sorted(sample_dirs)


def run_validation_job(
    dataset_dir: Path,
    output_dir: Path,
    samples: Optional[List[str]],
    max_spots: int,
    coord_tol: float,
    intermediate_dir: Optional[Path],
    seed: int,
) -> Path:
    dataset_dir = Path(dataset_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
    report_path = output_dir / f"{dataset_dir.name}_{timestamp}.json"
    sample_args = list(samples) if samples else None
    run_validator(
        dataset_dir=dataset_dir,
        intermediate_dir=Path(intermediate_dir) if intermediate_dir else None,
        max_spots_per_sample=max_spots,
        coordinate_tolerance=coord_tol,
        samples=sample_args,
        seed=seed,
        output_path=report_path,
    )
    return report_path


def load_report(report_path: Path):
    report_path = Path(report_path)
    data = json.loads(report_path.read_text(encoding="utf-8"))
    per_sample = pd.DataFrame(data.get("per_sample", []))
    return data, per_sample


def display_summary(summary: dict, per_sample: pd.DataFrame) -> None:
    header = f"""### Summary for {summary.get('dataset_key')}\n\n
Spots evaluated: {summary.get('spots_evaluated')} / {summary.get('total_spots_in_adata')}\n\n
Coordinate mismatches: {summary.get('coordinate_mismatches')}\n\n
Gene mismatches: {summary.get('gene_failures')}\n\n"""
    display(Markdown(header))
    if per_sample.empty:
        display(Markdown("No per-sample entries found in the report."))
        return
    important_cols = [
        "sample_id",
        "technology",
        "spots_evaluated",
        "coordinate_mismatches",
        "missing_reference_coords",
        "gene_failures",
    ]
    available_cols = [c for c in important_cols if c in per_sample.columns]
    display(per_sample[available_cols].sort_values(by="coordinate_mismatches", ascending=False))

In [7]:
if DATASET_DIR.exists():
    available_samples = load_sample_ids(DATASET_DIR)
else:
    available_samples = []
    print(f"Warning: Dataset directory {DATASET_DIR} does not exist. No samples loaded.")

sample_picker = widgets.SelectMultiple(
    options=available_samples,
    description="Samples",
    layout=widgets.Layout(width="50%", height="260px"),
)
max_spots_input = widgets.IntText(value=MAX_SPOTS_PER_SAMPLE, description="Max spots")
run_button = widgets.Button(description="Run validation", button_style="success")
output_area = widgets.Output()

def on_run_clicked(_):
    output_area.clear_output()
    with output_area:
        selected = list(sample_picker.value)
        report_path = run_validation_job(
            dataset_dir=DATASET_DIR,
            output_dir=OUTPUT_DIR,
            samples=selected if selected else None,
            max_spots=max_spots_input.value,
            coord_tol=COORD_TOL,
            intermediate_dir=INTERMEDIATE_DIR,
            seed=DEFAULT_SEED,
        )
        summary, per_sample_df = load_report(report_path)
        display_summary(summary, per_sample_df)
        display(Markdown(f"Report saved to `{report_path}`"))

run_button.on_click(on_run_clicked)

display(
    Markdown("Select samples (leave empty to validate every sample in the dataset)."),
    sample_picker,
    max_spots_input,
    run_button,
    output_area,
)



Select samples (leave empty to validate every sample in the dataset).

SelectMultiple(description='Samples', layout=Layout(height='260px', width='50%'), options=(), value=())

IntText(value=0, description='Max spots')

Button(button_style='success', description='Run validation', style=ButtonStyle())

Output()