Skip to content

Commit

Permalink
better naming and init docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
Bihaqo committed Dec 29, 2018
1 parent 05a02e0 commit 621c468
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions t3f/neural.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Utils for simplifying building neural networks with TT-layers"""

from keras.engine.topology import Layer
from keras.layers import Activation
import t3f
Expand All @@ -8,7 +10,7 @@
inits = ['glorot', 'he', 'lecun']


class TTDense(Layer):
class KerasTTDense(Layer):
counter = 0

def __init__(self, row_dims, column_dims, tt_rank=2, init='glorot',
Expand Down Expand Up @@ -41,7 +43,7 @@ def __init__(self, row_dims, column_dims, tt_rank=2, init='glorot',
self.bias = bias
self.bias_init = bias_init
self.init = init
super(TTDense, self).__init__(**kwargs)
super(Keras, self).__init__(**kwargs)

def build(self, input_shape):
if self.init == 'glorot':
Expand All @@ -56,15 +58,15 @@ def build(self, input_shape):
else:
raise ValueError('Unknown init "%s", only %s are supported'
% (self.init, inits))
name = 'tt_dense_matrix_{}'.format(TTDense.counter)
name = 'tt_dense_matrix_{}'.format(KerasTTDense.counter)
self.W = t3f.get_variable(name, initializer=initializer)
self.b = None
if self.bias:
b_name = 'tt_dense_b_{}'.format(TTDense.counter)
b_name = 'tt_dense_b_{}'.format(KerasTTDense.counter)
b_init = tf.constant_initializer(self.bias_init)
self.b = tf.get_variable(b_name, shape=self.output_dim,
initializer=b_init)
TTDense.counter += 1
KerasTTDense.counter += 1
self.trainable_weights = list(self.W.tt_cores)
if self.b is not None:
self.trainable_weights.append(self.b)
Expand Down

0 comments on commit 621c468

Please sign in to comment.