In [None]:
using JuMP
using Plots
using LaTeXStrings
pyplot()
include("../src/OptimalConstraintTree.jl")
const OCT = OptimalConstraintTree

In [None]:
md = OCT.sagemark_to_ModelData(3, lse=false);
md.lbs[4] = -300.;
md.ubs[4] = 0;

In [None]:
n_samples = 1000;
X = OCT.sample(md, n_samples=n_samples);

In [None]:
ineq_trees, eq_trees = OCT.fit(md, X, lnr = OCT.base_lnr(false));

In [None]:
IAI.show_in_browser(ineq_trees[1].lnr)
IAI.show_in_browser(ineq_trees[2].lnr)

In [None]:
# Getting data ready for plots
offset = [md.ineq_fns[2](X[j,:]) for j=1:n_samples];
# 1 - 0.01\frac{x_2}{x_3} - 0.01 x_2 - 0.005x_1x_3 \geq 0
feas_idxs = findall(x -> x.>=0, offset);
infeas_idxs = findall(x -> x.<0, offset);

In [None]:
include("../src/constraintify.jl");
upperDict, lowerDict = trust_region_data(ineq_trees[2].lnr, [Symbol("x",i) for i=1:4]);

In [None]:
lnr = ineq_trees[2].lnr;
n_nodes = IAI.get_num_nodes(lnr)
all_leaves = [i for i = 1:n_nodes if IAI.is_leaf(lnr, i)]
splits = [i for i=1:n_nodes if !(i in all_leaves)]
vks = [Symbol("x",i) for i=1:4];
fns = [];
# Calculation fn for x3
for i in splits
    threshold = IAI.get_split_threshold(lnr, i);
    weights = IAI.get_split_weights(lnr, i)[1];
    α = [];
    for j = 1:size(vks, 1)
        if vks[j] in keys(weights)
            append!(α, weights[vks[j]])
        else
            append!(α, 0.0)
        end
    end
    push!(fns, x -> (threshold - α[1]*x[1] - α[2]*x[2])/α[3]);
end

In [None]:
p_infeas = scatter(X[infeas_idxs,1], X[infeas_idxs,2], X[infeas_idxs,3], color=:red)
p_feas = scatter!(X[feas_idxs,1], X[feas_idxs,2], X[feas_idxs,3], color=:green)
s_1 = plot!(X[:,1], X[:,2], fns[1].([X[j,1:2] for j=1:n_samples]), st=:surface)
plot(s_1, xlabel=L"$x_1$", ylabel=L"$x_2$", zlabel=L"$x_3$", camera=(70,-10), legend=:none, colorbar=false)

In [None]:
# Actually solving the problem
m, x = OCT.jump_it(md);
OCT.add_linear_constraints!(m, x, md);
OCT.add_tree_constraints!(m, x, ineq_trees, eq_trees);
status = solve(m);
println("Solved minimum: ", sum(md.c .* getvalue(x)))
println("Known global bound: ", -147-2/3)
println("X values: ", getvalue(x))
println("Optimal X: ", vcat(exp.([5.01063529, 3.40119660, -0.48450710]), [-147-2/3]))