Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValuerError #104

Closed
Anselmoo opened this issue Mar 30, 2020 · 18 comments
Closed

ValuerError #104

Anselmoo opened this issue Mar 30, 2020 · 18 comments

Comments

@Anselmoo
Copy link

Traceback (most recent call last):
File "capsulenet.py", line 310, in
routings=args.routings,
File "capsulenet.py", line 60, in CapsNet
)(primarycaps)
File "/home/Anselmoo/.local/lib/python3.6/site-packages/keras/engine/base_layer.py", line 489, in call
output = self.call(inputs, **kwargs)
File "/home/Anselmoo/GitHub-Projects/CapsNet-Keras/capsule/capsulelayers.py", line 160, in call
b += K.batch_dot(outputs, inputs_hat, [2, 3])
File "/home/Anselmoo/.local/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 1499, in batch_dot
'y.shape[%d] (%d != %d).' % (axes[0], axes[1], d1, d2))

ValueError: Can not do batch_dot on inputs with shapes (None, 10, 10, 1152, 16) and (None, 10, None, 1152, 16) with axes=[2, 3]. x.shape[2] != y.shape[3] (10 != 1152).

I was running the capsulelayers.py with the default settings.

@gcfengxu
Copy link

gcfengxu commented Apr 14, 2020

I encountered the same problems.It may be caused by the function of 'K.batch_dot()',but I don't know how to solve it.
I have emailed to Xifeng Guo via 163 Mail,but get no reply yet.
If you find a solution,please reply to me,thanks.

@Anselmoo
Copy link
Author

Thx for the reply, I take look these days

@gcfengxu
Copy link

Hey,I found a solution from brjathu/deepcaps@e273cfd .
It really works.

@Anselmoo
Copy link
Author

But that's a new project, there is no way of merging? Or crearting a pull-request?

@data-hound
Copy link

@gcfengxu Has this own_batch_dot been tested by training+testing on the MNIST dataset?

@gcfengxu
Copy link

gcfengxu commented Apr 25, 2020

@Anselmoo Sorry,I don't quite understand what you mean.I just copy the 'batchdot.py' , add it to my project,and import the file,then replace all 'K.map_fn()' with 'own_batch_dot()'.
I haven't analyzed how the function works.

@gcfengxu
Copy link

gcfengxu commented Apr 25, 2020

@data-hound Yes,I have tested it by training on the MNIST.Although the test is not completely over, it really in training.

@Anselmoo
Copy link
Author

@Anselmoo Sorry,I don't quite understand what you mean.I just copy the 'batchdot.py' , add it to my project,and import the file,then replace all 'K.map_fn()' with 'own_batch_dot()'.
I haven't analyzed how the function works.

@gcfengxu Maybe you can once commit your version and upload it as a pull request?? I think that would help a lot. Thx

@gcfengxu
Copy link

@Anselmoo @data-hound . Sorry,I mistook the function. In my project , I replace the function: K.batch_dot() with 'own_batch_dot().Sorry for the wrong word.

@Anselmoo
Copy link
Author

Anselmoo commented Apr 26, 2020

@gcfengxu Thank you very much! I got it and I would recommend the following steps:

  1. Fetch this project
  2. Make a new branch
  3. Upload via git or via web-interface your own_batch_dot() and the other modified Files.
  4. Commit these files
  5. Create a pull request for @XifengGuo and we can take a look

I think it would be great if we can work together and add further test like MINST. I think that’s also the idea of open source and allows us to further modified the capsule net. The benchmark settings of this implementation looks the best and it would be sad if we cannot use it anymore.

@gcfengxu
Copy link

@Anselmoo Well, I am new to Github,I'll try your advice,thx for your suggestions!

@Anselmoo
Copy link
Author

@gcfengxu no worries, everybody has to start 🛫 therefore we are here

@data-hound
Copy link

@Anselmoo @gcfengxu
I think the own_batch_dot method is actually wrong as it was defined in the earlier versions, according to this issue:
keras-team/keras#13300

So, as a solution, in my implementation, I had used the K.batch_dot, and then tried to reshape it further to conform to the matrix shapes.

@gcfengxu are you running the model with or without eager execution? I have ran into some problems with eager execution, and the passing of graph tensors outside the graph. This occurs just before the end of the 1st training epoch. Let me know if using this method gets you past these

@gcfengxu
Copy link

@data-hound My tensorflow version is 2.0, I never run into problems with eager execetion. But the K.batch_dot still wrong in version 2.3.1 of Keras.

@data-hound
Copy link

data-hound commented Apr 30, 2020

Cool. So your training and testing tasks have completed successfully using this own_batch_dot method, then @gcfengxu ?

Also, as mentioned in the Keras issue I had mentioned, the behaviour of K.batch_dot will remain the same, for example,

x_batch = K.ones(shape=(1152, 10, 1, 8)) 
y_batch = K.ones(shape=(1152, 10, 8, 16))
xy_batch_dot = K.batch_dot(x_batch, y_batch, axes=(3, 2))
K.int_shape(xy_batch_dot) 

the above code will yield shape as (1152, 10, 1, 10, 16) and not (1152,10,1,16) as was returned in previous versions. I have not checked the mathematical accuracy for shapes with this operation, but this is what fchollet has confirmed to be accurate.

@uchar
Copy link

uchar commented May 9, 2020

I'm using own_batch_dot it's running but I'm getting negative values for loss function after 3 or 4 epochs !
@data-hound Can you post your full code?

@data-hound
Copy link

@mrtucar it is known that the issue is with the new version of tensorflow and keras

@uchar I havent got past the 1st epoch, even with the own_batch_dot method. I am running on Kaggle. I will try on a fresh set up in a few days, and let you know

@XifengGuo
Copy link
Owner

@uchar @Anselmoo @mrtucar @data-hound @gcfengxu
The problem is caused by the behavior change of keras.backend.batch_dot.
In keras==2.0.7: a.shape->(2, 3, 4, 5), b.shape->(2, 3, 5, 6), batch_dot(a, b, (3, 2)).shape->(2, 3, 4, 6).
But in newer version: batch_dot(a, b, (3, 2)).shape->(2, 3, 4, 3, 6)

I propose to replace K.batch_dot with tf.matmul.
For details please refer to

def call(self, inputs, training=None):
# inputs.shape=[None, input_num_capsule, input_dim_capsule]
# inputs_expand.shape=[None, 1, input_num_capsule, input_dim_capsule, 1]
inputs_expand = tf.expand_dims(tf.expand_dims(inputs, 1), -1)
# Replicate num_capsule dimension to prepare being multiplied by W
# inputs_tiled.shape=[None, num_capsule, input_num_capsule, input_dim_capsule, 1]
inputs_tiled = tf.tile(inputs_expand, [1, self.num_capsule, 1, 1, 1])
# Compute `inputs * W` by scanning inputs_tiled on dimension 0.
# W.shape=[num_capsule, input_num_capsule, dim_capsule, input_dim_capsule]
# x.shape=[num_capsule, input_num_capsule, input_dim_capsule, 1]
# Regard the first two dimensions as `batch` dimension, then
# matmul(W, x): [..., dim_capsule, input_dim_capsule] x [..., input_dim_capsule, 1] -> [..., dim_capsule, 1].
# inputs_hat.shape = [None, num_capsule, input_num_capsule, dim_capsule]
inputs_hat = tf.squeeze(tf.map_fn(lambda x: tf.matmul(self.W, x), elems=inputs_tiled))
# Begin: Routing algorithm ---------------------------------------------------------------------#
# The prior for coupling coefficient, initialized as zeros.
# b.shape = [None, self.num_capsule, 1, self.input_num_capsule].
b = tf.zeros(shape=[inputs.shape[0], self.num_capsule, 1, self.input_num_capsule])
assert self.routings > 0, 'The routings should be > 0.'
for i in range(self.routings):
# c.shape=[batch_size, num_capsule, 1, input_num_capsule]
c = tf.nn.softmax(b, axis=1)
# c.shape = [batch_size, num_capsule, 1, input_num_capsule]
# inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
# The first two dimensions as `batch` dimension,
# then matmal: [..., 1, input_num_capsule] x [..., input_num_capsule, dim_capsule] -> [..., 1, dim_capsule].
# outputs.shape=[None, num_capsule, 1, dim_capsule]
outputs = squash(tf.matmul(c, inputs_hat)) # [None, 10, 1, 16]
if i < self.routings - 1:
# outputs.shape = [None, num_capsule, 1, dim_capsule]
# inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
# The first two dimensions as `batch` dimension, then
# matmal:[..., 1, dim_capsule] x [..., input_num_capsule, dim_capsule]^T -> [..., 1, input_num_capsule].
# b.shape=[batch_size, num_capsule, 1, input_num_capsule]
b += tf.matmul(outputs, inputs_hat, transpose_b=True)
# End: Routing algorithm -----------------------------------------------------------------------#
return tf.squeeze(outputs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants