Notebook to visualize Game of Life generator and model outputs

In [None]:
from datasets.gameoflife import GameOfLifeGraph
import matplotlib.pyplot as plt
import numpy as np

Visualize Standard GOL

In [None]:
dataset = GameOfLifeGraph(num_graphs=1, grid_size=5, steps=1).data

def plot_grid(values, grid_size, title=None):
    print("values", values)
    fig, ax = plt.subplots()
    ax.imshow(values, cmap="gray_r")

    # Draw grid lines
    for x in range(grid_size+1):
        ax.axhline(x-0.5, color='black', linewidth=0.5)
        ax.axvline(x-0.5, color='black', linewidth=0.5)

    plt.title(title)
    ax.set_xticks([])
    ax.set_yticks([])
    plt.show()

graph = dataset[0]
plot_grid(graph.x.reshape(5, 5), 5, "Game of Life Grid Input")
plot_grid(graph.y.numpy().reshape(5, 5), 5, "Game of Life Grid Output")



Visualize Model Output


In [None]:
import torch 
# change import path
model = torch.load("runs/games-of-life/gol_step_2_1694696664990373000.pt")
model.set_hardmax(True)

def map_x_to_state(xs):
    outputs = []
    for x in xs:
        output = [0,0]
        output[int(x.item())] = 1
        
        outputs.append(output)
    outputs = torch.tensor(outputs)
    return outputs

input_values = map_x_to_state(graph.x)
output_values = model(input_values, graph.edge_index)
output_values = torch.argmax(output_values, dim=-1)
plot_grid(output_values.numpy().reshape(5, 5), 5, "Game of Life Grid Output")

Model Evaluation - Visualize ModelGraph


In [None]:
# extracting multiple state machines (we trained a state machine for each input class)
from graphviz import Digraph
from itertools import product


for start_id in range(0,2):
    G = Digraph(graph_attr={'rankdir':'LR'}, node_attr={'shape':'circle', "width": "0.8"})

    state_names = ['s0', 's1']

    G.node(state_names[0], penwidth='3px')
    G.node(state_names[1], penwidth='3px')

    added = []

    train_dataset = GameOfLifeGraph(num_graphs=200, grid_size=5, steps=3).data

    for j in range(len(train_dataset)):
            
        model.set_hardmax(True)
        model.set_logging(True)

        out, all_states, transitions = model(map_x_to_state(train_dataset[j].x), train_dataset[j].edge_index)
        out = torch.argmax(out, dim=-1)

        for i in range(25):
            
            current_state = torch.argmax(all_states[0][i], dim=-1).item()
            if start_id != current_state:
                continue

            for j in range(1,3):
                next_state = torch.argmax(all_states[j][i], dim=-1).item()
               # print("next_state", next_state)
                transition_value = str(transitions[j-1][i].tolist())

                transitionTuple = (state_names[current_state], state_names[next_state], transition_value)

                if transitionTuple not in added:
                                            
                    G.edge(state_names[current_state], state_names[next_state], transition_value)
                    added.append(transitionTuple)

                current_state = next_state

    print("starting state s", start_id)
    G.render('task0_ref_fsm', format='svg')
    display(G)

Statistic to understand GOL dataset distribution

In [None]:

import torch
import numpy as np
import networkx as nx
from torch_geometric.utils import from_networkx
from torch_geometric.data import Data
import matplotlib.pyplot as plt
from datasets.gameoflife import GameOfLifeGraph

# Function to generate statistics
def gather_statistics(grid_sizes, num_samples=100, steps=1, toroidal=False):
    stats = {}
    
    for grid_size in grid_sizes:
        black_count = 0
        white_count = 0

        # Generate samples and gather statistics
        game = GameOfLifeGraph(grid_size=grid_size, num_graphs=num_samples, steps=steps, toroidal=toroidal)
        
        for graph in game.data:
            black_count += torch.sum(graph.y == 1).item()
            white_count += torch.sum(graph.y == 0).item()
            
        total_cells = grid_size * grid_size * num_samples
        stats[grid_size] = {
            "black_ratio": black_count / total_cells,
            "white_ratio": white_count / total_cells
        }
        
    return stats

STEPS = 1

# Sample different grid sizes
grid_sizes = [3,4,5,6,8,10,15,20]
statistics = gather_statistics(grid_sizes, num_samples=10, steps=STEPS, toroidal=False)

# Plot the statistics
plt.figure(figsize=(10,6))

black_ratios = [100 * statistics[grid_size]["black_ratio"] for grid_size in grid_sizes]
white_ratios = [100 * statistics[grid_size]["white_ratio"] for grid_size in grid_sizes]

plt.plot(grid_sizes, black_ratios, marker='o', label="Black Ratio")
plt.plot(grid_sizes, white_ratios, marker='o', label="White Ratio")


plt.xticks(range(3, 21))
plt.yticks(range(0,81, 10))


plt.xlabel('Grid Size', fontsize=30)
plt.ylabel('Ratios in %', fontsize=30)
#plt.title(f'Black/White Ratios for Different Grid Sizes after {STEPS} Steps')
plt.legend(fontsize=30)
plt.grid(True)
plt.show()


# HexagonalGameOfLife

Visualization

In [None]:
# Visualization
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from datasets.gameoflife import HexagonalGameOfLifeGraph

def plot_hex_grid(values, grid_size, title=None):
    fig, ax = plt.subplots()
    ax.set_aspect('equal')

    # Calculate hexagon spacing parameters
    dx = 3/2
    dy = 4*np.sqrt(3)/5
    d_row = 7*np.sqrt(3)/16

    for i in range(grid_size):
        for j in range(grid_size):
            x = j * dx
            y = -i * dy
            
            # Check if the row is even or odd to adjust x coordinate
            if i % 2 == 0:
                x -= d_row

            hexagon = patches.RegularPolygon((x, y), numVertices=6, radius=0.85, 
                                             facecolor='#D3D3D3' if values[i, j] == 0 else 'black',
                                             edgecolor='black', linewidth=0.5)
            ax.add_patch(hexagon)

    ax.axis('off')
    ax.autoscale_view()
    plt.title(title)
    plt.show()

dataset = HexagonalGameOfLifeGraph(num_graphs=1, grid_size=5, steps=1).data

graph = dataset[0]
plot_hex_grid(graph.x.reshape(5, 5).numpy(), 5, "Game of Life Grid Input")
plot_hex_grid(graph.y.numpy().reshape(5, 5), 5, "Game of Life Grid Output")

Visualizes neighbourhood in HexagonalGOL

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.patches as patches
import torch

def plot_hex_grid_with_edges(values, edge_index, grid_size, title=None):
    fig, ax = plt.subplots()
    ax.set_aspect('equal')

    # Calculate hexagon spacing parameters
    dx = 3/2
    dy = 4*np.sqrt(3)/5
    d_row = 7*np.sqrt(3)/16

    # Helper function to get (x, y) position of a hexagon by its (i, j) index
    def get_pos(i, j):
        x = j * dx
        y = -i * dy
        if i % 2 == 0:
            x -= d_row
        return (x, y)

    # Draw the hexagons
    for i in range(grid_size):
        for j in range(grid_size):
            x, y = get_pos(i, j)
            hexagon = patches.RegularPolygon((x, y), numVertices=6, radius=0.85, 
                                             facecolor='#D3D3D3' if values[i, j] == 0 else 'black',
                                             edgecolor='black', linewidth=0.5)
            ax.add_patch(hexagon)

    # Draw edges based on the edge_index
    for start, end in edge_index.t().numpy():
        x1, y1 = divmod(start, grid_size)
        x2, y2 = divmod(end, grid_size)
        
        start_pos = get_pos(x1, y1)
        end_pos = get_pos(x2, y2)
        
        ax.plot([start_pos[0], end_pos[0]], [start_pos[1], end_pos[1]], color='red', linewidth=0.5)

    ax.axis('off')
    ax.autoscale_view()
    plt.title(title)
    plt.show()

# Example usage:
dataset = HexagonalGameOfLifeGraph(num_graphs=1, grid_size=5, steps=1).data
graph = dataset[0]
plot_hex_grid_with_edges(graph.x.reshape(5, 5), graph.edge_index, 5, "Game of Life Grid with Edges")
