-
Notifications
You must be signed in to change notification settings - Fork 364
/
tensor.rs
1078 lines (991 loc) · 31.3 KB
/
tensor.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use core::ops::Range;
/// Operations on float tensors.
pub trait TensorOps<B: Backend> {
/// Creates a new tensor from the data structure.
///
/// # Arguments
///
/// * `data` - The data structure.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given data.
fn from_data<const D: usize>(
data: Data<FloatElem<B>, D>,
device: &Device<B>,
) -> FloatTensor<B, D>;
/// Creates a new tensor with random values.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `distribution` - The distribution to sample from.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given shape and random values.
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution,
device: &Device<B>,
) -> FloatTensor<B, D>;
/// Creates a new tensor with zeros.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given shape and zeros.
fn zeros<const D: usize>(shape: Shape<D>, device: &Device<B>) -> FloatTensor<B, D> {
Self::from_data(Data::zeros(shape), device)
}
/// Creates a new tensor with ones.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given shape and ones.
fn ones<const D: usize>(shape: Shape<D>, device: &Device<B>) -> FloatTensor<B, D> {
Self::from_data(Data::ones(shape), device)
}
/// Creates a tensor filled with given value.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `fill_value` - The value with which to fill the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor filled with given value
fn full<const D: usize>(
shape: Shape<D>,
fill_value: FloatElem<B>,
device: &Device<B>,
) -> FloatTensor<B, D> {
Self::add_scalar(Self::zeros(shape, device), fill_value)
}
/// Gets the shape of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The shape of the tensor.
fn shape<const D: usize>(tensor: &FloatTensor<B, D>) -> Shape<D>;
/// Converts the tensor to a data structure.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The data structure with the tensor's data.
fn to_data<const D: usize>(tensor: &FloatTensor<B, D>) -> Reader<Data<FloatElem<B>, D>> {
Self::into_data(tensor.clone())
}
/// Converts the tensor to a data structure.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The data structure with the tensor's data.
fn into_data<const D: usize>(tensor: FloatTensor<B, D>) -> Reader<Data<FloatElem<B>, D>>;
/// Gets the device of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The device of the tensor.
fn device<const D: usize>(tensor: &FloatTensor<B, D>) -> Device<B>;
/// Moves the tensor to the given device.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `device` - The device to move the tensor to.
///
/// # Returns
///
/// The tensor on the given device.
fn to_device<const D: usize>(
tensor: FloatTensor<B, D>,
device: &Device<B>,
) -> FloatTensor<B, D>;
/// Creates a new tensor with values from the given range.
///
/// # Arguments
///
/// * `range` - The range of values.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given values.
///
/// # Remarks
///
/// Uses `arange_step` with a step size of 1 under the hood.
fn arange(range: Range<usize>, device: &Device<B>) -> IntTensor<B, 1> {
Self::arange_step(range, 1, device)
}
/// Converts float tensor to int tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The int tensor with the same data as the float tensor.
fn into_int<const D: usize>(tensor: FloatTensor<B, D>) -> IntTensor<B, D>;
/// Creates a new tensor with values from the given range with the given step size.
///
/// # Arguments
///
/// * `range` - The range of values.
/// * `step` - The step size.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given values.
fn arange_step(range: Range<usize>, step: usize, device: &Device<B>) -> IntTensor<B, 1> {
let value = range
.step_by(step)
.map(|i| (i as i64).elem())
.collect::<Vec<IntElem<B>>>();
let shape = Shape::new([value.len()]);
let data = Data::new(value, shape);
B::int_from_data(data, device)
}
/// Creates an empty tensor with the given shape.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The empty tensor with the given shape.
fn empty<const D: usize>(shape: Shape<D>, device: &Device<B>) -> FloatTensor<B, D>;
/// Repeat the tensor along the given dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension to repeat.
/// * `times` - The number of times to repeat the dimension.
///
/// # Returns
///
/// The tensor with the given dimension repeated.
fn repeat<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,
times: usize,
) -> FloatTensor<B, D> {
let mut shape = B::shape(&tensor);
if shape.dims[dim] != 1 {
panic!("Can only repeat dimension with dim=1");
}
shape.dims[dim] = times;
let mut i = 0;
let indices_select_all = [0; D].map(|_| {
let start = 0;
let end = shape.dims[i];
i += 1;
start..end
});
let mut tensor_output = B::empty(shape, &B::device(&tensor));
for i in 0..times {
let mut indices = indices_select_all.clone();
indices[dim] = i..i + 1;
tensor_output = B::slice_assign(tensor_output, indices, tensor.clone());
}
tensor_output
}
/// Adds two tensors together.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The result of adding the two tensors together.
fn add<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Adds a scalar to a tensor.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The result of adding the scalar to the tensor.
fn add_scalar<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> FloatTensor<B, D>;
/// Clamps a tensor under a minimum value.
///
/// # Arguments
///
/// * `tensor` - The tensor to clamp.
/// * `min` - The minimum value.
///
/// # Returns
///
/// The clamped tensor.
fn clamp_min<const D: usize>(
tensor: FloatTensor<B, D>,
min: FloatElem<B>,
) -> FloatTensor<B, D> {
// Default implementation
let mask = Self::lower_elem(tensor.clone(), min);
B::mask_fill(tensor, mask, min)
}
/// Clamps a tensor over a maximum value.
///
/// # Arguments
///
/// * `tensor` - The tensor to clamp.
/// * `max` - The maximum value.
///
/// # Returns
///
/// The clamped tensor.
fn clamp_max<const D: usize>(
tensor: FloatTensor<B, D>,
max: FloatElem<B>,
) -> FloatTensor<B, D> {
// Default implementation
let mask = Self::greater_elem(tensor.clone(), max);
B::mask_fill(tensor, mask, max)
}
/// Clamps a tensor between a minimum and maximum value.
///
/// # Arguments
///
/// * `tensor` - The tensor to clamp.
/// * `min` - The minimum value.
/// * `max` - The maximum value.
///
/// # Returns
///
/// The clamped tensor.
fn clamp<const D: usize>(
tensor: FloatTensor<B, D>,
min: FloatElem<B>,
max: FloatElem<B>,
) -> FloatTensor<B, D> {
// Default implementation
Self::clamp_min(Self::clamp_max(tensor, max), min)
}
/// Subtracts two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The result of subtracting the two tensors.
fn sub<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Subtracts a scalar from a tensor.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The result of subtracting the scalar from the tensor.
fn sub_scalar<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> FloatTensor<B, D>;
/// Multiplies two tensors together element-wise.
fn mul<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Multiplies a tensor by a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The result of multiplying the tensor by the scalar.
fn mul_scalar<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> FloatTensor<B, D>;
/// Divides two tensors element-wise.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The result of dividing the two tensors.
fn div<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Divides a tensor by a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The result of dividing the tensor by the scalar.
fn div_scalar<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> FloatTensor<B, D>;
/// Multiplies two tensors together using matrix multiplication.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The result of multiplying the two tensors together using matrix multiplication.
fn matmul<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Negates a tensor element-wise.
fn neg<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
Self::mul_scalar(tensor, (-1.0_f32).elem::<FloatElem<B>>())
}
/// Calculates the reciprocals elementwise
fn recip<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Transposes a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to transpose.
///
/// # Returns
///
/// The transposed tensor.
fn transpose<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
Self::swap_dims(tensor, D - 2, D - 1)
}
/// Swaps two dimensions of a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to swap the dimensions of.
/// * `dim1` - The first dimension to swap.
/// * `dim2` - The second dimension to swap.
///
/// # Returns
///
/// The tensor with the dimensions swapped.
fn swap_dims<const D: usize>(
tensor: FloatTensor<B, D>,
dim1: usize,
dim2: usize,
) -> FloatTensor<B, D>;
/// Reshapes a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to reshape.
/// * `shape` - The new shape of the tensor.
///
/// # Returns
///
/// The tensor with the new shape.
fn reshape<const D1: usize, const D2: usize>(
tensor: FloatTensor<B, D1>,
shape: Shape<D2>,
) -> FloatTensor<B, D2>;
/// Gather elements from a tensor.
///
/// # Arguments
///
/// * `dim` - The dimension to gather from.
/// * `tensor` - The tensor to gather from.
/// * `indices` - The indices to gather.
///
/// # Returns
///
/// The gathered elements.
fn gather<const D: usize>(
dim: usize,
tensor: FloatTensor<B, D>,
indices: IntTensor<B, D>,
) -> FloatTensor<B, D>;
/// Scatter elements into a tensor.
///
/// # Arguments
///
/// * `dim` - The dimension to scatter into.
/// * `tensor` - The tensor to scatter into.
/// * `indices` - The indices to scatter into.
/// * `value` - The value to scatter.
///
/// # Returns
///
/// The tensor with the scattered elements.
fn scatter<const D: usize>(
dim: usize,
tensor: FloatTensor<B, D>,
indices: IntTensor<B, D>,
value: FloatTensor<B, D>,
) -> FloatTensor<B, D>;
/// Select tensor elements along the given dimension corresponding for the given indices.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `dim` - The dimension to select from.
/// * `indices` - The indices to select.
///
/// # Returns
///
/// The selected elements.
fn select<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,
indices: IntTensor<B, 1>,
) -> FloatTensor<B, D>;
/// Assign the selected elements along the given dimension corresponding for the given indices
/// to the given value.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `dim` - The dimension to select from.
/// * `indices` - The indices to select.
/// * `value` - The value to assign.
///
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn select_assign<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,
indices: IntTensor<B, 1>,
value: FloatTensor<B, D>,
) -> FloatTensor<B, D>;
/// Select tensor elements corresponding for the given ranges.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `ranges` - The ranges to select.
///
/// # Returns
///
/// The selected elements in a new tensor.
fn slice<const D1: usize, const D2: usize>(
tensor: FloatTensor<B, D1>,
ranges: [Range<usize>; D2],
) -> FloatTensor<B, D1>;
/// Assign the selected elements corresponding for the given ranges to the given value.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `ranges` - The ranges to select.
/// * `value` - The value to assign.
///
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn slice_assign<const D1: usize, const D2: usize>(
tensor: FloatTensor<B, D1>,
ranges: [Range<usize>; D2],
value: FloatTensor<B, D1>,
) -> FloatTensor<B, D1>;
/// Update the given tensor with the value tensor where the mask is true.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `mask` - The boolean mask to select with.
/// * `value` - The value to assign to the selected elements from the value tensor.
///
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn mask_where<const D: usize>(
tensor: FloatTensor<B, D>,
mask: BoolTensor<B, D>,
value: FloatTensor<B, D>,
) -> FloatTensor<B, D>;
/// Update the given tensor with the value where the mask is true.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `mask` - The boolean mask to select with.
/// * `value` - The value to assign to the selected elements.
///
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn mask_fill<const D: usize>(
tensor: FloatTensor<B, D>,
mask: BoolTensor<B, D>,
value: FloatElem<B>,
) -> FloatTensor<B, D>;
/// Equal comparison of two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn equal<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> BoolTensor<B, D>;
/// Equal comparison of a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn equal_elem<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> BoolTensor<B, D>;
/// Greater than comparison of two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn greater<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> BoolTensor<B, D>;
/// Greater than comparison of a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn greater_elem<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> BoolTensor<B, D>;
/// Greater than or equal comparison of two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn greater_equal<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatTensor<B, D>,
) -> BoolTensor<B, D>;
/// Greater than or equal comparison of a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn greater_equal_elem<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatElem<B>,
) -> BoolTensor<B, D>;
/// Less than comparison of two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn lower<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatTensor<B, D>) -> BoolTensor<B, D>;
/// Less than comparison of a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn lower_elem<const D: usize>(lhs: FloatTensor<B, D>, rhs: FloatElem<B>) -> BoolTensor<B, D>;
/// Less than or equal comparison of two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn lower_equal<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatTensor<B, D>,
) -> BoolTensor<B, D>;
/// Less than or equal comparison of a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn lower_equal_elem<const D: usize>(
lhs: FloatTensor<B, D>,
rhs: FloatElem<B>,
) -> BoolTensor<B, D>;
/// Detaches a tensor from the computation graph.
fn detach<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D> {
// Should only be overridden by autodiff backends.
tensor
}
/// Sets the `require_grad` flag of a tensor.
fn set_require_grad<const D: usize>(
tensor: FloatTensor<B, D>,
_require_grad: bool,
) -> FloatTensor<B, D> {
// Should only be overridden by autodiff backends.
tensor
}
/// Returns the `require_grad` flag of a tensor.
fn is_require_grad<const D: usize>(_tensor: &FloatTensor<B, D>) -> bool {
// Should only be overridden by autodiff backends.
false
}
/// Sum of all elements in a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to sum.
///
/// # Returns
///
/// A scalar tensor with the sum of all elements in `tensor`.
fn sum<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, 1>;
/// Sum of all elements in a tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to sum.
/// * `dim` - The dimension along which to sum.
///
/// # Returns
///
/// A tensor with the sum of all elements in `tensor` along `dim`.
fn sum_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D>;
/// Mean of all elements in a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to mean.
///
/// # Returns
///
/// A scalar tensor with the mean of all elements in `tensor`.
fn mean<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, 1> {
let num_elems = B::shape(&tensor).num_elements();
B::div_scalar(B::sum(tensor), (num_elems as i64).elem())
}
/// Mean of all elements in a tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to mean.
/// * `dim` - The dimension along which to mean.
///
/// # Returns
///
/// A tensor with the mean of all elements in `tensor` along `dim`.
fn mean_dim<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> FloatTensor<B, D>;
/// Converts a tensor to full precision.
///
/// # Arguments
///
/// * `tensor` - The tensor to convert.
///
/// # Returns
///
/// A tensor with the same values as `tensor` but with full precision.
fn to_full_precision<const D: usize>(
tensor: &FloatTensor<B, D>,
) -> FloatTensor<FullPrecisionBackend<B>, D>;
/// Converts a tensor from full precision.
///
/// # Arguments
///
/// * `tensor` - The tensor to convert.
///
/// # Returns
///
/// A tensor with the same values as `tensor` but with the precision of the backend.
fn from_full_precision<const D: usize>(
tensor: FloatTensor<FullPrecisionBackend<B>, D>,
) -> FloatTensor<B, D>;
/// Returns a new tensor with exponential values.
///
/// # Arguments
///
/// * `tensor` - The tensor to exponentiate.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with exponential values.
fn exp<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with natural logarithm values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take the logarithm of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with natural logarithm values.
fn log<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with logarithm values of (1 + Xi).
///
/// # Arguments
///
/// * `tensor` - The tensor to take the logarithm of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
fn log1p<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with values raised to the power of `value`.
///
/// # Arguments
///
/// * `tensor` - The tensor to exponentiate.
/// * `value` - The exponent.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with values raised to the power of `value`.
fn powf<const D: usize>(tensor: FloatTensor<B, D>, value: f32) -> FloatTensor<B, D>;
/// Returns a new tensor with square root values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take the square root of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with square root values.
fn sqrt<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with absolute values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take absolute value of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with absolute values.
fn abs<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with cosine values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take the cosine of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with cosine values.
fn cos<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with sine values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take the sine of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with sine values.
fn sin<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with tangent values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take the tangent of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with tangent values.
fn tanh<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Returns a new tensor with the error function values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take the error function of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with error function values.
fn erf<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, D>;
/// Catcatenates tensors along a dimension.
///
/// # Arguments
///
/// * `tensors` - The tensors to catcatenate.
/// * `dim` - The dimension along which to catcatenate.
///
/// # Returns
///
/// A tensor with the catcatenated tensors along `dim`.
fn cat<const D: usize>(tensors: Vec<FloatTensor<B, D>>, dim: usize) -> FloatTensor<B, D>;
/// Gets the indices of the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum elements of.
/// * `dim` - The dimension along which to get the maximum elements.
///
/// # Returns
///
/// A tensor with the indices of the maximum elements of `tensor` along `dim`.
fn argmax<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> IntTensor<B, D>;
/// Gets the indices of the minimum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum elements of.
/// * `dim` - The dimension along which to get the minimum elements.
///
/// # Returns
///
/// A tensor with the indices of the minimum elements of `tensor` along `dim`.
fn argmin<const D: usize>(tensor: FloatTensor<B, D>, dim: usize) -> IntTensor<B, D>;
/// Gets the maximum element of a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum elements of.
///
/// # Returns
///
/// A tensor with the maximum element of `tensor`.
fn max<const D: usize>(tensor: FloatTensor<B, D>) -> FloatTensor<B, 1> {
let shape = B::shape(&tensor);
let tensor = B::reshape(tensor, Shape::new([shape.num_elements()]));
B::max_dim(tensor, 0)
}
/// Gets the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum elements of.
/// * `dim` - The dimension along which to get the maximum elements.
///
/// # Returns
///
/// A tensor with the maximum elements of `tensor` along `dim`.