Notebook to visualize WireWorld generator and model outputs

In [None]:
from datasets.wireworld import WireworldGraph
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

import numpy as np

In [None]:
# Assume GRID_SIZE and dataset are already defined
GRID_SIZE = 10
graph = WireworldGraph(grid_size=GRID_SIZE).data[0]

def plot_wireworld(values, grid_size, title=None):
    # Define a color map
    # empty (black), electron head (blue), electron tail (red), conductor (yellow)
    cmap = ListedColormap(['black', 'blue', 'red', 'yellow'])

    # Prepare data for imshow
    data = values.reshape(grid_size, grid_size)

    # Plot
    fig, ax = plt.subplots()
    ax.imshow(data, cmap=cmap)

    # Draw grid lines
    for x in range(grid_size+1):
        ax.axhline(x-0.5, color='black', linewidth=0.5)
        ax.axvline(x-0.5, color='black', linewidth=0.5)

    # Set title and remove ticks
    if title:
        plt.title(title)
    ax.set_xticks([])
    ax.set_yticks([])
    plt.show()


plot_wireworld(graph.x.numpy(), GRID_SIZE, "Wireworld Input")
plot_wireworld(graph.y.numpy(), GRID_SIZE, "Wireworld Output")

Visualization WireWorld for paper

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors

EMPTY = 0  # Empty cell
WIRE = 1  # Conductor cell
HEAD = 2  # Electron head
TAIL = 3  # Electron tail

PALETTE = np.array([
    [0, 0, 0],     # Empty: White
    [255, 255, 0],   # Wire: Black
    [0, 0, 255],   # Head: Blue
    [255, 0, 0],   # Tail: Red
], dtype=np.uint8)

def wireworld_step(grid):
    new_grid = grid.copy()
    height, width = grid.shape
    for y in range(height):
        for x in range(width):
            if grid[y, x] == WIRE:
                count = 0
                for dy in [-1, 0, 1]:
                    for dx in [-1, 0, 1]:
                        if (dx != 0 or dy != 0) and (0 <= y + dy < height) and (0 <= x + dx < width) and grid[y + dy, x + dx] == HEAD:
                            count += 1
                if count == 1 or count == 2:
                    new_grid[y, x] = HEAD
            elif grid[y, x] == HEAD:
                new_grid[y, x] = TAIL
            elif grid[y, x] == TAIL:
                new_grid[y, x] = WIRE
    return new_grid

world = np.array([
    [EMPTY, EMPTY, EMPTY, EMPTY,  EMPTY, EMPTY, EMPTY, EMPTY, EMPTY, EMPTY],
    [EMPTY, EMPTY, EMPTY, EMPTY,  EMPTY, EMPTY, EMPTY, EMPTY, EMPTY, EMPTY],
    [EMPTY, EMPTY, EMPTY, EMPTY,  EMPTY, WIRE, WIRE, EMPTY, EMPTY, EMPTY],
    [HEAD,  WIRE,  WIRE,  WIRE,   WIRE,   WIRE, EMPTY, WIRE, WIRE,  WIRE],
    [EMPTY, EMPTY, EMPTY, EMPTY,  EMPTY, WIRE, WIRE, EMPTY, EMPTY, EMPTY],
    [EMPTY, EMPTY, EMPTY, EMPTY,  EMPTY, EMPTY, EMPTY, EMPTY, EMPTY, EMPTY],
    [EMPTY, EMPTY, EMPTY, EMPTY,  EMPTY, EMPTY, EMPTY, EMPTY, EMPTY, EMPTY],
    [EMPTY, EMPTY, EMPTY, EMPTY,  EMPTY, WIRE, WIRE,  EMPTY, EMPTY, EMPTY],
    [HEAD,  WIRE,  WIRE,  WIRE,   WIRE,  EMPTY, WIRE, WIRE, WIRE,  WIRE],
    [EMPTY, EMPTY, EMPTY, EMPTY,  EMPTY, WIRE, WIRE, EMPTY, EMPTY, EMPTY],
    [EMPTY, EMPTY, EMPTY, EMPTY,  EMPTY, EMPTY, EMPTY, EMPTY, EMPTY, EMPTY],
    [EMPTY, EMPTY, EMPTY, EMPTY,  EMPTY, EMPTY, EMPTY, EMPTY, EMPTY, EMPTY],
], dtype=np.uint8)

fig, axes = plt.subplots(1, 4, figsize=(15, 5))

count = 0
for ax, i in zip(axes, range(4)):
    ax.imshow(PALETTE[world], interpolation='none')
    ax.set_title(f"Step {count}", fontsize=24)  # I've set the fontsize to 16. Adjust as required.
    ax.axis('off')

    world = wireworld_step(world)
    world = wireworld_step(world)
    world = wireworld_step(world)
    count += 3

plt.tight_layout()
plt.show()

Model Evaluation


In [None]:
import torch 
# plase adjust model path
model = torch.load("runs/wireworld/wireworld_1694529621082373000.pt")
model.set_hardmax(True)

def map_x_to_state(xs):
    outputs = []
    for x in xs:
        output = [0,0,0,0]
        output[int(x.item())] = 1
        
        outputs.append(output)
    outputs = torch.tensor(outputs)
    return outputs

input_values = map_x_to_state(graph.x)
output_values = model(input_values, graph.edge_index)
output_values = torch.argmax(output_values, dim=-1)
plot_wireworld(output_values.numpy().reshape(GRID_SIZE, GRID_SIZE), GRID_SIZE, "WireWolrd Model Output")