In [12]:
import argparse
import os
from functools import partial

import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.parallel
import torch.utils.data.distributed
from trainer import run_training
from monai.inferers import sliding_window_inference
from monai.losses import DiceCELoss, FocalLoss
from monai.metrics import DiceMetric, MeanIoU
from monai.transforms import Activations, AsDiscrete, Compose
from monai.utils.enums import MetricReduction
from monai.visualize import matshow3d

# from utils.myModel import MyModel, MyModel2d, MyModel3dunet, MyFlexibleUNet2d, MyFlexibleUNet2dLSTM, MyBasicUNetPlusPlus

from pathlib import Path
from easydict import EasyDict as edict
TEST_DIR = Path('/root/autodl-tmp/vesuvius-challenge-ink-detection/test')

In [24]:
# args = edict(RandFlipd_prob=0.2, RandRotate90d_prob=0.2, RandScaleIntensityd_prob=0.1, RandShiftIntensityd_prob=0.1, a_max=65535.0, a_min=0.0, amp=True, b_max=255.0, b_min=0.0, batch_size=1, cache_rate=1.0, checkpoint=None, data_dir='/root/autodl-tmp/vesuvius-challenge-ink-detection', debug=False, dist_backend='nccl', dist_url='tcp://127.0.0.1:23456', distributed=False, dropout_path_rate=0.0, dropout_rate=0.0, eff='b3', feature_size=48, gpu=0, in_channels=65, infer_overlap=0.5, json_list='/root/autodl-tmp/data_split/data_split.json', logdir='./runs/512_funetlstm_b3_16_sgd_continue', loss_mode='custom', loss_weight=(2.0, 1.0), lrschedule='cosine_anneal', max_epochs=2000, mid=28, model_mode='2dfunetlstm', momentum=0.99, noamp=False, norm_name='instance', normal=False, num_channel=16, num_samples=4, optim_lr=0.0001, optim_name='sgd', out_channels=1, pretrained_dir='./pretrained_models/', pretrained_model_name='512_funetlstm_b3_16_sgd_1000.pt', rank=0, reg_weight=1e-05, resume_ckpt=True, roi_x=512, roi_y=512, roi_z=16, save_checkpoint=True, smooth_dr=1e-06, smooth_nr=0.0, space_x=1.5, space_y=1.5, space_z=1.0, spatial_dims=3, sw_batch_size=4, test_mode=False, threshold=0.4, use_checkpoint=False, use_normal_dataset=False, use_ssl_pretrained=False, val_every=10, warmup_epochs=50, workers=0, world_size=1)
args = edict(RandFlipd_prob=0.2, RandRotate90d_prob=0.2, RandScaleIntensityd_prob=0.1, RandShiftIntensityd_prob=0.1, a_max=65535.0, a_min=0.0, amp=True, b_max=255.0, b_min=0.0, 
             batch_size=1, cache_rate=1.0, checkpoint=None, data_dir='/root/autodl-tmp/vesuvius-challenge-ink-detection', debug=False, dist_backend='nccl', dist_url='tcp://127.0.0.1:23456', 
             distributed=False, dropout_path_rate=0.0, dropout_rate=0.0, eff='b3', feature_size=48, gpu=0, in_channels=65, infer_overlap=0.5, json_list='/root/autodl-tmp/data_split/data_split.json', 
             logdir='./runs/512_funetlstm_b3_16_sgd_continue_1400', loss_mode='custom', loss_weight=(2.0, 1.0), lrschedule='cosine_anneal', max_epochs=1000, mid=26, model_mode='2dfunetlstm', 
             momentum=0.99, noamp=False, norm_name='instance', normal=False, num_channel=16, num_samples=4, optim_lr=0.0005, optim_name='adamw', out_channels=1, pretrained_dir='./pretrained_models/', 
             pretrained_model_name='512_funetlstm_b3_16_sgd_1400.pt', rank=0, reg_weight=1e-05, resume_ckpt=True, roi_x=512, roi_y=512, roi_z=16, save_checkpoint=True, smooth_dr=1e-06, smooth_nr=0.0, 
             space_x=1.5, space_y=1.5, space_z=1.0, spatial_dims=3, sw_batch_size=4, test_mode=False, threshold=0.4, use_checkpoint=False, use_normal_dataset=False, use_ssl_pretrained=False, 
             val_every=10, warmup_epochs=50, workers=0, world_size=1)


# Model
- MyModel
- MyModel2d
- MyModel3dunet
- MyFlexibleUNet2d
- MyFlexibleUNet2dLSTM 
- MyBasicUNetPlusPlus

In [None]:
import torch.nn as nn
from monai.networks.nets import SwinUNETR, UNet, FlexibleUNet, BasicUNetPlusPlus
from monai.networks.blocks.convolutions import Convolution

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = SwinUNETR(
            img_size=(96,96,96),
            in_channels=1,
            out_channels=14,
            feature_size=48,
            drop_rate=0.0,
            attn_drop_rate=0.0,
            dropout_path_rate=0.0,
            use_checkpoint=True,
        )
        # self.conv1 = Convolution(spatial_dims=3, in_channels=14, out_channels=1, kernel_size=1)
        self.conv2 = Convolution(spatial_dims=3, in_channels=1, out_channels=1, kernel_size=(1, 1, 64), strides=1, padding=0, act="sigmoid")

    
    def forward(self, x):
        if x[0].size() != (1, 64, 64, 64):
            print(x.size())
            raise ValueError("Input size is not correct")
        x_out = self.swinUNETR(x)
        # x_out = self.conv1(x_out)
        x_out = self.conv2(x_out)
        return x_out
    
    def load_swin_ckpt(self, model_dict, strict: bool = True):
        self.swinUNETR.load_state_dict(model_dict, strict)
        pass
    
class MyModel2d(nn.Module):
    def __init__(self,img_size=(192, 192)):
        super().__init__()
        self.swinUNETR = SwinUNETR(
                                img_size=img_size,
                                in_channels=65,
                                out_channels=1,
                                feature_size=12,
                                use_checkpoint=True, 
                                spatial_dims=2
                                )

    
    def forward(self, x):
        x_out = self.swinUNETR(x)
        return x_out
    
    def load_swin_ckpt(self, model_dict, strict: bool = True):
        self.swinUNETR.load_state_dict(model_dict, strict)
        pass
    
class MyModel3dunet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.unet = UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=1,
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        )
        self.conv1 = Convolution(spatial_dims=3, in_channels=1, out_channels=1, kernel_size=(1, 1, 64), strides=1, padding=0, act="sigmoid")
    
    def forward(self, x):
        x_out = self.unet(x)
        x_out = self.conv1(x_out)
        return x_out
    

class MyFlexibleUNet2d(nn.Module):
    def __init__(self, args) -> None:
        super().__init__()
        self.flexibleUNet = FlexibleUNet(
                in_channels=args.num_channel,
                out_channels=1,
                backbone=f"efficientnet-{args.eff}",
                pretrained=True,
                spatial_dims=2,
                dropout=0.0,
            )
        self.sig = nn.Sigmoid()
        
    def forward(self, x):
        x_out = self.flexibleUNet(x)
        x_out = self.sig(x_out)
        return x_out
        
        

class ConvLSTM(nn.Module):
    def __init__(self, in_channels=320, out_channels=320, kernel_size=1, padding=0, batch_first=True):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)
        self.lstm = nn.LSTM(256, 256, batch_first=batch_first)

    def forward(self, x):

        last_feature = x[-1]

        # print("Before modification:")
        # print(x[-1][0, 0, 0, :2])  # print a small part of the tensor
        # print(last_feature.shape)
        batch_size, channels, height, width = last_feature.shape

        # Apply 2D convolution
        last_feature = self.conv(last_feature)

        # Reshape output for LSTM
        last_feature = last_feature.view(batch_size, -1, height * width)

        # Pass through LSTM
        lstm_out, _ = self.lstm(last_feature)

        # Reshape output back to original shape
        last_feature = lstm_out.view(batch_size, channels, height, width)

        x[-1] = last_feature

        # print("After modification:")
        # print(x[-1][0, 0, 0, :2])  # print a small part of the tensor

        return x


class MyFlexibleUNet2dLSTM(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.flexibleUNet = FlexibleUNet(
            in_channels=args.num_channel,
            out_channels=1,
            backbone="efficientnet-b0",
            pretrained=True,
            spatial_dims=2,
            dropout=0.0,
        )
        # Add ConvLSTM layer after the last convolution layer in the encoder
        assert args.roi_x == args.roi_y, "ROI x and y must be the same"
        self.conv_lstm = ConvLSTM()
        self.sig = nn.Sigmoid()

    def forward(self, x):
        x_out = self.flexibleUNet.encoder(x)
        x_out = self.conv_lstm(x_out)
        x_out = self.flexibleUNet.decoder(x_out)
        x_out = self.flexibleUNet.segmentation_head(x_out)
        x_out = self.sig(x_out)
        return x_out
    
class MyBasicUNetPlusPlus(nn.Module):
    def __init__(self, args) -> None:
        super().__init__()
        self.basicUNetPlusPlus = BasicUNetPlusPlus(in_channels=1, out_channels=1)
        
    def forward(self, x):
        x_out = self.basicUNetPlusPlus(x)
        return x_out

# utils

- def resample_3d
- def resample_2d
- class Sampler

In [None]:
def resample_3d(img, target_size):
    imx, imy, imz = img.shape
    tx, ty, tz = target_size
    zoom_ratio = (float(tx) / float(imx), float(ty) / float(imy), float(tz) / float(imz))
    img_resampled = ndimage.zoom(img, zoom_ratio, order=0, prefilter=False)
    return img_resampled

def resample_2d(img, target_size):
    imx, imy = img.shape
    tx, ty = target_size
    zoom_ratio = (float(tx) / float(imx), float(ty) / float(imy))
    img_resampled = ndimage.zoom(img, zoom_ratio, order=0, prefilter=False)
    return img_resampled

class Sampler(torch.utils.data.Sampler):
    def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, make_even=True):
        if num_replicas is None:
            if not torch.distributed.is_available():
                raise RuntimeError(
                    "Requires distributed package to be available")
            num_replicas = torch.distributed.get_world_size()
        if rank is None:
            if not torch.distributed.is_available():
                raise RuntimeError(
                    "Requires distributed package to be available")
            rank = torch.distributed.get_rank()
        self.shuffle = shuffle
        self.make_even = make_even
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(
            math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        indices = list(range(len(self.dataset)))
        self.valid_length = len(
            indices[self.rank: self.total_size: self.num_replicas])

    def __iter__(self):
        if self.shuffle:
            g = torch.Generator()
            g.manual_seed(self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))
        if self.make_even:
            if len(indices) < self.total_size:
                if self.total_size - len(indices) < len(indices):
                    indices += indices[: (self.total_size - len(indices))]
                else:
                    extra_ids = np.random.randint(low=0, high=len(
                        indices), size=self.total_size - len(indices))
                    indices += [indices[ids] for ids in extra_ids]
            assert len(indices) == self.total_size
        indices = indices[self.rank: self.total_size: self.num_replicas]
        self.num_samples = len(indices)
        return iter(indices)

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch

# Custom Transform

In [None]:
class Copy(Transform):
    def __init__(self, num_channel, add_channel=False):
        self.num_channel = num_channel
        self.add_channel = add_channel

    def __call__(self, data):
        if self.add_channel:
            data = data.repeat(1, self.num_channel, 1, 1)  # output = (batch_size=1, num_channel, H, W)
        else:
            data = data.repeat(self.num_channel, 1, 1)  # output = (batch_size=1, num_channel, H, W)
        return data
    
class Copyd(MapTransform):
    """
    Dictionary-based wrapper of :py:class:`monai.transforms.AddChannel`.
    """
    def __init__(self, keys: KeysCollection, num_channel, add_channel=False) -> None:
        """
        Args:
            keys: keys of the corresponding items to be transformed.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            allow_missing_keys: don't raise exception if key is missing.
        """
        super().__init__(keys, )
        self.adder = Copy(num_channel, add_channel)

    def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
        d = dict(data)
        for key in self.key_iterator(d):
            d[key] = self.adder(d[key])
        return d

In [29]:
from monai.data import load_decathlon_datalist
from monai import data, transforms

def get_loader(args):
    data_dir = args.data_dir
    datalist_json = os.path.join(data_dir, args.json_list)
    test_transform = transforms.Compose(
        [
            transforms.LoadImaged(
                keys=["image", "label"], reader="NumpyReader"),
            Copyd(keys=["label"],
                    num_channel=args.num_channel),
            transforms.AddChanneld(keys=["image", 'label']),
            # transforms.CropForegroundd(
            #     keys=["image", "label"], source_key="image"),
            # transforms.Orientationd(keys=["image"], axcodes="RAS"),
            # change_channeld(keys=["image", "label", 'inklabels']),
            # transforms.Spacingd(keys="image", pixdim=(args.space_x, args.space_y, args.space_z), mode="bilinear"),
            transforms.ScaleIntensityRanged(
                keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
            ),
            transforms.ToTensord(keys=["image"]),
        ]
    )
    val_files = load_decathlon_datalist(
        datalist_json, True, "testing", base_dir=data_dir)
    val_ds = data.Dataset(data=val_files, transform=test_transform)
    val_sampler = Sampler(
        val_ds, shuffle=False) if args.distributed else None
    val_loader = data.DataLoader(
        val_ds, batch_size=8, shuffle=False, num_workers=args.workers, sampler=val_sampler, pin_memory=True
    )
    loader = val_loader
    return loader

# Test Function

In [21]:
# from utils.utils import resample_3d, resample_2d

def test(model_infer, val_loader, args):
    output_directory = Path("/kaggle/working/") / args.exp_name
    output_directory.mkdir(parents=True, exist_ok=True)
    with torch.no_grad():
        dice_list_case = []
        for i, batch in enumerate(val_loader):
            val_inputs, val_labels = (batch["image"].cuda(), batch["label"].cuda())
            print(type(val_labels))
            print(val_labels.shape)
            _, d, h, w = val_labels.shape
            target_shape = (h, w)
            img_name = batch["image_meta_dict"]["filename_or_obj"][0].split("/")[-1]
            print("Inference on case {}".format(img_name))
            val_outputs = model_infer(val_inputs)
            print(val_outputs.shape)
            val_outputs = torch.softmax(val_outputs, 1).cpu()
            val_outputs = np.array(val_outputs)
            val_outputs = np.argmax(val_outputs, axis=1).astype(np.uint8)[0]
            val_labels = val_labels.cpu()
            val_labels = np.array(val_labels)[0, 0, :, :]
            if args.model_mode in ["2dswin", "2dfunetlstm"]:
                val_outputs = resample_2d(val_outputs, target_shape)
            elif args.model_mode == "3dswin":
                val_outputs = resample_3d(val_outputs, target_shape)
            else:
                raise ValueError("model_mode should be ['3dswin', '2dswin', '3dunet', '2dunet']")

            np.save(
                os.path.join(output_directory, img_name), val_outputs[:,:]
            )
    return None

In [30]:
np.set_printoptions(formatter={"float": "{: 0.3f}".format}, suppress=True)

torch.cuda.set_device(0)
torch.backends.cudnn.benchmark = True
args.test_mode = True
loader = get_loader(args)

inf_size = [args.roi_x, args.roi_y, args.roi_z]

pretrained_dir = args.pretrained_dir
if args.model_mode == "3dswin":
    model = MyModel(img_size=(args.roi_x,args.roi_y,args.roi_y))
elif args.model_mode == "2dswin":
    model = MyModel2d(img_size=(args.roi_x,args.roi_y))
elif args.model_mode == "3dunet":
    model = MyModel3dunet()
elif args.model_mode == "2dfunet":
    model = MyFlexibleUNet2d(args)
elif args.model_mode == "2dfunetlstm":
    model = MyFlexibleUNet2dLSTM(args)
elif args.model_mode == "3dunet++":
    model = MyBasicUNetPlusPlus(args)
else:
    raise ValueError("model mode error")


model_dict = torch.load(os.path.join(pretrained_dir, args.pretrained_model_name))["state_dict"]
if args.model_mode in ["2dswin", "3dunet", "2dfunet", "2dfunetlstm", "3dunet++"]:
    model.load_state_dict(model_dict)
elif args.model_mode == "3dswin":
    model.load_swin_ckpt(model_dict)
else:
    raise ValueError("model mode error")

if args.model_mode in ["3dswin", "3dunet", "3dunet++"]:
    model_inferer = partial(
        sliding_window_inference,
        roi_size = (args.roi_x,args.roi_y,args.roi_z),
        sw_batch_size = 8,
        predictor = model,
        overlap = 0.5,
        progress = True,
        padding_mode = "reflect", 
        device = "cpu", 
        sw_device = "cuda"
    )
elif args.model_mode in ["2dswin", "2dfunet", "2dfunetlstm"]:
    model_inferer = partial(
        sliding_window_inference,
        roi_size = (args.roi_x,args.roi_y),
        sw_batch_size = 8,
        predictor = model,
        overlap = 0.5,
        progress = True,
        padding_mode = "reflect", 
        device = "cpu", 
        sw_device = "cuda"
    )     
else:
    raise ValueError("model mode error")

model.cuda(0)

print(args)
test(model_inferer, loader, args)


Total parameters count 8160661

Lodaer test

2023-04-24 12:39:09,082 - > collate dict key "image" out of 4 keys

2023-04-24 12:39:09,088 - >> collate/stack a list of tensors

2023-04-24 12:39:09,094 - >> E: stack expects each tensor to be equal size, but got [1, 65, 2727, 6330] at entry 0 and [1, 65, 5454, 6330] at entry 1, shape [(1, 65, 2727, 6330), (1, 65, 5454, 6330)] in collate([tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],

          [0., 0., 0.,  ..., 0., 0., 0.],

          [0., 0., 0.,  ..., 0., 0., 0.],

          ...,

          [0., 0., 0.,  ..., 0., 0., 0.],

          [0., 0., 0.,  ..., 0., 0., 0.],

          [0., 0., 0.,  ..., 0., 0., 0.]],



         [[0., 0., 0.,  ..., 0., 0., 0.],

          [0., 0., 0.,  ..., 0., 0., 0.],

          [0., 0., 0.,  ..., 0., 0., 0.],

          ...,

          [0., 0., 0.,  ..., 0., 0., 0.],

          [0., 0., 0.,  ..., 0., 0., 0.],

          [0., 0., 0.,  ..., 0., 0., 0.]],



         [[0., 0., 0.,  ..., 0., 0., 0.],

          [0., 0

RuntimeError: stack expects each tensor to be equal size, but got [1, 65, 2727, 6330] at entry 0 and [1, 65, 5454, 6330] at entry 1
Collate error on the key 'image' of dictionary data.

MONAI hint: if your transforms intentionally create images of different shapes, creating your `DataLoader` with `collate_fn=pad_list_data_collate` might solve this problem (check its documentation).