Skip to content

Commit 2db3098

Browse files
changed matmul
1 parent 5d0ea37 commit 2db3098

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

.vscode/launch.json

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
{
2-
// Use IntelliSense to learn about possible attributes.
3-
// Hover to view descriptions of existing attributes.
4-
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
52
"version": "0.2.0",
63
"configurations": [
74
{
85
"name": "CUDA C++: Launch",
96
"type": "cuda-gdb",
107
"request": "launch",
11-
"program": "${workspaceFolder}/build/TensorCore"
8+
"program": "${workspaceFolder}/build/test/binarytensor_test"
129
},
1310
{
1411
"name": "CUDA C++: Attach",
1512
"type": "cuda-gdb",
1613
"request": "attach"
17-
}
14+
},
1815
]
1916
}

src/binary_tensor/core/tensor_blas.cu

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ namespace binary_tensor
161161
TensorBase base_b = b.get_buffer().change_device(this_cuda);
162162
const std::initializer_list<unsigned int> shape_a = base_a.shape();
163163
const std::initializer_list<unsigned int> shape_b = base_b.shape();
164-
assert(shape_a.size() == 2 && shape_b.size() == 2 && shape_a.end()[-1] == shape_b.end()[-2]);
164+
assert(shape_a.size() == 2 &&shape_b.size() == 2 && shape_a.end()[-1] == shape_b.end()[-2]);
165165
std::vector<std::pair<Tensor, Derivation>> temp;
166166
if (is_derive)
167167
{
@@ -196,7 +196,7 @@ namespace binary_tensor
196196
static_cast<const uint1_t_x8*>(base_b.data())
197197
);
198198

199-
TensorBase value_buf({ batch_size, shape_a.end()[-2] , shape_b.end()[-1] }, c_ptr, this_cuda);
199+
TensorBase value_buf({ shape_a.end()[-2] , shape_b.end()[-1] }, c_ptr, this_cuda);
200200
cudaStat = cudaFree(c_ptr);
201201
return Tensor(std::move(value_buf), std::move(temp));
202202
}
@@ -212,7 +212,7 @@ namespace binary_tensor
212212
TensorBase base_b = b.get_buffer().change_device(this_cuda);
213213
const std::initializer_list<unsigned int> shape_a = base_a.shape();
214214
const std::initializer_list<unsigned int> shape_b = base_b.shape();
215-
assert(shape_a.size() == shape_b.size() && std::memcmp(shape_a.begin(), shape_b.begin(), std::min(shape_a.size(), shape_b.size()) - 2) && shape_a.end()[-1] == shape_b.end()[-2]);
215+
assert(shape_a.size() == shape_b.size() && std::memcmp(shape_a.begin(), shape_b.begin(), std::min(shape_a.size(), shape_b.size()) - 2) == 0 && shape_a.end()[-1] == shape_b.end()[-2]);
216216
std::vector<std::pair<Tensor, Derivation>> temp;
217217
if (is_derive)
218218
{
@@ -247,7 +247,10 @@ namespace binary_tensor
247247
static_cast<const uint1_t_x8*>(base_b.data())
248248
);
249249

250-
TensorBase value_buf({ batch_size, shape_a.end()[-2] , shape_b.end()[-1] }, c_ptr, this_cuda);
250+
std::vector<unsigned int> out_dims = shape_a;
251+
out_dims[out_dims.size() - 1] = shape_b.end()[-1];
252+
253+
TensorBase value_buf(out_dims, c_ptr, this_cuda);
251254
cudaStat = cudaFree(c_ptr);
252255
return Tensor(std::move(value_buf), std::move(temp));
253256
}

test/main.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using namespace binary_tensor::dtype;
66
int main(int argc, char const *argv[])
77
{
88
/* code */
9-
TensorArray<2, 2> a1 =
9+
TensorArray<2, 3> a1 =
1010
{{
1111
{{
1212
1
@@ -15,15 +15,15 @@ int main(int argc, char const *argv[])
1515
1
1616
}}
1717
}};
18-
TensorArray<2, 2> a2 =
18+
TensorArray<3, 2> a2 =
1919
{{
2020
{{
2121
1, 1
2222
}}
2323
}};
2424
Tensor a01 = Tensor(a1);
2525
Tensor a02 = Tensor(a2);
26-
auto b = a01 + a02;
26+
auto b = matmul(a01, a02);
2727
b.calc_grad(ones(b.get_buffer().shape()));
2828
std::cout << b << std::endl <<
2929
a01.get_grad() << std::endl <<

0 commit comments

Comments
 (0)