# Morphogenesis with Attention/Transformer Cellular Automata

Let's use the code from `Render_MHA.ipynb` to train self-assembly into a particle system with dynamics defined by a multi-head attention block. 

We will compute the $\ell_2$ distance between the target image and the particle system projection (Gaussian bitmap) and optimize the weights accordingly. 

I honestly don't know if the single MHA will be able to do it -- or if we will need multiple layers. I have a hunch that we can do this with an MHA and a non-linearity given enough latent states. 

In [2]:
## Import box
import numpy as np 
import matplotlib.pyplot as plt

import torch 
from torch import nn

from tqdm import tqdm
import os
import glob 
import cv2

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)

Device:  cuda


  from .autonotebook import tqdm as notebook_tqdm


## 0: Dynamical MHA Functions

In [3]:
def RGB_render_particles_to_bitmap(X, x_res, y_res, sigma=1.0, scale_sigma=False):
    """
    Renders colored particles as Gaussian distributions on an RGB bitmap.

    Args:
        X (torch.Tensor): A tensor of shape [N, d] where d >= 5, with the first two dimensions
                          representing particle positions and the next three dimensions
                          representing RGB colors of each particle.
        x_res (int): The resolution of the bitmap along the x-axis.
        y_res (int): The resolution of the bitmap along the y-axis.
        sigma (float): The standard deviation of the Gaussians in pixels.
        scale_sigma (bool): If True, scales sigma by the same factor as positions.

    Returns:
        torch.Tensor: An RGB bitmap of shape [x_res, y_res, 3] with rendered Gaussian
                      distributions of particles, each colored according to its RGB value.
    """
    # Initialize an empty RGB bitmap of shape [y_res, x_res, 3]
    rgb_bitmap = torch.zeros((y_res, x_res, 3))

    # Scale particle positions to fit within [-1, 1]
    min_pos = torch.min(X[:, 0:2], dim=0)[0]
    max_pos = torch.max(X[:, 0:2], dim=0)[0]
    scale_factor = torch.max(max_pos - min_pos) / 2
    scaled_X = (X[:, 0:2] - (min_pos + scale_factor)) / scale_factor

    # Optionally scale sigma
    if scale_sigma:
        sigma *= scale_factor

    # Generate a meshgrid for the bitmap coordinates
    x = torch.linspace(-1, 1, steps=x_res)
    y = torch.linspace(-1, 1, steps=y_res)
    xx, yy = torch.meshgrid(x, y, indexing="ij")

    # For each particle, add its colored Gaussian distribution to the RGB bitmap
    for i in range(X.shape[0]):
        # Calculate Gaussian distribution
        gx = torch.exp(-((xx - scaled_X[i, 0]) ** 2) / (2 * sigma ** 2))
        gy = torch.exp(-((yy - scaled_X[i, 1]) ** 2) / (2 * sigma ** 2))
        gaussian = gx * gy

        # Apply color to Gaussian
        for channel in range(3):
            rgb_bitmap[:, :, channel] += gaussian * X[i, channel + 2]

    # Normalize the RGB bitmap so the maximum Gaussian value is 1 for each channel
    max_val = torch.max(rgb_bitmap.view(-1, 3), dim=0)[0]
    rgb_bitmap /= max_val.clamp(min=1)  # Avoid division by zero

    return rgb_bitmap

In [4]:
def visualize_state_matrix(X_, 
                           title='State Visualization: XY-RGB State Interpretation', 
                           savename='../figs/test.png', 
                           scale_sigma=False,
                           show=False):
    # make the directories leading to savename if they don't exist 
    os.makedirs(os.path.dirname(savename), exist_ok=True)
    
    X = X_.detach().cpu()
    bitmap = RGB_render_particles_to_bitmap(X, x_res, y_res, sigma=0.03, scale_sigma=scale_sigma)
    plt.imshow(bitmap.numpy())
    plt.title(title)
    plt.savefig(savename)
    if show:
        plt.show()
    plt.close()

    

In [5]:
def run_sim(state_matrix, 
            num_iters = 50, 
            dt = 0.02 / 100, 
            noise_factor = 0.005,
            fig_folder=None,
            frame_period = 2000):
    for i in tqdm(range(num_iters)): 
        dx, _ = multihead_attn(state_matrix, state_matrix, state_matrix)
        dxy = dx[:,:,0:2]

        # noise up the dxy 
        noise = torch.randn_like(dxy).to(DEVICE)*noise_factor  # 0 mean, variance 1
        dxy += noise

        # let's try normalizing dxy -- right now the variance is super tight.
        dxy_std, dxy_mean = torch.std_mean(dxy)
        dxy = dxy - dxy_mean
        dxy = dxy/dxy_std

        dxyrgb = torch.zeros(dx.shape).to(DEVICE) 
        dxyrgb[:,:,0:2] = dxy

        state_matrix += dxyrgb*dt
        
        # savename = f'../figs/{fig_folder}/frame_{str(i).zfill(3)}.png'
        if fig_folder is not None and i % frame_period == 0: 
            with torch.no_grad():
                visualize_state_matrix(state_matrix[:,0,:], 
                                    scale_sigma=True,
                                    savename=f'../figs/{fig_folder}/frame_{str(i//frame_period).zfill(3)}.png')
    return state_matrix


## 1: Hyperparameters

In [6]:
## Setting up the MHA 
embed_dim = 40      # Key and query dimension expected in the input.
key_dim = 40        # Expected dimension of the keys + queries
val_dim = 40        # Expected dimension of the 
num_heads = 4       # Number of heads in the MHA.
multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, kdim=key_dim, vdim=val_dim, bias=True).to(DEVICE) # value/outpu

## Setting up the state matrix 
N = 200             # Number of particles/sequence length
num_batch = 1

state_matrix = torch.rand(N, num_batch, val_dim).to(DEVICE) # uniform [0, 1]
print("State matrix shape (N, batch, particle_state_dim): ", state_matrix.shape)

## Define a for loop that iterates thru some number of state updates
num_iters = 10000
dt = 0.02 / 100
noise_factor = 0.005
fig_folder='anim04'

x_res = 500
y_res = 500

frame_period = 200
# frame_period = 2000

State matrix shape (N, batch, particle_state_dim):  torch.Size([200, 1, 40])


## 2: Running the Simulation/Training

In [7]:
# reset output directory contents
png_path = f'../figs/{fig_folder}/*.png'

os.system(f'rm {png_path}')
os.makedirs(f'../figs/{fig_folder}', exist_ok=True)

In [None]:

state_matrix = run_sim(state_matrix,
                        num_iters=num_iters,
                        dt=dt,
                        noise_factor=noise_factor,
                        fig_folder=fig_folder,
                        frame_period=frame_period)