In [1]:
from PIL import Image
import os
import numpy as np
import tensorflow as tf

In [2]:
path = './dataset/instruction'
img_files = os.listdir(path)

img_list = []
for f in img_files:
    img_path = os.path.join(path, f)
    try:
        img = np.array(Image.open(img_path)).astype(np.int32)
    except:
        continue
    img_list.append(img)

In [3]:
img = img_list[2]
img

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 4, 0],
       [0, 0, 6, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0],
       [0, 3, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0],
       [0, 3, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 3, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0],
       [0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 6, 0],
       [0, 3, 4, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0],
       [0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 6, 0],
       [0, 4, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 6, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 4, 0],
       [0, 0, 6, 0, 0, 0, 0, 0, 6,

In [4]:
instr = tf.convert_to_tensor(img, dtype=tf.int32)
instr = tf.reshape(instr, [1, instr.shape[0], instr.shape[1], 1])

In [5]:
params = {}
dx = [ -1, 0, 1, -1, 1, -1, 0, 1]
dy = [ -1, -1, -1, 0, 0, 1, 1, 1]
src_from = { -1: 1, 0: 0, 1: 0 }
trg_from = { -1: 0, 0: 0, 1: 1 }
rng_size = { -1: 19, 0: 20, 1: 19 }
B, h, w, channels = instr.get_shape()
syntax_weight = params.get('syntax_weight', 10) # used for binary case
syntax_softmax = params.get('syntax_softmax', 1)
# transform logits with softmax
losses = []
for i in range(1):
    matname = os.path.join('./dataset', 'syntax', 'T' + str(i+1) + '.txt')
    T = np.loadtxt(matname, delimiter = ',')
    # select target slice of instructions
    t_src_slice = tf.slice(instr,
        [0, src_from[dy[i]], src_from[dx[i]], 0],
        [B.value, rng_size[dy[i]], rng_size[dx[i]], channels.value])
    t_trg_slice = tf.slice(instr,
        [0, trg_from[dy[i]], trg_from[dx[i]], 0],
        [B.value, rng_size[dy[i]], rng_size[dx[i]], channels.value])
    # type of loss
    # binary => use label map
    t_src = tf.reshape(t_src_slice, [-1])
    t_trg = tf.reshape(t_trg_slice, [-1]) # instr[:, trg_y, trg_x, :], [-1, 1])
    t_indices = tf.stack([t_src, t_trg], axis = 1)
    # t_indices = debug(t_indices, 't_indices', dtype = tf.int64)
    t_count = tf.gather_nd(T, t_indices)
    # t_count = debug(t_count, 't_count')
    t_bad = tf.ones_like(t_count, tf.float32) * syntax_weight
    t_good = tf.zeros_like(t_count, tf.float32)
    t_loss = tf.where(t_count < 1.0, t_bad, t_good)
#     T[T >= 1] = 1
#     P = tf.reshape(1 - T.astype(np.float32), [1,1,1,17,17]) # penalty matrix whose one entries denote invalid pairs
#     P = tf.tile(P, [B.value, rng_size[dy[i]], rng_size[dx[i]], 1, 1])
#     # compute loss using Einstein summation notation
#     # note: we do not reduce as this is done automatically in the total loss computation
#     t_loss = tf.einsum('bhwi,bhwij,bhwj->bhw', t_src_slice, P, t_trg_slice)
#     t_loss = tf.reshape(fn_smooth_l1(t_loss), [-1, 1]) # to be able to use concat below

In [7]:
with tf.Session() as sess:
    print(sess.run([t_count, t_loss]))

[array([3.3742e+06, 3.3742e+06, 3.3742e+06, 3.3742e+06, 3.3742e+06,
       3.3742e+06, 3.3742e+06, 3.3742e+06, 3.3742e+06, 3.3742e+06,
       3.3742e+06, 3.3742e+06, 3.3742e+06, 3.3742e+06, 3.3742e+06,
       3.3742e+06, 3.3742e+06, 3.3742e+06, 3.3742e+06, 3.3742e+06,
       3.3742e+06, 3.3742e+06, 3.3742e+06, 8.4513e+04, 3.3742e+06,
       3.3742e+06, 3.3742e+06, 3.3742e+06, 8.4513e+04, 3.3742e+06,
       3.3742e+06, 3.3742e+06, 3.3742e+06, 3.3742e+06, 3.3742e+06,
       3.3742e+06, 8.4513e+04, 3.3742e+06, 3.3742e+06, 9.5811e+04,
       3.3742e+06, 3.3742e+06, 3.3742e+06, 8.9999e+04, 3.3742e+06,
       9.5811e+04, 3.3742e+06, 3.3742e+06, 8.9999e+04, 3.3742e+06,
       3.3742e+06, 3.3742e+06, 9.5811e+04, 3.3742e+06, 3.3742e+06,
       3.3742e+06, 8.9999e+04, 7.5086e+04, 3.3742e+06, 8.1665e+04,
       8.4513e+04, 3.3742e+06, 3.3742e+06, 3.3742e+06, 3.3742e+06,
       8.1665e+04, 3.3742e+06, 3.3742e+06, 3.3742e+06, 3.3742e+06,
       3.3742e+06, 3.3742e+06, 8.1665e+04, 8.4513e+04, 3.3742