# import

In [None]:
import os
import json
import random
import numpy as np
import pandas as pd

ROOT = "/Users/lixiaokang/daily/data/train_samples"

META_DIR = os.path.join(ROOT, "meta")
IMG_DIR = os.path.join(ROOT, "images")
TAB_DIR = os.path.join(ROOT, "tabular")
TS_DIR = os.path.join(ROOT, "timeseries_merged")

SAMPLE_INDEX_PATH = os.path.join(META_DIR, "sample_index.json")

with open(SAMPLE_INDEX_PATH, "r") as f:
    sample_index = json.load(f)

df_index = pd.DataFrame(sample_index)

print("Total samples:", len(df_index))

df_demo = df_index.sample(n=10, random_state=42).reset_index(drop=True)

print("Selected sample_ids:")
print(df_demo["sample_id"].tolist())

TAB_FILES = [
    "admissions.csv",
    "demographics.csv",
    "icustays.csv",
]

tabular_tables = {}
for fname in TAB_FILES:
    path = os.path.join(TAB_DIR, fname)
    if os.path.exists(path):
        tabular_tables[fname.replace(".csv", "")] = pd.read_csv(path)

TS_FILES = [
    "labevents.csv",
    "chartevents.csv",
]

timeseries_tables = {}
for fname in TS_FILES:
    path = os.path.join(TS_DIR, fname)
    if os.path.exists(path):
        df = pd.read_csv(path)
        if "charttime" in df.columns:
            df["charttime"] = pd.to_datetime(df["charttime"])
        timeseries_tables[fname.replace(".csv", "")] = df

demo_samples = []

for _, row in df_demo.iterrows():
    sample_id = row["sample_id"]
    subject_id = row["subject_id"]
    cxr_time = pd.to_datetime(row["cxr_time"])

    sample = {
        "sample_id": sample_id,
        "subject_id": subject_id,
        "meta": row.to_dict(),
    }

    sample_id = row["sample_id"]
    img_path = os.path.join(IMG_DIR, f"{sample_id}.npy")

    if os.path.exists(img_path):
        sample["image"] = np.load(img_path)
    else:
        print(f"[WARN] image not found for {sample_id}")
        sample["image"] = None

    tab_data = {}
    for name, df in tabular_tables.items():
        if "subject_id" in df.columns:
            tab_data[name] = df[df["subject_id"] == subject_id]
    sample["tabular"] = tab_data

    ts_data = {}
    for name, df in timeseries_tables.items():
        if "subject_id" in df.columns and "charttime" in df.columns:
            ts_data[name] = df[
                (df["subject_id"] == subject_id) &
                (df["charttime"] <= cxr_time) &
                (df["charttime"] >= cxr_time - pd.Timedelta(hours=48))
            ]
    sample["timeseries"] = ts_data

    sample["label_any"] = row.get("label_any", None)
    sample["label_multi"] = row.get("label_multi", None)

    demo_samples.append(sample)

print(f"\nDemo dataset built: {len(demo_samples)} samples")


Total samples: 1446
Selected sample_ids:
['s_000413', 's_000316', 's_000554', 's_000065', 's_001380', 's_000967', 's_000175', 's_000836', 's_000651', 's_000231']

Demo dataset built: 10 samples


In [5]:
import os
import json
import numpy as np
import pandas as pd


# paths
ROOT = "E:/NUS/data/perdata/train_text_samples"

META_DIR = os.path.join(ROOT, "meta")
IMG_DIR = os.path.join(ROOT, "images")
TAB_DIR = os.path.join(ROOT, "tabular")
TS_DIR = os.path.join(ROOT, "timeseries_merged")
TEXT_PATH = os.path.join(ROOT, "text.csv")

SAMPLE_INDEX_PATH = os.path.join(META_DIR, "sample_index.json")

# load sample index
with open(SAMPLE_INDEX_PATH, "r") as f:
    sample_index = json.load(f)

df_index = pd.DataFrame(sample_index)
print("Total samples:", len(df_index))


# random select demo samples
df_demo = df_index.sample(n=10, random_state=42).reset_index(drop=True)
print("Selected sample_ids:", df_demo["sample_id"].tolist())

# load tabular tables (once)
# TAB_FILES = ["admissions.csv", "demographics.csv", "icustays.csv"]
tabular_tables = {}

TAB_FILES = [f for f in os.listdir(TAB_DIR) if f.lower().endswith(".csv")]

for fname in TAB_FILES:
    path = os.path.join(TAB_DIR, fname)
    try:
        tabular_tables[fname[:-4]] = pd.read_csv(path)
    except Exception as e:
        print(f"[WARN] failed to read {fname}: {e}")


# load & preprocess text table (once)
if os.path.exists(TEXT_PATH):
    df_text = pd.read_csv(TEXT_PATH)

    if "cxr_time" in df_text.columns:
        df_text["cxr_time"] = pd.to_datetime(df_text["cxr_time"], errors="coerce")

    df_text["text"] = df_text["text"].fillna("").astype(str)

    if "dicom_id" in df_text.columns:
        text_map_dicom = (
            df_text.groupby("dicom_id")["text"]
            .apply(lambda s: "\n".join([t.strip() for t in s if t.strip()]))
            .to_dict()
        )
    else:
        text_map_dicom = {}

else:
    print("[WARN] text.csv not found")
    df_text = None
    text_map_dicom = {}


# load timeseries tables (once)
timeseries_tables = {}

TS_FILES = [f for f in os.listdir(TS_DIR) if f.lower().endswith(".csv")]

for fname in TS_FILES:
    path = os.path.join(TS_DIR, fname)
    if os.path.exists(path):
        df = pd.read_csv(path)
        if "charttime" in df.columns:
            df["charttime"] = pd.to_datetime(df["charttime"], errors="coerce")
        timeseries_tables[fname[:-4]] = df  # remove ".csv"



# build demo samples (CXR-level)

demo_samples = []

for _, row in df_demo.iterrows():
    sample_id = row["sample_id"]
    subject_id = row["subject_id"]
    cxr_time = pd.to_datetime(row["cxr_time"], errors="coerce")
    dicom_id = row.get("dicom_id", None)

    sample = {
        "sample_id": sample_id,
        "subject_id": subject_id,
        "dicom_id": dicom_id,
        "cxr_time": cxr_time,
        "meta": row.to_dict(),
    }

    # image (1 CXR = 1 image)
    img_path = os.path.join(IMG_DIR, f"{sample_id}.npy")
    sample["image"] = np.load(img_path) if os.path.exists(img_path) else None

    # text
    text = ""

    if dicom_id is not None and dicom_id in text_map_dicom:
        text = text_map_dicom[dicom_id]

    elif df_text is not None:
        rows = df_text[
            (df_text["subject_id"] == subject_id) &
            (df_text["cxr_time"].notna()) &
            (abs(df_text["cxr_time"] - cxr_time) <= pd.Timedelta(minutes=5))
        ]
        if len(rows) > 0:
            text = "\n".join(
                [t.strip() for t in rows["text"].tolist() if t.strip()]
            )

    sample["text"] = text  

    # tabular (patient-level snapshot)
    tab_data = {}
    for name, df in tabular_tables.items():
        if "subject_id" in df.columns:
            tab_data[name] = df[df["subject_id"] == subject_id]
    sample["tabular"] = tab_data

    # timeseries (48h before CXR)
    ts_data = {}
    for name, df in timeseries_tables.items():
        if "subject_id" in df.columns and "charttime" in df.columns and pd.notna(cxr_time):
            ts_data[name] = df[
                (df["subject_id"] == subject_id) &
                (df["charttime"] <= cxr_time) &
                (df["charttime"] >= cxr_time - pd.Timedelta(hours=48))
            ]
    sample["timeseries"] = ts_data

    # labels
    sample["label_any"] = row.get("label_any", None)
    sample["label_multi"] = row.get("label_multi", None)

    demo_samples.append(sample)

print(f"\nDemo dataset built: {len(demo_samples)} samples")


Total samples: 1446
Selected sample_ids: ['s_000413', 's_000316', 's_000554', 's_000065', 's_001380', 's_000967', 's_000175', 's_000836', 's_000651', 's_000231']


  tabular_tables[fname[:-4]] = pd.read_csv(path)



Demo dataset built: 10 samples


In [13]:
import os
import numpy as np
import pandas as pd
from config import TS_CFG

def analyze_timeseries_tables(
    TS_DIR: str,
    tables=None,
    chunksize: int = 1_000_000,
    candidates=(128, 256, 512, 1024),
):
    """
    逐表按 chunk 扫描：
      - 总行数
      - subject_id 缺失率
      - 对每个候选 num_buckets，统计 bucket 行数分布（max/mean）
    不做时间列解析，速度快，内存稳。
    """
    if tables is None:
        tables = list(TS_CFG.keys())

    report = []

    for table in tables:
        path = os.path.join(TS_DIR, f"{table}.csv")
        if not os.path.exists(path):
            print(f"[MISS] {table}.csv")
            continue

        cfg = TS_CFG.get(table, {})
        # 统计时只读 subject_id 就够了（速度最快）
        usecols = cfg.get("usecols")
        if usecols is None or "subject_id" not in usecols:
            usecols = ["subject_id"]
        else:
            usecols = ["subject_id"]  # 统计阶段只取这列即可

        # 为了稳，dtype 用 object/float 都行；这里用 object 再转 numeric
        reader = pd.read_csv(
            path,
            usecols=usecols,
            chunksize=chunksize,
            low_memory=False,
        )

        size_mb = os.path.getsize(path) / (1024**2)

        total_rows = 0
        missing_sid = 0

        # 为每个 candidate 准备一个 bincount 数组
        # 注意：num_buckets 越大，数组越大，但 1024 也就 1024 个 int64，很小
        bucket_counts = {B: np.zeros(B, dtype=np.int64) for B in candidates}

        for chunk_id, df in enumerate(reader):
            total_rows += len(df)

            sid = pd.to_numeric(df["subject_id"], errors="coerce")
            missing_sid += sid.isna().sum()

            sid = sid.dropna().astype(np.int64).to_numpy()

            # 对每个候选桶数，统计该 chunk 的桶分布并累加
            for B in candidates:
                b = sid % B
                bc = np.bincount(b, minlength=B)
                bucket_counts[B] += bc

            if chunk_id % 10 == 0:
                print(f"[{table}] scanned chunks={chunk_id}, rows={total_rows:,}")

        valid_rows = total_rows - missing_sid
        miss_rate = (missing_sid / total_rows) if total_rows > 0 else 0.0

        row = {
            "table": table,
            "size_mb": round(size_mb, 2),
            "rows_total": int(total_rows),
            "rows_valid_subject": int(valid_rows),
            "subject_missing_rate": round(miss_rate, 6),
        }

        # 汇总每个候选桶数的 max/mean
        for B in candidates:
            counts = bucket_counts[B]
            row[f"b{B}_mean_rows"] = int(counts.mean())
            row[f"b{B}_max_rows"]  = int(counts.max())
            # 可选：95分位数看看“通常最坏多大”
            row[f"b{B}_p95_rows"]  = int(np.quantile(counts, 0.95))

        report.append(row)

        print(f"[DONE] {table}: rows={total_rows:,}, size={size_mb:.2f}MB, miss_sid={miss_rate:.4%}")

    return pd.DataFrame(report)


if __name__ == "__main__":
    TS_DIR = r"E:/NUS/data/perdata/train_text_samples/timeseries_merged"
    df = analyze_timeseries_tables(
        TS_DIR=TS_DIR,
        tables=list(TS_CFG.keys()),
        chunksize=1_000_000,
        candidates=(128, 256, 512, 1024),
    )
    print("\n===== SUMMARY =====")
    print(df.sort_values("rows_total", ascending=False).to_string(index=False))
    df.to_csv("timeseries_bucket_planning.csv", index=False)
    print("\nSaved: timeseries_bucket_planning.csv")


[chartevents] scanned chunks=0, rows=602,046
[DONE] chartevents: rows=602,046, size=135.94MB, miss_sid=0.0000%
[labevents] scanned chunks=0, rows=51,765
[DONE] labevents: rows=51,765, size=10.60MB, miss_sid=0.0000%
[inputevents] scanned chunks=0, rows=313,095
[DONE] inputevents: rows=313,095, size=116.46MB, miss_sid=0.0000%
[outputevents] scanned chunks=0, rows=9,228
[DONE] outputevents: rows=9,228, size=1.72MB, miss_sid=0.0000%
[procedureevents] scanned chunks=0, rows=18,214
[DONE] procedureevents: rows=18,214, size=5.89MB, miss_sid=0.0000%

===== SUMMARY =====
          table  size_mb  rows_total  rows_valid_subject  subject_missing_rate  b128_mean_rows  b128_max_rows  b128_p95_rows  b256_mean_rows  b256_max_rows  b256_p95_rows  b512_mean_rows  b512_max_rows  b512_p95_rows  b1024_mean_rows  b1024_max_rows  b1024_p95_rows
    chartevents   135.94      602046              602046                   0.0            4703          66014          34284            2351          66014          

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
from typing import List, Dict, Any, Optional, Callable, Tuple
import numpy as np


class MultiModalDataset(Dataset):
    """
    One item = one CXR event (image-level).

    Expected sample dict keys (from your demo_samples):
      - image: np.ndarray (H,W) or (1,H,W) or None
      - text: str (can be empty)
      - tabular: Dict[str, pd.DataFrame]
      - timeseries: Dict[str, pd.DataFrame]
      - label_any: int/float/bool or None
      - label_multi: list/np.ndarray or None
      - meta fields: sample_id, subject_id, dicom_id, cxr_time (optional)
    """

    def __init__(
        self,
        samples: List[Dict[str, Any]],
        image_transform: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,
        return_meta: bool = True,
        image_normalize_255: bool = True,  # if uint8 0~255, normalize to 0~1
    ):
        self.samples = samples
        self.image_transform = image_transform
        self.return_meta = return_meta
        self.image_normalize_255 = image_normalize_255

        # infer multi-label dimension (optional but helpful)
        self.num_labels = None
        for s in samples:
            lm = s.get("label_multi", None)
            if lm is not None:
                try:
                    self.num_labels = int(len(lm))
                    break
                except Exception:
                    pass

    def __len__(self):
        return len(self.samples)

    def _to_image_tensor(self, image: Any) -> Optional[torch.Tensor]:
        if image is None:
            return None
        if isinstance(image, torch.Tensor):
            img = image.float()
        else:
            img = torch.from_numpy(np.asarray(image)).float()

        # shape to (1,H,W)
        if img.ndim == 2:
            img = img.unsqueeze(0)
        elif img.ndim == 3 and img.shape[0] != 1:
            # if (H,W,C) or (C,H,W) unexpected, you can adjust here
            # for safety: if last dim is 1, convert to (1,H,W)
            if img.shape[-1] == 1:
                img = img.permute(2, 0, 1)
            else:
                # fallback: take first channel
                img = img[0:1, ...]
        elif img.ndim != 3:
            # unexpected
            return None

        # normalize if looks like 0~255
        if self.image_normalize_255:
            if img.max() > 1.5:
                img = img / 255.0

        if self.image_transform is not None:
            img = self.image_transform(img)
        return img

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        s = self.samples[idx]

        # image
        image = self._to_image_tensor(s.get("image", None))

        # text
        text = s.get("text", "")
        if text is None:
            text = ""

        # tabular & timeseries: keep as-is (DataFrame dicts)
        tabular = s.get("tabular", {}) or {}
        timeseries = s.get("timeseries", {}) or {}

        # labels -> tensor-friendly
        label_any = s.get("label_any", None)
        if label_any is None:
            label_any_t = None
        else:
            label_any_t = torch.tensor(float(label_any), dtype=torch.float32)

        label_multi = s.get("label_multi", None)
        if label_multi is None:
            label_multi_t = None
        else:
            label_multi_arr = np.asarray(label_multi, dtype=np.float32)
            label_multi_t = torch.from_numpy(label_multi_arr)

        item = {
            "image": image,            # torch.Tensor or None
            "text": text,              # str
            "tabular": tabular,        # Dict[str, DataFrame]
            "timeseries": timeseries,  # Dict[str, DataFrame]
            "label_any": label_any_t,  # torch.Tensor or None
            "label_multi": label_multi_t,  # torch.Tensor(C,) or None
        }

        if self.return_meta:
            item.update({
                "sample_id": s.get("sample_id"),
                "subject_id": s.get("subject_id"),
                "dicom_id": s.get("dicom_id"),
                "cxr_time": s.get("cxr_time"),
            })

        return item


def collate_multimodal(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Custom collate that:
      - stacks images into (B,1,H,W) with zero-fill if None, plus image_mask
      - keeps text as List[str]
      - keeps tabular/timeseries as List[Dict[str, DataFrame]]
      - stacks labels if present, else returns None
      - keeps meta as List
    """
    out: Dict[str, Any] = {}

    # ---- images ----
    images = [b["image"] for b in batch]
    has_img = [img is not None for img in images]
    image_mask = torch.tensor(has_img, dtype=torch.bool)

    if any(has_img):
        # find reference shape
        ref = next(img for img in images if img is not None)
        B, C, H, W = len(images), ref.shape[0], ref.shape[1], ref.shape[2]
        img_batch = torch.zeros((B, C, H, W), dtype=ref.dtype)
        for i, img in enumerate(images):
            if img is not None:
                # if shapes mismatch, you could add resize/pad here
                img_batch[i] = img
        out["image"] = img_batch
    else:
        out["image"] = None

    out["image_mask"] = image_mask  # (B,)

    # ---- text ----
    out["text"] = [b.get("text", "") or "" for b in batch]

    # ---- tabular/timeseries (keep list-of-dicts of DataFrames) ----
    out["tabular"] = [b.get("tabular", {}) or {} for b in batch]
    out["timeseries"] = [b.get("timeseries", {}) or {} for b in batch]

    # ---- labels ----
    la = [b.get("label_any", None) for b in batch]
    if all(x is not None for x in la):
        out["label_any"] = torch.stack(la, dim=0)  # (B,)
    else:
        out["label_any"] = None

    lm = [b.get("label_multi", None) for b in batch]
    if all(x is not None for x in lm):
        out["label_multi"] = torch.stack(lm, dim=0)  # (B,C)
    else:
        out["label_multi"] = None

    # ---- meta (optional keys) ----
    for k in ["sample_id", "subject_id", "dicom_id", "cxr_time"]:
        if k in batch[0]:
            out[k] = [b.get(k) for b in batch]

    return out


# ===== Usage example =====
dataset = MultiModalDataset(demo_samples, image_transform=None, return_meta=True)
loader = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=collate_multimodal)
batch = next(iter(loader))
print(batch["image"].shape if batch["image"] is not None else None)
print(batch["image_mask"])
print(len(batch["tabular"]), batch["tabular"][0].keys())


torch.Size([2, 1, 224, 224])
tensor([True, True])
2 dict_keys(['admissions', 'demographics', 'diagnoses_icd', 'icustays', 'prescriptions', 'procedures_icd', 'transfers'])


In [6]:
dataset = MultiModalDataset(
    demo_samples,
    image_transform=None,   
    return_meta=True,
)
sample = dataset[0]

print("sample_id:", sample["sample_id"])
print("dicom_id:", sample["dicom_id"])
print("image shape:", None if sample["image"] is None else sample["image"].shape)
print("text preview:", sample["text"][:300])
print("tabular keys:", sample["tabular"].keys())
print("timeseries keys:", sample["timeseries"].keys())
print("label_any:", sample["label_any"])
print("label_multi:", sample["label_multi"])


sample_id: s_000413
dicom_id: 4a985010-1425fdc1-f649f309-8b269e24-6d85c760
image shape: torch.Size([1, 224, 224])
text preview: 
tabular keys: dict_keys(['admissions', 'demographics', 'diagnoses_icd', 'icustays', 'prescriptions', 'procedures_icd', 'transfers'])
timeseries keys: dict_keys(['chartevents', 'labevents', 'outputevents'])
label_any: 1
label_multi: [1, 1, 0, 0, 0, 0, 0, 0]


In [23]:
num_total = len(dataset)

num_non_empty = 0
num_empty = 0

for s in dataset:
    txt = s["text"]
    if txt is not None and txt.strip() != "":
        num_non_empty += 1
    else:
        num_empty += 1

print(f"Total samples: {num_total}")
print(f"Samples with text: {num_non_empty}")
print(f"Samples without text: {num_empty}")
print(f"Text coverage: {num_non_empty / num_total:.2%}")



Total samples: 10
Samples with text: 6
Samples without text: 4
Text coverage: 60.00%


# image
* densenet121->emb,rqkmeans->codebook

In [None]:
import numpy as np
import os

img_path = "/Users/lixiaokang/daily/data/train_samples/images/s_000000.npy"

img = np.load(img_path, allow_pickle=True)

print("type:", type(img))
print("dtype:", img.dtype)

if isinstance(img, np.ndarray):
    print("shape:", img.shape)
    print("ndim:", img.ndim)
    print("size:", img.size)

    print("first elements:", img.flatten()[:10])
else:
    print("content:", img)



type: <class 'numpy.ndarray'>
dtype: uint8
shape: (224, 224)
ndim: 2
size: 50176
first elements: [0 0 0 0 0 0 0 0 0 0]


In [3]:
img.min(), img.max(), img.mean()


(np.uint8(0), np.uint8(255), np.float64(94.75358737244898))

In [4]:
sample = demo_samples[0]
sample.keys()

dict_keys(['sample_id', 'subject_id', 'meta', 'image', 'tabular', 'timeseries', 'label_any', 'label_multi'])

In [5]:
img = sample["image"]

print(type(img))
print(img.shape)
print(img.dtype)

print(img.min(), img.max(), img.mean())


<class 'numpy.ndarray'>
(224, 224)
uint8
0 255 118.86437739158163


## densenet121

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class DenseNet121_Encoder(nn.Module):
    def __init__(self, out_dim=256, pretrained=True):
        super().__init__()

        # DenseNet-121 backbone
        self.backbone = models.densenet121(pretrained=pretrained)

        self.backbone.features.conv0 = nn.Conv2d(
            1, 64, kernel_size=7, stride=2, padding=3, bias=False
        )

        feat_dim = self.backbone.classifier.in_features

        self.backbone.classifier = nn.Identity()

        self.proj = nn.Linear(feat_dim, out_dim)

    def forward(self, x):
        """
        x: (B, 1, 224, 224)
        return:
            z: (B, 256)
        """
        feat = self.backbone(x)        # (B, 1024)
        z = self.proj(feat)             # (B, 256)
        z = F.normalize(z, dim=-1)     
        return z

class ImageSemanticModel(nn.Module):
    def __init__(self, emb_dim=256, num_labels=8):
        super().__init__()

        self.encoder = DenseNet121_Encoder(out_dim=emb_dim)

        self.cls_any = nn.Linear(emb_dim, 1)

        self.cls_multi = nn.Linear(emb_dim, num_labels)

    def forward(self, x):
        z = self.encoder(x)

        out_any = self.cls_any(z).squeeze(-1)       # (B,)
        out_multi = self.cls_multi(z)               # (B, num_labels)

        return {
            "embedding": z,
            "logit_any": out_any,
            "logit_multi": out_multi,
        }


In [7]:
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, demo_samples):
        self.samples = demo_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]

        img = s["image"]                      # (224,224) uint8
        img = img.astype("float32") / 255.0  # normalize
        img = torch.from_numpy(img).unsqueeze(0)  # (1,224,224)

        label_any = torch.tensor(s["label_any"], dtype=torch.float32)
        label_multi = torch.tensor(s["label_multi"], dtype=torch.float32)

        return img, label_any, label_multi

def train_one_epoch(loader):
    model.train()
    total_loss = 0

    for img, y_any, y_multi in loader:
        img = img.to(device)
        y_any = y_any.to(device)
        y_multi = y_multi.to(device)

        out = model(img)

        loss_any = criterion_any(out["logit_any"], y_any)
        loss_multi = criterion_multi(out["logit_multi"], y_multi)

        loss = loss_any + loss_multi

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)


In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = ImageSemanticModel(emb_dim=256, num_labels=8).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

criterion_any = nn.BCEWithLogitsLoss()
criterion_multi = nn.BCEWithLogitsLoss()

dataset = ImageDataset(demo_samples)
loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=4,
    shuffle=True,
    num_workers=0,
)

for epoch in range(3):
    loss = train_one_epoch(loader)
    print(f"Epoch {epoch}: loss = {loss:.4f}")




Epoch 0: loss = 1.3519
Epoch 1: loss = 1.2802
Epoch 2: loss = 1.2355


In [19]:
model.eval()
with torch.no_grad():
    img, _, _ = dataset[0]
    z = model.encoder(img.unsqueeze(0).to(device))
    print(z.shape)   # (1, 256)

torch.Size([1, 256])


In [9]:
model.eval()
Z = []

with torch.no_grad():
    for img, _, _ in loader:
        img = img.to(device)
        z = model.encoder(img)   # (B, 256)
        Z.append(z.cpu())

Z = torch.cat(Z, dim=0)   # (N, 256)

In [None]:
from rqvae import RQVAE
class ZDataset(torch.utils.data.Dataset):
    def __init__(self, Z):
        self.Z = Z

    def __len__(self):
        return self.Z.size(0)

    def __getitem__(self, idx):
        return self.Z[idx]

rq_model = RQVAE(
    in_dim=256,                     
    num_emb_list=[8, 8, 8],    
    e_dim=256,                      
    layers=[],                      
    dropout_prob=0.0,
    bn=False,
    loss_type="mse",
    quant_loss_weight=1.0,
    beta=0.25,
    kmeans_init=True,              
    kmeans_iters=50,
    sk_epsilons=[0.0, 0.0, 0.0],     
    sk_iters=50,
).to(device)

optimizer = torch.optim.Adam(rq_model.parameters(), lr=1e-3)

rq_dataset = ZDataset(Z)
rq_loader = torch.utils.data.DataLoader(
    rq_dataset,
    batch_size=256,
    shuffle=True,
    num_workers=0
)


In [None]:
out, rq_loss, indices = rq_model(Z, use_sk=False)

print(out.shape)       # (N, 256)
print(indices.shape)   # (N, 3)   


torch.Size([10, 256])
torch.Size([10, 3])


  return fit_method(estimator, *args, **kwargs)


# config

In [None]:
TAB_KEEP_COLS = {
    "demographics": ["subject_id", "gender", "anchor_age", "anchor_year_group"],

    "admissions": [
        "subject_id", "hadm_id", "admittime",
        "admission_type", "admission_location",
        "insurance", "language", "marital_status",
        # "ethnicity", "edregtime", "edouttime"  
    ],

    "icustays": ["subject_id", "hadm_id", "stay_id", "first_careunit", "intime"],

    "transfers": ["subject_id", "hadm_id", "eventtype", "careunit", "intime", "outtime"],

    "prescriptions": ["subject_id", "hadm_id", "starttime", "drug_type", "route", "drug"],

    "procedures_icd": ["subject_id", "hadm_id", "chartdate", "icd_code", "icd_version"],

    # leak error
    # "diagnoses_icd": ["subject_id", "hadm_id", "icd_code", "icd_version"],
}


# tabluar

In [10]:
import os
import json
import pandas as pd

ROOT = "E:/NUS/data/perdata/train_text_samples"
TAB_DIR = os.path.join(ROOT, "tabular")

# 输出文件（可选）
OUT_JSON = os.path.join(TAB_DIR, "_tabular_schema_columns.json")
OUT_TXT  = os.path.join(TAB_DIR, "_tabular_schema_columns.txt")

def scan_tabular_schema(tab_dir: str):
    csv_files = sorted([f for f in os.listdir(tab_dir) if f.lower().endswith(".csv")])

    schema = {}
    for fname in csv_files:
        path = os.path.join(tab_dir, fname)

        info = {
            "file": fname,
            "path": path,
            "nrows": None,          # 行数（需要读全表才能精确，这里用可选方式）
            "has_subject_id": False,
            "time_cols": [],
            "columns": [],
        }

        try:
            # 只读表头，不读数据
            header = pd.read_csv(path, nrows=0)
            cols = list(header.columns)
            info["columns"] = cols
            info["has_subject_id"] = ("subject_id" in header.columns)

            # 粗略识别可能的时间列名（你也可以扩展关键词）
            time_keywords = ("time", "date", "admit", "discharge", "start", "end")
            info["time_cols"] = [c for c in cols if any(k in c.lower() for k in time_keywords)]

        except Exception as e:
            info["error"] = str(e)

        schema[fname[:-4]] = info  # key 用表名（去掉 .csv）

    return schema

schema = scan_tabular_schema(TAB_DIR)

# 打印到控制台（方便你快速看）
for table, info in schema.items():
    print("=" * 80)
    print(f"Table: {table}")
    if "error" in info:
        print(f"[ERROR] {info['error']}")
        continue
    print(f"has_subject_id: {info['has_subject_id']}")
    print(f"time_cols: {info['time_cols']}")
    print(f"num_cols: {len(info['columns'])}")
    print("columns:")
    print(info["columns"])

# 保存成 json（方便你后面人工筛选/写配置）
with open(OUT_JSON, "w", encoding="utf-8") as f:
    json.dump(schema, f, ensure_ascii=False, indent=2)

# 同时保存成 txt（更适合肉眼扫）
with open(OUT_TXT, "w", encoding="utf-8") as f:
    for table, info in schema.items():
        f.write("=" * 80 + "\n")
        f.write(f"Table: {table}\n")
        if "error" in info:
            f.write(f"[ERROR] {info['error']}\n")
            continue
        f.write(f"has_subject_id: {info['has_subject_id']}\n")
        f.write(f"time_cols: {info['time_cols']}\n")
        f.write(f"num_cols: {len(info['columns'])}\n")
        f.write("columns:\n")
        f.write(", ".join(info["columns"]) + "\n\n")

print("\nSaved schema files:")
print(" -", OUT_JSON)
print(" -", OUT_TXT)


Table: admissions
has_subject_id: True
time_cols: ['admittime', 'dischtime', 'deathtime', 'discharge_location', 'edregtime', 'edouttime']
num_cols: 16
columns:
['subject_id', 'hadm_id', 'admittime', 'dischtime', 'deathtime', 'admission_type', 'admission_location', 'discharge_location', 'insurance', 'language', 'marital_status', 'ethnicity', 'edregtime', 'edouttime', 'hospital_expire_flag', 'sample_id']
Table: demographics
has_subject_id: True
time_cols: ['gender']
num_cols: 7
columns:
['subject_id', 'gender', 'anchor_age', 'anchor_year', 'anchor_year_group', 'dod', 'sample_id']
Table: diagnoses_icd
has_subject_id: True
time_cols: []
num_cols: 7
columns:
['subject_id', 'hadm_id', 'seq_num', 'icd_code', 'icd_version', 'long_title', 'sample_id']
Table: icustays
has_subject_id: True
time_cols: ['intime', 'outtime']
num_cols: 9
columns:
['subject_id', 'hadm_id', 'stay_id', 'first_careunit', 'last_careunit', 'intime', 'outtime', 'los', 'sample_id']
Table: prescriptions
has_subject_id: True
t

# timeseries

In [11]:
import os
import json
import pandas as pd

ROOT = "E:/NUS/data/perdata/train_text_samples"
TS_DIR = os.path.join(ROOT, "timeseries_merged")

# 输出文件（可选）
OUT_JSON = os.path.join(TS_DIR, "_timeseries_schema_columns.json")
OUT_TXT  = os.path.join(TS_DIR, "_timeseries_schema_columns.txt")


def scan_timeseries_schema(ts_dir: str):
    csv_files = sorted([f for f in os.listdir(ts_dir) if f.lower().endswith(".csv")])

    schema = {}
    for fname in csv_files:
        path = os.path.join(ts_dir, fname)

        info = {
            "file": fname,
            "path": path,
            "has_subject_id": False,
            "time_cols_exact": [],   # 常见时间列名命中
            "time_cols_keyword": [], # 关键词匹配出来的
            "columns": [],
        }

        try:
            # 只读表头
            header = pd.read_csv(path, nrows=0)
            cols = list(header.columns)
            info["columns"] = cols
            info["has_subject_id"] = ("subject_id" in header.columns)

            # 1) 常见时间列名（timeseries 通常就这些）
            exact_time_names = {
                "charttime", "storetime",
                "starttime", "stoptime",
                "intime", "outtime",
                "eventtime", "endtime",
                "chartdate", "eventdate",
                "admittime", "dischtime",
            }
            info["time_cols_exact"] = [c for c in cols if c.lower() in exact_time_names]

            # 2) 关键词匹配（兜底）
            time_keywords = ("time", "date")
            info["time_cols_keyword"] = [
                c for c in cols
                if any(k in c.lower() for k in time_keywords)
            ]

        except Exception as e:
            info["error"] = str(e)

        schema[fname[:-4]] = info  # key 用表名（去掉 .csv）

    return schema


schema = scan_timeseries_schema(TS_DIR)

# 打印到控制台
for table, info in schema.items():
    print("=" * 80)
    print(f"Table: {table}")
    if "error" in info:
        print(f"[ERROR] {info['error']}")
        continue
    print(f"has_subject_id: {info['has_subject_id']}")
    print(f"time_cols_exact: {info['time_cols_exact']}")
    print(f"time_cols_keyword: {info['time_cols_keyword']}")
    print(f"num_cols: {len(info['columns'])}")
    print("columns:")
    print(info["columns"])

# 保存成 json
with open(OUT_JSON, "w", encoding="utf-8") as f:
    json.dump(schema, f, ensure_ascii=False, indent=2)

# 保存成 txt
with open(OUT_TXT, "w", encoding="utf-8") as f:
    for table, info in schema.items():
        f.write("=" * 80 + "\n")
        f.write(f"Table: {table}\n")
        if "error" in info:
            f.write(f"[ERROR] {info['error']}\n")
            continue
        f.write(f"has_subject_id: {info['has_subject_id']}\n")
        f.write(f"time_cols_exact: {info['time_cols_exact']}\n")
        f.write(f"time_cols_keyword: {info['time_cols_keyword']}\n")
        f.write(f"num_cols: {len(info['columns'])}\n")
        f.write("columns:\n")
        f.write(", ".join(info["columns"]) + "\n\n")

print("\nSaved schema files:")
print(" -", OUT_JSON)
print(" -", OUT_TXT)


Table: chartevents
has_subject_id: True
time_cols_exact: ['charttime', 'storetime']
time_cols_keyword: ['charttime', 'storetime', 'deltacharttime', 'cxr_time']
num_cols: 22
columns:
Table: inputevents
has_subject_id: True
time_cols_exact: ['starttime', 'endtime', 'storetime']
time_cols_keyword: ['starttime', 'endtime', 'storetime', 'cxr_time']
num_cols: 37
columns:
['subject_id', 'hadm_id', 'stay_id', 'starttime', 'endtime', 'storetime', 'itemid', 'amount', 'amountuom', 'rate', 'rateuom', 'orderid', 'linkorderid', 'ordercategoryname', 'secondaryordercategoryname', 'ordercomponenttypedescription', 'ordercategorydescription', 'patientweight', 'totalamount', 'totalamountuom', 'isopenbag', 'continueinnextdept', 'cancelreason', 'statusdescription', 'originalamount', 'originalrate', 'label', 'abbreviation', 'linksto', 'category', 'unitname', 'param_type', 'lownormalvalue', 'highnormalvalue', 'sample_id', 'cxr_time', 'delta_hours']
Table: labevents
has_subject_id: True
time_cols_exact: ['char

In [45]:
import os
import glob
import pandas as pd

ROOT = r"E:/NUS/data/perdata/train_text_all_samples/timeseries_merged"

def count_rows_fast(csv_path: str):
    # 逐行计数，几乎不占内存；一般 MIMIC 这种 CSV 没“字段内换行”，可放心用
    n = 0
    with open(csv_path, "rb") as f:
        for _ in f:
            n += 1
    return max(0, n - 1)  # 减 header

def human_size(num_bytes: int):
    gb = num_bytes / (1024**3)
    mb = num_bytes / (1024**2)
    return (f"{gb:.2f} GB" if gb >= 1 else f"{mb:.2f} MB")

rows = []
for fp in sorted(glob.glob(os.path.join(ROOT, "*.csv"))):
    size_bytes = os.path.getsize(fp)
    # 只读 header 拿列数
    cols = pd.read_csv(fp, nrows=0, low_memory=False).shape[1]
    nrows = count_rows_fast(fp)

    rows.append({
        "file": os.path.basename(fp),
        "size": human_size(size_bytes),
        "rows": nrows,
        "cols": cols,
        "path": fp,
    })

df = pd.DataFrame(rows).sort_values(["rows", "file"], ascending=[False, True])
pd.set_option("display.width", 200)
pd.set_option("display.max_columns", 50)
print(df.to_string(index=False))



               file      size     rows  cols                                                                             path
    chartevents.csv   6.53 GB 30153901    22     E:/NUS/data/perdata/train_text_all_samples/timeseries_merged\chartevents.csv
    inputevents.csv   6.49 GB 17881065    37     E:/NUS/data/perdata/train_text_all_samples/timeseries_merged\inputevents.csv
      labevents.csv 601.64 MB  2916563    23       E:/NUS/data/perdata/train_text_all_samples/timeseries_merged\labevents.csv
procedureevents.csv 371.39 MB  1151392    37 E:/NUS/data/perdata/train_text_all_samples/timeseries_merged\procedureevents.csv
   outputevents.csv  78.75 MB   427810    20    E:/NUS/data/perdata/train_text_all_samples/timeseries_merged\outputevents.csv


In [46]:
from dataset import bucketize_timeseries_csv

bucketize_timeseries_csv(
    TS_DIR="E:/NUS/data/perdata/train_text_all_samples/timeseries_merged",
    OUT_DIR="E:/NUS/data/perdata/train_text_all_samples/timeseries_bucketed",
    num_buckets=128,
    chunksize=100_000,
)



[TS-BUCKET] table=chartevents file=chartevents.csv
  chunk 0: rows=100,000
  chunk 1: rows=100,000
  chunk 2: rows=100,000
  chunk 3: rows=100,000
  chunk 4: rows=100,000
  chunk 5: rows=100,000
  chunk 6: rows=100,000
  chunk 7: rows=100,000
  chunk 8: rows=100,000
  chunk 9: rows=100,000
  chunk 10: rows=100,000
  chunk 11: rows=100,000
  chunk 12: rows=100,000
  chunk 13: rows=100,000
  chunk 14: rows=100,000
  chunk 15: rows=100,000
  chunk 16: rows=100,000
  chunk 17: rows=100,000
  chunk 18: rows=100,000
  chunk 19: rows=100,000
  chunk 20: rows=100,000
  chunk 21: rows=100,000
  chunk 22: rows=100,000
  chunk 23: rows=100,000
  chunk 24: rows=100,000
  chunk 25: rows=100,000
  chunk 26: rows=100,000
  chunk 27: rows=100,000
  chunk 28: rows=100,000
  chunk 29: rows=100,000
  chunk 30: rows=100,000
  chunk 31: rows=100,000
  chunk 32: rows=100,000
  chunk 33: rows=100,000
  chunk 34: rows=100,000
  chunk 35: rows=100,000
  chunk 36: rows=100,000
  chunk 37: rows=100,000
  chunk 

In [None]:
from dataset import *

import os
import json
import numpy as np
import pandas as pd
from collections import OrderedDict

from config import TAB_CFG, TS_CFG


# -----------------------------
# tabular helpers
# -----------------------------
def read_csv_light(path, usecols=None, dtype=None, parse_dates=None, chunksize=None):
    return pd.read_csv(
        path,
        usecols=usecols,
        dtype=dtype,
        parse_dates=parse_dates,
        low_memory=False,
        chunksize=chunksize,
    )

def build_subject_index(df, key="subject_id"):
    return df.groupby(key, sort=False).indices

def load_tabular_all(TAB_DIR: str):
    tab_df = {}
    tab_idx = {}

    for name, cfg in TAB_CFG.items():
        path = os.path.join(TAB_DIR, f"{name}.csv")
        if not os.path.exists(path):
            print(f"[TAB][SKIP] {name}.csv not found")
            continue

        df = read_csv_light(
            path,
            usecols=cfg.get("usecols"),
            dtype=cfg.get("dtype"),
            parse_dates=cfg.get("parse_dates"),
            chunksize=None,
        )

        if "subject_id" in df.columns:
            df["subject_id"] = pd.to_numeric(df["subject_id"], errors="coerce")
            df = df[df["subject_id"].notna()].copy()
            df["subject_id"] = df["subject_id"].astype("int32")

        tab_df[name] = df
        tab_idx[name] = build_subject_index(df, "subject_id") if "subject_id" in df.columns else {}
        print(f"[TAB] {name}: rows={len(df):,} cols={len(df.columns)}")

    return tab_df, tab_idx

def query_tabular_indices(tab_idx, subject_id: int):
    """只返回每个表该 subject_id 对应的行号列表（更省）"""
    sid = int(subject_id)
    out = {}
    for name, idx_map in tab_idx.items():
        inds = idx_map.get(sid)
        out[name] = inds.tolist() if inds is not None else []
    return out




def iter_all_samples(
    df_index: pd.DataFrame,
    IMG_DIR: str,
    tab_df, tab_idx,
    text_map_dicom, text_map_subject,
    ts_reader: TimeseriesBucketReader,
    ts_tables,
    hours: int = 48,
):
    """
    逐样本yield：每次只在内存里保留一个sample，适合全量不炸。
    """
    for row in df_index.itertuples(index=False):
        sample_id = getattr(row, "sample_id")
        subject_id = int(getattr(row, "subject_id"))
        dicom_id = getattr(row, "dicom_id", None)
        cxr_time = pd.to_datetime(getattr(row, "cxr_time"), errors="coerce")

        sample = {
            "sample_id": sample_id,
            "subject_id": subject_id,
            "dicom_id": dicom_id,
            "cxr_time": cxr_time,

            "label_any": getattr(row, "label_any", None),
            "label_multi": getattr(row, "label_multi", None),

            # meta如果你要也可以放：但别放太大字段
            "meta": row._asdict(),
        }

        # -------- image：这里是“真的读出来” --------
        img_path = os.path.join(IMG_DIR, f"{sample_id}.npy")
        sample["image"] = np.load(img_path) if os.path.exists(img_path) else None

        # -------- text：真的读出来 --------
        sample["text"] = query_text(
            text_map_dicom=text_map_dicom,
            text_map_subject=text_map_subject,
            subject_id=subject_id,
            dicom_id=dicom_id,
            cxr_time=cxr_time,
            tol_minutes=5,
        )

        # -------- tabular：真的读出来（按subject过滤）--------
        sample["tabular"] = query_tabular(tab_df, tab_idx, subject_id)

        # -------- timeseries：真的读出来（48h window）--------
        ts_data = {}
        for tname in ts_tables:
            ts_data[tname] = ts_reader.query_window(
                table=tname,
                subject_id=subject_id,
                cxr_time=cxr_time,
                hours=hours,
            )
        sample["timeseries"] = ts_data

        yield sample


def process_one(sample: dict):
    """
    你下一步处理写在这里：比如提特征、做聚合、保存到磁盘、喂模型等。
    这里给个占位示例：统计各表行数。
    """
    out = {
        "sample_id": sample["sample_id"],
        "ts_lens": {k: len(v) for k, v in sample["timeseries"].items()},
        "text_len": len(sample["text"]),
        "has_image": sample["image"] is not None,
    }
    return out


if __name__ == "__main__":
    ROOT = r"E:/NUS/data/perdata/train_text_all_samples"
    META_DIR = os.path.join(ROOT, "meta")
    IMG_DIR = os.path.join(ROOT, "images")
    TAB_DIR = os.path.join(ROOT, "tabular")
    TEXT_PATH = os.path.join(ROOT, "text.csv")

    TS_BUCKET_ROOT = r"E:/NUS/data/perdata/train_text_all_samples/timeseries_bucketed"
    SAMPLE_INDEX_PATH = os.path.join(META_DIR, "sample_index.json")

    with open(SAMPLE_INDEX_PATH, "r") as f:
        sample_index = json.load(f)
    df_index = pd.DataFrame(sample_index)
    print("Total samples:", len(df_index))

    tab_df, tab_idx = load_tabular_all(TAB_DIR)
    _, text_map_dicom, text_map_subject = load_text_table(TEXT_PATH)

    ts_reader = TimeseriesBucketReader(
        bucket_root=TS_BUCKET_ROOT,
        num_buckets=128,   # 这里要和你 bucketize 一致
        cache_size=4,
    )
    ts_tables = list(TS_CFG.keys())

    # ===== 分开存四部分 =====
    metas = []        # 只存必要meta（小）
    imgs  = []        # numpy array / None
    texts = []        # str
    tabs  = []        # dict(table_name -> df)
    tss   = []        # dict(ts_table -> df)

    for i, sample in enumerate(iter_all_samples(
        df_index=df_index,
        IMG_DIR=IMG_DIR,
        tab_df=tab_df,
        tab_idx=tab_idx,
        text_map_dicom=text_map_dicom,
        text_map_subject=text_map_subject,
        ts_reader=ts_reader,
        ts_tables=ts_tables,
        hours=48,
    )):
        # meta：建议只留你要的字段，别把 row._asdict() 全塞进去也行
        metas.append({
            "sample_id": sample["sample_id"],
            "subject_id": sample["subject_id"],
            "dicom_id": sample.get("dicom_id"),
            "cxr_time": sample.get("cxr_time"),
            "label_any": sample.get("label_any"),
            "label_multi": sample.get("label_multi"),
        })

        imgs.append(sample["image"])
        texts.append(sample["text"])
        tabs.append(sample["tabular"])
        tss.append(sample["timeseries"])

        if (i + 1) % 1000 == 0:
            print(f"[PROGRESS] {i+1:,}/{len(df_index):,}")

    print("done")
    print("metas:", len(metas), "imgs:", len(imgs), "texts:", len(texts), "tabs:", len(tabs), "tss:", len(tss))

    # ===== 你可以马上抽查第0条 =====
    k = 0
    print("sample_id:", metas[k]["sample_id"])
    print("img shape:", None if imgs[k] is None else imgs[k].shape)
    print("text len:", len(texts[k]))
    print("tab keys:", list(tabs[k].keys()))
    print("ts keys:", list(tss[k].keys()))




Total samples: 109968
[TAB] demographics: rows=109,968 cols=4
[TAB] admissions: rows=109,968 cols=8
[TAB] icustays: rows=109,968 cols=6
[TAB] transfers: rows=664,791 cols=6


ParserError: Error tokenizing data. C error: out of memory

: 

In [35]:
sid = int(s["subject_id"])
print("sid:", sid)

# 1) 你 reader 用的 buckets 数
print("reader.num_buckets:", ts_reader.num_buckets)

# 2) bucket 文件是否存在（你算出来的 bucket）
bid = sid % ts_reader.num_buckets
path = os.path.join(ts_reader.bucket_root, "chartevents", f"bucket_{bid:03d}.csv")
print("expected bucket path:", path, "exists:", os.path.exists(path))

# 3) 如果存在，看看这个 bucket 里到底有没有 sid
if os.path.exists(path):
    df0 = pd.read_csv(path, usecols=["subject_id"], low_memory=False)
    col = pd.to_numeric(df0["subject_id"], errors="coerce")
    print("bucket rows:", len(col), "hit:", (col == sid).sum())


sid: 10000032
reader.num_buckets: 8
expected bucket path: E:/NUS/data/perdata/train_text_samples/timeseries_bucketed\chartevents\bucket_000.csv exists: True
bucket rows: 90619 hit: 0


In [36]:
import pandas as pd

sid = 10000032
src = r"E:/NUS/data/perdata/train_text_samples/timeseries_merged/chartevents.csv"

found = False
for chunk in pd.read_csv(src, usecols=["subject_id"], chunksize=200000, low_memory=False):
    col = pd.to_numeric(chunk["subject_id"], errors="coerce")
    if (col == sid).any():
        found = True
        break

print("found sid in SOURCE chartevents:", found)


found sid in SOURCE chartevents: False


In [None]:
import os
import json
import numpy as np
import pandas as pd
from collections import OrderedDict

from config import TAB_CFG, TS_CFG

from dataset import load_tabular_all, TimeseriesBucketReader, load_text_table, query_text, query_tabular


def iter_all_samples(
    df_index: pd.DataFrame,
    IMG_DIR: str,
    tab_df, tab_idx,
    text_map_dicom, text_map_subject,
    ts_reader: TimeseriesBucketReader,
    ts_tables,
    hours: int = 48,
):
    """
    逐样本yield：每次只在内存里保留一个sample，适合全量不炸。
    """
    for row in df_index.itertuples(index=False):
        sample_id = getattr(row, "sample_id")
        subject_id = int(getattr(row, "subject_id"))
        dicom_id = getattr(row, "dicom_id", None)
        cxr_time = pd.to_datetime(getattr(row, "cxr_time"), errors="coerce")

        sample = {
            "sample_id": sample_id,
            "subject_id": subject_id,
            "dicom_id": dicom_id,
            "cxr_time": cxr_time,

            "label_any": getattr(row, "label_any", None),
            "label_multi": getattr(row, "label_multi", None),

            # meta如果你要也可以放：但别放太大字段
            "meta": row._asdict(),
        }

        # -------- image：这里是“真的读出来” --------
        img_path = os.path.join(IMG_DIR, f"{sample_id}.npy")
        sample["image"] = np.load(img_path) if os.path.exists(img_path) else None

        # -------- text：真的读出来 --------
        sample["text"] = query_text(
            text_map_dicom=text_map_dicom,
            text_map_subject=text_map_subject,
            subject_id=subject_id,
            dicom_id=dicom_id,
            cxr_time=cxr_time,
            tol_minutes=5,
        )

        # -------- tabular：真的读出来（按subject过滤）--------
        sample["tabular"] = query_tabular(tab_df, tab_idx, subject_id)

        # -------- timeseries：真的读出来（48h window）--------
        ts_data = {}
        for tname in ts_tables:
            ts_data[tname] = ts_reader.query_window(
                table=tname,
                subject_id=subject_id,
                cxr_time=cxr_time,
                hours=hours,
            )
        sample["timeseries"] = ts_data

        yield sample


def process_one(sample: dict):
    """
    你下一步处理写在这里：比如提特征、做聚合、保存到磁盘、喂模型等。
    这里给个占位示例：统计各表行数。
    """
    out = {
        "sample_id": sample["sample_id"],
        "ts_lens": {k: len(v) for k, v in sample["timeseries"].items()},
        "text_len": len(sample["text"]),
        "has_image": sample["image"] is not None,
    }
    return out


if __name__ == "__main__":
    ROOT = r"E:/NUS/data/perdata/train_text_samples"
    META_DIR = os.path.join(ROOT, "meta")
    IMG_DIR = os.path.join(ROOT, "images")
    TAB_DIR = os.path.join(ROOT, "tabular")
    TEXT_PATH = os.path.join(ROOT, "text.csv")

    TS_BUCKET_ROOT = r"E:/NUS/data/perdata/timeseries_bucketed"
    SAMPLE_INDEX_PATH = os.path.join(META_DIR, "sample_index.json")

    with open(SAMPLE_INDEX_PATH, "r") as f:
        sample_index = json.load(f)
    df_index = pd.DataFrame(sample_index)
    print("Total samples:", len(df_index))

    # load once
    tab_df, tab_idx = load_tabular_all(TAB_DIR)
    _, text_map_dicom, text_map_subject = load_text_table(TEXT_PATH)

    # ✅ 必须和 bucketize 一致
    ts_reader = TimeseriesBucketReader(
        bucket_root=TS_BUCKET_ROOT,
        num_buckets=128,
        cache_size=6,
    )
    ts_tables = list(TS_CFG.keys())

    # 全量 streaming 处理（不存list）
    results = []
    for i, sample in enumerate(iter_all_samples(
        df_index=df_index,
        IMG_DIR=IMG_DIR,
        tab_df=tab_df,
        tab_idx=tab_idx,
        text_map_dicom=text_map_dicom,
        text_map_subject=text_map_subject,
        ts_reader=ts_reader,
        ts_tables=ts_tables,
        hours=48,
    )):
        results.append(sample)

        if (i + 1) % 1000 == 0:
            print(f"[PROGRESS] {i+1:,}/{len(df_index):,}")

    print("done, results:", len(results))



ImportError: cannot import name 'load_text_table' from 'dataset' (f:\study\NUS\capstone\GMM\data\dataset.py)

In [32]:
s = results[0]   # 或者你当前 sample

print("image:", None if s["image"] is None else (s["image"].shape, s["image"].dtype))
print("text first 200:", s["text"][:200])

print("tabular keys:", list(s["tabular"].keys()))
print("admissions head:")
print(s["tabular"]["admissions"].head())

print("timeseries keys:", list(s["timeseries"].keys()))
print("chartevents head:")
print(s["timeseries"]["chartevents"].head())



image: ((224, 224), dtype('uint8'))
text first 200: 
tabular keys: ['demographics', 'admissions', 'icustays', 'transfers', 'prescriptions', 'procedures_icd']
admissions head:
   subject_id   hadm_id           admittime admission_type admission_location  \
0    10000032  29079034 2180-07-23 12:35:00       EW EMER.     EMERGENCY ROOM   
1    10000032  29079034 2180-07-23 12:35:00       EW EMER.     EMERGENCY ROOM   
2    10000032  29079034 2180-07-23 12:35:00       EW EMER.     EMERGENCY ROOM   
3    10000032  29079034 2180-07-23 12:35:00       EW EMER.     EMERGENCY ROOM   
4    10000032  29079034 2180-07-23 12:35:00       EW EMER.     EMERGENCY ROOM   

  insurance language marital_status  
0  Medicaid  ENGLISH        WIDOWED  
1  Medicaid  ENGLISH        WIDOWED  
2  Medicaid  ENGLISH        WIDOWED  
3  Medicaid  ENGLISH        WIDOWED  
4  Medicaid  ENGLISH        WIDOWED  
timeseries keys: ['chartevents', 'labevents', 'inputevents', 'outputevents', 'procedureevents']
chartevents he

In [38]:
   # 拿第一个 sample 或你指定的 sample

print("sample_id:", s["sample_id"])
print("image:", None if s["image"] is None else (s["image"].shape, s["image"].dtype))
print("text_len:", len(s["text"]))

print("tabular rows:")
for k, df in s["tabular"].items():
    print(" ", k, len(df))

print("timeseries rows (window):")
for k, df in s["timeseries"].items():
    print(" ", k, len(df))



sample_id: s_000000
image: ((224, 224), dtype('uint8'))
text_len: 0
tabular rows:
  demographics 10
  admissions 10
  icustays 10
  transfers 60
  prescriptions 240
  procedures_icd 0
timeseries rows (window):
  chartevents 0
  labevents 0
  inputevents 0
  outputevents 0
  procedureevents 0


In [42]:
import pandas as pd

sample_id = "s_000000"
src = r"E:/NUS/data/perdata/train_text_samples/timeseries_merged/chartevents.csv"

hit = []
for chunk in pd.read_csv(src, chunksize=200000, low_memory=False):
    sub = chunk[chunk["sample_id"] == sample_id]
    if len(sub):
        hit.append(sub)

df = pd.concat(hit, ignore_index=True) if hit else pd.DataFrame()
print("rows:", len(df))
print(df.head())



rows: 0
Empty DataFrame
Columns: []
Index: []


In [43]:
import pandas as pd

# 读数据
outputevents = pd.read_csv(
    r"E:\NUS\data\perdata\train_text_all_samples\timeseries_merged\outputevents.csv"
)
demographics = pd.read_csv(
    r"E:\NUS\data\perdata\train_text_all_samples\tabular\demographics.csv"
)

# 取 subject_id 集合
oe_ids = set(outputevents['subject_id'].dropna().unique())
demo_ids = set(demographics['subject_id'].dropna().unique())

print("outputevents 数量:", len(oe_ids))
print("demographics 数量:", len(demo_ids))
print("只在 outputevents:", len(oe_ids - demo_ids))
print("只在 demographics:", len(demo_ids - oe_ids))
print("是否完全一致:", oe_ids == demo_ids)


outputevents 数量: 1778
demographics 数量: 2990
只在 outputevents: 0
只在 demographics: 1212
是否完全一致: False
