In [None]:
import os, h5py, numpy as np, pandas as pd, pydicom, json
from tqdm import tqdm

h5_path = "/kaggle/input/rsna-dataset/dataset.h5"
series_root = "/kaggle/input/rsna-intracranial-aneurysm-detection/series"
patch_size = (64,64,64)
stride = (64,64,64)   # stride = taille patch => pas de recouvrement
out_path = "/kaggle/working/patches.h5"

train_df = pd.read_csv("/kaggle/input/rsna-intracranial-aneurysm-detection/train.csv")
localizers = pd.read_csv("/kaggle/input/rsna-intracranial-aneurysm-detection/train_localizers.csv")
localizers["coords"] = localizers["coordinates"].apply(eval)

_sop2z_cache = {}
def build_sop2z(series_uid: str):
    if series_uid in _sop2z_cache:
        return _sop2z_cache[series_uid]
    series_path = os.path.join(series_root, series_uid)
    files = [f for f in os.listdir(series_path) if f.endswith(".dcm")]
    rows = []
    for f in files:
        ds = pydicom.dcmread(os.path.join(series_path,f), stop_before_pixels=True, force=True)
        inst = getattr(ds, "InstanceNumber", None)
        sop = getattr(ds, "SOPInstanceUID", None)
        rows.append((inst, sop))
    rows = sorted(rows, key=lambda t: (t[0] if t[0] is not None else 0))
    mapping = {str(sop): z for z, (_, sop) in enumerate(rows) if sop is not None}
    _sop2z_cache[series_uid] = mapping
    return mapping



def crop_pad_cube(vol, start, size):
    """Extrait un cube (D,H,W) à partir de start (z,y,x), pad si nécessaire pour obtenir size."""
    D,H,W = vol.shape
    pd_,ph,pw = size
    z,y,x = start

    patch = vol[z:z+pd_, y:y+ph, x:x+pw]

    pad_z = max(0, pd_ - patch.shape[0])
    pad_y = max(0, ph  - patch.shape[1])
    pad_x = max(0, pw  - patch.shape[2])

    patch = np.pad(patch, ((0,pad_z),(0,pad_y),(0,pad_x)), mode="constant")
    return patch




with h5py.File(h5_path, "r") as f:
    meta = f["meta"][:]
uids = [uid.decode() if isinstance(uid, bytes) else str(uid) for uid in meta["series_uid"]]
idxs = [int(i) for i in meta["h5_index"]]
uid2idx = {uid: idx for uid, idx in zip(uids, idxs)}

print(f"[INFO] {len(uids)} séries trouvées dans HDF5.")





In [None]:

with h5py.File(out_path, "w") as out_f, h5py.File(h5_path, "r") as f:
    gX = out_f.create_group("X")
    gy = out_f.create_group("y")
    gmeta = out_f.create_group("meta")

    patch_id = 0
    margin = 5  # marge en voxels autour de l'anévrysme
    count_pos, count_neg = 0, 0

    pbar = tqdm(uids, desc="Découpage des séries")
    for uid in pbar:
        if uid not in uid2idx:
            continue

        h5_idx = uid2idx[uid]
        vol = f["X"][str(h5_idx)][()]
        if vol.ndim == 4 and vol.shape[0] == 1: 
            vol = vol[0]
        elif vol.ndim == 4 and vol.shape[-1] == 1: 
            vol = vol[...,0]
        if vol.ndim != 3: 
            continue

        D,H,W = vol.shape

        # --- récupérer annotations ---
        annots = localizers[localizers["SeriesInstanceUID"] == uid]
        centers = []
        if len(annots) > 0:
            sop2z = build_sop2z(uid)
            for _, row in annots.iterrows():
                sop = str(row["SOPInstanceUID"])
                if sop not in sop2z: 
                    continue
                cz = sop2z[sop]
                cx = int(round(row["coords"]["x"]))
                cy = int(round(row["coords"]["y"]))
                centers.append((cz,cy,cx))

        pd_, ph, pw = patch_size
        sd, sh, sw = stride

        for z in range(0, max(1,D-pd_+1), sd):
            for y in range(0, max(1,H-ph+1), sh):
                for x in range(0, max(1,W-pw+1), sw):
                    patch = crop_pad_cube(vol, (z,y,x), patch_size)

                    # ⚡ Stockage en uint8 = 1 octet par voxel (valeurs 0–7 déjà)
                    patch = patch.astype(np.uint8)
                    patch = np.expand_dims(patch, -1)  # (D,H,W,1)

                    # Label = 1 si un centre tombe dans le patch (+ marge)
                    label = 0
                    for (cz,cy,cx) in centers:
                        if (z - margin <= cz < z+pd_ + margin and
                            y - margin <= cy < y+ph + margin and
                            x - margin <= cx < x+pw + margin):
                            label = 1
                            break

                    if label == 1:
                        count_pos += 1
                    else:
                        count_neg += 1

                    # Sauvegarde patch
                    gX.create_dataset(str(patch_id), data=patch, compression="gzip", compression_opts=4)
                    gy.create_dataset(str(patch_id), data=np.array(label, dtype=np.uint8))
                    gmeta.create_dataset(str(patch_id), 
                        data=np.string_(json.dumps({
                            "series_uid": uid,
                            "patch_id": patch_id,
                            "coords": [z,y,x],
                            "label": int(label)
                        })))
                    patch_id += 1

                    # --- log tous les 100 patches ---
                    if patch_id % 100 == 0:
                        total = count_pos + count_neg
                        pct = (count_pos / total * 100) if total > 0 else 0
                        print(f"[LOG] {patch_id} patches | pos={count_pos} | neg={count_neg} | %pos={pct:.2f}%")

    total = count_pos + count_neg
    pct = (count_pos / total * 100) if total > 0 else 0
    pbar.set_postfix({
        "patches": total,
        "pos": count_pos,
        "neg": count_neg,
        "%pos": f"{pct:.2f}%"
    })

print(f"✅ Fini. Sauvé {patch_id} patches dans {out_path}")
print(f"   Patches positifs : {count_pos}")
print(f"   Patches négatifs : {count_neg}")
