In [None]:
using Random, NLopt, LinearAlgebra, Interpolations, Base.Threads, DataFrames
using Plots, Statistics, ProgressMeter, ForwardDiff, Distributions, Profile
using DataFrames, Measures, StatsBase, LaTeXStrings, Printf, ForwardDiff, Logging
using StaticArrays: SVector

In [None]:
function get_child_value_interp(child_model::ConSavLabor)
    solve_model_work!(child_model)
    V = interpolate(
        (child_model.a_grid, child_model.k_grid),
        child_model.sol_v[1, :, :],
        Gridded(Linear())
    )
    return extrapolate(V, Line())
end


function solve_model!(model::Parent_model, V_child_interp)
    T, Na, Nk, Nhc = model.T, model.Na, model.Nk, model.Nhc
    a_grid, k_grid, hc_grid = model.a_grid, model.k_grid, model.hc_grid
    
    # Define T_terminal (separation period), e.g., child's age 18
    T_terminal = model.T_terminal  # Assume this is defined in the model, e.g., 18

    # ----- Final period (t = T), parent only -----
    println("Solving final period $T ...")
    for i_a in 1:Na, i_k in 1:Nk
        i_hc = 1  # Human capital fixed as not relevant
        assets = a_grid[i_a]
        capital = k_grid[i_k]
        function obj_wrapper(h_vec::Vector, grad::Vector)
            f = obj_last_period(model, h_vec, assets, capital, T, grad)
            if length(grad) > 0
                grad[:] = -grad[:]  # Negate for minimization
            end
            return -f  # Minimize negative utility
        end
        opt = Opt(:LD_SLSQP, 1)
        lower_bounds!(opt, [0.0])
        upper_bounds!(opt, [1.0])
        ftol_rel!(opt, 1e-8)
        min_objective!(opt, obj_wrapper)
        init = [0.3]
        (minf, h_vec, ret) = optimize(opt, init)
        model.sol_h[T, i_a, i_k, i_hc] = h_vec[1]
        model.sol_c[T, i_a, i_k, i_hc] = (1.0 + model.r) * assets + wage_func(model, capital, T) * h_vec[1] + model.y
        model.sol_v[T, i_a, i_k, i_hc] = -minf
        model.sol_i[T, i_a, i_k, i_hc] = 0.0  # No investment in parent-only periods
        model.sol_e[T, i_a, i_k, i_hc] = 0.0  # No education in parent-only periods
        model.sol_t[T, i_a, i_k, i_hc] = 0.0  # No time spent in parent-only periods
        model.sol_b[T, i_a, i_k, i_hc] = 0.0  # No altruistic transfer in parent-only periods
    end

    # ----- Post-separation periods (t = T-1 to T_terminal + 1) -----
    for t in (T-1):-1:(T_terminal + 1)
        println("Solving period $t ... (parent only)")
        interp = create_interp2(model, model.sol_v, t + 1)
        for i_a in 1:Na, i_k in 1:Nk
            i_hc = 1
            assets = a_grid[i_a]
            capital = k_grid[i_k]
            function obj_wrapper(x::Vector, grad::Vector)
                f = obj_work_period(model, x, assets, capital, t, interp, grad)
                if length(grad) > 0
                    grad[:] = -grad[:]  # Negate for minimization
                end
                return -f  # Minimize negative value function
            end
            function constraint_wrapper(x::Vector, grad::Vector)
                    return asset_constraint(x, grad, model, assets, capital, t)
            end
            opt = Opt(:LD_SLSQP, 2)
            lower_bounds!(opt, [0.01, 0.0])
            upper_bounds!(opt, [30.0, 1.0])
            ftol_rel!(opt, 1e-8)
            min_objective!(opt, obj_wrapper)
            inequality_constraint!(opt, constraint_wrapper, 0.0)
            init = [model.sol_c[t+1, i_a, i_k, i_hc], model.sol_h[t+1, i_a, i_k, i_hc]]
            (minf, x_opt, ret) = optimize(opt, init)
            model.sol_c[t, i_a, i_k, i_hc] = x_opt[1]
            model.sol_h[t, i_a, i_k, i_hc] = x_opt[2]
            model.sol_v[t, i_a, i_k, i_hc] = -minf
            model.sol_i[t, i_a, i_k, i_hc] = 0.0  # No investment in parent-only periods
            model.sol_e[t, i_a, i_k, i_hc] = 0.0  # No education in parent-only periods
            model.sol_t[t, i_a, i_k, i_hc] = 0.0  # No time spent in parent-only periods
            model.sol_b[t, i_a, i_k, i_hc] = 0.0  # No altruistic transfer in parent-only periods

        end
    end

    # ----- Separation period (t = T_terminal) -----
    t = T_terminal
    println("Solving separation period $t with altruistic transfer...")
    interp = create_interp(model, model.sol_v, t + 1)
    for i_a in 1:Na, i_k in 1:Nk, i_hc in 1:Nhc
        assets = a_grid[i_a]
        capital = k_grid[i_k]
        log_HC = hc_grid[i_hc]
        
        function obj_wrapper(x::Vector, grad::Vector)
            c_p, i_c, e_p, h_p, t_p, b= x[1], x[2], x[3], x[4], x[5], x[6]
            f = obj_terminal_period(model, c_p, i_c, e_p, h_p, t_p, b, assets, log_HC, capital, t, grad, interp, V_child_interp)
            if length(grad) > 0
                grad[:] = -grad[:]  # NLopt minimizes, so negate gradient
            end
            return -f
        end

        # Set up the optimization problem   
        opt = Opt(:LD_SLSQP, 6)
        lower_bounds!(opt, [0.01, 1e-6, 0.01, 1e-6, 1e-6, 0.0])
        upper_bounds!(opt, [30.0, 1.0, 30.0, 1.0, 1.0, 50.0])
        min_objective!(opt, obj_wrapper)
        inequality_constraint!(opt, constraint_min_leisure, 1e-4)
        inequality_constraint!(opt, constraint_child_time, 1e-4)
        inequality_constraint!(opt, (x, grad) -> asset_constraint_terminal(x, grad, model, capital, t, assets), 1e-5)
        init = [3.0, 0.5, 2.0, 0.7, 0.1, 0.5 * assets]
        ftol_rel!(opt, 1e-8)
        maxeval!(opt, 5000)
        (minf, x_opt, ret) = optimize(opt, init)
        model.sol_c[t, i_a, i_k, i_hc] = x_opt[1]
        model.sol_i[t, i_a, i_k, i_hc] = x_opt[2]
        model.sol_e[t, i_a, i_k, i_hc] = x_opt[3]
        model.sol_h[t, i_a, i_k, i_hc] = x_opt[4]
        model.sol_t[t, i_a, i_k, i_hc] = x_opt[5]
        model.sol_b[t, i_a, i_k, i_hc] = x_opt[6]  # Store transfer
        model.sol_v[t, i_a, i_k, i_hc] = -minf
    end

    # ----- Earlier periods (t = T_terminal-1 to 8) -----
    for t in (T_terminal-1):-1:8
        # Unchanged from original, using full model
        interp = create_interp(model, model.sol_v, t + 1)
        for i_a in 1:Na, i_k in 1:Nk, i_hc in 1:Nhc
            assets = a_grid[i_a]
            capital = k_grid[i_k]
            HC = hc_grid[i_hc]
            function obj_wrapper(x::Vector, grad::Vector)
                c_p, i_c, e_p, h_p, t_p = x
                f = obj_interactive_part(model, c_p, i_c, e_p, h_p, t_p, assets, HC, capital, t, interp, grad)
                if length(grad) > 0
                    grad[:] = -grad[:]
                end
                return -f
            end
            opt = Opt(:LD_SLSQP, 5)
            lower_bounds!(opt, [0.01, 0.01, 0.01, 1e-6, 1e-6])
            upper_bounds!(opt, [30, 1.0, 30, 1.0, 1.0])
            inequality_constraint!(opt, constraint_min_leisure, 1e-4)
            inequality_constraint!(opt, constraint_child_time, 1e-4)
            inequality_constraint!(opt, (x, grad) -> asset_constraint_full(x, grad, model, capital, t, assets), 1e-5)
            min_objective!(opt, obj_wrapper)
            init = [model.sol_c[t+1, i_a, i_k, i_hc], model.sol_i[t+1, i_a, i_k, i_hc], model.sol_e[t+1, i_a, i_k, i_hc], 
                    model.sol_h[t+1, i_a, i_k, i_hc], model.sol_t[t+1, i_a, i_k, i_hc]]
            ftol_rel!(opt, 1e-8)
            maxeval!(opt, 5000)
            (minf, x_opt, ret) = optimize(opt, init)
            model.sol_c[t, i_a, i_k, i_hc] = x_opt[1]
            model.sol_i[t, i_a, i_k, i_hc] = x_opt[2]
            model.sol_e[t, i_a, i_k, i_hc] = x_opt[3]
            model.sol_h[t, i_a, i_k, i_hc] = x_opt[4]
            model.sol_t[t, i_a, i_k, i_hc] = x_opt[5]
            model.sol_b[t, i_a, i_k, i_hc] = 0.0  # No altruistic transfer in parent-only periods
            model.sol_v[t, i_a, i_k, i_hc] = -minf
        end
    end

    # ----- Parent-only periods (t = 7 to 1) -----
    for t in 7:-1:1
        # Unchanged from original
        interp = create_interp(model, model.sol_v, t + 1)
        for i_a in 1:Na, i_k in 1:Nk, i_hc in 1:Nhc
            assets = a_grid[i_a]
            capital = k_grid[i_k]
            HC = hc_grid[i_hc]
            function obj_wrapper(x::Vector, grad::Vector)
                c_p, e_p, h_p, t_p = x
                f = obj_parent_only(model, c_p, e_p, h_p, t_p, assets, HC, capital, t, interp, grad)
                if length(grad) > 0
                    grad[:] = -grad[:]
                end
                return -f
            end
            opt = Opt(:LD_SLSQP, 4)
            lower_bounds!(opt, [0.01, 0.01, 1e-6, 1e-6])
            upper_bounds!(opt, [30, 30, 1.0, 1.0])
            inequality_constraint!(opt, constraint_min_leisure, 1e-3)
            inequality_constraint!(opt, (x, grad) -> asset_constraint_parentonly(x, grad, model, capital, t, assets), 1e-5)
            min_objective!(opt, obj_wrapper)
            init = [model.sol_c[t+1, i_a, i_k, i_hc], model.sol_e[t+1, i_a, i_k, i_hc], 
                    model.sol_h[t+1, i_a, i_k, i_hc], model.sol_t[t+1, i_a, i_k, i_hc]]
            ftol_rel!(opt, 1e-6)
            maxeval!(opt, 5000)
            (minf, x_opt, ret) = optimize(opt, init)
            model.sol_c[t, i_a, i_k, i_hc] = x_opt[1]
            model.sol_e[t, i_a, i_k, i_hc] = x_opt[2]
            model.sol_h[t, i_a, i_k, i_hc] = x_opt[3]
            model.sol_t[t, i_a, i_k, i_hc] = x_opt[4]
            model.sol_i[t, i_a, i_k, i_hc] = 0.0
            model.sol_b[t, i_a, i_k, i_hc] = 0.0  # No altruistic transfer in parent-only periods
            model.sol_v[t, i_a, i_k, i_hc] = -minf
        end
    end
end

# ------------------------------------------------
# Supporting Functions
# ------------------------------------------------

@inline function obj_last_period(model::Parent_model, h_vec::Vector, assets::Float64, capital::Float64, t::Int, grad::Vector)
    h = h_vec[1]
    w = wage_func(model, capital, t)
    income = w * h
    c = assets + income + model.y

    u = util(model, c, h)             # Your utility function
    du_dc = c^(-model.rho)
    du_dh = -model.phi * h^model.eta
    du_dh_total = w * du_dc + du_dh

    if length(grad) > 0
        grad[1] = du_dh_total
    end
    return u
end


@inline function obj_work_period(model::Parent_model, x::Vector, assets::Float64, capital::Float64, t::Int, interp, grad::Vector)
    c, h = x[1], x[2]
    w = wage_func(model, capital, t)
    income = w * h
    a_next = (1.0 + model.r) * assets + income - c + model.y
    k_next = capital + h
    V_next = interp(a_next, k_next)
    util_now = util(model, c, h)
    dutil_dc = c^(-model.rho)
    dutil_dh = -model.phi * h^model.eta
    V = util_now + model.beta_vector[t] * V_next  # Updated to use time-varying beta
    if length(grad) > 0
        grad_V_next = Interpolations.gradient(interp, a_next, k_next)
        dV_next_da = grad_V_next[1]
        dV_next_dk = grad_V_next[2]
        dV_dc = dutil_dc - model.beta_vector[t] * dV_next_da
        dV_dh = dutil_dh + model.beta_vector[t] * (w * dV_next_da + dV_next_dk)
        grad[1] = dV_dc
        grad[2] = dV_dh
    end
    return V
end

@inline function obj_terminal_period(
    model::Parent_model, c_p::Float64, i_c::Float64, e_p::Float64, h_p::Float64, t_p::Float64,
    b::Float64, assets::Float64, log_HC::Float64, capital::Float64, t::Int, grad::Vector
)
    # c_p, i_c, e_p, h_p, t_p, b = x[1], x[2], x[3], x[4], x[5], x[6]
    w = wage_func(model, capital, t)
    a_next = (1.0 + model.r) * assets + w * h_p + model.y - c_p - e_p - b
    if a_next < model.a_min || b < 0
        return Inf
    end
    leisure_p = 1.0 - h_p - t_p
    leisure_c = 1.0 - t_p - i_c
    log_HC_next = HC_technology_full(model, t_p, e_p, log_HC, i_c, t)
    HC_next = exp(log_HC_next)
    util_now = util_total(model, c_p, h_p, t_p, i_c, log_HC, t)
    V_child_value = V_child_interp(b, HC_next)
    V_next = interp(a_next, capital + h_p, log_HC_next)
    f = util_now + model.lambda * V_child_value + model.beta_vector[t] * V_next

    if length(grad) > 0
        # Gradients of V_next and V_child
        ∂V_next = gradient(interp, a_next, capital + h_p, log_HC_next)
        ∂V_∂a, ∂V_∂k, ∂V_∂logHC = ∂V_next
        ∂V_child_∂b, ∂V_child_∂HC = gradient(V_child_interp, b, HC_next)

        # Common terms
        dlogHC_dt = model.sigma_1_vector[t] / t_p
        dlogHC_de = model.sigma_2_vector[t] / e_p
        dlogHC_di = model.sigma_4_vector[t] / i_c
        term_leisure_p = -model.phi_2_vector[t] / leisure_p
        term_leisure_c = -(1 - model.mu_vector[t]) * model.lambda_1_vector[t] / leisure_c

        # Gradients
        grad[1] = model.phi_1_vector[t] / c_p - model.beta_vector[t] * ∂V_∂a  # c_p
        grad[2] = term_leisure_c + model.lambda * ∂V_child_∂HC * HC_next * dlogHC_di + model.beta_vector[t] * ∂V_∂logHC * dlogHC_di  # i_c
        grad[3] = model.lambda * ∂V_child_∂HC * HC_next * dlogHC_de + model.beta_vector[t] * (-∂V_∂a + ∂V_∂logHC * dlogHC_de)  # e_p
        grad[4] = term_leisure_p + model.beta_vector[t] * (w * ∂V_∂a + ∂V_∂k)  # h_p
        grad[5] = term_leisure_p + term_leisure_c + model.lambda * ∂V_child_∂HC * HC_next * dlogHC_dt + model.beta_vector[t] * ∂V_∂logHC * dlogHC_dt  # t_p
        grad[6] = model.lambda * ∂V_child_∂b - model.beta_vector[t] * ∂V_∂a  # b
    end

    return f
end


@inline function obj_interactive_part(
    model::Parent_model, c_p::Float64, i_c::Float64, e_p::Float64, h_p::Float64, t_p::Float64,
    assets::Float64, log_HC::Float64, capital::Float64, t::Int, interp, grad::Vector
)
    # c_p, i_c, e_p, h_p, t_p= x[1], x[2], x[3], x[4], x[5]
    w = wage_func(model, capital, t)
    income = w * h_p
    a_next = (1.0 + model.r) * assets + income + model.y - c_p - e_p
    k_next = capital + h_p
    leisure_p = 1.0 - h_p - t_p
    leisure_c = 1.0 - t_p - i_c
    log_HC_next = HC_technology_full(model, t_p, e_p, log_HC, i_c, t)
    util_now = util_total(model, c_p, h_p, t_p, i_c, log_HC, t)
    V_next = interp(a_next, k_next, log_HC_next)
    f = util_now + model.beta_vector[t] * V_next

    if length(grad) > 0
        ∇V_next = gradient(interp, a_next, k_next, log_HC_next)
        ∂V_∂a, ∂V_∂k, ∂V_∂logHC = ∇V_next
        dlogHC_de = model.sigma_2_vector[t] / e_p
        dlogHC_dt = model.sigma_1_vector[t] / t_p
        term_leisure_p = model.phi_2_vector[t] / leisure_p * (-1)
        term_leisure_c = (1 - model.mu_vector[t]) * model.lambda_1_vector[t] / leisure_c * (-1)
        grad[1] = model.phi_1_vector[t] / c_p + model.beta_vector[t] * ∂V_∂a * (-1)
        grad[2] = term_leisure_c + model.beta_vector[t] * ∂V_∂logHC * (model.sigma_4_vector[t] / i_c)
        grad[3] = model.beta_vector[t] * (∂V_∂a * (-1) + ∂V_∂logHC * dlogHC_de)
        grad[4] = term_leisure_p + model.beta_vector[t] * (∂V_∂a * w + ∂V_∂k)
        grad[5] = term_leisure_p + term_leisure_c + model.beta_vector[t] * ∂V_∂logHC * dlogHC_dt
    end
    return f
end

@inline function obj_parent_only(
    model::Parent_model, c_p::Float64, e_p::Float64, h_p::Float64, t_p::Float64,
    assets::Float64, log_HC::Float64, capital::Float64, t::Int, interp, grad::Vector
)
    # c_p, e_p, h_p, t_p = x[1], x[2], x[3], x[4]
    w = wage_func(model, capital, t)
    income = w * h_p
    a_next = (1.0 + model.r) * assets + income + model.y - c_p - e_p
    k_next = capital + h_p
    log_HC_next = HC_technology_parentonly(model, t_p, e_p, log_HC, t)
    leisure = 1.0 - h_p - t_p
    util_now = util_parent(model, c_p, h_p, t_p, log_HC, t)
    V_next = interp(a_next, k_next, log_HC_next)

    if length(grad) > 0
        ∇V_next = gradient(interp, a_next, k_next, log_HC_next)
        ∂V_∂a, ∂V_∂k, ∂V_∂logHC = ∇V_next
        dlogHC_de = model.sigma_2_vector[t] / e_p
        dlogHC_dt = model.sigma_1_vector[t] / t_p
        grad[1] = model.phi_1_vector[t] / c_p - model.beta_vector[t] * ∂V_∂a
        grad[2] = model.beta_vector[t] * (-∂V_∂a + ∂V_∂logHC * dlogHC_de)
        grad[3] = -model.phi_2_vector[t] / leisure + model.beta_vector[t] * (w * ∂V_∂a + ∂V_∂k)
        grad[4] = -model.phi_2_vector[t] / leisure + model.beta_vector[t] * ∂V_∂logHC * dlogHC_dt
    end
    return util_now + model.beta_vector[t] * V_next
end
# ------------------------------------------------
# Utlity Functions
# ------------------------------------------------
@inline function util(model::Parent_model, c, h)
    if model.rho == 1.0
        cons_utility = log(c)
    else
        cons_utility = (c^(1.0 - model.rho)) / (1.0 - model.rho)
    end
    labor_disutility = model.phi * (h^(1.0 + model.eta)) / (1.0 + model.eta)
    return cons_utility - labor_disutility
end

@inline function util_total(model::Parent_model, c::Float64, h_p::Float64,
                            t_p::Float64, i_c::Float64, log_HC::Float64, t::Int)
    leisure_p = 1.0 - h_p - t_p
    leisure_c = 1.0 - t_p - i_c
    if leisure_p <= 0.0 || c <= 0.0 || i_c <= 0.0 || leisure_c <= 0.0
        return -Inf
    end
    u_parent = model.phi_1_vector[t] * log(c) + model.phi_2_vector[t] * log(leisure_p)
    u_child  = model.mu_vector[t] * model.phi_3_vector[t] * log_HC + 
            (1 - model.mu_vector[t]) * (model.lambda_1_vector[t] * log(leisure_c) + model.lambda_2_vector[t] * log_HC)

    return u_parent + u_child
end

@inline function util_parent(model::Parent_model, c::Float64, h_p::Float64, t_p::Float64, log_HC::Float64, t::Int)
    leisure = 1.0 - h_p - t_p
    if leisure <= 0.0 || c <= 0.0
        return -Inf
    end
    return model.phi_1_vector[t] * log(c) + model.phi_2_vector[t] * log(leisure) + model.phi_3_vector[t] * log_HC
end
# ------------------------------------------------
# Human Cpatial Functions
# ------------------------------------------------
@inline function HC_technology_full(model::Parent_model, t_p, e_p, log_HC , i_c, t)
    if t_p <= 0.0 || e_p <= 0.0
        return -1e8  # or your model.a_min or something small, but NOT -Inf
    end
    return log(model.R_vector[t]) +
        model.sigma_1_vector[t] * log(t_p) +
        model.sigma_2_vector[t] * log(e_p) +
        model.sigma_3_vector[t] * log_HC  +
        model.sigma_4_vector[t] * log(i_c)
  
end

@inline function HC_technology_parentonly(model::Parent_model, t_p::Float64, e_p::Float64, log_HC::Float64, t::Int)
    if t_p <= 0.0 || e_p <= 0.0
        return -1e8 
    end
    return log(model.R_vector[t]) +
            model.sigma_1_vector[t] * log(t_p) +
            model.sigma_2_vector[t] * log(e_p) + 
            model.sigma_3_vector[t] * log_HC
end


# ------------------------------------------------
# Budget constraints
# ------------------------------------------------
@inline function asset_constraint_terminal(x::Vector, grad::Vector, model::Parent_model, capital::Float64, t::Int, assets::Float64)
    c_p, i_c, e_p, h_p, t_p, b = x[1], x[2], x[3], x[4], x[5], x[6]
    w = wage_func(model, capital, t)
    a_next = (1.0 + model.r) * assets + w * h_p + model.y - c_p - e_p - b
    if length(grad) > 0
        grad[1] = 1.0   # ∂g/∂c_p
        grad[2] = 0.0   # ∂g/∂i_c
        grad[3] = 1.0   # ∂g/∂e_p
        grad[4] = -w    # ∂g/∂h_p
        grad[5] = 0.0   # ∂g/∂t_p
        grad[6] = -1.0  # ∂g/∂b
    end
    return 0.0 - a_next  # g(x) <= 0 ensures a_next >= a_min
end

@inline function asset_constraint_full(x::Vector, grad::Vector, model::Parent_model, capital::Float64, t::Int, assets::Float64)
    c_p, i_c, e_p, h_p, t_p = x
    w = wage_func(model, capital, t)
    a_next = (1.0 + model.r) * assets + w * h_p + model.y - c_p - e_p
    if length(grad) > 0
        grad[1] = 1.0   # ∂g/∂c_p
        grad[2] = 0.0   # ∂g/∂i_c
        grad[3] = 1.0   # ∂g/∂e_p
        grad[4] = -w    # ∂g/∂h_p
        grad[5] = 0.0   # ∂g/∂t_p
    end
    return 0.0 - a_next  # a_next >= a_min, so g(x) ≤ 0
end




@inline function asset_constraint_parentonly(x::Vector, grad::Vector, model::Parent_model, capital::Float64, t::Int, assets::Float64)
    c_p, e_p, h_p, t_p = x
    w = wage_func(model, capital, t)
    a_next = (1.0 + model.r) * assets + w * h_p + model.y - c_p - e_p
    if length(grad) > 0
        grad[1] = 1.0   # ∂g/∂c_p
        grad[2] = 1.0   # ∂g/∂e_p
        grad[3] = -w    # ∂g/∂h_p
        grad[4] = 0.0   # ∂g/∂t_p
    end
    return 0.0 - a_next  # a_next >= a_min, so g(x) ≤ 0
end
# ------------------------------------------------
# Time constraints
# ------------------------------------------------
@inline function constraint_min_leisure(x::Vector, grad::Vector)
    n = length(x)
    h_p = x[n-1]  # Second-to-last element
    t_p = x[n]    # Last element
    if length(grad) > 0
        grad .= 0.0
        grad[n-1] = 1.0  # ∂g/∂h_p
        grad[n]   = 1.0  # ∂g/∂t_p
    end
    return (h_p + t_p) - 1.0  # h_p + t_p <= 1.0
end



@inline function constraint_child_time(x::Vector, grad::Vector)
    # Determine i_c and t_p based on vector length
    if length(x) == 4  # Parent-only periods: [c_p, e_p, h_p, t_p]
        i_c = 0.0  # No child study time in parent-only periods
        t_p = x[4]
        if length(grad) > 0
            grad .= 0.0
            grad[4] = 1.0  # ∂g/∂t_p
        end
    elseif length(x) == 5  # Work periods: [c_p, i_c, e_p, h_p, t_p]
        i_c = x[2]
        t_p = x[5]
        if length(grad) > 0
            grad .= 0.0
            grad[2] = 1.0  # ∂g/∂i_c
            grad[5] = 1.0  # ∂g/∂t_p
        end
    else
        error("Unexpected vector length in constraint_child_time: $(length(x))")
    end
    return (i_c + t_p) - 1.0  # i_c + t_p <= 1.0
end
# ------------------------------------------------
# Interpolation Functions
# ------------------------------------------------

function create_interp2(model::Parent_model, sol_v::Array{Float64, 3}, t::Int)
    return LinearInterpolation((model.a_grid, model.k_grid), sol_v[t, :, :], extrapolation_bc=Line())
end

function create_interp(model::Parent_model, sol_v, t)
    itp = interpolate(
        (model.a_grid, model.k_grid, model.hc_grid),
        sol_v[t, :, :, :],
        Gridded(Linear())
    )
    return extrapolate(itp, Line())
end


# ------------------------------------------------
# Debug Functions
# ------------------------------------------------

function result_type_name(ret)
    if ret == :FTOL_REACHED || ret == :XTOL_REACHED
        return "converged"
    elseif ret == :MAXEVAL_REACHED
        return "maxeval"
    else
        return "other"
    end
end

function print_period_stats(t, converge_count, maxeval_count, other_dict, itercounts, total)
    avg_iter = round(mean(itercounts), digits=2)
    println("Period $t: Converged: $(round(converge_count/total*100, digits=1))%, Maxeval: $(round(maxeval_count/total*100, digits=1))%, Other: $(round(sum(values(other_dict))/total*100, digits=1))%, Avg iters: $avg_iter")
    if sum(values(other_dict)) > 0
        println("    Other status codes:")
        for (code, count) in other_dict
            println("        $code : $count times ($(round(count/total*100, digits=1))%)")
        end
    end
end
