In [4]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import math 

device = 'cuda'

In [3]:
# ((W - K + 2p)/s) + 1

In [21]:
class ResNet_block(nn.Module):
    "A ResNet-like block with the GroupNorm normalization providing optional bottle-neck functionality"
    def __init__(self, ch, k_size=3, stride=1, p=1, num_groups=1):
        super(ResNet_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(ch, ch, kernel_size=k_size, stride=stride, padding=p), 
            nn.BatchNorm3d(ch),
            nn.ReLU(inplace=True), 
            nn.Conv3d(ch, ch, kernel_size=k_size, stride=stride, padding=p),  
            nn.BatchNorm3d(ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        out = self.conv(x) + x
        return out

# ch = 1
# h = 113
# w = 113
# s = 137
# b = 2
# input = torch.randn(b, ch, s, h, w).to(torch.float32).to(device)
# net = ResNet_block(ch, 3).to(device)
# out = net(input)
# print(out.shape)

In [20]:
class conv_block(nn.Module):
    def __init__(self, ch_in, ch_out, k_size=3, stride=1, p=1, num_groups=1):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(ch_in, ch_out, kernel_size=k_size, stride=stride, padding=p),  
            nn.BatchNorm3d(ch_out),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        out = self.conv(x)
        return out

# ch_in = 128
# h = 113
# w = 113
# s = 137
# b = 2
# ch_out = 64 
# k_size = 3
# input = torch.randn(b, ch_in, s, h, w).to(torch.float32).to(device)
# net = conv_block(ch_in, ch_out, k_size).to(device)
# out = net(input)
# print(out.shape)

In [6]:
class up_conv(nn.Module):
    "Reduce the number of features by 2 using Conv with kernel size 1x1x1 and double the spatial dimension using 3D trilinear upsampling"
    def __init__(self, ch_in, ch_out, k_size=(1,2,2), stride=(1,2,2), p=(0,0,0)):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            #nn.Conv3d(ch_in, ch_out, kernel_size=k_size),
            #nn.Upsample(scale_factor=scale, mode='trilinear', align_corners=align_corners),
            nn.ConvTranspose3d(ch_in, ch_out, k_size, stride=stride, padding=p),
            nn.BatchNorm3d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.up(x)


    
# ch_in = 128
# h = 128
# w = 128
# s = 137
# b = 2
# ch_out = 64 
# k_size = 2
# input = torch.randn(b, ch_in, s, h, w).to(torch.float32).to(device)
# net = up_conv(ch_in, ch_out).to(device)
# out = net(input)
# print(out.shape)

In [18]:
up = nn.ConvTranspose3d(8, 8, (1,2,2), stride=(2,2,2), padding=(0,0,0)).to(device)
scale = (137/68,2,2)
up = nn.Upsample(scale_factor=scale, mode='trilinear', align_corners=True)

b = 2
ch_in = 8
s = 68
h = 16
w = 16
input = torch.randn(b, ch_in, s, h, w).to(torch.float32).to(device)
out = up(input)
print(out.shape)

torch.Size([2, 8, 137, 32, 32])


In [17]:
137/68

2.014705882352941

In [33]:
class up_sample_conv(nn.Module):
    "Reduce the number of features by 2 using Conv with kernel size 1x1x1 and double the spatial dimension using 3D trilinear upsampling"
    def __init__(self, ch_in, ch_out, scale, k_size=3, stride=1, p=1, align_corners=True):
        super(up_sample_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=scale, mode='trilinear', align_corners=align_corners),
            nn.Conv3d(ch_in, ch_out, kernel_size=k_size, stride=stride, padding=p),
            nn.BatchNorm3d(ch_out),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.up(x)

scale = (137/68,2,2)

b = 2
ch_in = 8
ch_out = 8
s = 68
h = 16
w = 16
input = torch.randn(b, ch_in, s, h, w).to(torch.float32).to(device)
up_sample = up_sample_conv(ch_in, ch_out, scale).to(device)
out = up_sample(input)
print(out.shape)

torch.Size([2, 8, 137, 32, 32])


In [39]:
# 128 -> 64 -> 32 -> 16 -> 8
class Encoder(nn.Module):
    """ Encoder module """
    def __init__(self):
        super(Encoder, self).__init__()
        
        start_val = 4
        self.conv1 = conv_block(ch_in=1, ch_out=start_val, k_size=3, num_groups=1)
        self.res_block1 = ResNet_block(ch=start_val, k_size=3, num_groups=8)
        self.MaxPool1 = nn.MaxPool3d(3, stride=2, padding=1)
#         self.MaxPool1 = nn.MaxPool3d((3,3,3), stride=(2,2,2), padding=(0,1,1))

        self.conv2 = conv_block(ch_in=start_val, ch_out=start_val*2, k_size=3, num_groups=8)
        self.res_block2 = ResNet_block(ch=start_val*2, k_size=3, num_groups=16)
        self.MaxPool2 = nn.MaxPool3d(3, stride=2, padding=1)
#         self.MaxPool2 = nn.MaxPool3d((3,3,3), stride=(1,2,2), padding=(0,1,1))
        

        self.conv3 = conv_block(ch_in=start_val*2, ch_out=start_val*4, k_size=3, num_groups=16)
        self.res_block3 = ResNet_block(ch=start_val*4, k_size=3, num_groups=16)
        self.MaxPool3 = nn.MaxPool3d(3, stride=2, padding=1)
#         self.MaxPool3 = nn.MaxPool3d((3,3,3), stride=(1,2,2), padding=(0,1,1))

        self.conv4 = conv_block(ch_in=start_val*4, ch_out=start_val*8, k_size=3, num_groups=16)
        self.res_block4 = ResNet_block(ch=start_val*8, k_size=3, num_groups=16)
        self.MaxPool4 = nn.MaxPool3d(3, stride=2, padding=1)
#         self.MaxPool4 = nn.MaxPool3d((3,3,3), stride=(1,2,2), padding=(0,1,1))

        self.reset_parameters()
      
    def reset_parameters(self):
        for weight in self.parameters():
            stdv = 1.0 / math.sqrt(weight.size(0))
            torch.nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = self.res_block1(x1)
        x1 = self.MaxPool1(x1) # torch.Size([1, 32, 26, 31, 26])
        
        x2 = self.conv2(x1)
        x2 = self.res_block2(x2)
        x2 = self.MaxPool2(x2) # torch.Size([1, 64, 8, 10, 8])

        x3 = self.conv3(x2)
        x3 = self.res_block3(x3)
        x3 = self.MaxPool3(x3) # torch.Size([1, 128, 2, 3, 2])
        
        x4 = self.conv4(x3)
        x4 = self.res_block4(x4) # torch.Size([1, 256, 2, 3, 2])
        x4 = self.MaxPool4(x4) # torch.Size([1, 256, 1, 1, 1])
#         print("x1 shape: ", x1.shape)
#         print("x2 shape: ", x2.shape)
#         print("x3 shape: ", x3.shape)
#         print("x4 shape: ", x4.shape) 
        return x4

ch_in = 1
h = 128
w = 128
s = 137
b = 2
ch_out = 64 
k_size = 2
input = torch.randn(b, ch_in, s, h, w).to(torch.float32).to(device)
net = Encoder().to(device)
out = net(input)
print(out.shape)

torch.Size([2, 32, 9, 8, 8])


In [24]:
32*9*8*8

18432

In [38]:
class Decoder(nn.Module):
    """ Decoder Module """
    def __init__(self, latent_dim, prev_dim = 32*9*8*8):
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        
        #self.linear_up = nn.Linear(latent_dim, 256*9*8*8)
        
        self.linear_up = nn.Sequential(nn.Linear(latent_dim, prev_dim//8),
                                   nn.BatchNorm1d(prev_dim//8),
                                   nn.ReLU(),
                                   nn.Linear(prev_dim//8, prev_dim),
                                   nn.BatchNorm1d(prev_dim),
                                   nn.ReLU())

        #up_sample_conv(ch_in, ch_out, scale)
        #self.relu = nn.ReLU()
        
        self.upsize4 = up_sample_conv(ch_in=32, ch_out=32, scale=(18/9, 2, 2))
        self.res_block4 = ResNet_block(ch=32, k_size=3, num_groups=16)
        
        self.upsize3 = up_sample_conv(ch_in=32, ch_out=16, scale=(35/18, 2, 2))
        self.res_block3 = ResNet_block(ch=16, k_size=3, num_groups=16)        
        
        self.upsize2 = up_sample_conv(ch_in=16, ch_out=8, scale=(69/35, 2, 2))
        self.res_block2 = ResNet_block(ch=8, k_size=3, num_groups=16)   
        
        self.upsize1 = up_sample_conv(ch_in=8, ch_out=4, scale=(137/69, 2, 2))
        self.res_block1 = ResNet_block(ch=4, k_size=3, num_groups=1)
        
        self.out_conv = conv_block(ch_in=4, ch_out=1)


        self.reset_parameters()
      
    def reset_parameters(self):
        for weight in self.parameters():
            stdv = 1.0 / math.sqrt(weight.size(0))
            torch.nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, x):
        x4_ = self.linear_up(x)
        #x4_ = self.relu(x4_)
        #print('x4 shape - ',x4_.shape)

        x4_ = x4_.view(-1, 32, 9, 8, 8)
        #print()
        # x4_ = x4_.view(-1, 256, 9, 8, 8)

        x4_ = self.upsize4(x4_) 
        x4_ = self.res_block4(x4_)
        print(x4_.shape)

        x3_ = self.upsize3(x4_) 
        x3_ = self.res_block3(x3_)
        print(x3_.shape)
        
        x2_ = self.upsize2(x3_) 
        x2_ = self.res_block2(x2_)
        print(x2_.shape)

        x1_ = self.upsize1(x2_)
        # print('last layer', x1_.shape) 
        x1_ = self.res_block1(x1_)
        print(x1_.shape)
        
        out = self.out_conv(x1_)
        
        #print("x1 shape: ", x1_.shape)
        #print("x2 shape: ", x2_.shape)
        #print("x3 shape: ", x3_.shape)
        #print("x4 shape: ", x4_.shape) 
        
        return out

ch_in = 1
h = 128
w = 128
s = 137
b = 2
ch_out = 64 
k_size = 2
input = torch.randn(b, 512).to(torch.float32).to(device)
net = Decoder(latent_dim=512, prev_dim = 32*9*8*8).to(device)
out = net(input)
print(out.shape)

torch.Size([2, 32, 18, 16, 16])
torch.Size([2, 16, 35, 32, 32])
torch.Size([2, 8, 69, 64, 64])
torch.Size([2, 4, 137, 128, 128])
torch.Size([2, 1, 137, 128, 128])


In [25]:
latent_dim = 512
prev_dim = 32*9*8*8
linear_up = nn.Sequential(nn.Linear(latent_dim, prev_dim//8),
                                   nn.BatchNorm1d(prev_dim//8),
                                   nn.ReLU(),
                                   nn.Linear(prev_dim//8, prev_dim),
                                   nn.BatchNorm1d(latent_dim),
                                   nn.ReLU()).to(device)

In [14]:
prev_dim//8

35072

In [15]:
prev_dim

280576

In [40]:
class VAE(nn.Module):
    def __init__(self, in_dim = 32*9*8*8, latent_dim=512):
        super(VAE, self).__init__()
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.latent_dim = latent_dim

        self.z_mean = nn.Sequential(nn.Linear(in_dim, in_dim//8),
                                   nn.BatchNorm1d(in_dim//8),
                                   nn.ReLU(),
                                   nn.Linear(in_dim//8, latent_dim),
                                   nn.BatchNorm1d(latent_dim),
                                   nn.ReLU())
        
        self.z_log_sigma = nn.Sequential(nn.Linear(in_dim, in_dim//8),
                                   nn.BatchNorm1d(in_dim//8),
                                   nn.ReLU(),
                                   nn.Linear(in_dim//8, latent_dim),
                                   nn.BatchNorm1d(latent_dim),
                                   nn.ReLU())
        
        # self.z_log_sigma = nn.Linear(256*9*8*8, latent_dim)
        
        #self.z_mean = nn.Linear(256*150, latent_dim)
        #self.z_log_sigma = nn.Linear(256*150, latent_dim)
        
        self.epsilon = torch.normal(size=(1, latent_dim), mean=0, std=1.0, device=self.device)
        self.encoder = Encoder()
        self.decoder = Decoder(latent_dim)

        self.reset_parameters()
      
    def reset_parameters(self):
        for weight in self.parameters():
            stdv = 1.0 / math.sqrt(weight.size(0))
            torch.nn.init.uniform_(weight, -stdv, stdv)

    def forward(self, x):
        x = self.encoder(x)
        #print('encoder shape -', x.shape)
        x = torch.flatten(x, start_dim=1)
        # print(x.shape)
        z_mean = self.z_mean(x)
        z_log_sigma = self.z_log_sigma(x)
        z = z_mean + z_log_sigma.exp()*self.epsilon
        # print(x.shape)
        y = self.decoder(z)
        return y, z_mean, z_log_sigma


model = VAE()
model = model.to(device)
# B, C, Seq, H, W
# input = torch.randn(20, 16, 10, 50, 100).to(torch.float32).to(device)
input = torch.randn(2, 1, 137, 128, 128).to(torch.float32).to(device)

out = model(input)
print(out[0].shape)
print(out[1].shape)
print(out[2].shape)

torch.Size([2, 32, 18, 16, 16])
torch.Size([2, 16, 35, 32, 32])
torch.Size([2, 8, 69, 64, 64])
torch.Size([2, 4, 137, 128, 128])
torch.Size([2, 1, 137, 128, 128])
torch.Size([2, 512])
torch.Size([2, 512])


In [None]:
# torch.Size([2, 32, 9, 8, 8])