In [1]:
import torch
from sam import SamPredictor, sam_model_registry
from sam.utils.transforms import ResizeLongestSide

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
image_encoder = sam.image_encoder
prompt_encoder = sam.prompt_encoder

In [2]:
import clip
clip_model, preprocess = clip.load("ViT-B/32", device=device)
text = clip.tokenize(["brain"]).to(device)
text_features = clip_model.encode_text(text)

# 前處理

In [3]:
import json,os
import numpy as np
from monai.transforms import (
    AsDiscrete,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
    EnsureTyped,
    SpatialCrop,
    AddChanneld,
    Transform,
    ResizeWithPadOrCropd,
    Lambda,
)
from monai.data import (
    ThreadDataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
    set_track_meta,
    DataLoader,
)

def get_evaluation_transform(spacing):
    return Compose(
        [
            LoadImaged(keys=["image", "label"], ensure_channel_first=None),
            AddChanneld(keys=["image", "label"]),
            ScaleIntensityRanged(
                keys=["image"], a_min=-1024, a_max=3071, b_min=0.0, b_max=255.0, clip=True
            ),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            Spacingd(
                keys=["image", "label"],
                pixdim=(spacing[0], spacing[1], spacing[2]),
                mode=("bilinear", "nearest"),
            ),
            EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
        ]
    )

def custom_load_decathlon_datalist(json_data_path, index):
    with open(json_data_path, 'r', encoding='utf-8') as json_file:
        json_data = json_file.read()
    parsed_data = json.loads(json_data)
    target = parsed_data[index]
    # 路徑+去掉.變成完整路徑
    for i in range(len(target)):
        target[i]['image'] = data_dir + target[i]['image'][1:]
        target[i]['label'] = data_dir + target[i]['label'][1:]

    return target

# image數量
organ_num = {
    3:131,
    6:63,
    7:281,
    8:303,
    9:41,
    10:126,
}
spacings = {
    2: [1.25, 1.25, 1.37],
    3: [0.7676, 0.7676, 1],
    6: [0.79, 0.79, 1.24],
    7: [0.8, 0.8, 2.5],
    8: [0.8, 0.8, 1.5],
    9: [0.78, 0.78, 1.6],
    10: [0.78, 0.78, 3],
}
data_dir ="D:\\SAM\\data"

In [4]:
import cv2
task_list = [3,6,7,8,9,10] #3,6,7,8,9,10
for task in task_list:
    datasets = os.path.join(data_dir, f"dataset_{task}.json")

    # 得到data path list
    datalist = custom_load_decathlon_datalist(datasets, "training")

    split_index = organ_num[task]* 4 // 5
    segment1 = datalist[0:7]
    segment2 = datalist[split_index:min(split_index+7, organ_num[task])]
    # 连接两个段
    eval_list = segment1 + segment2
    print('eval_list', eval_list, len(eval_list))

    dataset = CacheDataset(
        data = eval_list,
        transform = get_evaluation_transform(spacings[task]),
        cache_num = 24,
        cache_rate = 1.0,
        num_workers = 8,
    )   

    # split_index = len(datalist) * 4 // 5

    for i in range(len(dataset)):
        image, label = dataset[i]['image'], dataset[i]['label']
        
        # 獲取圖像和標籤的形狀
        print(image.shape)
        _, D, H, W = image.shape
        print(type(image))
        
        # 遍歷每一個slice
        for d in range(D):
            # 獲取2D slice
            image_slice = image[0, d, :, :]
            label_slice = label[0, d, :, :]
            image_slice_uint8 = image_slice.astype(np.uint8)
            # 顯示圖像
            cv2.imshow('Image Slice', image_slice_uint8)
            cv2.imshow('Label Slice', label_slice.cpu().numpy())
            
            # 等待並檢查是否有按鍵事件
            key = cv2.waitKey(0) & 0xFF
            
            # 如果按下 'q' 鍵，則退出循環
            if key == ord('q'):
                break
        # 如果按下 'q' 鍵，則退出循環
        if key == ord('q'):
            break
            
            # if not os.path.exists(f"my_data/task{task}"):
            #     os.makedirs(f"my_data/task{task}")
            # # 保存為.npy文件
            # np.save(f'my_data/task{task}/image_{i}_{d}.npy', image_slice.cpu().numpy())
            # np.save(f'my_data/task{task}/label_{i}_{d}.npy', label_slice.cpu().numpy())
cv2.destroyAllWindows()



eval_list [{'image': 'D:\\SAM\\data/imagesTr/liver_14.nii.gz', 'label': 'D:\\SAM\\data/labelsTr/liver_14.nii.gz'}, {'image': 'D:\\SAM\\data/imagesTr/liver_69.nii.gz', 'label': 'D:\\SAM\\data/labelsTr/liver_69.nii.gz'}, {'image': 'D:\\SAM\\data/imagesTr/liver_77.nii.gz', 'label': 'D:\\SAM\\data/labelsTr/liver_77.nii.gz'}, {'image': 'D:\\SAM\\data/imagesTr/liver_120.nii.gz', 'label': 'D:\\SAM\\data/labelsTr/liver_120.nii.gz'}, {'image': 'D:\\SAM\\data/imagesTr/liver_18.nii.gz', 'label': 'D:\\SAM\\data/labelsTr/liver_18.nii.gz'}, {'image': 'D:\\SAM\\data/imagesTr/liver_65.nii.gz', 'label': 'D:\\SAM\\data/labelsTr/liver_65.nii.gz'}, {'image': 'D:\\SAM\\data/imagesTr/liver_30.nii.gz', 'label': 'D:\\SAM\\data/labelsTr/liver_30.nii.gz'}, {'image': 'D:\\SAM\\data/imagesTr/liver_37.nii.gz', 'label': 'D:\\SAM\\data/labelsTr/liver_37.nii.gz'}, {'image': 'D:\\SAM\\data/imagesTr/liver_29.nii.gz', 'label': 'D:\\SAM\\data/labelsTr/liver_29.nii.gz'}, {'image': 'D:\\SAM\\data/imagesTr/liver_54.nii.gz',

Loading dataset: 100%|██████████| 14/14 [01:05<00:00,  4.65s/it]


torch.Size([1, 457, 457, 588])
<class 'monai.data.meta_tensor.MetaTensor'>
torch.Size([1, 503, 503, 489])
<class 'monai.data.meta_tensor.MetaTensor'>
torch.Size([1, 521, 521, 466])
<class 'monai.data.meta_tensor.MetaTensor'>
torch.Size([1, 496, 496, 636])
<class 'monai.data.meta_tensor.MetaTensor'>
torch.Size([1, 615, 615, 676])
<class 'monai.data.meta_tensor.MetaTensor'>
torch.Size([1, 417, 417, 410])
<class 'monai.data.meta_tensor.MetaTensor'>
torch.Size([1, 667, 667, 200])
<class 'monai.data.meta_tensor.MetaTensor'>
