In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import torch

from IPython.display import clear_output

from growing_ca.core.trainer import CaTrainer
from growing_ca.core.utils_vis import to_rgb

In [None]:
def visualize_batch(vis0, vis1):
    """Visualize a batch of CA states before and after training steps."""
    print("batch (before/after):")
    plt.figure(figsize=[15, 5])
    for i in range(vis0.shape[0]):
        plt.subplot(2, vis0.shape[0], i + 1)
        plt.imshow(vis0[i])
        plt.axis("off")
    for i in range(vis0.shape[0]):
        plt.subplot(2, vis0.shape[0], i + 1 + vis0.shape[0])
        plt.imshow(vis1[i])
        plt.axis("off")
    plt.show()


def plot_loss(loss_log):
    """Plot the loss history on a log10 scale."""
    plt.figure(figsize=(10, 4))
    plt.title("Loss history (log10)")
    plt.plot(np.log10(loss_log), ".", alpha=0.1)
    plt.show()

In [None]:
# Training configuration
device = torch.device(
    torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
)

# Path to target image to train on
# For emojis, split the sprite sheet first: python -m growing_ca.core.emoji_utils data/emoji.png
TARGET_IMAGE = "data/emojis/emoji_0.png"  # Can use individual emoji or any custom image
EXPERIMENT_TYPE = "Regenerating"  # Options: "Growing", "Persistent", "Regenerating"
n_epoch = 8000  # Number of training epochs

# Initialize the trainer with configuration
trainer = CaTrainer(
    target_image_path=TARGET_IMAGE,
    model_path=f"models/{TARGET_IMAGE.split('/')[-1].split('.')[0]}.pth",
    experiment_type=EXPERIMENT_TYPE,
    device=device,
    channel_n=16,
    target_padding=16,
    lr=2e-3,
    lr_gamma=0.9999,
    betas=(0.5, 0.5),
    batch_size=8,
    pool_size=1024,
    cell_fire_rate=0.5,
)

print(f"Trainer initialized for image: {TARGET_IMAGE}")
print(f"Experiment type: {EXPERIMENT_TYPE}")
print(f"Device: {device}")

In [None]:
# Visualize the target image
target_img = trainer.load_image(TARGET_IMAGE)
plt.figure(figsize=(4, 4))
plt.title(f"Target Image: {TARGET_IMAGE}")
plt.imshow(to_rgb(target_img))
plt.axis("off")
plt.show()

In [None]:
# Training loop with visualization
loss_log = []

print(f"Starting training for {n_epoch} epochs...")

for i in range(n_epoch + 1):
    # Get batch
    x0, batch = trainer.get_batch()

    # Train for random number of steps
    steps = np.random.randint(64, 96)
    x, loss = trainer.train_step(x0, trainer.pad_target, steps)

    # Update pool if using pattern pool
    if trainer.use_pattern_pool and batch is not None:
        batch.x[:] = x.detach().cpu().numpy()
        batch.commit()

    # Log loss
    loss_log.append(loss.item())

    # Visualize progress every 100 steps
    if i % 100 == 0:
        clear_output(wait=True)
        print(f"Epoch {i}/{n_epoch}, Loss: {loss.item():.6f}")

        # Visualize batch
        vis0, vis1 = trainer.visualize_current_state()
        visualize_batch(vis0, vis1)

        # Plot loss history
        plot_loss(loss_log)

        # Save model
        trainer.save_model()

print("Training complete!")