In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils as utils
import torch.nn.init as init
import torch.utils.data as data
import torchvision.utils as v_utils
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable

def conv_block_3d(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        nn.Conv3d(in_dim,out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm3d(out_dim),
        act_fn,
    )
    return model

def conv_trans_block_3d(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        nn.ConvTranspose3d(in_dim,out_dim, kernel_size=3, stride=2, padding=1,output_padding=1),
        nn.BatchNorm3d(out_dim),
        act_fn,
    )
    return model

def maxpool_3d():
    pool = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
    return pool

def conv_block_2_3d(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        conv_block_3d(in_dim,out_dim,act_fn),
        nn.Conv3d(out_dim,out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm3d(out_dim),
    )
    return model    

def conv_block_3_3d(in_dim,out_dim,act_fn):
    model = nn.Sequential(
        conv_block_3d(in_dim,out_dim,act_fn),
        conv_block_3d(out_dim,out_dim,act_fn),
        nn.Conv3d(out_dim,out_dim, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm3d(out_dim),
    )
    return model

class UnetGenerator_3d(nn.Module):

    def __init__(self,in_dim,out_dim,num_filter):
        super(UnetGenerator_3d,self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.num_filter = num_filter
        act_fn = nn.LeakyReLU(0.2, inplace=True)

        print("\n------\tInitiating U-Net\t------\t\n")
        
        self.down_1 = conv_block_2_3d(self.in_dim,self.num_filter,act_fn)
        self.pool_1 = maxpool_3d()
        self.down_2 = conv_block_2_3d(self.num_filter,self.num_filter*2,act_fn)
        self.pool_2 = maxpool_3d()
        self.down_3 = conv_block_2_3d(self.num_filter*2,self.num_filter*4,act_fn)
        self.pool_3 = maxpool_3d()
        
        self.bridge = conv_block_2_3d(self.num_filter*4,self.num_filter*8,act_fn)
        
        self.trans_1 = conv_trans_block_3d(self.num_filter*8,self.num_filter*8,act_fn)
        self.up_1 = conv_block_2_3d(self.num_filter*12,self.num_filter*4,act_fn)
        self.trans_2 = conv_trans_block_3d(self.num_filter*4,self.num_filter*4,act_fn)
        self.up_2 = conv_block_2_3d(self.num_filter*6,self.num_filter*2,act_fn)
        self.trans_3 = conv_trans_block_3d(self.num_filter*2,self.num_filter*2,act_fn)
        self.up_3 = conv_block_2_3d(self.num_filter*3,self.num_filter*1,act_fn)
        
        self.out = conv_block_3d(self.num_filter,out_dim,act_fn)

        print("------\tComplete\t------")

    def forward(self,x):
        down_1 = self.down_1(x)
        pool_1 = self.pool_1(down_1)
        down_2 = self.down_2(pool_1)
        pool_2 = self.pool_2(down_2)
        down_3 = self.down_3(pool_2)
        pool_3 = self.pool_3(down_3)
        
        bridge = self.bridge(pool_3)
        
        trans_1  = self.trans_1(bridge)
        concat_1 = torch.cat([trans_1,down_3],dim=1)
        up_1     = self.up_1(concat_1)
        trans_2  = self.trans_2(up_1)
        concat_2 = torch.cat([trans_2,down_2],dim=1)
        up_2     = self.up_2(concat_2)
        trans_3  = self.trans_3(up_2)
        concat_3 = torch.cat([trans_3,down_1],dim=1)
        up_3     = self.up_3(concat_3)
        
        out = self.out(up_3)
                        
        return out

def loss_function(output,label):
    batch_len, channel,x,y,z = output.size()
    total_loss = 0
    for i in range(batch_len):    
        for j in range(z):
            loss = 0
            output_z = output[i:i+1,:,:,:,j]
            label_z = label[i,:,:,:,j]
            
            softmax_output_z = nn.Softmax2d()(output_z)
            logsoftmax_output_z = torch.log(softmax_output_z)
            
            loss = nn.NLLLoss2d()(logsoftmax_output_z,label_z)
            total_loss += loss
            
    return total_loss+1e-6

In [2]:
in_dim = 3
out_dim = 2
num_filter = 16

model = UnetGenerator_3d(in_dim=in_dim,
                         out_dim=out_dim,
                         num_filter=num_filter)


------	Initiating U-Net	------	

------	Complete	------


In [3]:
batch_size = 2
channel =3
voxel_size = 16

data = torch.zeros(batch_size,
                    channel,
                    voxel_size,voxel_size,voxel_size)

data[:,:,2:5,3:6,2:9]=1.0

data.size(),torch.sum(data[data==1])

(torch.Size([2, 3, 16, 16, 16]), 378.0)

In [4]:
label = torch.zeros(batch_size,1,16,16,16).type_as(torch.LongTensor())
label[:,:,2:10,3:7,4:9]=1
label_var = Variable(label)

#print(label)
print(16**3,"\n",torch.sum(label_var),"\n",torch.sum(label)/(16**3))

4096 
 Variable containing:
 320
[torch.LongTensor of size 1]
 
 0.078125


In [5]:
#output = model(Variable(data))

output = torch.zeros(batch_size,2,16,16,16)
output[:,1,2:10,3:7,4:9]=1
output = Variable(output)

loss = loss_function(output,label_var)

print(output.data.type(),"\n",loss)

torch.FloatTensor 
 Variable containing:
 21.7058
[torch.FloatTensor of size 1]



In [6]:
output_idx = output[:,0,:,:,:] < output[:,1,:,:,:]
#print(output_idx)
torch.sum(output_idx)

Variable containing:
 64
[torch.ByteTensor of size 1]

In [7]:
import numpy as np
def dice_score(output,label_var):
    batch_size = output.size()[0]
    #print(batch_size)
    
    np_output = output.data.numpy()
    np_label = label.numpy()
    
    #print(np_output.shape,np_label.shape)
    
    output_idx = np_output[:,0,:,:,:] < np_output[:,1,:,:,:]
    output_idx = output_idx.astype(np.int64)
    #print(output_idx.sum(),np_label.sum())
    
    denominator = output_idx.sum()+np_label.sum()
    
    #print(denominator.dtype, output_idx.dtype,np_label.dtype)
    
    #print(output_idx.sum())
    output_idx[output_idx==0]=-1
    #print(output_idx.sum())
    
    output_idx = output_idx.reshape([batch_size,1,voxel_size,voxel_size,voxel_size])
    #print(output_idx.shape,np_label.shape)

    intersection = output_idx==np_label
    intersection = intersection.astype(np.int64)
    #print(intersection.shape)
    
    #print(intersection.sum(),denominator)
    
    return intersection.sum()*2/denominator

In [8]:
dice_score(output,label_var)

1.0

In [9]:
optimizer = optim.Adam(model.parameters(),lr=0.002)

In [10]:
for i in range(10000):
    optimizer.zero_grad()
    output = model(Variable(data))
    loss = loss_function(output,label_var)
    loss.backward()
    if i % 10 ==0:
        d_score = dice_score(output,label_var)*100
        print(d_score)
        if d_score ==100:
            break
    optimizer.step()

12.016175621
63.4920634921
59.8130841121
61.3026819923
62.1359223301
62.9921259843
65.5737704918
72.8929384966
78.239608802
86.7208672087
93.567251462
96.6767371601
97.8593272171
99.0712074303
99.6884735202
100.0
