In [2]:
import numpy as np
import torch.nn as nn
import torch.optim as optim
# from torch.utils.data import Dataset, DataLoader
from monai.data import DataLoader, Dataset

import sys
# import itk

import SimpleITK as sitk
import time

import torch
# import segmentation_models_pytorch as smp
from torch.nn.functional import one_hot

from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchviz import make_dot

In [2]:
import torch.nn.functional as F
class UNet3D(nn.Module):
    def __init__(self, in_channels=1, out_channels=3, level_channels=[64, 128, 256], bottleneck_channel=512):
        super(UNet3D, self).__init__()

        self.maxpool3d = nn.MaxPool3d(kernel_size=2, stride=2)

        # Encoder part
        self.encoder1 = self._make_encoder_block(in_channels, level_channels[0])
        self.encoder2 = self._make_encoder_block(level_channels[0], level_channels[1])
        self.encoder3 = self._make_encoder_block(level_channels[1], level_channels[2])

        # Bottleneck part
        self.bottleneck = self._make_encoder_block(level_channels[2], bottleneck_channel)

        # Decoder part
        self.upconv3 = nn.ConvTranspose3d(bottleneck_channel, bottleneck_channel, kernel_size=2, stride=2)
        self.act = nn.ReLU(inplace=True)
        self.decoder3 = self._make_decoder_block(level_channels[2]+bottleneck_channel, level_channels[2])

        self.upconv2 = nn.ConvTranspose3d(level_channels[2], level_channels[2], kernel_size=2, stride=2)
        self.decoder2 = self._make_decoder_block(level_channels[1]+level_channels[2], level_channels[1])

        self.upconv1 = nn.ConvTranspose3d(level_channels[1], level_channels[1], kernel_size=2, stride=2)
        self.decoder1 = self._make_decoder_block(level_channels[0]+level_channels[1], level_channels[0])

        # Out layer conv
        self.outlayer = nn.Conv3d(in_channels=level_channels[0], out_channels=out_channels, kernel_size=1)
        self.soft = nn.Softmax(dim=0)

    def forward(self, x):

        # Encode part
        res1 = self.encoder1(x)
        x = self.maxpool3d(res1)

        res2 = self.encoder2(x)
        x = self.maxpool3d(res2)

        res3 = self.encoder3(x)
        x = self.maxpool3d(res3)

        # Bottleneck part
        x = self.bottleneck(x)
        # x = self.act(x)

        # Decode part
        x = self.upconv3(x)
        x = torch.cat((res3, x), dim=1)
        x = self.decoder3(x)

        x = self.upconv2(x)
        x = torch.cat((res2, x), dim=1)
        x = self.decoder2(x)

        x = self.upconv1(x)
        x = torch.cat((res1, x), dim=1)
        x = self.decoder1(x)

        # Out layer
        out = self.outlayer(x)
        # print(out.shape)
        # out = self.soft(x)
        out = F.softmax(out, dim=1)
        out = out.permute(0,1,3,4,2)
        # print(out.shape)

        return out

    def _make_encoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels//2, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels//2),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels//2, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

    def _make_decoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )


In [3]:
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference

# device = torch.device("cuda")

# 创建模型实例

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
)