### CRCME Feature Extraction Demo

This notebook demonstrates how to:
1. Load the CRCME pre-trained model (CT + joint/FU weights)
2. Encode a single CT image to extract global features
3. Optionally, batch encode multiple images from a folder

#### 1. Install and Import Dependencies

In [4]:
# !pip install torch torchvision SimpleITK numpy matplotlib

import torch
import os
import sys
current_dir = os.getcwd()  # 替代 __file__
root_dir = os.path.abspath(os.path.join(current_dir, ".."))
sys.path.append(root_dir)
import SimpleITK as sitk
import numpy as np
from lib.model_MOE import ViT_ct, ViT_fu, FusionModel, FusionPipeline

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Using device: cuda


#### 2. Initialize Models and Load Pretrained Weights

In [8]:
# Initialize models
model_a = ViT_ct(
    image_size = 256,
    frames = 32,
    image_patch_size = 16,
    frame_patch_size = 16,
    dim = 1024,
    depth = 24,
    heads = 16,
    emb_dropout = 0.1)

model_b = ViT_fu(
    image_size = 256,
    frames = 32,
    image_patch_size = 16,
    frame_patch_size = 16,
    dim = 1024,
    depth = 24,
    heads = 16)

fusion_model = FusionModel(input_dim_a=1024, input_dim_b=1024, classes=2)  # TODO: modify num_classes
pipeline = FusionPipeline(model_a, model_b, fusion_model, num_classes=2)

# Paths to pretrained weights
pretrained_ct = 'checkpoints/checkpoint-ct.pth'        # TODO: modify path
pretrained_fu = 'checkpoints/checkpoint-joint.pth'     # TODO: modify path

# Helper function to load weights
def load_weights(model, path):
    if os.path.isfile(path):
        print(f"Loading pretrained weights from: {path}")
        checkpoint = torch.load(path, map_location='cpu')
        pretrained_dict = checkpoint.get('model', checkpoint)
        cleaned_dict = {k.replace("module.", ""): v for k, v in pretrained_dict.items()}
        compatible_dict = {k: v for k, v in cleaned_dict.items() if k in model.state_dict()}
        model.load_state_dict({**model.state_dict(), **compatible_dict})
        print("Pretrained weights loaded successfully.")
    else:
        raise FileNotFoundError(f"Pretrained model not found at: {path}")

# Load weights
load_weights(model_a, pretrained_ct)
load_weights(model_b, pretrained_fu)

# Move pipeline to device
pipeline.to(device)
pipeline.eval()


Loading pretrained weights from: /cache/yangjing/main_files/CRCFound2/argo2/mymodel/checkpoint-999.pth


  checkpoint = torch.load(path, map_location='cpu')


Pretrained weights loaded successfully.
Loading pretrained weights from: /cache/yangjing/main_files/CRCFound2/CRCFound1/logs_1w/mix_no_ada_2k/patch32_frame32_large_256-None/logs/checkpoint-300.pth
Pretrained weights loaded successfully.


FusionPipeline(
  (model_a): ViT_ct(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(1, 1024, kernel_size=(16, 16, 16), stride=(16, 16, 16))
    )
    (dropout): Dropout(p=0.1, inplace=False)
    (blocks): ModuleList(
      (0-23): 24 x Block(
        (norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1024, out_features=3072, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=1024, out_features=1024, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (norm)

#### 3. Encode a Single Image

In [10]:
# Load a CT image
img_path = 'path/to/ct_images/'  # TODO: modify path
img = sitk.ReadImage(img_path)
img_array = sitk.GetArrayFromImage(img).astype(np.float32)

# Convert to tensor and add batch & channel dims
img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)

# Extract image features
with torch.inference_mode():
    _,image_embeddings,_ = pipeline(
        img_tensor.to(device),
        flag=1   # return global feature token
    )
print("Extracted feature shape:", image_embeddings.shape)


Extracted feature shape: (1, 1024)


#### 4. (Optional) Batch Encode a Folder of CT Images

In [None]:
import glob

img_folder = 'path/to/ct_images/'  # TODO: modify folder path
img_paths = sorted(glob.glob(os.path.join(img_folder, '*.nii')))  # adjust extension if needed

features_list = []

for p in img_paths:
    img = sitk.ReadImage(p)
    img_array = sitk.GetArrayFromImage(img).astype(np.float32)
    img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)
    
    with torch.inference_mode():
        feat = pipeline(img_tensor.to(device), return_features=True)
    features_list.append(feat.cpu().numpy())

print(f"Extracted features for {len(features_list)} images, each shape: {features_list[0].shape}")
