In [1]:
import os
import random
import shutil
import tempfile
import time
import numpy as np 
import matplotlib.pyplot as plt

from monai.data import DataLoader, decollate_batch, Dataset, CacheDataset
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import UNet, SegResNet
from monai.transforms import (
    Activations,
    Compose,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    AsDiscrete,
    Resized,

)

import torch
from pathlib import Path
from natsort import natsorted
from tqdm.notebook import tqdm

In [2]:
class cfg: 
    base_path = Path('/player/data/BraTS2020_converted')
    image_path = base_path / "image"
    label_path = base_path / "label"
    seed = 2023
    in_channels=4
    out_channels=3
    save_dir = Path('/player/workspace/Python/brain-otock/Model')
    

In [3]:
test_data_dicts = []

image_files = [cfg.image_path / file_path for file_path in sorted(os.listdir(cfg.image_path))][-30:]


for image_file in image_files:
    match_number = image_file.name.split('_')[-2]
    label_file = cfg.label_path / f"volume_{match_number}_mask.nii"

    # only add match 
    if os.path.exists(label_file):
        test_data_dicts.append({"image": str(image_file), "label": str(label_file)})
        
print(f"the length of match is {len(test_data_dicts)}")

the length of match is 30


In [4]:
print(f"the length of train dataset is {len(test_data_dicts)}")
print(test_data_dicts[0])

the length of train dataset is 30
{'image': '/player/data/BraTS2020_converted/image/volume_73_image.nii', 'label': '/player/data/BraTS2020_converted/label/volume_73_mask.nii'}


In [5]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert input channels to specific output channels using PyTorch tensors.

    Input Dimension: (W, H, D, C(Channel)) 
    Input Channel Description:
        0: 'Necrotic (NEC)' unique (0, 1)
        1: 'Edema (ED)' unique (0, 1)
        2: 'Tumour (ET)' unique (0, 1)

    Output Dimension: (C(Channel), W, H, D)
    Output Channel Description:
        0: TC (Tumor core)
        1: WT (Whole tumor)
        2: ET (Enhancing tumor)
    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            # Convert data to PyTorch tensor if not already one
            data_tensor = data[key]

            necrotic = data_tensor[..., 0]
            edema = data_tensor[..., 1]
            enhancing = data_tensor[..., 2]

            # Compute TC: necrotic and enhancing tumor areas (logical OR)
            tc = torch.logical_or(necrotic, enhancing)

            # Compute WT: all tumor areas (logical OR among all channels)
            wt = torch.logical_or(torch.logical_or(necrotic, edema), enhancing)

            # Combine channels into a new tensor with dimensions (C, W, H, D)
            d[key] = torch.stack([tc, wt, enhancing], dim=0)

        return d


In [6]:
# Test transform
test_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Resized(keys=["image"], spatial_size=[128, 128, 80], mode="bilinear"),
        Resized(keys=["label"], spatial_size=[128, 128, 80], mode="nearest"),
    ]
)

In [9]:
# Define train and validation datasets with caching
test_dataset = Dataset(data=test_data_dicts, transform=test_transform)

# Define DataLoaders
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=3)

In [10]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(cfg.seed)

In [11]:
val_interval = 1
VAL_AMP = True

# standard PyTorch program style: create SegResNet, DiceLoss, and Adam optimizer
model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).cuda()


dice_metric = DiceMetric(include_background=True, reduction="mean")

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# define inference method
def inference(input):
    def _compute(input):
        # Direct inference without sliding window
        return model(input)

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)

# use amp to accelerate training
scaler = torch.cuda.amp.GradScaler()
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True


In [12]:
import os
import torch
from tqdm import tqdm

# 모델을 평가 모드로 설정
model.eval()

# 저장된 모델 파일들을 순회
for load_model in os.listdir(cfg.save_dir):
    full_path = os.path.join(cfg.save_dir, load_model)
    
    # 모델 로드
    model.load_state_dict(torch.load(full_path))

    with torch.no_grad():
        # 테스트 루프
        with tqdm(total=len(test_loader), desc="Test", unit="batch") as pbar_val:
            for test_data in test_loader:
                test_inputs, test_labels = (
                    test_data["image"].to('cuda'),
                    test_data["label"].to('cuda')
                )
                # 모델 추론
                test_outputs = inference(test_inputs)

                # 후처리 및 배치 디콜레이트
                test_outputs = [post_trans(i) for i in decollate_batch(test_outputs)]

                # Dice 지표 계산
                dice_metric(y_pred=test_outputs, y=test_labels)
                pbar_val.update(1)
        
        # 최종 결과 계산 및 출력
        metric = dice_metric.aggregate().item()
        dice_metric.reset()
        print(f"Model: {load_model}, Dice Metric: {metric:.4f}")

Test: 100%|██████████| 8/8 [04:23<00:00, 33.00s/batch]


Model: best_metric_model_0.7519.pth, Dice Metric: 0.7574


Test: 100%|██████████| 8/8 [02:37<00:00, 19.71s/batch]


Model: best_metric_model_0.7534.pth, Dice Metric: 0.7788


Test: 100%|██████████| 8/8 [02:36<00:00, 19.53s/batch]


Model: best_metric_model_0.7536.pth, Dice Metric: 0.7712


Test: 100%|██████████| 8/8 [02:40<00:00, 20.11s/batch]


Model: best_metric_model_0.7612.pth, Dice Metric: 0.7808


Test: 100%|██████████| 8/8 [02:38<00:00, 19.76s/batch]


Model: best_metric_model_0.7619.pth, Dice Metric: 0.7774


Test: 100%|██████████| 8/8 [02:38<00:00, 19.81s/batch]


Model: best_metric_model_0.7621.pth, Dice Metric: 0.7747


Test: 100%|██████████| 8/8 [02:38<00:00, 19.79s/batch]


Model: best_metric_model_0.7640.pth, Dice Metric: 0.7532


Test: 100%|██████████| 8/8 [02:38<00:00, 19.81s/batch]


Model: best_metric_model_0.7641.pth, Dice Metric: 0.7834


Test: 100%|██████████| 8/8 [02:38<00:00, 19.82s/batch]


Model: best_metric_model_0.7645.pth, Dice Metric: 0.7733


Test: 100%|██████████| 8/8 [02:36<00:00, 19.54s/batch]


Model: best_metric_model_0.7672.pth, Dice Metric: 0.7733


Test: 100%|██████████| 8/8 [02:35<00:00, 19.40s/batch]


Model: best_metric_model_0.7674.pth, Dice Metric: 0.7744


Test: 100%|██████████| 8/8 [02:39<00:00, 19.95s/batch]


Model: best_metric_model_0.7713.pth, Dice Metric: 0.7769


Test: 100%|██████████| 8/8 [02:38<00:00, 19.75s/batch]


Model: best_metric_model_0.7747.pth, Dice Metric: 0.7900


Test: 100%|██████████| 8/8 [02:35<00:00, 19.50s/batch]

Model: best_metric_model_0.7790.pth, Dice Metric: 0.7793



