# 1 Prepare the Dataset

Using the **BloodMNIST** dataset as an example.

Prepare the dataset and wrap the data into DataLoaders for batch feature extraction.

Prepare the following variables:
1. `label_list`: List of label texts, e.g. for BloodMNIST: `['basophil', 'eosinophil', ...]`
2. `label_prompt_template_list`: List of prompt templates for labels. Example:
    ```python
    label_prompt_template_list = [
        "a microscopic image of a {} cell",
        "a peripheral blood smear image of a {}",
        "a bloodcell of {}"
    ]
    ```
3. `concept_list`: List of concept texts (designed by experts), e.g. for BloodMNIST: `['Segmented nucleus','Band nucleus (band form)','Reniform / indented nucleus','Round nucleus', ...]`
4. `concept_prompt_template_list`: List of prompt templates for concepts. Example:
    ```python
    concept_prompt_template_list = [
        "a cell photo with sign of {}",
        "a photo of a cell with {}",
        "a cell image indicating {}",
    ]
    ```
5. `train_loader`, `test_loader`, `valid_loader`: DataLoaders (torch.utils.data.DataLoader) for training, test, and validation sets.

In [1]:
import torch
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms

import numpy as np

from tqdm import tqdm

import medmnist
from medmnist import BloodMNIST, INFO

In [None]:
# Select BloodMNIST dataset information
data_flag = 'bloodmnist'
info = INFO[data_flag]

In [None]:
# Label mapping dictionary (label_map)
label_map = {int(k): v for k, v in info['label'].items()}  # e.g. 0:'basophil', 1:'eosinophil', ...
label_map[3] = 'immature granulocytes'

# Create list of label texts
label_list = [label_map[idx] for idx in sorted(label_map.keys())]
label_list

['basophil',
 'eosinophil',
 'erythroblast',
 'immature granulocytes',
 'lymphocyte',
 'monocyte',
 'neutrophil',
 'platelet']

In [None]:
# Select 15 cellular morphological features as concepts (manually chosen)
cell_features = [
    {"en": "Segmented nucleus", "zh": "分叶细胞核"},
    {"en": "Band nucleus (band form)", "zh": "带状细胞核（未分叶）"},
    {"en": "Reniform / indented nucleus", "zh": "肾形/凹陷细胞核"},
    {"en": "Round nucleus", "zh": "圆形细胞核"},
    {"en": "Fine azurophilic granules", "zh": "细嗜天青颗粒"},
    {"en": "Eosinophilic granules", "zh": "嗜酸性颗粒（橙红）"},
    {"en": "Basophilic granules", "zh": "嗜碱性颗粒（深紫粗颗粒）"},
    {"en": "Basophilic cytoplasm", "zh": "嗜碱性胞质"},
    {"en": "Cytoplasmic vacuoles", "zh": "胞质空泡"},
    {"en": "High nuclear-to-cytoplasmic ratio", "zh": "高核浆比"},
    {"en": "Pale cytoplasm", "zh": "淡染胞质"},
    {"en": "Nucleated erythrocyte (erythroblast)", "zh": "有核红细胞（幼红细胞）"},
    {"en": "Platelet fragments / clumps", "zh": "血小板碎片/成团"},
    {"en": "Stain precipitate (artifact)", "zh": "染色沉淀（伪影）"},
    {"en": "Overlapping cell clumps (artifact)", "zh": "细胞重叠/成团（伪影）"},
]

# Generate concept text list
concept_list = [cell_feature["en"] for cell_feature in cell_features]

In [None]:
# Prepare prompt templates
concept_prompt_template_list = [
    "a cell photo with sign of {}",
    "a photo of a cell with {}",
    "a cell image indicating {}",
    "an image of a cell showing {}",
    "blood cell with {}",
    "a blood cell photo with sign of {}",
    "a photo of a blood cell with {}",
    "a blood cell image indicating {}",
    "an image of blood cell showing {}",
]

label_prompt_template_list = [
    "a microscopic image of a {} cell",
    "a peripheral blood smear image of a {}",
    "a bloodcell of {}",
]

DB_METADATA = {
    "label_texts": label_list,
    "prompt_temp_for_labels": label_prompt_template_list,
    "concept_texts": concept_list,
    "prompt_temp_for_concepts": concept_prompt_template_list,
}

In [None]:
# Acquire BloodMNIST subclass via DataClass
DataClass = getattr(medmnist, info['python_class'])

common_tf = transforms.Compose([
    # Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255]
    # to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
    transforms.ToTensor(),
])

ds_train = DataClass(split='train', size=224, as_rgb=True, mmap_mode='r', transform=common_tf, download=True)
ds_test  = DataClass(split='test',  size=224, as_rgb=True, mmap_mode='r', transform=common_tf, download=True)
ds_valid = DataClass(split='val',   size=224, as_rgb=True, mmap_mode='r', transform=common_tf, download=True)

print("num_train:", len(ds_train), "sample shape:", np.array(ds_train[0][0]).shape)  # (224, 224, 3)
print("num_test:", len(ds_test), "sample shape:", np.array(ds_test[0][0]).shape)    # (224, 224, 3)
print("num_valid:", len(ds_valid), "sample shape:", np.array(ds_valid[0][0]).shape) # (224, 224, 3)

Using downloaded and verified file: C:\Users\chenk\.medmnist\bloodmnist_224.npz
Using downloaded and verified file: C:\Users\chenk\.medmnist\bloodmnist_224.npz
Using downloaded and verified file: C:\Users\chenk\.medmnist\bloodmnist_224.npz
num_train: 11959 sample shape: (3, 224, 224)
num_test: 3421 sample shape: (3, 224, 224)
num_valid: 1712 sample shape: (3, 224, 224)


In [None]:
# Create DataLoaders
train_loader = DataLoader(ds_train, batch_size=64, shuffle=False)
test_loader  = DataLoader(ds_test, batch_size=64, shuffle=False)
valid_loader = DataLoader(ds_valid, batch_size=64, shuffle=False)

# 2 Load the CLIP Model

In [None]:
# Load CLIP model from HuggingFace
from transformers import AutoModel, AutoProcessor
from huggingface_hub import login, whoami
login(token="<YOUR_HF_TOKEN>")  # replace <YOUR_HF_TOKEN> with your actual token
print(whoami())

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to C:\Users\chenk\.cache\huggingface\token
Login successful
{'type': 'user', 'id': '6693342048f8a55b7eba3d8e', 'name': 'takedachia', 'fullname': 'Jack Chan', 'email': 'chenkai1989129@gmail.com', 'emailVerified': True, 'canPay': False, 'periodEnd': None, 'isPro': False, 'avatarUrl': '/avatars/df59d16febde1a73fc98bf70f5476ba2.svg', 'orgs': [], 'auth': {'type': 'access_token', 'accessToken': {'displayName': 'ConceptCLIP', 'role': 'read', 'createdAt': '2025-08-20T17:55:27.543Z'}}}


In [None]:
# Load model and processor (first time will download to local cache)
model = AutoModel.from_pretrained('JerrryNie/ConceptCLIP', trust_remote_code=True)
processor = AutoProcessor.from_pretrained('JerrryNie/ConceptCLIP', trust_remote_code=True)

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


In [None]:
# Move model to GPU if available (else CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device).eval()

In [None]:
# Adjust processor image preprocessing settings
# BloodMNIST DataLoader already yields float tensors in [0,1], so we disable rescale and enable normalize
processor.image_processor.do_rescale = False
processor.image_processor.do_normalize = True

# 3 Use Text and Image Encoders to Infer, Extract, and Store Features

## (1) Deploy H5 database for storing outputs

In [None]:
import os, h5py, hashlib, json, time
from pathlib import Path
from typing import Optional, Sequence

OUT_PATH = "./conceptclip_features.h5"
D       = 1152
T_img   = 729
DTYPE_DYNAMIC = "float16"   # Dynamic parts (image_*) can be stored in lower precision to save space
DTYPE_TEXT    = "float32"   # Text features are recommended to remain in higher precision
DTYPE_STR     = h5py.string_dtype(encoding="utf-8")

def _serialize_for_attr(value):
    """Serialize a value for storing as an HDF5 file attribute."""
    return json.dumps(value, ensure_ascii=False)

def _read_json_attr(f: h5py.File, key: str, default=None):
    """Read a JSON-formatted attribute from the root of the H5 file and parse it."""
    if key not in f.attrs:
        return default
    raw = f.attrs[key]
    if isinstance(raw, bytes):
        raw = raw.decode("utf-8")
    if isinstance(raw, str):
        try:
            return json.loads(raw)
        except json.JSONDecodeError:
            return raw
    return raw

def _hash_texts(texts: Sequence[str]) -> str:
    """Compute a SHA256 hash for a list of texts."""
    h = hashlib.sha256()
    for t in texts:
        h.update(t.encode("utf-8")); h.update(b"\0")
    return h.hexdigest()

def init_file(path, d=D, t_img=T_img, metadata: Optional[dict] = None):
    """Initialize HDF5 file structure if missing; open in append mode."""
    if not os.path.exists(path):
        with h5py.File(path, "w") as f:
            # Create main datasets
            f.create_dataset("image_features",
                shape=(0, d), maxshape=(None, d),
                chunks=(64, d), dtype=DTYPE_DYNAMIC, compression="lzf")
            f.create_dataset("image_token_features",
                shape=(0, t_img, d), maxshape=(None, t_img, d),
                chunks=(4, t_img, d), dtype=DTYPE_DYNAMIC, compression="lzf")
            f.create_dataset("ids",
                shape=(0,), maxshape=(None,),
                chunks=(4096,), dtype="int64", compression="lzf")
            f.create_dataset("labels",
                shape=(0,), maxshape=(None,),
                chunks=(4096,), dtype="int64", compression="lzf")
            f.create_dataset("split",
                shape=(0,), maxshape=(None,),
                chunks=(4096,), dtype=DTYPE_STR, compression="lzf")
            
            # Create prompt-templates group
            f.create_group("templates")
            
            # Add global attributes (model parameters)
            f.attrs["D"] = d
            f.attrs["T_img"] = t_img
            f.attrs["created_at"] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
            f.attrs["version"] = "1.0"
            
    f = h5py.File(path, "a")
    current_len = f["image_features"].shape[0] if "image_features" in f else 0
    if "image_token_features" not in f:
        f.create_dataset("image_token_features",
            shape=(current_len, t_img, d), maxshape=(None, t_img, d),
            chunks=(4, t_img, d), dtype=DTYPE_DYNAMIC, compression="lzf")
    if "ids" not in f:
        f.create_dataset("ids",
            shape=(current_len,), maxshape=(None,),
            chunks=(4096,), dtype="int64", compression="lzf")
    if "labels" not in f:
        ds_labels = f.create_dataset("labels",
            shape=(current_len,), maxshape=(None,),
            chunks=(4096,), dtype="int64", compression="lzf")
        if current_len > 0:
            ds_labels[:] = np.full((current_len,), -1, dtype="int64")
    if "split" not in f:
        ds_split = f.create_dataset("split",
            shape=(current_len,), maxshape=(None,),
            chunks=(4096,), dtype=DTYPE_STR, compression="lzf")
        if current_len > 0:
            ds_split[:] = np.array(["unspecified"] * current_len, dtype=object)
    if "templates" not in f:
        f.create_group("templates")
    if metadata:
        for key, value in metadata.items():
            try:
                f.attrs[key] = _serialize_for_attr(value)
            except TypeError:
                f.attrs[key] = _serialize_for_attr(str(value))
    return f

def _to_np(x, dtype):
    """Convert tensor to numpy array with specified dtype."""
    if torch.is_tensor(x):
        x = x.detach().cpu().numpy()
    return x.astype(dtype, copy=False)

def append_batch(f: h5py.File,
                 image_feats: torch.Tensor,            # [B, D]
                 image_token_feats: torch.Tensor,      # [B, T_img, D]
                 ids: np.ndarray,                     # [B] int64
                 split_names: Optional[Sequence[str]] = None,
                 labels: Optional[np.ndarray] = None):
    """Append a batch of data to the HDF5 file."""
    ds_img = f["image_features"]
    ds_tok = f["image_token_features"]
    ds_ids = f["ids"]
    ds_labels = f["labels"]
    ds_split = f["split"]
    B = image_feats.shape[0]

    n0 = ds_img.shape[0]
    ds_img.resize(n0 + B, axis=0)
    ds_tok.resize(n0 + B, axis=0)
    ds_ids.resize(n0 + B, axis=0)
    ds_labels.resize(n0 + B, axis=0)
    ds_split.resize(n0 + B, axis=0)

    ds_img[n0:n0+B, :]      = _to_np(image_feats, DTYPE_DYNAMIC)
    ds_tok[n0:n0+B, :, :]   = _to_np(image_token_feats, DTYPE_DYNAMIC)
    ds_ids[n0:n0+B]         = ids.astype("int64", copy=False)

    if labels is None:
        labels_arr = np.full((B,), -1, dtype="int64")
    else:
        labels_arr = np.asarray(labels, dtype="int64").reshape(-1)
        if labels_arr.shape[0] != B:
            raise ValueError(f"Label count {labels_arr.shape[0]} does not match batch size {B}.")
    ds_labels[n0:n0+B] = labels_arr

    if split_names is None:
        split_arr = np.array(["unspecified"] * B, dtype=object)
    else:
        split_arr = np.asarray(list(split_names), dtype=object).reshape(-1)
        if split_arr.shape[0] != B:
            raise ValueError(f"Split count {split_arr.shape[0]} does not match batch size {B}.")
    ds_split[n0:n0+B] = split_arr

    f.flush()

def write_template(f: h5py.File,
                   template_id: str,
                   texts: Sequence[str],
                   text_features: torch.Tensor,            # [K, D]
                   text_token_features: Optional[torch.Tensor]=None # [K, T_txt, D]
                   ):
    """Write text templates and their features."""
    g_root = f["templates"]
    if template_id in g_root:
        del g_root[template_id]             # Overwrite existing (change logic if verification needed)
    g = g_root.create_group(template_id)

    # Main data
    tf  = _to_np(text_features, DTYPE_TEXT)         # [K, D]
    g.create_dataset("text_features", data=tf, compression="lzf")

    if text_token_features is not None:
        ttf = _to_np(text_token_features, DTYPE_TEXT)     # [K, T_txt, D]
        g.create_dataset("text_token_features", data=ttf, compression="lzf")
        T_txt = ttf.shape[1]
    else:
        T_txt = -1

    # Save original prompts (variable-length UTF-8 strings)
    dt_str = h5py.string_dtype(encoding='utf-8')
    ds_txt = g.create_dataset("texts", shape=(len(texts),), dtype=dt_str, compression="lzf")
    ds_txt[:] = np.array(list(texts), dtype=object)

    # Metadata
    g.attrs["K"]          = tf.shape[0]
    g.attrs["D"]          = tf.shape[1]
    g.attrs["T_txt"]      = int(T_txt)
    g.attrs["texts_hash"] = _hash_texts(texts)
    g.attrs["created_at"] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    f.flush()

def save_model_params(f: h5py.File, 
                      logit_scale: float = None,
                      logit_bias: float = None,
                      concept_logit_scale: float = None,
                      concept_logit_bias: float = None):
    """Persist key model parameters to root attributes of the HDF5 file."""
    if logit_scale is not None:
        f.attrs["logit_scale"] = float(logit_scale)
    if logit_bias is not None:
        f.attrs["logit_bias"] = float(logit_bias)
    if concept_logit_scale is not None:
        f.attrs["concept_logit_scale"] = float(concept_logit_scale)
    if concept_logit_bias is not None:
        f.attrs["concept_logit_bias"] = float(concept_logit_bias)
    f.flush()

def load_model_params(f: h5py.File) -> dict:
    """Load model parameters from file attributes."""
    params = {}
    for key in ["logit_scale", "logit_bias", "concept_logit_scale", "concept_logit_bias"]:
        if key in f.attrs:
            params[key] = f.attrs[key]
    return params

## (2) Compute and save concept text features to H5 (using the Text Encoder)

In [26]:
def _maybe_to_scalar(x):
    if x is None:
        return None
    if isinstance(x, (int, float)):
        return float(x)
    if torch.is_tensor(x):
        x = x.detach()
        if x.numel() == 1:
            return float(x.item())
        try:
            return float(x.squeeze().item())
        except Exception:
            return None
    if hasattr(x, "item"):
        try:
            return float(x.item())
        except Exception:
            return None
    return None

stored_templates = []
with init_file(OUT_PATH, metadata=DB_METADATA) as f:
    concept_texts_in_file = _read_json_attr(f, "concept_texts", concept_list)
    concept_template_list = _read_json_attr(f, "prompt_temp_for_concepts", concept_prompt_template_list)
    label_texts_in_file = _read_json_attr(f, "label_texts", label_list)
    label_template_list = _read_json_attr(f, "prompt_temp_for_labels", label_prompt_template_list)

    template_texts_map = {}
    if isinstance(concept_template_list, (list, tuple)) and concept_texts_in_file:
        for idx, template in enumerate(concept_template_list, start=1):
            formatted_texts = [template.format(text) for text in concept_texts_in_file]
            template_key = f"concept_prompts_t{idx:02d}"
            template_texts_map[template_key] = formatted_texts
    if isinstance(label_template_list, (list, tuple)) and label_texts_in_file:
        for idx, template in enumerate(label_template_list, start=1):
            formatted_label_texts = [template.format(text) for text in label_texts_in_file]
            template_key = f"label_prompts_t{idx:02d}"
            template_texts_map[template_key] = formatted_label_texts

    with torch.no_grad():
        for template_id, template_texts in template_texts_map.items():
            text_inputs = processor(text=template_texts, return_tensors="pt", padding=True, truncation=True).to(device)
            text_cls, text_tokens = model.encode_text(text_inputs["input_ids"], normalize=True)  # [K, D], [K, T_txt, D]
            text_tokens_proj = model.text_proj(text_tokens) if hasattr(model, "text_proj") else text_tokens
            write_template(f, template_id=template_id, texts=template_texts, text_features=text_cls, text_token_features=text_tokens_proj)
            stored_templates.append((template_id, text_cls.shape[0]))
    params_kwargs = {
        "logit_scale": _maybe_to_scalar(getattr(model, "logit_scale", None)),
        "logit_bias": _maybe_to_scalar(getattr(model, "logit_bias", None)),
        "concept_logit_scale": _maybe_to_scalar(getattr(model, "concept_logit_scale", None)),
        "concept_logit_bias": _maybe_to_scalar(getattr(model, "concept_logit_bias", None)),
    }
    save_model_params(f, **{k: v for k, v in params_kwargs.items() if v is not None})

print("Stored templates:")
for template_id, count in stored_templates:
    print(f"  - {template_id}: {count} prompts")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Stored templates:
  - concept_prompts_t01: 15 prompts
  - concept_prompts_t02: 15 prompts
  - concept_prompts_t03: 15 prompts
  - concept_prompts_t04: 15 prompts
  - concept_prompts_t05: 15 prompts
  - concept_prompts_t06: 15 prompts
  - concept_prompts_t07: 15 prompts
  - concept_prompts_t08: 15 prompts
  - concept_prompts_t09: 15 prompts
  - label_prompts_t01: 8 prompts
  - label_prompts_t02: 8 prompts
  - label_prompts_t03: 8 prompts


## (3) Compute and save image (token) features to H5 (using the Image Encoder)
Token refers to each 27×27 patch.

In [None]:
# Use Image Encoder to compute and save image (token) features to H5

# Define data splits and their loaders
split_loaders = [
    ("train", train_loader),
    ("val", valid_loader),
    ("test", test_loader),
]

def encode_and_store_image_features(split_loaders):
    """Traverse each data split and write image features to HDF5."""
    written_counts = {}
    with torch.no_grad():
        with init_file(OUT_PATH, metadata=DB_METADATA) as f:
            existing_counts = {}
            if "split_counts" in f.attrs:
                try:
                    existing_counts = json.loads(f.attrs["split_counts"])
                except Exception:
                    existing_counts = {}
            start_idx = f["image_features"].shape[0]
            for split_name, loader in split_loaders:
                expected = len(loader.dataset)
                recorded = int(existing_counts.get(split_name, 0) or 0)
                if recorded >= expected:
                    print(f"Skip {split_name}: already stored {recorded}/{expected} samples.")
                    continue

                processed_in_split = 0
                for imgs, labels in tqdm(loader, desc=f"{split_name} split"):
                    batch_total = imgs.shape[0]
                    if processed_in_split + batch_total <= recorded:
                        processed_in_split += batch_total
                        continue
                    if processed_in_split < recorded:
                        offset = recorded - processed_in_split
                        imgs = imgs[offset:]
                        labels = labels[offset:]
                        processed_in_split = recorded
                        batch_total = imgs.shape[0]

                    pixel_inputs = processor(images=imgs, return_tensors="pt", padding=True, truncation=True)
                    pixel_inputs = pixel_inputs["pixel_values"].to(device)
                    img_cls, img_tokens = model.encode_image(pixel_inputs, normalize=True)   # [B,1152], [B,729,1152]
                    img_tokens_proj = model.image_proj(img_tokens) if hasattr(model, "image_proj") else img_tokens  #  [B,729,1152]
                    batch_size = img_cls.shape[0]
                    batch_ids = np.arange(start_idx, start_idx + batch_size, dtype=np.int64)
                    labels_tensor = labels.view(-1) if hasattr(labels, "view") else labels
                    labels_np = np.asarray(labels_tensor, dtype=np.int64).reshape(-1)
                    append_batch(
                        f,
                        image_feats=img_cls,
                        image_token_feats=img_tokens_proj,
                        ids=batch_ids,
                        split_names=[split_name] * batch_size,
                        labels=labels_np,
                    )
                    start_idx += batch_size
                    processed_in_split += batch_size
                    written_counts[split_name] = written_counts.get(split_name, 0) + batch_size

            # Update split counts in file attributes
            for split_name, count in written_counts.items():
                existing_counts[split_name] = int(existing_counts.get(split_name, 0) or 0) + count
            f.attrs["split_counts"] = json.dumps(existing_counts)
            f.flush()

    total_new = sum(written_counts.values())
    if total_new == 0:
        print("No new samples were appended.")
    else:
        print(f"Appended {total_new} samples to {OUT_PATH}.")
        for split_name, count in written_counts.items():
            print(f"  - {split_name}: {count}")

encode_and_store_image_features(split_loaders)

Skip train: already stored 11959/11959 samples.
Skip val: already stored 1712/1712 samples.
Skip test: already stored 3421/3421 samples.
No new samples were appended.


## (4) Validate H5 contents

In [28]:
def validate_h5(path: str = OUT_PATH, max_templates: int = 2):
    if not os.path.exists(path):
        print(f"File '{path}' does not exist yet.")
        return
    with h5py.File(path, "r") as f:
        num_samples = f["image_features"].shape[0]
        print(f"Total samples stored: {num_samples}")
        print(f"image_features dataset shape: {f['image_features'].shape}")
        print(f"image_token_features dataset shape: {f['image_token_features'].shape}")
        ids_preview = f["ids"][:5]
        labels_preview = f["labels"][:5]
        split_preview = f["split"][:5]
        print(f"Preview ids: {ids_preview}")
        print(f"Preview labels: {labels_preview}")
        print(f"Preview splits: {split_preview}")
        try:
            split_counts = json.loads(f.attrs.get("split_counts", "{}"))
        except Exception:
            split_counts = {}
        print(f"Recorded split counts: {split_counts}")
        metadata_keys = [
            "label_texts",
            "prompt_temp_for_labels",
            "concept_texts",
            "prompt_temp_for_concepts",
        ]
        for key in metadata_keys:
            value = _read_json_attr(f, key, None)
            if value is None:
                continue
            if isinstance(value, list):
                preview = value[:3] + (["..."] if len(value) > 3 else [])
                print(f"{key}: {preview}")
            else:
                print(f"{key}: {value}")
        templates = list(f["templates"].keys())
        print(f"Templates stored: {templates}")
        for template_id in templates[:max_templates]:
            g = f["templates"][template_id]
            texts_sample = g["texts"][:3] if "texts" in g else []
            print(f"  • Template '{template_id}' -> K={g.attrs.get('K')}, D={g.attrs.get('D')}, sample texts={texts_sample}")
        params = load_model_params(f)
        print(f"Stored model params: {params}")

validate_h5()

Total samples stored: 17092
image_features dataset shape: (17092, 1152)
image_token_features dataset shape: (17092, 729, 1152)
Preview ids: [0 1 2 3 4]
Preview labels: [7 3 6 6 7]
Preview splits: [b'train' b'train' b'train' b'train' b'train']
Recorded split counts: {'train': 11959, 'val': 1712, 'test': 3421}
label_texts: ['basophil', 'eosinophil', 'erythroblast', '...']
prompt_temp_for_labels: ['a microscopic image of a {} cell', 'a peripheral blood smear image of a {}', 'a bloodcell of {}']
concept_texts: ['Segmented nucleus', 'Band nucleus (band form)', 'Reniform / indented nucleus', '...']
prompt_temp_for_concepts: ['a cell photo with sign of {}', 'a photo of a cell with {}', 'a cell image indicating {}', '...']
Templates stored: ['concept_prompts_t01', 'concept_prompts_t02', 'concept_prompts_t03', 'concept_prompts_t04', 'concept_prompts_t05', 'concept_prompts_t06', 'concept_prompts_t07', 'concept_prompts_t08', 'concept_prompts_t09', 'label_prompts_t01', 'label_prompts_t02', 'label_