Reaction-Diffusion Equations
We'll use RDE to create RGB images. They should result in cool patterns like the following:
https://en.wikipedia.org/wiki/Turing_pattern

$\frac{\partial c_i}{\partial t} = D_i (\frac{\partial^2 c_i}{\partial x^2} + \frac{\partial^2 c_i}{\partial y^2}) + f_i(c_0, c_1, c_2)$

where $c_i$ is $i^{th}$  color dimension, $D_i$ is corresponding diffusion rate, and $f_i$ describes affect of other colors on $c_i$


In [33]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML


D = torch.tensor([0.1, 0.1, 0.1], dtype=torch.float)
dt = 0.1

def init_state(size: int = 128, num_frames: int = 256, kind: str = "rand", **kwargs) -> torch.Tensor:
    frames = torch.zeros(num_frames, 3, size, size, dtype=torch.float)
    if kind == "center":
        frames[0, :, size // 2, size // 2] = 1.0  # Set a single pixel to white
    elif kind.startswith("box"):
        box_size = kwargs.get("box_size", int(kind[3:]))
        frames[
            0, 
            :, 
            size // 2 - box_size // 2 : size // 2 + box_size // 2, 
            size // 2 - box_size // 2 : size // 2 + box_size // 2
        ] = 1.0  # Set a single pixel to white
    elif kind == "rand":
        frames = torch.rand(num_frames, 3, size, size, dtype=torch.float)
        frames[:, 2, ...] = 0.0
        frames[1:, 0, ...] = 0.0
    elif kind == "perturb":
        frames = 1 + kwargs.get("epsilon", 1) * (torch.rand(num_frames, 3, size, size, dtype=torch.float) - 0.5)
        frames[:, 2, ...] = 0.0
        frames[1:, 0, ...] = 0.0
    elif kind == "rings":
        alpha = kwargs.get("alpha")
        frequency = kwargs.get("frequency")
        xs = torch.linspace(0, 1, steps=size)
        ys = torch.linspace(0, 1, steps=size)
        x, y = torch.meshgrid(xs, ys, indexing='xy')
        frames[0, 0, ...] = (1 + alpha * torch.cos(frequency * x)) * (1 + alpha * torch.cos(frequency * y))
        frames[0, 1, ...] = (1 - alpha * torch.cos(frequency * x)) * (1 - alpha * torch.cos(frequency * y))
    else:
        raise NotImplementedError(f"{kind=} is not implemented")
    return frames

class Interaction(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        o = torch.zeros_like(x)
        o[0, ...] = (x[0, ...] + x[1, ...])
        o[1, ...] = (x[1, ...] - x[0, ...])
        o[2, ...] = 0
        return o


def laplacian(x: torch.Tensor) -> torch.Tensor:
    # x: C, H, W
    # out: C, H, W
    laplacian = torch.zeros_like(x) - 4 * x
    laplacian[:, :-1, :] += x[:, 1:, :]  # x+1
    laplacian[:, :, :-1] += x[:, :, 1:]  # y+1
    laplacian[:, 1:, :] += x[:, :-1, :]  # x-1
    laplacian[:, :, 1:] += x[:, :, :-1]  # y-1
    return laplacian


def simulate(frames: torch.Tensor, interaction: Interaction, D: torch.Tensor, dt: float) -> torch.Tensor:
    T, C, H, W = frames.shape
    D = D.view(C, 1, 1)
    for t in range(T-1):
        frames[t + 1] = frames[t] + ((D * laplacian(frames[t, ...]) + interaction(frames[t])) * dt)
    return frames


In [29]:

class Schnakenberg(nn.Module):
    def __init__(self, a: float, b: float) -> None:
        super().__init__()
        self.a = a
        self.b = b

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        o = torch.zeros_like(x)
        o[0, ...] = self.a - x[0] + (torch.pow(x[0], 2) * x[1])
        o[1, ...] = self.b - (torch.pow(x[0], 2) * x[1])
        o[2, ...] = 0
        return o

class FitzHughNagumo(nn.Module):
    def __init__(self, epsilon: float, beta: float, gamma: float) -> None:
        super().__init__()
        self.epsilon = epsilon
        self.beta = beta
        self.gamma = gamma

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        o = torch.zeros_like(x)
        o[0, ...] = x[0] - x[1] + (torch.pow(x[0], 3) / 3)
        o[1, ...] = self.epsilon * (x[0] + self.beta - self.gamma * x[1])
        o[2, ...] = 0
        return o

class Brusselator(nn.Module):
    def __init__(self, a: float, b: float) -> None:
        super().__init__()
        self.a = a
        self.b = b

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        o = torch.zeros_like(x)
        o[0, ...] = 1 + torch.pow(x[0], 2) * x[1] - self.a * (x[1] + 1)
        o[1, ...] = x[0] * (self.b - (torch.pow(x[0], 2) * x[1]))
        o[2, ...] = 0
        return o


In [30]:

def animate(frames: torch.Tensor) -> FuncAnimation:
    num_frames = frames.shape[0]
    # Create a figure and axis
    fig, ax = plt.subplots()

    # Create an empty plot
    frame = ax.matshow(frames[0].detach().cpu().permute(1, 2, 0).numpy(), vmin=0, vmax=5.0)

    # Update function for the animation
    def update(frame_index):
        frame.set_data(frames[frame_index].detach().cpu().permute(1, 2, 0).numpy())

    # Create the animation
    animation = FuncAnimation(fig, update, frames=num_frames, interval=50)
    plt.tight_layout()
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    # Display the animation
    plt.close()  # Prevents displaying the static plot below
    return HTML(animation.to_html5_video())

In [47]:
frames = init_state(128, 128, kind="rand", frequency=18, alpha=0.5)
frames = simulate(frames, interaction=FitzHughNagumo(epsilon=0.1, beta=0.1, gamma=0.1), D=torch.tensor([0.1, 0.1, 0.1]), dt=0.05)
print(frames.min(), frames.max())
animate(frames=frames)

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


tensor(nan) tensor(nan)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

In [44]:
frames = init_state(128, 128, kind="rings", frequency=18, alpha=0.5)
frames = simulate(frames, interaction=Brusselator(a=0.2, b=0.1), D=torch.tensor([0.1, 0.1, 0.1]), dt=0.05)
print(frames.min(), frames.max())
animate(frames=frames)

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


tensor(nan) tensor(nan)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

In [42]:
frames = init_state(128, 128, kind="rings", frequency=18, alpha=0.5)
frames = simulate(frames, interaction=Schnakenberg(a=0.2, b=0.1), D=torch.tensor([0.1, 0.1, 0.1]), dt=0.1)
print(frames.min(), frames.max())
animate(frames=frames)

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


tensor(0.) tensor(2.2500)


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i