@@ -99,10 +99,11 @@ RAI_Tensor *RAI_TensorNew(void) {
9999 RAI_Tensor * ret = RedisModule_Calloc (1 , sizeof (* ret ));
100100 ret -> refCount = 1 ;
101101 ret -> len = LEN_UNKOWN ;
102+ return ret ;
102103}
103104
104105RAI_Tensor * RAI_TensorCreateWithDLDataType (DLDataType dtype , long long * dims , int ndims ,
105- int tensorAllocMode ) {
106+ bool empty ) {
106107
107108 size_t dtypeSize = Tensor_DataTypeSize (dtype );
108109 if (dtypeSize == 0 ) {
@@ -124,20 +125,14 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
124125 }
125126
126127 DLDevice device = (DLDevice ){.device_type = kDLCPU , .device_id = 0 };
127- void * data = NULL ;
128- switch (tensorAllocMode ) {
129- case TENSORALLOC_ALLOC :
130- data = RedisModule_Alloc (len * dtypeSize );
131- break ;
132- case TENSORALLOC_CALLOC :
128+
129+ // If we return an empty tensor, we initialize the data with zeros to avoid security
130+ // issues. Otherwise, we only allocate without initializing (for better performance)
131+ void * data ;
132+ if (empty ) {
133133 data = RedisModule_Calloc (len , dtypeSize );
134- break ;
135- case TENSORALLOC_NONE :
136- /* shallow copy no alloc */
137- default :
138- /* assume TENSORALLOC_NONE
139- shallow copy no alloc */
140- break ;
134+ } else {
135+ data = RedisModule_Alloc (len * dtypeSize );
141136 }
142137
143138 ret -> tensor = (DLManagedTensor ){.dl_tensor = (DLTensor ){.device = device ,
@@ -214,27 +209,11 @@ RAI_Tensor *_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, size_t dtype
214209 return ret ;
215210}
216211
217- RAI_Tensor * RAI_TensorCreate (const char * dataType , long long * dims , int ndims , int hasdata ) {
212+ // Important note: the tensor data must be initialized after the creation.
213+ RAI_Tensor * RAI_TensorCreate (const char * dataType , long long * dims , int ndims ) {
218214 DLDataType dtype = RAI_TensorDataTypeFromString (dataType );
219- return RAI_TensorCreateWithDLDataType (dtype , dims , ndims , TENSORALLOC_ALLOC );
220- }
221-
222- #if 0
223- void RAI_TensorMoveFrom (RAI_Tensor * dst , RAI_Tensor * src ) {
224- if (-- dst -> refCount <= 0 ){
225- RedisModule_Free (t -> tensor .shape );
226- if (t -> tensor .strides ) {
227- RedisModule_Free (t -> tensor .strides );
228- }
229- RedisModule_Free (t -> tensor .data );
230- RedisModule_Free (t );
231- }
232- dst -> tensor .ctx = src -> tensor .ctx ;
233- dst -> tensor .data = src -> tensor .data ;
234-
235- dst -> refCount = 1 ;
215+ return RAI_TensorCreateWithDLDataType (dtype , dims , ndims , false);
236216}
237- #endif
238217
239218RAI_Tensor * RAI_TensorCreateByConcatenatingTensors (RAI_Tensor * * ts , long long n ) {
240219
@@ -273,7 +252,7 @@ RAI_Tensor *RAI_TensorCreateByConcatenatingTensors(RAI_Tensor **ts, long long n)
273252
274253 DLDataType dtype = RAI_TensorDataType (ts [0 ]);
275254
276- RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , TENSORALLOC_ALLOC );
255+ RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , false );
277256
278257 for (long long i = 0 ; i < n ; i ++ ) {
279258 memcpy (RAI_TensorData (ret ) + batch_offsets [i ] * sample_size * dtype_size ,
@@ -300,7 +279,7 @@ RAI_Tensor *RAI_TensorCreateBySlicingTensor(RAI_Tensor *t, long long offset, lon
300279
301280 DLDataType dtype = RAI_TensorDataType (t );
302281
303- RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , TENSORALLOC_ALLOC );
282+ RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , false );
304283
305284 memcpy (RAI_TensorData (ret ), RAI_TensorData (t ) + offset * sample_size * dtype_size ,
306285 len * sample_size * dtype_size );
@@ -329,14 +308,14 @@ int RAI_TensorDeepCopy(RAI_Tensor *t, RAI_Tensor **dest) {
329308
330309 DLDataType dtype = RAI_TensorDataType (t );
331310
332- RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , TENSORALLOC_ALLOC );
311+ RAI_Tensor * ret = RAI_TensorCreateWithDLDataType (dtype , dims , ndims , false );
333312
334313 memcpy (RAI_TensorData (ret ), RAI_TensorData (t ), sample_size * dtype_size );
335314 * dest = ret ;
336315 return 0 ;
337316}
338317
339- // Beware: this will take ownership of dltensor
318+ // Beware: this will take ownership of dltensor.
340319RAI_Tensor * RAI_TensorCreateFromDLTensor (DLManagedTensor * dl_tensor ) {
341320
342321 RAI_Tensor * ret = RAI_TensorNew ();
@@ -419,19 +398,15 @@ int RAI_TensorSetValueFromLongLong(RAI_Tensor *t, long long i, long long val) {
419398 case 8 :
420399 ((int8_t * )data )[i ] = val ;
421400 break ;
422- break ;
423401 case 16 :
424402 ((int16_t * )data )[i ] = val ;
425403 break ;
426- break ;
427404 case 32 :
428405 ((int32_t * )data )[i ] = val ;
429406 break ;
430- break ;
431407 case 64 :
432408 ((int64_t * )data )[i ] = val ;
433409 break ;
434- break ;
435410 default :
436411 return 0 ;
437412 }
@@ -440,19 +415,15 @@ int RAI_TensorSetValueFromLongLong(RAI_Tensor *t, long long i, long long val) {
440415 case 8 :
441416 ((uint8_t * )data )[i ] = val ;
442417 break ;
443- break ;
444418 case 16 :
445419 ((uint16_t * )data )[i ] = val ;
446420 break ;
447- break ;
448421 case 32 :
449422 ((uint32_t * )data )[i ] = val ;
450423 break ;
451- break ;
452424 case 64 :
453425 ((uint64_t * )data )[i ] = val ;
454426 break ;
455- break ;
456427 default :
457428 return 0 ;
458429 }
@@ -642,7 +613,6 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
642613
643614 const char * fmtstr ;
644615 int datafmt = TENSOR_NONE ;
645- int tensorAllocMode = TENSORALLOC_CALLOC ;
646616 size_t ndims = 0 ;
647617 long long len = 1 ;
648618 long long * dims = (long long * )array_new (long long , 1 );
@@ -656,7 +626,6 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
656626 remaining_args = argc - 1 - argpos ;
657627 if (!strcasecmp (opt , "BLOB" )) {
658628 datafmt = TENSOR_BLOB ;
659- tensorAllocMode = TENSORALLOC_CALLOC ;
660629 // if we've found the dataformat there are no more dimensions
661630 // check right away if the arity is correct
662631 if (remaining_args != 1 && enforceArity == 1 ) {
@@ -669,7 +638,6 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
669638 break ;
670639 } else if (!strcasecmp (opt , "VALUES" )) {
671640 datafmt = TENSOR_VALUES ;
672- tensorAllocMode = TENSORALLOC_CALLOC ;
673641 // if we've found the dataformat there are no more dimensions
674642 // check right away if the arity is correct
675643 if (remaining_args != len && enforceArity == 1 ) {
@@ -699,7 +667,8 @@ int RAI_parseTensorSetArgs(RedisModuleString **argv, int argc, RAI_Tensor **t, i
699667 RedisModuleString * rstr = argv [argpos ];
700668 * t = _TensorCreateWithDLDataTypeAndRString (datatype , datasize , dims , ndims , rstr , error );
701669 } else {
702- * t = RAI_TensorCreateWithDLDataType (datatype , dims , ndims , tensorAllocMode );
670+ bool is_empty = (datafmt == TENSOR_NONE );
671+ * t = RAI_TensorCreateWithDLDataType (datatype , dims , ndims , is_empty );
703672 }
704673 if (!(* t )) {
705674 array_free (dims );
0 commit comments