Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Homogenize interface for initializing Simulation, schedules, etc #3015

Merged
merged 9 commits into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ using Oceananigans.Grids: AbstractGrid

using DocStringExtensions

import Oceananigans: fields, prognostic_fields
import Oceananigans.Models: initialize_model!
import Oceananigans: fields, prognostic_fields, initialize!
import Oceananigans.Advection: cell_advection_timescale

abstract type AbstractFreeSurface{E, G} end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using Oceananigans: fields, prognostic_fields, TimeStepCallsite, TendencyCallsite, UpdateStateCallsite

using Oceananigans: fields, prognostic_fields, TendencyCallsite, UpdateStateCallsite
using Oceananigans.Architectures: device_event
using Oceananigans: fields, prognostic_fields, TimeStepCallsite, TendencyCallsite, UpdateStateCallsite
using Oceananigans.Utils: work_layout
using Oceananigans.Fields: immersed_boundary_condition
using Oceananigans.Biogeochemistry: update_tendencies!
Expand Down Expand Up @@ -36,7 +34,9 @@ function calculate_tendencies!(model::HydrostaticFreeSurfaceModel, callbacks)
model.closure,
model.buoyancy)

[callback(model) for callback in callbacks if isa(callback.callsite, TendencyCallsite)]
for callback in callbacks
callback.callsite isa TendencyCallsite && callback(model)
end

update_tendencies!(model.biogeochemistry, model)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,4 @@ validate_momentum_advection(momentum_advection::Union{VectorInvariant, Nothing},

initialize_model!(model::HydrostaticFreeSurfaceModel) = initialize_free_surface!(model.free_surface, model.grid, model.velocities)
initialize_free_surface!(free_surface, grid, velocities) = nothing

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using KernelAbstractions: NoneEvent
using CUDA: @allowscalar

using Oceananigans: UpdateStateCallsite
using Oceananigans.Grids: Flat, Bounded
using Oceananigans.Coriolis: AbstractRotation
using Oceananigans.TurbulenceClosures: AbstractTurbulenceClosure
Expand Down Expand Up @@ -52,7 +53,9 @@ function update_state!(model::HydrostaticFreeSurfaceModel, grid::SingleColumnGri

fill_halo_regions!(model.diffusivity_fields, model.clock, fields(model))

[callback(model) for callback in callbacks if isa(callback.callsite, UpdateStateCallsite)]
for callback in callbacks
callback.callsite isa UpdateStateCallsite && callback(model)
end

return nothing
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using Oceananigans.Architectures
using Oceananigans.Architectures: device_event
using Oceananigans.BoundaryConditions

using Oceananigans: UpdateStateCallsite
using Oceananigans.Architectures: device_event
using Oceananigans.Biogeochemistry: update_biogeochemical_state!
using Oceananigans.TurbulenceClosures: calculate_diffusivities!
using Oceananigans.ImmersedBoundaries: mask_immersed_field!, mask_immersed_field_xy!, inactive_node
Expand Down Expand Up @@ -35,7 +37,9 @@ function update_state!(model::HydrostaticFreeSurfaceModel, grid, callbacks)
fill_halo_regions!(model.diffusivity_fields, model.clock, fields(model))
fill_halo_regions!(model.pressure.pHY′)

[callback(model) for callback in callbacks if isa(callback.callsite, UpdateStateCallsite)]
for callback in callbacks
callback.callsite isa UpdateStateCallsite && callback(model)
end

update_biogeochemical_state!(model.biogeochemistry, model)

Expand Down
3 changes: 2 additions & 1 deletion src/Models/Models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ export
using Oceananigans: AbstractModel
using Oceananigans.Grids: halo_size, inflate_halo_size

import Oceananigans: initialize!
import Oceananigans.Architectures: device_event, architecture

device_event(model::AbstractModel) = device_event(model.architecture)
architecture(model::AbstractModel) = model.architecture

initialize_model!(model::AbstractModel) = nothing
initialize!(model::AbstractModel) = nothing

using Oceananigans: fields
import Oceananigans.TimeSteppers: reset!
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Oceananigans.Biogeochemistry: update_tendencies!
using Oceananigans: fields, TimeStepCallsite, TendencyCallsite, UpdateStateCallsite
using Oceananigans: fields, TendencyCallsite
using Oceananigans.Utils: work_layout

using Oceananigans.ImmersedBoundaries: use_only_active_cells, ActiveCellsIBG, active_linear_index_to_ntuple
Expand Down Expand Up @@ -35,7 +35,9 @@ function calculate_tendencies!(model::NonhydrostaticModel, callbacks)
model.clock,
fields(model))

[callback(model) for callback in callbacks if isa(callback.callsite, TendencyCallsite)]
for callback in callbacks
callback.callsite isa TendencyCallsite && callback(model)
end

update_tendencies!(model.biogeochemistry, model)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Oceananigans: UpdateStateCallsite
using Oceananigans.Architectures
using Oceananigans.BoundaryConditions
using Oceananigans.Biogeochemistry: update_biogeochemical_state!
Expand Down Expand Up @@ -36,9 +37,7 @@ function update_state!(model::NonhydrostaticModel, callbacks=[])
fill_halo_regions!(model.pressures.pHY′)

for callback in callbacks
if callback.callsite isa UpdateStateCallsite
callback(model)
end
callback.callsite isa UpdateStateCallsite && callback(model)
end

update_biogeochemical_state!(model.biogeochemistry, model)
Expand Down
2 changes: 2 additions & 0 deletions src/Oceananigans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ Abstract supertype for output writers that write data to disk.
"""
abstract type AbstractOutputWriter end

# Callsites for Callbacks
struct TimeStepCallsite end
struct TendencyCallsite end
struct UpdateStateCallsite end
Expand All @@ -178,6 +179,7 @@ struct UpdateStateCallsite end

function run_diagnostic! end
function write_output! end
function initialize! end # for initializing models, simulations, etc
function location end
function instantiated_location end
function tupleit end
Expand Down
8 changes: 8 additions & 0 deletions src/Simulations/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using Oceananigans.Utils: prettysummary
using Oceananigans.OutputWriters: WindowedTimeAverage, advance_time_average!
using Oceananigans: TimeStepCallsite, TendencyCallsite, UpdateStateCallsite

import Oceananigans: initialize!

struct Callback{P, F, S, CS}
func :: F
schedule :: S
Expand All @@ -12,6 +14,12 @@ end
@inline (callback::Callback)(sim) = callback.func(sim, callback.parameters)
@inline (callback::Callback{<:Nothing})(sim) = callback.func(sim)

# Fallback initialization: call the schedule, then the callback
function initialize!(callback::Callback, sim)
initialize!(callback.schedule, sim.model) && callback(sim)
return nothing
end

"""
Callback(func, schedule=IterationInterval(1); parameters=nothing)

Expand Down
37 changes: 22 additions & 15 deletions src/Simulations/run.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ using Oceananigans.TimeSteppers: QuasiAdamsBashforth2TimeStepper, RungeKutta3Tim

using Oceananigans: AbstractModel, run_diagnostic!, write_output!

using Oceananigans.Models: initialize_model!

import Oceananigans: initialize!
import Oceananigans.OutputWriters: checkpoint_path, set!
import Oceananigans.TimeSteppers: time_step!
import Oceananigans.Utils: aligned_time_step
Expand Down Expand Up @@ -103,14 +102,17 @@ function run!(sim; pickup=false)
return nothing
end

const ModelCallsite = Union{TendencyCallsite, UpdateStateCallsite}

""" Step `sim`ulation forward by one time step. """
function time_step!(sim::Simulation)

start_time_step = time_ns()
model_callbacks = Tuple(cb for cb in values(sim.callbacks) if cb isa ModelCallsite)

if !(sim.initialized) # execute initialization step
initialize_simulation!(sim)
initialize_model!(sim.model)
initialize!(sim)
initialize!(sim.model)

if sim.running # check that initialization didn't stop time-stepping
if sim.verbose
Expand All @@ -119,7 +121,7 @@ function time_step!(sim::Simulation)
end

Δt = aligned_time_step(sim, sim.Δt)
time_step!(sim.model, Δt; callbacks=[callback for callback in values(sim.callbacks) if !isa(callback.callsite, TimeStepCallsite)])
time_step!(sim.model, Δt, callbacks=model_callbacks)

if sim.verbose
elapsed_initial_step_time = prettytime(1e-9 * (time_ns() - start_time))
Expand All @@ -131,13 +133,21 @@ function time_step!(sim::Simulation)

else # business as usual...
Δt = aligned_time_step(sim, sim.Δt)
time_step!(sim.model, Δt; callbacks=[callback for callback in values(sim.callbacks) if !isa(callback.callsite, TimeStepCallsite)])
time_step!(sim.model, Δt, callbacks=model_callbacks)
end

# Callbacks and callback-like things
[diag.schedule(sim.model) && run_diagnostic!(diag, sim.model) for diag in values(sim.diagnostics)]
[callback.schedule(sim.model) && callback(sim) for callback in values(sim.callbacks) if isa(callback.callsite, TimeStepCallsite)]
[writer.schedule(sim.model) && write_output!(writer, sim.model) for writer in values(sim.output_writers)]
for diag in values(sim.diagnostics)
diag.schedule(sim.model) && run_diagnostic!(diag, sim.model)
end

for callback in values(sim.callbacks)
callback.callsite isa TimeStepCallsite && callback.schedule(sim.model) && callback(sim)
end

for writer in values(sim.output_writers)
writer.schedule(sim.model) && write_output!(writer, sim.model)
end

end_time_step = time_ns()

Expand All @@ -163,15 +173,15 @@ we_want_to_pickup(pickup::String) = true
we_want_to_pickup(pickup) = throw(ArgumentError("Cannot run! with pickup=$pickup"))

"""
initialize_simulation!(sim, pickup=false)
initialize!(sim::Simulation, pickup=false)

Initialize a simulation:

- Update the auxiliary state of the simulation (filling halo regions, computing auxiliary fields)
- Evaluate all diagnostics, callbacks, and output writers if sim.model.clock.iteration == 0
- Add diagnostics that "depend" on output writers
"""
function initialize_simulation!(sim)
function initialize!(sim::Simulation)
if sim.verbose
@info "Initializing simulation..."
start_time = time_ns()
Expand All @@ -196,10 +206,7 @@ function initialize_simulation!(sim)
end

for callback in values(sim.callbacks)
if isa(callback.callsite, TimeStepCallsite)
callback.schedule(model)
callback(sim)
end
callback.callsite isa TimeStepCallsite && initialize!(callback, sim)
end

for writer in values(sim.output_writers)
Expand Down
15 changes: 15 additions & 0 deletions src/Utils/schedules.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Oceananigans: initialize!

"""
AbstractSchedule

Expand All @@ -10,6 +12,17 @@ abstract type AbstractSchedule end
# Default behavior is no alignment.
aligned_time_step(schedule, clock, Δt) = Δt

# Fallback initialization for schedule: call the schedule,
# then return `true`, indicating that the schedule "actuates" at
# initial call.
function initialize!(schedule::AbstractSchedule, model)
schedule(model)

# `return true` indicates that the schedule
# "actuates" at initial call.
return true
end

#####
##### TimeInterval
#####
Expand Down Expand Up @@ -155,6 +168,8 @@ function (st::SpecifiedTimes)(model)
return false
end

initialize!(st::SpecifiedTimes, model) = st(model)

align_time_step(schedule::SpecifiedTimes, clock, Δt) = min(Δt, next_appointment_time(schedule) - clock.time)

function specified_times_str(st)
Expand Down
9 changes: 9 additions & 0 deletions test/test_schedules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ include("dependencies_for_runtests.jl")

using Oceananigans.Utils: TimeInterval, IterationInterval, WallTimeInterval, SpecifiedTimes
using Oceananigans.TimeSteppers: Clock
using Oceananigans: initialize!

@testset "Schedules" begin
@info "Testing schedules..."

# Some fake models
fake_model_at_iter_0 = (; clock=Clock(time=0.0, iteration=0))
fake_model_at_iter_3 = (; clock=Clock(time=1.0, iteration=3))
fake_model_at_iter_5 = (; clock=Clock(time=2.0, iteration=5))

Expand All @@ -20,12 +22,14 @@ using Oceananigans.TimeSteppers: Clock
@test ti.interval == 2.0
@test ti(fake_model_at_time_2)
@test !(ti(fake_model_at_time_3))
@test initialize!(ti, fake_model_at_iter_0)

# IterationInterval
ii = IterationInterval(3)

@test !(ii(fake_model_at_iter_5))
@test ii(fake_model_at_iter_3)
@test initialize!(ii, fake_model_at_iter_0)

# OrSchedule
ti_and_ii = AndSchedule(TimeInterval(2), IterationInterval(3))
Expand Down Expand Up @@ -53,6 +57,7 @@ using Oceananigans.TimeSteppers: Clock
st_vector = SpecifiedTimes([2, 5, 6])
@test st_list.times == st_vector.times
@test st.times == [2.0, 5.0, 6.0]
@test !(initialize!(st, fake_model_at_iter_0))

# Times are sorted
st = SpecifiedTimes(5, 2, 6)
Expand All @@ -62,4 +67,8 @@ using Oceananigans.TimeSteppers: Clock

@test !(st(fake_model_at_time_4))
@test st(fake_model_at_time_5)

# Specified times includes iteration 0
st = SpecifiedTimes(0, 2, 4)
@test initialize!(st, fake_model_at_iter_0)
end