In [None]:
import os
import glob as glob

In [4]:
root_dir = "/scratch/scratch6/akansh12/Parse_data/train/train/"
train_images = sorted(glob.glob(os.path.join(root_dir, "*", 'image', "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join("./labels/", "*.nii.gz")))

data_dicts = [{"image": images_name, "label": label_name} for images_name, label_name in zip(train_images, train_labels)]
train_files, val_files = data_dicts[:-9], data_dicts[-9:]
set_determinism(seed = 0)

### Test set

In [4]:
import torch
import monai
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd, 
    EnsureTyped,
    EnsureType,
    Invertd,
    AddChanneld,
    ToTensord,
    RandFlipd,
    RandRotate90d,
    RandShiftIntensityd,
    RandRotated,
    Zoomd

)
# from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceCELoss, DiceFocalLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import nibabel as nib
import numpy as np
from tqdm.notebook import tqdm

In [18]:
from monai.networks.nets import UNet
from monai.networks.layers import Norm
import torch
from collections import OrderedDict
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 

UNet_meatdata = dict(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH
    )

model = UNet(**UNet_meatdata).to(device)
path2weights = "/scratch/scratch6/akansh12/challenges/parse2022/temp/main_artery_seg_models/Unet_1000_no_hu_spacing_160_augmentations_loss_ce_main_artery_89_27.pth"
state_dict = torch.load(path2weights, map_location='cpu')

for keyA, keyB in zip(state_dict, model.state_dict()):
    state_dict = OrderedDict((keyB if k == keyA else k, v) for k, v in state_dict.items())
model.load_state_dict(state_dict)

<All keys matched successfully>

In [3]:
root_dir = "/scratch/scratch6/akansh12/Parse_data/evaluation/"
test_files_path = sorted(glob.glob(os.path.join(root_dir, "*.nii.gz")))
test_data = [{"images": image_name } for image_name in test_files_path]

test_transforms = Compose(
    [
        LoadImaged(keys=["images"]),
        EnsureChannelFirstd(keys=["images"]),
        Orientationd(keys=["images"], axcodes="LPS"),
        ScaleIntensityRanged(
            keys=["images"], a_min=-1000, a_max=1000,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["images"], source_key="images"),
        EnsureTyped(keys=["images"]),
    ]
)

test_ds = CacheDataset(
    data = test_data, transform = test_transforms,
    cache_rate = 1.0, num_workers = 4
)
test_loader = DataLoader(test_ds, batch_size = 1, shuffle = False, num_workers=4)

Loading dataset: 100%|â–ˆ| 30/30 [00:39<0


In [19]:
os.makedirs("./submit_orig_8927",exist_ok= True)
post_transforms = Compose([
    Invertd(
        keys="pred",
        transform=test_transforms,
        orig_keys="images",
        meta_keys=None,
        orig_meta_keys=None,
        meta_key_postfix="meta_dict",
        nearest_interp=False,
        to_tensor=True,
    ),
    AsDiscreted(keys="pred", argmax=True),
    SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir="./submit_orig_8927/", output_postfix='seg', resample=False),
])

In [20]:
model.eval()
with torch.no_grad():
    for test_data in test_loader:
        test_inputs = test_data["images"]
        roi_size = (224, 224, 224)
        sw_batch_size = 4
        test_data["pred"] = sliding_window_inference(
            test_inputs, roi_size, sw_batch_size, model, overlap=0.6)
        test_data = [post_transforms(i) for i in decollate_batch(test_data)]

2022-07-26 22:06:56,795 INFO image_writer.py:190 - writing: submit_orig_8927/PA000013/PA000013_seg.nii.gz
2022-07-26 22:08:22,722 INFO image_writer.py:190 - writing: submit_orig_8927/PA000032/PA000032_seg.nii.gz
2022-07-26 22:09:51,811 INFO image_writer.py:190 - writing: submit_orig_8927/PA000044/PA000044_seg.nii.gz
2022-07-26 22:11:17,416 INFO image_writer.py:190 - writing: submit_orig_8927/PA000045/PA000045_seg.nii.gz
2022-07-26 22:12:40,406 INFO image_writer.py:190 - writing: submit_orig_8927/PA000051/PA000051_seg.nii.gz
2022-07-26 22:14:07,261 INFO image_writer.py:190 - writing: submit_orig_8927/PA000057/PA000057_seg.nii.gz
2022-07-26 22:15:32,106 INFO image_writer.py:190 - writing: submit_orig_8927/PA000059/PA000059_seg.nii.gz
2022-07-26 22:16:57,320 INFO image_writer.py:190 - writing: submit_orig_8927/PA000061/PA000061_seg.nii.gz
2022-07-26 22:19:06,462 INFO image_writer.py:190 - writing: submit_orig_8927/PA000069/PA000069_seg.nii.gz
2022-07-26 22:20:33,208 INFO image_writer.py:1

### CCA Part

In [9]:
import SimpleITK as sitk
import itk
def GetLargestConnectedCompont(binarysitk_image):
    """
    save largest object
    :param sitk_maskimg:binary itk image
    :return: largest region binary image
    """
    cc = sitk.ConnectedComponent(binarysitk_image)
    stats = sitk.LabelIntensityStatisticsImageFilter()
    stats.SetGlobalDefaultNumberOfThreads(8)
    stats.Execute(cc, binarysitk_image)
    maxlabel = 0
    maxsize = 0
    for l in stats.GetLabels():
        size = stats.GetPhysicalSize(l)
        if maxsize < size:
            maxlabel = l
            maxsize = size
    labelmaskimage = sitk.GetArrayFromImage(cc)
    outmask = labelmaskimage.copy()
    outmask[labelmaskimage == maxlabel] = 1
    outmask[labelmaskimage != maxlabel] = 0
    outmask_sitk = sitk.GetImageFromArray(outmask)
    outmask_sitk.SetDirection(binarysitk_image.GetDirection())
    outmask_sitk.SetSpacing(binarysitk_image.GetSpacing())
    outmask_sitk.SetOrigin(binarysitk_image.GetOrigin())
    return outmask_sitk

In [10]:
import os
import glob
import shutil
import SimpleITK as sitk
import nibabel as nib
import numpy as np

output = glob.glob("./submit_orig/*/*.nii.gz")
dst = "./submit_wo_seg/"
os.makedirs(dst, exist_ok=True)
for i in output:
    shutil.copy(i,dst+i.split('/')[2] + '.nii.gz')
    
root_dir = "/scratch/scratch6/akansh12/Parse_data/evaluation/"
test_files_path = sorted(glob.glob(os.path.join(root_dir, "*.nii.gz")))
test_data = [{"images": image_name } for image_name in test_files_path]


os.makedirs("./submit", exist_ok=True)
for i in tqdm(test_data):
    input_image = sitk.ReadImage(i['images'])
    input_array = sitk.GetArrayFromImage(input_image)
    resolution = input_image.GetSpacing()
    masked = nib.load("./submit_wo_seg/"+ i['images'].split('/')[-1])
    mask_image = sitk.GetImageFromArray(np.swapaxes(masked.get_fdata(), 0,2))
    mask_image.SetOrigin(input_image.GetOrigin())
    mask_image.SetSpacing(input_image.GetSpacing())
    seg = sitk.BinaryThreshold(mask_image, lowerThreshold=0, upperThreshold=0.5, insideValue=0,
                           outsideValue=1)
    lc = GetLargestConnectedCompont(seg)    

    sitk.WriteImage(lc, "./submit/"+i['images'].split('/')[-1])

  0%|          | 0/30 [00:00<?, ?it/s]

In [8]:
!zip  -r submit_main.zip ./submit

  adding: submit/ (stored 0%)
  adding: submit/PA000219.nii.gz (deflated 93%)
  adding: submit/PA000164.nii.gz (deflated 94%)
  adding: submit/PA000316.nii.gz (deflated 95%)
  adding: submit/PA000105.nii.gz (deflated 93%)
  adding: submit/PA000032.nii.gz (deflated 94%)
  adding: submit/PA000136.nii.gz (deflated 94%)
  adding: submit/PA000122.nii.gz (deflated 93%)
  adding: submit/PA000013.nii.gz (deflated 93%)
  adding: submit/PA000044.nii.gz (deflated 93%)
  adding: submit/PA000312.nii.gz (deflated 95%)
  adding: submit/PA000114.nii.gz (deflated 94%)
  adding: submit/PA000087.nii.gz (deflated 94%)
  adding: submit/PA000126.nii.gz (deflated 94%)
  adding: submit/PA000059.nii.gz (deflated 93%)
  adding: submit/PA000069.nii.gz (deflated 95%)
  adding: submit/PA000051.nii.gz (deflated 94%)
  adding: submit/PA000218.nii.gz (deflated 94%)
  adding: submit/PA000117.nii.gz (deflated 94%)
  adding: submit/PA000269.nii.gz (deflated 93%)
  adding: submit/PA000172.nii.gz (deflated 94%)
  adding: 

### Main Atery Seg ensemble