In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np

import matplotlib.pyplot as plt

In [None]:
def get_input_layer(board) -> np.ndarray:
    r = (board == 1).astype(float)
    g = (board == 2).astype(float)
    b = (board == 0).astype(float)
    planes = [r, g, b]
    return np.stack(planes, axis=-1)

board = np.array([
    [0, 0, 0, 0, 0, 2, 0],
    [0, 0, 0, 0, 0, 1, 2],
    [0, 1, 0, 0, 0, 1, 1],
    [0, 2, 0, 0, 1, 2, 2],
    [0, 2, 0, 0, 2, 2, 1],
    [1, 2, 0, 0, 1, 1, 2]
])

planes = get_input_layer(board)
plt.imshow(planes)
board

In [None]:
in_channels = 3
out_channels = 3
kernel_size = 2

# Create a convolutional layer with (out_channels x in_channels x kernel_size x kernel_size) weights
conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding = 1)

weights = conv.weight.detach().numpy()[0]
weights.transpose(0, 1, 2)
normalised_weights = (weights - np.min(weights)) / (np.max(weights) - np.min(weights))

plt.imshow(normalised_weights.transpose(1, 2, 0))

conv.weight.shape

In [None]:
t_new = torch.tensor([
         # red to red (1 to 1)
        [[[ 1.0000,  0.0000],
          [ 0.0000,  1.0000]],
            
         # green to red (2 to 1)
         [[ 1.0000,  0.0000],
          [ 0.0000,  1.0000]],

         # blue to red (3 to 1)
         [[ 1.0000,  0.0000],
          [ 0.0000,  1.0000]]],

         # red to green (1 to 2)
        [[[ 1.0000,  0.0000],
          [ 0.0000,  1.0000]],

         # green to green (2 to 2)
         [[ 1.0000,  1.0000],
          [ 1.0000,  1.0000]],

         # blue to green (3 to 2)
         [[ 1.0000,  0.0000],
          [ 0.0000,  1.0000]]],

         # red to blue (1 to 3)
        [[[ 1.0000,  0.0000],
          [ 0.0000,  1.0000]],

         # green to blue (2 to 3)
         [[ 1.0000,  0.0000],
          [ 0.0000,  1.0000]],

          # blue to blue (3 to 3)
         [[ 1.0000,  1.0000],
          [ 1.0000,  1.0000]]]], requires_grad=True)

    
conv.weight.data = t_new
conv.bias.data = torch.tensor([0., 0., 0.])

fig, axs = plt.subplots(out_channels, 3)

label = ['r', 'g', 'b', 'k']

# Loop over the 3x3 grid of subplots
for i in range(out_channels):
    for j in range(in_channels):
        current_image = conv.weight[i, j].detach().numpy()
        
        axs[i, j].imshow(current_image, cmap='gray')
        axs[i, j].set_title(f"({label[j]} -> {label[i]})")
        axs[i, j].set(xticks=[], yticks=[], xlabel='', ylabel='')
        

plt.subplots_adjust(wspace=-0.5, hspace=0.5)
plt.gcf().text(0.5, 0.02, 'RGB input channels', ha='center')
plt.gcf().text(0.02, 0.5, 'RGB output channels', va='center', rotation='vertical')
plt.show()

In [None]:
tensor_planes = torch.Tensor(planes.transpose(2, 0, 1))
after = conv(tensor_planes).detach().numpy().transpose(1, 2, 0)
normalised_after = (after - np.min(after)) / (np.max(after) - np.min(after))
plt.imshow(normalised_after)