In [1]:
using JuMP, Gurobi

[1m[36mINFO: [39m[22m[36mPrecompiling module JuMP.
[39m[1m[36mINFO: [39m[22m[36mPrecompiling module Gurobi.
[39m

### Create th OCT in the paper with a toy dataset

In [160]:
# generate XOR-like data
function f(x1::Float64, x2::Float64)
    if (x1 <= 0.5 && x2 <= 0.5) || (x1 > 0.5 && x2 > 0.5)
        return 0
    end
    return 1     
end

K = 2  # number of classes
p = 2  # number of features
n = 10  # number of data points
x = rand(n, p)
y = diag([f(x1, x2) for x1 = x[:, 1], x2 = x[:, 2]])
Y = reshape(vcat(-y*2+1, y*2-1), (n, p));

In [162]:
Y

10×2 Array{Int64,2}:
 -1   1
  1  -1
 -1   1
  1  -1
 -1   1
 -1   1
  1  -1
  1  -1
 -1   1
 -1   1

In [163]:
# find epsilon, needed in equation (13)
epsilon = Float64[]

for j in 1:p
    min_j = typemax(Int32)
    for i in 1:n-1
        diff = abs(x[i, j] - x[i+1, j])
        if diff < min_j
            min_j = diff
        end
    end
    append!(epsilon, min_j)
end

In [164]:
# find left and right ancestors
function get_ancestors(t)
    current_node = t
    Al = Int[]
    Ar = Int[]
    
    while current_node != 1
        parent = convert(Int, floor(current_node/2))
        if current_node % 2 == 0
            append!(Al, parent)
        else
            append!(Ar, parent)
        end
        current_node = parent
    end 
    
    return Al, Ar
end

get_ancestors (generic function with 1 method)

In [174]:
# set up model
m = Model(solver=GurobiSolver())

# parameters 
D = 2  # depth
N_min = 1  # min number of data points in each leaf
alpha = 0  # complexity penalty weight

T = 2^(D+1)-1  # number of nodes in the tree
Tb = convert(Int, floor(T/2))  # number of branch nodes
Tl = T - Tb  # number of leaf nodes

# variables
@variable(m, a[1:p, 1:Tb], Bin)  # vector a for each branch node
@variable(m, b[1:Tb])  # split value b
@variable(m, d[1:Tb], Bin)  # d indicates if a branch node applies split
@variable(m, z[1:n, Tb+1:T], Bin)  # z indicates if xi is in leaf node t
@variable(m, l[Tb+1:T], Bin)  # l indicates if leaf node t contains any points
@variable(m, c[1:K, Tb+1:T], Bin)  # ???
@variable(m, Nt[Tb+1:T])  # total number of points in leaf node t
@variable(m, Nkt[1:K, Tb+1:T])  # total number of points of label k in leaf node t
@variable(m, L[Tb+1:T])  # loss at each leaf node

# constraints
@constraint(m, [t = 1:Tb], sum(a[:, t]) == d[t])  # equation (2)
@constraint(m, [t = 1:Tb], b[t] <= d[t])  # equation (3)
@constraint(m, [t = 1:Tb], b[t] >= 0)  # equation (3)
@constraint(m, [t = 2:Tb], d[t] <= d[convert(Int, floor(t/2))])  # equation (5)
@constraint(m, [i = 1:n, t = Tb + 1:T], z[i, t] <= l[t])  # equation (6)
@constraint(m, [t = Tb+1:T], sum(z[:, t]) >= N_min*l[t])  # equation (7)
@constraint(m, [i = 1:n], sum(z[i, :]) == 1)  # equation (8)

# equation (13), (14)
for t in Tb+1:T
    Al, Ar = get_ancestors(t)    
    for s in Ar, i in 1:n
        @constraint(m, x[i, :]'*a[:, s]-b[s]-z[i, t]+1 >= 0)  # equation (14)
    end
    for s in Al, i in 1:n
        @constraint(m, (x[i, :]+epsilon)'*a[:, s] <= b[s]+(1+maximum(epsilon))*(1-z[i, t]))  # equation (13)
    end
end

@constraint(m, [k = 1:K, t = Tb+1:T], Nkt[k, t] == 0.5*sum((1+Y[:, k])'*z[:, t]))  # equation (15)
@constraint(m, [t = Tb+1:T], Nt[t] == sum(z[:, t]))  # equation (16)
@constraint(m, [t = Tb+1:T], sum(c[:, t]) == l[t])
@constraint(m, [t = Tb+1:T], L[t] >= 0)  # equation (22)
@constraint(m, [k = 1:K, t = Tb+1:T], L[t] >= Nt[t]-Nkt[k, t]-n*(1-c[k, t]))  #equation (20)
@constraint(m, [k = 1:K, t = Tb+1:T], L[t] <= Nt[t]-Nkt[k, t]+n*c[k, t])

# set up objective
@objective(m, Min, sum(L)+alpha*sum(d));

In [175]:
# print the problem
#print(m)

In [177]:
status = solve(m)
println("Objective value: ", getobjectivevalue(m))
println(getvalue(b))

Optimize a model with 181 rows, 80 columns and 678 nonzeros
Variable types: 19 continuous, 61 integer (61 binary)
Coefficient statistics:
  Matrix range     [4e-02, 1e+01]
  Objective range  [1e+00, 1e+00]
  Bounds range     [1e+00, 1e+00]
  RHS range        [1e+00, 1e+01]

Loaded MIP start with objective 0
MIP start did not produce a new incumbent solution

Presolve removed 9 rows and 2 columns
Presolve time: 0.00s
Presolved: 172 rows, 78 columns, 689 nonzeros
Variable types: 3 continuous, 75 integer (59 binary)

Explored 0 nodes (0 simplex iterations) in 0.00 seconds
Thread count was 4 (of 4 available processors)

Solution count 1: 0 

Optimal solution found (tolerance 1.00e-04)
Best objective 0.000000000000e+00, best bound 0.000000000000e+00, gap 0.0000%
Objective value: 0.0
[0.0, 0.0, 0.0]
