Skip to content
This repository was archived by the owner on Feb 18, 2020. It is now read-only.

Commit dd0f049

Browse files
CArray::inner
1 parent 731d72b commit dd0f049

File tree

3 files changed

+98
-6
lines changed

3 files changed

+98
-6
lines changed

kernel/linalg.c

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ CArray_Matmul(CArray * ap1, CArray * ap2, CArray * out, MemoryPointer * ptr)
136136
CArray * result = NULL, * target1, * target2;
137137
int nd1, nd2, nd, typenum;
138138
int i, j, l, matchDim, is1, is2, axis, os;
139-
int * dimensions;
139+
int * dimensions = NULL;
140140
CArray_DotFunc *dot;
141141
CArrayIterator * it1, * it2;
142142
char * op;
@@ -240,8 +240,6 @@ CArray_Matmul(CArray * ap1, CArray * ap2, CArray * out, MemoryPointer * ptr)
240240
CArrayIterator_FREE(it2);
241241

242242
efree(dimensions);
243-
// Remove appended dimension
244-
result->ndim = ap1->ndim;
245243

246244
if (ptr != NULL) {
247245
add_to_buffer(ptr, result, sizeof(CArray));
@@ -777,4 +775,70 @@ CArray_Svd(CArray * a, int full_matrices, int compute_uv, MemoryPointer * out)
777775
efree(data);
778776
}
779777
return rtn;
778+
}
779+
780+
CArray *
781+
CArray_InnerProduct(CArray *op1, CArray *op2, MemoryPointer *out)
782+
{
783+
CArray *ap1 = NULL;
784+
CArray *ap2 = NULL;
785+
int typenum;
786+
CArrayDescriptor *typec = NULL;
787+
CArray* ap2t = NULL;
788+
int dims[CARRAY_MAXDIMS];
789+
CArray_Dims newaxes = {dims, 0};
790+
int i;
791+
CArray* ret = NULL;
792+
793+
typenum = CArray_ObjectType(op1, 0);
794+
typenum = CArray_ObjectType(op2, typenum);
795+
typec = CArray_DescrFromType(typenum);
796+
if (typec == NULL) {
797+
throw_typeerror_exception("Cannot find a common data type.");
798+
goto fail;
799+
}
800+
801+
CArrayDescriptor_INCREF(typec);
802+
ap1 = CArray_FromAny(op1, typec, 0, 0, CARRAY_ARRAY_ALIGNED);
803+
if (ap1 == NULL) {
804+
CArrayDescriptor_DECREF(typec);
805+
goto fail;
806+
}
807+
ap2 = (CArray *)CArray_FromAny(op2, typec, 0, 0, CARRAY_ARRAY_ALIGNED);
808+
if (ap2 == NULL) {
809+
goto fail;
810+
}
811+
812+
newaxes.len = CArray_NDIM(ap2);
813+
if ((CArray_NDIM(ap1) >= 1) && (newaxes.len >= 2)) {
814+
for (i = 0; i < newaxes.len - 2; i++) {
815+
dims[i] = (int)i;
816+
}
817+
dims[newaxes.len - 2] = newaxes.len - 1;
818+
dims[newaxes.len - 1] = newaxes.len - 2;
819+
820+
ap2t = CArray_Transpose(ap2, &newaxes, NULL);
821+
if (ap2t == NULL) {
822+
goto fail;
823+
}
824+
}
825+
else {
826+
ap2t = ap2;
827+
CArray_INCREF(ap2);
828+
}
829+
830+
ret = CArray_Matmul(ap1, ap2t, NULL, out);
831+
832+
if (ret == NULL) {
833+
goto fail;
834+
}
835+
836+
CArray_DECREF(op2);
837+
CArrayDescriptor_DECREF(CArray_DESCR(op2));
838+
839+
CArrayDescriptor_DECREF(typec);
840+
CArrayDescriptor_FREE(typec);
841+
return ret;
842+
fail:
843+
return NULL;
780844
}

kernel/linalg.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ CArray * CArray_Norm(CArray * a, int norm, MemoryPointer * out);
1313
CArray * CArray_Det(CArray * a, MemoryPointer * out);
1414
CArray * CArray_Vdot(CArray * target_a, CArray * target_b, MemoryPointer * out);
1515
CArray ** CArray_Svd(CArray * a, int full_matrices, int compute_uv, MemoryPointer * out);
16+
CArray * CArray_InnerProduct(CArray *a, CArray *b, MemoryPointer *out);
1617
#endif

phpsci.c

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -385,12 +385,21 @@ PHP_METHOD(CArray, reshape)
385385
Z_PARAM_ZVAL(target)
386386
Z_PARAM_ZVAL(new_shape_zval)
387387
ZEND_PARSE_PARAMETERS_END();
388+
389+
if(ZEND_NUM_ARGS() == 1) {
390+
throw_valueerror_exception("Expected 2 arguments");
391+
return;
392+
}
393+
388394
ZVAL_TO_MEMORYPOINTER(target, &ptr, NULL);
389395
carray = CArray_FromMemoryPointer(&ptr);
390396
new_shape = ZVAL_TO_TUPLE(new_shape_zval, &ndim);
391397
newcarray = CArray_Newshape(carray, new_shape, zend_hash_num_elements(Z_ARRVAL_P(new_shape_zval)), CARRAY_CORDER, &ptr);
392398
FREE_TUPLE(new_shape);
393399

400+
if (newcarray == NULL) {
401+
return;
402+
}
394403
RETURN_MEMORYPOINTER(return_value, &ptr);
395404
}
396405

@@ -1124,10 +1133,28 @@ PHP_METHOD(CArray, vdot)
11241133
}
11251134
PHP_METHOD(CArray, inner)
11261135
{
1127-
zval * a;
1128-
ZEND_PARSE_PARAMETERS_START(1, 1)
1129-
Z_PARAM_ZVAL(a)
1136+
MemoryPointer rtn_ptr, a_ptr, b_ptr;
1137+
zval *a, *b;
1138+
CArray *a_ca, *b_ca, *rtn_ca;
1139+
ZEND_PARSE_PARAMETERS_START(2, 2)
1140+
Z_PARAM_ZVAL(a)
1141+
Z_PARAM_ZVAL(b)
11301142
ZEND_PARSE_PARAMETERS_END();
1143+
ZVAL_TO_MEMORYPOINTER(a, &a_ptr, NULL);
1144+
ZVAL_TO_MEMORYPOINTER(b, &b_ptr, NULL);
1145+
1146+
a_ca = CArray_FromMemoryPointer(&a_ptr);
1147+
b_ca = CArray_FromMemoryPointer(&b_ptr);
1148+
1149+
rtn_ca = CArray_InnerProduct(a_ca, b_ca, &rtn_ptr);
1150+
1151+
FREE_FROM_MEMORYPOINTER(&a_ptr);
1152+
FREE_FROM_MEMORYPOINTER(&b_ptr);
1153+
if (rtn_ca == NULL) {
1154+
return;
1155+
}
1156+
1157+
RETURN_MEMORYPOINTER(return_value, &rtn_ptr);
11311158
}
11321159
PHP_METHOD(CArray, outer)
11331160
{

0 commit comments

Comments
 (0)