In [1]:
import os 
os.environ["CUDA_VISIABLE_DEVICES"] = "1"
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:

class DownSampling(nn.Module):
    # 3x3x3 convolution,1 padding as default
    def __init__(self, inChans, outChans, stride=2, kernel_size=3, padding=1):
        super(DownSampling, self).__init__()
        
        self.conv1 = nn.Conv3d(in_channels=inChans, 
                     out_channels=outChans, 
                     kernel_size=kernel_size, 
                     stride=stride,
                     padding=padding,
                     bias=False)
    def forward(self, x):
        return self.conv1(x)
    
class EncoderBlock(nn.Module):
    def __init__(self, inChans, outChans, stride=1, activation="relu", normalizaiton="group_normalization"):
        super(EncoderBlock, self).__init__()
        
        if normalizaiton == "group_normalization":
            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=outChans)
            self.norm2 = nn.GroupNorm(num_groups=8, num_channels=outChans)
        if activation == "relu":
            self.actv1 = nn.ReLU(inplace=True)
            self.actv2 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv3d(in_channels=outChans, out_channels=outChans, kernel_size=1, stride=stride)
        self.conv2 = nn.Conv3d(in_channels=outChans, out_channels=outChans, kernel_size=1, stride=stride)
        
        
    def forward(self, x):
        residual = x
        
        out = self.norm1(x)
        out = self.actv1(out)
        out = self.conv1(out)
        out = self.norm2(out)
        out = self.actv2(out)
        out = self.conv2(out)
        
        out += residual
        return out
    
class LinearUpSampling(nn.Module):
    def __init__(self, inChans, outChans, mode="trilinear", align_corners=True):
        super(LinearUpSampling, self).__init__()
        
        self.conv1 = nn.Conv3d(in_channels=inChans, out_channels=inChans, kernel_size=1)
        scale_factor = inChans / outChans
        self.up1 = nn.Upsample(scale_factor=scale_factor,mode=mode, align_corners=align_corners)
    
    def forward(self, x, skipx):
        out = self.conv1(x)
        out = self.up1(out)
        out = torch.cat((out, skipx), 1)
        
        return out
    
class DecoderBlock(nn.Module):
    def __init__(self, inChans, outChans, stride=1, activation="relu", normalizaiton="group_normalization"):
        super(DecoderBlock, self).__init__()
        
        if normalizaiton == "group_normalization":
            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=outChans)
            self.norm2 = nn.GroupNorm(num_groups=8, num_channels=outChans)
        if activation == "relu":
            self.actv1 = nn.ReLU(inplace=True)
            self.actv2 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv3d(in_channels=outChans, out_channels=outChans, kernel_size=1, stride=stride)
        self.conv2 = nn.Conv3d(in_channels=outChans, out_channels=outChans, kernel_size=1, stride=stride)
        
        
    def forward(self, x):
        residual = x
        
        out = self.norm1(x)
        out = self.actv1(out)
        out = self.conv1(out)
        out = self.norm2(out)
        out = self.actv2(out)
        out = self.conv2(out)
        
        out += residual
        return out
    
class OutputTransition(nn.Module):
    def __init__(self, inChans, outChans):
        super(OutputTransition, self).__init__()
        
        self.conv1 = nn.Conv3d(in_channels=outChans, out_channels=outChans, kernel_size=1)
        self.actv1 = F.sigmoid
        
    def forward(self, x):
        return self.actv1(self.conv1(x))

class VAE(nn.Module):
    def __init__(self, inChans):
        super(VAE, self).__init__()
        pass
    def forward(self, x):
        pass
        
class NvNet(nn.Module):
    def __init__(self, inChans=4, activation="relu", normalizaiton="group_normalization", mode="bilinear"):
        super(NvNet, self).__init__()
        self.in_conv0 = DownSampling(inChans=inChans, outChans=32, kernel_size=1, stride=1)
        self.en_block0 = EncoderBlock(32, 32, activation=activation, normalizaiton=normalizaiton)
        self.en_down1 = DownSampling(32, 64)
        self.en_block1_0 = EncoderBlock(64, 64, activation=activation, normalizaiton=normalizaiton)
        self.en_block1_1 = EncoderBlock(64, 64, activation=activation, normalizaiton=normalizaiton)
        self.en_down2 = DownSampling(64, 128, kernel_size=1)
        self.en_block2_0 = EncoderBlock(128, 128, activation=activation, normalizaiton=normalizaiton)
        self.en_block2_1 = EncoderBlock(128, 128, activation=activation, normalizaiton=normalizaiton)
        self.en_down3 = DownSampling(128, 256)
        self.en_block3_0 = EncoderBlock(256, 256, activation=activation, normalizaiton=normalizaiton)
        self.en_block3_1 = EncoderBlock(256, 256, activation=activation, normalizaiton=normalizaiton)
        self.en_block3_2 = EncoderBlock(256, 256, activation=activation, normalizaiton=normalizaiton)
        self.en_block3_3 = EncoderBlock(256, 256, activation=activation, normalizaiton=normalizaiton)
        
        self.de_up2 =  LinearUpSampling(256, 128, mode=mode)
        self.de_block2 = DecoderBlock(128, 128, activation=activation, normalizaiton=normalizaiton)
        self.de_up1 =  LinearUpSampling(128, 64, mode=mode)
        self.de_block1 = DecoderBlock(64, 64, activation=activation, normalizaiton=normalizaiton)
        self.de_up0 =  LinearUpSampling(64, 32, mode=mode)
        self.de_block0 = DecoderBlock(32, 32, activation=activation, normalizaiton=normalizaiton)
        self.de_end = OutputTransition(32, 1)
        
    def forward(self, x):
        out_init = self.in_conv0(x)
        out_en0 = self.en_block0(out_init)
        out_en1 = self.en_block1_1(self.en_block1_0(self.en_down1(out_en0))) 
        out_en2 = self.en_block2_1(self.en_block2_0(self.en_down2(out_en1)))
        out_en3 = self.en_block3_3(
            self.en_block3_2(
                self.en_block3_1(
                    self.en_block3_0(
                        self.en_down3(out_en2)))))
        
        out_de2 = self.de_block2(self.de_up2(out_en3, out_en2))
        out_de1 = self.de_block1(self.de_up1(out_de2, out_en1))
        out_de0 = self.de_block0(self.de_up0(out_de1, out_en0))
        
        out_end = self.de_end(out_de0)


In [3]:
test_NvNet = NvNet()
print(test_NvNet)

NvNet(
  (in_conv0): DownSampling(
    (conv1): Conv3d(4, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
  )
  (en_block0): EncoderBlock(
    (norm1): GroupNorm(8, 32, eps=1e-05, affine=True)
    (norm2): GroupNorm(8, 32, eps=1e-05, affine=True)
    (actv1): ReLU(inplace)
    (actv2): ReLU(inplace)
    (conv1): Conv3d(32, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (conv2): Conv3d(32, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  )
  (en_down1): DownSampling(
    (conv1): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
  )
  (en_block1_0): EncoderBlock(
    (norm1): GroupNorm(8, 64, eps=1e-05, affine=True)
    (norm2): GroupNorm(8, 64, eps=1e-05, affine=True)
    (actv1): ReLU(inplace)
    (actv2): ReLU(inplace)
    (conv1): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (conv2): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
  )
  (en_block1_1): EncoderBlock(
    (norm1): GroupNorm(8, 64, e

In [4]:
test_NvNet = test_NvNet.cuda(1)
x = torch.randn(1,4,160,192,128).cuda(1)
out = test_NvNet(x)
print(out)
print(out.size())



NotImplementedError: Got 5D input, but bilinear mode needs 4D input

In [None]:
y = torch.rand(5,3)
print(y)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")          # a CUDA device object
    y = torch.ones((3,5), device=device)  # directly create a tensor on GPU
    print(y)
    print(y.type())

In [None]:
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

In [None]:
net.zero_grad()
out.backward(torch.randn(1, 10))

In [None]:
import resnet

In [None]:
class tmp_block(nn.Module):
    def __init__(self, inChans, outChans,stride1=2, stride2=1,activation="relu", normalizaiton="group_normalization"):
        super(tmp_block, self).__init__()
        self.conv1 = nn.Conv3d(in_channels=inChans, out_channels=outChans, kernel_size=3, stride=stride1, padding=1)
        if normalizaiton == "group_normalization":
            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=outChans)
            self.norm2 = nn.GroupNorm(num_groups=8, num_channels=outChans)
        if activation == "relu":
            self.actv1 = nn.ReLU(inplace=True)
            self.actv2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(in_channels=outChans, out_channels=outChans, kernel_size=1, stride=stride2)
        self.conv3 = nn.Conv3d(in_channels=outChans, out_channels=outChans, kernel_size=1, stride=stride2)
    def forward(self, x):
        out = self.conv1(x)
        residual = out
        
        out = self.norm1(out)
        out = self.actv1(out)
        out = self.conv2(out)
        out = self.norm2(out)
        out = self.actv2(out)
        out = self.conv3(out)
        
        out += residual
        
        return out

In [None]:
tmp_b = tmp_block(32, 64)
print(tmp_b)

In [None]:
input = torch.randn(1, 32, 160, 192, 128).cuda(1)
# print(input)
out = tmp_b(input)
print(out)

In [None]:
print(input.size())
print(out.size())