In [None]:
# Run once

%load_ext autoreload
%autoreload 2

%matplotlib ipympl

import math
import repl
import torch
import matplotlib
import matplotlib.pyplot
import matplotlib.collections
from PIL import Image

hook_to_call = None

class TorchHookContextManager():
    def __init__(self, hook):
        self.hook = hook
    def __enter__(self):
        global hook_to_call
        hook_to_call = self.hook
    def __exit__(self, type, value, traceback):
        global hook_to_call
        hook_to_call = None

def the_one_and_only_hook(module, input, output):
    if hook_to_call:
        hook_to_call(module, input, output)

torch.nn.modules.module.register_module_forward_hook(the_one_and_only_hook)

In [None]:
the_repl = repl.Repl()

NN_W = 16
NN_H = 16
gfx_pos_bb_left = -1
gfx_pos_bb_right = 1 + (NN_W-1) * 2 + (NN_H-1)
gfx_pos_bb_bottom = -1
gfx_pos_bb_top = 1 + (NN_H-1) * math.sqrt(3)
gfx_pos_centre_x = (gfx_pos_bb_left + gfx_pos_bb_right) / 2
gfx_pos_centre_y = (gfx_pos_bb_bottom + gfx_pos_bb_top) / 2
gfx_width = gfx_pos_bb_right - gfx_pos_bb_left
gfx_height = gfx_pos_bb_top - gfx_pos_bb_bottom
gfx_scale = 25

In [None]:
# Run once per model to investigate

import model
the_model = model.model_v1_params_v1()
saved_model = torch.jit.load('../test/net/mainline/1200000.pt', map_location=torch.device('cpu'))
assert([p.numel() for p in the_model.parameters()] == [p.numel() for p in saved_model.parameters()])
the_model.load_state_dict(saved_model.state_dict())

In [None]:
the_repl.run(['load', '--recentre-to-fit-within-nn-bounds',
             #'--rotate', '5'
             ])

In [None]:
the_repl.run('cycle 140')

the_repl.run(['render',
          '--filename', '/tmp/output.png',
          '--offset', ','.join(map(str, [-gfx_pos_centre_x, -gfx_pos_centre_y])),
          '--scale', str(gfx_scale),
          '--size', 'x'.join([str(int(math.ceil(f))) for f in [gfx_width * gfx_scale, gfx_height * gfx_scale]])])

the_repl.run(['nn.save_input', '/tmp/inputs.npz'])

In [None]:
img = Image.open('/tmp/output.png')
img = img.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
fig = matplotlib.pyplot.figure()
ax = fig.subplots()
ax.imshow(img, origin='lower', extent=(gfx_pos_bb_left, gfx_pos_bb_right, gfx_pos_bb_bottom, gfx_pos_bb_top))

def hexplot(ax):
    def verts(x, y):
        patch = matplotlib.patches.RegularPolygon((x*2 + y, math.sqrt(3) * y), 6, radius=math.sqrt(4/3))
        return patch
    c = matplotlib.collections.PatchCollection([verts(x, y) for x in range(NN_W) for y in range(NN_H)])
    c.set_alpha(0.5)
    return c

import numpy as np
inputs = np.load('/tmp/inputs.npz')
inputs

def get_intermediate_layer(layer_id):
    intermediate_tensor = None
    def hook(module, input, output):
        if module == the_model.trunk.resblocks[layer_id]:
            nonlocal intermediate_tensor
            (i,) = input
            intermediate_tensor = i.detach().numpy()
    with TorchHookContextManager(hook) as hook_mgr:
        the_model(torch.from_numpy(inputs['spatial']),
                  torch.from_numpy(inputs['spatiotemporal']),
                  torch.from_numpy(inputs['temporal']),
                  torch.from_numpy(inputs['policy_softmax_temperature']))
    if intermediate_tensor is None:
        raise RuntimeError("Hook didn't execute for given layer")
    print(intermediate_tensor.shape)
    return intermediate_tensor[0]

#data = inputs['spatiotemporal'][0]
#data = np.expand_dims(inputs['spatial'][0], axis=1)
data = get_intermediate_layer(2)

hp = hexplot(ax)
ax.add_collection(hp)
fig.colorbar(hp, ax=ax, orientation='horizontal')

dimension_id = 0
time_id = 0
def update():
    global dimension_id
    global time_id
    hp.set_array([data[dimension_id, time_id, y, x] for x in range(NN_W) for y in range(NN_H)])
    hp.autoscale()
    print('{} {}'.format(hp.norm.vmin, hp.norm.vmax))
    if hp.norm.vmin == -0.1 and hp.norm.vmax == 0.1:
        hp.norm.vmin = 0
        hp.norm.vmax = 1
    fig.canvas.draw_idle()
update()

# Make a vertically oriented slider to control the dimension_id
ax_dimension = fig.add_axes([0.06, 0.25, 0.0225, 0.63])
dimension_slider = matplotlib.widgets.Slider(
    ax=ax_dimension,
    label="Dim",
    valmin=0,
    closedmin=True,
    valmax=data.shape[0],
    closedmax=False,
    valstep=1,
    valinit=dimension_id,
    orientation="vertical"
)
def update_dimension(d):
    global dimension_id
    dimension_id = d
    update()
dimension_slider.on_changed(update_dimension)

# Make a vertically oriented slider to control the time_id
ax_time = fig.add_axes([0.10, 0.25, 0.0225, 0.63])
time_slider = matplotlib.widgets.Slider(
    ax=ax_time,
    label="Time",
    valmin=0,
    closedmin=True,
    valmax=data.shape[1],
    closedmax=False,
    valstep=1,
    valinit=time_id,
    orientation="vertical"
)
def update_time(t):
    global time_id
    time_id = t
    update()
time_slider.on_changed(update_time)

matplotlib.pyplot.show()

In [None]:
the_repl.close()
del the_repl