In [1]:
import math

import numba
from numba import vectorize, boolean, int32, complex128, complex64

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [None]:
@vectorize([int32(complex128, int32)])
def mandelbrot_test(c: complex, max_iters: int = 1000) -> int:
    z = 0 
    for n in range(max_iters):
        z = (z**2) + c
        if abs(z) > 2:
            return n
    return max_iters

def get_complex_grid(
    real_min: float,
    real_max: float,
    n_real: int, 
    imag_min: float,
    imag_max: float,
    n_imag: int
) -> np.ndarray:
    real = np.linspace(real_min, real_max, n_real)
    imag = np.linspace(imag_min, imag_max, n_imag)
    return (real[np.newaxis, :] + imag[:, np.newaxis] * 1j).astype(np.complex64)
    
def plot_mandelbrot(set_data, xmin, xmax, ymin, ymax, cmap='hot'):
    fig, ax = plt.subplots(figsize=(10,10))
    img = ax.imshow(set_data, cmap=cmap, extent=[xmin, xmax, ymin, ymax])
    fig.colorbar(img, ax=ax)

    ax.set_title('Mandelbrot Set')
    ax.set_xlabel('Re')
    ax.set_ylabel('Im')
    
    return fig

def get_mandelbrot_fig(real_min, real_max, n_real, imag_min, imag_max, n_imag, max_iters):
    grid = get_complex_grid(real_min, real_max, n_real, imag_min, imag_max, n_imag)
    mandelbrot_set_data = mandelbrot_test(grid, max_iters)
    return plot_mandelbrot(mandelbrot_set_data, real_min, real_max, imag_min, imag_max)

In [None]:
"""
Create some figures.
"""

real_min = -2
real_max = 1
n_real = 10_000
imag_min = -1.5
imag_max = 1.5
n_imag = 10_000
max_iters = int32(10_000)

fig = get_mandelbrot_fig(real_min, real_max, n_real, imag_min, imag_max, n_imag, max_iters)