In [None]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
from tqdm import tqdm

from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandFlipd,
    RandCropByPosNegLabeld,
    RandShiftIntensityd,
    ScaleIntensityRanged,
    Spacingd,
    RandRotate90d,
)

from monai.config import print_config
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR

from monai.data import (
    DataLoader,
    CacheDataset,
    load_decathlon_datalist,
    decollate_batch,
)


import torch

print_config()

In [None]:
directory = "datasets/processed"
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-175,
            a_max=250,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[1],
            prob=0.10,
        ),
        RandFlipd(
            keys=["image", "label"],
            spatial_axis=[2],
            prob=0.10,
        ),
        RandRotate90d(
            keys=["image", "label"],
            prob=0.10,
            max_k=3,
        ),
        RandShiftIntensityd(
            keys=["image"],
            offsets=0.10,
            prob=0.50,
        ),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.5, 1.5, 2.0),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=["image", "label"], source_key="image"),
    ]
)

In [None]:
data_dir = "datasets/"
split_json = "dataset_0.json"

datasets = data_dir + split_json
datalist = load_decathlon_datalist(datasets, True, "training")
val_files = load_decathlon_datalist(datasets, True, "validation")
train_ds = CacheDataset(
    data=datalist,
    transform=train_transforms,
    cache_num=24,
    cache_rate=1.0,
    num_workers=8,
)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_num=6, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
slice_map = {
    "img0035.nii.gz": 170,
    "img0036.nii.gz": 230,
    "img0037.nii.gz": 204,
    "img0038.nii.gz": 204,
    "img0039.nii.gz": 204,
    "img0040.nii.gz": 180,
}
case_num = 0
img_name = os.path.split(val_ds[case_num]["image"].meta["filename_or_obj"])[1]
img = val_ds[case_num]["image"]
label = val_ds[case_num]["label"]
img_shape = img.shape
label_shape = label.shape
print(f"image shape: {img_shape}, label shape: {label_shape}")
plt.figure("image", (18, 6))
plt.subplot(1, 2, 1)
plt.title("image")
plt.imshow(img[0, :, :, slice_map[img_name]].detach().cpu(), cmap="gray")
plt.subplot(1, 2, 2)
plt.title("label")
plt.imshow(label[0, :, :, slice_map[img_name]].detach().cpu())
# Statistics about each label's srea
print("Statistics about each label's area")
for i in range(14):
    print(f"label {i} has {torch.sum(label == i)} voxels")
plt.show()

In [None]:
import numpy as np
import cv2
# 加载预训练的SAM模型
from segment_anything import SamPredictor, sam_model_registry
sam = sam_model_registry['vit_b'](checkpoint="checkpoints/sam_vit_b_01ec64.pth")
predictor = SamPredictor(sam)
predictor.model.to('cuda')
torch.set_grad_enabled(False)
# 准备评估指标
from monai.metrics import DiceMetric
dice_metric = DiceMetric(include_background=True, reduction="mean")

# Perform inference using sliding window
roi_size = (96, 96)  # Adjust this if needed
sw_batch_size = 4

val_outputs = []
for batch in train_loader:
    image = batch['image'][0]  # select the first image
    print(image.shape)
    for i in range(image.shape[1]):  # iterate over each slice
        slice_image = image[0][i].numpy()  # select the i-th slice and convert to numpy array
        # print shape
        slice_image_gs = cv2.cvtColor(slice_image, cv2.COLOR_GRAY2RGB).astype(np.uint8)  # convert grayscale to RGB
        predictor.set_image(slice_image_gs)  # set the image before prediction
        masks, iou_preds, _ = predictor.predict()  # predict
        val_outputs.append(masks)  # append the prediction to the list
        print(masks[0])
        # visualize
        # visualize
        if i == 0:  # only visualize the first slice
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 2, 1)
            plt.imshow(slice_image, cmap='gray')
            plt.title('Slice Image')
            plt.subplot(1, 2, 2)
            plt.imshow(masks[0], cmap='jet')
            plt.title('Prediction')
            plt.show()
        break
        
    break

# Calculate Dice scores
val_labels = [data["label"] for data in decollate_batch(val_outputs[0])]
val_outputs = [pred.argmax(dim=1, keepdim=True) for pred in decollate_batch(val_outputs[1])]
dice_metric.update((val_labels, val_outputs))
mean_dice = dice_metric.aggregate().item()
print(f"Mean Dice: {mean_dice:.4f}")

# Visualize results
pred_slice = val_outputs[case_num][0, :, :].detach().cpu()

plt.figure("Prediction", (18, 6))
plt.subplot(1, 3, 1)
plt.title("Image")
plt.imshow(img[0, :, :].detach().cpu(), cmap="gray")
plt.subplot(1, 3, 2)
plt.title("Label")
plt.imshow(label[0, :, :].detach().cpu())
plt.subplot(1, 3, 3)
plt.title("Prediction")
plt.imshow(pred_slice)
plt.show()