@@ -161,7 +161,7 @@ namespace binary_tensor
161161 TensorBase base_b = b.get_buffer ().change_device (this_cuda);
162162 const std::initializer_list<unsigned int > shape_a = base_a.shape ();
163163 const std::initializer_list<unsigned int > shape_b = base_b.shape ();
164- assert (shape_a.size () == 2 && shape_b.size () == 2 && shape_a.end ()[-1 ] == shape_b.end ()[-2 ]);
164+ assert (shape_a.size () == 2 &&shape_b.size () == 2 && shape_a.end ()[-1 ] == shape_b.end ()[-2 ]);
165165 std::vector<std::pair<Tensor, Derivation>> temp;
166166 if (is_derive)
167167 {
@@ -196,7 +196,7 @@ namespace binary_tensor
196196 static_cast <const uint1_t_x8*>(base_b.data ())
197197 );
198198
199- TensorBase value_buf ({ batch_size, shape_a.end ()[-2 ] , shape_b.end ()[-1 ] }, c_ptr, this_cuda);
199+ TensorBase value_buf ({ shape_a.end ()[-2 ] , shape_b.end ()[-1 ] }, c_ptr, this_cuda);
200200 cudaStat = cudaFree (c_ptr);
201201 return Tensor (std::move (value_buf), std::move (temp));
202202 }
@@ -212,7 +212,7 @@ namespace binary_tensor
212212 TensorBase base_b = b.get_buffer ().change_device (this_cuda);
213213 const std::initializer_list<unsigned int > shape_a = base_a.shape ();
214214 const std::initializer_list<unsigned int > shape_b = base_b.shape ();
215- assert (shape_a.size () == shape_b.size () && std::memcmp (shape_a.begin (), shape_b.begin (), std::min (shape_a.size (), shape_b.size ()) - 2 ) && shape_a.end ()[-1 ] == shape_b.end ()[-2 ]);
215+ assert (shape_a.size () == shape_b.size () && std::memcmp (shape_a.begin (), shape_b.begin (), std::min (shape_a.size (), shape_b.size ()) - 2 ) == 0 && shape_a.end ()[-1 ] == shape_b.end ()[-2 ]);
216216 std::vector<std::pair<Tensor, Derivation>> temp;
217217 if (is_derive)
218218 {
@@ -247,7 +247,10 @@ namespace binary_tensor
247247 static_cast <const uint1_t_x8*>(base_b.data ())
248248 );
249249
250- TensorBase value_buf ({ batch_size, shape_a.end ()[-2 ] , shape_b.end ()[-1 ] }, c_ptr, this_cuda);
250+ std::vector<unsigned int > out_dims = shape_a;
251+ out_dims[out_dims.size () - 1 ] = shape_b.end ()[-1 ];
252+
253+ TensorBase value_buf (out_dims, c_ptr, this_cuda);
251254 cudaStat = cudaFree (c_ptr);
252255 return Tensor (std::move (value_buf), std::move (temp));
253256 }
0 commit comments