Skip to content

Commit

Permalink
Improve aggregation speeds by summing function
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Aug 24, 2023
1 parent 7318c58 commit a8f7f46
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ include("losses/weighted.jl")
Return sum of `loss` values over the iterables `outputs` and `targets`.
"""
function sum(loss::SupervisedLoss, outputs, targets)
sum(loss(ŷ, y) for (ŷ, y) in zip(outputs, targets))
sum(i -> loss(outputs[i], targets[i]), eachindex(outputs, targets))
end

"""
Expand All @@ -46,7 +46,7 @@ The `weights` determine the importance of each observation. The option
`normalize` divides the result by the sum of the weights.
"""
function sum(loss::SupervisedLoss, outputs, targets, weights; normalize=true)
s = sum(w * loss(ŷ, y) for (ŷ, y, w) in zip(outputs, targets, weights))
s = sum(i -> weights[i] * loss(outputs[i], targets[i]), eachindex(outputs, targets, weights))
n = normalize ? sum(weights) : one(first(weights))
s / n
end
Expand All @@ -57,7 +57,7 @@ end
Return mean of `loss` values over the iterables `outputs` and `targets`.
"""
function mean(loss::SupervisedLoss, outputs, targets)
mean(loss(ŷ, y) for (ŷ, y) in zip(outputs, targets))
mean(i -> loss(outputs[i], targets[i]), eachindex(outputs, targets))
end

"""
Expand All @@ -68,7 +68,7 @@ The `weights` determine the importance of each observation. The option
`normalize` divides the result by the sum of the weights.
"""
function mean(loss::SupervisedLoss, outputs, targets, weights; normalize=true)
m = mean(w * loss(ŷ, y) for (ŷ, y, w) in zip(outputs, targets, weights))
m = mean(i -> weights[i] * loss(outputs[i], targets[i]), eachindex(outputs, targets, weights))
n = normalize ? sum(weights) : one(first(weights))
m / n
end

0 comments on commit a8f7f46

Please sign in to comment.