In [1]:
import nibabel as nib
import numpy as np
from monai.networks.nets import UNet, SegResNet, SegResNetDS, SwinUNETR

import torch
import pandas as pd
import glob
import os

In [None]:
import numpy as np
import torch

# Create two sample tensors and convert them to torch.float32
tensor1 = torch.tensor(np.random.choice([0, 1, 2], size=(192, 192, 192)), dtype=torch.float32)
tensor2 = torch.tensor(np.random.choice([0, 1, 2], size=(192, 192, 192)), dtype=torch.float32)

# Stack the tensors along a new dimension (axis 0) to create a 2-channel tensor
final_tensor = torch.stack((tensor1, tensor2), dim=0)
final_tensor = final_tensor[np.newaxis, :]

# Check the shape of the final tensor
print(final_tensor.shape)  # It should be torch.Size([2, 192, 192, 192])

# Assuming SegResNet and other necessary imports are available
n_classes = 3
n_channels = 2
model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    init_filters=16,
    blocks_up=[1, 1, 1],
    in_channels=n_channels,
    out_channels=n_classes,
    dropout_prob=0.2
)

# You can proceed with exporting the model without the conversion issue
export_output = torch.onnx.export(model, final_tensor, "model.onnx")


In [1]:
from torch import nn
from torch.utils.data import Dataset as _TorchDataset
from torch.nn import functional as F
from torch.multiprocessing import Manager
from torch.serialization import DEFAULT_PROTOCOL
from torch.utils.data import Dataset as _TorchDataset
from torch.utils.data import Subset
class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels, reduction=2, scale=2):
        super().__init__()
        self.scale = scale
        self.conv = FastSmoothSeNormConv3d(in_channels, out_channels, reduction, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.conv(x)
        x = F.interpolate(x, scale_factor=self.scale, mode='trilinear', align_corners=False)
        return x


# In[29]:


class FastSmoothSeNormConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, reduction=2, **kwargs):
        super(FastSmoothSeNormConv3d, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, bias=True, **kwargs)
        self.norm = FastSmoothSENorm(out_channels, reduction)

    def forward(self, x):
        x = self.conv(x)
        x = F.relu(x, inplace=True)
        x = self.norm(x)
        return x


# In[30]:


class RESseNormConv3d(nn.Module):
    def __init__(self, in_channels, out_channels, reduction=2, **kwargs):
        super().__init__()
        self.conv1 = FastSmoothSeNormConv3d(in_channels, out_channels, reduction, **kwargs)

        if in_channels != out_channels:
            self.res_conv = FastSmoothSeNormConv3d(in_channels, out_channels, reduction, kernel_size=1, stride=1, padding=0)
        else:
            self.res_conv = None

    def forward(self, x):
        residual = self.res_conv(x) if self.res_conv else x
        x = self.conv1(x)
        x += residual
        return x


# In[31]:


class FastSmoothSENormDeepUNet_supervision_skip_no_drop(nn.Module):
    """The model presented in the paper. This model is one of the multiple models that we tried in our experiments
    that it why it has such an awkward name."""

    def __init__(self, in_channels, n_cls, n_filters, reduction=2, return_logits=False):
        super(FastSmoothSENormDeepUNet_supervision_skip_no_drop, self).__init__()
        self.in_channels = in_channels
        self.n_cls = 1 if n_cls == 2 else n_cls
        self.n_filters = n_filters
        self.return_logits = return_logits
        
        
        self.block_1_1_left = RESseNormConv3d(in_channels, n_filters, reduction, kernel_size=7, stride=1, padding=3)
        self.block_1_2_left = RESseNormConv3d(n_filters, n_filters, reduction, kernel_size=3, stride=1, padding=1)

        self.pool_1 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.block_2_1_left = RESseNormConv3d(n_filters, 2 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.block_2_2_left = RESseNormConv3d(2 * n_filters, 2 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.block_2_3_left = RESseNormConv3d(2 * n_filters, 2 * n_filters, reduction, kernel_size=3, stride=1, padding=1)

        self.pool_2 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.block_3_1_left = RESseNormConv3d(2 * n_filters, 4 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.block_3_2_left = RESseNormConv3d(4 * n_filters, 4 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.block_3_3_left = RESseNormConv3d(4 * n_filters, 4 * n_filters, reduction, kernel_size=3, stride=1, padding=1)

        self.pool_3 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.block_4_1_left = RESseNormConv3d(4 * n_filters, 8 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.block_4_2_left = RESseNormConv3d(8 * n_filters, 8 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.block_4_3_left = RESseNormConv3d(8 * n_filters, 8 * n_filters, reduction, kernel_size=3, stride=1, padding=1)

        self.pool_4 = nn.MaxPool3d(kernel_size=2, stride=2)
        self.block_5_1_left = RESseNormConv3d(8 * n_filters, 16 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.block_5_2_left = RESseNormConv3d(16 * n_filters, 16 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.block_5_3_left = RESseNormConv3d(16 * n_filters, 16 * n_filters, reduction, kernel_size=3, stride=1, padding=1)

        self.upconv_4 = nn.ConvTranspose3d(16 * n_filters, 8 * n_filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.block_4_1_right = FastSmoothSeNormConv3d((8 + 8) * n_filters, 8 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.block_4_2_right = FastSmoothSeNormConv3d(8 * n_filters, 8 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.vision_4 = UpConv(8 * n_filters, n_filters, reduction, scale=8)

        self.upconv_3 = nn.ConvTranspose3d(8 * n_filters, 4 * n_filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.block_3_1_right = FastSmoothSeNormConv3d((4 + 4) * n_filters, 4 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.block_3_2_right = FastSmoothSeNormConv3d(4 * n_filters, 4 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.vision_3 = UpConv(4 * n_filters, n_filters, reduction, scale=4)

        self.upconv_2 = nn.ConvTranspose3d(4 * n_filters, 2 * n_filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.block_2_1_right = FastSmoothSeNormConv3d((2 + 2) * n_filters, 2 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.block_2_2_right = FastSmoothSeNormConv3d(2 * n_filters, 2 * n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.vision_2 = UpConv(2 * n_filters, n_filters, reduction, scale=2)

        self.upconv_1 = nn.ConvTranspose3d(2 * n_filters, 1 * n_filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.block_1_1_right = FastSmoothSeNormConv3d((1 + 1) * n_filters, n_filters, reduction, kernel_size=3, stride=1, padding=1)
        self.block_1_2_right = FastSmoothSeNormConv3d(n_filters, n_filters, reduction, kernel_size=3, stride=1, padding=1)

        self.conv1x1 = nn.Conv3d(1 * n_filters, self.n_cls, kernel_size=1, stride=1, padding=0)

    def forward(self, x):

        ds0 = self.block_1_2_left(self.block_1_1_left(x))
        ds1 = self.block_2_3_left(self.block_2_2_left(self.block_2_1_left(self.pool_1(ds0))))
        ds2 = self.block_3_3_left(self.block_3_2_left(self.block_3_1_left(self.pool_2(ds1))))
        ds3 = self.block_4_3_left(self.block_4_2_left(self.block_4_1_left(self.pool_3(ds2))))
        x = self.block_5_3_left(self.block_5_2_left(self.block_5_1_left(self.pool_4(ds3))))

        x = self.block_4_2_right(self.block_4_1_right(torch.cat([self.upconv_4(x), ds3], 1)))
        sv4 = self.vision_4(x)

        x = self.block_3_2_right(self.block_3_1_right(torch.cat([self.upconv_3(x), ds2], 1)))
        sv3 = self.vision_3(x)

        x = self.block_2_2_right(self.block_2_1_right(torch.cat([self.upconv_2(x), ds1], 1)))
        sv2 = self.vision_2(x)

        x = self.block_1_1_right(torch.cat([self.upconv_1(x), ds0], 1))
        x = x + sv4 + sv3 + sv2
        x = self.block_1_2_right(x)

        x = self.conv1x1(x)

        if self.return_logits:
            return x
        else:
             return F.softmax(x, dim=1)


# In[32]:


class FastSmoothSENorm(nn.Module):
    class SEWeights(nn.Module):
        def __init__(self, in_channels, reduction=2):
            super().__init__()
            self.conv1 = nn.Conv3d(in_channels, in_channels // reduction, kernel_size=1, stride=1, padding=0, bias=True)
            self.conv2 = nn.Conv3d(in_channels // reduction, in_channels, kernel_size=1, stride=1, padding=0, bias=True)

        def forward(self, x):
            b, c, d, h, w = x.size()
            out = torch.mean(x.view(b, c, -1), dim=-1).view(b, c, 1, 1, 1)  # output_shape: in_channels x (1, 1, 1)
            out = F.relu(self.conv1(out))
            out = self.conv2(out)
            return out

    def __init__(self, in_channels, reduction=2):
        super(FastSmoothSENorm, self).__init__()
        self.norm = nn.InstanceNorm3d(in_channels, affine=False)
        self.gamma = self.SEWeights(in_channels, reduction)
        self.beta = self.SEWeights(in_channels, reduction)

    def forward(self, x):
        gamma = torch.sigmoid(self.gamma(x))
        beta = torch.tanh(self.beta(x))
        x = self.norm(x)
        return gamma * x + beta     


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = FastSmoothSENormDeepUNet_supervision_skip_no_drop(in_channels=2,n_cls=3,n_filters=4)

  self.weight = Parameter(torch.empty(


In [3]:
model

FastSmoothSENormDeepUNet_supervision_skip_no_drop(
  (block_1_1_left): RESseNormConv3d(
    (conv1): FastSmoothSeNormConv3d(
      (conv): Conv3d(2, 4, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3))
      (norm): FastSmoothSENorm(
        (norm): InstanceNorm3d(4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (gamma): SEWeights(
          (conv1): Conv3d(4, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
          (conv2): Conv3d(2, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1))
        )
        (beta): SEWeights(
          (conv1): Conv3d(4, 2, kernel_size=(1, 1, 1), stride=(1, 1, 1))
          (conv2): Conv3d(2, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1))
        )
      )
    )
    (res_conv): FastSmoothSeNormConv3d(
      (conv): Conv3d(2, 4, kernel_size=(1, 1, 1), stride=(1, 1, 1))
      (norm): FastSmoothSENorm(
        (norm): InstanceNorm3d(4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (gamma): SEWeights(
       

In [5]:
# Create two sample tensors and convert them to torch.float32
tensor1 = torch.tensor(np.random.choice([0, 1, 2], size=(192, 192, 192)), dtype=torch.float32)
tensor2 = torch.tensor(np.random.choice([0, 1, 2], size=(192, 192, 192)), dtype=torch.float32)

# Stack the tensors along a new dimension (axis 0) to create a 2-channel tensor
final_tensor = torch.stack((tensor1, tensor2), dim=0)
final_tensor = final_tensor[np.newaxis, :]

# Check the shape of the final tensor
print(final_tensor.shape)  # It should be torch.Size([2, 192, 192, 192])

torch.Size([1, 2, 192, 192, 192])


In [6]:
export_output = torch.onnx.export(model, final_tensor, "unet.onnx")

