In [1]:
# check current directory

pwd()

"/Users/xyliu/Desktop/ISE 616 final/ISE-616-final"

In [None]:
using JuMP
using Ipopt   
using LinearAlgebra
using MathOptInterface
const MOI = MathOptInterface



## function of one-hot encoding

In [21]:
############################
# Encoding information type
############################

"""
    ZGEncodingInfo

Stores the information needed to convert between
original (integer-coded) categorical/group features
and their one-hot encodings.

Fields:
- k_z::Vector{Int}     : number of categories for each categorical component (length m)
- z_start::Vector{Int} : start index (column) of each component in Z_onehot (length m)
- num_g::Int           : total number of groups
"""

struct ZGEncodingInfo
    k_z::Vector{Int}
    z_start::Vector{Int}
    num_g::Int
end



#############################################
# Original (Z, group) → Reduced (Z_enc, G_enc)
#############################################

"""
    encode_zg_reduced(Z::AbstractMatrix{<:Integer},
                      group::AbstractVector{<:Integer})

Encode categorical features Z and group indexes into a reduced dummy
representation:

- Each categorical component l with k_l levels is mapped to (k_l - 1)
  dummy variables. The last category k_l is the baseline (all zeros).
- The group variable with num_g levels is mapped to (num_g - 1)
  dummy variables. The last group num_g is the baseline (all zeros).

Input:
- Z     : N × m matrix, each column l stores an integer in {1, ..., k_l}
- group : length-N vector, each entry in {1, ..., num_g}

Output:
- Z_enc : N × sum_l (k_l - 1) matrix (reduced dummies for Z)
- G_enc : N × (num_g - 1) matrix (reduced dummies for group)
- info  : ZGEncodingInfo, stores the encoding scheme for later decoding
"""
function encode_zg_reduced(Z::AbstractMatrix{<:Integer},
                           group::AbstractVector{<:Integer})
    N, m = size(Z)
    @assert length(group) == N "group must have length N"

    # k_z[l] = number of categories for component l (assumed {1, ..., k_l})
    k_z = [maximum(Z[:, l]) for l in 1:m]

    # Column start indices in the reduced dummy matrix Z_enc.
    # Component l uses (k_z[l] - 1) columns (baseline = category k_z[l]).
    z_start = Vector{Int}(undef, m)
    col = 1
    for l in 1:m
        z_start[l] = col
        d_l = max(k_z[l] - 1, 0)  # number of dummies for component l
        col += d_l
    end
    total_z = col - 1   # total number of columns in Z_enc

    # Allocate Z_enc: N × total_z
    Z_enc = zeros(Int8, N, total_z)

    for i in 1:N
        for l in 1:m
            val = Z[i, l]
            @assert 1 ≤ val ≤ k_z[l] "Category out of range in column $l (row $i)"

            k_l = k_z[l]
            d_l = max(k_l - 1, 0)

            # If d_l == 0, there is effectively a single category (no column needed).
            if d_l > 0 && val < k_l
                col0 = z_start[l]              # start index for component l
                Z_enc[i, col0 + val - 1] = 1   # categories 1..(k_l-1)
            end
            # If val == k_l, baseline category -> all zeros for this component
        end
    end

    # Encode group using (num_g - 1) dummies; last group is baseline.
    num_g = maximum(group)
    d_g   = max(num_g - 1, 0)
    G_enc = zeros(Int8, N, d_g)

    if d_g > 0
        for i in 1:N
            gi = group[i]
            @assert 1 ≤ gi ≤ num_g "Group index out of range at row $i"

            if gi < num_g
                # Non-baseline group in {1, ..., num_g - 1}
                G_enc[i, gi] = 1
            else
                # Baseline group num_g -> all zeros
            end
        end
    end

    info = ZGEncodingInfo(k_z, z_start, num_g)
    return Z_enc, G_enc, info
end


encode_zg_reduced

## test one-hot encode

In [22]:
# testing
Z = [
    1  1;   # row 1
    2  3;   # row 2
    3  2;   # row 3
    1  3;   # row 4
]
# group: length-N vector, values in {1,2,3}
group = [1, 2, 1, 3]

#######################
# 2. Run encoding
#######################

Z_enc, G_onehot, info = encode_zg_reduced(Z, group)
println("Original Z:")
println(Z)
println()

println("Encoded Z_enc (reduced one-hot):")
println(Z_enc)
println("size(Z_enc) = ", size(Z_enc))
println()

println("Group one-hot G_onehot:")
println(G_onehot)
println("size(G_onehot) = ", size(G_onehot))
println()

println("Encoding info:")
println("  k_z     = ", info.k_z)
println("  z_start = ", info.z_start)
println("  num_g   = ", info.num_g)

Original Z:
[1 1; 2 3; 3 2; 1 3]

Encoded Z_enc (reduced one-hot):
Int8[1 0 1 0; 0 1 0 0; 0 0 0 1; 1 0 0 0]
size(Z_enc) = (4, 4)

Group one-hot G_onehot:
Int8[1 0; 0 1; 1 0; 0 0]
size(G_onehot) = (4, 2)

Encoding info:
  k_z     = [3, 3]
  z_start = [1, 3]
  num_g   = 3


In [24]:
#############################################
# Reduced (Z_enc, G_enc) → Original (Z, group)
#############################################

"""
    decode_zg_reduced(Z_enc::AbstractMatrix{<:Integer},
                      G_enc::AbstractMatrix{<:Integer},
                      info::ZGEncodingInfo)

Decode reduced dummy encoded categorical and group features back to
their original integer-coded form.

For each categorical component l with k_l levels:
- If the corresponding block has all zeros, we recover category k_l
  (the baseline).
- Otherwise, the position of the 1 determines the category in {1, ..., k_l - 1}.

For the group variable with num_g levels:
- If the row in G_enc is all zeros, we recover group num_g (baseline).
- Otherwise, the position of the 1 determines the group in {1, ..., num_g - 1}.

Input:
- Z_enc : N × sum_l (k_l - 1) matrix (reduced dummies for Z)
- G_enc : N × (num_g - 1) matrix (reduced dummies for group)
- info  : ZGEncodingInfo produced earlier by `encode_zg_reduced`

Output:
- Z     : N × m matrix, each entry in {1, ..., k_l}
- group : length-N vector, each entry in {1, ..., num_g}
"""
function decode_zg_reduced(Z_enc::AbstractMatrix{<:Integer},
                           G_enc::AbstractMatrix{<:Integer},
                           info::ZGEncodingInfo)
    N, total_z = size(Z_enc)
    m = length(info.k_z)

    # Check group encoding consistency
    num_g = info.num_g
    d_g   = max(num_g - 1, 0)
    @assert size(G_enc, 1) == N "Z_enc and G_enc must have same number of rows"
    @assert size(G_enc, 2) == d_g "G_enc columns must be num_g - 1"

    # Recover categorical matrix Z (N × m)
    Z = zeros(Int, N, m)
    for l in 1:m
        k_l = info.k_z[l]
        d_l = max(k_l - 1, 0)
        s   = info.z_start[l]
        e   = s + d_l - 1

        if d_l == 0
            # Only one category exists for this component; always category 1.
            for i in 1:N
                Z[i, l] = 1
            end
            continue
        end

        @assert e ≤ total_z "Z_enc has too few columns for component $l"

        # View the reduced dummy block for component l
        sub = @view Z_enc[:, s:e]  # N × d_l
        for i in 1:N
            idx = findfirst(==(1), sub[i, :])
            if idx === nothing
                # All zeros: baseline category k_l
                Z[i, l] = k_l
            else
                # Non-baseline category in {1, ..., k_l - 1}
                Z[i, l] = idx
            end
        end
    end

    # Recover group vector (length N) from reduced dummies
    group = Vector{Int}(undef, N)
    if d_g == 0
        # Only one group; always group 1
        for i in 1:N
            group[i] = 1
        end
    else
        for i in 1:N
            idx = findfirst(==(1), G_enc[i, :])
            if idx === nothing
                # Baseline group num_g
                group[i] = num_g
            else
                # Non-baseline group in {1, ..., num_g - 1}
                group[i] = idx
            end
        end
    end

    return Z, group
end


decode_zg_reduced

## test one-hot decode

In [25]:
Z, group = decode_zg_reduced(Z_enc, G_onehot, info)
Z, group

([1 1; 2 3; 3 2; 1 3], [1, 2, 1, 3])

## DAG builder

In [None]:
###############################
# Graph structure for sample-level DAGs
# compatible with Theorem 2 (graph-based formulation)
###############################

# -----------------------------
# Arc kind (structure only)
# -----------------------------

"""
    ArcKind

Abstract type for different kinds of arcs in the DAG.
We distinguish between:
  - `CatArc` : categorical feature transitions
  - `TermArc`: terminal transitions to (m+1, 0) with a chosen group g
"""
abstract type ArcKind end

"""
    CatArc

Categorical arc corresponding to choosing category `c` for component `k`.

Fields:
- k      : which categorical component (1..m)
- c      : chosen category in {1, ..., k_z[k]}
- d_prev : previous accumulated categorical distance at state (k-1, d_prev)
- d      : new accumulated distance at state (k, d)
"""

struct CatArc <: ArcKind
    k::Int
    c::Int
    d_prev::Float64
    d::Float64
end

"""
    TermArc

Terminal arc from (m, d) to (m+1, 0) with a chosen destination group g.

Fields:
- d : accumulated categorical distance at state (m, d)
- g : destination group in {1, ..., num_g}
"""
struct TermArc <: ArcKind
    d::Float64
    g::Int
end

# -----------------------------
# Arc and per-sample DAG
# -----------------------------

"""
    Arc

Directed edge in the DAG for a single sample i.

Fields:
- src  : index of the source node in `nodes`
- dst  : index of the destination node in `nodes`
- kind : arc type (CatArc or TermArc), which encodes all information
         needed to build w^i(e; β, λ, r_i, ...) later
"""
struct Arc
    src::Int
    dst::Int
    kind::ArcKind
end






"""
    SampleDAG

Graph structure associated with a single sample i, compatible with
Theorem 2 (graph-based formulation).

Fields:
- sample_index : index i of the sample this DAG corresponds to
- nodes        : vector of DP states (k, d), plus the terminal node (m+1, 0.0)
                 each node is a Tuple{Int,Float64}
- arcs         : list of directed edges with their arc kind (no numeric weight)
- source       : index of the source node (corresponds to state (0, 0.0))
- sink         : index of the terminal node (corresponds to state (m+1, 0.0))
"""
struct SampleDAG
    sample_index::Int
    nodes::Vector{Tuple{Int,Float64}}
    arcs::Vector{Arc}
    source::Int
    sink::Int
end




# -----------------------------
# Categorical encoding info
# -----------------------------
"""
    CatEncodingInfo

Encoding information for categorical features (structure level only).

Fields:
- k_z::Vector{Int} : number of categories per component (length m).
                     For component k, categories are {1, ..., k_z[k]}.
- num_g::Int       : total number of groups (used for terminal arcs).
"""
struct CatEncodingInfo
    k_z::Vector{Int}
    num_g::Int
end


# -----------------------------
# Build DAG structure for one sample i
# -----------------------------
"""
    build_sample_dag_structure(i, info, delta, z_i) -> SampleDAG

Build the DAG structure G^i = (V^i, A^i) for a fixed sample i,
compatible with Theorem 2 (graph-based formulation), using our setup.

This function ONLY builds:
  - the nodes (states (k, d) plus the terminal (m+1, 0)),
  - the arcs with their structural information (CatArc or TermArc).

It does NOT compute numeric weights w^i(e). Those should be constructed
later as expressions of (β, λ, r_i, y_i, g_i, B_{g_i}, C_{g_i}, ...).

Arguments:
- i      : sample index
- info   : CatEncodingInfo (k_z, num_g)
- delta  : length-m vector of δ_k in
           d(z, zᶦ) = Σ_k δ_k * 1[z_k ≠ z_kᶦ]
- z_i    : length-m vector of original categories for sample i,
           each z_i[k] ∈ {1, ..., k_z[k]}

Output:
- SampleDAG describing the structure of G^i
"""
function build_sample_dag_structure(
    i::Int,
    info::CatEncodingInfo,
    delta::AbstractVector{<:Real},
    z_i::AbstractVector{<:Integer}
)::SampleDAG
    k_z   = info.k_z
    num_g = info.num_g
    m     = length(k_z)

    # -------------------------
    # 0. Sanity checks
    # -------------------------
    @assert length(delta) == m "delta must have length m"
    @assert length(z_i) == m "z_i must have length m"

    # -------------------------
    # 1. Enumerate states (k, d) with per-layer dedup
    # -------------------------
    nodes = Vector{Tuple{Int,Float64}}()
    node_index = Dict{Tuple{Int,Float64},Int}()

    # Source state (0, 0.0)
    push!(nodes, (0, 0.0))
    node_index[(0, 0.0)] = 1
    source_idx = 1

    arcs = Vector{Arc}()

    # current_layer holds all unique states (k-1, d_prev)
    current_layer = [(0, 0.0)]

    for k in 1:m
        next_layer = Tuple{Int,Float64}[]
        seen_next = Set{Tuple{Int,Float64}}()

        k_l   = k_z[k]
        δ_k   = float(delta[k])
        z_i_k = z_i[k]

        for (k_prev, d_prev) in current_layer
            @assert k_prev == k - 1

            # Enumerate all categories c ∈ {1, ..., k_l}
            for c in 1:k_l
                mismatch = (c != z_i_k)
                d = d_prev + (mismatch ? δ_k : 0.0)
                state = (k, d)

                # Add state to global node list if new
                if !haskey(node_index, state)
                    push!(nodes, state)
                    node_index[state] = length(nodes)
                end

                # Ensure each (k, d) appears at most once in next_layer
                if !(state in seen_next)
                    push!(next_layer, state)
                    push!(seen_next, state)
                end

                # Add categorical arc from (k-1, d_prev) to (k, d)
                src = node_index[(k-1, d_prev)]
                dst = node_index[state]
                kind = CatArc(k, c, d_prev, d)
                push!(arcs, Arc(src, dst, kind))
            end
        end

        # Move to next layer (already deduplicated)
        current_layer = next_layer
    end

    # States with k = m are the final DP layer S_1^i
    # current_layer is already deduplicated, but we keep the name for clarity
    S1_states = current_layer

    # -------------------------
    # 2. Add terminal node (m+1, 0.0) and terminal arcs
    # -------------------------
    terminal_state = (m+1, 0.0)
    push!(nodes, terminal_state)
    node_index[terminal_state] = length(nodes)
    sink_idx = node_index[terminal_state]

    # For each (m, d) in S_1^i, and for each group g, add a TermArc
    for (k_state, d) in S1_states
        @assert k_state == m
        src = node_index[(m, d)]
        for g in 1:num_g
            kind = TermArc(d, g)
            push!(arcs, Arc(src, sink_idx, kind))
        end
    end

    # -------------------------
    # 3. Return SampleDAG
    # -------------------------
    return SampleDAG(i, nodes, arcs, source_idx, sink_idx)
end

build_sample_dag_structure

## test DAG builder

In [None]:
i = 2001

k_z   = [2, 2]      # each component 2 cats
delta = [1.0, 1.0]  # δ₁ = δ₂ = 1
z_i   = [1, 2]      # sample i：z₁ = 1, z₂ = 2
num_g = 2

info = CatEncodingInfo(k_z, num_g)

dag_i = build_sample_dag_structure(i, info, delta, z_i)
dag_i

SampleDAG(2001, [(0, 0.0), (1, 0.0), (1, 1.0), (2, 1.0), (2, 0.0), (2, 2.0), (3, 0.0)], Arc[Arc(1, 2, CatArc(1, 1, 0.0, 0.0)), Arc(1, 3, CatArc(1, 2, 0.0, 1.0)), Arc(2, 4, CatArc(2, 1, 0.0, 1.0)), Arc(2, 5, CatArc(2, 2, 0.0, 0.0)), Arc(3, 6, CatArc(2, 1, 1.0, 2.0)), Arc(3, 4, CatArc(2, 2, 1.0, 1.0)), Arc(4, 7, TermArc(1.0, 1)), Arc(4, 7, TermArc(1.0, 2)), Arc(5, 7, TermArc(0.0, 1)), Arc(5, 7, TermArc(0.0, 2)), Arc(6, 7, TermArc(2.0, 1)), Arc(6, 7, TermArc(2.0, 2))], 1, 7)

## modeling

In [36]:
using JuMP

"""
    build_group_dro_graph_model(
        X, Z, group, y,
        encinfo,
        delta,
        B_group, C_group,
        gamma_x,
        ε,
        optimizer
    ) -> (model, meta)

Build our group-dependent graph-based DRO logistic regression model,
using the same reduced encoding convention as `ZGEncodingInfo` /
`encode_zg_reduced`.
"""






"    build_group_dro_graph_model(\n        X, Z, group, y,\n        encinfo,\n        delta,\n        B_group, C_group,\n        gamma_x,\n        ε,\n        optimizer\n    ) -> (model, meta)\n\nBuild our group-dependent graph-based DRO logistic regression model,\nusing the same reduced encoding convention as `ZGEncodingInfo` /\n`encode_zg_reduced`.\n"

In [30]:
2 <= 3


true