Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mean/sum/stddev of multiple dimensions #216

Closed
dsyme opened this issue Oct 12, 2020 · 2 comments
Closed

mean/sum/stddev of multiple dimensions #216

dsyme opened this issue Oct 12, 2020 · 2 comments

Comments

@dsyme
Copy link
Collaborator

dsyme commented Oct 12, 2020

We should add overloads of mean, sum, stddev etc that take multiple dimensions.

From the torch documentation for mean:

"Returns the mean value of each row of the input tensor in the given dimension dim. If dim is a list of dimensions, reduce over all of them"

@dsyme
Copy link
Collaborator Author

dsyme commented Oct 12, 2020

Here's the code if you need this before it's added:

type Tensor with 
    member t.mean (dims: int[], ?keepDim: bool) =
           (t, Array.rev (Array.sort dims)) ||> Array.fold (fun input dim -> dsharp.mean(input, dim, ?keepDim=keepDim))

    member t.variance (dims: int[], ?keepDim: bool) =
           (t, Array.rev (Array.sort dims)) ||> Array.fold (fun input dim -> dsharp.variance(input, dim, ?keepDim=keepDim))

    member t.sum (dims: int[], ?keepDim: bool) =
           (t, Array.rev (Array.sort dims)) ||> Array.fold (fun input dim -> dsharp.sum(input, dim, ?keepDim=keepDim))

    member t.stddev (dims: int[], ?keepDim: bool) =
           (t, Array.rev (Array.sort dims)) ||> Array.fold (fun input dim -> dsharp.stddev(input, dim, ?keepDim=keepDim))

    member t.moments () =
           dsharp.mean(t), dsharp.stddev(t)

    member t.moments (dim: int, ?keepDim: bool) =
           t.mean(dim, ?keepDim=keepDim), t.stddev(dim, ?keepDim=keepDim)

    member t.moments (dims: int[], ?keepDim: bool) =
           t.mean(dims, ?keepDim=keepDim), t.stddev(dims, ?keepDim=keepDim)


type dsharp with 
    static member mean (input: Tensor, dims: int[], ?keepDim: bool) = input.mean(dims, ?keepDim=keepDim)

    static member variance (input: Tensor, dims: int[], ?keepDim: bool) = input.variance(dims, ?keepDim=keepDim)

    static member sum (input: Tensor, dims: int[], ?keepDim: bool) = input.sum(dims, ?keepDim=keepDim)

    static member stddev (input: Tensor, dims: int[], ?keepDim: bool) = input.stddev(dims, ?keepDim=keepDim)

    static member moments (input: Tensor) = input.moments()

    static member moments (input: Tensor, dim: int, ?keepDim: bool) = input.moments(dim, ?keepDim=keepDim)

    static member moments (input: Tensor, dims: int[], ?keepDim: bool) = input.moments(dims, ?keepDim=keepDim)

@dsyme
Copy link
Collaborator Author

dsyme commented Nov 20, 2020

Tracking this in #227

@dsyme dsyme closed this as completed Nov 20, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant