forked from jump-dev/Gurobi.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
grb_callbacks.jl
156 lines (137 loc) · 4.68 KB
/
grb_callbacks.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Gurobi callbacks
mutable struct CallbackData
cbdata::Ptr{Cvoid}
model::Model
end
function gurobi_callback_wrapper(ptr_model::Ptr{Cvoid}, cbdata::Ptr{Cvoid}, where::Cint, userdata::Ptr{Cvoid})
(callback,model) = unsafe_pointer_to_objref(userdata)::Tuple{Function,Model}
callback(CallbackData(cbdata,model), where)
return convert(Cint,0)
end
# User callback function should be of the form:
# callback(cbdata::CallbackData, where::Cint)
function set_callback_func!(model::Model, callback::Function)
grbcallback = @cfunction(gurobi_callback_wrapper, Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Cint, Ptr{Cvoid}))
usrdata = (callback,model)
ret = @grb_ccall(setcallbackfunc, Cint, (Ptr{Cvoid}, Ptr{Cvoid}, Any), model.ptr_model, grbcallback, usrdata)
if ret != 0
throw(GurobiError(model.env, ret))
end
# we need to keep a reference to the callback function
# so that it isn't garbage collected
model.callback = usrdata
nothing
end
export CallbackData, set_callback_func!
for f in (:cbcut, :cblazy)
@eval function ($f)(cbdata::CallbackData, ind::Vector{Cint}, val::Vector{Float64},
sense::Char, rhs::Float64)
len = length(ind)
@assert length(val) == len
ret = @grb_ccall($f, Cint, (Ptr{Cvoid},Cint,Ptr{Cint},Ptr{Float64},
Cchar,Float64), cbdata.cbdata, len, ind.-Cint(1), val, sense, rhs)
if ret != 0
throw(GurobiError(cbdata.model.env, ret))
end
end
end
export cbcut, cblazy
function cbsolution(cbdata::CallbackData, sol::Vector{Float64})
@assert length(sol) >= num_vars(cbdata.model)
objP = Ref{Float64}()
ret = @grb_ccall(
cbsolution,
Cint,
(Ptr{Cvoid}, Ptr{Float64}, Ref{Float64}),
cbdata.cbdata, sol, objP
)
if ret != 0
throw(GurobiError(cbdata.model.env, ret))
end
return objP[]
end
function cbget(::Type{T},cbdata::CallbackData, where::Cint, what::Integer) where T
out = Ref{T}()
ret = @grb_ccall(cbget, Cint, (Ptr{Cvoid}, Cint, Cint, Ptr{T}),
cbdata.cbdata, where, convert(Cint,what), out)
if ret != 0
throw(GurobiError(cbdata.model.env, ret))
end
return out[]
end
# Callback constants
# grep GRB_CB gurobi_c.h | awk '{ print "const " substr($2,5) " = " $3; }'
const CB_POLLING = 0
const CB_PRESOLVE = 1
const CB_SIMPLEX = 2
const CB_MIP = 3
const CB_MIPSOL = 4
const CB_MIPNODE = 5
const CB_MESSAGE = 6
const CB_BARRIER = 7
export CB_POLLING, CB_PRESOLVE, CB_SIMPLEX, CB_MIP,
CB_MIPSOL, CB_MIPNODE, CB_MESSAGE, CB_BARRIER
# grep GRB_CB gurobi_c.h | awk '{ print "(\"" tolower(substr($2,8)) "\"," $3 ")"; }'
const cbconstants = [
("pre_coldel",1000,Cint),
("pre_rowdel",1001,Cint),
("pre_senchg",1002,Cint),
("pre_bndchg",1003,Cint),
("pre_coechg",1004,Cint),
("spx_itrcnt",2000,Float64),
("spx_objval",2001,Float64),
("spx_priminf",2002,Float64),
("spx_dualinf",2003,Float64),
("spx_ispert",2004,Float64),
("mip_objbst",3000,Float64),
("mip_objbnd",3001,Float64),
("mip_nodcnt",3002,Float64),
("mip_solcnt",3003,Cint),
("mip_cutcnt",3004,Cint),
("mip_nodlft",3005,Float64),
("mip_itrcnt",3006,Float64),
###("mipsol_sol",4001),
("mipsol_obj",4002,Float64),
("mipsol_objbst",4003,Float64),
("mipsol_objbnd",4004,Float64),
("mipsol_nodcnt",4005,Float64),
("mipsol_solcnt",4006,Cint),
("mipnode_status",5001,Cint),
###("mipnode_rel",5002),
("mipnode_objbst",5003,Float64),
("mipnode_objbnd",5004,Float64),
("mipnode_nodcnt",5005,Float64),
("mipnode_solcnt",5006,Cint),
##("mipnode_brvar",5007), -- undocumented
##("msg_string",6001), -- not yet implemented:
### documentation is unclear on output type
("runtime",6002, Float64),
("barrier_itrcnt",7001,Cint),
("barrier_primobj",7002,Float64),
("barrier_dualobj",7003,Float64),
("barrier_priminf",7004,Float64),
("barrier_dualinf",7005,Float64),
("barrier_compl",7006,Float64)]
for (cname,what,T) in cbconstants
fname = Symbol("cbget_$cname")
@eval ($fname)(cbdata::CallbackData, where::Cint) = cbget($T, cbdata, where, $what)
eval(Expr(:export,fname))
end
for (fname, what) in ((:cbget_mipsol_sol, 4001), (:cbget_mipnode_rel, 5002))
@eval function ($fname)(cbdata::CallbackData, where::Cint, out::Vector{Float64})
nvar = num_vars(cbdata.model)
@assert length(out) >= nvar
ret = @grb_ccall(cbget, Cint, (Ptr{Cvoid}, Cint, Cint, Ptr{Float64}),
cbdata.cbdata, where, $what, out)
if ret != 0
throw(GurobiError(cbdata.model.env, ret))
end
end
@eval function ($fname)(cbdata::CallbackData, where::Cint)
nvar = num_vars(cbdata.model)
out = Array{Float64}(undef, nvar)
($fname)(cbdata, where, out)
return out
end
eval(Expr(:export,fname))
end