In [None]:
import numpy as np
from tqdm import tqdm
from glob import glob
from jax import jit, lax
import jax.numpy as jnp
import cv2
import os

from PIL import Image, ImageDraw, ImageFont
import matplotlib.font_manager as fm
path = fm.findfont(fm.FontProperties(family='cmr10'))


In [None]:
def mandelbrot(z, max_iter, c):
    def cond_fun(val):
        z, i = val
        return (jnp.abs(z) < 2) & (i < max_iter)

    def body_fun(val):
        z, i = val
        z = z**2 + c
        return z, i + 1

    z_final, iter_count = lax.while_loop(cond_fun, body_fun, (z, 0))
    return iter_count/max_iter


colorize = jnp.vectorize(lambda x: 255*x)

In [None]:
def create_complex_grid(width, height, x_min=-2, x_max=2, y_min=-1, y_max=1):
    """
    Create a 2D grid of complex numbers.

    Args:
    width, height: Screen resolution
    x_min, x_max, y_min, y_max: The range of the complex plane to map to the screen

    Returns:
    A 2D JAX array of complex numbers
    """
    x = jnp.linspace(x_min, x_max, width)
    y = jnp.linspace(y_max, y_min, height)  # Reversed to match screen coordinates
    xx, yy = jnp.meshgrid(x, y)
    return xx + 1j * yy

# Example usage:
screen_width, screen_height = 1920, 1080
complex_grid = create_complex_grid(screen_width, screen_height)

# Now you can use this with jax.vmap to compute the Mandelbrot set
from jax import vmap

def compute_mandelbrot(grid, max_iter, c=None):
    """
    Compute the Mandelbrot set for a given grid of complex numbers.

    Args:
    grid: A 2D JAX array of complex numbers
    max_iter: The maximum number of iterations
    c: The constant to use in the Julia set

    Returns:
    A 2D JAX array of integers representing the number of iterations
    """

    if c is None:
        # Original Mandelbrot set: c is the input value
        mandelbrot_vmap = vmap(vmap(lambda z: mandelbrot(0j, max_iter, z)))
        return mandelbrot_vmap(grid)
    else:
        # Julia set: z is the input value, c is constant
        mandelbrot_vmap = vmap(vmap(lambda z: mandelbrot(z, max_iter, c)))
        return mandelbrot_vmap(grid)

colorize = jnp.vectorize(lambda x: 255*x)

result = compute_mandelbrot(complex_grid, 100, c=complex(-0.76,0.19))

In [None]:
def create_custom_linspace(var, duration):
  """
  Creates a Linear Space between two values given a duration

  Args:
  var: A 2D list consisting of two values in each row, corresponding to its x & y
  duration: How many values are needed in the linear space shown as an integer
  """

  length_each = int(duration/(len(var)-1))

  real = jnp.linspace(var[0][0], var[1][0], length_each)
  imag = jnp.linspace(var[0][1], var[1][1], length_each)

  for i in range(len(var)-1):

    if i + 2 >= len(var):
      continue

    curr_real = jnp.linspace(var[i+1][0], var[i+2][0], length_each)
    curr_imag = jnp.linspace(var[i+1][1], var[i+2][1], length_each)

    real = jnp.concatenate((real, curr_real))
    imag = jnp.concatenate((imag, curr_imag))
  return jnp.vstack((real, imag)).T


In [None]:
duration = 30
fps = 60

var = [
    [.72,.31],
    [.26,.74],
    [-.32, .72],
    [-.68, .4],
    [-.73, -.29],
    [-.79, -.03],
    [0.36,-0.36],
    [-.16, -.77],
    [.72,.31],
    ]
constants = create_custom_linspace(var, duration*fps)

# **Main Rendering Functionality**

In [None]:


colorize = jnp.vectorize(lambda x: 255*x)
iter = 50
i = 0
screen_width, screen_height = 1920, 1080

complex_grid = create_complex_grid(screen_width, screen_height)
font_size = 60
font = ImageFont.truetype(path, font_size)
for c in tqdm(constants):
  result = compute_mandelbrot(complex_grid, iter, complex(c[0], c[1]))

  image = np.array(colorize(result)).astype(np.uint8)

  # Convert to a 3-channel image if it's not already
  if len(image.shape) == 2:
      image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

  pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

  # Create the caption text
  caption = f"c = {c[0]:.2f} + {c[1]:.2f}i"

  # Add the caption to the image
  draw = ImageDraw.Draw(pil_image)

  # Create the caption text
  caption = f"c = {c[0]:.2f} + {c[1]:.2f}i"

  # Add the caption to the image
  position = (10, 10)  # Top-left position
  draw.text(position, caption, font=font, fill=(255, 255, 255))

  image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)

  # Save the image
  cv2.imwrite(f"/content/Output/Mandelbrot_frame_{i}.png", image)
  i+=1

100%|██████████| 1800/1800 [08:22<00:00,  3.58it/s]


In [None]:
path = "/content/Output/*.png"

images = glob(path)


series_image = {int(os.path.split(images[i])[1].split("_")[-1].split(".")[0]):images[i] for i in range(len(images))}
sorted_series = sorted(series_image.keys())

In [None]:
vid_name ="/content/drive/MyDrive/Colab_projects/mandelbrot.mp4"
writer = cv2.VideoWriter(vid_name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (screen_width,screen_height))
for idx in tqdm(sorted_series):
    writer.write(cv2.imread(series_image[idx]))
cv2.destroyAllWindows()
writer.release()

100%|██████████| 1800/1800 [01:03<00:00, 28.47it/s]
