Skip to content

Commit

Permalink
Simplify parameter types for RTE shortwave solver (#428)
Browse files Browse the repository at this point in the history
  • Loading branch information
sriharshakandala committed Jan 19, 2024
1 parent ac026f1 commit 9f53b9d
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 72 deletions.
73 changes: 34 additions & 39 deletions src/rte/shortwave1scalar.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
function rte_sw_noscat_solve!(
device::ClimaComms.AbstractCPUDevice,
flux_sw::FluxSW{FT},
flux_sw::FluxSW,
op::OneScalar,
bcs_sw::SwBCs{FT},
bcs_sw::SwBCs,
max_threads,
as::GrayAtmosphericState{FT},
) where {FT <: AbstractFloat}
as::GrayAtmosphericState,
)
(; nlay, ncol) = as
nlev = nlay + 1
n_gpt, igpt = 1, 1
FT = eltype(op.angle_disc.gauss_Ds)
solar_frac = FT(1)
@inbounds begin
ClimaComms.@threaded device for gcol in 1:ncol
Expand All @@ -22,12 +23,12 @@ end

function rte_sw_noscat_solve!(
device::ClimaComms.CUDADevice,
flux_sw::FluxSW{FT},
flux_sw::FluxSW,
op::OneScalar,
bcs_sw::SwBCs{FT},
bcs_sw::SwBCs,
max_threads,
as::GrayAtmosphericState{FT},
) where {FT <: AbstractFloat}
as::GrayAtmosphericState,
)
(; nlay, ncol) = as
nlev = nlay + 1
tx = min(ncol, max_threads)
Expand All @@ -37,17 +38,11 @@ function rte_sw_noscat_solve!(
return nothing
end

function rte_sw_noscat_solve_CUDA!(
flux_sw::FluxSW{FT},
op::OneScalar,
bcs_sw::SwBCs{FT},
nlay,
ncol,
as::GrayAtmosphericState{FT},
) where {FT <: AbstractFloat}
function rte_sw_noscat_solve_CUDA!(flux_sw::FluxSW, op::OneScalar, bcs_sw::SwBCs, nlay, ncol, as::GrayAtmosphericState)
gcol = threadIdx().x + (blockIdx().x - 1) * blockDim().x # global id
nlev = nlay + 1
n_gpt, igpt = 1, 1
FT = eltype(op.angle_disc.gauss_Ds)
solar_frac = FT(1)
# setting references for flux_sw
if gcol ncol
Expand All @@ -62,15 +57,15 @@ end

function rte_sw_noscat_solve!(
device::ClimaComms.AbstractCPUDevice,
flux::FluxSW{FT},
flux_sw::FluxSW{FT},
flux::FluxSW,
flux_sw::FluxSW,
op::OneScalar,
bcs_sw::SwBCs{FT},
bcs_sw::SwBCs,
max_threads,
as::AtmosphericState{FT},
as::AtmosphericState,
lookup_sw::LookUpSW,
lookup_sw_cld::Union{LookUpCld, PadeCld, Nothing} = nothing,
) where {FT <: AbstractFloat}
)
(; nlay, ncol) = as
nlev = nlay + 1
n_gpt = length(lookup_sw.solar_src_scaled)
Expand All @@ -91,15 +86,15 @@ end

function rte_sw_noscat_solve!(
device::ClimaComms.CUDADevice,
flux::FluxSW{FT},
flux_sw::FluxSW{FT},
flux::FluxSW,
flux_sw::FluxSW,
op::OneScalar,
bcs_sw::SwBCs{FT},
bcs_sw::SwBCs,
max_threads,
as::AtmosphericState{FT},
as::AtmosphericState,
lookup_sw::LookUpSW,
lookup_sw_cld::Union{LookUpCld, PadeCld, Nothing} = nothing,
) where {FT <: AbstractFloat}
)
(; nlay, ncol) = as
nlev = nlay + 1
# setting references for flux_sw
Expand All @@ -111,16 +106,16 @@ function rte_sw_noscat_solve!(
end

function rte_sw_noscat_solve_CUDA!(
flux::FluxSW{FT},
flux_sw::FluxSW{FT},
flux::FluxSW,
flux_sw::FluxSW,
op::OneScalar,
bcs_sw::SwBCs{FT},
bcs_sw::SwBCs,
nlay,
ncol,
as::AtmosphericState{FT},
as::AtmosphericState,
lookup_sw::LookUpSW,
lookup_sw_cld::Union{LookUpCld, PadeCld, Nothing} = nothing,
) where {FT <: AbstractFloat}
)
gcol = threadIdx().x + (blockIdx().x - 1) * blockDim().x # global id
nlev = nlay + 1
n_gpt = length(lookup_sw.solar_src_scaled)
Expand All @@ -139,27 +134,27 @@ end

"""
rte_sw_noscat!(
flux::FluxSW{FT},
flux::FluxSW,
op::OneScalar,
bcs_sw::SwBCs{FT},
solar_frac::FT,
bcs_sw::SwBCs,
solar_frac::AbstractFloat,
gcol,
nlev,
) where {FT<:AbstractFloat}
)
No-scattering solver for the shortwave problem.
(Extinction-only i.e. solar direct beam)
"""
function rte_sw_noscat!(
flux::FluxSW{FT},
flux::FluxSW,
op::OneScalar,
bcs_sw::SwBCs{FT},
bcs_sw::SwBCs,
igpt::Int,
n_gpt::Int,
solar_frac::FT,
solar_frac::AbstractFloat,
gcol::Int,
nlev::Int,
) where {FT <: AbstractFloat}
)
(; toa_flux, cos_zenith) = bcs_sw
τ = op.τ
(; flux_dn_dir, flux_net) = flux
Expand Down
69 changes: 36 additions & 33 deletions src/rte/shortwave2stream.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
function rte_sw_2stream_solve!(
device::ClimaComms.AbstractCPUDevice,
flux_sw::FluxSW{FT},
flux_sw::FluxSW,
op::TwoStream,
bcs_sw::SwBCs{FT},
src_sw::SourceSW2Str{FT},
bcs_sw::SwBCs,
src_sw::SourceSW2Str,
max_threads,
as::GrayAtmosphericState{FT},
) where {FT <: AbstractFloat}
as::GrayAtmosphericState,
)
(; nlay, ncol) = as
nlev = nlay + 1
n_gpt, igpt, ibnd = 1, 1, UInt8(1)
FT = eltype(bcs_sw.cos_zenith)
solar_frac = FT(1)
@inbounds begin
ClimaComms.@threaded device for gcol in 1:ncol
Expand All @@ -24,13 +25,13 @@ end

function rte_sw_2stream_solve!(
device::ClimaComms.CUDADevice,
flux_sw::FluxSW{FT},
flux_sw::FluxSW,
op::TwoStream,
bcs_sw::SwBCs{FT},
src_sw::SourceSW2Str{FT},
bcs_sw::SwBCs,
src_sw::SourceSW2Str,
max_threads,
as::GrayAtmosphericState{FT},
) where {FT <: AbstractFloat}
as::GrayAtmosphericState,
)
(; nlay, ncol) = as
nlev = nlay + 1
tx = min(ncol, max_threads)
Expand All @@ -41,17 +42,18 @@ function rte_sw_2stream_solve!(
end

function rte_sw_2stream_solve_CUDA!(
flux_sw::FluxSW{FT},
flux_sw::FluxSW,
op::TwoStream,
bcs_sw::SwBCs{FT},
src_sw::SourceSW2Str{FT},
bcs_sw::SwBCs,
src_sw::SourceSW2Str,
nlay,
ncol,
as::GrayAtmosphericState{FT},
) where {FT <: AbstractFloat}
as::GrayAtmosphericState,
)
gcol = threadIdx().x + (blockIdx().x - 1) * blockDim().x # global id
nlev = nlay + 1
n_gpt, igpt, ibnd = 1, 1, UInt8(1)
FT = eltype(bcs_sw.cos_zenith)
solar_frac = FT(1)
if gcol ncol
@inbounds begin
Expand All @@ -66,16 +68,16 @@ end

function rte_sw_2stream_solve!(
device::ClimaComms.AbstractCPUDevice,
flux::FluxSW{FT},
flux_sw::FluxSW{FT},
flux::FluxSW,
flux_sw::FluxSW,
op::TwoStream,
bcs_sw::SwBCs{FT},
src_sw::SourceSW2Str{FT},
bcs_sw::SwBCs,
src_sw::SourceSW2Str,
max_threads,
as::AtmosphericState{FT},
as::AtmosphericState,
lookup_sw::LookUpSW,
lookup_sw_cld::Union{LookUpCld, PadeCld, Nothing} = nothing,
) where {FT <: AbstractFloat}
)
(; nlay, ncol) = as
nlev = nlay + 1
n_gpt = length(lookup_sw.solar_src_scaled)
Expand Down Expand Up @@ -104,16 +106,16 @@ end

function rte_sw_2stream_solve!(
device::ClimaComms.CUDADevice,
flux::FluxSW{FT},
flux_sw::FluxSW{FT},
flux::FluxSW,
flux_sw::FluxSW,
op::TwoStream,
bcs_sw::SwBCs{FT},
src_sw::SourceSW2Str{FT},
bcs_sw::SwBCs,
src_sw::SourceSW2Str,
max_threads,
as::AtmosphericState{FT},
as::AtmosphericState,
lookup_sw::LookUpSW,
lookup_sw_cld::Union{LookUpCld, PadeCld, Nothing} = nothing,
) where {FT <: AbstractFloat}
)
(; nlay, ncol) = as
nlev = nlay + 1
n_gpt = length(lookup_sw.solar_src_scaled)
Expand All @@ -125,17 +127,17 @@ function rte_sw_2stream_solve!(
end

function rte_sw_2stream_solve_CUDA!(
flux::FluxSW{FT},
flux_sw::FluxSW{FT},
flux::FluxSW,
flux_sw::FluxSW,
op::TwoStream,
bcs_sw::SwBCs{FT},
src_sw::SourceSW2Str{FT},
bcs_sw::SwBCs,
src_sw::SourceSW2Str,
nlay,
ncol,
as::AtmosphericState{FT},
as::AtmosphericState,
lookup_sw::LookUpSW,
lookup_sw_cld::Union{LookUpCld, PadeCld, Nothing} = nothing,
) where {FT <: AbstractFloat}
)
gcol = threadIdx().x + (blockIdx().x - 1) * blockDim().x # global id
nlev = nlay + 1
n_gpt = length(lookup_sw.major_gpt2bnd)
Expand All @@ -145,6 +147,7 @@ function rte_sw_2stream_solve_CUDA!(
flux_net_sw = flux_sw.flux_net
flux_up = flux.flux_up
flux_dn = flux.flux_dn
FT = eltype(flux_up)
@inbounds for ilev in 1:nlev
flux_up_sw[ilev, gcol] = FT(0)
flux_dn_sw[ilev, gcol] = FT(0)
Expand Down

0 comments on commit 9f53b9d

Please sign in to comment.