![Model](./sketch_model.png)

In [4]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow.contrib.layers as tf_layers

  from ._conv import register_converters as _register_converters


[unet](https://www.youtube.com/watch?v=81AvQQnpG4Q):explains the concatenation in the decoder

In [47]:
def encoder(images,normalizer_fn=tf_layers.batch_norm,activation=tf.nn.leaky_relu):
    """
    images:n*h*x*c
    """
    
    e1=tf_layers.conv2d(images,num_outputs=64,kernel_size=4,stride=2,normalizer_fn=normalizer_fn,activation_fn=tf.nn.leaky_relu)
    e2=tf_layers.conv2d(images,num_outputs=128,kernel_size=4,stride=2,normalizer_fn=normalizer_fn,activation_fn=tf.nn.leaky_relu)
    e3=tf_layers.conv2d(images,num_outputs=256,kernel_size=4,stride=2,normalizer_fn=normalizer_fn,activation_fn=tf.nn.leaky_relu)
    e4=tf_layers.conv2d(images,num_outputs=512,kernel_size=4,stride=2,normalizer_fn=normalizer_fn,activation_fn=tf.nn.leaky_relu)
    e5=tf_layers.conv2d(images,num_outputs=512,kernel_size=4,stride=2,normalizer_fn=normalizer_fn,activation_fn=tf.nn.leaky_relu)
    e6=tf_layers.conv2d(images,num_outputs=512,kernel_size=4,stride=2,normalizer_fn=normalizer_fn,activation_fn=tf.nn.leaky_relu)
    encoded=tf_layers.conv2d(images,num_outputs=512,kernel_size=4,stride=2,normalizer_fn=normalizer_fn,activation_fn=tf.nn.leaky_relu)
    
    num_images=images.get_shape()[0].value
    #features=tf.reshape(encoded,[num_images,-1])
    
    return encoded
    

In [56]:
# images=tf.placeholder("float",[None,256,256,nc])
# init=tf.global_variables_initializer()
# enc=encoder(images)
# with tf.Session() as sess:
#     sess.run(init)
#     u=sess.run(enc,feed_dict={images:img})

In [33]:
def upsample(x,n_channels,kernel=4,stride=2,activation_fn=tf.nn.leaky_relu,normalizer_fn=tf_layers.batch_norm):
    """
    x is encoded
    """
    h_new=(x.get_shape()[1].value)*stride
    w_new=(x.get_shape()[2].value)*stride
    up=tf.image.resize_nearest_neighbor(x,[h_new,w_new])
    
    return tf_layers.conv2d(up,num_outputs=n_channels,kernel_size=kernel,stride=1,normalizer_fn=normalizer_fn,activation_fn=activation_fn)
    

In [34]:
def decoder(encoded,out_channels):
    d6=tf_layers.droput(upsample(encoded,512))
    d5=tf_layers.droput(upsample(tf.concat([d6,e6],3),512))
    d4=upsample(tf.concat([d5,e5],3),512)
    d3=upsample(tf.concat([d4,e4],3),256)
    d2=upsample(tf.concat([d3,e3],3),128)
    d1=upsample(tf.concat([d2,e2],3),64)
    decoded=upsample(tf.concat([d1,e1],3),out_channels)
    
    return decoded
    
    

In [38]:
def uNet(images,out_channels=5):
    encoded,features=encoder(images)
    decoded=decode(encoded)
    
    return decoded

In [11]:
import os
import cv2

In [52]:
imgsv=cv2.imread('1.png',0)
imgtv=cv2.imread('2.png',0)
img=np.dstack((imgsv,imgtv))
nc=2
img=np.reshape(img,(-1,256,256,nc))

The 5-channel  image includes   
1 a depth map  
3 normal map  
5 foreground mask(threshold 50%)

#### Losses  
Depth loss= $\sum_{p}(d_{p} - d)f$ | f is 0 or 1

In [1]:
def depth_loss(pred,truth,mask):
    """
    pred=nxhxwx1
    truth="
    mask="
    
    return normalized loss scalar
    """
    loss=tf.subtract(pred-truth)
    loss=tf.abs(loss)
    loss=tf.boolean_mask(loss,tf.squeeze(mask,[3]))
    nloss=tf.reduce_mean(loss)
    nloss=nloss*pred.get_shape()[0].value
    
    return nloss

In [1]:
def normal_loss(pred,truth,mask):
    """
    pred=nxhxwx1
    truth="
    mask="
    
    return normalized loss scalar
    """
    loss=depth_loss(pred,truth,mask)
    nloss=loss*pred.get_shape()[3]
    return nloss

In [37]:
def mask_loss(pred,truth):
    #[-1,1] -> [0,1]
    pred=pred*0.5+0.5
    truth=truth*0.5+0.5
    
    loss=tf.multiply(truth,tf.log(tf.maximum(1e-6,pred)))
    loss=loss+tf.multiply((1-truth),tf.log(tf.maximum(1e-6,1-pred)))
    loss=tf.reduce_sum(-loss)
    nloss=loss/np.prod(truth.get_shape().as_list[1:])
    return nloss

In [38]:
def total_loss(pred,truth):
    """
    pred=nxhxwxc
    """
    depth_pred=pred[:,:,:,0]
    depth_truth=truth[:,:,:,0]
    normal_pred=pred[:,:,:,1:4]
    normal_truth=truth[:,:,:,1:4]
    mask_pred=pred[:,:,:,4]
    mask_truth=truth[:,:,:,4]
    
    dl=depth_loss(depth_pred,depth_truth,mask_truth)
    nl=normal_loss(normal_pred,normal_truth,mask_truth)
    ml=mask_loss(mask_pred,mask_truth)
    
    return (dl+ml+nl)
    