In [None]:
using Pkg
# Activate environment that has ProgressiveHedging installed
Pkg.activate("..")
Pkg.status()

In [None]:
using Distributed
const WORKERS = 1 # Change to > 1 to use parallel
if nworkers() < WORKERS
    diff = (nprocs() == nworkers() ? WORKERS : WORKERS - nworkers())
    println("Adding $diff worker processes.")
    Distributed.addprocs(diff)
    # Make sure these workers also have an environment with PH installed
    @everywhere using Pkg
    for w in workers()
        @spawnat(w, Pkg.activate(".."))
    end
end

@everywhere using ProgressiveHedging
@everywhere const PH = ProgressiveHedging

@everywhere using Ipopt
@everywhere using JuMP

### Model Construction Function Definition

In [None]:
@everywhere function create_model(scenario_id::PH.ScenarioID,
        additional_arg::String;
        key_word_arg::String="Default key word argument", kwargs...
        )
    
    model = JuMP.Model(()->Ipopt.Optimizer())
    JuMP.set_optimizer_attribute(model, "print_level", 0)
    JuMP.set_optimizer_attribute(model, "tol", 1e-12)
    
    println(additional_arg)
    println(key_word_arg)
    
    c = [1.0, 10.0, 0.01]
    d = 7.0
    a = 16.0

    α = 1.0
    β = 1.0
    γ = 1.0
    δ = 1.0
    ϵ = 1.0

    s1 = 8.0
    s2 = 4.0
    s11 = 9.0
    s12 = 16.0
    s21 = 5.0
    s22 = 18.0
    
    stage1 = JuMP.@variable(model, x[1:3] >= 0.0)
    JuMP.@constraint(model, x[3] <= 1.0)
    obj = zero(JuMP.GenericQuadExpr{Float64,JuMP.VariableRef})
    JuMP.add_to_expression!(obj, sum(c.*x))

    # Second stage
    stage2 = Vector{JuMP.VariableRef}()
    if scenario_id < PH.scid(2)
        vref = JuMP.@variable(model, y >= 0.0)
        JuMP.@constraint(model, α*sum(x) + β*y >= s1)
        JuMP.add_to_expression!(obj, d*y)
    else
        vref = JuMP.@variable(model, y >= 0.0)
        JuMP.@constraint(model, α*sum(x) + β*y >= s2)
        JuMP.add_to_expression!(obj, d*y)
    end
    push!(stage2, vref)

    # Third stage
    stage3 = Vector{JuMP.VariableRef}()
    if scenario_id == PH.scid(0)
        vref = JuMP.@variable(model, z[1:2])
        JuMP.@constraint(model, ϵ*sum(x) + γ*y + δ*sum(z) == s11)
        JuMP.add_to_expression!(obj, a*sum(z[i]^2 for i in 1:2))
        
    elseif scenario_id == PH.scid(1)
        vref = JuMP.@variable(model, z[1:2])
        JuMP.@constraint(model, ϵ*sum(x) + γ*y + δ*sum(z) == s12)
        JuMP.add_to_expression!(obj, a*sum(z[i]^2 for i in 1:2))

    elseif scenario_id == PH.scid(2)
        vref = JuMP.@variable(model, z[1:2])
        JuMP.@constraint(model, ϵ*sum(x) + γ*y + δ*sum(z) == s21)
        JuMP.add_to_expression!(obj, a*sum(z[i]^2 for i in 1:2))

    else
        vref = JuMP.@variable(model, z[1:2])
        JuMP.@constraint(model, ϵ*sum(x) + γ*y + δ*sum(z) == s22)
        JuMP.add_to_expression!(obj, a*sum(z[i]^2 for i in 1:2))
    end
    append!(stage3, vref)

    JuMP.@objective(model, Min, obj)
    
    vdict = Dict{PH.StageID, Vector{JuMP.VariableRef}}([PH.stid(1) => stage1,
                                                        PH.stid(2) => stage2,
                                                        PH.stid(3) => stage3,
                                                        ])
    
    return JuMPSubproblem(model, scenario_id, vdict)
end
;

### Build Scenario Tree
Note that you must call add_leaf when adding a leaf node instead of add_node.  Failing to do so will cause undefined behavior.

In [None]:
function build_scen_tree()

    probs = [0.5*0.75, 0.5*0.25, 0.5*0.75, 0.5*0.25]
    
    tree = PH.ScenarioTree()
    
    for k in 1:2
        node2 = PH.add_node(tree, tree.root)
        for l in 1:2
            PH.add_leaf(tree, node2, probs[(k-1)*2 + l])
        end
    end
    return tree
end
;

### Extensive Form Solve

In [None]:
ef_model = @time PH.solve_extensive(build_scen_tree(),
    create_model, 
    ()->Ipopt.Optimizer(),
    "Unused example string",
    opt_args=(print_level=0,)
)
println(ef_model)

In [None]:
println(JuMP.termination_status(ef_model))
println(JuMP.objective_value(ef_model))
for var in JuMP.all_variables(ef_model)
    println("$var = $(JuMP.value(var))")
end

### PH Solve

In [None]:
st = build_scen_tree()
(n, err, rerr, obj, soln, phd) = PH.solve(st,
                                          create_model,
                                          PH.ScalarPenaltyParameter(25.0), "Passed to constructor.",
                                          atol=1e-8, rtol=1e-12, max_iter=500, report=0,
                                          key_word_arg="We changed this key word arg!")
println("Number of iterations: ", n)
println("L^2 error: ", err)
println(obj)

In [None]:
aobj = PH.retrieve_aug_obj_value(phd)
println("Augmented Objective: ", aobj)
println("Difference: ", aobj - obj)