In [19]:
# check current directory

pwd()

"/Users/xyliu/Desktop/ISE 616 final/ISE-616-final"

In [20]:
using JuMP
using Ipopt   
using LinearAlgebra
using MathOptInterface
const MOI = MathOptInterface



MathOptInterface

## function of one-hot encoding

In [21]:
############################
# Encoding information type
############################

"""
    ZGEncodingInfo

Stores the information needed to convert between
original (integer-coded) categorical/group features
and their one-hot encodings.

Fields:
- k_z::Vector{Int}     : number of categories for each categorical component (length m)
- z_start::Vector{Int} : start index (column) of each component in Z_onehot (length m)
- num_g::Int           : total number of groups
"""

struct ZGEncodingInfo
    k_z::Vector{Int}
    z_start::Vector{Int}
    num_g::Int
end



#############################################
# Original (Z, group) → Reduced (Z_enc, G_enc)
#############################################

"""
    encode_zg_reduced(Z::AbstractMatrix{<:Integer},
                      group::AbstractVector{<:Integer})

Encode categorical features Z and group indexes into a reduced dummy
representation:

- Each categorical component l with k_l levels is mapped to (k_l - 1)
  dummy variables. The last category k_l is the baseline (all zeros).
- The group variable with num_g levels is mapped to (num_g - 1)
  dummy variables. The last group num_g is the baseline (all zeros).

Input:
- Z     : N × m matrix, each column l stores an integer in {1, ..., k_l}
- group : length-N vector, each entry in {1, ..., num_g}

Output:
- Z_enc : N × sum_l (k_l - 1) matrix (reduced dummies for Z)
- G_enc : N × (num_g - 1) matrix (reduced dummies for group)
- info  : ZGEncodingInfo, stores the encoding scheme for later decoding
"""
function encode_zg_reduced(Z::AbstractMatrix{<:Integer},
                           group::AbstractVector{<:Integer})
    N, m = size(Z)
    @assert length(group) == N "group must have length N"

    # k_z[l] = number of categories for component l (assumed {1, ..., k_l})
    k_z = [maximum(Z[:, l]) for l in 1:m]

    # Column start indices in the reduced dummy matrix Z_enc.
    # Component l uses (k_z[l] - 1) columns (baseline = category k_z[l]).
    z_start = Vector{Int}(undef, m)
    col = 1
    for l in 1:m
        z_start[l] = col
        d_l = max(k_z[l] - 1, 0)  # number of dummies for component l
        col += d_l
    end
    total_z = col - 1   # total number of columns in Z_enc

    # Allocate Z_enc: N × total_z
    Z_enc = zeros(Int8, N, total_z)

    for i in 1:N
        for l in 1:m
            val = Z[i, l]
            @assert 1 ≤ val ≤ k_z[l] "Category out of range in column $l (row $i)"

            k_l = k_z[l]
            d_l = max(k_l - 1, 0)

            # If d_l == 0, there is effectively a single category (no column needed).
            if d_l > 0 && val < k_l
                col0 = z_start[l]              # start index for component l
                Z_enc[i, col0 + val - 1] = 1   # categories 1..(k_l-1)
            end
            # If val == k_l, baseline category -> all zeros for this component
        end
    end

    # Encode group using (num_g - 1) dummies; last group is baseline.
    num_g = maximum(group)
    d_g   = max(num_g - 1, 0)
    G_enc = zeros(Int8, N, d_g)

    if d_g > 0
        for i in 1:N
            gi = group[i]
            @assert 1 ≤ gi ≤ num_g "Group index out of range at row $i"

            if gi < num_g
                # Non-baseline group in {1, ..., num_g - 1}
                G_enc[i, gi] = 1
            else
                # Baseline group num_g -> all zeros
            end
        end
    end

    info = ZGEncodingInfo(k_z, z_start, num_g)
    return Z_enc, G_enc, info
end


encode_zg_reduced

## test one-hot encode

In [22]:
# testing
Z = [
    1  1;   # row 1
    2  3;   # row 2
    3  2;   # row 3
    1  3;   # row 4
]
# group: length-N vector, values in {1,2,3}
group = [1, 2, 1, 3]

#######################
# 2. Run encoding
#######################

Z_enc, G_onehot, info = encode_zg_reduced(Z, group)
println("Original Z:")
println(Z)
println()

println("Encoded Z_enc (reduced one-hot):")
println(Z_enc)
println("size(Z_enc) = ", size(Z_enc))
println()

println("Group one-hot G_onehot:")
println(G_onehot)
println("size(G_onehot) = ", size(G_onehot))
println()

println("Encoding info:")
println("  k_z     = ", info.k_z)
println("  z_start = ", info.z_start)
println("  num_g   = ", info.num_g)

Original Z:
[1 1; 2 3; 3 2; 1 3]

Encoded Z_enc (reduced one-hot):
Int8[1 0 1 0; 0 1 0 0; 0 0 0 1; 1 0 0 0]
size(Z_enc) = (4, 4)

Group one-hot G_onehot:
Int8[1 0; 0 1; 1 0; 0 0]
size(G_onehot) = (4, 2)

Encoding info:
  k_z     = [3, 3]
  z_start = [1, 3]
  num_g   = 3


In [23]:
#############################################
# Reduced (Z_enc, G_enc) → Original (Z, group)
#############################################

"""
    decode_zg_reduced(Z_enc::AbstractMatrix{<:Integer},
                      G_enc::AbstractMatrix{<:Integer},
                      info::ZGEncodingInfo)

Decode reduced dummy encoded categorical and group features back to
their original integer-coded form.

For each categorical component l with k_l levels:
- If the corresponding block has all zeros, we recover category k_l
  (the baseline).
- Otherwise, the position of the 1 determines the category in {1, ..., k_l - 1}.

For the group variable with num_g levels:
- If the row in G_enc is all zeros, we recover group num_g (baseline).
- Otherwise, the position of the 1 determines the group in {1, ..., num_g - 1}.

Input:
- Z_enc : N × sum_l (k_l - 1) matrix (reduced dummies for Z)
- G_enc : N × (num_g - 1) matrix (reduced dummies for group)
- info  : ZGEncodingInfo produced earlier by `encode_zg_reduced`

Output:
- Z     : N × m matrix, each entry in {1, ..., k_l}
- group : length-N vector, each entry in {1, ..., num_g}
"""
function decode_zg_reduced(Z_enc::AbstractMatrix{<:Integer},
                           G_enc::AbstractMatrix{<:Integer},
                           info::ZGEncodingInfo)
    N, total_z = size(Z_enc)
    m = length(info.k_z)

    # Check group encoding consistency
    num_g = info.num_g
    d_g   = max(num_g - 1, 0)
    @assert size(G_enc, 1) == N "Z_enc and G_enc must have same number of rows"
    @assert size(G_enc, 2) == d_g "G_enc columns must be num_g - 1"

    # Recover categorical matrix Z (N × m)
    Z = zeros(Int, N, m)
    for l in 1:m
        k_l = info.k_z[l]
        d_l = max(k_l - 1, 0)
        s   = info.z_start[l]
        e   = s + d_l - 1

        if d_l == 0
            # Only one category exists for this component; always category 1.
            for i in 1:N
                Z[i, l] = 1
            end
            continue
        end

        @assert e ≤ total_z "Z_enc has too few columns for component $l"

        # View the reduced dummy block for component l
        sub = @view Z_enc[:, s:e]  # N × d_l
        for i in 1:N
            idx = findfirst(==(1), sub[i, :])
            if idx === nothing
                # All zeros: baseline category k_l
                Z[i, l] = k_l
            else
                # Non-baseline category in {1, ..., k_l - 1}
                Z[i, l] = idx
            end
        end
    end

    # Recover group vector (length N) from reduced dummies
    group = Vector{Int}(undef, N)
    if d_g == 0
        # Only one group; always group 1
        for i in 1:N
            group[i] = 1
        end
    else
        for i in 1:N
            idx = findfirst(==(1), G_enc[i, :])
            if idx === nothing
                # Baseline group num_g
                group[i] = num_g
            else
                # Non-baseline group in {1, ..., num_g - 1}
                group[i] = idx
            end
        end
    end

    return Z, group
end


decode_zg_reduced

## test one-hot decode

In [24]:
Z, group = decode_zg_reduced(Z_enc, G_onehot, info)
Z, group

([1 1; 2 3; 3 2; 1 3], [1, 2, 1, 3])

## DAG builder

In [25]:
###############################
# Graph structure for sample-level DAGs
# compatible with Theorem 2 (graph-based formulation)
###############################

# -----------------------------
# Arc kind (structure only)
# -----------------------------

"""
    ArcKind

Abstract type for different kinds of arcs in the DAG.
We distinguish between:
  - `CatArc` : categorical feature transitions
  - `TermArc`: terminal transitions to (m+1, 0) with a chosen group g
"""
abstract type ArcKind end

"""
    CatArc

Categorical arc corresponding to choosing category `c` for component `k`.

Fields:
- k      : which categorical component (1..m)
- c      : chosen category in {1, ..., k_z[k]}
- d_prev : previous accumulated categorical distance at state (k-1, d_prev)
- d      : new accumulated distance at state (k, d)
"""

struct CatArc <: ArcKind
    k::Int
    c::Int
    d_prev::Float64
    d::Float64
end

"""
    TermArc

Terminal arc from (m, d) to (m+1, 0) with a chosen destination group g.

Fields:
- d : accumulated categorical distance at state (m, d)
- g : destination group in {1, ..., num_g}
"""
struct TermArc <: ArcKind
    d::Float64
    g::Int
end

# -----------------------------
# Arc and per-sample DAG
# -----------------------------

"""
    Arc

Directed edge in the DAG for a single sample i.

Fields:
- src  : index of the source node in `nodes`
- dst  : index of the destination node in `nodes`
- kind : arc type (CatArc or TermArc), which encodes all information
         needed to build w^i(e; β, λ, r_i, ...) later
"""
struct Arc
    src::Int
    dst::Int
    kind::ArcKind
end






"""
    SampleDAG

Graph structure associated with a single sample i, compatible with
Theorem 2 (graph-based formulation).

Fields:
- sample_index : index i of the sample this DAG corresponds to
- nodes        : vector of DP states (k, d), plus the terminal node (m+1, 0.0)
                 each node is a Tuple{Int,Float64}
- arcs         : list of directed edges with their arc kind (no numeric weight)
- source       : index of the source node (corresponds to state (0, 0.0))
- sink         : index of the terminal node (corresponds to state (m+1, 0.0))
"""
struct SampleDAG
    sample_index::Int
    nodes::Vector{Tuple{Int,Float64}}
    arcs::Vector{Arc}
    source::Int
    sink::Int
end




# -----------------------------
# Categorical encoding info
# -----------------------------
"""
    CatEncodingInfo

Encoding information for categorical features (structure level only).

Fields:
- k_z::Vector{Int} : number of categories per component (length m).
                     For component k, categories are {1, ..., k_z[k]}.
- num_g::Int       : total number of groups (used for terminal arcs).
"""
struct CatEncodingInfo
    k_z::Vector{Int}
    num_g::Int
end


# -----------------------------
# Build DAG structure for one sample i
# -----------------------------
"""
    build_sample_dag_structure(i, info, delta, z_i) -> SampleDAG

Build the DAG structure G^i = (V^i, A^i) for a fixed sample i,
compatible with Theorem 2 (graph-based formulation), using our setup.

This function ONLY builds:
  - the nodes (states (k, d) plus the terminal (m+1, 0)),
  - the arcs with their structural information (CatArc or TermArc).

It does NOT compute numeric weights w^i(e). Those should be constructed
later as expressions of (β, λ, r_i, y_i, g_i, B_{g_i}, C_{g_i}, ...).

Arguments:
- i      : sample index
- info   : CatEncodingInfo (k_z, num_g)
- delta  : length-m vector of δ_k in
           d(z, zᶦ) = Σ_k δ_k * 1[z_k ≠ z_kᶦ]
- z_i    : length-m vector of original categories for sample i,
           each z_i[k] ∈ {1, ..., k_z[k]}

Output:
- SampleDAG describing the structure of G^i
"""
function build_sample_dag_structure(
    i::Int,
    info::CatEncodingInfo,
    delta::AbstractVector{<:Real},
    z_i::AbstractVector{<:Integer}
)::SampleDAG
    k_z   = info.k_z
    num_g = info.num_g
    m     = length(k_z)

    # -------------------------
    # 0. Sanity checks
    # -------------------------
    @assert length(delta) == m "delta must have length m"
    @assert length(z_i) == m "z_i must have length m"

    # -------------------------
    # 1. Enumerate states (k, d) with per-layer dedup
    # -------------------------
    nodes = Vector{Tuple{Int,Float64}}()
    node_index = Dict{Tuple{Int,Float64},Int}()

    # Source state (0, 0.0)
    push!(nodes, (0, 0.0))
    node_index[(0, 0.0)] = 1
    source_idx = 1

    arcs = Vector{Arc}()

    # current_layer holds all unique states (k-1, d_prev)
    current_layer = [(0, 0.0)]

    for k in 1:m
        next_layer = Tuple{Int,Float64}[]
        seen_next = Set{Tuple{Int,Float64}}()

        k_l   = k_z[k]
        δ_k   = float(delta[k])
        z_i_k = z_i[k]

        for (k_prev, d_prev) in current_layer
            @assert k_prev == k - 1

            # Enumerate all categories c ∈ {1, ..., k_l}
            for c in 1:k_l
                mismatch = (c != z_i_k)
                d = d_prev + (mismatch ? δ_k : 0.0)
                state = (k, d)

                # Add state to global node list if new
                if !haskey(node_index, state)
                    push!(nodes, state)
                    node_index[state] = length(nodes)
                end

                # Ensure each (k, d) appears at most once in next_layer
                if !(state in seen_next)
                    push!(next_layer, state)
                    push!(seen_next, state)
                end

                # Add categorical arc from (k-1, d_prev) to (k, d)
                src = node_index[(k-1, d_prev)]
                dst = node_index[state]
                kind = CatArc(k, c, d_prev, d)
                push!(arcs, Arc(src, dst, kind))
            end
        end

        # Move to next layer (already deduplicated)
        current_layer = next_layer
    end

    # States with k = m are the final DP layer S_1^i
    # current_layer is already deduplicated, but we keep the name for clarity
    S1_states = current_layer

    # -------------------------
    # 2. Add terminal node (m+1, 0.0) and terminal arcs
    # -------------------------
    terminal_state = (m+1, 0.0)
    push!(nodes, terminal_state)
    node_index[terminal_state] = length(nodes)
    sink_idx = node_index[terminal_state]

    # For each (m, d) in S_1^i, and for each group g, add a TermArc
    for (k_state, d) in S1_states
        @assert k_state == m
        src = node_index[(m, d)]
        for g in 1:num_g
            kind = TermArc(d, g)
            push!(arcs, Arc(src, sink_idx, kind))
        end
    end

    # -------------------------
    # 3. Return SampleDAG
    # -------------------------
    return SampleDAG(i, nodes, arcs, source_idx, sink_idx)
end

build_sample_dag_structure

## test DAG builder

In [26]:
i = 2001

k_z   = [2, 2]      # each component 2 cats
delta = [1.0, 1.0]  # δ₁ = δ₂ = 1
z_i   = [1, 2]      # sample i：z₁ = 1, z₂ = 2
num_g = 2

info = CatEncodingInfo(k_z, num_g)

dag_i = build_sample_dag_structure(i, info, delta, z_i)
dag_i

SampleDAG(2001, [(0, 0.0), (1, 0.0), (1, 1.0), (2, 1.0), (2, 0.0), (2, 2.0), (3, 0.0)], Arc[Arc(1, 2, CatArc(1, 1, 0.0, 0.0)), Arc(1, 3, CatArc(1, 2, 0.0, 1.0)), Arc(2, 4, CatArc(2, 1, 0.0, 1.0)), Arc(2, 5, CatArc(2, 2, 0.0, 0.0)), Arc(3, 6, CatArc(2, 1, 1.0, 2.0)), Arc(3, 4, CatArc(2, 2, 1.0, 1.0)), Arc(4, 7, TermArc(1.0, 1)), Arc(4, 7, TermArc(1.0, 2)), Arc(5, 7, TermArc(0.0, 1)), Arc(5, 7, TermArc(0.0, 2)), Arc(6, 7, TermArc(2.0, 1)), Arc(6, 7, TermArc(2.0, 2))], 1, 7)

## modeling

In [27]:
using JuMP

"""
    build_group_dro_graph_model(
        X, Z, group, y,
        encinfo,
        delta,
        A_group,
        B_group, C_group,
        gamma_x,
        ε,
        optimizer
    ) -> (model, meta)

Build our group-dependent, graph-based DRO logistic regression model,
using the same reduced encoding convention as `ZGEncodingInfo` /
`encode_zg_reduced`.

Arguments
---------
- X        :: N × n_x matrix of continuous features.
- Z        :: N × m   matrix of original categorical features.
              Entry Z[i, k] ∈ {1, ..., k_z[k]} (no one-hot).

- group    :: length-N vector of group indices g_i ∈ {1,...,num_g}.
- y        :: length-N vector of labels in {-1, +1}.

- encinfo  :: ZGEncodingInfo
              (k_z, z_start, num_g) describing reduced encoding blocks
              for the categorical features and groups.

- delta    :: length-m vector δ_k used in
              d_cat(z, z^i) = Σ_k δ_k * 1[z_k ≠ z_k^i].

- A_group  :: length-num_g vector A_g
              continuous-part metric weights in
              A_g Σ_j γ_j |x_j - x_j^i|.

- B_group  :: length-num_g vector B_g
              categorical-part metric weight.

- C_group  :: length-num_g vector C_g
              group-change penalty.

- gamma_x  :: length n_x vector γ_j for continuous features.
              Continuous dual constraint will be
                  |β_{xj}| ≤ λ * A_min * γ_j,
              where A_min = minimum(A_group).

- ε        :: Wasserstein radius ε.

- optimizer: optimizer constructor for JuMP, e.g.
              optimizer_with_attributes(Mosek.Optimizer, "QUIET" => true)

Returns
-------
- model :: JuMP.Model
- meta  :: NamedTuple: (encinfo=encinfo, dags=dags, n_nodes=n_nodes)
"""
function build_group_dro_graph_model(
    X::AbstractMatrix{<:Real},
    Z::AbstractMatrix{<:Integer},
    group::AbstractVector{<:Integer},
    y::AbstractVector{<:Integer},
    encinfo::ZGEncodingInfo,
    delta::AbstractVector{<:Real},
    A_group::AbstractVector{<:Real},
    B_group::AbstractVector{<:Real},
    C_group::AbstractVector{<:Real},
    gamma_x::AbstractVector{<:Real},
    ε::Real,
    optimizer,
)
    # -------------------------
    # 0. Dimensions & sanity checks
    # -------------------------
    N, n_x = size(X)
    N_Z, m = size(Z)
    @assert N_Z == N "X and Z must have the same number of rows (samples)."
    @assert length(group) == N "group must have length N."
    @assert length(y) == N "y must have length N."
    @assert length(delta) == m "delta must have length m."
    @assert length(gamma_x) == n_x "gamma_x must have length n_x."

    k_z     = encinfo.k_z
    z_start = encinfo.z_start
    num_g   = encinfo.num_g

    @assert length(k_z) == m "encinfo.k_z must have length m."
    @assert length(z_start) == m "encinfo.z_start must have length m."
    @assert length(A_group) == num_g "A_group length must equal num_g."
    @assert length(B_group) == num_g "B_group length must equal num_g."
    @assert length(C_group) == num_g "C_group length must equal num_g."

    # Total length of β_z under reduced encoding:
    # last block starts at z_start[m], length (k_z[m] - 1)
    # so p_z = z_start[m] + (k_z[m] - 1) - 1
    p_z = z_start[end] + (k_z[end] - 1) - 1

    # Small positive constant for log(exp(inner) - 1) domain
    # We will enforce: r[i] + λ * inner_coeff ≥ η
    η = 1e-6

    # -------------------------
    # 1. Build per-sample DAGs (discrete part only)
    # -------------------------
    info = CatEncodingInfo(k_z, num_g)

    dags      = Vector{SampleDAG}(undef, N)
    n_nodes   = Vector{Int}(undef, N)
    max_nodes = 0

    for i in 1:N
        # Z[i, :] is an AbstractVector{<:Integer}
        dags[i] = build_sample_dag_structure(i, info, delta, Z[i, :])
        n_nodes[i] = length(dags[i].nodes)
        max_nodes = max(max_nodes, n_nodes[i])
    end

    # -------------------------
    # 2. Create JuMP model & decision variables
    # -------------------------
    model = Model(optimizer)

    # Wasserstein dual variable λ ≥ 0
    @variable(model, λ >= 0.0)

    # Per-sample slack variables r_i ∈ ℝ (free),
    # domain of log(exp(...)-1) will be enforced via extra constraints
    @variable(model, r[1:N])

    # Logistic regression parameters
    @variable(model, β0)                # intercept
    @variable(model, β_x[1:n_x])        # continuous coefficients
    @variable(model, β_z[1:p_z])        # categorical coefficients (reduced)
    @variable(model, β_grp[1:(num_g-1)])# group coefficients (reduced)

    # Graph dual variables μ_i_v for each sample i and each node v
    @variable(model, μ[1:N, 1:max_nodes])

    # -------------------------
    # 3. Objective: λ ε + (1/N) Σ r_i
    # -------------------------
    @objective(model, Min, λ * ε + (1.0 / N) * sum(r[i] for i in 1:N))

    # -------------------------
    # 4. Continuous part dual constraints with A_group
    #
    # Metric for x uses A_g:
    #   d_x(x^i, x; g) = A_g Σ_j γ_j |x_j - x_j^i|
    # Dual boundedness ⇒ for each j:
    #   |β_{xj}| ≤ λ * γ_j * min_g A_g
    # We enforce:
    #   -λ * A_min * γ_j ≤ β_x[j] ≤ λ * A_min * γ_j
    # -------------------------
    A_min = minimum(A_group)

    for j in 1:n_x
        @constraint(model,  β_x[j] <=  λ * A_min * gamma_x[j])
        @constraint(model, -β_x[j] <=  λ * A_min * gamma_x[j])
    end

    # -------------------------
    # 5. Outer logistic inequality:
    #    y^i (β_x^T x^i + β0) ≥ -μ_i(0,0) + μ_i(m+1,0)
    # -------------------------
    for i in 1:N
        dag = dags[i]
        s = dag.source
        t = dag.sink

        # Left-hand side: y_i * (β0 + β_x^T x^i)
        lhs = y[i] * (β0 + sum(β_x[j] * X[i, j] for j in 1:n_x))

        # Right-hand side: - μ_i(source) + μ_i(sink)
        rhs = - μ[i, s] + μ[i, t]

        @constraint(model, lhs >= rhs)
    end

    # -------------------------
    # 6. Edge constraints: μ_t(e) - μ_s(e) ≥ w^i(e)
    #
    # - For CatArc(k,c,...):
    #     w^i(e) = - y^i β_{z_k}^T z_k(c)
    #   Reduced encoding aligned with ZGEncodingInfo:
    #     if c < k_z[k], β_{z_k}^T z_k(c) = β_z[z_start[k] + c - 1]
    #     if c = k_z[k], baseline ⇒ 0.
    #
    # - For TermArc(d,g):
    #     w^i(e) =
    #       - y^i β_grp^T φ_g(g)
    #       - log( exp( r_i + λ (B_{g_i} d + C_{g_i} 1[g ≠ g_i]) ) - 1 )
    #
    #   with group reduced encoding:
    #     if g < num_g: β_grp^T φ_g(g) = β_grp[g]
    #     if g = num_g: baseline ⇒ 0.
    #
    #   Additionally, domain constraints:
    #     r_i + λ (B_{g_i} d + C_{g_i} 1[g ≠ g_i]) ≥ η
    #   ensure that log(exp(inner) - 1) is well-defined.
    # -------------------------
    for i in 1:N
        dag  = dags[i]
        y_i  = y[i]
        g_i  = group[i]

        for arc in dag.arcs
            src = arc.src
            dst = arc.dst

            if arc.kind isa CatArc
                kind = arc.kind::CatArc
                k = kind.k
                c = kind.c
                k_l = k_z[k]

                # Categorical part:
                # w_cat^i(e) = - y_i β_{z_k}^T z_k(c)
                if c < k_l
                    idx = z_start[k] + (c - 1)
                    w_expr = - y_i * β_z[idx]
                else
                    w_expr = 0.0
                end

                @constraint(model, μ[i, dst] - μ[i, src] >= w_expr)

            elseif arc.kind isa TermArc
                kind     = arc.kind::TermArc
                d_val    = kind.d
                g_choice = kind.g

                # Metric coefficient for categorical + group part:
                B_gi = float(B_group[g_i])
                C_gi = float(C_group[g_i])
                cross = (g_choice != g_i) ? C_gi : 0.0
                inner_coeff = B_gi * d_val + cross

                # --- Domain constraint: inner = r[i] + λ * inner_coeff ≥ η ---
                @NLconstraint(model, r[i] + λ * inner_coeff >= η)

                # --- w^i(e) constraint with nonlinear log/exp ---
                if g_choice < num_g
                    # group linear term: - y_i * β_grp[g_choice]
                    @NLconstraint(model,
                        μ[i, dst] - μ[i, src] >=
                        - y_i * β_grp[g_choice] -
                        log(exp(r[i] + λ * inner_coeff) - 1)
                    )
                else
                    # baseline group: no β_grp contribution
                    @NLconstraint(model,
                        μ[i, dst] - μ[i, src] >=
                        - log(exp(r[i] + λ * inner_coeff) - 1)
                    )
                end
            else
                error("Unknown arc kind in DAG.")
            end
        end
    end

    # -------------------------
    # 7. (Optional) Give Ipopt a safe starting point inside the domain
    # -------------------------
    set_start_value(λ, 1.0)
    for i in 1:N
        set_start_value(r[i], 1.0)
    end

    # -------------------------
    # 8. Return model + metadata
    # -------------------------
    meta = (
        encinfo = encinfo,
        dags    = dags,
        n_nodes = n_nodes,

        β0      = β0,
        β_x     = β_x,
        β_z     = β_z,
        β_grp   = β_grp,
        λ       = λ,
        r       = r,
        μ       = μ,
    )

    return model, meta
end


build_group_dro_graph_model

## Toy example

In [28]:
# ========== Toy data ==========
N   = 2      # two samples
n_x = 1      # one continuous feature
m   = 2      # two categorical components
num_g = 2    # two groups

# Continuous features X: N × n_x
X = [
    0.0;
    1.0
]
X = reshape(X, N, n_x)   # 2×1 matrix

# Categorical features Z: N × m, z_{i,k} ∈ {1,2}
Z = [
    1  1;   # sample 1: (1,1)
    2  2    # sample 2: (2,2)
]

# Group indices g_i ∈ {1,2}
group = [1, 2]

# Labels y_i ∈ {-1, +1}
y = [1, -1]

# Hamming weights δ_k
delta = [1.0, 1.0]

# Metric weights
A_group = [1.0, 2.0]   # continuous part weights A_g
B_group = [1.0, 1.0]   # categorical part weights B_g
C_group = [0.5, 0.5]   # group-change penalty C_g

# Continuous scaling γ_x
gamma_x = [1.0]        # length n_x

# Wasserstein radius
ε = 0.1


0.1

In [29]:
# ========== Encoding info ==========
k_z = [2, 2]

# reduced encoding block starts for categorical β_z
# block 1 starts at 1, length (k_z[1]-1) = 1
# block 2 starts at 2, length (k_z[2]-1) = 1
z_start = [1, 2]

struct ZGEncodingInfo
    k_z::Vector{Int}
    z_start::Vector{Int}
    num_g::Int
end

encinfo = ZGEncodingInfo(k_z, z_start, num_g)


ZGEncodingInfo([2, 2], [1, 2], 2)

In [30]:
# ========== Build model ==========
model, meta = build_group_dro_graph_model(
    X,
    Z,
    group,
    y,
    encinfo,
    delta,
    A_group,
    B_group,
    C_group,
    gamma_x,
    ε,
    optimizer_with_attributes(Ipopt.Optimizer) # or just Ipopt.Optimizer
)

println("Model successfully built.")
println(model)


Model successfully built.
Min 0.1 λ + 0.5 r[1] + 0.5 r[2]
Subject to
 β0 + μ[1,1] - μ[1,7] ≥ 0
 -β0 - β_x[1] + μ[2,1] - μ[2,7] ≥ 0
 β_z[1] - μ[1,1] + μ[1,2] ≥ 0
 -μ[1,1] + μ[1,3] ≥ 0
 β_z[2] - μ[1,2] + μ[1,4] ≥ 0
 -μ[1,2] + μ[1,5] ≥ 0
 β_z[2] - μ[1,3] + μ[1,5] ≥ 0
 -μ[1,3] + μ[1,6] ≥ 0
 -β_z[1] - μ[2,1] + μ[2,2] ≥ 0
 -μ[2,1] + μ[2,3] ≥ 0
 -β_z[2] - μ[2,2] + μ[2,4] ≥ 0
 -μ[2,2] + μ[2,5] ≥ 0
 -β_z[2] - μ[2,3] + μ[2,5] ≥ 0
 -μ[2,3] + μ[2,6] ≥ 0
 -λ + β_x[1] ≤ 0
 -λ - β_x[1] ≤ 0
 λ ≥ 0
 (r[1] + λ * 0.0) - 1.0e-6 ≥ 0
 ((μ[1,7] - μ[1,4]) - (-1.0 * β_grp[1] - log(exp(r[1] + λ * 0.0) - 1.0))) - 0.0 ≥ 0
 (r[1] + λ * 0.5) - 1.0e-6 ≥ 0
 ((μ[1,7] - μ[1,4]) - -(log(exp(r[1] + λ * 0.5) - 1.0))) - 0.0 ≥ 0
 (r[1] + λ * 1.0) - 1.0e-6 ≥ 0
 ((μ[1,7] - μ[1,5]) - (-1.0 * β_grp[1] - log(exp(r[1] + λ * 1.0) - 1.0))) - 0.0 ≥ 0
 (r[1] + λ * 1.5) - 1.0e-6 ≥ 0
 ((μ[1,7] - μ[1,5]) - -(log(exp(r[1] + λ * 1.5) - 1.0))) - 0.0 ≥ 0
 (r[1] + λ * 2.0) - 1.0e-6 ≥ 0
 ((μ[1,7] - μ[1,6]) - (-1.0 * β_grp[1] - log(exp(r[1] + 

In [31]:
optimize!(model)

println("Termination status: ", termination_status(model))
println("Primal status: ", primal_status(model))

This is Ipopt version 3.14.19, running with linear solver MUMPS 5.8.1.

Number of nonzeros in equality constraint Jacobian...:        0
Number of nonzeros in inequality constraint Jacobian.:      119
Number of nonzeros in Lagrangian Hessian.............:       36

Total number of variables............................:       22
                     variables with only lower bounds:        1
                variables with lower and upper bounds:        0
                     variables with only upper bounds:        0
Total number of equality constraints.................:        0
Total number of inequality constraints...............:       40
        inequality constraints with only lower bounds:       38
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        2

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  1.1000000e+00 0.00e+00 1.50e+00  -1.0 0.00e+00    -  0.00e+00 0.00e+00 

### larger example

In [None]:
using Random
using JuMP
using Ipopt

# -------------------------------------------------------------------
# If ZGEncodingInfo is already defined in your code, comment this out.
# -------------------------------------------------------------------
struct ZGEncodingInfo
    k_z::Vector{Int}     # number of categories for each categorical feature
    z_start::Vector{Int} # starting index of each block in β_z (1-based)
    num_g::Int           # number of groups
end

# Sigmoid
σ(t) = 1 / (1 + exp(-t))

"""
    linpred(β0, βx, βz, βgrp, x, z, g, encinfo)

Compute β0 + βxᵀ x + βzᵀ φ_z(z) + βgrpᵀ φ_g(g)
using the SAME reduced coding as in the DRO model.
"""
function linpred(
    β0::Real,
    βx::AbstractVector,
    βz::AbstractVector,
    βgrp::AbstractVector,
    x::AbstractVector,
    z::AbstractVector,
    g::Int,
    encinfo::ZGEncodingInfo,
)
    m     = length(encinfo.k_z)
    num_g = encinfo.num_g

    v = β0 + dot(βx, x)

    # categorical part
    for ℓ in 1:m
        kℓ   = encinfo.k_z[ℓ]
        start = encinfo.z_start[ℓ]      # 1-based index in βz
        c    = z[ℓ]                     # category ∈ {1,…,kℓ}

        if c < kℓ
            idx = start + (c - 1)       # reduced dummy: last level is reference
            v += βz[idx]
        end
    end

    # group part: reduced dummy, last group as reference
    if g < num_g
        v += βgrp[g]
    end

    return v
end

"""
    generate_synthetic_instance(; N=300)

Generate a synthetic dataset for testing the DRO LR model.

Returns:
    X, Z, group, y, encinfo, β_true

β_true is a NamedTuple:
    (β0 = ..., βx = ..., βz = ..., βgrp = ...)
"""
function generate_synthetic_instance(; N::Int = 500)
    Random.seed!(2025)

    # --- dimensions ---
    n_x   = 2              # two numerical features
    k_z   = [3, 2]         # two categorical features: sizes 3 and 2
    m     = length(k_z)
    num_g = 3              # three groups

    p_z = sum(k_z .- 1)    # total length of βz (reduced dummy)
    p_g = num_g - 1        # length of βgrp

    # --- build encoding info for z ---
    z_start = zeros(Int, m)
    offset = 1
    for ℓ in 1:m
        z_start[ℓ] = offset
        offset += k_z[ℓ] - 1
    end
    encinfo = ZGEncodingInfo(k_z, z_start, num_g)

    # --- true parameters β* ---
    β0_true   = -0.3
    βx_true   = [1.0, -0.7]
    βz_true   = [0.8, -0.5, 0.6]   # length p_z = (3-1)+(2-1) = 3
    βgrp_true = [0.5, -0.4]        # length p_g = 2

    @assert length(βx_true)   == n_x
    @assert length(βz_true)   == p_z
    @assert length(βgrp_true) == p_g

    # --- sample features and labels ---
    X     = randn(N, n_x)
    Z     = zeros(Int, N, m)
    group = zeros(Int, N)
    y     = zeros(Int, N)

    for i in 1:N
        # categorical features
        for ℓ in 1:m
            Z[i,ℓ] = rand(1:k_z[ℓ])
        end

        # group index
        group[i] = rand(1:num_g)

        # linear predictor and label
        η = linpred(β0_true, βx_true, βz_true, βgrp_true,
                    view(X, i, :), view(Z, i, :), group[i], encinfo)
        p = σ(η)
        y[i] = rand() < p ? 1 : -1
    end

    β_true = (β0 = β0_true,
              βx = βx_true,
              βz = βz_true,
              βgrp = βgrp_true)

    return X, Z, group, y, encinfo, β_true
end

"""
    run_synthetic_experiment()

1. Generate synthetic data from a known β*.
2. Build the DRO LR graph-based model.
3. Solve it.
4. Print true vs estimated parameters.

Assumes you already defined `build_group_dro_graph_model`.
Also assumes that function returns `meta` containing JuMP variables
`β0`, `β_x`, `β_z`, `β_grp`.
"""
function run_synthetic_experiment()
    # 1. data
    X, Z, group, y, encinfo, β_true = generate_synthetic_instance(N = 300)
    N, n_x = size(X)
    m      = size(Z, 2)
    num_g  = encinfo.num_g

    # 2. metric parameters (simple choice)
    delta   = 1e-2 * ones(m)            # δ_ℓ
    A_group = ones(num_g)               # A_g
    B_group = ones(num_g)               # B_g
    C_group = ones(num_g)               # C_g
    gamma_x = 1e-2 * ones(n_x)          # γ_j
    ε       = 1e-4

    # 3. build DRO model
    optimizer = optimizer_with_attributes(Ipopt.Optimizer,
                                          "print_level" => 5)

    model, meta = build_group_dro_graph_model(
        X, Z, group, y,
        encinfo,
        delta,
        A_group,
        B_group,
        C_group,
        gamma_x,
        ε,
        optimizer,
    )

    println("Model built. Start optimization...")
    optimize!(model)
    println("Termination status: ", termination_status(model))
    println("Objective value:    ", objective_value(model))

    # 4. extract parameters (adjust if your meta uses other field names)
    β0_hat   = value(meta.β0)
    βx_hat   = value.(meta.β_x)
    βz_hat   = value.(meta.β_z)
    βgrp_hat = value.(meta.β_grp)

    println("\n=== True vs estimated parameters ===")
    println("β0   true = ", β_true.β0,   "   hat = ", β0_hat)
    println("βx   true = ", β_true.βx,   "   hat = ", βx_hat)
    println("βz   true = ", β_true.βz,   "   hat = ", βz_hat)
    println("βgrp true = ", β_true.βgrp, "   hat = ", βgrp_hat)

    return model, meta, β_true, (β0_hat, βx_hat, βz_hat, βgrp_hat)
end

# -------------------------------------------------------------------
# Example usage from REPL / notebook:
#
# include("your_dro_graph_code.jl")   # defines build_group_dro_graph_model
# include("this_synthetic_test.jl")   # this file
model, meta, β_true, β_hat = run_synthetic_experiment()
# -------------------------------------------------------------------


Model built. Start optimization...
This is Ipopt version 3.14.19, running with linear solver MUMPS 5.8.1.

Number of nonzeros in equality constraint Jacobian...:        0
Number of nonzeros in inequality constraint Jacobian.:    24908
Number of nonzeros in Lagrangian Hessian.............:     8100

Total number of variables............................:     2409
                     variables with only lower bounds:        1
                variables with lower and upper bounds:        0
                     variables with only upper bounds:        0
Total number of equality constraints.................:        0
Total number of inequality constraints...............:     7804
        inequality constraints with only lower bounds:     7800
   inequality constraints with lower and upper bounds:        0
        inequality constraints with only upper bounds:        4

iter    objective    inf_pr   inf_du lg(mu)  ||d||  lg(rg) alpha_du alpha_pr  ls
   0  1.0001000e+00 0.00e+00 1.78e+00  -1.

(A JuMP Model
├ solver: Ipopt
├ objective_sense: MIN_SENSE
│ └ objective_function_type: AffExpr
├ num_variables: 2409
├ num_constraints: 7805
│ ├ AffExpr in MOI.GreaterThan{Float64}: 2400
│ ├ AffExpr in MOI.LessThan{Float64}: 4
│ ├ VariableRef in MOI.GreaterThan{Float64}: 1
│ └ Nonlinear: 5400
└ Names registered in the model
  └ :r, :β0, :β_grp, :β_x, :β_z, :λ, :μ, (encinfo = ZGEncodingInfo([3, 2], [1, 3], 3), dags = SampleDAG[SampleDAG(1, [(0, 0.0), (1, 0.01), (1, 0.0), (2, 0.01), (2, 0.02), (2, 0.0), (3, 0.0)], Arc[Arc(1, 2, CatArc(1, 1, 0.0, 0.01)), Arc(1, 2, CatArc(1, 2, 0.0, 0.01)), Arc(1, 3, CatArc(1, 3, 0.0, 0.0)), Arc(2, 4, CatArc(2, 1, 0.01, 0.01)), Arc(2, 5, CatArc(2, 2, 0.01, 0.02)), Arc(3, 6, CatArc(2, 1, 0.0, 0.0)), Arc(3, 4, CatArc(2, 2, 0.0, 0.01)), Arc(4, 7, TermArc(0.01, 1)), Arc(4, 7, TermArc(0.01, 2)), Arc(4, 7, TermArc(0.01, 3)), Arc(5, 7, TermArc(0.02, 1)), Arc(5, 7, TermArc(0.02, 2)), Arc(5, 7, TermArc(0.02, 3)), Arc(6, 7, TermArc(0.0, 1)), Arc(6, 7, TermArc(0.0, 