# Build Manifest

In [1]:
import boto3, s3fs, gzip, io, os, tempfile, time
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.auto import tqdm

BUCKET      = "echodata25"
ROOT_PREFIX = "results/echo-study-2/"
DST_KEY     = "results/echo-images/all_unmasked_png_paths_2.clean.txt.gz"

THREADS       = 64           # tune 32-128 depending on instance CPU/network
STUDY_PAGE    = 1_000        # list 1 000 study prefixes per call
PNG_PAGE      = 1_000
FLUSH_LINES   = 100_000      # write to gzip every N lines

s3 = boto3.client("s3")

# ────────────────────────── step 1 ─ list study folders ──────────────────────
study_prefixes = []

paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(
    Bucket=BUCKET,
    Prefix=ROOT_PREFIX,
    Delimiter="/",               # <── get common prefixes (folder names)
    PaginationConfig={'PageSize': STUDY_PAGE},
)

for page in tqdm(pages, desc="study folders"):
    study_prefixes += [p["Prefix"] for p in page.get("CommonPrefixes", [])]

print(f"found {len(study_prefixes):,} study dirs")

# ───────────────── step 2 ─ list PNGs under each study in parallel ───────────
tmpf = tempfile.NamedTemporaryFile("wb", delete=False)
gz   = gzip.GzipFile(fileobj=tmpf, mode="wb", compresslevel=6)
buf  = []
lock = os.fsync                         # we just need a callable placeholder

def list_pngs(study_pref):
    paginator = s3.get_paginator("list_objects_v2")
    pages = paginator.paginate(
        Bucket=BUCKET,
        Prefix=study_pref + "unmasked/png/",
        PaginationConfig={'PageSize': PNG_PAGE},
    )
    for page in pages:
        for obj in page.get("Contents", []):
            yield obj["Key"]

def worker(study_pref):
    lines = []
    for key in list_pngs(study_pref):
        lines.append(f"s3://{BUCKET}/{key}\n")
    return "".join(lines)

t0 = time.time()
count = 0
bar = tqdm(total=len(study_prefixes), desc="studies processed")

with ThreadPoolExecutor(max_workers=THREADS) as pool:
    futures = {pool.submit(worker, p): p for p in study_prefixes}
    for fut in as_completed(futures):
        data = fut.result()
        if data:
            gz.write(data.encode())
            count += data.count("\n")
        bar.update(1)
        if count // FLUSH_LINES != (count - data.count("\n")) // FLUSH_LINES:
            # show key throughput every FLUSH_LINES
            elapsed = time.time() - t0
            speed = count / elapsed
            bar.set_postfix({"png": f"{count/1e6:.2f} M",
                             "speed": f"{speed:,.0f}/s"})

gz.close(); tmpf.close(); bar.close()
elapsed = time.time() - t0
print(f"\nlocal gzip complete  •  {count:,} PNGs  •  {elapsed/60:.1f} min")

# ───────────────── step 3 ─ single multipart upload ─────────────────────────
print("uploading …")
s3.upload_file(tmpf.name, BUCKET, DST_KEY)
os.unlink(tmpf.name)
print(f"✓ manifest at s3://{BUCKET}/{DST_KEY}")


study folders: 0it [00:00, ?it/s]

found 79,598 study dirs


studies processed:   0%|          | 0/79598 [00:00<?, ?it/s]


local gzip complete  •  5,059,180 PNGs  •  12.4 min
uploading …
✓ manifest at s3://echodata25/results/echo-images/all_unmasked_png_paths_2.clean.txt.gz


# Deduplicate

In [56]:
import pandas as pd
import os
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.max_colwidth', None)

In [None]:
import pandas as pd
import s3fs

PREFIX = "s3://echodata25/results/es2_preds/"      # folder with all rank-csvs
fs = s3fs.S3FileSystem(anon=False)

# ①  find every preds_rank*.csv in that prefix
paths = fs.glob(PREFIX + "preds_rank*.csv")
print(f"found {len(paths)} files")

# ②  load each CSV into a list of DataFrames
dfs = [
    pd.read_csv(
        fs.open(p, "rb"),
        dtype={                                   # make sure probability cols stay float
            "quality": "float32", "salience": "float32",
            **{f"p_{v}": "float32" for v in [
                "a2c","a3c","a4c","a5c","plax","tee","exclude",
                "psax-av","psax-mv","psax-ap","psax-pm"]},
        },
    )
    for p in paths
]

# ③  concatenate and reset the index
es2_done = pd.concat(dfs, ignore_index=True)
print(es2_done.shape)

In [58]:
print(es2_done.shape)

(2450725, 16)


In [57]:
es2_done.head()

Unnamed: 0,png_uri,mp4_uri,pred_view,quality,salience,p_a2c,p_a3c,p_a4c,p_a5c,p_plax,p_tee,p_exclude,p_psax-av,p_psax-mv,p_psax-ap,p_psax-pm
0,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724119076.370729/unmasked/png/1.2.276.0.7230010.3.1.4.895693665.1.1724119344.1004303.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724119076.370729/1.2.276.0.7230010.3.1.4.895693665.1.1724119344.1004303.mp4,a4c,0.0981,0.460436,0.285156,0.009262,0.615723,1.1e-05,0.00753,0.001049,0.081482,0.0,0.0,0.0,0.0
1,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724119076.370729/unmasked/png/1.2.276.0.7230010.3.1.4.1667523124.1.1724119579.372021.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724119076.370729/1.2.276.0.7230010.3.1.4.1667523124.1.1724119579.372021.mp4,exclude,0.078298,0.615481,0.003279,1.1e-05,0.012772,3e-06,0.001054,0.137207,0.845703,4e-06,7e-06,3e-06,0.000182
2,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724092694.307824/unmasked/png/1.2.276.0.7230010.3.1.4.1667523124.1.1724094344.311721.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724092694.307824/1.2.276.0.7230010.3.1.4.1667523124.1.1724094344.311721.mp4,tee,0.040586,0.698504,6e-06,0.0,8e-06,1e-06,1.4e-05,0.980469,0.019348,0.0,0.0,0.0,0.0
3,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1723914896.191703/unmasked/png/1.2.276.0.7230010.3.1.4.1667523124.1.1723914920.191729.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1723914896.191703/1.2.276.0.7230010.3.1.4.1667523124.1.1723914920.191729.mp4,psax-pm,0.07811,0.70566,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6e-05,0.000371,0.024765,0.974609
4,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724124498.384801/unmasked/png/1.2.276.0.7230010.3.1.4.1667523124.1.1724124531.384887.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724124498.384801/1.2.276.0.7230010.3.1.4.1667523124.1.1724124531.384887.mp4,exclude,0.127712,0.72806,0.014061,2e-05,0.000137,9e-06,4e-06,0.000183,0.985352,0.0,0.0,0.0,0.0


In [60]:
import boto3, gzip, io, os, tempfile, time
import pandas as pd
from tqdm.auto import tqdm

# ──────────────────────────── config ────────────────────────────────
BUCKET   = "echodata25"
SRC_KEY  = "results/echo-images/all_unmasked_png_paths_2.clean.txt.gz"
DST_KEY  = "results/echo-images/all_unmasked_png_paths_2.clean.dedup.txt.gz"
CHUNK    = 64 << 20            # 64 MiB read chunks for the progress pre-pass

# ───────────────────── load es2_done → python set ───────────────────
# print("loading es2_done …")
# es2_done = pd.read_parquet("es2_done.parquet", columns=["png_uri"])   # adjust loader if needed
done_set = set(es2_done["png_uri"].tolist())
print(f"→ {len(done_set):,} png_uri marked as complete")

# ───────────────────── establish S3 client / helpers ────────────────
s3 = boto3.client("s3")

def open_gz_from_s3(bucket: str, key: str) -> gzip.GzipFile:
    """Return a seekable gzip-file object backed by an in-memory buffer."""
    obj = s3.get_object(Bucket=bucket, Key=key)
    return gzip.GzipFile(fileobj=obj["Body"])

# ────────────── optional pre-pass to estimate total lines ───────────
print("scanning manifest once to size the progress bar …")
total_lines = 0
with open_gz_from_s3(BUCKET, SRC_KEY) as gz:
    while chunk := gz.read(CHUNK):
        total_lines += chunk.count(b"\n")
print(f"manifest contains ≈ {total_lines:,} lines")

# ────────────── stream, filter, write to tmp gzip locally ───────────
tmpf = tempfile.NamedTemporaryFile("wb", delete=False)
out_gz = gzip.GzipFile(fileobj=tmpf, mode="wb", compresslevel=6)

kept = removed = 0
t0   = time.time()

with open_gz_from_s3(BUCKET, SRC_KEY) as src_gz, \
     tqdm(total=total_lines, desc="filtering", unit="png", dynamic_ncols=True) as bar:
    for raw in src_gz:
        uri = raw.decode().rstrip("\n")
        if uri not in done_set:
            out_gz.write(raw)
            kept += 1
        else:
            removed += 1
        bar.update()

out_gz.close(); tmpf.close()
elapsed = time.time() - t0
print(f"kept {kept:,}  |  removed {removed:,}  •  {elapsed/60:.1f} min")

# ────────────────────── single multipart upload ────────────────────
print("uploading deduplicated manifest …")
s3.upload_file(tmpf.name, BUCKET, DST_KEY)
os.unlink(tmpf.name)
print(f"✓ uploaded to s3://{BUCKET}/{DST_KEY}")


→ 2,450,725 png_uri marked as complete
scanning manifest once to size the progress bar …
manifest contains ≈ 5,059,180 lines


filtering:   0%|          | 0/5059180 [00:00<?, ?png/s]

kept 2,608,455  |  removed 2,450,725  •  0.2 min
uploading deduplicated manifest …
✓ uploaded to s3://echodata25/results/echo-images/all_unmasked_png_paths_2.clean.dedup.txt.gz


# Classify

In [None]:
%%writefile batch_classify.py

In [64]:
%%writefile batch_classify.py
#!/usr/bin/env python3
# ============================================================================
#  batch_classify.py · 2025‑05‑14 (refreshed)
#  • multi‑GPU · FP16 · auto batch‑halve on OOM / IndexMath overflow
#  • strict checkpoint check ✔︎ + per‑rank skip logs + debug limit/dry‑run
#  • *NEW* crisp tqdm bars — one per rank, live ETA & throughput
# ============================================================================

"""Key UI changes
────────────────────────────────────────────────────────────────────────────────
* Each CUDA rank gets its own tqdm line (position=rank) so bars never clash.
* We pre‑count how many PNGs this rank will process. With a known total tqdm
  can compute ETA. The manifest scan adds <1 s even for multi‑million files.
* Progress bar shows imgs/s and ETA using a custom format string.
* Bar refreshes every second (mininterval).
"""

import argparse, csv, gzip, io, itertools, logging, os, re, time
from collections import Counter
from math import tanh
from typing import Iterable, Tuple, Optional

import boto3, s3fs
from botocore.config import Config as BotoCfg
from botocore.exceptions import ClientError, IncompleteReadError
from pathlib import Path
import cv2, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
import torch.distributed as dist
from PIL import Image
from torchvision import models, transforms
from tqdm.auto import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

# ───────────────────────────── constants ──────────────────────────────
VIEW = [
    "a2c","a3c","a4c","a5c","plax","tee","exclude",
    "psax-av","psax-mv","psax-ap","psax-pm"
]

PNG_ROW_RE = re.compile(
    r"""^s3://[^/]+/
        (?P<key>
            results/echo-study(?:-[12])?/        # echo-study/, echo-study-1/2/
            .+?/unmasked/png/
            .+\.png$
        )""",
    re.IGNORECASE | re.VERBOSE,
)

tf = transforms.Compose([
    transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

# ───────────────────── placeholders lifted to module scope ────────────
ctr: Counter = Counter()

def record_skip(kind: str, msg: str):
    ...  # rebound in main()

# ───────────────────────────── helpers ────────────────────────────────

def quality_score(img_bgr: np.ndarray) -> float:
    gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
    sharp = cv2.Laplacian(gray, cv2.CV_64F).var()
    bright = gray.mean() / 255.0
    return tanh(0.004 * sharp) * bright

def png_to_mp4(png_key: str) -> str:
    return png_key.replace("/unmasked/png/", "/")[:-4] + ".mp4"

def chunks(it: Iterable, n: int):
    while True:
        batch = list(itertools.islice(it, n))
        if not batch:
            break
        yield batch

def retry_open(fs: s3fs.S3FileSystem, path: str, tries: int = 3):
    delay = 1.0
    for attempt in range(tries):
        try:
            return fs.open(path, "rb")
        except ClientError as e:
            if attempt == tries - 1:
                raise
            if e.response["Error"]["Code"] in ("500", "503", "InternalError"):
                time.sleep(delay)
                delay *= 1.5
            else:
                raise

# ───────────────────────────── model ──────────────────────────────────
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.efficientnet_b2(weights=None)
        base.classifier = nn.Identity()
        self.b = base
        f = 1408
        self.vb = nn.Linear(f, 2)
        self.vo = nn.Linear(f, 7)
        self.vs = nn.Linear(f, 4)

    def forward(self, x):
        f = self.b(x)
        pb, po, ps = (F.softmax(h(f), 1) for h in (self.vb, self.vo, self.vs))
        out = x.new_zeros(x.size(0), 11)
        out[:, :7] = pb[:, :1] * po
        out[:, 7:] = pb[:, 1:] * ps
        return out

# ─────────────────────── manifest utilities ──────────────────────────

def open_body(s3, uri: str):
    """Download manifest once per rank → /tmp/rankX.manifest.gz, then open."""
    bucket, key = uri[5:].split("/", 1)
    local = Path(f"/tmp/manifest_rank{os.getenv('LOCAL_RANK')}.gz")
    if not local.exists():
        logging.info("rank%s downloading manifest → %s", os.getenv("LOCAL_RANK"), local)
        with s3fs.S3FileSystem(anon=False).open(f"s3://{bucket}/{key}", "rb") as src, local.open("wb") as dst:
            for chunk in iter(lambda: src.read(8 << 20), b""):
                dst.write(chunk)
    fh = local.open("rb")
    return gzip.GzipFile(fileobj=fh) if key.endswith(".gz") else fh


def iter_manifest(s3, uri: str, world: int, rank: int, limit: Optional[int]):
    seen = 0
    for idx, raw in enumerate(open_body(s3, uri)):
        if idx % world != rank:
            continue
        if limit is not None and seen >= limit:
            break
        line = raw.strip().decode()
        m = PNG_ROW_RE.match(line)
        if m:
            yield m.group("key")
            seen += 1
        else:
            ctr["regex"] += 1
            record_skip("REGEX", line)

# ───────────────────────────── main ──────────────────────────────────

def count_samples(s3, uri: str, world: int, rank: int, limit: Optional[int]) -> int:
    """Fast pass over manifest to know how many PNGs this rank owns."""
    n = 0
    for _ in iter_manifest(s3, uri, world, rank, limit):
        n += 1
    return n


def main(a):
    log_level = os.getenv("LOG_LEVEL", "INFO").upper()
    logging.basicConfig(
        format="%(asctime)s %(levelname)s │ %(message)s",
        datefmt="%H:%M:%S",
        level=getattr(logging, log_level, logging.INFO),
        force=True,
    )

    dist.init_process_group("nccl")
    rank = int(os.environ["LOCAL_RANK"])
    world = dist.get_world_size()
    dev_id = rank % torch.cuda.device_count()
    torch.cuda.set_device(dev_id)
    device = torch.device(f"cuda:{dev_id}")

    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = True

    MAX_WORKERS = int(os.getenv("MAX_WORKERS", "256"))
    s3 = boto3.client("s3", config=BotoCfg(max_pool_connections=MAX_WORKERS))
    fs = s3fs.S3FileSystem(
        anon=False,
        default_block_size=8 << 20,
        default_fill_cache=False,
        config_kwargs={"max_pool_connections": MAX_WORKERS},
    )

    # ───── per-rank skip file ─────
    skip_path = f"/opt/ml/processing/output/skip_rank{rank}.txt.gz"
    skip_fh = gzip.open(skip_path, "wt")

    def _record(kind: str, msg: str):
        skip_fh.write(f"{kind}\t{msg}\n")

    globals()["record_skip"] = _record  # make visible globally

    # ───── dry-run? just count regex matches, exit early ─────
    if a.dry_run:
        for _ in iter_manifest(s3, a.manifest_s3, world, rank, a.limit):
            pass
        skip_fh.close()
        logging.info("DRY-RUN finished – regex %d", ctr["regex"])
        return

    # ───── progress bar prep  ─────
    total_imgs = count_samples(s3, a.manifest_s3, world, rank, a.limit)
    logging.info("rank%d will process ~%s imgs", rank, f"{total_imgs:,d}" if total_imgs else "?")

    bar = tqdm(
        total=total_imgs or None,
        desc=f"rank {rank}",
        unit="img",
        position=rank,
        dynamic_ncols=True,
        smoothing=0.1,
        mininterval=1.0,
        bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed} < {remaining}, {rate_fmt}]",
    )

    # ───── model ─────
    net = Net().to(device).eval()
    with fs.open(a.model_s3, "rb") as f:
        state = torch.load(io.BytesIO(f.read()), map_location="cpu")
    ck = net.load_state_dict(state, strict=False)
    if ck.missing_keys or ck.unexpected_keys:
        logging.warning("‼️ checkpoint mismatch – %d missing, %d unexpected", len(ck.missing_keys), len(ck.unexpected_keys))
    else:
        logging.info("✅ checkpoint keys match perfectly")
    net.half()

    # ───── output CSV ─────
    os.makedirs("/opt/ml/processing/output", exist_ok=True)
    csv_path = f"/opt/ml/processing/output/preds_rank{rank}.csv"
    header = ["png_uri", "mp4_uri", "pred_view", "quality", "salience"] + [f"p_{v}" for v in VIEW]

    key_iter = iter_manifest(s3, a.manifest_s3, world, rank, a.limit)
    processed = 0
    t0 = time.time()

    pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)

    def load_one(k: str) -> Tuple[str, torch.Tensor, float, bool]:
        try:
            with retry_open(fs, f"{a.bucket}/{k}") as f:
                arr = np.frombuffer(f.read(), np.uint8)
            img = cv2.imdecode(arr, cv2.IMREAD_COLOR)
            if img is None:
                raise ValueError("cv2.imdecode returned None")
            q = quality_score(img)
            ten = tf(Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))).half()
            return k, ten, q, False
        except ClientError as ce:
            ctr["open"] += 1
            record_skip("OPEN", f"{k}\t{ce}")
        except Exception as exc:
            ctr["decode"] += 1
            record_skip("DECODE", f"{k}\t{exc}")
            logging.debug("DECODE‑fail %s — %s", k, exc)
        return k, None, None, True

    def safe_infer(b: torch.Tensor) -> np.ndarray:
        cur = b
        while True:
            try:
                with torch.cuda.amp.autocast(), torch.no_grad():
                    return net(cur).cpu().numpy()
            except (torch.cuda.OutOfMemoryError, RuntimeError) as exc:
                if ("indexmath" not in str(exc).lower() and not isinstance(exc, torch.cuda.OutOfMemoryError)):
                    raise
                torch.cuda.empty_cache()
                cur = cur[: max(128, cur.size(0) // 2)]

    with open(csv_path, "w", newline="") as fh:
        w = csv.writer(fh)
        w.writerow(header)

        for keys in chunks(key_iter, a.batch_size):
            futs = [pool.submit(load_one, k) for k in keys]
            tens, qs, oks = [], [], []
            for fut in as_completed(futs):
                k, ten, q, bad = fut.result()
                if bad:
                    continue
                tens.append(ten)
                qs.append(q)
                oks.append(k)

            if not oks:
                continue

            probs = safe_infer(torch.stack(tens).to(device, non_blocking=True))
            for k, q, p in zip(oks, qs, probs):
                sal = 0.7 * p.max() + 0.3 * q
                w.writerow([
                    f"s3://{a.bucket}/{k}",
                    f"s3://{a.bucket}/{png_to_mp4(k)}",
                    VIEW[int(p.argmax())],
                    round(q, 6),
                    round(sal, 6),
                    *map(lambda x: round(float(x), 6), p),
                ])

            processed += len(oks)
            bar.update(len(oks))

    bar.close()
    skip_fh.close()
    elapsed = time.time() - t0
    logging.info(
        "✓ finished — %.1f min | %s OK | drops %s",
        elapsed / 60,
        f"{processed:,d}",
        ", ".join(f"{k}:{v}" for k, v in ctr.items()) or "none",
    )

# ────────────────────────── CLI wiring ────────────────────────────────
if __name__ == "__main__":
    P = argparse.ArgumentParser()
    P.add_argument("--bucket", required=True)
    P.add_argument("--manifest_s3", required=True)
    P.add_argument("--model_s3", required=True)
    P.add_argument("--batch_size", type=int, default=2048)
    P.add_argument("--limit", type=int, default=None, help="debug‑only: per‑rank cap on #pngs")
    P.add_argument("--dry_run", action="store_true", help="only parse manifest + regex stats (no decoding/infer)")
    main(P.parse_args())


Overwriting batch_classify.py


In [65]:
# ! pip install -U sagemaker

In [66]:
%%bash
cat > requirements.txt <<'REQ'
opencv-python-headless==4.11.0.86
tqdm
s3fs
REQ


In [53]:
# ! aws s3 ls s3://echodata25/results/echo-images/ | grep -i unmasked_png_paths

In [67]:
from sagemaker.pytorch import PyTorchProcessor
from sagemaker.processing import ProcessingOutput
import sagemaker, time

import os

role = sagemaker.get_execution_role()          # works in Studio/Jupyter
# N_GPUS = int(os.environ.get("SM_NUM_GPUS", "1"))
env = {"MAX_WORKERS": "256"}   # 🡐 move into the constructor

role = sagemaker.get_execution_role()

proc = PyTorchProcessor(
    framework_version="2.1",
    py_version="py310",
    role=role,
    instance_type="ml.g5.48xlarge",      # 8 × A10 G
    instance_count=1,
    volume_size_in_gb=100,
    max_runtime_in_seconds=6*60*60,
    command=["torchrun", "--nproc_per_node", "8"],
    env=env,
)

proc.run(
    code="batch_classify.py",
    arguments=[
        "--bucket","echodata25",
        "--manifest_s3","s3://echodata25/results/echo-images/all_unmasked_png_paths_2.clean.dedup.txt.gz",
        "--model_s3","s3://echodata25/results/models/view_classifier/best_f1_84.pt",
        "--batch_size","2048",           # fits with FP16
    ],
    outputs=[ProcessingOutput(
        source="/opt/ml/processing/output",
        destination="s3://echodata25/results/es2_preds_dedup",
        output_name="preds")],
    job_name=f"view-classify-{int(time.time())}",
)


print("🚀 submitted — watch the Processing-Job logs for per-rank counters.")

[34m14:57:04 INFO │ rank0 downloading manifest → /tmp/manifest_rank0.gz[0m
[34m14:57:05 INFO │ rank3 downloading manifest → /tmp/manifest_rank3.gz[0m
[34m14:57:05 INFO │ rank4 downloading manifest → /tmp/manifest_rank4.gz[0m
[34m14:57:05 INFO │ rank2 downloading manifest → /tmp/manifest_rank2.gz[0m
[34m14:57:05 INFO │ rank5 downloading manifest → /tmp/manifest_rank5.gz[0m
[34m14:57:05 INFO │ rank1 downloading manifest → /tmp/manifest_rank1.gz[0m
[34m14:57:05 INFO │ rank6 downloading manifest → /tmp/manifest_rank6.gz[0m
[34m14:57:05 INFO │ rank7 downloading manifest → /tmp/manifest_rank7.gz[0m
[34m14:57:08 INFO │ rank0 will process ~306,255 imgs[0m
[34m#015rank 0:   0%|          | 0/306255 [00:00 < ?, ?img/s]14:57:08 INFO │ rank5 will process ~306,117 imgs[0m
[34m#015rank 5:   0%|          | 0/306117 [00:00 < ?, ?img/s]#033[A#033[A#033[A#033[A#033[A14:57:08 INFO │ rank2 will process ~306,296 imgs[0m
[34m#015rank 2:   0%|          | 0/306296 [00:00 < ?, ?img/s]#033

# Validate

In [68]:
import io, gzip, s3fs, pandas as pd

PREFIX = "s3://echodata25/results/es2_preds_dedup/"
fs      = s3fs.S3FileSystem(anon=False)

In [69]:
gz_paths = fs.glob(PREFIX + "skip_rank*.txt.gz")
print("found", len(gz_paths), "gzip files")

# ── Option 1 ─ let pandas decompress ───────────────────────────
dfs1 = [
    pd.read_csv(
        fs.open(p, "rb"),         # binary handle
        compression="gzip",       # <- tell pandas it's gzipped
        dtype={"reason":"category", "raw":"string"},
    )
    for p in gz_paths
]
skips = pd.concat(dfs1, ignore_index=True)

# ── Option 2 ─ manual gunzip, then pandas ─────────────────────
dfs2 = []
for p in gz_paths:
    with fs.open(p, "rb") as f, gzip.GzipFile(fileobj=f) as gz, \
         io.TextIOWrapper(gz, encoding="utf-8") as txt:
        dfs2.append(pd.read_csv(txt,
                                dtype={"reason":"category", "raw":"string"}))

skips_manual = pd.concat(dfs2, ignore_index=True)

assert skips.equals(skips_manual)      # both give the same result
print("total skipped:", len(skips))


found 8 gzip files
total skipped: 318329


In [70]:
# ── 1.  collect key lists ──────────────────────────────────────────────
gz_keys     = fs.glob(PREFIX + "skip_rank*.txt.gz")       #  DECODE failures
plain_keys  = fs.glob(PREFIX + "regex_skip_rank*.txt")    #  REGEX drops

print(f"{len(gz_keys)} gzip, {len(plain_keys)} plain-text skip files")

# ── 2.  helper readers ─────────────────────────────────────────────────
def read_gzip(key: str) -> pd.DataFrame:
    """'reason<TAB>raw' lines inside a .gz file."""
    with fs.open(key, "rb") as f, gzip.open(f, "rt", encoding="utf-8") as g:
        recs = [ln.rstrip("\n").split("\t", 1) for ln in g if ln.strip()]
    return pd.DataFrame(recs, columns=["reason", "raw"])

def read_plain(key: str) -> pd.DataFrame:
    with fs.open(key, "rt", encoding="utf-8") as f:
        recs = [ln.rstrip("\n").split("\t", 1) for ln in f if ln.strip()]
    return pd.DataFrame(recs, columns=["reason", "raw"])

# ── 3.  load & concat ──────────────────────────────────────────────────
frames = [read_gzip(k)  for k in gz_keys] + \
         [read_plain(k) for k in plain_keys]

skips  = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame(
         columns=["reason", "raw"])

print("total skips loaded:", len(skips))

# ── 4.  filter the “harmless” ones ─────────────────────────────────────
is_mask_vid  = skips.raw.str.endswith("mask_visualization.mp4")
is_blank_dir = skips.raw.str.endswith("/unmasked/png/") | \
               skips.raw.str.endswith("/unmasked/png")

interesting   = skips.loc[~is_mask_vid & ~is_blank_dir]

print("⚠️  interesting (non-benign) skips:", len(interesting))
interesting.head()

8 gzip, 0 plain-text skip files
total skips loaded: 318337
⚠️  interesting (non-benign) skips: 1


Unnamed: 0,reason,raw
142272,DECODE,results/echo-study-2/1.2.276.0.7230010.3.1.2.1714500150.1.1724725467.768670/unmasked/png/1.2.276.0.7230010.3.1.4.895693665.1.1724726207.2852821.png\tOpenCV(4.10.0) /io/opencv/modules/imgcodecs/src/loadsave.cpp:813: error: (-215:Assertion failed) !buf.empty() in function 'imdecode_'


In [71]:
skips.info(show_counts=True)        # dtypes, non-null counts
skips.sample(5, random_state=0)     # random spot-check

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 318337 entries, 0 to 318336
Data columns (total 2 columns):
 #   Column  Non-Null Count   Dtype 
---  ------  --------------   ----- 
 0   reason  318337 non-null  object
 1   raw     318337 non-null  object
dtypes: object(2)
memory usage: 4.9+ MB


Unnamed: 0,reason,raw
81833,REGEX,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1725102138.2025732/unmasked/png/
317468,REGEX,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.929392183.1.1724186898.688050/unmasked/png/mask_visualization.mp4
296167,REGEX,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.895693665.1.1725110606.3993439/unmasked/png/mask_visualization.mp4
148533,REGEX,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.895627313.1.1724150016.959167/unmasked/png/
267543,REGEX,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.842097970.1.1725165717.2647134/unmasked/png/


In [72]:
reason_cnt = skips['reason'].value_counts().rename_axis('reason')
display(reason_cnt.to_frame('rows'))

print("\nPercentage distribution")
print((reason_cnt / len(skips) * 100).round(2).astype(str) + ' %')

Unnamed: 0_level_0,rows
reason,Unnamed: 1_level_1
REGEX,318336
DECODE,1



Percentage distribution
reason
REGEX     100.0 %
DECODE      0.0 %
Name: count, dtype: object


In [73]:
mask_vid   = skips['raw'].str.endswith("mask_visualization.mp4")
placeholer = skips['raw'].str.endswith("/unmasked/png/")

print("mask videos     :", mask_vid.sum())
print("empty directories:", placeholer.sum())
print("≈ benign regex skips:",
      mask_vid.sum() + placeholer.sum(), "of", len(skips.query('reason == "REGEX"')))

mask videos     : 159158
empty directories: 159178
≈ benign regex skips: 318336 of 318336


# Proper MP4 Patching

In [None]:
import boto3, gzip, io, os, tempfile, time
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm.auto import tqdm

BUCKET       = "echodata25"
ROOT_PREFIX  = "results/echo-study-2/"
DST_KEY      = "results/echo-images/es2_all_mp4_paths.txt.gz"

THREADS      = 64
STUDY_PAGE   = 1_000
MP4_PAGE     = 1_000
FLUSH_LINES  = 100_000

s3 = boto3.client("s3")

# ── list study folders ───────────────────────────────────────────────────────
study_pref = []
for page in s3.get_paginator("list_objects_v2").paginate(
        Bucket=BUCKET, Prefix=ROOT_PREFIX, Delimiter="/",
        PaginationConfig={'PageSize': STUDY_PAGE}):
    study_pref += [p["Prefix"] for p in page.get("CommonPrefixes", [])]

# ── worker: list *.mp4 under a study root in one pass ───────────────────────
def list_mp4s(study_pref):
    paginator = s3.get_paginator("list_objects_v2")
    pages = paginator.paginate(
        Bucket=BUCKET, Prefix=study_pref,
        PaginationConfig={'PageSize': MP4_PAGE})
    for page in pages:
        for obj in page.get("Contents", []):
            if obj["Key"].endswith(".mp4"):
                yield f"s3://{BUCKET}/{obj['Key']}\n"

# ── build compressed manifest ───────────────────────────────────────────────
tmp = tempfile.NamedTemporaryFile("wb", delete=False)
gz  = gzip.GzipFile(fileobj=tmp, mode="wb", compresslevel=6)

count, t0 = 0, time.time()
bar = tqdm(total=len(study_pref), desc="studies")

with ThreadPoolExecutor(max_workers=THREADS) as pool:
    futures = {pool.submit(list_mp4s, p): p for p in study_pref}
    for fut in as_completed(futures):
        for line in fut.result():
            gz.write(line.encode())
            count += 1
        bar.update(1)
        if count % FLUSH_LINES == 0:
            elapsed = time.time() - t0
            bar.set_postfix({"mp4": f"{count/1e6:.2f} M",
                             "speed": f"{count/elapsed:,.0f}/s"})

gz.close(); tmp.close(); bar.close()
s3.upload_file(tmp.name, BUCKET, DST_KEY)
os.unlink(tmp.name)


In [109]:
import boto3, gzip, io, pandas as pd, s3fs, re

# ─────────────────────────── A. load dataframe ──────────────────────────────
PREFIX = "s3://echodata25/results/es2_preds_dedup/"
fs     = s3fs.S3FileSystem(anon=False)

paths  = fs.glob(PREFIX + "preds_rank*.csv")
es2    = pd.concat(
           [pd.read_csv(fs.open(p, "rb"), dtype={
               "quality":"float32","salience":"float32",
               **{f"p_{v}":"float32" for v in
                 ["a2c","a3c","a4c","a5c","plax","tee","exclude",
                  "psax-av","psax-mv","psax-ap","psax-pm"]} }
            ) for p in paths],
           ignore_index=True,
         )

# drop the bad column
es2 = es2.drop(columns="mp4_uri", errors="ignore")

# ─────────────────────────── B. extract UID from PNG ────────────────────────
# UID = stem before .png (same stem will appear before .mp4)
es2["uid"] = es2["png_uri"].str.replace(r".*/([^/]+)\.png$", r"\1", regex=True)

needed = set(es2["uid"])

# ─────────────────────────── C. stream manifest and map UIDs ────────────────
MANIFEST_KEY = "results/echo-images/es2_all_mp4_paths.txt.gz"

s3   = boto3.client("s3")
buf  = io.BytesIO()
s3.download_fileobj("echodata25", MANIFEST_KEY, buf)
buf.seek(0)

uid2mp4 = {}
with gzip.open(buf, "rt") as g:
    for line in g:
        path = line.strip()
        uid  = path.rsplit("/", 1)[-1][:-4]     # drop ".mp4"
        if uid in needed:
            uid2mp4[uid] = f"s3://{path}"       # prepend scheme

# ─────────────────────────── D. attach correct mp4_uri ──────────────────────
es2["mp4_uri"] = es2["uid"].map(uid2mp4)

missing = es2["mp4_uri"].isna().sum()
if missing:
    print(f"warning: {missing:,} UIDs were not found in the manifest")

# ─────────────────────────── E. tidy up ─────────────────────────────────────
es2 = es2.drop(columns="uid")
es2 = es2_dedup

In [110]:
import boto3, gzip, io, pandas as pd, s3fs, re

# ─────────────────────────── A. load dataframe ──────────────────────────────
PREFIX = "s3://echodata25/results/es2_preds/"
fs     = s3fs.S3FileSystem(anon=False)

paths  = fs.glob(PREFIX + "preds_rank*.csv")
es2    = pd.concat(
           [pd.read_csv(fs.open(p, "rb"), dtype={
               "quality":"float32","salience":"float32",
               **{f"p_{v}":"float32" for v in
                 ["a2c","a3c","a4c","a5c","plax","tee","exclude",
                  "psax-av","psax-mv","psax-ap","psax-pm"]} }
            ) for p in paths],
           ignore_index=True,
         )

# drop the bad column
es2 = es2.drop(columns="mp4_uri", errors="ignore")

# ─────────────────────────── B. extract UID from PNG ────────────────────────
# UID = stem before .png (same stem will appear before .mp4)
es2["uid"] = es2["png_uri"].str.replace(r".*/([^/]+)\.png$", r"\1", regex=True)

needed = set(es2["uid"])

# ─────────────────────────── C. stream manifest and map UIDs ────────────────
MANIFEST_KEY = "results/echo-images/es2_all_mp4_paths.txt.gz"

s3   = boto3.client("s3")
buf  = io.BytesIO()
s3.download_fileobj("echodata25", MANIFEST_KEY, buf)
buf.seek(0)

uid2mp4 = {}
with gzip.open(buf, "rt") as g:
    for line in g:
        path = line.strip()
        uid  = path.rsplit("/", 1)[-1][:-4]     # drop ".mp4"
        if uid in needed:
            uid2mp4[uid] = f"s3://{path}"       # prepend scheme

# ─────────────────────────── D. attach correct mp4_uri ──────────────────────
es2["mp4_uri"] = es2["uid"].map(uid2mp4)

missing = es2["mp4_uri"].isna().sum()
if missing:
    print(f"warning: {missing:,} UIDs were not found in the manifest")

# ─────────────────────────── E. tidy up ─────────────────────────────────────
es2 = es2.drop(columns="uid")
es2 = es2_done

# Recombine

In [111]:
# import pandas as pd
# import s3fs

# PREFIX = "s3://echodata25/results/es2_preds_dedup/"      # folder with all rank-csvs
# fs = s3fs.S3FileSystem(anon=False)

# # ①  find every preds_rank*.csv in that prefix
# paths = fs.glob(PREFIX + "preds_rank*.csv")
# print(f"found {len(paths)} files")

# # ②  load each CSV into a list of DataFrames
# dfs = [
#     pd.read_csv(
#         fs.open(p, "rb"),
#         dtype={                                   # make sure probability cols stay float
#             "quality": "float32", "salience": "float32",
#             **{f"p_{v}": "float32" for v in [
#                 "a2c","a3c","a4c","a5c","plax","tee","exclude",
#                 "psax-av","psax-mv","psax-ap","psax-pm"]},
#         },
#     )
#     for p in paths
# ]

# # ③  concatenate and reset the index
# es2_dedup = pd.concat(dfs, ignore_index=True)
# print(es2_dedup.shape)

In [112]:
# import pandas as pd
# import s3fs

# PREFIX = "s3://echodata25/results/es2_preds/"      # folder with all rank-csvs
# fs = s3fs.S3FileSystem(anon=False)

# # ①  find every preds_rank*.csv in that prefix
# paths = fs.glob(PREFIX + "preds_rank*.csv")
# print(f"found {len(paths)} files")

# # ②  load each CSV into a list of DataFrames
# dfs = [
#     pd.read_csv(
#         fs.open(p, "rb"),
#         dtype={                                   # make sure probability cols stay float
#             "quality": "float32", "salience": "float32",
#             **{f"p_{v}": "float32" for v in [
#                 "a2c","a3c","a4c","a5c","plax","tee","exclude",
#                 "psax-av","psax-mv","psax-ap","psax-pm"]},
#         },
#     )
#     for p in paths
# ]

# # ③  concatenate and reset the index
# es2_done = pd.concat(dfs, ignore_index=True)
# print(es2_done.shape)

In [113]:
# es2_done.head()

In [114]:
# count of overlapping png_uri values
n_overlap = es2_done['png_uri'].isin(es2_dedup['png_uri']).sum()

# share of each frame in its own df
pct_done   = n_overlap / len(es2_done)   # fraction of es2_done present in es2_dedup
pct_dedup  = n_overlap / len(es2_dedup)  # fraction of es2_dedup present in es2_done

print(pct_done)
print(pct_dedup)

0.0
0.0


# Fix MP4 Paths

In [115]:
# import boto3, pandas as pd, re, time
# from concurrent.futures import ThreadPoolExecutor, as_completed
# from tqdm.auto import tqdm

# BUCKET       = "echodata25"
# ROOT_PREFIX  = "results/echo-study-2/"      # adjust if you have echo-study-1, etc.
# THREADS      = 64
# STUDY_PAGE   = 1_000
# MP4_PAGE     = 1_000

# s3 = boto3.client("s3")

# # ────────────────────────── 1. collect study prefixes ───────────────────────
# study_pref = []
# for page in s3.get_paginator("list_objects_v2").paginate(
#         Bucket=BUCKET, Prefix=ROOT_PREFIX, Delimiter="/",
#         PaginationConfig={'PageSize': STUDY_PAGE}):
#     study_pref += [p["Prefix"] for p in page.get("CommonPrefixes", [])]

# print(f"{len(study_pref):,} studies")

# # ────────────────────────── 2. pick best series per study ───────────────────
# def count_mp4s(prefix):
#     """return (study_id, best_series_uid | None)"""
#     # study_id ends right before the trailing slash
#     study_id = prefix.rstrip("/").split("/")[-1]

#     # list top-level children of the study (Delimiter="/")
#     series_counts = {}
#     for page in s3.get_paginator("list_objects_v2").paginate(
#             Bucket=BUCKET, Prefix=prefix, Delimiter="/"):
#         for cp in page.get("CommonPrefixes", []):
#             uid = cp["Prefix"].split("/")[-2]          # folder name
#             if uid in ("png", "unmasked"):             # skip aux dirs
#                 continue

#             # count *.mp4 in that uid folder
#             n = 0
#             for mp in s3.get_paginator("list_objects_v2").paginate(
#                     Bucket=BUCKET, Prefix=cp["Prefix"],
#                     PaginationConfig={'PageSize': MP4_PAGE}):
#                 n += sum(1 for obj in mp.get("Contents", [])
#                          if obj["Key"].endswith(".mp4"))
#             if n:
#                 series_counts[uid] = n

#     best = max(series_counts, key=series_counts.get) if series_counts else None
#     return study_id, best

# series_map = {}
# t0 = time.time()
# with ThreadPoolExecutor(max_workers=THREADS) as pool:
#     futures = {pool.submit(count_mp4s, p): p for p in study_pref}
#     for fut in tqdm(as_completed(futures), total=len(futures),
#                     desc="series scan"):
#         sid, best = fut.result()
#         if best:
#             series_map[sid] = best
# elapsed = time.time() - t0
# print(f"done in {elapsed/60:.1f} min  •  {len(series_map):,} studies mapped")

In [116]:
# import pandas as pd
# import re

# # pre-compile once
# _STUDY_RE = re.compile(r"echo-study(?:-\d+)?/([^/]+)/")   # echo-study/, echo-study-1/, echo-study-23/ …

# def patch(df: pd.DataFrame) -> pd.DataFrame:
#     """
#     • Extracts the de-identified study UID from either
#         echo-study/…/, echo-study-1/…/, echo-study-2/…/  etc.
#     • Looks up that UID in `series_map` and, when found, inserts the
#       *best* series UID one level above the MP4 filename.
#     • Returns a fresh dataframe with the corrected mp4_uri column.
#     """
#     df = df.copy()

#     # ── 1. pull study id from the S3 key ──────────────────────────────────
#     df["study_id"] = df["png_uri"].str.extract(_STUDY_RE, expand=False)

#     # ── 2. build the updated mp4_uri only where we have a replacement ────
#     uid_series = df["study_id"].map(series_map)        # NaN where no match
#     mask = uid_series.notna()

#     if mask.any():
#         # split once into head / filename
#         split = df.loc[mask, "mp4_uri"].str.rsplit("/", n=1, expand=True)
#         df.loc[mask, "mp4_uri"] = (
#             split[0] + "/" + uid_series[mask] + "/" + split[1]
#         )

#     return df.drop(columns="study_id")


In [117]:
es2_all = (pd.concat([es2_done, es2_dedup], ignore_index=True)
             .drop_duplicates(subset="png_uri", keep="first"))
es2_all = patch(es2_all)

In [118]:
es2_all.shape

(3679403, 16)

In [119]:
es2_all.head()

Unnamed: 0,png_uri,mp4_uri,pred_view,quality,salience,p_a2c,p_a3c,p_a4c,p_a5c,p_plax,p_tee,p_exclude,p_psax-av,p_psax-mv,p_psax-ap,p_psax-pm
0,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724119076.370729/unmasked/png/1.2.276.0.7230010.3.1.4.895693665.1.1724119344.1004303.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724119076.370729/1.2.276.0.7230010.3.1.3.1667523124.1.1724119076.370730/1.2.276.0.7230010.3.1.4.895693665.1.1724119344.1004303.mp4,a4c,0.0981,0.460436,0.285156,0.009262,0.615723,1.1e-05,0.00753,0.001049,0.081482,0.0,0.0,0.0,0.0
1,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724119076.370729/unmasked/png/1.2.276.0.7230010.3.1.4.1667523124.1.1724119579.372021.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724119076.370729/1.2.276.0.7230010.3.1.3.1667523124.1.1724119076.370730/1.2.276.0.7230010.3.1.4.1667523124.1.1724119579.372021.mp4,exclude,0.078298,0.615481,0.003279,1.1e-05,0.012772,3e-06,0.001054,0.137207,0.845703,4e-06,7e-06,3e-06,0.000182
2,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724092694.307824/unmasked/png/1.2.276.0.7230010.3.1.4.1667523124.1.1724094344.311721.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724092694.307824/1.2.276.0.7230010.3.1.3.1667523124.1.1724092694.307825/1.2.276.0.7230010.3.1.4.1667523124.1.1724094344.311721.mp4,tee,0.040586,0.698504,6e-06,0.0,8e-06,1e-06,1.4e-05,0.980469,0.019348,0.0,0.0,0.0,0.0
3,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1723914896.191703/unmasked/png/1.2.276.0.7230010.3.1.4.1667523124.1.1723914920.191729.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1723914896.191703/1.2.276.0.7230010.3.1.3.895693665.1.1723914897.823995/1.2.276.0.7230010.3.1.4.1667523124.1.1723914920.191729.mp4,psax-pm,0.07811,0.70566,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6e-05,0.000371,0.024765,0.974609
4,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724124498.384801/unmasked/png/1.2.276.0.7230010.3.1.4.1667523124.1.1724124531.384887.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724124498.384801/1.2.276.0.7230010.3.1.3.1667523124.1.1724124498.384802/1.2.276.0.7230010.3.1.4.1667523124.1.1724124531.384887.mp4,exclude,0.127712,0.72806,0.014061,2e-05,0.000137,9e-06,4e-06,0.000183,0.985352,0.0,0.0,0.0,0.0


# Compute Marginal Salience

In [120]:
import numpy as np

# 1. Identify the probability columns (they all start with "p_")
prob_cols = [c for c in es2_all.columns if c.startswith("p_")]

# 2. Pull them into a NumPy array for speed   (shape: N × 11)
P = es2_all[prob_cols].to_numpy(dtype=np.float32)

# 3. Compute the ingredients of the margin score
p_max      = P.max(1)                                    # top-class prob p_c
mean_other = (P.sum(1) - p_max) / (P.shape[1] - 1)       # mean of other K-1 classes
margin     = p_max - mean_other                          # separation term

# 4. Combine with the quality proxy q
es2_all["marginal_salience"] = 0.7 * p_max * margin + 0.3 * es2_all["quality"].to_numpy()

# Optional: drop the old salience column or keep both
# es2_all.drop(columns="salience", inplace=True)


In [121]:
es2_all.head()

Unnamed: 0,png_uri,mp4_uri,pred_view,quality,salience,p_a2c,p_a3c,p_a4c,p_a5c,p_plax,p_tee,p_exclude,p_psax-av,p_psax-mv,p_psax-ap,p_psax-pm,marginal_salience
0,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724119076.370729/unmasked/png/1.2.276.0.7230010.3.1.4.895693665.1.1724119344.1004303.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724119076.370729/1.2.276.0.7230010.3.1.3.1667523124.1.1724119076.370730/1.2.276.0.7230010.3.1.4.895693665.1.1724119344.1004303.mp4,a4c,0.0981,0.460436,0.285156,0.009262,0.615723,1.1e-05,0.00753,0.001049,0.081482,0.0,0.0,0.0,0.0,0.278239
1,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724119076.370729/unmasked/png/1.2.276.0.7230010.3.1.4.1667523124.1.1724119579.372021.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724119076.370729/1.2.276.0.7230010.3.1.3.1667523124.1.1724119076.370730/1.2.276.0.7230010.3.1.4.1667523124.1.1724119579.372021.mp4,exclude,0.078298,0.615481,0.003279,1.1e-05,0.012772,3e-06,0.001054,0.137207,0.845703,4e-06,7e-06,3e-06,0.000182,0.514991
2,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724092694.307824/unmasked/png/1.2.276.0.7230010.3.1.4.1667523124.1.1724094344.311721.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724092694.307824/1.2.276.0.7230010.3.1.3.1667523124.1.1724092694.307825/1.2.276.0.7230010.3.1.4.1667523124.1.1724094344.311721.mp4,tee,0.040586,0.698504,6e-06,0.0,8e-06,1e-06,1.4e-05,0.980469,0.019348,0.0,0.0,0.0,0.0,0.68377
3,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1723914896.191703/unmasked/png/1.2.276.0.7230010.3.1.4.1667523124.1.1723914920.191729.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1723914896.191703/1.2.276.0.7230010.3.1.3.895693665.1.1723914897.823995/1.2.276.0.7230010.3.1.4.1667523124.1.1723914920.191729.mp4,psax-pm,0.07811,0.70566,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6e-05,0.000371,0.024765,0.974609,0.686618
4,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724124498.384801/unmasked/png/1.2.276.0.7230010.3.1.4.1667523124.1.1724124531.384887.png,s3://echodata25/results/echo-study-2/1.2.276.0.7230010.3.1.2.1667523124.1.1724124498.384801/1.2.276.0.7230010.3.1.3.1667523124.1.1724124498.384802/1.2.276.0.7230010.3.1.4.1667523124.1.1724124531.384887.mp4,exclude,0.127712,0.72806,0.014061,2e-05,0.000137,9e-06,4e-06,0.000183,0.985352,0.0,0.0,0.0,0.0,0.716962


In [122]:
import pandas as pd
import numpy as np

# ---- 1 · load ----------------------------------------------------------------
# es2_all = pd.read_parquet("es2_all.parquet")   # you already have it in RAM

# ---- 2 · add helper columns --------------------------------------------------
es2_all["data_source"] = es2_all["mp4_uri"].str.extract(r"results/([^/]+)/")
es2_all["DeidentifiedStudyID"] = es2_all["mp4_uri"].str.extract(
    r"results/[^/]+/([^/]+)/"
)

# (skip if marginal_salience already computed)
prob_cols = [c for c in es2_all.columns if c.startswith("p_")]
P         = es2_all[prob_cols].to_numpy(dtype=np.float32)
p_max     = P.max(1)
margin    = p_max - (P.sum(1) - p_max) / (P.shape[1] - 1)
es2_all["marginal_salience"] = 0.7 * p_max * margin + 0.3 * es2_all["quality"].to_numpy()

# ── views we actively want ──────────────────────────────────────────────
CANONICAL = [
    "a2c", "a3c", "a4c", "a5c", "plax",
    "psax-av", "psax-mv", "psax-ap", "psax-pm",
]  # <-- “exclude” and “tee” removed

IGNORE = {"tee", "exclude"}

# ── per-study selector that skips IGNORE views ──────────────────────────
def select_clips(study_df: pd.DataFrame, k: int = 32) -> pd.Series:
    # drop clips we never want
    study_df = study_df.loc[~study_df["pred_view"].isin(IGNORE)]

    if study_df.empty:                       # edge-case: nothing left
        return pd.Series({"salient_videos": [], "salient_views": []})

    ranked = study_df.sort_values("marginal_salience", ascending=False)

    # Step 1 – best clip per wanted view present in this study
    top_per_view = (
        ranked
        .loc[ranked["pred_view"].isin(CANONICAL)]
        .groupby("pred_view", group_keys=False)
        .head(1)
    )

    # Step 2 – fill remaining slots with next-best overall
    remaining = k - len(top_per_view)
    if remaining > 0:
        mask  = ranked.index.isin(top_per_view.index)
        extra = ranked.loc[~mask].head(remaining)
        final = pd.concat([top_per_view, extra])
    else:
        final = ranked.head(k)               # very rare: > k distinct views left

    final = final.sort_values("marginal_salience", ascending=False)

    return pd.Series(
        {
            "salient_videos": final["mp4_uri"].tolist(),
            "salient_views":  final["pred_view"].tolist(),
        }
    )

# ── build study-level dataframe ─────────────────────────────────────────
salient_df = (
    es2_all
    .groupby(["data_source", "DeidentifiedStudyID"], sort=False)
    .apply(select_clips, k=32)
    .reset_index()
)

print(salient_df.head())


    data_source                                     DeidentifiedStudyID  \
0  echo-study-2  1.2.276.0.7230010.3.1.2.1667523124.1.1724119076.370729   
1  echo-study-2  1.2.276.0.7230010.3.1.2.1667523124.1.1724092694.307824   
2  echo-study-2  1.2.276.0.7230010.3.1.2.1667523124.1.1723914896.191703   
3  echo-study-2  1.2.276.0.7230010.3.1.2.1667523124.1.1724124498.384801   
4  echo-study-2  1.2.276.0.7230010.3.1.2.1667523124.1.1724099667.323188   

                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     

  .apply(select_clips, k=32)


In [123]:
print(salient_df.columns)
print(salient_df.shape)

Index(['data_source', 'DeidentifiedStudyID', 'salient_videos',
       'salient_views'],
      dtype='object')
(79526, 4)


In [124]:
# import matplotlib.pyplot as plt
# import numpy as np
# import pandas as pd

# # --------------------------------------------------------------------------
# # 1. Histogram: how many clips per study after selection
# # --------------------------------------------------------------------------
# salient_df["n_clips"] = salient_df["salient_videos"].str.len()     # 10-32

# fig, ax = plt.subplots(figsize=(6, 4))
# bins = np.arange(10, 33)                                           # 10 … 32
# counts, _, _ = ax.hist(salient_df["n_clips"], bins=bins,
#                        edgecolor="black")

# ax.set_xticks(bins)                                                # every int
# ax.set_yticks(np.arange(0, counts.max() + 1, 1))                   # every int
# ax.set_xlabel("# clips per study")
# ax.set_ylabel("count of studies")
# ax.set_title("Distribution of selected clips per study")
# plt.tight_layout()
# plt.show()


# # --------------------------------------------------------------------------
# # 2. Bar chart: which views survive most often
# # --------------------------------------------------------------------------
# view_counts = (
#     salient_df["salient_views"]
#       .explode()                      # flatten list → one row per clip
#       .value_counts()                 # frequency of each view
#       .sort_index()                   # A2C, A3C … order
# )

# fig, ax = plt.subplots(figsize=(8, 4))
# ax.bar(view_counts.index, view_counts.values, edgecolor="black")

# ax.set_xticks(range(len(view_counts)))
# ax.set_xticklabels(view_counts.index, rotation=45, ha="right")
# ax.set_yticks(np.arange(0, view_counts.max() + 1, 1))              # every int
# ax.set_ylabel("clips kept")
# ax.set_title("Kept-view frequency across all studies")
# plt.tight_layout()
# plt.show()


In [126]:
print("hello")

hello


In [125]:
salient_df.to_csv('es2_salient_vids.csv')

studies:   0%|          | 0/79598 [00:00<?, ?it/s]