In [None]:
%matplotlib inline
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import random
import util
from util import *
from PIL import Image
from IPython.display import display

In [None]:
# code to read weights
import re

params = []
end_re = re.compile(r'^\[torch.CudaTensor of size (\w+)\]')
index_re = re.compile(r'^\((\d+),(\d+)')
with open('shapenet_intrinsics/train/model_weights_all.txt') as f:
    current = []
    linenum = 1
    for line in f:
        end = end_re.match(line)
        if end:
            end_dims = tuple(map(int, end.group(1).split('x')))
            if len(end_dims) == 1:
                if current[0].startswith('0.01 *'):
                    current = current[1:]
                array = np.array(tuple(map(float, current)))
                assert(len(array) == end_dims[0])
            else:
                array = np.zeros(end_dims)
                i = 0
                while i < len(current):
                    match = index_re.match(current[i])
                    first_index = int(match.group(1)) - 1
                    second_index = int(match.group(2)) - 1
                    for j in range(3):
                        array[first_index][second_index][j] = np.array(tuple(map(float, current[i + 1 + j].lstrip().split())))
                    i += 5
            params.append(array)
            current = []
            continue
        current.append(line)

In [None]:
def load_dataset():
    delighted_data = {}
    for mesh, mesh_dir in delighted_dirs.items():
        mesh_im = Image.open(mesh_dir + '%s_HD_BC.tga' % mesh, 'r')
        mesh_im = mesh_im.resize((256, 256), Image.ANTIALIAS)
        mesh_array = zero_rgb(np.array(mesh_im))
        mesh_array = mesh_array[:,:,:3]        
        delighted_data[mesh] = np.asarray([mesh_array])

    lighted_data = []
    pyramid_depth = 8
    reduction_factor = 2
    for t, mesh, v, mesh_dir in lighted_dirs:
        lit_file = '%s%s_%s_%s_Lit.tga' % (mesh_dir, t, mesh, v)
        if not os.path.exists(lit_file):
            continue
        mesh_im = Image.open(lit_file, 'r')
        # resize according to shapenet
        mesh_im = mesh_im.resize((256, 256), Image.ANTIALIAS)
        meshes = np.asarray([np.array(mesh_im)[:,:,:3]])
        lighted_data.append((mesh, meshes))
        
    return delighted_data, lighted_data

#unused method to scale from 0 to 1~ish
def preproc(unclean_batch_x):
    """Convert values to range 0-1"""
    temp_batch = unclean_batch_x / unclean_batch_x.max()
    return temp_batch

def shuffle(Xtrain, ytrain):
    stacked = np.column_stack((Xtrain,ytrain))
    np.random.shuffle(stacked)
    return stacked[:,:Xtrain.shape[1]], stacked[:,Xtrain.shape[1]]

date_dir = ''
delighted_dirs, lighted_dirs = scan_lighted_delighted(date_dir)

delighted_data, lighted_data = load_dataset()

In [None]:
def get_kernel_size(factor):
    """
    Find the kernel size given the desired factor of upsampling.
    """
    return 2 * factor - factor % 2


def upsample_filt(size):
    """
    Make a 2D bilinear kernel suitable for upsampling of the given (h, w) size.
    """
    factor = (size + 1) // 2
    if size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = np.ogrid[:size, :size]
    return (1 - abs(og[0] - center) / factor) * \
           (1 - abs(og[1] - center) / factor)

def maxpool2d(x, k=2):
    # MaxPool2D wrapper
    return tf.nn.max_pool(x, ksize=[1, k, k, 1], strides=[1, k, k, 1],
                          padding='SAME')

# wrapper for applying spatial conv, batchnorm, reLU
def conv2d(x, W, b, stride):
    # Conv2D wrapper, with bias and relu activation
    x = tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding='SAME')
    x = tf.nn.bias_add(x, b)
    # Batch Norm
    x = tf.contrib.layers.batch_norm(x, center=True, scale=True)
    return tf.nn.relu(x)

def deconv2d(x, W, b, stride):
    x = tf.nn.conv2d(x, W, strides=[1, stride, stride, 1], padding='SAME')
    x = tf.nn.bias_add(x, b)
    # Batch Norm
    x = tf.contrib.layers.batch_norm(x, center=True, scale=True)
    x = tf.nn.relu(x)
    
    # bilinear interpolation upsampling
    old_height, old_width = x.get_shape().as_list()[1 : 3]
    scale = 2
    new_height = old_height * scale
    new_width = old_width * scale
    return tf.image.resize_images(x, [new_height, new_width], method=tf.image.ResizeMethod.BILINEAR)

ENCODER = 'encoder'
MID = 'mid'
DECODER = 'decoder'

def get_dimension_name(stage, layer_num):
    return 'dims_' + stage + '_' + str(layer_num)

def get_weight_name(stage, layer_num):
    return 'w_' + stage + '_' + str(layer_num)
    
def get_bias_name(stage, layer_num):
    return 'b_' + stage + '_' + str(layer_num)

In [None]:
# Store layers weight & bias
# more parameters
FILTER_SIZE = 3
INPUT_DEPTH = 3

# output depths of the layers
ENCODER_DEPTHS = [16, 32, 64, 128, 256, 256]
MID_DEPTH = ENCODER_DEPTHS[-1]
DECODER_DEPTHS = [256, 128, 64, 32, 16, 16, 3]

# defines the sizes for each of the conv / deconv layers
layer_dimensions = {}
prev_depth = INPUT_DEPTH
for i, output_depth in enumerate(ENCODER_DEPTHS):
    weight_name = get_dimension_name(ENCODER, i)
    stride = 1 if i == 0 else 2 # stride 1 for the first conv layer only
    layer_dimensions[weight_name] = {
        'input_depth' : prev_depth,
        'output_depth' : output_depth,
        'stride' : stride,
        'filter_size' : FILTER_SIZE
    }
    prev_depth = output_depth

for i in range(4):
    weight_name = get_dimension_name(MID, i)
    layer_dimensions[weight_name] = {
        'input_depth' : MID_DEPTH,
        'output_depth' : MID_DEPTH,
        'stride' : 1,
        'filter_size' : FILTER_SIZE
    }

prev_depth = MID_DEPTH
for i, output_depth in enumerate(DECODER_DEPTHS):
    # the ith deconv layer's input is the concatenation of the previous output and the output of (5-i)th encoder conv layer
    # so the depth is the sum of the depths
    # the last two layers in this loop (5th and 6th layer) are conv layers, not deconv
    if i < 6:
        prev_encoder_depth = ENCODER_DEPTHS[5 - i]
        input_depth = prev_depth + prev_encoder_depth
    else: # 6th layer doesn't have concatenation
        input_depth = prev_depth
    weight_name = get_dimension_name(DECODER, i)
    layer_dimensions[weight_name] = {
        'input_depth' : input_depth,
        'output_depth' : output_depth,
        'stride' : 1,
        'filter_size' : FILTER_SIZE
    }
    prev_depth = output_depth

In [None]:
class IntrinsicNetwork(object):
    def __init__(self, input_dimensions, dimensions, learning_rate=1e-3):
        graph = tf.Graph()
        self.sess = tf.InteractiveSession(graph=graph)
        
        self.dimensions = dimensions
        self.params = {}
        
        # in the paper each input sample is 256x256x3
        self.inp = tf.placeholder(tf.float32, shape=[None,] + list(input_dimensions))
        self.expected_output = tf.placeholder(tf.float32, shape=[None,] + list(input_dimensions))
        
        encoder_layers = self.get_encoder_layers(self.inp)
        mid_output = self.get_mid_output(encoder_layers[-1])
        self.output = self.get_decoder_output(mid_output, encoder_layers)
        
        self.loss = tf.reduce_mean(tf.squared_difference(self.output, self.expected_output))
        self.optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss)
        init = tf.global_variables_initializer()
        self.sess.run(init)
    
    def create_weights(self, stage, layer_num):
        dims = self.dimensions[get_dimension_name(stage, layer_num)]
        input_depth, output_depth, stride, filter_size = \
            (dims[x] for x in ['input_depth', 'output_depth', 'stride', 'filter_size'])
        W = tf.Variable(tf.random_normal([filter_size, filter_size, input_depth, output_depth]))
        b = tf.Variable(tf.random_normal([output_depth]))
        self.params[get_weight_name(stage, layer_num)] = W
        self.params[get_bias_name(stage, layer_num)] = b
        return W, b, stride
        
    def get_encoder_layers(self, input):
        prev = input
        encoder_layers = []
        # encoder layers 0 to 5
        for layer_num in range(6):
            W, b, stride = self.create_weights(ENCODER, layer_num)
            prev = conv2d(prev, W, b, stride)
            encoder_layers.append(prev)
        return encoder_layers
    
    def get_mid_output(self, encoder_output):
        prev = encoder_output
        for layer_num in range(4):
            W, b, stride = self.create_weights(MID, layer_num)
            prev = conv2d(prev, W, b, stride)
        return prev
    
    def get_decoder_output(self, mid_output, encoder_layers):
        prev = mid_output
        for layer_num in range(7):
            W, b, stride = self.create_weights(DECODER, layer_num)
            if layer_num < 6: # concatenate with encoder layer for layers 0 to 5
                encoder_layer_input = encoder_layers[5 - layer_num]
                prev = tf.concat([prev, encoder_layer_input], axis=3)
            if layer_num < 5: # perform deconv on layers 0 to 4
                prev = deconv2d(prev, W, b, stride)
            else: # conv on layers 5 and 6
                prev = conv2d(prev, W, b, stride)
        return prev
    
    def fit_batch(self, inputs, output):
        _, loss = self.sess.run((self.optimizer, self.loss), feed_dict={self.inp : inputs, self.expected_output : output})
        return loss
    
    def train(self, lighted_inputs, delighted_data, epochs, batch_size=1, display_step=5):
        random.shuffle(lighted_inputs)
        n_samples = len(lighted_inputs)
        mean_losses = []
        for epoch in range(epochs):
            total_iter = n_samples // batch_size
            total_loss = 0
            for i in range(total_iter):
                mesh, inputs = lighted_inputs[i * batch_size]
                loss = self.fit_batch(inputs, delighted_data[mesh])
                total_loss += loss
            mean_loss = total_loss / total_iter
            mean_losses.append(mean_loss)
            if (epoch + 1) % display_step == 0:
                print('epoch %s: loss=%.4f' % (epoch + 1, mean_loss))
                
    def predict(self, inputs):
        return self.sess.run(self.output, feed_dict={self.inp : inputs})

In [None]:
inet = IntrinsicNetwork((256, 256, INPUT_DEPTH), layer_dimensions)
inet.train(lighted_data, delighted_data, epochs=1, batch_size=1)