In [12]:
import argparse
import os
from functools import partial

import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.parallel
import torch.utils.data.distributed
from trainer import run_training
from monai.inferers import sliding_window_inference
from monai.losses import DiceCELoss, FocalLoss
from monai.metrics import DiceMetric, MeanIoU
from monai.transforms import Activations, AsDiscrete, Compose
from monai.utils.enums import MetricReduction
from monai.visualize import matshow3d

from utils.myModel import MyModel, MyModel2d, MyModel3dunet, MyFlexibleUNet2d, MyFlexibleUNet2dLSTM, MyBasicUNetPlusPlus
from utils.my_loss import CustomWeightedDiceCELoss

from pathlib import Path
from easydict import EasyDict as edict
TEST_DIR = Path('/root/autodl-tmp/vesuvius-challenge-ink-detection/test')

In [24]:
args = edict(RandFlipd_prob=0.2, RandRotate90d_prob=0.2, RandScaleIntensityd_prob=0.1, RandShiftIntensityd_prob=0.1, a_max=65535.0, a_min=0.0, amp=True, b_max=255.0, b_min=0.0, batch_size=1, cache_rate=1.0, checkpoint=None, data_dir='/root/autodl-tmp/vesuvius-challenge-ink-detection', debug=False, dist_backend='nccl', dist_url='tcp://127.0.0.1:23456', distributed=False, dropout_path_rate=0.0, dropout_rate=0.0, eff='b3', feature_size=48, gpu=0, in_channels=65, infer_overlap=0.5, json_list='/root/autodl-tmp/data_split/data_split.json', logdir='./runs/512_funetlstm_b3_16_sgd_continue', loss_mode='custom', loss_weight=(2.0, 1.0), lrschedule='cosine_anneal', max_epochs=2000, mid=28, model_mode='2dfunetlstm', momentum=0.99, noamp=False, norm_name='instance', normal=False, num_channel=16, num_samples=4, optim_lr=0.0001, optim_name='sgd', out_channels=1, pretrained_dir='./pretrained_models/', pretrained_model_name='512_funetlstm_b3_16_sgd_1000.pt', rank=0, reg_weight=1e-05, resume_ckpt=True, roi_x=512, roi_y=512, roi_z=16, save_checkpoint=True, smooth_dr=1e-06, smooth_nr=0.0, space_x=1.5, space_y=1.5, space_z=1.0, spatial_dims=3, sw_batch_size=4, test_mode=False, threshold=0.4, use_checkpoint=False, use_normal_dataset=False, use_ssl_pretrained=False, val_every=10, warmup_epochs=50, workers=0, world_size=1)

In [9]:
class Sampler(torch.utils.data.Sampler):
    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True):
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError(
                    "Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError(
                    "Requires distributed package to be available")
            rank = torch.distributed.get_rank()
        self.shuffle = shuffle
        self.make_even = make_even
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(
            math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        indices = list(range(len(self.dataset)))
        self.valid_length = len(
            indices[self.rank: self.total_size: self.num_replicas])

    def __iter__(self):
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))
        if self.make_even:
            if len(indices) < self.total_size:
                if self.total_size - len(indices) < len(indices):
                    indices += indices[: (self.total_size - len(indices))]
                else:
                    extra_ids = np.random.randint(low=0, high=len(
                        indices), size=self.total_size - len(indices))
                    indices += [indices[ids] for ids in extra_ids]
            assert len(indices) == self.total_size
        indices = indices[self.rank: self.total_size: self.num_replicas]
        self.num_samples = len(indices)
        return iter(indices)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch

In [29]:
from utils.utils import get_transforms
from monai.data import load_decathlon_datalist
from monai import data, transforms
from utils.my_transform import *

def get_loader(args):
    data_dir = args.data_dir
    datalist_json = os.path.join(data_dir, args.json_list)
    test_transform = transforms.Compose(
        [
            transforms.LoadImaged(
                keys=["image", "label"], reader="NumpyReader"),
            Copyd(keys=["label"],
                    num_channel=args.num_channel),
            transforms.AddChanneld(keys=["image", 'label']),
            # transforms.CropForegroundd(
            #     keys=["image", "label"], source_key="image"),
            # transforms.Orientationd(keys=["image"], axcodes="RAS"),
            # change_channeld(keys=["image", "label", 'inklabels']),
            # transforms.Spacingd(keys="image", pixdim=(args.space_x, args.space_y, args.space_z), mode="bilinear"),
            transforms.ScaleIntensityRanged(
                keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
            ),
            transforms.ToTensord(keys=["image"]),
        ]
    )
    val_files = load_decathlon_datalist(
        datalist_json, True, "testing", base_dir=data_dir)
    val_ds = data.Dataset(data=val_files, transform=test_transform)
    val_sampler = Sampler(
        val_ds, shuffle=False) if args.distributed else None
    val_loader = data.DataLoader(
        val_ds, batch_size=8, shuffle=False, num_workers=args.workers, sampler=val_sampler, pin_memory=True
    )
    loader = val_loader
    return loader

In [21]:
from utils.utils import resample_3d, resample_2d

def test(model_infer, val_loader, loss_func, args):
    output_directory = Path("./outputs/") / args.exp_name
    output_directory.mkdir(parents=True, exist_ok=True)
    with torch.no_grad():
        dice_list_case = []
        for i, batch in enumerate(val_loader):
            val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
            print(type(val_labels))
            print(val_labels.shape)
            _, d, h, w = val_labels.shape
            target_shape = (h, w)
            img_name = batch["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]
            print("Inference on case {}".format(img_name))
            val_outputs = model_infer(val_inputs)
            print(val_outputs.shape)
            val_outputs = torch.softmax(val_outputs, 1).cpu()
            val_outputs = np.array(val_outputs)
            val_outputs = np.argmax(val_outputs, axis=1).astype(np.uint8)[0]
            val_labels = val_labels.cpu()
            val_labels = np.array(val_labels)[0, 0, :, :]
            if args.model_mode == "2dswin":
                val_outputs = resample_2d(val_outputs, target_shape)
            elif args.model_mode == "3dswin":
                val_outputs = resample_3d(val_outputs, target_shape)
            else:
                raise ValueError("model_mode should be ['3dswin', '2dswin', '3dunet', '2dunet']")
                
            dice_list_sub = []
            for i in [1]:
                organ_Dice = loss_func(val_outputs == i, val_labels == i)
                dice_list_sub.append(organ_Dice)
            mean_dice = np.mean(dice_list_sub)
            print("Mean Organ Dice: {}".format(mean_dice))
            dice_list_case.append(mean_dice)
            np.save(
                os.path.join(output_directory, img_name), val_outputs[:,:]
            )
        mean_loss = np.mean(dice_list_case)
        print("Overall Mean Dice: {}".format(mean_loss))
    return mean_loss

In [30]:
np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True)

torch.cuda.set_device(0)
torch.backends.cudnn.benchmark = True
args.test_mode = True
loader = get_loader(args)

inf_size = [args.roi_x, args.roi_y, args.roi_z]

pretrained_dir = args.pretrained_dir
if args.model_mode == "3dswin":
    model = MyModel(img_size=(args.roi_x,args.roi_y,args.roi_y))
elif args.model_mode == "2dswin":
    model = MyModel2d(img_size=(args.roi_x,args.roi_y))
elif args.model_mode == "3dunet":
    model = MyModel3dunet()
elif args.model_mode == "2dfunet":
    model = MyFlexibleUNet2d(args)
elif args.model_mode == "2dfunetlstm":
    model = MyFlexibleUNet2dLSTM(args)
elif args.model_mode == "3dunet++":
    model = MyBasicUNetPlusPlus(args)
else:
    raise ValueError("model mode error")


model_dict = torch.load(os.path.join(pretrained_dir, args.pretrained_model_name))["state_dict"]
if args.model_mode in ["2dswin", "3dunet", "2dfunet", "2dfunetlstm", "3dunet++"]:
    model.load_state_dict(model_dict)
elif args.model_mode == "3dswin":
    model.load_swin_ckpt(model_dict)
else:
    raise ValueError("model mode error")

if args.loss_mode == 'focalLoss':
    loss = FocalLoss(weight=[10.0])
elif args.loss_mode == 'squared_dice':
    loss = DiceCELoss(squared_pred=True, smooth_nr=args.smooth_nr, smooth_dr=args.smooth_dr)
elif args.loss_mode == 'DiceCELoss':
    loss = DiceCELoss(include_background=True, sigmoid=True, ce_weight=torch.Tensor([ 10])) # Normally
elif args.loss_mode == 'custom':
    loss = CustomWeightedDiceCELoss(ink_weight=3.0, weight=args.loss_weight)
    

post_label = AsDiscrete(to_onehot=args.out_channels)
post_pred = AsDiscrete(argmax=True, to_onehot=args.out_channels)
dice_acc = DiceMetric(include_background=False, reduction=MetricReduction.MEAN, get_not_nans=True)
miou_acc = MeanIoU(include_background=False, reduction=MetricReduction.MEAN, get_not_nans=True)
# f_beta_acc = FBetaScore()

if args.model_mode in ["3dswin", "3dunet", "3dunet++"]:
    model_inferer = partial(
        sliding_window_inference,
        roi_size = (args.roi_x,args.roi_y,args.roi_z),
        sw_batch_size = 8,
        predictor = model,
        overlap = 0.5,
        progress = True,
        padding_mode = "reflect", 
        device = "cpu", 
        sw_device = "cuda"
    )
elif args.model_mode in ["2dswin", "2dfunet", "2dfunetlstm"]:
    model_inferer = partial(
        sliding_window_inference,
        roi_size = (args.roi_x,args.roi_y),
        sw_batch_size = 8,
        predictor = model,
        overlap = 0.5,
        progress = True,
        padding_mode = "reflect", 
        device = "cpu", 
        sw_device = "cuda"
    )     
else:
    raise ValueError("model mode error")
    
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total parameters count", pytorch_total_params)

best_acc = 0
start_epoch = 0

model.cuda(0)

print("Lodaer test")
for i in loader:
    print(i['image'].shape)
    print(i['label'].shape)
    print(torch.unique(i['image']))
    print(torch.unique(i['label']))
    break
print("Pass Test")
print(args)
test(model_inferer, loader, loss, args)


Total parameters count 8160661
Lodaer test
2023-04-24 12:39:09,082 - > collate dict key "image" out of 4 keys
2023-04-24 12:39:09,088 - >> collate/stack a list of tensors
2023-04-24 12:39:09,094 - >> E: stack expects each tensor to be equal size, but got [1, 65, 2727, 6330] at entry 0 and [1, 65, 5454, 6330] at entry 1, shape [(1, 65, 2727, 6330), (1, 65, 5454, 6330)] in collate([tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 

RuntimeError: stack expects each tensor to be equal size, but got [1, 65, 2727, 6330] at entry 0 and [1, 65, 5454, 6330] at entry 1
Collate error on the key 'image' of dictionary data.

MONAI hint: if your transforms intentionally create images of different shapes, creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its documentation).