In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Get CT NIfTI image embeddings using MedCLIP.
"""

from __future__ import annotations

import argparse
import json
import logging
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Sequence, Tuple

import numpy as np
import pandas as pd
import nibabel as nib
import torch
from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from torchvision.transforms import functional as TF

try:
    from medclip import (
        MedCLIPModel,
        MedCLIPVisionModelViT,
        constants as medclip_constants,
    )
except ImportError:
    print("please install medclip: pip install medclip")
    raise

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

"""
defalut pass
"""
DEFAULT_INPUT_DIRS = [
    r"D:\research\cd-ml\(process)nzhf\提取的单个nii文件",
]

DEFAULT_OUTPUT_DIR = r"D:\research\cd-ml\(process)nzhf\medclip_embeddings"
DEFAULT_BATCH_SIZE = 64


In [3]:
"""
data class
"""
@dataclass
class ImageItem:
    identifier: str
    image: Image.Image
    patient_name: str
    center: str

In [4]:
"""
read and normalize nifti file
"""
def nifti_to_images(
    nifti_path: Path,
    center_name: str,
    patient_name: str,
) -> List[ImageItem]:
    """change nifti file to images"""
    volume = nib.load(str(nifti_path))
    data = volume.get_fdata().astype(np.float32)

    if data.ndim == 2:
        data = data[..., np.newaxis]
    elif data.ndim == 4:
        data = np.moveaxis(data, 0, -1)

    images: List[ImageItem] = []
    num_slices = data.shape[-1]
    for idx in range(num_slices):
        frame = data[..., idx]
        frame = frame.astype(np.float32)
        frame -= frame.min()
        max_val = frame.max()
        if max_val > 0:
            frame /= max_val
        frame_uint8 = (frame * 255).clip(0, 255).astype(np.uint8)
        pil_img = Image.fromarray(frame_uint8).convert("RGB")
        identifier = f"{nifti_path}::slice_{idx:04d}"
        images.append(
            ImageItem(
                identifier=identifier,
                image=pil_img,
                patient_name=patient_name,
                center=center_name,
            )
        )

    return images

In [5]:
"""file iter"""
NIFTI_EXTENSIONS = (".nii", ".nii.gz")

def iter_nifti_files(input_dirs: Sequence[Path]) -> Iterator[Tuple[Path, str, str]]:

    for directory in input_dirs:
        dir_path = Path(directory)
        if not dir_path.exists():
            continue
        
        for center_dir in dir_path.iterdir():
            if not center_dir.is_dir():
                continue
            center_name = center_dir.name
            
            for patient_dir in center_dir.iterdir():
                if not patient_dir.is_dir():
                    continue
                patient_name = patient_dir.name
                
                for file_path in patient_dir.iterdir():
                    if not file_path.is_file():
                        continue
                    lower = file_path.name.lower()
                    if lower.endswith(NIFTI_EXTENSIONS) and "ct" in lower:
                        yield file_path, center_name, patient_name


In [6]:
"""medclip model"""
def _resolve_pretrained_dir(model: MedCLIPModel, input_dir: Optional[str] = None) -> Tuple[Path, str]:
    """according to model type, return pretrained weight dir and download url"""
    if isinstance(model.vision_model, MedCLIPVisionModelViT):
        pretrained_url = medclip_constants.PRETRAINED_URL_MEDCLIP_VIT
        default_dir = Path("./pretrained/medclip-vit")
    else:
        pretrained_url = medclip_constants.PRETRAINED_URL_MEDCLIP_RESNET
        default_dir = Path("./pretrained/medclip-resnet")

    if input_dir is not None:
        default_dir = Path(input_dir)
    return default_dir, pretrained_url

def load_medclip_pretrained(model: MedCLIPModel, input_dir: Optional[str] = None) -> None:
    """custom load pretrained weight"""
    from zipfile import ZipFile
    import requests
    import wget

    weight_dir, pretrained_url = _resolve_pretrained_dir(model, input_dir)
    weight_dir = weight_dir.resolve()
    weight_dir.mkdir(parents=True, exist_ok=True)

    weight_file = weight_dir / medclip_constants.WEIGHTS_NAME
    if not weight_file.exists():
        response = requests.get(pretrained_url, timeout=30)
        response.raise_for_status()
        download_url = response.text.strip()
        zip_path = wget.download(download_url, str(weight_dir))
        with ZipFile(zip_path) as zipf:
            zipf.extractall(weight_dir)
        print(f"\n download pretrained model: {download_url}")

    state_dict = torch.load(weight_file, map_location="cpu")
    state_dict.pop("text_model.model.embeddings.position_ids", None)
    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    if missing or unexpected:
        logging.warning("ignore missing=%s unexpected=%s", missing, unexpected)
    print(f"load model weight from: {weight_dir}")

def build_model(device: torch.device) -> Tuple[MedCLIPModel, Callable[[Image.Image], torch.Tensor]]:
    vision_cls = MedCLIPVisionModelViT
    model = MedCLIPModel(vision_cls=vision_cls)
    load_medclip_pretrained(model)
    model = model.to(device).eval()
    normalize = transforms.Normalize(
        mean=[medclip_constants.IMG_MEAN],
        std=[medclip_constants.IMG_STD],
    )
    target_size = medclip_constants.IMG_SIZE

    def preprocess(image: Image.Image) -> torch.Tensor:
        if image.mode != "L":
            image = image.convert("L")
        width, height = image.size
        canvas_size = max(target_size, width, height)
        canvas = Image.new("L", (canvas_size, canvas_size), color=0)
        offset = ((canvas_size - width) // 2, (canvas_size - height) // 2)
        canvas.paste(image, offset)
        canvas = canvas.resize((target_size, target_size), Image.BICUBIC)
        tensor = TF.to_tensor(canvas)
        tensor = normalize(tensor)
        return tensor

    return model, preprocess

def embed_images(
    model: MedCLIPModel,
    preprocess: Callable[[Image.Image], torch.Tensor],
    device: torch.device,
    items: Sequence[ImageItem],
) -> np.ndarray:
    """batch get embeddings (unit vector)"""
    with torch.no_grad():
        pixel_values = torch.stack([preprocess(item.image) for item in items]).to(device)
        embeddings = model.encode_image(pixel_values=pixel_values)
    return embeddings.cpu().numpy().astype(np.float32)


In [None]:
 """main process"""
def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="use medclip to extract ct nifti image embedding")
    parser.add_argument(
        "--input-dirs",
        nargs="+",
        default=DEFAULT_INPUT_DIRS,
        help="need to traverse folder (can be multiple)",
    )
    parser.add_argument(
        "--output-dir",
        default=DEFAULT_OUTPUT_DIR,
        help="output folder, save embeddings and index",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=DEFAULT_BATCH_SIZE,
        help="the number of images to send to medclip",
    )
    parser.add_argument(
        "--device",
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="use device: cuda / cpu",
    )
    parser.add_argument(
        "--progress",
        action="store_true",
        help="show tqdm progress bar",
    )
    if argv is None and "ipykernel" in sys.modules:
        argv = []
    return parser.parse_args(argv)

def main() -> None:
    args = parse_args()
    input_dirs = [Path(p).resolve() for p in args.input_dirs]
    output_dir = Path(args.output_dir).resolve()
    output_dir.mkdir(parents=True, exist_ok=True)

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
    )
    logging.info("input dirs: %s", input_dirs)
    logging.info("output dir: %s", output_dir)

    device = torch.device(args.device)
    model, preprocess = build_model(device)
    logging.info("loaded medclip model, use device: %s", device)

    batch_size = max(1, args.batch_size)
    nifti_entries = list(iter_nifti_files(input_dirs))
    if not nifti_entries:
        logging.warning("no ct nifti file found in specified dirs")
        return

    logging.info("found %d ct nifti files", len(nifti_entries))

    embedding_blocks: List[np.ndarray] = []
    identifiers: List[str] = []
    patient_names: List[str] = []
    centers: List[str] = []
    errors: List[Tuple[str, str]] = []

    iterator: Iterable[Tuple[Path, str, str]]
    if args.progress:
        iterator = tqdm(nifti_entries, desc="read NIfTI", unit="file")
    else:
        iterator = nifti_entries

    buffer: List[ImageItem] = []
    for nifti_path, center_name, patient_name in iterator:
        try:
            images = nifti_to_images(nifti_path, center_name, patient_name)
            buffer.extend(images)
        except Exception as exc:
            errors.append((str(nifti_path), repr(exc)))
            continue

        while len(buffer) >= batch_size:
            batch = buffer[:batch_size]
            buffer = buffer[batch_size:]
            embeddings = embed_images(model, preprocess, device, batch)
            embedding_blocks.append(embeddings)
            identifiers.extend(item.identifier for item in batch)
            patient_names.extend(item.patient_name for item in batch)
            centers.extend(item.center for item in batch)

    if buffer:
        embeddings = embed_images(model, preprocess, device, buffer)
        embedding_blocks.append(embeddings)
        identifiers.extend(item.identifier for item in buffer)
        patient_names.extend(item.patient_name for item in buffer)
        centers.extend(item.center for item in buffer)

    all_embeddings = np.concatenate(embedding_blocks, axis=0) if embedding_blocks else np.zeros((0, 512), dtype=np.float32)

    npz_path = output_dir / "medclip_embeddings_ct_vit.npz"
    np.savez_compressed(
        npz_path,
        embeddings=all_embeddings,
        identifiers=np.array(identifiers, dtype=object),
        patient_names=np.array(patient_names, dtype=object),
        centers=np.array(centers, dtype=object),
    )
    logging.info("save embeddings to %s, shape: %s", npz_path, all_embeddings.shape)

    embedding_cols = [f"embedding_{i:03d}" for i in range(all_embeddings.shape[1])]
    dataframe = pd.DataFrame(all_embeddings, columns=embedding_cols, dtype=np.float32)
    dataframe.insert(0, "center", centers)
    dataframe.insert(0, "patient_name", patient_names)
    csv_path = output_dir / "medclip_embeddings_ct_vit.csv"
    dataframe.to_csv(csv_path, index=False)
    logging.info("save csv to %s, shape: %s", csv_path, dataframe.shape[1])

    mapping_df = pd.DataFrame(
        {
            "identifier": identifiers,
            "patient_name": patient_names,
            "center": centers,
        }
    )
    mapping_path = output_dir / "medclip_embeddings_ct_mapping.csv"
    mapping_df.to_csv(mapping_path, index=False)
    logging.info("save mapping to %s, shape: %s", mapping_path, mapping_df.shape[0])

    meta = {
        "input_dirs": [str(p) for p in input_dirs],
        "model": "MedCLIP-ViT",
        "batch_size": batch_size,
        "device": str(device),
        "num_files": len(nifti_entries),
        "num_images": len(identifiers),
        "npz_path": str(npz_path),
        "csv_path": str(csv_path),
        "mapping_path": str(mapping_path),
    }
    meta_path = output_dir / "medclip_embeddings_ct_meta.json"
    with meta_path.open("w", encoding="utf-8") as f:
        json.dump(meta, f, indent=2, ensure_ascii=False)
    logging.info("save meta to %s", meta_path)

    if errors:
        error_log_path = output_dir / "medclip_embedding_ct_errors.log"
        with error_log_path.open("w", encoding="utf-8") as f:
            for path, err in errors:
                f.write(f"{path}\t{err}\n")
        logging.warning("save %d errors to %s", len(errors), error_log_path)
    else:
        logging.info("save %d errors to %s", len(errors), error_log_path)

if __name__ == "__main__":
    main()
