/
mandelbrot.Rmd
205 lines (168 loc) · 5.59 KB
/
mandelbrot.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
---
output: github_document
---
# Improve R's Performance using JuliaCall with Mandelbrot Set Example
## What is Mandelbrot Set?
Mandelbrot set is the set of complex numbers $c$ for which the sequence
$\{0, f_{c}(0), f_{c}(f_{c}(0)), \ldots, f^{(n)}_{c}(0), \ldots\}$
remains bounded in absolute value where
function $f_{c}(z)=z^{2}+c$. For more detailed introduction,
you could go to [Wikipedia page for Mandelbrot set](https://en.wikipedia.org/wiki/Mandelbrot_set).
The calculation for Mandelbrot set involves heavy computation,
and here we will use JuliaCall to speed up the computation by
rewriting R function in Julia,
and we will also see some easy and useful tips in writing higher performance
Julia code.
In this example, we will see that
JuliaCall actually brings **more than 100 times speedup** of the calculation.
```{r, echo = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
)
```
## Pure R Implementation
The R implementation for the calculation of Mandelbrot set is quite straightforward.
Note that it could be proved that
if absolute value of one item in the sequence is greater than 2,
the sequence $\{0, f_{c}(0), f_{c}(f_{c}(0)), \ldots, f^{(n)}_{c}(0), \ldots\}$
will be divergent.
In the `mandelbrot` function,
we keep the iteration number that the absolute value of the item is greater than 2
until the maximum iteration times.
And in the `mandelbrotImage` function,
for two given sequences of $x$ and $y$, we calculate the value
of `mandelbrot` function divided by the maximum iteration times
on every grid point.
```{r}
mandelbrot <- function(c, iterate_max = 500){
z <- 0i
for (i in 1:iterate_max) {
z <- z ^ 2 + c
if (abs(z) > 2.0) {
return(i)
}
}
iterate_max
}
mandelbrotImage <- function(xs, ys, iterate_max = 500){
sapply(ys, function(y) sapply(xs, function(x) mandelbrot(x + y * 1i, iterate_max = iterate_max))) / iterate_max
}
```
## Julia Implementation in R using JuliaCall
To use Julia from JuliaCall,
we first need to do some necessary setup work:
```{r}
library(JuliaCall)
julia_setup()
```
And then we could just define Julia functions using
`julia_command`, note that the syntax is quite similar to R
and easy to understand.
```{r}
julia_command("
function mandelbrot(c, iterate_max = 500)
z = 0.0im
for i in 1:iterate_max
z = z ^ 2 + c
if abs2(z) > 4.0
return(i)
end
end
iterate_max
end")
julia_command("
function mandelbrotImage(xs, ys, iterate_max = 500)
z = zeros(Float64, length(xs), length(ys))
for i in 1:length(xs)
for j in 1:length(ys)
z[i, j] = mandelbrot(xs[i] + ys[j] * im, iterate_max) / iterate_max
end
end
z
end")
```
## Performance Comparison Between R and Julia Implementation
This is the setting we use to compare the performance between
R and Julia implementations.
```{r}
iterate_max <- 1000L
centerx <- 0.37522 #0.3750001200618655
centery <- -0.22 #-0.2166393884377127
step <- 0.000002
size <- 125
xs <- seq(-step * size, step * size, step) + centerx
ys <- seq(-step * size, step * size, step) + centery
```
### Time for Pure R Implementation
```{r}
system.time(zR <- mandelbrotImage(xs, ys, iterate_max))
```
### Time for Julia Implementation using JuliaCall
```{r}
# A little warm up for Julia implementation
invisible(julia_call("mandelbrotImage", xs, ys, 2L))
system.time(zJL <- julia_call("mandelbrotImage", xs, ys, iterate_max))
```
We could see that
JuliaCall brings **a lot of times speedup** of the calculation,
actually, we could see more speedup with larger problem scale,
like **100 times speedup** or even more.
I won't show the result here because I don't want to wait
minutes for this RMarkdown document to be knitted.
## Tips for Julia Performance
### Write Small Functions
A general advice in writing Julia (as well as R) is that
you should write small functions which target at doing one thing.
For example,
it is possible to write `mandelbrot` and `mandelbrotImage` function
together, but it is not a good practice.
And the function call is also very cheap in Julia.
### Type Stability
If you want to write high performance Julia code,
you should write type stability functions.
It means the variable in the functions should be of only one type.
For example, if you change the first line of `mandelbrot` functions like this:
```{r}
julia_command("
function mandelbrot1(c, iterate_max = 500)
z = 0 ## instead of z = 0.0im in the original example
for i in 1:iterate_max
z = z ^ 2 + c
if abs2(z) > 4.0
return(i)
end
end
iterate_max
end")
julia_command("
function mandelbrotImage1(xs, ys, iterate_max = 500)
z = zeros(Float64, length(xs), length(ys))
for i in 1:length(xs)
for j in 1:length(ys)
z[i, j] = mandelbrot1(xs[i] + ys[j] * im, iterate_max) / iterate_max
end
end
z
end")
```
And we do the timing again:
```{r}
# A little warm up for Julia implementation
invisible(julia_call("mandelbrotImage1", xs, ys, 2L))
system.time(zJL <- julia_call("mandelbrotImage1", xs, ys, iterate_max))
```
We could see the function becomes much slower,
because in the `mandelbrot1` function,
`z` is an integer at the beginning,
but becomes a complex number in the iteration.
We could use `@code_warntype` or `code_warntype` and other tools
provided by Julia to check about this problem,
see <https://docs.julialang.org/en/stable/manual/performance-tips/>
for more information.
## Conclusion
For the end of this article, let us have a look at our Mandel set!
```{r}
image(xs, ys, zR, col = topo.colors(12))
image(xs, ys, zJL, col = topo.colors(12))
```