In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# 1 Data Preprocessing and Building Dataset & DataLoader

In [2]:
DATASET_FOLDER = "./datasets/curation2/lidc_patches_all"  # Path to the dataset folder (pad=5)

In [3]:
meta_df = pd.read_csv(f"{DATASET_FOLDER}/all_patches_metadata.csv")

In [4]:
# Check cases where a single patient has multiple scans
meta_df.groupby('patient_id')['scan_id'].nunique().sort_values(ascending=False).head(10)

patient_id
LIDC-IDRI-0132    2
LIDC-IDRI-0355    2
LIDC-IDRI-0151    2
LIDC-IDRI-0315    2
LIDC-IDRI-0442    2
LIDC-IDRI-0484    2
LIDC-IDRI-0332    2
LIDC-IDRI-0365    2
LIDC-IDRI-0677    1
LIDC-IDRI-0670    1
Name: scan_id, dtype: int64

In [5]:
meta_df[meta_df['patient_id']=='LIDC-IDRI-0484']

Unnamed: 0,scan_id,patient_id,nodule_index,k_global,img_path,mask_path,area_mm2,nodule_bbox_xmin,nodule_bbox_ymin,nodule_bbox_xmax,nodule_bbox_ymax,ann_subtlety,ann_internalStructure,ann_calcification,ann_sphericity,ann_margin,ann_lobulation,ann_spiculation,ann_texture,ann_malignancy
2314,489,LIDC-IDRI-0484,0,34,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,112.352547,361,122,400,171,5,2,6,5,4,4,4,5,3
2315,489,LIDC-IDRI-0484,0,35,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,163.557468,361,122,400,171,5,2,6,5,4,4,4,5,3
2316,489,LIDC-IDRI-0484,0,36,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,194.876984,361,122,400,171,5,2,6,5,4,4,4,5,3
2317,489,LIDC-IDRI-0484,0,37,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,180.957199,361,122,400,171,5,2,6,5,4,4,4,5,3
2318,489,LIDC-IDRI-0484,0,38,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,86.998653,361,122,400,171,5,2,6,5,4,4,4,5,3
2319,489,LIDC-IDRI-0484,1,36,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,51.702057,261,112,284,135,2,1,6,5,3,5,5,5,4
2320,490,LIDC-IDRI-0484,0,39,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,127.874329,355,130,392,168,5,1,6,4,4,1,1,4,4
2321,490,LIDC-IDRI-0484,0,40,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,158.705688,355,130,392,168,5,1,6,4,4,1,1,4,4
2322,490,LIDC-IDRI-0484,0,41,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,183.471863,355,130,392,168,5,1,6,4,4,1,1,4,4
2323,490,LIDC-IDRI-0484,0,42,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,./lidc_patches_all\LIDC-IDRI-0484\LIDC-IDRI-04...,144.553589,355,130,392,168,5,1,6,4,4,1,1,4,4


## 1 Filter nodules with area_mm2 >= 50

In [6]:
meta_df_selected = meta_df[meta_df['area_mm2'] >= 50].copy()
meta_df_selected.shape

(2532, 20)

In [7]:
# Check patient count after filtering
print(f"Patient count before filtering: {meta_df['patient_id'].nunique()}, Patient count after filtering: {meta_df_selected['patient_id'].nunique()}")

Patient count before filtering: 875, Patient count after filtering: 440


## 2 Select specified columns

In [8]:
columns_selected = ['scan_id', 'patient_id', 'nodule_index', 'img_path','area_mm2',
                    'ann_subtlety', 'ann_internalStructure', 'ann_calcification', 'ann_sphericity',
                    'ann_margin', 'ann_lobulation', 'ann_spiculation', 'ann_texture',
                    'ann_malignancy']

In [9]:
meta_df_selected = meta_df_selected[columns_selected].copy()

## 3 Classify into 2 categories based on 'ann_malignancy'
Malignant if greater than 3, benign if less than or equal to 3

In [10]:
meta_df_selected['ann_malignancy'].unique()

array([4, 3, 2, 5, 1], dtype=int64)

In [11]:
# Malignant (1) if > 3, benign (0) if <= 3
meta_df_selected['malignancy_label'] = meta_df_selected['ann_malignancy'].apply(lambda x: 1 if x > 3 else 0)
meta_df_selected['malignancy_label'].value_counts()

malignancy_label
1    1667
0     865
Name: count, dtype: int64

## 4 Add column nodule_index_in_patient based on scan_id and nodule_index
Nodules from different scans of the same patient are treated as different nodules, re-indexed

In [12]:
for pid, sub_df in meta_df_selected.groupby('patient_id'):
    # Get all (scan_id, nodule_index) combinations for this patient
    multi_indices = sub_df.groupby(['scan_id', 'nodule_index']).count().index
    # Create new nodule number, re-indexed by patient_id
    for idx, multi_index in enumerate(multi_indices):
        scan_id, nodule_index = multi_index
        meta_df_selected.loc[(meta_df_selected['patient_id']==pid) & 
                             (meta_df_selected['scan_id']==scan_id) & 
                             (meta_df_selected['nodule_index']==nodule_index), 'nodule_index_in_patient'] = idx

In [13]:
# Convert nodule_index_in_patient column to integer type
meta_df_selected['nodule_index_in_patient'] = meta_df_selected['nodule_index_in_patient'].astype(int)

## 5 Get image shape based on img_path column, add column: img_shape

In [14]:
# Curate image path based on img_path column
def curate_img_path(img_path, DATASET_FOLDER):
    if img_path.startswith('./lidc_patches_all'):
        img_path = img_path.replace('./lidc_patches_all', DATASET_FOLDER)
    else:
        print(f"Unexpected img_path format: {img_path}")
        return None
    return img_path

# Get image shape based on img_path column
def get_img_shape(img_path):
    # img_path = curate_img_path(img_path, DATASET_FOLDER)
    if img_path:
    # print(f"Loading image from: {img_path}")
    # Read image
        img = plt.imread(img_path)
        return img.shape
    else:
        return None

img_path_sample = meta_df_selected['img_path'].iloc[100]
img_shape_sample = get_img_shape(curate_img_path(img_path_sample, DATASET_FOLDER))
img_shape_sample

(30, 28)

In [15]:
# Add column: img_shape
# Use progress bar
from tqdm import tqdm
tqdm.pandas()
meta_df_selected['img_path_curated'] = meta_df_selected['img_path'].apply(lambda x: curate_img_path(x, DATASET_FOLDER))
meta_df_selected['img_shape'] = meta_df_selected['img_path_curated'].progress_apply(get_img_shape)

100%|██████████| 2532/2532 [00:39<00:00, 64.28it/s]
100%|██████████| 2532/2532 [00:39<00:00, 64.28it/s]


In [16]:
meta_df_selected['img_shape_H'] = meta_df_selected['img_shape'].apply(lambda x: x[0])
meta_df_selected['img_shape_W'] = meta_df_selected['img_shape'].apply(lambda x: x[1])

In [17]:
meta_df_selected.columns

Index(['scan_id', 'patient_id', 'nodule_index', 'img_path', 'area_mm2',
       'ann_subtlety', 'ann_internalStructure', 'ann_calcification',
       'ann_sphericity', 'ann_margin', 'ann_lobulation', 'ann_spiculation',
       'ann_texture', 'ann_malignancy', 'malignancy_label',
       'nodule_index_in_patient', 'img_path_curated', 'img_shape',
       'img_shape_H', 'img_shape_W'],
      dtype='object')

## 6 Now start building Dataset and DataLoader

Current approach:
1. Filter out columns `['patient_id', 'nodule_index_in_patient', 'img_path_curated', 'area_mm2',
       'malignancy_label', 'img_shape_H', 'img_shape_W']` to form a new DataFrame,
       and save as a CSV file as the curated meta file.
2. Split Dataset by `patient_id`+`nodule_index_in_patient`, inject into HDF5 database (split field).<br>
   Split name format: `patient_id`+`_`+`nodule_index_in_patient`

```python
# Original dataloader definition list:
split_loaders = [
    ("train", train_loader),
    ("val", valid_loader),
    ("test", test_loader),
]

# Now each nodule of each patient is a split, pseudocode:
split_loaders = []
for pid in patient_ids:
    for nid in nodule_indices_for_patient[pid]:
        split_name = f"{pid}_{nid}"
        split_loader = DataLoader( ... )  # Create corresponding DataLoader
        split_loaders.append((split_name, split_loader))
```

### (1) Dataset Statistics

In [18]:
df_selected = meta_df_selected[['patient_id', 'nodule_index_in_patient', 'img_path_curated', 'area_mm2',
       'malignancy_label', 'img_shape_H', 'img_shape_W']].copy()
df_selected['pid_nid_combo'] = df_selected['patient_id'] + "_" + df_selected['nodule_index_in_patient'].astype(str)
df_selected = df_selected[['patient_id', 'pid_nid_combo', 'img_path_curated', 'area_mm2',
       'malignancy_label', 'img_shape_H', 'img_shape_W']].copy()
df_selected.head(6)

Unnamed: 0,patient_id,pid_nid_combo,img_path_curated,area_mm2,malignancy_label,img_shape_H,img_shape_W
0,LIDC-IDRI-0078,LIDC-IDRI-0078_0,./datasets/curation2/lidc_patches_all\LIDC-IDR...,156.325,1,44,54
1,LIDC-IDRI-0078,LIDC-IDRI-0078_0,./datasets/curation2/lidc_patches_all\LIDC-IDR...,184.21,1,44,54
2,LIDC-IDRI-0078,LIDC-IDRI-0078_0,./datasets/curation2/lidc_patches_all\LIDC-IDR...,191.3925,1,44,54
3,LIDC-IDRI-0078,LIDC-IDRI-0078_0,./datasets/curation2/lidc_patches_all\LIDC-IDR...,147.875,1,44,54
4,LIDC-IDRI-0078,LIDC-IDRI-0078_0,./datasets/curation2/lidc_patches_all\LIDC-IDR...,81.12,1,44,54
5,LIDC-IDRI-0078,LIDC-IDRI-0078_1,./datasets/curation2/lidc_patches_all\LIDC-IDR...,87.88,1,50,37


In [19]:
df_selected.pid_nid_combo.unique().shape

(678,)

In [20]:
print(f"Min image height: {df_selected['img_shape_H'].min()}, Min width: {df_selected['img_shape_W'].min()}, Total samples: {df_selected.shape[0]}, Number of patients: {df_selected['patient_id'].nunique()}")

pid_nid_pairs = df_selected.pid_nid_combo.unique()
print(f"Total nodules: {pid_nid_pairs.shape[0]}")

# Save this meta data table as csv file
df_selected.to_csv("curated_metadata.csv", index=False)

Min image height: 21, Min width: 22, Total samples: 2532, Number of patients: 440
Total nodules: 678


In [21]:
# Calculate average image height and width per nodule
nodule_image_shapes = df_selected.groupby('pid_nid_combo').agg({'img_shape_H':'mean', 'img_shape_W':'mean'}).reset_index()
nodule_image_shapes.rename(columns={'img_shape_H':'avg_img_shape_H', 'img_shape_W':'avg_img_shape_W'}, inplace=True)
# Print average image height and width with standard deviation
print(f"Average image height: {nodule_image_shapes['avg_img_shape_H'].mean()}, Average image width: {nodule_image_shapes['avg_img_shape_W'].mean()}")
print(f"Image height std: {nodule_image_shapes['avg_img_shape_H'].std()}, Image width std: {nodule_image_shapes['avg_img_shape_W'].std()}")

# Calculate average nodule area per nodule
nodule_area_stats = df_selected.groupby('pid_nid_combo').agg({'area_mm2':['mean']}).reset_index()
nodule_area_stats.columns = ['pid_nid_combo', 'avg_area_mm2']
# Print average nodule area with standard deviation
print(f"Average nodule area: {nodule_area_stats['avg_area_mm2'].mean()}, Nodule area std: {nodule_area_stats['avg_area_mm2'].std()}")

# Calculate malignancy label distribution per nodule
nodule_malignancy_stats = df_selected.groupby('pid_nid_combo').agg({'malignancy_label':'mean'}).reset_index()
nodule_malignancy_stats.rename(columns={'malignancy_label':'nodule_malignancy_label'}, inplace=True)
# Print nodule malignancy distribution statistics
print(f"Nodule malignancy distribution: {nodule_malignancy_stats['nodule_malignancy_label'].value_counts()}")

Average image height: 39.06047197640118, Average image width: 38.849557522123895
Image height std: 12.771628953601033, Image width std: 12.900062110808308
Average nodule area: 145.92181160085315, Nodule area std: 126.77192625900362
Nodule malignancy distribution: nodule_malignancy_label
1.0    397
0.0    281
Name: count, dtype: int64


### (2) CLIP Feature Construction and Storage

In [26]:
import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms import functional as TF
import numpy as np
from tqdm import tqdm

In [None]:
# Create Dataset for each pid_nid_combo
class NoduleDataset(Dataset):
    """
    NoduleDataset
    - Input: DataFrame containing columns ['img_path_curated', 'malignancy_label']
    - Output:
        image: torch.FloatTensor, shape [C, H, W], value range [0, 1]
        label: torch.LongTensor, scalar (0/1)
    - Note:
        Regardless of whether transform is provided, images are first converted to [0,1] tensors;
        If transform is provided, it will receive a Tensor (not PIL.Image).
    """
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        # If transform is provided, it should accept Tensor input (C,H,W in [0,1])
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['img_path_curated']
        # Read and convert to RGB
        image = Image.open(img_path).convert('RGB')
        # Always convert to [0,1] float32 Tensor first (shape [C,H,W])
        image = TF.to_tensor(image)

        # If additional transform exists, apply to Tensor
        if self.transform is not None:
            image = self.transform(image)

        label = int(row['malignancy_label'])
        label = torch.tensor(label, dtype=torch.long)

        return image, label

In [None]:
# Create torch Dataset and DataLoader split by pid_nid_combo
# Note: Since image sizes within each nodule may differ, to avoid batch stacking failure, default batch_size=1
# For larger batch_size, first apply uniform size transformation (e.g., transforms.Resize or custom pad/collate_fn)
split_loaders = []

# Optional: Define additional transforms (applied to Tensor), e.g., normalization; no additional transform here	extra_transform = None

for pid_nid, group_df in df_selected.groupby('pid_nid_combo'):
    dataset = NoduleDataset(group_df)
    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        pin_memory=True if torch.cuda.is_available() else False,
    )
    split_loaders.append((pid_nid, loader))

# 2 Construct Concepts and Templates

Concepts have synonyms settings, I need to maintain and manage concepts and their corresponding synonyms.

In ConceptCLIP, all concepts and templates are combined and forward-propagated through the text encoder, then stored in the HDF5 database.

In [None]:
# CONCEPTS_slot_dict:
# I divided concepts into 10 major categories (based on domain knowledge), each category has various synonyms

# TEMPLATES_slot_dict:
# All templates are stored as prompt list in H5 database (attr → prompt_temp_for_concepts).

# ---------------------------------------------------------
# 1. Flattened Concept Dictionary
# ---------------------------------------------------------
CONCEPTS_dict = {
    # --- Margins/Shape ---
    
    # 1. Spiculation
    # Malignant sign: radial fine lines extending from nodule edge to surrounding lung parenchyma.
    "spiculation": [
        "spiculated",
        "spiculation",
        "radial spicules",
        "spiculated margins",
        "spikes extending from the surface"
    ],

    # 2. Lobulation
    # Malignant sign: wavy or scalloped margins due to uneven tumor growth rate.
    "lobulation": [
        "lobulated",
        "lobulation",
        "a scalloped contour",
        "lobulated margins",
        "wavy contour"
    ],

    # 3. Round/Sphericity
    # Benign tendency: commonly seen in hamartomas, tuberculomas, but also in metastases.
    "round_sphericity": [
        "round",
        "spherical",
        "high-sphericity",
        "a round shape",
        "highly spherical",
        "circular shape"
    ],

    # --- Density/Attenuation ---

    # 4. Pure Ground-Glass Nodule (Pure GGN)
    # Alveolar spaces filled with air but alveolar walls thickened, not obscuring vascular markings.
    "pure_GGN": [
        "pure ground-glass",
        "non-solid",
        "pure ground-glass appearance",
        "non-solid attenuation",
        "pure GGO"
    ],

    # 5. Part-solid/Mixed
    # Highest malignancy probability: contains both ground-glass and solid components.
    "part_solid": [
        "part-solid",
        "subsolid",
        "a solid focus within ground-glass",
        "mixed ground-glass density",
        "mixed attenuation"
    ],

    # 6. Solid Nodule
    # Completely obscures pulmonary vascular markings, high density.
    "solid": [
        "solid",
        "soft-tissue attenuation",
        "solid attenuation",
        "dense solid structure"
    ],

    # --- Internal Structure ---

    # 7. Benign Calcification
    # Usually presents as diffuse, central, laminar, or popcorn pattern.
    "benign_calc": [
        "diffuse calcification",
        "central calcification",
        "laminar calcification",
        "popcorn calcification",
        "benign calcification pattern"
    ],

    # 8. Eccentric/Punctate Calcification
    # Higher malignancy risk: eccentric distribution or scattered punctate.
    "eccentric_punctate_calc": [
        "eccentrically calcified",
        "eccentric calcification",
        "punctate calcification",
        "stippled calcification",
        "small scattered calcifications"
    ],

    # 9. Air Bronchogram/Cavitation
    # Air-filled bronchi or cavities visible within the nodule.
    "air_bronch_cav": [
        "cavitary",
        "an air bronchogram",
        "internal air lucency",
        "cavitation",
        "air-filled pockets"
    ],

    # --- Image Quality and Visibility ---

    # 10. Very Subtle
    # Low contrast, even hard for human eyes to detect, used to test model sensitivity to weak signals.
    "very_subtle": [
        "very subtle",
        "low-contrast",
        "faint-margin",
        "low contrast",
        "faint margins",
        "hard to see",
        "indistinct boundaries"
    ]
}

# ---------------------------------------------------------
# 2. Universal Templates
# ---------------------------------------------------------
# Design principle: compatible with both adjectives and noun phrases, or eliminate grammatical errors through ensemble
PROMPT_TEMPLATES = [
    # Format A: Simple concatenation (Context: {Description})
    "chest CT showing a pulmonary nodule: {}.",
    "axial chest CT of a pulmonary nodule, {}.",
    "lung window CT depicting a pulmonary nodule, {}.",
    
    # Format B: Descriptive (Context showing {Description})
    # This format accepts both nouns (spiculation) and adjectives (spiculated)
    "chest CT image of a pulmonary nodule showing {}.",
    "axial lung CT slice demonstrating {}.",
    "a pulmonary nodule characterized by {}.",
    
    # Format C: Feature emphasis (The nodule is/has {Description})
    "chest CT where the pulmonary nodule is {}.",
    "chest CT showing a pulmonary nodule with features of {}.",
    
    # Format D: Short medical description style
    "pulmonary nodule, {}.",
    "CT scan, lung nodule, {}."
]

# Deduplicate all concepts from CONCEPTS_slot_dict as a list, preserving order
concept_list = []

for concept_type, concepts in CONCEPTS_dict.items():
    for concept in concepts:
        if concept not in concept_list:
            concept_list.append(concept)

concept_prompt_template_list = PROMPT_TEMPLATES

In [32]:
label_list = ['benign', 'malignant']

label_prompt_template_list = [
    "a {} lung nodule in CT scan",
    "a lung nodule showing {} in CT scan",
    "a CT scan image of a {} lung nodule",
    "a CT slice with a {} lung nodule",
    "a Chest CT image showing a {} nodule",
    "{} nodule in Chest CT scan",
    "{} nodule in Chest CT"
]

DB_METADATA = {
    "label_texts": label_list,
    "prompt_temp_for_labels": label_prompt_template_list,
    "concept_texts": concept_list,
    "prompt_temp_for_concepts": concept_prompt_template_list,
}

# 3 Load CLIP Model

In [None]:
# Load CLIP model from HuggingFace
from transformers import AutoModel, AutoProcessor
from huggingface_hub import login, whoami
login(token="YOUR_TOKEN")  # NOTE: Replace with your own token; avoid hard-coding secrets in production.
print(whoami())

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to C:\Users\chenk\.cache\huggingface\token
Login successful
{'type': 'user', 'id': '6693342048f8a55b7eba3d8e', 'name': 'takedachia', 'fullname': 'Jack Chan', 'email': 'chenkai1989129@gmail.com', 'emailVerified': True, 'canPay': False, 'periodEnd': None, 'isPro': False, 'avatarUrl': '/avatars/df59d16febde1a73fc98bf70f5476ba2.svg', 'orgs': [], 'auth': {'type': 'access_token', 'accessToken': {'displayName': 'ConceptCLIP', 'role': 'read', 'createdAt': '2025-08-20T17:55:27.543Z'}}}


In [34]:
# Load model and processor (first time will download to local cache)
model = AutoModel.from_pretrained('JerrryNie/ConceptCLIP', trust_remote_code=True)
processor = AutoProcessor.from_pretrained('JerrryNie/ConceptCLIP', trust_remote_code=True)

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


In [35]:
# Move model to GPU if available (else CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
model = model.to(device).eval()

In [36]:
# Adjust processor image preprocessing settings
# BloodMNIST DataLoader already yields float tensors in [0,1], so we disable rescale and enable normalize
processor.image_processor.do_rescale = False
processor.image_processor.do_normalize = True

# 4 Use Text and Image Encoders to Infer, Extract, and Store Features

## (1) Deploy H5 database for storing outputs

In [38]:
import os, h5py, hashlib, json, time
from pathlib import Path
from typing import Optional, Sequence

OUT_PATH = "./conceptclip_features.h5"
D       = 1152
T_img   = 729
DTYPE_DYNAMIC = "float16"   # Dynamic parts (image_*) can be stored in lower precision to save space
DTYPE_TEXT    = "float32"   # Text features are recommended to remain in higher precision
DTYPE_STR     = h5py.string_dtype(encoding="utf-8")

def _serialize_for_attr(value):
    """Serialize a value for storing as an HDF5 file attribute."""
    return json.dumps(value, ensure_ascii=False)

def _read_json_attr(f: h5py.File, key: str, default=None):
    """Read a JSON-formatted attribute from the root of the H5 file and parse it."""
    if key not in f.attrs:
        return default
    raw = f.attrs[key]
    if isinstance(raw, bytes):
        raw = raw.decode("utf-8")
    if isinstance(raw, str):
        try:
            return json.loads(raw)
        except json.JSONDecodeError:
            return raw
    return raw

def _hash_texts(texts: Sequence[str]) -> str:
    """Compute a SHA256 hash for a list of texts."""
    h = hashlib.sha256()
    for t in texts:
        h.update(t.encode("utf-8")); h.update(b"\0")
    return h.hexdigest()

def init_file(path, d=D, t_img=T_img, metadata: Optional[dict] = None):
    """Initialize HDF5 file structure if missing; open in append mode."""
    if not os.path.exists(path):
        with h5py.File(path, "w") as f:
            # Create main datasets
            f.create_dataset("image_features",
                shape=(0, d), maxshape=(None, d),
                chunks=(64, d), dtype=DTYPE_DYNAMIC, compression="lzf")
            f.create_dataset("image_token_features",
                shape=(0, t_img, d), maxshape=(None, t_img, d),
                chunks=(4, t_img, d), dtype=DTYPE_DYNAMIC, compression="lzf")
            f.create_dataset("ids",
                shape=(0,), maxshape=(None,),
                chunks=(4096,), dtype="int64", compression="lzf")
            f.create_dataset("labels",
                shape=(0,), maxshape=(None,),
                chunks=(4096,), dtype="int64", compression="lzf")
            f.create_dataset("split",
                shape=(0,), maxshape=(None,),
                chunks=(4096,), dtype=DTYPE_STR, compression="lzf")
            
            # Create prompt-templates group
            f.create_group("templates")
            
            # Add global attributes (model parameters)
            f.attrs["D"] = d
            f.attrs["T_img"] = t_img
            f.attrs["created_at"] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
            f.attrs["version"] = "1.0"
            
    f = h5py.File(path, "a")
    current_len = f["image_features"].shape[0] if "image_features" in f else 0
    if "image_token_features" not in f:
        f.create_dataset("image_token_features",
            shape=(current_len, t_img, d), maxshape=(None, t_img, d),
            chunks=(4, t_img, d), dtype=DTYPE_DYNAMIC, compression="lzf")
    if "ids" not in f:
        f.create_dataset("ids",
            shape=(current_len,), maxshape=(None,),
            chunks=(4096,), dtype="int64", compression="lzf")
    if "labels" not in f:
        ds_labels = f.create_dataset("labels",
            shape=(current_len,), maxshape=(None,),
            chunks=(4096,), dtype="int64", compression="lzf")
        if current_len > 0:
            ds_labels[:] = np.full((current_len,), -1, dtype="int64")
    if "split" not in f:
        ds_split = f.create_dataset("split",
            shape=(current_len,), maxshape=(None,),
            chunks=(4096,), dtype=DTYPE_STR, compression="lzf")
        if current_len > 0:
            ds_split[:] = np.array(["unspecified"] * current_len, dtype=object)
    if "templates" not in f:
        f.create_group("templates")
    if metadata:
        for key, value in metadata.items():
            try:
                f.attrs[key] = _serialize_for_attr(value)
            except TypeError:
                f.attrs[key] = _serialize_for_attr(str(value))
    return f

def _to_np(x, dtype):
    """Convert tensor to numpy array with specified dtype."""
    if torch.is_tensor(x):
        x = x.detach().cpu().numpy()
    return x.astype(dtype, copy=False)

def append_batch(f: h5py.File,
                 image_feats: torch.Tensor,            # [B, D]
                 image_token_feats: torch.Tensor,      # [B, T_img, D]
                 ids: np.ndarray,                     # [B] int64
                 split_names: Optional[Sequence[str]] = None,
                 labels: Optional[np.ndarray] = None):
    """Append a batch of data to the HDF5 file."""
    ds_img = f["image_features"]
    ds_tok = f["image_token_features"]
    ds_ids = f["ids"]
    ds_labels = f["labels"]
    ds_split = f["split"]
    B = image_feats.shape[0]

    n0 = ds_img.shape[0]
    ds_img.resize(n0 + B, axis=0)
    ds_tok.resize(n0 + B, axis=0)
    ds_ids.resize(n0 + B, axis=0)
    ds_labels.resize(n0 + B, axis=0)
    ds_split.resize(n0 + B, axis=0)

    ds_img[n0:n0+B, :]      = _to_np(image_feats, DTYPE_DYNAMIC)
    ds_tok[n0:n0+B, :, :]   = _to_np(image_token_feats, DTYPE_DYNAMIC)
    ds_ids[n0:n0+B]         = ids.astype("int64", copy=False)

    if labels is None:
        labels_arr = np.full((B,), -1, dtype="int64")
    else:
        labels_arr = np.asarray(labels, dtype="int64").reshape(-1)
        if labels_arr.shape[0] != B:
            raise ValueError(f"Label count {labels_arr.shape[0]} does not match batch size {B}.")
    ds_labels[n0:n0+B] = labels_arr

    if split_names is None:
        split_arr = np.array(["unspecified"] * B, dtype=object)
    else:
        split_arr = np.asarray(list(split_names), dtype=object).reshape(-1)
        if split_arr.shape[0] != B:
            raise ValueError(f"Split count {split_arr.shape[0]} does not match batch size {B}.")
    ds_split[n0:n0+B] = split_arr

    f.flush()

def write_template(f: h5py.File,
                   template_id: str,
                   texts: Sequence[str],
                   text_features: torch.Tensor,            # [K, D]
                   text_token_features: Optional[torch.Tensor]=None # [K, T_txt, D]
                   ):
    """Write text templates and their features."""
    g_root = f["templates"]
    if template_id in g_root:
        del g_root[template_id]             # Overwrite existing (change logic if verification needed)
    g = g_root.create_group(template_id)

    # Main data
    tf  = _to_np(text_features, DTYPE_TEXT)         # [K, D]
    g.create_dataset("text_features", data=tf, compression="lzf")

    if text_token_features is not None:
        ttf = _to_np(text_token_features, DTYPE_TEXT)     # [K, T_txt, D]
        g.create_dataset("text_token_features", data=ttf, compression="lzf")
        T_txt = ttf.shape[1]
    else:
        T_txt = -1

    # Save original prompts (variable-length UTF-8 strings)
    dt_str = h5py.string_dtype(encoding='utf-8')
    ds_txt = g.create_dataset("texts", shape=(len(texts),), dtype=dt_str, compression="lzf")
    ds_txt[:] = np.array(list(texts), dtype=object)

    # Metadata
    g.attrs["K"]          = tf.shape[0]
    g.attrs["D"]          = tf.shape[1]
    g.attrs["T_txt"]      = int(T_txt)
    g.attrs["texts_hash"] = _hash_texts(texts)
    g.attrs["created_at"] = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    f.flush()

def save_model_params(f: h5py.File, 
                      logit_scale: float = None,
                      logit_bias: float = None,
                      concept_logit_scale: float = None,
                      concept_logit_bias: float = None):
    """Persist key model parameters to root attributes of the HDF5 file."""
    if logit_scale is not None:
        f.attrs["logit_scale"] = float(logit_scale)
    if logit_bias is not None:
        f.attrs["logit_bias"] = float(logit_bias)
    if concept_logit_scale is not None:
        f.attrs["concept_logit_scale"] = float(concept_logit_scale)
    if concept_logit_bias is not None:
        f.attrs["concept_logit_bias"] = float(concept_logit_bias)
    f.flush()

def load_model_params(f: h5py.File) -> dict:
    """Load model parameters from file attributes."""
    params = {}
    for key in ["logit_scale", "logit_bias", "concept_logit_scale", "concept_logit_bias"]:
        if key in f.attrs:
            params[key] = f.attrs[key]
    return params

## (2) Compute and save concept text features to H5 (using the Text Encoder)

In [39]:
def _maybe_to_scalar(x):
    if x is None:
        return None
    if isinstance(x, (int, float)):
        return float(x)
    if torch.is_tensor(x):
        x = x.detach()
        if x.numel() == 1:
            return float(x.item())
        try:
            return float(x.squeeze().item())
        except Exception:
            return None
    if hasattr(x, "item"):
        try:
            return float(x.item())
        except Exception:
            return None
    return None

stored_templates = []
with init_file(OUT_PATH, metadata=DB_METADATA) as f:
    concept_texts_in_file = _read_json_attr(f, "concept_texts", concept_list)
    concept_template_list = _read_json_attr(f, "prompt_temp_for_concepts", concept_prompt_template_list)
    label_texts_in_file = _read_json_attr(f, "label_texts", label_list)
    label_template_list = _read_json_attr(f, "prompt_temp_for_labels", label_prompt_template_list)

    template_texts_map = {}
    if isinstance(concept_template_list, (list, tuple)) and concept_texts_in_file:
        for idx, template in enumerate(concept_template_list, start=1):
            formatted_texts = [template.format(text) for text in concept_texts_in_file]
            template_key = f"concept_prompts_t{idx:02d}"
            template_texts_map[template_key] = formatted_texts
    if isinstance(label_template_list, (list, tuple)) and label_texts_in_file:
        for idx, template in enumerate(label_template_list, start=1):
            formatted_label_texts = [template.format(text) for text in label_texts_in_file]
            template_key = f"label_prompts_t{idx:02d}"
            template_texts_map[template_key] = formatted_label_texts

    with torch.no_grad():
        for template_id, template_texts in template_texts_map.items():
            text_inputs = processor(text=template_texts, return_tensors="pt", padding=True, truncation=True).to(device)
            text_cls, text_tokens = model.encode_text(text_inputs["input_ids"], normalize=True)  # [K, D], [K, T_txt, D]
            text_tokens_proj = model.text_proj(text_tokens) if hasattr(model, "text_proj") else text_tokens
            write_template(f, template_id=template_id, texts=template_texts, text_features=text_cls, text_token_features=text_tokens_proj)
            stored_templates.append((template_id, text_cls.shape[0]))
    params_kwargs = {
        "logit_scale": _maybe_to_scalar(getattr(model, "logit_scale", None)),
        "logit_bias": _maybe_to_scalar(getattr(model, "logit_bias", None)),
        "concept_logit_scale": _maybe_to_scalar(getattr(model, "concept_logit_scale", None)),
        "concept_logit_bias": _maybe_to_scalar(getattr(model, "concept_logit_bias", None)),
    }
    save_model_params(f, **{k: v for k, v in params_kwargs.items() if v is not None})

print("Stored templates:")
for template_id, count in stored_templates:
    print(f"  - {template_id}: {count} prompts")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Stored templates:
  - concept_prompts_t01: 52 prompts
  - concept_prompts_t02: 52 prompts
  - concept_prompts_t03: 52 prompts
  - concept_prompts_t04: 52 prompts
  - concept_prompts_t05: 52 prompts
  - concept_prompts_t06: 52 prompts
  - concept_prompts_t07: 52 prompts
  - concept_prompts_t08: 52 prompts
  - concept_prompts_t09: 52 prompts
  - concept_prompts_t10: 52 prompts
  - label_prompts_t01: 2 prompts
  - label_prompts_t02: 2 prompts
  - label_prompts_t03: 2 prompts
  - label_prompts_t04: 2 prompts
  - label_prompts_t05: 2 prompts
  - label_prompts_t06: 2 prompts
  - label_prompts_t07: 2 prompts


## (3) Compute and save image (token) features to H5 (using the Image Encoder)
Token refers to each 27×27 patch.

In [40]:
split_loaders[:5]

[('LIDC-IDRI-0001_0',
  <torch.utils.data.dataloader.DataLoader at 0x1a4d02e25c0>),
 ('LIDC-IDRI-0002_0',
  <torch.utils.data.dataloader.DataLoader at 0x1a4d02e3790>),
 ('LIDC-IDRI-0003_0',
  <torch.utils.data.dataloader.DataLoader at 0x1a4d02e12a0>),
 ('LIDC-IDRI-0003_1',
  <torch.utils.data.dataloader.DataLoader at 0x1a4d02e3640>),
 ('LIDC-IDRI-0003_2',
  <torch.utils.data.dataloader.DataLoader at 0x1a4d02e3dc0>)]

In [41]:
# Use Image Encoder to compute and save image (token) features to H5

# Define data splits and their loaders
# split_loaders = [
#     ("train", train_loader),
#     ("val", valid_loader),
#     ("test", test_loader),
# ]

def encode_and_store_image_features(split_loaders):
    """Traverse each data split and write image features to HDF5."""
    written_counts = {}
    with torch.no_grad():
        with init_file(OUT_PATH, metadata=DB_METADATA) as f:
            existing_counts = {}
            if "split_counts" in f.attrs:
                try:
                    existing_counts = json.loads(f.attrs["split_counts"])
                except Exception:
                    existing_counts = {}
            start_idx = f["image_features"].shape[0]
            for split_name, loader in split_loaders:
                expected = len(loader.dataset)
                recorded = int(existing_counts.get(split_name, 0) or 0)
                if recorded >= expected:
                    print(f"Skip {split_name}: already stored {recorded}/{expected} samples.")
                    continue

                processed_in_split = 0
                for imgs, labels in tqdm(loader, desc=f"{split_name} split"):
                    batch_total = imgs.shape[0]
                    if processed_in_split + batch_total <= recorded:
                        processed_in_split += batch_total
                        continue
                    if processed_in_split < recorded:
                        offset = recorded - processed_in_split
                        imgs = imgs[offset:]
                        labels = labels[offset:]
                        processed_in_split = recorded
                        batch_total = imgs.shape[0]

                    pixel_inputs = processor(images=imgs, return_tensors="pt", padding=True, truncation=True)
                    pixel_inputs = pixel_inputs["pixel_values"].to(device)
                    img_cls, img_tokens = model.encode_image(pixel_inputs, normalize=True)   # [B,1152], [B,729,1152]
                    img_tokens_proj = model.image_proj(img_tokens) if hasattr(model, "image_proj") else img_tokens  #  [B,729,1152]
                    batch_size = img_cls.shape[0]
                    batch_ids = np.arange(start_idx, start_idx + batch_size, dtype=np.int64)
                    labels_tensor = labels.view(-1) if hasattr(labels, "view") else labels
                    labels_np = np.asarray(labels_tensor, dtype=np.int64).reshape(-1)
                    append_batch(
                        f,
                        image_feats=img_cls,
                        image_token_feats=img_tokens_proj,
                        ids=batch_ids,
                        split_names=[split_name] * batch_size,
                        labels=labels_np,
                    )
                    start_idx += batch_size
                    processed_in_split += batch_size
                    written_counts[split_name] = written_counts.get(split_name, 0) + batch_size

            # Update split counts in file attributes
            for split_name, count in written_counts.items():
                existing_counts[split_name] = int(existing_counts.get(split_name, 0) or 0) + count
            f.attrs["split_counts"] = json.dumps(existing_counts)
            f.flush()

    total_new = sum(written_counts.values())
    if total_new == 0:
        print("No new samples were appended.")
    else:
        print(f"Appended {total_new} samples to {OUT_PATH}.")
        for split_name, count in written_counts.items():
            print(f"  - {split_name}: {count}")

encode_and_store_image_features(split_loaders)

Skip LIDC-IDRI-0001_0: already stored 5/5 samples.
Skip LIDC-IDRI-0002_0: already stored 5/5 samples.
Skip LIDC-IDRI-0003_0: already stored 5/5 samples.
Skip LIDC-IDRI-0003_1: already stored 5/5 samples.
Skip LIDC-IDRI-0003_2: already stored 2/2 samples.
Skip LIDC-IDRI-0003_3: already stored 4/4 samples.
Skip LIDC-IDRI-0006_0: already stored 1/1 samples.
Skip LIDC-IDRI-0007_0: already stored 5/5 samples.
Skip LIDC-IDRI-0007_1: already stored 4/4 samples.
Skip LIDC-IDRI-0011_0: already stored 3/3 samples.
Skip LIDC-IDRI-0012_0: already stored 4/4 samples.
Skip LIDC-IDRI-0013_0: already stored 1/1 samples.
Skip LIDC-IDRI-0013_1: already stored 5/5 samples.
Skip LIDC-IDRI-0014_0: already stored 5/5 samples.
Skip LIDC-IDRI-0015_0: already stored 5/5 samples.
Skip LIDC-IDRI-0016_0: already stored 4/4 samples.
Skip LIDC-IDRI-0016_1: already stored 2/2 samples.
Skip LIDC-IDRI-0016_2: already stored 5/5 samples.
Skip LIDC-IDRI-0018_0: already stored 3/3 samples.
Skip LIDC-IDRI-0018_1: already 

## (4) Validate H5 contents

In [42]:
def validate_h5(path: str = OUT_PATH, max_templates: int = 2):
    if not os.path.exists(path):
        print(f"File '{path}' does not exist yet.")
        return
    with h5py.File(path, "r") as f:
        num_samples = f["image_features"].shape[0]
        print(f"Total samples stored: {num_samples}")
        print(f"image_features dataset shape: {f['image_features'].shape}")
        print(f"image_token_features dataset shape: {f['image_token_features'].shape}")
        ids_preview = f["ids"][:5]
        labels_preview = f["labels"][:5]
        split_preview = f["split"][:5]
        print(f"Preview ids: {ids_preview}")
        print(f"Preview labels: {labels_preview}")
        print(f"Preview splits: {split_preview}")
        try:
            split_counts = json.loads(f.attrs.get("split_counts", "{}"))
        except Exception:
            split_counts = {}
        print(f"Recorded split counts: {split_counts}")
        metadata_keys = [
            "label_texts",
            "prompt_temp_for_labels",
            "concept_texts",
            "prompt_temp_for_concepts",
        ]
        for key in metadata_keys:
            value = _read_json_attr(f, key, None)
            if value is None:
                continue
            if isinstance(value, list):
                preview = value[:3] + (["..."] if len(value) > 3 else [])
                print(f"{key}: {preview}")
            else:
                print(f"{key}: {value}")
        templates = list(f["templates"].keys())
        print(f"Templates stored: {templates}")
        for template_id in templates[:max_templates]:
            g = f["templates"][template_id]
            texts_sample = g["texts"][:3] if "texts" in g else []
            print(f"  • Template '{template_id}' -> K={g.attrs.get('K')}, D={g.attrs.get('D')}, sample texts={texts_sample}")
        params = load_model_params(f)
        print(f"Stored model params: {params}")

validate_h5()

Total samples stored: 2532
image_features dataset shape: (2532, 1152)
image_token_features dataset shape: (2532, 729, 1152)
Preview ids: [0 1 2 3 4]
Preview labels: [1 1 1 1 1]
Preview splits: [b'LIDC-IDRI-0001_0' b'LIDC-IDRI-0001_0' b'LIDC-IDRI-0001_0'
 b'LIDC-IDRI-0001_0' b'LIDC-IDRI-0001_0']
Recorded split counts: {'LIDC-IDRI-0001_0': 5, 'LIDC-IDRI-0002_0': 5, 'LIDC-IDRI-0003_0': 5, 'LIDC-IDRI-0003_1': 5, 'LIDC-IDRI-0003_2': 2, 'LIDC-IDRI-0003_3': 4, 'LIDC-IDRI-0006_0': 1, 'LIDC-IDRI-0007_0': 5, 'LIDC-IDRI-0007_1': 4, 'LIDC-IDRI-0011_0': 3, 'LIDC-IDRI-0012_0': 4, 'LIDC-IDRI-0013_0': 1, 'LIDC-IDRI-0013_1': 5, 'LIDC-IDRI-0014_0': 5, 'LIDC-IDRI-0015_0': 5, 'LIDC-IDRI-0016_0': 4, 'LIDC-IDRI-0016_1': 2, 'LIDC-IDRI-0016_2': 5, 'LIDC-IDRI-0018_0': 3, 'LIDC-IDRI-0018_1': 5, 'LIDC-IDRI-0019_0': 5, 'LIDC-IDRI-0020_0': 5, 'LIDC-IDRI-0022_0': 5, 'LIDC-IDRI-0023_0': 5, 'LIDC-IDRI-0024_0': 4, 'LIDC-IDRI-0024_1': 2, 'LIDC-IDRI-0027_0': 1, 'LIDC-IDRI-0027_1': 1, 'LIDC-IDRI-0029_0': 5, 'LIDC-IDRI-00