Skip to content

Commit

Permalink
Simplify types used in math (#14928)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaananW committed Apr 5, 2024
1 parent 56ae711 commit e363fd0
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions packages/dev/core/src/Maths/tensor.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { DeepImmutable, Flatten, FloatArray, Length } from "../types";
import type { DeepImmutable, Flatten, FloatArray, Length, Tuple } from "../types";
/**
* Computes the tensor dimension of a multi-dimensional array
*/
Expand All @@ -14,6 +14,8 @@ export type TensorValue = number[] | TensorValue[];
*/
export type ValueOfTensor<T = unknown> = T extends Tensor<infer V> ? V : TensorValue;

type TensorNumberArray<V extends TensorValue> = Length<Dimension<V>> extends 2 ? Tuple<number, 16> : V;

/**
* Describes a mathematical tensor.
* @see https://wikipedia.org/wiki/Tensor
Expand Down Expand Up @@ -67,7 +69,7 @@ export interface Tensor<V extends TensorValue = TensorValue> {
* Copy the current instance to an array
* @returns a new array with the instance coordinates.
*/
asArray(): Flatten<V>;
asArray(): TensorNumberArray<V>;

/**
* Sets the current instance coordinates with the given source coordinates
Expand All @@ -81,13 +83,13 @@ export interface Tensor<V extends TensorValue = TensorValue> {
* @returns the current updated instance
*/

copyFromFloats(...floats: Flatten<V>): this;
copyFromFloats(...floats: TensorNumberArray<V>): this;

/**
* Sets the instance coordinates with the given floats
* @returns the current updated instance
*/
set(...values: Flatten<V>): this;
set(...values: TensorNumberArray<V>): this;

/**
* Sets the instance coordinates to the given value
Expand Down Expand Up @@ -122,7 +124,7 @@ export interface Tensor<V extends TensorValue = TensorValue> {
* @param floats the floats to add
* @returns the current updated instance
*/
addInPlaceFromFloats(...floats: Flatten<V>): this;
addInPlaceFromFloats(...floats: TensorNumberArray<V>): this;

/**
* Returns a new instance set with the subtracted coordinates of other's coordinates from the current coordinates.
Expand Down Expand Up @@ -151,15 +153,15 @@ export interface Tensor<V extends TensorValue = TensorValue> {
* @param floats the coordinates to subtract
* @returns the resulting instance
*/
subtractFromFloats(...floats: Flatten<V>): this;
subtractFromFloats(...floats: TensorNumberArray<V>): this;

/**
* Subtracts the given floats from the current instance coordinates and set the given instance "result" with this result
* Note: Implementation uses array magic so types may be confusing.
* @param args the coordinates to subtract with the last element as the result
* @returns the result
*/
subtractFromFloatsToRef(...args: [...Flatten<V>, this]): this;
subtractFromFloatsToRef(...args: [...TensorNumberArray<V>, this]): this;

/**
* Returns a new instance set with the multiplication of the current instance and the given one coordinates
Expand Down Expand Up @@ -187,7 +189,7 @@ export interface Tensor<V extends TensorValue = TensorValue> {
* Gets a new instance set with the instance coordinates multiplied by the given floats
* @returns a new instance
*/
multiplyByFloats(...floats: Flatten<V>): this;
multiplyByFloats(...floats: TensorNumberArray<V>): this;

/**
* Returns a new instance set with the instance coordinates divided by the given one coordinates
Expand Down Expand Up @@ -223,7 +225,7 @@ export interface Tensor<V extends TensorValue = TensorValue> {
* @param floats defines the floats to compare against
* @returns this current updated instance
*/
minimizeInPlaceFromFloats(...floats: Flatten<V>): this;
minimizeInPlaceFromFloats(...floats: TensorNumberArray<V>): this;

/**
* Updates the current instance with the maximal coordinate values between its and the given instance ones.
Expand All @@ -237,7 +239,7 @@ export interface Tensor<V extends TensorValue = TensorValue> {
* @param floats defines the floats to compare against
* @returns this current updated instance
*/
maximizeInPlaceFromFloats(...floats: Flatten<V>): this;
maximizeInPlaceFromFloats(...floats: TensorNumberArray<V>): this;

/**
* Gets a new instance with current instance negated coordinates
Expand Down Expand Up @@ -308,7 +310,7 @@ export interface Tensor<V extends TensorValue = TensorValue> {
* @param floats defines the coordinates to compare against
* @returns true if both instances are equal
*/
equalsToFloats(...floats: Flatten<V>): boolean;
equalsToFloats(...floats: TensorNumberArray<V>): boolean;

/**
* Gets a new instance from current instance floored values
Expand Down

0 comments on commit e363fd0

Please sign in to comment.