In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# 基于UNet3d进行改进

## 1. 网络结构

In [3]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
         )
    def forward(self, x):
        return self.conv(x)

### 1.1 UNet_BN

In [4]:
class UNet3d_bn(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3d_bn, self).__init__()
        self.encoder1 = DoubleConv(in_channels, 32)
        self.encoder2 = DoubleConv(32, 64)
        self.encoder3 = DoubleConv(64, 128)
        self.encoder4 = DoubleConv(128, 256)
        self.encoder5 = DoubleConv(256, 512) 

        self.decoder1 = DoubleConv(512, 256)
        self.conv_trans1 = nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2)
        self.decoder2 = DoubleConv(256, 128)
        self.conv_trans2 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = DoubleConv(128, 64)
        self.conv_trans3 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        self.decoder4 = DoubleConv(64, 32)
        self.conv_trans4 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
        self.out_conv = DoubleConv(32, out_channels)

        self.soft = nn.Softmax(dim=1)
        
        
    def forward(self, x):
        # 编码器部分
        t1 = self.encoder1(x)                                               # 32 x 128 x 128 x 128
        out = F.max_pool3d(t1, 2, 2)                                        # 32 x 64 x 64 x 64
                                    
        t2 = self.encoder2(out)                                             # 64 x 64 x 64 x 64
        out = F.max_pool3d(t2, 2, 2)                                        # 64 x 32 x 32 x 32
        
        t3 = self.encoder3(out)                                             # 128 x 32 x 32 x 32
        out = F.max_pool3d(t3, 2, 2)                                        # 128 x 16 x 16 x 16
        
        t4 = self.encoder4(out)                                             # 256 x 16 x 16 x 16
        out = F.max_pool3d(t4, 2, 2)                                        # 256 x 8 x 8 x 8
        
        out = self.encoder5(out)                                            # 512 x 8 x 8 x 8
        
        
        
        out = self.conv_trans1(out)                                         # 256 x 16 x 16 x 16
        out = self.decoder1(torch.cat([out, t4], dim=1))                    # 256 x 16 x 16 x 16
        
        out = self.conv_trans2(out)                                          # 128 x 32 x 32 x 32
        out = self.decoder2(torch.cat([out, t3], dim=1))                    # 128 x 32 x 32 x 32
        
        out = self.conv_trans3(out)                                         # 64 x 64 x 64 x 64
        out = self.decoder3(torch.cat([out, t2], dim=1))                    # 64 x 64 x 64 x 64                

        out = self.conv_trans4(out)                                         # 32 x 128 x 128 x 128
        out = self.decoder4(torch.cat([out, t1], dim=1))                    # 32 x 128 x 128 x 128

        out = self.out_conv(out)                                            # out_channels x 128 x 128
        
        out = self.soft(out)                                             # softmax
        return out

In [5]:
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet3d_bn(in_channels=4, out_channels=4)
input_tensor = torch.randn([1, 4, 128, 128, 128]).float()

model.to(device)
input_tensor = input_tensor.to(device)


out = model(input_tensor)
print(out.shape)
summary(model, (4, 128, 128, 128), 1)

torch.Size([1, 4, 128, 128, 128])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1     [1, 32, 128, 128, 128]           3,488
       BatchNorm3d-2     [1, 32, 128, 128, 128]              64
              ReLU-3     [1, 32, 128, 128, 128]               0
            Conv3d-4     [1, 32, 128, 128, 128]          27,680
       BatchNorm3d-5     [1, 32, 128, 128, 128]              64
              ReLU-6     [1, 32, 128, 128, 128]               0
        DoubleConv-7     [1, 32, 128, 128, 128]               0
            Conv3d-8        [1, 64, 64, 64, 64]          55,360
       BatchNorm3d-9        [1, 64, 64, 64, 64]             128
             ReLU-10        [1, 64, 64, 64, 64]               0
           Conv3d-11        [1, 64, 64, 64, 64]         110,656
      BatchNorm3d-12        [1, 64, 64, 64, 64]             128
             ReLU-13        [1, 64, 64, 64, 64]               0
     

In [6]:
# simple UNet3d_ln
class UNet3d_ln(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3d_ln, self).__init__()
        self.encoder1 = nn.Conv3d(in_channels, 32, kernel_size=3, padding=1)
        self.encoder2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.encoder3 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.encoder4 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
        self.encoder5 = nn.Conv3d(256, 512, kernel_size=3, padding=1)

        self.decoder1 = nn.Conv3d(512, 256, kernel_size=3, padding=1)
        self.conv_trans1 = nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2)
        self.decoder2 = nn.Conv3d(256, 128, kernel_size=3, padding=1)
        self.conv_trans2 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
        self.conv_trans3 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        self.decoder4 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
        self.conv_trans4 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
        self.out_conv = nn.Conv3d(32, out_channels, kernel_size=3, padding=1)

        self.soft = nn.Softmax(dim=1)
        
        
    def forward(self, x):
        # 编码器
        out = self.encoder1(x)                                              # 32 x 128 x 128 x 128
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                     
        t1 = out                                                            # 32 x 128 x 128 x 128
        
        out = F.max_pool3d(t1, 2, 2)                                        # 32 x 64 x 64 x 64
        out = self.encoder2(out)                                            # 64 x 64 x 64 x 64
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        t2 = out                                                            # 64 x 64 x 64 x 64

        out = F.max_pool3d(t2, 2, 2)                                        # 64 x 32 x 32 x 32
        out = self.encoder3(out)                                            # 128 x 32 x 32 x 32
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        t3 = out                                                            # 128 x 32 x 32 x 32

        out = F.max_pool3d(t3, 2, 2)                                        # 128 x 16 x 16 x 16
        out = self.encoder4(out)                                            # 256 x 16 x 16 x 16
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                     
        t4 = out                                                            # 256 x 16 x 16 x 16
        
        out = F.max_pool3d(t4, 2, 2)                                        # 256 x 8 x 8 x 8
        out = self.encoder5(out)                                            # 512 x 8 x 8 x 8
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        
        # 解码器
        out = self.conv_trans1(out)                                         # 256 x 16 x 16 x 16
        out = self.decoder1(torch.cat([out, t4], dim=1))                    # 256 x 16 x 16 x 16
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        
        out = self.conv_trans2(out)                                         # 128 x 32 x 32 x 32
        out = self.decoder2(torch.cat([out, t3], dim=1))                    # 128 x 32 x 32 x 32
        out = F.relu(F.layer_norm(out, out.shape[-3:]))

        out = self.conv_trans3(out)                                         # 64 x 64 x 64 x 64
        out = self.decoder3(torch.cat([out, t2], dim=1))                    # 64 x 64 x 64 x 64
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                 
        
        out = self.conv_trans4(out)                                         # 32 x 128 x 128 x 128
        out = self.decoder4(torch.cat([out, t1], dim=1))                    # 32 x 128 x 128 x 128
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                     
        
        out = self.out_conv(out)                                            # out_channels x 128 x 128 x 128
        
        out = self.soft(out)                                                # softmax
        
        return out

In [7]:

# 改进
class UNet3d_ln_double(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3d_ln_double, self).__init__()
        self.encoder1 = nn.Conv3d(in_channels, 32, kernel_size=3, padding=1)
        self.encoder2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.encoder3 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.encoder4 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
        self.encoder5 = nn.Conv3d(256, 512, kernel_size=3, padding=1)

        self.conv_32    = nn.Conv3d(32, 32, kernel_size=3, padding=1)
        self.conv_64    = nn.Conv3d(64, 64, kernel_size=3, padding=1)
        self.conv_128    = nn.Conv3d(128, 128, kernel_size=3, padding=1)
        self.conv_256    = nn.Conv3d(256, 256, kernel_size=3, padding=1)
        self.conv_512    = nn.Conv3d(512, 512, kernel_size=3, padding=1)    
        
        self.decoder1 = nn.Conv3d(512, 256, kernel_size=3, padding=1)
        self.conv_trans1 = nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2)
        self.decoder2 = nn.Conv3d(256, 128, kernel_size=3, padding=1)
        self.conv_trans2 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
        self.conv_trans3 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        self.decoder4 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
        self.conv_trans4 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
        self.out_conv = nn.Conv3d(32, out_channels, kernel_size=3, padding=1)

        self.soft = nn.Softmax(dim=1)
        
        
        
    def forward(self, x):
        # 编码器
        out = self.encoder1(x)                                              # 32 x 128 x 128 x 128
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                     
        out = self.conv_32(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        t1 = out                                                            # 32 x 128 x 128 x 128
        
        out = F.max_pool3d(t1, 2, 2)                                        # 32 x 64 x 64 x 64
        out = self.encoder2(out)                                            # 64 x 64 x 64 x 64
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        out = self.conv_64(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        t2 = out                                                            # 64 x 64 x 64 x 64

        out = F.max_pool3d(t2, 2, 2)                                        # 64 x 32 x 32 x 32
        out = self.encoder3(out)                                            # 128 x 32 x 32 x 32
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        out = self.conv_128(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        t3 = out                                                            # 128 x 32 x 32 x 32

        out = F.max_pool3d(t3, 2, 2)                                        # 128 x 16 x 16 x 16
        out = self.encoder4(out)                                            # 256 x 16 x 16 x 16
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                     
        out = self.conv_256(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        t4 = out                                                            # 256 x 16 x 16 x 16
        
        out = F.max_pool3d(t4, 2, 2)                                        # 256 x 8 x 8 x 8
        out = self.encoder5(out)                                            # 512 x 8 x 8 x 8
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        out = self.conv_512(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        
        # 解码器
        out = self.conv_trans1(out)                                         # 256 x 16 x 16 x 16
        out = self.decoder1(torch.cat([out, t4], dim=1))                    # 256 x 16 x 16 x 16
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        out = self.conv_256(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        
        out = self.conv_trans2(out)                                         # 128 x 32 x 32 x 32
        out = self.decoder2(torch.cat([out, t3], dim=1))                    # 128 x 32 x 32 x 32
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        out = self.conv_128(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))

        out = self.conv_trans3(out)                                         # 64 x 64 x 64 x 64
        out = self.decoder3(torch.cat([out, t2], dim=1))                    # 64 x 64 x 64 x 64
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                 
        out = self.conv_64(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        
        out = self.conv_trans4(out)                                         # 32 x 128 x 128 x 128
        out = self.decoder4(torch.cat([out, t1], dim=1))                    # 32 x 128 x 128 x 128
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                     
        out = self.conv_32(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        
        out = self.out_conv(out)                                            # out_channels x 128 x 128 x 128
        
        out = self.soft(out)                                                # softmax
        
        return out

In [8]:
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet3d_ln(in_channels=4, out_channels=4)
input_tensor = torch.randn([1, 4, 128, 128, 128]).float()

model.to(device)
input_tensor = input_tensor.to(device)


out = model(input_tensor)
print(out.shape)
summary(model, (4, 128, 128, 128))

torch.Size([1, 4, 128, 128, 128])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1    [-1, 32, 128, 128, 128]           3,488
            Conv3d-2       [-1, 64, 64, 64, 64]          55,360
            Conv3d-3      [-1, 128, 32, 32, 32]         221,312
            Conv3d-4      [-1, 256, 16, 16, 16]         884,992
            Conv3d-5         [-1, 512, 8, 8, 8]       3,539,456
   ConvTranspose3d-6      [-1, 256, 16, 16, 16]       1,048,832
            Conv3d-7      [-1, 256, 16, 16, 16]       3,539,200
   ConvTranspose3d-8      [-1, 128, 32, 32, 32]         262,272
            Conv3d-9      [-1, 128, 32, 32, 32]         884,864
  ConvTranspose3d-10       [-1, 64, 64, 64, 64]          65,600
           Conv3d-11       [-1, 64, 64, 64, 64]         221,248
  ConvTranspose3d-12    [-1, 32, 128, 128, 128]          16,416
           Conv3d-13    [-1, 32, 128, 128, 128]          55,328
     

In [9]:
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet3d_ln_double(in_channels=4, out_channels=4)
input_tensor = torch.randn([1, 4, 128, 128, 128]).float()

model.to(device)
input_tensor = input_tensor.to(device)


out = model(input_tensor)
print(out.shape)
summary(model, (4, 128, 128, 128))

torch.Size([1, 4, 128, 128, 128])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1    [-1, 32, 128, 128, 128]           3,488
            Conv3d-2    [-1, 32, 128, 128, 128]          27,680
            Conv3d-3       [-1, 64, 64, 64, 64]          55,360
            Conv3d-4       [-1, 64, 64, 64, 64]         110,656
            Conv3d-5      [-1, 128, 32, 32, 32]         221,312
            Conv3d-6      [-1, 128, 32, 32, 32]         442,496
            Conv3d-7      [-1, 256, 16, 16, 16]         884,992
            Conv3d-8      [-1, 256, 16, 16, 16]       1,769,728
            Conv3d-9         [-1, 512, 8, 8, 8]       3,539,456
           Conv3d-10         [-1, 512, 8, 8, 8]       7,078,400
  ConvTranspose3d-11      [-1, 256, 16, 16, 16]       1,048,832
           Conv3d-12      [-1, 256, 16, 16, 16]       3,539,200
           Conv3d-13      [-1, 256, 16, 16, 16]       1,769,728
  Con

In [19]:
class _make_conv_layer(nn.Module):
    def __init__(self, in_channels:int, out_channels:int, use_bn=True, use_ln=False, use_dropout=False, dropout_rate=0, ln_spatial_shape:list=[]):
        super(_make_conv_layer, self).__init__()
        # 参数
        self.use_bn = use_bn
        self.use_ln = use_ln
        self.use_dropout = use_dropout
        self.dropout_rate = dropout_rate
        self.ln_spatial_shape = ln_spatial_shape

        # 卷积层
        if use_bn:
            self.conv3x3 = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_channels),
                nn.ReLU(inplace=True),
                nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_channels),
                nn.ReLU(inplace=True)
            )
        elif use_ln:
            self.conv3x3 = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.LayerNorm([out_channels, *ln_spatial_shape]),
                nn.ReLU(inplace=True),
                nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.LayerNorm([out_channels, *ln_spatial_shape]),
                nn.ReLU(inplace=True)
            )
        else:
            raise"Error: no normalization layer is used!"

        self.dropout = nn.Dropout3d(self.dropout_rate)

    def forward(self, x):
        out = self.conv3x3(x)
        if self.use_dropout:
            out = self.dropout(out)
        return out

class _make_upsample_layer(nn.Module):
    def __init__(self, in_channels:int, out_channels:int, use_bn=True, use_ln=False, use_dropout=False, dropout_rate=0, ln_spatial_shape:list=[]):
        super(_make_upsample_layer, self).__init__()    
        self.use_bn = use_bn
        self.use_ln = use_ln
        self.use_dropout = use_dropout
        self.dropout_rate = dropout_rate
        self.ln_spatial_shape = ln_spatial_shape

        if use_bn:
            self.up2times = nn.Sequential(
                nn.ConvTranspose3d(in_channels, in_channels, kernel_size=2, stride=2),
                nn.BatchNorm3d(in_channels),
                nn.ReLU(inplace=True),
                nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_channels),
                nn.ReLU(inplace=True)
            )
        elif use_ln:
            self.up2times = nn.Sequential(
                nn.ConvTranspose3d(in_channels, in_channels, kernel_size=2, stride=2),
                nn.LayerNorm([in_channels, *(2*ln_spatial_shape)]),
                nn.ReLU(inplace=True),
                nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.LayerNorm([out_channels, *(2*ln_spatial_shape)]),
                nn.ReLU(inplace=True)
            )
        else:
            raise"Error: no normalization layer is used!"
        
        self.dropout = nn.Dropout3d(self.dropout_rate)

    def forward(self, x):
        out = self.up2times(x)
        if self.use_dropout:
            out = self.dropout(out)
        return out

class UNet3D(nn.Module):
    def __init__(self, in_channels:int, out_channels:int, dropout_rate:float=0, use_bn:bool=True, use_ln:bool=False, use_dropout:bool=False, ln_spatial_shape:list=[]):
        super(UNet3D, self).__init__()     
        self.dropout_rate = dropout_rate
        self.encoder_use_list = (use_bn, use_ln, True, 0.1)
        self.decoder_use_list = (use_bn, use_ln, False, 0.1)
        # 编码器
        self.encoder1 = _make_conv_layer(in_channels, 32, * self.encoder_use_list)
        self.encoder2 = _make_conv_layer(32, 64, *self.encoder_use_list)
        self.encoder3 = _make_conv_layer(64, 128, *self.encoder_use_list)
        self.encoder4 = _make_conv_layer(128, 256, *self.encoder_use_list)
        self.encoder5 = _make_conv_layer(256, 512, *self.encoder_use_list)

        # 解码器
        self.decoder1 = _make_conv_layer(512, 256, *self.decoder_use_list)
        self.up1      = _make_upsample_layer(512, 256, *self.decoder_use_list)
        self.decoder2 = _make_conv_layer(256, 128, *self.decoder_use_list)
        self.up2      = _make_upsample_layer(256, 128, *self.decoder_use_list)
        self.decoder3 = _make_conv_layer(128, 64, *self.decoder_use_list)
        self.up3      = _make_upsample_layer(128, 64, *self.decoder_use_list)
        self.decoder4 = _make_conv_layer(64, 32, *self.decoder_use_list)
        self.up4      = _make_upsample_layer(64, 32, *self.decoder_use_list)

        # 输出层
        self.output_conv = nn.Conv3d(32, out_channels, kernel_size=1)

        # 归一化层
        self.dropout = nn.Dropout3d(dropout_rate)
        
        self.soft = nn.Softmax(dim=1)

    def forward(self, x):
        # 编码器
        t1 = self.encoder1(x)                                                                   # [1, 32, 128, 128, 128]
        t2 = self.encoder2(F.max_pool3d(t1, 2, 2))                                              # [1, 64, 64, 64, 64] 
        t3 = self.encoder3(F.max_pool3d(t2, 2, 2))                                              # [1, 128, 32, 32, 32]
        t4 = self.encoder4(F.max_pool3d(t3, 2, 2))                                              # [1, 256, 16, 16, 16]
        out = self.encoder5(F.max_pool3d(t4, 2, 2))                                             # [1, 512, 8, 8, 8]

        # Dropout
        if self.dropout_rate > 0:
            out = self.dropout(out)                                                              # [1, 512, 8, 8, 8]
        # 解码器        
        out = self.decoder1(torch.cat([self.up1(out), t4], dim=1))                               # [1, 256, 16, 16, 16]
        out = self.decoder2(torch.cat([self.up2(out), t3], dim=1))                               # [1, 128, 32, 32, 32]                      
        out = self.decoder3(torch.cat([self.up3(out), t2], dim=1))                               # [1, 64, 64, 64, 64]
        out = self.decoder4(torch.cat([self.up4(out), t1], dim=1))                               # [1, 32, 128, 128, 128]

        # 输出层
        out = self.output_conv(out)
        out = self.soft(out)

        return out

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet3D(in_channels=4, out_channels=4)
input_tensor = torch.randn([1, 4, 128, 128, 128]).float()

model.to(device)
input_tensor = input_tensor.to(device)


# out = model(input_tensor)
# print(out.shape)
summary(model, (4, 128, 128, 128), 2)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1     [2, 32, 128, 128, 128]           3,488
       BatchNorm3d-2     [2, 32, 128, 128, 128]              64
              ReLU-3     [2, 32, 128, 128, 128]               0
            Conv3d-4     [2, 32, 128, 128, 128]          27,680
       BatchNorm3d-5     [2, 32, 128, 128, 128]              64
              ReLU-6     [2, 32, 128, 128, 128]               0
         Dropout3d-7     [2, 32, 128, 128, 128]               0
  _make_conv_layer-8     [2, 32, 128, 128, 128]               0
            Conv3d-9        [2, 64, 64, 64, 64]          55,360
      BatchNorm3d-10        [2, 64, 64, 64, 64]             128
             ReLU-11        [2, 64, 64, 64, 64]               0
           Conv3d-12        [2, 64, 64, 64, 64]         110,656
      BatchNorm3d-13        [2, 64, 64, 64, 64]             128
             ReLU-14        [2, 64, 64,

In [29]:
labels = {
    'BG': 0, 
    'NCR' : 1,
    'ED': 2,
    'ET':3
    }
[k for k in labels.keys()]

['BG', 'NCR', 'ED', 'ET']

In [33]:
class _make_conv_layer(nn.Module):
    def __init__(self, in_channels:int, out_channels:int, use_bn=True, use_ln=False, use_dropout=False, dropout_rate=0, ln_spatial_shape:list=[]):
        super(_make_conv_layer, self).__init__()
        # 参数
        self.use_bn = use_bn
        self.use_ln = use_ln
        self.use_dropout = use_dropout
        self.dropout_rate = dropout_rate
        self.ln_spatial_shape = ln_spatial_shape

        # 卷积层
        if use_bn:
            self.conv3x3 = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_channels),
                nn.ReLU(inplace=True),
                nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_channels),
                nn.ReLU(inplace=True)
            )
        elif use_ln:
            self.conv3x3 = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.LayerNorm([out_channels, *ln_spatial_shape]),
                nn.ReLU(inplace=True),
                nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.LayerNorm([out_channels, *ln_spatial_shape]),
                nn.ReLU(inplace=True)
            )
        else:
            raise"Error: no normalization layer is used!"

        self.dropout = nn.Dropout3d(self.dropout_rate)

    def forward(self, x):
        out = self.conv3x3(x)
        if self.use_dropout:
            out = self.dropout(out)
        return out

class _make_upsample_layer(nn.Module):
    def __init__(self, in_channels:int, out_channels:int, use_bn=True, use_ln=False, use_dropout=False, dropout_rate=0, ln_spatial_shape:list=[]):
        super(_make_upsample_layer, self).__init__()    
        self.use_bn = use_bn
        self.use_ln = use_ln
        self.use_dropout = use_dropout
        self.dropout_rate = dropout_rate
        self.ln_spatial_shape = ln_spatial_shape

        if use_bn:
            self.up2times = nn.Sequential(
                nn.ConvTranspose3d(in_channels, in_channels, kernel_size=2, stride=2),
                nn.BatchNorm3d(in_channels),
                nn.ReLU(inplace=True),
                nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.BatchNorm3d(out_channels),
                nn.ReLU(inplace=True)
            )
        elif use_ln:
            self.up2times = nn.Sequential(
                nn.ConvTranspose3d(in_channels, in_channels, kernel_size=2, stride=2),
                nn.LayerNorm([in_channels, *(2*ln_spatial_shape)]),
                nn.ReLU(inplace=True),
                nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.LayerNorm([out_channels, *(2*ln_spatial_shape)]),
                nn.ReLU(inplace=True)
            )
        else:
            raise"Error: no normalization layer is used!"
        
        self.dropout = nn.Dropout3d(self.dropout_rate)

    def forward(self, x):
        out = self.up2times(x)
        if self.use_dropout:
            out = self.dropout(out)
        return out

class UNet3D(nn.Module):
    def __init__(self, in_channels:int, out_channels:int, dropout_rate:float=0, use_bn:bool=True, use_ln:bool=False, use_dropout:bool=False, ln_spatial_shape:list=[]):
        super(UNet3D, self).__init__()     
        self.dropout_rate = dropout_rate
        self.encoder_use_list = (use_bn, use_ln, True, 0.1)
        self.decoder_use_list = (use_bn, use_ln, False, 0.1)
        # 编码器
        self.encoder1 = _make_conv_layer(in_channels, 32, * self.encoder_use_list)
        self.encoder2 = _make_conv_layer(32, 64, *self.encoder_use_list)
        self.encoder3 = _make_conv_layer(64, 128, *self.encoder_use_list)
        self.encoder4 = _make_conv_layer(128, 256, *self.encoder_use_list)
        self.encoder5 = _make_conv_layer(256, 512, *self.encoder_use_list)

        # 解码器
        self.decoder1 = _make_conv_layer(512, 256, *self.decoder_use_list)
        self.up1      = _make_upsample_layer(512, 256, *self.decoder_use_list)
        self.decoder2 = _make_conv_layer(256, 128, *self.decoder_use_list)
        self.up2      = _make_upsample_layer(256, 128, *self.decoder_use_list)
        self.decoder3 = _make_conv_layer(128, 64, *self.decoder_use_list)
        self.up3      = _make_upsample_layer(128, 64, *self.decoder_use_list)
        self.decoder4 = _make_conv_layer(64, 32, *self.decoder_use_list)
        self.up4      = _make_upsample_layer(64, 32, *self.decoder_use_list)

        # 输出层
        self.output_conv = nn.Conv3d(32, out_channels, kernel_size=1)

        # 归一化层
        self.dropout = nn.Dropout3d(dropout_rate)
        
        self.soft = nn.Softmax(dim=1)

    def forward(self, x):
        # 编码器
        t1 = self.encoder1(x)                                                                   # [1, 32, 128, 128, 128]
        t2 = self.encoder2(F.max_pool3d(t1, 2, 2))                                              # [1, 64, 64, 64, 64] 
        t3 = self.encoder3(F.max_pool3d(t2, 2, 2))                                              # [1, 128, 32, 32, 32]
        t4 = self.encoder4(F.max_pool3d(t3, 2, 2))                                              # [1, 256, 16, 16, 16]
        out = self.encoder5(F.max_pool3d(t4, 2, 2))                                             # [1, 512, 8, 8, 8]

        # Dropout
        if self.dropout_rate > 0:
            out = self.dropout(out)                                                              # [1, 512, 8, 8, 8]
        # 解码器        
        out = self.decoder1(torch.cat([self.up1(out), t4], dim=1))                               # [1, 256, 16, 16, 16]
        out = self.decoder2(torch.cat([self.up2(out), t3], dim=1))                               # [1, 128, 32, 32, 32]                      
        out = self.decoder3(torch.cat([self.up3(out), t2], dim=1))                               # [1, 64, 64, 64, 64]
        out = self.decoder4(torch.cat([self.up4(out), t1], dim=1))                               # [1, 32, 128, 128, 128]

        # 输出层
        out = self.output_conv(out)
        out = self.soft(out)

        return out

In [44]:
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
writer = SummaryWriter('/mnt/d/AI_Research/WS-HUB/WS-segBratsWorkflow/Helium-327-SegBrats/results/2024-09-25/20-00-30/tensorBoard/UNet3D_braTS21_2024-09-25_20-00-30')

In [45]:
model = UNet3D(4, 4).to(device)

In [47]:
writer.add_scalars('DiceLoss/val',
                {'Mean':0.1,
                    'ET': 0.2,
                    'TC': 0.5,
                    'WT': 0.5},
                7)

In [46]:
writer.add_graph(model, torch.rand([1, 4, 128, 128, 128]).to(device))

writer.close()

In [48]:
/mnt/d/AI_Research/WS-HUB/WS-segBratsWorkflow/Helium-327-SegBrats/results/2024-09-26/15-05-39/tensorBoard

SyntaxError: leading zeros in decimal integer literals are not permitted; use an 0o prefix for octal integers (3327457341.py, line 1)

In [74]:
path = '/mnt/d/AI_Research/WS-HUB/WS-segBratsWorkflow/Helium-327-SegBrats/results/2024-09-26/17-23-28/checkpoints/UNet3D_braTS21_2024-09-26_17-23-29/best@epoch5_diceloss0.2831_dice0.3501_4.pth'
split_path = path.split('/')

In [80]:
import os
results_dir = ('/').join(path.split('/')[:-3])

results_dir= os.path.join(results_dir, 'tensorBoard')

results_dir

'/mnt/d/AI_Research/WS-HUB/WS-segBratsWorkflow/Helium-327-SegBrats/results/2024-09-26/17-23-28/tensorBoard'

In [73]:
import os

results_dir = ('/').join(path.split('/')[:-1])
logs_dir = os.path.join(results_dir, 'logs')
logs_file_name = [file for file in os.listdir(logs_dir) if file.endswith('.log')]
logs_path = os.path.join(logs_dir, logs_file_name[0])
logs_path

'/mnt/d/AI_Research/WS-HUB/WS-segBratsWorkflow/Helium-327-SegBrats/results/2024-09-26/15-05-39/logs/2024-09-26.log'