In [4]:
# Copyright 2020 - 2022 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from functools import partial

import nibabel as nib
import numpy as np
import torch
import rootutils
rootutils.setup_root(search_from="/work/hpc/spine-segmentation/notebooks/infer.ipynb", indicator="setup.py", pythonpath=True)

from src.data.spider_datamodule import get_loader

from monai.inferers import sliding_window_inference
from monai.networks.nets import SwinUNETR

parser = argparse.ArgumentParser(description="Swin UNETR segmentation pipeline")
parser.add_argument("--data_dir", default="./dataset/", type=str, help="dataset directory")
parser.add_argument("--exp_name", default="test_nii_2", type=str, help="experiment name")
parser.add_argument("--json_list", default="./jsons/brats21_folds.json", type=str, help="dataset json file")
parser.add_argument("--pretrained_model_name", default="model_final.pt", type=str, help="pretrained model name")
parser.add_argument("--feature_size", default=48, type=int, help="feature size")
parser.add_argument("--infer_overlap", default=0.6, type=float, help="sliding window inference overlap")
parser.add_argument("--in_channels", default=4, type=int, help="number of input channels")
parser.add_argument("--out_channels", default=3, type=int, help="number of output channels")
parser.add_argument("--a_min", default=-175.0, type=float, help="a_min in ScaleIntensityRanged")
parser.add_argument("--a_max", default=250.0, type=float, help="a_max in ScaleIntensityRanged")
parser.add_argument("--b_min", default=0.0, type=float, help="b_min in ScaleIntensityRanged")
parser.add_argument("--b_max", default=1.0, type=float, help="b_max in ScaleIntensityRanged")
parser.add_argument("--space_x", default=1.5, type=float, help="spacing in x direction")
parser.add_argument("--space_y", default=1.5, type=float, help="spacing in y direction")
parser.add_argument("--space_z", default=2.0, type=float, help="spacing in z direction")
parser.add_argument("--roi_x", default=96, type=int, help="roi size in x direction")
parser.add_argument("--roi_y", default=96, type=int, help="roi size in y direction")
parser.add_argument("--roi_z", default=96, type=int, help="roi size in z direction")
parser.add_argument("--dropout_rate", default=0.0, type=float, help="dropout rate")
parser.add_argument("--distributed", action="store_true", help="start distributed training")
parser.add_argument("--workers", default=8, type=int, help="number of workers")
parser.add_argument("--RandFlipd_prob", default=0.2, type=float, help="RandFlipd aug probability")
parser.add_argument("--RandRotate90d_prob", default=0.2, type=float, help="RandRotate90d aug probability")
parser.add_argument("--RandScaleIntensityd_prob", default=0.1, type=float, help="RandScaleIntensityd aug probability")
parser.add_argument("--RandShiftIntensityd_prob", default=0.1, type=float, help="RandShiftIntensityd aug probability")
parser.add_argument("--spatial_dims", default=3, type=int, help="spatial dimension of input data")
parser.add_argument("--use_checkpoint", action="store_true", help="use gradient checkpointing to save memory")
parser.add_argument(
    "--pretrained_dir",
    default="./runs/brats/",
    type=str,
    help="pretrained checkpoint directory",
)


def main():
    args = parser.parse_args()
    args.test_mode = True
    output_directory = "./outputs/" + args.exp_name
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)
    test_loader = get_loader(args)
    pretrained_dir = args.pretrained_dir
    model_name = args.pretrained_model_name
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pretrained_pth = os.path.join(pretrained_dir, model_name)
    model = SwinUNETR(
        img_size=128,
        in_channels=args.in_channels,
        out_channels=args.out_channels,
        feature_size=args.feature_size,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        dropout_path_rate=0.0,
        use_checkpoint=args.use_checkpoint,
    )
    model_dict = torch.load(pretrained_pth)["state_dict"]
    model.load_state_dict(model_dict)
    model.eval()
    model.to(device)

    model_inferer_test = partial(
        sliding_window_inference,
        roi_size=[args.roi_x, args.roi_y, args.roi_z],
        sw_batch_size=1,
        predictor=model,
        overlap=args.infer_overlap,
    )

    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            image = batch["image"].cuda()
            print(batch.keys())
            affine = batch["image_meta_dict"]["original_affine"][0].numpy()
            num = batch["image_meta_dict"]["filename_or_obj"][0].split("/")[-1].split("_")[1]
            img_name = "BraTS2021_" + num + ".nii.gz"
            print("Inference on case {}".format(img_name))
            prob = torch.sigmoid(model_inferer_test(image))
            seg = prob[0].detach().cpu().numpy()
            seg = (seg > 0.5).astype(np.int8)
            seg_out = np.zeros((seg.shape[1], seg.shape[2], seg.shape[3]))
            seg_out[seg[1] == 1] = 2
            seg_out[seg[0] == 1] = 1
            seg_out[seg[2] == 1] = 4
            nib.save(nib.Nifti1Image(seg_out.astype(np.uint8), affine), os.path.join(output_directory, img_name))
        print("Finished inference!")


if __name__ == "__main__":
    main()

ModuleNotFoundError: No module named 'src.utils.data_utils'

In [12]:
from monai import transforms
from monai.data.utils import list_data_collate
from src.data.components.spider_dataset import SpiderDataset, SpiderTransformedDataset
from torch.utils.data import DataLoader, Dataset

transform_val = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image", "label"]),
        transforms.ConvertToMultiChannelBasedOnBratsClassesd(keys=["label"]),
        transforms.NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        transforms.ToTensord(keys=["image", "label"]),
    ]
)

data_dir = "/data/hpc/spine/dataset"
json_path = "/data/hpc/spine/jsons/brats21_folds_test.json"
dataset = SpiderDataset(data_dir=data_dir, json_path=json_path)

# print(dataset[0])

dataloader = DataLoader(dataset=SpiderTransformedDataset(dataset, transform_val), 
                            batch_size=1, ##self.hparams.batch_size,
                            num_workers=4,
                            pin_memory=False,
                            shuffle=False,
                            collate_fn = list_data_collate)

batch = next(iter(dataloader))
print(batch)

{'image': ['/data/hpc/spine/dataset/TrainingData/BraTS2021_01146/BraTS2021_01146_flair.nii.gz', '/data/hpc/spine/dataset/TrainingData/BraTS2021_01146/BraTS2021_01146_t1ce.nii.gz', '/data/hpc/spine/dataset/TrainingData/BraTS2021_01146/BraTS2021_01146_t1.nii.gz', '/data/hpc/spine/dataset/TrainingData/BraTS2021_01146/BraTS2021_01146_t2.nii.gz'], 'label': '/data/hpc/spine/dataset/TrainingData/BraTS2021_01146/BraTS2021_01146_seg.nii.gz'}
{'image': metatensor([[[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]],

          [[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]

# Checkpoint

In [1]:
import torch
import rootutils
rootutils.setup_root(search_from="/work/hpc/spine-segmentation/notebooks/infer.ipynb", indicator="setup.py", pythonpath=True)

cp = torch.load("/work/hpc/spine-segmentation/logs/train/runs/2024-08-05_18-28-22/checkpoints/epoch_271.ckpt")
print(cp)
print(cp["state_dict"]["criterion.class_weight"])

cp["state_dict"]["criterion.class_weight"] = torch.Tensor([1])

torch.save(cp, "/work/hpc/spine-segmentation/logs/train/runs/2024-08-05_18-28-22/checkpoints/epoch_271_v2.ckpt")

Using `TRANSFORMERS_CACHE` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.


{'epoch': 271, 'global_step': 27472, 'pytorch-lightning_version': '2.2.1', 'state_dict': OrderedDict([('net.model.0.conv.0.conv.weight', tensor([[[[[-0.1225, -0.0842, -0.0639],
           [ 0.0017, -0.0607,  0.0699],
           [ 0.0577,  0.1214,  0.1349]],

          [[-0.0244, -0.0877, -0.0796],
           [-0.0219, -0.1818, -0.0941],
           [-0.1491,  0.1811, -0.1348]],

          [[-0.0140,  0.0715, -0.0974],
           [-0.0808,  0.1216,  0.1736],
           [ 0.0701,  0.1207,  0.0307]]]],



        [[[[-0.1209,  0.0849,  0.0183],
           [-0.0314, -0.0661, -0.1775],
           [ 0.0375, -0.1959,  0.0802]],

          [[-0.1698, -0.1662,  0.1522],
           [-0.0099,  0.0234, -0.0121],
           [ 0.1764,  0.1582, -0.1094]],

          [[ 0.1075,  0.0656,  0.0349],
           [-0.1363,  0.1911, -0.1055],
           [-0.0515,  0.1204,  0.1011]]]],



        [[[[-0.0166, -0.0135,  0.0241],
           [-0.1057, -0.1105, -0.0416],
           [ 0.0411, -0.0390, -0.1430]],

 

In [20]:
import torch
import rootutils
rootutils.setup_root(search_from="/work/hpc/spine-segmentation/notebooks/infer.ipynb", indicator="setup.py", pythonpath=True)

cp = torch.load("/work/hpc/spine-segmentation/logs/train/runs/attention-unet-v1/checkpoints/epoch_107_v2.ckpt")
print(cp)


{'epoch': 0, 'global_step': 202, 'pytorch-lightning_version': '2.2.1', 'state_dict': OrderedDict([('net.swinViT.patch_embed.proj.weight', tensor([[[[[ 4.1620e-02, -1.0403e-01],
           [ 2.7865e-01,  2.7520e-01]],

          [[ 1.9443e-01,  1.6664e-01],
           [-3.0434e-01,  3.3083e-01]]]],



        [[[[-2.3491e-01, -2.4341e-01],
           [-6.1268e-02, -1.4723e-01]],

          [[ 9.7526e-02, -2.6410e-01],
           [-5.3158e-02,  1.3275e-01]]]],



        [[[[-2.8319e-01, -3.4155e-01],
           [-3.2951e-01,  1.7043e-01]],

          [[ 4.5029e-02,  3.3398e-01],
           [-6.7721e-02,  2.6913e-01]]]],



        [[[[ 1.8056e-01,  1.0918e-01],
           [ 1.0931e-01,  5.5638e-02]],

          [[-1.1958e-01,  6.1515e-02],
           [-7.5508e-02, -1.8411e-01]]]],



        [[[[ 1.0075e-01,  2.0924e-01],
           [ 2.2939e-01,  3.0182e-01]],

          [[-3.5235e-01,  1.7376e-01],
           [ 1.8674e-01, -1.3287e-01]]]],



        [[[[ 1.3757e-01, -5.3111e-02],
   

In [2]:
import wandb

api = wandb.Api()
runs = api.runs(path="tiendung050803/Spine")
for i in runs:
  print("run name = ",i.name," id: ", i.id)

run name =  restful-yogurt-112  id:  ztuc8lee
run name =  196-36 v2  id:  h7ztw2pj
run name =  192-36  id:  tdjll85z
run name =  3-classes  id:  vswbxb4v
run name =  15-classes  id:  5uraune8


: 