Skip to content

Commit

Permalink
KerasTTDense -> KerasDense and fixes w.r.t. new arg names
Browse files Browse the repository at this point in the history
  • Loading branch information
Bihaqo committed Dec 29, 2018
1 parent 0f85023 commit 560c6c3
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions t3f/neural.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tensorflow as tf


class KerasTTDense(Layer):
class KerasDense(Layer):
counter = 0

def __init__(self, input_dims, output_dims, tt_rank=2,
Expand Down Expand Up @@ -42,36 +42,36 @@ def __init__(self, input_dims, output_dims, tt_rank=2,
self.use_bias = use_bias
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
super(KerasTTDense, self).__init__(**kwargs)
super(KerasDense, self).__init__(**kwargs)

def build(self, input_shape):
if self.init == 'glorot':
if self.kernel_initializer == 'glorot':
initializer = t3f.glorot_initializer(self.tt_shape,
tt_rank=self.tt_rank)
elif self.init == 'he':
elif self.kernel_initializer == 'he':
initializer = t3f.he_initializer(self.tt_shape,
tt_rank=self.tt_rank)
elif self.init == 'lecun':
elif self.kernel_initializer == 'lecun':
initializer = t3f.lecun_initializer(self.tt_shape,
tt_rank=self.tt_rank)
else:
raise ValueError('Unknown kernel_initializer "%s", only "glorot",'
'"he", and "lecun" are supported' % self.init)
name = 'tt_dense_matrix_{}'.format(KerasTTDense.counter)
'"he", and "lecun" are supported' % self.kernel_initializer)
name = 'tt_dense_matrix_{}'.format(KerasDense.counter)
self.W = t3f.get_variable(name, initializer=initializer)
self.b = None
if self.bias:
b_name = 'tt_dense_b_{}'.format(KerasTTDense.counter)
b_init = tf.constant_initializer(self.bias_init)
if self.use_bias:
b_name = 'tt_dense_b_{}'.format(KerasDense.counter)
b_init = tf.constant_initializer(self.bias_initializer)
self.b = tf.get_variable(b_name, shape=self.output_dim,
initializer=b_init)
KerasTTDense.counter += 1
KerasDense.counter += 1
self.trainable_weights = list(self.W.tt_cores)
if self.b is not None:
self.trainable_weights.append(self.b)

def call(self, x):
if self.bias:
if self.use_bias:
h = t3f.matmul(x, self.W) + self.b
else:
h = t3f.matmul(x, self.W)
Expand Down

0 comments on commit 560c6c3

Please sign in to comment.