In [1]:
import os
os.environ["CHAINER_TYPE_CHECK"] = "0"

import math
import random
import numpy as np
import chainer
from chainer import cuda, Chain, optimizers, Variable, serializers, Link
import chainer.functions as F
import chainer.links as L

import cv2

In [2]:
class EqualizedConv2d(chainer.Chain):
    def __init__(self, in_ch, out_ch, ksize, stride, pad):
        w = chainer.initializers.Normal(1.0) # equalized learning rate
        self.inv_c = np.sqrt(2.0/(in_ch*ksize**2))
        super(EqualizedConv2d, self).__init__()
        with self.init_scope():
            self.c = L.Convolution2D(in_ch, out_ch, ksize, stride, pad, initialW=w)
            
    def __call__(self, x):
        return self.c(self.inv_c * x)

In [3]:
def feature_vector_normalization(x, eps=1e-8):
    
    alpha = 1.0 / F.sqrt(F.mean(x*x, axis=1, keepdims=True) + eps)
    return F.broadcast_to(alpha, x.data.shape) * x

In [4]:
class DownSampleBlock(chainer.Chain):
    def __init__(self, in_ch, out_ch):
        super(DownSampleBlock, self).__init__()
        with self.init_scope():
            self.c0 = EqualizedConv2d(in_ch, in_ch, 3, 1, 1)
            self.c1 = EqualizedConv2d(in_ch, out_ch, 3, 1, 1)
            
    def __call__(self, x):
        
        # downsampling -> conv -> conv
        
        h = F.average_pooling_2d(x, 2, 2, 0)
        h = F.leaky_relu(feature_vector_normalization(self.c0(h)))
        h = F.leaky_relu(feature_vector_normalization(self.c1(h)))
        
        return h

In [5]:
class UpSampleBlock(chainer.Chain):
    def __init__(self, in_ch, out_ch):
        super(UpSampleBlock, self).__init__()
        with self.init_scope():
            self.c0 = EqualizedConv2d(in_ch, out_ch, 3, 1, 1)
            self.c1 = EqualizedConv2d(out_ch, out_ch, 3, 1, 1)
            
    def __call__(self, x):
        
        # conv -> conv -> upsampling
        
        h = x
        h = F.leaky_relu(feature_vector_normalization(self.c0(h)))
        h = F.leaky_relu(feature_vector_normalization(self.c1(h)))
        
        h = F.unpooling_2d(h, 2, 2, 0, outsize=(h.shape[2]*2, h.shape[3]*2))
        
        return h

In [6]:
class UNet(chainer.Chain):
    def __init__(self, in_ch, out_ch):
        super(UNet, self).__init__()
        
        self.R = (32, 64, 128, 256, 512)
        
        with self.init_scope():
            self.c0 = EqualizedConv2d(in_ch, self.R[0], 3, 1, 1)
            self.c1 = EqualizedConv2d(self.R[0], self.R[0], 3, 1, 1)
            
            # down
            self.d1 = DownSampleBlock(self.R[0], self.R[1])
            self.d2 = DownSampleBlock(self.R[1], self.R[2])
            self.d3 = DownSampleBlock(self.R[2], self.R[3])
            self.d4 = DownSampleBlock(self.R[3], self.R[4])
            
            # middle
            self.m0 = DownSampleBlock(self.R[4], self.R[4])
            
            # up
            self.u4 = UpSampleBlock(self.R[4]*2, self.R[3])
            self.u3 = UpSampleBlock(self.R[3]*2, self.R[2])
            self.u2 = UpSampleBlock(self.R[2]*2, self.R[1])
            self.u1 = UpSampleBlock(self.R[1]*2, self.R[0])
            
            self.c2 = EqualizedConv2d(self.R[0]*2, self.R[0], 3, 1, 1)
            self.c3 = EqualizedConv2d(self.R[0], self.R[0], 3, 1, 1)
            
            self.out = EqualizedConv2d(self.R[0], out_ch, 3, 1, 1)
    
    def __call__(self, x):
        
        h = F.leaky_relu(feature_vector_normalization(self.c0(x)))
        h0 = F.leaky_relu(feature_vector_normalization(self.c1(h)))
        
        # down
        h1 = self.d1(h0)
        h2 = self.d2(h1)
        h3 = self.d3(h2)
        h4 = self.d4(h3)
        
        # middle and upsampling
        h5 = self.m0(h4)
        h5 = F.unpooling_2d(h5, 2, 2, 0, outsize=(h5.shape[2]*2, h5.shape[3]*2))
        
        # concat and up
        h6 = self.u4(F.concat([h5, h4], axis=1))
        h7 = self.u3(F.concat([h6, h3], axis=1))
        h8 = self.u2(F.concat([h7, h2], axis=1))
        h9 = self.u1(F.concat([h8, h1], axis=1))
        
        # last
        h = F.leaky_relu(feature_vector_normalization(self.c2(F.concat([h9, h0], axis=1))))
        h = F.leaky_relu(feature_vector_normalization(self.c3(h)))
        
        return F.relu(self.out(h))
    
    def dice_loss(self, y_true, y_pred):
        
        batch_size = y_true.data.shape[0]
        
        intersection = F.sum(y_true*y_pred) / batch_size
        
        return 2.0 * intersection / (F.sum(y_true) / batch_size + F.sum(y_pred) / batch_size + 1e-8)
        

In [7]:
def find_all_file(directory):
    for root, dirs, files in os.walk(directory):
        yield root
        for file in files:
            yield os.path.join(root, file)

In [8]:
imgs = []
labels = []

out_dim = 17

width = 256
height = 256

for file in find_all_file('./base'):
    if file.find('.jpg') == -1: continue
    if file.find('train') == -1: continue
    
    img = cv2.imread(file)
    img = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
    img = img.transpose(2,0,1).astype("f")/128.0-1.0
    
    label = cv2.imread(file.replace('train', 'label'), cv2.IMREAD_GRAYSCALE)
    label = cv2.resize(label, (width, height), interpolation=cv2.INTER_AREA)
    #label = np.reshape(label, (width, height, 1)).astype("i")
    
    label_dim = np.zeros((out_dim, label.shape[0], label.shape[1])).astype("f")
    
    for j in range(out_dim):
        label_dim[j,:] = label == j
    
    #print(label_dim)
    # add_
    imgs.append(img)
    labels.append(label_dim)
    
imgs = np.asarray(imgs).astype('f')
labels = np.asarray(labels).astype('f')

In [9]:
unet = UNet(3, out_dim)
batch_size=2

optimizer = optimizers.Adam()
#optimizer = o)ptimizers.Adam(alpha=0.001, beta1=0.0, beta2=0.99)
optimizer.setup(unet)


    

<chainer.optimizers.adam.Adam at 0x7f26cab08be0>

In [None]:
train_num = len(imgs)
print(train_num)

161


In [None]:
batch_num = 4
xp = np

for epoch in range(100):
    
    serializers.save_npz("unet" + str(epoch) + ".npz", unet)
    
    perm = np.random.permutation(train_num)
    for i in range(0, train_num, batch_num):
    
        indies = perm[i:i+batch_num]
        #print(indies.tolist())

        # real_image and label batch
        image_batch = imgs[indies]
        label_batch = labels[indies]
        
        #print(image_batch.shape, label_batch.shape)
        
        image_batch = Variable(xp.asarray(image_batch, dtype=np.float32))
        label_batch = Variable(xp.asarray(label_batch, dtype='f'))
        
        fake_label = unet(image_batch)
        
        dice = unet.dice_loss(label_batch, fake_label)
        loss = 1.0 - dice
        
        #print(i)

        unet.cleargrads()
        loss.backward()
        optimizer.update()
        
        print(i, loss.data, dice.data)

0 0.9598525762557983 0.040147412568330765
4 0.9535766839981079 0.0464232936501503
8 0.9371452927589417 0.06285470724105835
12 0.8921831250190735 0.1078168973326683
16 0.86331707239151 0.1366829127073288
20 0.8462271094322205 0.15377290546894073
24 0.8831335306167603 0.11686649918556213
28 0.8957775831222534 0.10422242432832718
32 0.805911123752594 0.194088876247406
36 0.820465624332428 0.17953436076641083
40 0.7619314193725586 0.2380686104297638
44 0.7761659622192383 0.2238340675830841
48 0.7730020880699158 0.22699792683124542
52 0.7813543081283569 0.21864570677280426
56 0.7008681297302246 0.299131840467453
60 0.7101012468338013 0.28989872336387634
64 0.8272091746330261 0.17279081046581268
68 0.7123072147369385 0.2876928150653839
72 0.6952790021896362 0.3047209680080414
76 0.645759642124176 0.354240357875824
80 0.7181050777435303 0.28189489245414734
84 0.6776475310325623 0.32235246896743774
88 0.670186460018158 0.32981353998184204
92 0.6739155650138855 0.3260844349861145
96 0.679705739