In [1]:
import logging
import os
import sys
import tempfile
from glob import glob

import nibabel as nib
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    LoadImaged,
    Activations,
    EnsureChannelFirst,
    AsDiscrete,
    Compose,
    RandRotate90,
    RandSpatialCrop,
    ScaleIntensity,
    Spacingd,
    EnsureChannelFirstd,
    EnsureTyped,
    Orientationd
)
from monai.transforms import Randomizable
from monai.data import CacheDataset
from monai.visualize import plot_2d_or_3d_image
from monai.transforms import Randomizable
from monai.data import CacheDataset

In [2]:
dataset_dir = "/data/imagecas"
img_dir = os.path.join(dataset_dir, "images")

In [3]:
images = glob(os.path.join(img_dir, "*.img.nii.gz"))
segs = []
for img_path in images:
    img_basename = os.path.basename(img_path)
    label_basename = img_basename.replace("img", "label")
    segs.append(os.path.join(img_dir, label_basename))

In [4]:
class ImageCASDataset(Randomizable, CacheDataset):
    resource = None
    md5 = None

    def __init__(
        self,
        root_dir,
        section,
        transform,
        download=False,
        seed=0,
        val_frac=0.2,
        test_frac=0.2,
        cache_num=sys.maxsize,
        cache_rate=1.0,
        num_workers=0,
    ):
        if not os.path.isdir(root_dir):
            raise ValueError("Root directory root_dir must be a directory.")
        self.section = section
        self.val_frac = val_frac
        self.test_frac = test_frac
        self.set_random_state(seed=seed)
        dataset_dir = os.path.join(root_dir, "imagecas")
        split_filename = "imageCAS_data_split.xlsx"
        tarfile_name = f"{dataset_dir}.tar"
        if download:
            raise ValueError("Download the dataset manually.")
            
        img_dir = os.path.join(dataset_dir, "images")
        img_list = glob(os.path.join(img_dir, "*.img.nii.gz"))
        self.datalist = []
        for img_path in img_list:
            img_basename = os.path.basename(img_path)
            label_basename = img_basename.replace("img", "label")
            self.datalist.append({"image": img_path, "label": os.path.join(img_dir, label_basename)})
        
        data = self._generate_data_list()
        super().__init__(
            data,
            transform,
            cache_num=cache_num,
            cache_rate=cache_rate,
            num_workers=num_workers,
        )

    def randomize(self, data=None):
        self.rann = self.R.random()

    def _generate_data_list(self):
        data = []
        for d in self.datalist:
            self.randomize()
            if self.section == "training":
                if self.rann < self.val_frac + self.test_frac:
                    continue
            elif self.section == "validation":
                if self.rann >= self.val_frac:
                    continue
            elif self.section == "test":
                if self.rann < self.val_frac or self.rann >= self.val_frac + self.test_frac:
                    continue
            else:
                raise ValueError(
                    f"Unsupported section: {self.section}, " "available options are ['training', 'validation', 'test']."
                )
            data.append(d)
        return data

In [6]:
train_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        EnsureTyped(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        EnsureTyped(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")),

    ]
)

In [12]:
train_ds = ImageCASDataset(
    root_dir="/data/", 
    section="training", 
    transform=train_transform,
    cache_rate=0.0
)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4)
val_ds = ImageCASDataset(
    root_dir="/data/", 
    section="validation", 
    transform=val_transform,
    cache_rate=0.0
)
val_loader = DataLoader(val_ds, batch_size=4, shuffle=False, num_workers=4)

In [13]:
sample = next(iter(train_loader))
print(sample['image'].shape, sample['label'].shape)

> collate dict key "image" out of 2 keys
>> collate/stack a list of tensors
>> E: stack expects each tensor to be equal size, but got [1, 171, 171, 109] at entry 0 and [1, 164, 164, 138] at entry 1, shape [torch.Size([1, 171, 171, 109]), torch.Size([1, 164, 164, 138]), torch.Size([1, 193, 193, 104]), torch.Size([1, 182, 182, 138])] in collate([metatensor([[[[   72.0000,    68.0000,    60.0000,  ...,  -651.0000,
            -596.0000,  -672.0000],
          [   28.9294,    80.2235,    78.0118,  ...,  -730.9529,
            -739.1765,  -660.0236],
          [   18.1882,   118.6706,    70.9529,  ...,  -737.3530,
            -726.8823,  -706.3765],
          ...,
          [ -119.2118,  -110.9294,   -77.2118,  ...,    52.4118,
              38.3176,    27.4941],
          [  -97.0000,  -100.0588,   -97.8706,  ...,  -146.9176,
            -263.9765,  -463.6588],
          [ -108.0000,   -95.0000,  -109.0000,  ...,  -556.0000,
            -677.0000,  -833.0000]],

         [[   63.8706,    6

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/whikwon/.local/lib/python3.10/site-packages/monai/data/utils.py", line 493, in list_data_collate
    ret[key] = collate_meta_tensor(data_for_batch)
  File "/home/whikwon/.local/lib/python3.10/site-packages/monai/data/utils.py", line 454, in collate_meta_tensor
    collated = default_collate(batch)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 277, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 125, in collate
    return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 173, in collate_tensor_fn
    out = elem.new(storage).resize_(len(batch), *list(elem.size()))
  File "/home/whikwon/.local/lib/python3.10/site-packages/monai/data/meta_tensor.py", line 282, in __torch_function__
    ret = super().__torch_function__(func, types, args, kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 1418, in __torch_function__
    ret = func(*args, **kwargs)
RuntimeError: Trying to resize storage that is not resizable

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/home/whikwon/.local/lib/python3.10/site-packages/monai/data/utils.py", line 508, in list_data_collate
    raise RuntimeError(re_str) from re
RuntimeError: Trying to resize storage that is not resizable


> collate dict key "image" out of 2 keys
>> collate/stack a list of tensors
>> E: stack expects each tensor to be equal size, but got [1, 164, 164, 138] at entry 0 and [1, 176, 176, 130] at entry 1, shape [torch.Size([1, 164, 164, 138]), torch.Size([1, 176, 176, 130]), torch.Size([1, 164, 164, 116]), torch.Size([1, 164, 164, 123])] in collate([metatensor([[[[   77.0000,    60.0000,    35.0000,  ...,  -520.0000,
            -569.0000,  -749.0000],
          [    9.8957,    49.7055,    42.6933,  ...,  -590.7607,
            -637.1043,  -777.4356],
          [  -32.8712,    59.5890,    45.7669,  ...,  -650.3558,
            -649.6626,  -740.6381],
          ...,
          [ -989.7423,  -991.8221,  -982.6687,  ..., -1007.8466,
            -988.2822, -1003.7178],
          [-1003.1411,  -976.2822,  -980.4356,  ..., -1011.2945,
            -983.4233, -1014.1534],
          [ -994.0000,  -985.0000,  -987.0000,  ..., -1017.0000,
            -992.0000, -1019.0000]],

         [[   93.5890,    6

In [53]:
# create a training data loader
# train_ds = ImageDataset(images, segs, transform=train_imtrans, seg_transform=train_segtrans)
# train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
# # create a validation data loader
# val_ds = ImageDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)
# val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

In [54]:
cd ..

/workspaces/SegMamba


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [55]:
import torch 
from model_segmamba.segmamba import SegMamba

In [56]:
# create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = monai.networks.nets.UNet(
#     spatial_dims=3,
#     in_channels=1,
#     out_channels=1,
#     channels=(16, 32, 64, 128, 256),
#     strides=(2, 2, 2, 2),
#     num_res_units=2,
# ).to(device)
model = SegMamba(in_chans=1,
                 out_chans=1,
                 depths=[2, 2, 2, 2],
                 feat_size=[48, 96, 192, 384]).to(device)
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

In [57]:
sample = next(iter(train_loader))

In [18]:
# start a typical PyTorch training
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
for epoch in range(5):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{5}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_images = None
            val_labels = None
            val_outputs = None
            for val_data in val_loader:
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                roi_size = (96, 96, 96)
                sw_batch_size = 4
                # resize가 안 들어가고, memory 이슈 때문에 sliding 방식으로 inference 하고 합쳐준다.
                val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)
            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "best_metric_model_segmentation3d_array.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
            writer.add_scalar("val_mean_dice", metric, epoch + 1)
            # plot the last model output as GIF image in TensorBoard with the corresponding image and label
            plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
            plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
            plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")

print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()

----------
epoch 1/5
1/5, train_loss: 0.9996
2/5, train_loss: 1.0000
3/5, train_loss: 0.9977
4/5, train_loss: 0.9952
5/5, train_loss: 0.9821
epoch 1 average loss: 0.9949
----------
epoch 2/5
1/5, train_loss: 0.9982
2/5, train_loss: 0.9868
3/5, train_loss: 0.9815
4/5, train_loss: 0.9914
5/5, train_loss: 0.9922
epoch 2 average loss: 0.9900


KeyboardInterrupt: 