In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# 基于UNet3d进行改进

## 1. 网络结构

In [3]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
         )
    def forward(self, x):
        return self.conv(x)

### 1.1 UNet_BN

In [4]:
class UNet3d_bn(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3d_bn, self).__init__()
        self.encoder1 = DoubleConv(in_channels, 32)
        self.encoder2 = DoubleConv(32, 64)
        self.encoder3 = DoubleConv(64, 128)
        self.encoder4 = DoubleConv(128, 256)
        self.encoder5 = DoubleConv(256, 512) 

        self.decoder1 = DoubleConv(512, 256)
        self.conv_trans1 = nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2)
        self.decoder2 = DoubleConv(256, 128)
        self.conv_trans2 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = DoubleConv(128, 64)
        self.conv_trans3 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        self.decoder4 = DoubleConv(64, 32)
        self.conv_trans4 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
        self.out_conv = DoubleConv(32, out_channels)

        self.soft = nn.Softmax(dim=1)
        
        
    def forward(self, x):
        # 编码器部分
        t1 = self.encoder1(x)                                               # 32 x 128 x 128 x 128
        out = F.max_pool3d(t1, 2, 2)                                        # 32 x 64 x 64 x 64
                                    
        t2 = self.encoder2(out)                                             # 64 x 64 x 64 x 64
        out = F.max_pool3d(t2, 2, 2)                                        # 64 x 32 x 32 x 32
        
        t3 = self.encoder3(out)                                             # 128 x 32 x 32 x 32
        out = F.max_pool3d(t3, 2, 2)                                        # 128 x 16 x 16 x 16
        
        t4 = self.encoder4(out)                                             # 256 x 16 x 16 x 16
        out = F.max_pool3d(t4, 2, 2)                                        # 256 x 8 x 8 x 8
        
        out = self.encoder5(out)                                            # 512 x 8 x 8 x 8
        
        
        
        out = self.conv_trans1(out)                                         # 256 x 16 x 16 x 16
        out = self.decoder1(torch.cat([out, t4], dim=1))                    # 256 x 16 x 16 x 16
        
        out = self.conv_trans2(out)                                          # 128 x 32 x 32 x 32
        out = self.decoder2(torch.cat([out, t3], dim=1))                    # 128 x 32 x 32 x 32
        
        out = self.conv_trans3(out)                                         # 64 x 64 x 64 x 64
        out = self.decoder3(torch.cat([out, t2], dim=1))                    # 64 x 64 x 64 x 64                

        out = self.conv_trans4(out)                                         # 32 x 128 x 128 x 128
        out = self.decoder4(torch.cat([out, t1], dim=1))                    # 32 x 128 x 128 x 128

        out = self.out_conv(out)                                            # out_channels x 128 x 128
        
        out = self.soft(out)                                             # softmax
        return out

In [5]:
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet3d_bn(in_channels=4, out_channels=4)
input_tensor = torch.randn([1, 4, 128, 128, 128]).float()

model.to(device)
input_tensor = input_tensor.to(device)


out = model(input_tensor)
print(out.shape)
summary(model, (4, 128, 128, 128))

torch.Size([1, 4, 128, 128, 128])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1    [-1, 32, 128, 128, 128]           3,488
       BatchNorm3d-2    [-1, 32, 128, 128, 128]              64
              ReLU-3    [-1, 32, 128, 128, 128]               0
            Conv3d-4    [-1, 32, 128, 128, 128]          27,680
       BatchNorm3d-5    [-1, 32, 128, 128, 128]              64
              ReLU-6    [-1, 32, 128, 128, 128]               0
        DoubleConv-7    [-1, 32, 128, 128, 128]               0
            Conv3d-8       [-1, 64, 64, 64, 64]          55,360
       BatchNorm3d-9       [-1, 64, 64, 64, 64]             128
             ReLU-10       [-1, 64, 64, 64, 64]               0
           Conv3d-11       [-1, 64, 64, 64, 64]         110,656
      BatchNorm3d-12       [-1, 64, 64, 64, 64]             128
             ReLU-13       [-1, 64, 64, 64, 64]               0
     

In [6]:
# simple UNet3d_ln
class UNet3d_ln(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3d_ln, self).__init__()
        self.encoder1 = nn.Conv3d(in_channels, 32, kernel_size=3, padding=1)
        self.encoder2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.encoder3 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.encoder4 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
        self.encoder5 = nn.Conv3d(256, 512, kernel_size=3, padding=1)

        self.decoder1 = nn.Conv3d(512, 256, kernel_size=3, padding=1)
        self.conv_trans1 = nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2)
        self.decoder2 = nn.Conv3d(256, 128, kernel_size=3, padding=1)
        self.conv_trans2 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
        self.conv_trans3 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        self.decoder4 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
        self.conv_trans4 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
        self.out_conv = nn.Conv3d(32, out_channels, kernel_size=3, padding=1)

        self.soft = nn.Softmax(dim=1)
        
        
    def forward(self, x):
        # 编码器
        out = self.encoder1(x)                                              # 32 x 128 x 128 x 128
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                     
        t1 = out                                                            # 32 x 128 x 128 x 128
        
        out = F.max_pool3d(t1, 2, 2)                                        # 32 x 64 x 64 x 64
        out = self.encoder2(out)                                            # 64 x 64 x 64 x 64
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        t2 = out                                                            # 64 x 64 x 64 x 64

        out = F.max_pool3d(t2, 2, 2)                                        # 64 x 32 x 32 x 32
        out = self.encoder3(out)                                            # 128 x 32 x 32 x 32
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        t3 = out                                                            # 128 x 32 x 32 x 32

        out = F.max_pool3d(t3, 2, 2)                                        # 128 x 16 x 16 x 16
        out = self.encoder4(out)                                            # 256 x 16 x 16 x 16
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                     
        t4 = out                                                            # 256 x 16 x 16 x 16
        
        out = F.max_pool3d(t4, 2, 2)                                        # 256 x 8 x 8 x 8
        out = self.encoder5(out)                                            # 512 x 8 x 8 x 8
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        
        # 解码器
        out = self.conv_trans1(out)                                         # 256 x 16 x 16 x 16
        out = self.decoder1(torch.cat([out, t4], dim=1))                    # 256 x 16 x 16 x 16
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        
        out = self.conv_trans2(out)                                         # 128 x 32 x 32 x 32
        out = self.decoder2(torch.cat([out, t3], dim=1))                    # 128 x 32 x 32 x 32
        out = F.relu(F.layer_norm(out, out.shape[-3:]))

        out = self.conv_trans3(out)                                         # 64 x 64 x 64 x 64
        out = self.decoder3(torch.cat([out, t2], dim=1))                    # 64 x 64 x 64 x 64
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                 
        
        out = self.conv_trans4(out)                                         # 32 x 128 x 128 x 128
        out = self.decoder4(torch.cat([out, t1], dim=1))                    # 32 x 128 x 128 x 128
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                     
        
        out = self.out_conv(out)                                            # out_channels x 128 x 128 x 128
        
        out = self.soft(out)                                                # softmax
        
        return out

In [7]:

# 改进
class UNet3d_ln_double(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet3d_ln_double, self).__init__()
        self.encoder1 = nn.Conv3d(in_channels, 32, kernel_size=3, padding=1)
        self.encoder2 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.encoder3 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.encoder4 = nn.Conv3d(128, 256, kernel_size=3, padding=1)
        self.encoder5 = nn.Conv3d(256, 512, kernel_size=3, padding=1)

        self.conv_32    = nn.Conv3d(32, 32, kernel_size=3, padding=1)
        self.conv_64    = nn.Conv3d(64, 64, kernel_size=3, padding=1)
        self.conv_128    = nn.Conv3d(128, 128, kernel_size=3, padding=1)
        self.conv_256    = nn.Conv3d(256, 256, kernel_size=3, padding=1)
        self.conv_512    = nn.Conv3d(512, 512, kernel_size=3, padding=1)    
        
        self.decoder1 = nn.Conv3d(512, 256, kernel_size=3, padding=1)
        self.conv_trans1 = nn.ConvTranspose3d(512, 256, kernel_size=2, stride=2)
        self.decoder2 = nn.Conv3d(256, 128, kernel_size=3, padding=1)
        self.conv_trans2 = nn.ConvTranspose3d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
        self.conv_trans3 = nn.ConvTranspose3d(128, 64, kernel_size=2, stride=2)
        self.decoder4 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
        self.conv_trans4 = nn.ConvTranspose3d(64, 32, kernel_size=2, stride=2)
        self.out_conv = nn.Conv3d(32, out_channels, kernel_size=3, padding=1)

        self.soft = nn.Softmax(dim=1)
        
        
        
    def forward(self, x):
        # 编码器
        out = self.encoder1(x)                                              # 32 x 128 x 128 x 128
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                     
        out = self.conv_32(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        t1 = out                                                            # 32 x 128 x 128 x 128
        
        out = F.max_pool3d(t1, 2, 2)                                        # 32 x 64 x 64 x 64
        out = self.encoder2(out)                                            # 64 x 64 x 64 x 64
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        out = self.conv_64(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        t2 = out                                                            # 64 x 64 x 64 x 64

        out = F.max_pool3d(t2, 2, 2)                                        # 64 x 32 x 32 x 32
        out = self.encoder3(out)                                            # 128 x 32 x 32 x 32
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        out = self.conv_128(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        t3 = out                                                            # 128 x 32 x 32 x 32

        out = F.max_pool3d(t3, 2, 2)                                        # 128 x 16 x 16 x 16
        out = self.encoder4(out)                                            # 256 x 16 x 16 x 16
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                     
        out = self.conv_256(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        t4 = out                                                            # 256 x 16 x 16 x 16
        
        out = F.max_pool3d(t4, 2, 2)                                        # 256 x 8 x 8 x 8
        out = self.encoder5(out)                                            # 512 x 8 x 8 x 8
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        out = self.conv_512(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        
        # 解码器
        out = self.conv_trans1(out)                                         # 256 x 16 x 16 x 16
        out = self.decoder1(torch.cat([out, t4], dim=1))                    # 256 x 16 x 16 x 16
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        out = self.conv_256(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        
        out = self.conv_trans2(out)                                         # 128 x 32 x 32 x 32
        out = self.decoder2(torch.cat([out, t3], dim=1))                    # 128 x 32 x 32 x 32
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        out = self.conv_128(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))

        out = self.conv_trans3(out)                                         # 64 x 64 x 64 x 64
        out = self.decoder3(torch.cat([out, t2], dim=1))                    # 64 x 64 x 64 x 64
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                 
        out = self.conv_64(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        
        out = self.conv_trans4(out)                                         # 32 x 128 x 128 x 128
        out = self.decoder4(torch.cat([out, t1], dim=1))                    # 32 x 128 x 128 x 128
        out = F.relu(F.layer_norm(out, out.shape[-3:]))                     
        out = self.conv_32(out)
        out = F.relu(F.layer_norm(out, out.shape[-3:]))
        
        out = self.out_conv(out)                                            # out_channels x 128 x 128 x 128
        
        out = self.soft(out)                                                # softmax
        
        return out

In [8]:
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet3d_ln(in_channels=4, out_channels=4)
input_tensor = torch.randn([1, 4, 128, 128, 128]).float()

model.to(device)
input_tensor = input_tensor.to(device)


out = model(input_tensor)
print(out.shape)
summary(model, (4, 128, 128, 128))

torch.Size([1, 4, 128, 128, 128])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1    [-1, 32, 128, 128, 128]           3,488
            Conv3d-2       [-1, 64, 64, 64, 64]          55,360
            Conv3d-3      [-1, 128, 32, 32, 32]         221,312
            Conv3d-4      [-1, 256, 16, 16, 16]         884,992
            Conv3d-5         [-1, 512, 8, 8, 8]       3,539,456
   ConvTranspose3d-6      [-1, 256, 16, 16, 16]       1,048,832
            Conv3d-7      [-1, 256, 16, 16, 16]       3,539,200
   ConvTranspose3d-8      [-1, 128, 32, 32, 32]         262,272
            Conv3d-9      [-1, 128, 32, 32, 32]         884,864
  ConvTranspose3d-10       [-1, 64, 64, 64, 64]          65,600
           Conv3d-11       [-1, 64, 64, 64, 64]         221,248
  ConvTranspose3d-12    [-1, 32, 128, 128, 128]          16,416
           Conv3d-13    [-1, 32, 128, 128, 128]          55,328
     

In [9]:
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet3d_ln_double(in_channels=4, out_channels=4)
input_tensor = torch.randn([1, 4, 128, 128, 128]).float()

model.to(device)
input_tensor = input_tensor.to(device)


out = model(input_tensor)
print(out.shape)
summary(model, (4, 128, 128, 128))

torch.Size([1, 4, 128, 128, 128])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1    [-1, 32, 128, 128, 128]           3,488
            Conv3d-2    [-1, 32, 128, 128, 128]          27,680
            Conv3d-3       [-1, 64, 64, 64, 64]          55,360
            Conv3d-4       [-1, 64, 64, 64, 64]         110,656
            Conv3d-5      [-1, 128, 32, 32, 32]         221,312
            Conv3d-6      [-1, 128, 32, 32, 32]         442,496
            Conv3d-7      [-1, 256, 16, 16, 16]         884,992
            Conv3d-8      [-1, 256, 16, 16, 16]       1,769,728
            Conv3d-9         [-1, 512, 8, 8, 8]       3,539,456
           Conv3d-10         [-1, 512, 8, 8, 8]       7,078,400
  ConvTranspose3d-11      [-1, 256, 16, 16, 16]       1,048,832
           Conv3d-12      [-1, 256, 16, 16, 16]       3,539,200
           Conv3d-13      [-1, 256, 16, 16, 16]       1,769,728
  Con

In [11]:
out

In [9]:
a_list = (128 ,128, 128, [1, 1, 1])

b_list = (4, *a_list)

b_list


(4, 128, 128, 128, [1, 1, 1])

In [21]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np


plt.switch_backend('SVG')

# 创建一个简单的神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 10)
        self.fc2 = nn.Linear(10, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 创建一个数据集
class Dataset(torch.utils.data.Dataset):
    def __init__(self):
        self.X = torch.rand(1000, 2)
        self.y = torch.rand(1000, 2)

    def __getitem__(self, index):
        return self.X[index], self.y[index]

    def __len__(self):
        return len(self.X)

# 创建一个数据加载器
dataset = Dataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

# 创建一个模型和优化器
model = Net()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 定义动画函数
def animate(i):
    for batch_idx, (X, y) in enumerate(dataloader):
        optimizer.zero_grad()
        outputs = model(X)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()
    ax.clear()
    ax.plot(range(i+1), [loss.item()], 'bo-')
    ax.set_title('Training Loss')
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Loss')
    return ax,

# 创建一个新图形
fig, ax = plt.subplots()

# 创建动画对象
ani = animation.FuncAnimation(fig, animate, frames=100, interval=50, blit=True)

# 显示动画
plt.show()


  plt.show()


In [22]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# 创建一个新的图形
fig, ax = plt.subplots()

# 初始化x和y值
x = np.arange(0, 2*np.pi, 0.01)
y = np.sin(x)

# 创建一个线条对象
line, = ax.plot(x, y)

# 定义动画函数
def animate(i):
    line.set_ydata(np.sin(x + i / 50))  # 更新y值
    return line,

# 创建动画对象
ani = animation.FuncAnimation(fig, animate, frames=200, interval=20, blit=True)

# 显示动画
plt.show()


  plt.show()
