-
Notifications
You must be signed in to change notification settings - Fork 103
/
pyjulia_helper.jl
162 lines (137 loc) · 4.68 KB
/
pyjulia_helper.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
module _PyJuliaHelper
import REPL
using PyCall
using PyCall: pyeval_, Py_eval_input, Py_file_input
using PyCall.MacroTools: isexpr, walk
"""
fullnamestr(m)
# Examples
```jldoctest
julia> fullnamestr(Base.Enums)
"Base.Enums"
```
"""
fullnamestr(m) = join(fullname(m), ".")
isdefinedstr(parent, member) = isdefined(parent, Symbol(member))
function completions(str, pos)
ret, ran, should_complete = REPL.completions(str, pos)
return (
map(REPL.completion_text, ret),
(first(ran), last(ran)),
should_complete,
)
end
# takes an expression like `$foo + 1` and turns it into a pyfunction
# `(globals,locals) -> convert(PyAny, pyeval_("foo",globals,locals,PyAny)) + 1`
# so that Python code can call it and just pass the appropriate globals/locals
# dicts to perform the interpolation.
macro prepare_for_pyjulia_call(ex)
# f(x, quote_depth) should return a transformed expression x and whether to
# recurse into the new expression. quote_depth keeps track of how deep
# inside of nested quote objects we arepyeval
function stoppable_walk(f, x, quote_depth=1)
(fx, recurse) = f(x, quote_depth)
if isexpr(fx,:quote)
quote_depth += 1
end
if isexpr(fx,:$)
quote_depth -= 1
end
walk(fx, (recurse ? (x -> stoppable_walk(f,x,quote_depth)) : identity), identity)
end
function make_pyeval(globals, locals, expr::Union{String,Symbol}, options...)
code = string(expr)
T = length(options) == 1 && 'o' in options[1] ? PyObject : PyAny
input_type = '\n' in code ? Py_file_input : Py_eval_input
:($convert($T, $pyeval_($code, $globals, $locals, $input_type)))
end
function insert_pyevals(globals, locals, ex)
stoppable_walk(ex) do x, quote_depth
if quote_depth==1 && isexpr(x, :$)
if x.args[1] isa Symbol
make_pyeval(globals, locals, x.args[1]), false
else
error("""syntax error in: \$($(string(x.args[1])))
Use py"..." instead of \$(...) for interpolating Python expressions.""")
end
elseif quote_depth==1 && isexpr(x, :macrocall)
if x.args[1]==Symbol("@py_str")
# in Julia 0.7+, x.args[2] is a LineNumberNode, so filter it out
# in a way that's compatible with Julia 0.6:
make_pyeval(globals, locals, filter(s->(s isa String), x.args[2:end])...), false
else
x, false
end
else
x, true
end
end
end
esc(quote
$pyfunction(
(globals, locals)->Base.eval(Main, $insert_pyevals(globals, locals, $(QuoteNode(ex)))),
$PyObject, $PyObject
)
end)
end
module IOPiper
const orig_stdin = Ref{IO}()
const orig_stdout = Ref{IO}()
const orig_stderr = Ref{IO}()
function __init__()
orig_stdin[] = stdin
orig_stdout[] = stdout
orig_stderr[] = stderr
end
"""
num_utf8_trailing(d::Vector{UInt8})
If `d` ends with an incomplete UTF8-encoded character, return the number of trailing incomplete bytes.
Otherwise, return `0`.
Taken from IJulia.jl.
"""
function num_utf8_trailing(d::Vector{UInt8})
i = length(d)
# find last non-continuation byte in d:
while i >= 1 && ((d[i] & 0xc0) == 0x80)
i -= 1
end
i < 1 && return 0
c = d[i]
# compute number of expected UTF-8 bytes starting at i:
n = c <= 0x7f ? 1 : c < 0xe0 ? 2 : c < 0xf0 ? 3 : 4
nend = length(d) + 1 - i # num bytes from i to end
return nend == n ? 0 : nend
end
function pipe_stream(sender::IO, receiver, buf::IO = IOBuffer())
try
while !eof(sender)
nb = bytesavailable(sender)
write(buf, read(sender, nb))
# Taken from IJulia.send_stream:
d = take!(buf)
n = num_utf8_trailing(d)
dextra = d[end-(n-1):end]
resize!(d, length(d) - n)
s = String(copy(d))
write(buf, dextra)
receiver(s) # check isvalid(String, s)?
end
catch e
if !isa(e, InterruptException)
rethrow()
end
pipe_stream(sender, receiver, buf)
end
end
const read_stdout = Ref{Base.PipeEndpoint}()
const read_stderr = Ref{Base.PipeEndpoint}()
function pipe_std_outputs(out_receiver, err_receiver)
global readout_task
global readerr_task
read_stdout[], = redirect_stdout()
readout_task = @async pipe_stream(read_stdout[], out_receiver)
read_stderr[], = redirect_stderr()
readerr_task = @async pipe_stream(read_stderr[], err_receiver)
end
end # module
end # module