In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

from pypad.games import ConnectX, TicTacToe

# Describing TicTacToe with a Convolutional Layer

In [None]:
tictactoe = TicTacToe()
tictactoe_state = tictactoe.initial_state("0,1,5")

tictactoe_state.plot()
tictactoe_state

# Describing TicTacToe with a Convolutional Layer

In [None]:
connectx = ConnectX()
connectx_state = connectx.initial_state([3,3,4,5,3,4,5,5,6,1,1,1,1,1,1,2,6,6,7,7])

connectx_state.plot()
connectx_state

# Understanding the action of a convolutional layer using ConnectX

In [None]:
in_channels = 3 # e.g. three RGB channels
out_channels = 3 # e.g. three RGB channels
kernel_size = 2 # usually odd number is better because padding = (KS - 1) // 2

# Create a convolutional layer with (out_channels x in_channels x kernel_size x kernel_size) weights
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=1)

t_new = torch.tensor([
         # red to red (1 to 1)
        [[[ 0.0000,  0.0000],
          [ 0.0000,  0.0000]],
            
         # green to red (2 to 1)
         [[ 0.0001,  0.0000],
          [ 0.0000,  0.0001]],

         # blue to red (3 to 1)
         [[ 1.0000,  0.9000],
          [ 0.9000,  1.0000]]],

         # red to green (1 to 2)
        [[[ 0.0000,  0.5000],
          [ 0.0000,  0.5000]],

         # green to green (2 to 2)
         [[ 1.0000,  0.9000],
          [ 0.9000,  1.0000]],

         # blue to green (3 to 2)
         [[ 0.0000,  0.0000],
          [ 0.0000,  0.0000]]],

         # red to blue (1 to 3)
        [[[ 0.0000,  0.0000],
          [ 0.0000,  0.0000]],

         # green to blue (2 to 3)
         [[ 0.0000,  0.0000],
          [ 0.0000,  0.0000]],

          # blue to blue (3 to 3)
         [[ 0.0000,  0.0000],
          [ 0.0000,  0.0000]]]], requires_grad=True)

    
conv.weight.data = t_new
conv.bias.data = torch.tensor([0., 0., 0.])

fig, axs = plt.subplots(out_channels, 3)

label = ['r', 'g', 'b', 'k']

# Loop over the 3x3 grid of subplots
for i in range(out_channels):
    for j in range(in_channels):
        current_image = conv.weight[i, j].detach().numpy()
        
        axs[i, j].imshow(current_image, cmap='gray', vmin=0, vmax=1)
        axs[i, j].set_title(f"({label[j]} -> {label[i]})")
        axs[i, j].set(xticks=[], yticks=[], xlabel='', ylabel='')
        

plt.subplots_adjust(wspace=-0.5, hspace=0.5)
plt.gcf().text(0.5, 0.02, 'RGB input channels', ha='center')
plt.gcf().text(0.02, 0.5, 'RGB output channels', va='center', rotation='vertical')
plt.show()

In [None]:
planes = connectx_state.to_feature()
tensor_planes = torch.Tensor(planes)
after = conv(tensor_planes).detach().numpy()
normalised_after = (after - np.min(after)) / (np.max(after) - np.min(after))
_ = plt.imshow(normalised_after.transpose(1, 2, 0))
_ = plt.axis('off')