In [1]:
import numpy as np
import time
from ipycanvas import MultiCanvas, hold_canvas
from ipywidgets import Button, GridspecLayout, Output, IntSlider

import torch
from torch import nn
from MinimalSolution import MinNet

In [2]:
# Creates a 32x32 field to play Life on and populates it with a glider pattern
glider = np.zeros((32,32))
glider[1,3]=1
glider[2,3]=1
glider[3,3]=1
glider[3,2]=1
glider[2,1]=1

In [8]:
pulsar = np.zeros((32,32))
pulsar[7,6] = pulsar[8,6] = pulsar[9,6] = pulsar[13,6] = pulsar[14,6] = pulsar[15,6] = 1
pulsar[7,11] = pulsar[8,11] = pulsar[9,11] = pulsar[13,11] = pulsar[14,11] = pulsar[15,11] = 1
pulsar[7,13] = pulsar[8,13] = pulsar[9,13] = pulsar[13,13] = pulsar[14,13] = pulsar[15,13] = 1
pulsar[7,18] = pulsar[8,18] = pulsar[9,18] = pulsar[13,18] = pulsar[14,18] = pulsar[15,18] = 1
pulsar[5,8] = pulsar[5,9] = pulsar[5,10] = pulsar[5,14] = pulsar[5,15] = pulsar[5,16] = 1
pulsar[10,8] = pulsar[10,9] = pulsar[10,10] = pulsar[10,14] = pulsar[10,15] = pulsar[10,16] = 1
pulsar[12,8] = pulsar[12,9] = pulsar[12,10] = pulsar[12,14] = pulsar[12,15] = pulsar[12,16] = 1
pulsar[17,8] = pulsar[17,9] = pulsar[17,10] = pulsar[17,14] = pulsar[17,15] = pulsar[17,16] = 1

In [3]:
random = np.random.randint(0,2,(32,32))

In [4]:
class GOLCNNInspector:
    def __init__(self, golcnn, init_state_param, c_width=512, c_height=512):
                
        self.init_state = init_state_param.copy()
        self.game_state = init_state_param.copy()
        self.field_width, self.field_height = np.shape(self.init_state)
        
        self.model = golcnn
        
        self.step_list = []
        self.compute_step_list()
        
        # Set up debug output
        self.out = Output()
        
        # Set up canvas
        self.canvas_width = c_width
        self.canvas_height = c_height
        self.multicanvas = MultiCanvas(2, width=self.canvas_width, height=self.canvas_height)
        
        self.multicanvas[0].fill_style = '#eceff4'
        self.multicanvas[0].fill_rect(0, 0, self.multicanvas.width, self.multicanvas.height)
        
        @self.out.capture()
        def handle_mouse_up(x, y):
            self.handle_mouse_event(x, y)

        self.multicanvas.on_mouse_up(handle_mouse_up)
    
        def reset_button_on_click(b):
            self.reset();
            
        def handle_slider_change(b):
            v = self.intslider.value
            self.change_and_render(v)
        
        self.reset_button = Button(
            description='Reset',
            disabled=False,
            tooltip='Reset',
        )
        self.reset_button.on_click(reset_button_on_click)
        
        self.intslider = IntSlider(
            value = 0,
            min = 0,
            max = self.model.num_steps,
            step = 1,
            description = 'Step:',
            disabled=False,
            continuous_update=False,
            orientation='horizontal',
            readout=True,
            readout_format='d'
        )
        
        self.intslider.observe(handle_slider_change, names='value')
           
        self.ui = GridspecLayout(1,2)
        self.ui[0,0] = self.reset_button
        self.ui[0,1] = self.intslider
    
        self.render()
        
        display(self.multicanvas)
        display(self.ui)
        display(self.out)
    
    def render(self):
        with hold_canvas(self.multicanvas[1]):
            self.multicanvas[1].clear()
            self.multicanvas[1].fill_style = '#3b4252'
        
            m, n = np.shape(self.game_state)
            pixheight = self.multicanvas[1].height
            pixwidth = self.multicanvas[1].width
            m_pixels = pixheight / m
            n_pixels = pixwidth / n

            r = 0
            for row in self.game_state:
                c = 0
                for value in row:
                    if value:
                        self.multicanvas[1].fill_rect(r * m_pixels, c * n_pixels, m_pixels, n_pixels)
                    c += 1
                r += 1
                
    '''
    n must be <= the number of steps the CNN simulates
    '''
    def change_and_render(self, n):
        self.game_state = self.step_list[n]
        self.render()
    
    def reset(self):
        self.game_state = self.init_state.copy()
        self.intslider.value = 0
        self.render()
            
    def handle_mouse_event(self, x, y):
        field_x = int(np.floor(x / ( self.canvas_width / self.field_width)))
        field_y = int(np.floor(y / (self.canvas_height / self.field_height)))
        self.game_state[field_x, field_y] = (not self.game_state[field_x, field_y])
        self.compute_step_list()
        self.render()
        
    def compute_step_list(self):
        self.step_list = [self.game_state.copy()]
        for i in range(1,self.model.num_steps+1):
            
            state_prep = np.array([[self.game_state.copy()]])
            model_in = torch.from_numpy(state_prep).type(torch.float32)
            model_in = model_in.repeat(1,self.model.overcompleteness_factor,1,1)
            model_out = self.model.inspect_step(model_in, i)
            
            self.step_list.append(model_out[0][0].detach().numpy().round())

In [5]:
mynet = MinNet(2)

testnet = torch.load(f'./models/op_m16_n2_model1.pt')
testnet.to('cpu')

OPNet(
  (type1_layers): ModuleList(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
    (1): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode=circular)
  )
  (type2_layers): ModuleList(
    (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
  )
  (relu1): ReLU()
  (relu2): ReLU()
  (final_conv_layer): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))
  (sigmoid3): Sigmoid()
)

In [9]:
inspector = GOLCNNInspector(mynet, pulsar)

MultiCanvas(height=512, width=512)

GridspecLayout(children=(Button(description='Reset', layout=Layout(grid_area='widget001'), style=ButtonStyle()…

Output()

In [10]:
inspector2 = GOLCNNInspector(testnet, pulsar)

MultiCanvas(height=512, width=512)

GridspecLayout(children=(Button(description='Reset', layout=Layout(grid_area='widget001'), style=ButtonStyle()…

Output()