Notebook to run the model using pure python.

Template for the C implementation of the model.

In [1]:
from typing import Iterable, Dict

import numpy as np

In [2]:
weights = {}

In [3]:
for i in range(3):
    base_name = f"weights/conv{i}_{{}}.npy"
    weights[f"conv{i}_kernel"] = np.load(base_name.format("kernel"))
    weights[f"conv{i}_bias"] = np.load(base_name.format("bias"))
    
weights["dense_kernel"] = np.load("weights/dense_kernel.npy")
weights["dense_bias"] = np.load("weights/dense_bias.npy")

In [4]:
def apply_2x2_conv_kernels(x: np.ndarray, kernels: np.ndarray) -> np.ndarray:
    # we only need 2x2, no padding :D
    out_rows = x.shape[0] - 1
    out_cols = x.shape[1] - 1
    out = np.zeros((out_rows, out_cols, kernels.shape[-1]))
    
    for row in range(out_rows):
        for col in range(out_cols):
            sliding_window = x[row:row+2, col:col+2, :]
            out[row, col, :] = (sliding_window[...,np.newaxis] * kernels).sum(axis=(0,1,2))
    return out

            
def apply_full_2x2_conv_layer(x: np.ndarray, kernel: np.ndarray,
                              bias: np.ndarray) -> np.ndarray:

    convolved = apply_2x2_conv_kernels(x, kernel)
    with_bias = convolved + bias
    
    # apply relu
    return np.maximum(with_bias, 0)


def apply_conv_layer_stack(x: np.ndarray, kernels: Iterable[np.ndarray],
                           biases: Iterable[np.ndarray]) -> np.ndarray:
    for kernel, bias in zip(kernels, biases):
        x = apply_full_2x2_conv_layer(x, kernel, bias)
    return x


def apply_max_pool(x: np.ndarray) -> np.ndarray:
    return np.max(x, axis=(0,1))


def softmax(x: np.ndarray) -> np.ndarray:
    e = np.exp(x)
    return e / e.sum()


def apply_dense_output_layer(x: np.ndarray, kernel: np.ndarray,
                             bias: np.ndarray) -> np.ndarray:
    x = x[:,np.newaxis]
    out = ((kernel.T)@x)[:,0] + bias
    return softmax(out)


def apply_full_model(x: np.ndarray, weights: Dict[str, np.ndarray]):
    x = apply_conv_layer_stack(
        x,
        [weights[f"conv{i}_kernel"] for i in range(3)],
        [weights[f"conv{i}_bias"] for i in range(3)]
    )
    x = apply_max_pool(x)
    x = apply_dense_output_layer(
        x,
        kernel=weights["dense_kernel"],
        bias=weights["dense_bias"]
    )
    return x

In [5]:
test_input = np.array([
    [1,1,0,0,0],
    [0,0,0,0,0],
    [0,0,1,0,0],
    [0,0,1,1,0],
    [0,0,1,0,1]
    ])[..., np.newaxis]

In [6]:
apply_2x2_conv_kernels(test_input, weights["conv0_kernel"])[0,0,0]

2.346823751926422

In [7]:
apply_2x2_conv_kernels(test_input, weights["conv0_kernel"])[2,1,3]

-0.2727128267288208

In [8]:
res = apply_full_2x2_conv_layer(test_input, weights["conv0_kernel"], weights["conv0_bias"])
print(res[0,0,0])
print(res[2,1,3])

1.867294818162918
0.09612968564033508


In [9]:
res = apply_conv_layer_stack(
        test_input,
        [weights[f"conv{i}_kernel"] for i in range(3)],
        [weights[f"conv{i}_bias"] for i in range(3)]
    )
print("res[0,0,0]:", res[0,0,0])
print("res[0,1,0]:", res[0, 1, 10])

res_pool = apply_max_pool(res)
res_pool

res[0,0,0]: 0.0
res[0,1,0]: 5.342033024360134


array([0.        , 6.94676578, 0.        , 2.1543044 , 5.58622718,
       4.95571283, 9.05192321, 4.97206007, 5.46789808, 5.18478296,
       6.98573438, 0.6666584 , 0.08449961, 0.66260766, 2.28731028,
       2.06870844])

In [10]:
weights["conv0_kernel"][0,1,0,2]

0.7610935

In [17]:
for i in range(3):
    print(f"conv{i}")
    print(repr(weights[f"conv{i}_kernel"].flatten()))
    print(repr(weights[f"conv{i}_bias"].flatten()))
    print()

conv0
array([ 0.725426  ,  1.3791285 , -0.9170304 ,  1.0277331 ,  1.6213977 ,
        0.11924043,  0.7610935 ,  1.0245485 ,  1.5988867 , -0.22061627,
        1.1487229 , -0.9349341 ,  0.14820372, -1.5151786 ,  1.6517029 ,
       -1.2972614 ], dtype=float32)
array([-0.47952893,  0.00482019, -0.38126057,  0.3688425 ], dtype=float32)

conv1
array([ 7.27378011e-01,  9.45881903e-01, -7.18327522e-01,  1.03769243e+00,
        9.22698140e-01, -1.27145927e-02, -4.23858732e-01, -1.73859015e-01,
       -1.02740204e+00, -5.38816392e-01, -2.44815856e-01,  1.14176989e-01,
        3.23992759e-01,  2.21838737e+00, -1.10039902e+00,  4.98762399e-01,
       -5.02348661e-01,  2.69528985e-01, -6.66497946e-02,  1.25165731e-01,
       -6.57751858e-02,  3.90281111e-01, -5.45669757e-02, -3.91329706e-01,
       -1.40814837e-02,  2.38467261e-01,  8.55961382e-01, -4.63857085e-01,
        6.35963202e-01, -7.48987854e-01,  6.51141942e-01, -8.07598382e-02,
        7.61195242e-01,  5.70850372e-01,  2.89167285e-01, -5

In [18]:
weights["dense_kernel"].T.flatten()

array([-0.83960456,  0.19276161, -0.6516487 ,  0.58073187,  0.8809782 ,
       -0.19721672, -0.7706567 ,  0.24139214,  0.09156837,  1.0134032 ,
       -0.80731326,  0.22475286, -0.45592448, -0.8114214 , -0.6149736 ,
        0.7058    ,  0.49973786,  0.6961818 ,  0.6696165 , -1.5401654 ,
       -0.03712799,  0.14859608,  1.1001315 , -1.2117448 , -0.5655154 ,
       -0.22047481, -0.21379043,  0.5482302 ,  0.06175999, -1.3887627 ,
       -0.9614345 , -0.6491538 , -0.78515285, -0.17361808, -0.9029241 ,
        0.8398335 , -0.15483056,  0.07114238, -0.3013683 , -0.32277304,
        0.2331746 ,  0.5330044 ,  0.2117948 ,  1.2236348 ,  0.86173016,
       -0.37394926, -0.41088736, -0.5172179 ,  0.82984847,  0.09649578,
        0.35214788,  0.6599633 , -0.40643334, -0.7613639 , -0.22025885,
       -0.32842168,  0.5196468 ,  0.46605024,  0.41934952, -1.0067512 ,
        0.72975904, -0.5682072 , -0.19975613,  0.41365075,  0.37000197,
        0.8825513 ,  0.3661467 ,  1.074506  ,  0.32708085,  1.25

In [14]:
weights.keys()

dict_keys(['conv0_kernel', 'conv0_bias', 'conv1_kernel', 'conv1_bias', 'conv2_kernel', 'conv2_bias', 'dense_kernel', 'dense_bias'])

In [19]:
weights["dense_bias"].flatten()

array([ 0.33879644,  0.5005194 ,  0.47052455,  0.09496455,  0.0743339 ,
       -0.7524616 ,  0.25328198, -0.15340249, -0.8011218 ,  0.09212416],
      dtype=float32)

In [20]:
apply_full_model(
    test_input,
    weights
)

array([0.22565381, 0.00345653, 0.14215502, 0.109535  , 0.1730379 ,
       0.00396627, 0.11912674, 0.21432252, 0.00213021, 0.00661601])