Skip to content

Commit

Permalink
Add device function for FiniteDifferenceSpace
Browse files Browse the repository at this point in the history
  • Loading branch information
sriharshakandala committed Mar 3, 2023
1 parent 95e8fbf commit 1349ea1
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 23 deletions.
14 changes: 9 additions & 5 deletions src/Fields/mapreduce.jl
Expand Up @@ -19,8 +19,7 @@ local_sum(
todata(field),
),
)
local_sum(field::Field) = local_sum(field, Device.device(field))
local_sum(field::Base.Broadcast.Broadcasted{<:FieldStyle}) =
local_sum(field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}}) =
local_sum(field, Device.device(axes(field)))
"""
sum([f=identity,]v::Field)
Expand Down Expand Up @@ -51,8 +50,7 @@ function Base.sum(
end
Base.sum(fn, field::Field, ::ClimaComms.CPU) =
Base.sum(Base.Broadcast.broadcasted(fn, field))
Base.sum(field::Field) = Base.sum(field, Device.device(field))
Base.sum(field::Base.Broadcast.Broadcasted{<:FieldStyle}) =
Base.sum(field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}}) =
Base.sum(field, Device.device(axes(field)))
Base.sum(fn, field::Field) = Base.sum(fn, field, Device.device(field))

Expand Down Expand Up @@ -106,6 +104,7 @@ If `v` is a distributed field, this uses a `ClimaComms.allreduce` operation.
"""
function Statistics.mean(
field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}},
::ClimaComms.CPU,
)
space = axes(field)
context = comm_context(space)
Expand All @@ -115,9 +114,14 @@ function Statistics.mean(
sum_v, area_v = data_combined[]
RecursiveApply.rdiv(sum_v, area_v)
end
Statistics.mean(fn, field::Field) =
Statistics.mean(fn, field::Field, ::ClimaComms.CPU) =
Statistics.mean(Base.Broadcast.broadcasted(fn, field))

Statistics.mean(field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}}) =
Statistics.mean(field, Device.device(axes(field)))
Statistics.mean(fn, field::Field) =
Statistics.mean(fn, field, Device.device(axes(field)))

"""
norm(v::Field, p=2; normalize=true)
Expand Down
36 changes: 18 additions & 18 deletions src/Fields/mapreduce_cuda.jl
@@ -1,31 +1,31 @@
local_sum(
field::Field{V},
::ClimaComms.CUDA,
) where {
Nij,
A <: AbstractArray,
V <:
Union{DataLayouts.IJFH{<:Any, Nij, A}, DataLayouts.VIJFH{<:Any, Nij, A}},
} = mapreduce_cuda(identity, +, field, weighting = true)

Base.sum(
field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}},
device::ClimaComms.CUDA,
) = local_sum(field, device) #TODO: distributed support to be added
::ClimaComms.CUDA,
) = mapreduce_cuda(identity, +, field, weighting = true) #TODO: distributed support to be added

Base.sum(fn, field::Field, ::ClimaComms.CUDA) =
mapreduce_cuda(fn, +, field, weighting = true)
#TODO: distributed support to be added
mapreduce_cuda(fn, +, field, weighting = true) #TODO: distributed support to be added

Base.maximum(fn, field::Field, ::ClimaComms.CUDA) =
mapreduce_cuda(fn, max, field)
mapreduce_cuda(fn, max, field) #TODO: distributed support to be added

Base.maximum(field::Field, ::ClimaComms.CUDA) =
mapreduce_cuda(identity, max, field)
mapreduce_cuda(identity, max, field) #TODO: distributed support to be added

Base.minimum(fn, field::Field, ::ClimaComms.CUDA) =
mapreduce_cuda(fn, min, field)
mapreduce_cuda(fn, min, field) #TODO: distributed support to be added

Base.minimum(field::Field, ::ClimaComms.CUDA) =
mapreduce_cuda(identity, min, field)
mapreduce_cuda(identity, min, field) #TODO: distributed support to be added

Statistics.mean(
field::Union{Field, Base.Broadcast.Broadcasted{<:FieldStyle}},
::ClimaComms.CUDA,
) = Base.sum(field) ./ Spaces.local_area(axes(field)) #TODO: distributed support to be added

Statistics.mean(fn, field::Field, ::ClimaComms.CUDA) =
Base.sum(fn, field) ./ Spaces.local_area(axes(field)) #TODO: distributed support to be added

function mapreduce_cuda(
f,
Expand Down Expand Up @@ -78,7 +78,7 @@ function mapreduce_cuda(
Val(shmemsize),
)
end
return tuple(Array(reduce_cuda)[1, :]...)
return Array(Array(reduce_cuda)[1, :])
end

function mapreduce_cuda_kernel!(
Expand Down
2 changes: 2 additions & 0 deletions src/Spaces/Spaces.jl
Expand Up @@ -69,5 +69,7 @@ If `space` is distributed, this uses a `ClimaComms.allreduce` operation.
area(space::Spaces.AbstractSpace) =
ClimaComms.allreduce(comm_context(space), local_area(space), +)

Device.device_array_type(space::AbstractSpace) =
Device.device_array_type(Device.device(space))

end # module
1 change: 1 addition & 0 deletions src/Spaces/finitedifference.jl
Expand Up @@ -125,6 +125,7 @@ end
FiniteDifferenceSpace{S}(mesh::Meshes.IntervalMesh) where {S <: Staggering} =
FiniteDifferenceSpace{S}(Topologies.IntervalTopology(mesh))

Device.device(space::FiniteDifferenceSpace) = ClimaComms.CPU()

const CenterFiniteDifferenceSpace = FiniteDifferenceSpace{CellCenter}
const FaceFiniteDifferenceSpace = FiniteDifferenceSpace{CellFace}
Expand Down

0 comments on commit 1349ea1

Please sign in to comment.