# Init

In [None]:
import jax
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
import numpy as np
import matplotlib

matplotlib.use('Agg')

import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.animation as animation
from matplotlib import cm

from scipy import ndimage

import time
from google.colab import files

import datetime

import pickle

import shutil

In [None]:
# # so that results can be saved directly to Google Drive, rather than being lost when the colab kernel stops

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Image intensity scaling

In [None]:
def cdf_img(x, buffer=0.25):
  """
  restretch to use full color palette, while preserving sign
  """

  u = jnp.sort(x.ravel())

  num_neg = jnp.sum(u<0)
  num_nonneg = u.shape[0] - num_neg
  v = jnp.concatenate((jnp.linspace(-1,-buffer,num_neg), jnp.linspace(buffer,1,num_nonneg)), axis=0)
  y = jnp.interp(x, u, v)
  return y

In [None]:
def convergence_measure(v, max_val = 1e6):
  """
  turn the training history into a single number which looks pretty in an
  image
  """

  v = jnp.abs(v)
  fin = jnp.isfinite(v)
  v = v*fin + max_val*(1-fin)

  v /= 2
  exceeds = (v > max_val)
  v = v*(1-exceeds) + max_val*exceeds

  # converged = (v[-1] < 1)
  converged = (jnp.mean(v[-20:]) < 1) # average over any oscillatory behavior
  return jnp.where(converged, -jnp.sum(v), jnp.sum(1/v))

convergence_measure_vmap = jax.jit(jax.vmap(convergence_measure, in_axes=(0,), out_axes=0))

# Iteration of z**2 + c for grid of hyperparameters

In [None]:
def mandel_trajectory(c, num_steps):

  def mandel_f(z, step):
    newz = z**2 + c
    return newz, newz

  z_init = 0+0j
  z_final, z_hist = jax.lax.scan(mandel_f, z_init, jnp.arange(num_steps))
  return z_hist

def mandel_trajectory_summary(c, num_steps):
  z_hist = mandel_trajectory(c, num_steps)
  return convergence_measure(z_hist)

mandel_trajectory_summary_vmap = jax.jit(jax.vmap(mandel_trajectory_summary, in_axes=(0,None), out_axes=0), static_argnums=(1,))

def mandel_image(c0, c1, resolution, num_chunks = 64):
  C_real, C_imag = jnp.meshgrid(jnp.linspace(c0.real, c1.real, resolution), jnp.linspace(c0.imag, c1.imag, resolution))
  C = C_real + 1j*C_imag

  C = C.reshape((num_chunks, -1))
  img = []
  for C_chunk in C:
    img_chunk = mandel_trajectory_summary_vmap(C_chunk, 100)
    img.append(img_chunk)

  img = jnp.concatenate(img, axis=0)

  return img.reshape((resolution,resolution))

  # z_hist = mandel_trajectory_summary_vmap(C.ravel(), 100)
  # z_hist = z_hist.reshape((resolution,resolution))
  # return z_hist

# Compute a region of Mandelbrot

In [None]:
c0 = -2.1 - 1.35j
c1 = 0.6 + 1.35j
resolution = 4096*4

img = mandel_image(c0, c1, resolution)

# Draw the figure

In [None]:
dpi = 1000
figsize = (5.2,4.5)

img_norm = cdf_img(-img)

fig, (ax1) = plt.subplots(figsize=figsize, dpi=dpi)
im = ax1.imshow(img_norm,
                extent=[c0.real, c1.real, c0.imag, c1.imag],
                origin='lower',
                vmin=-1, vmax=1,
                cmap='Spectral',
                aspect='auto',
                interpolation='nearest'
                )

plt.xlabel(r'Real part of $c$')
plt.ylabel(r'Imaginary part of $c$')

plt.tight_layout()

savename = f'/content/drive/MyDrive/fractal/mandelbrot'
plt.savefig(savename + '.pdf')
plt.savefig(savename + '.png')

In [None]:
dpi = 300
figsize = (5.2,4.5)

img_norm = cdf_img(-img)

fig, (ax1) = plt.subplots(figsize=figsize, dpi=dpi)
im = ax1.imshow(img_norm,
                extent=[c0.real, c1.real, c0.imag, c1.imag],
                origin='lower',
                vmin=-1, vmax=1,
                cmap='Spectral',
                aspect='auto',
                interpolation='nearest'
                )

plt.xlabel(r'Real part of $c$')
plt.ylabel(r'Imaginary part of $c$')

plt.tight_layout()

savename = f'/content/drive/MyDrive/fractal/mandelbrot_midres'
plt.savefig(savename + '.pdf')
plt.savefig(savename + '.png')