# Setup

In [None]:
import os
import numpy as np
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from mpl_toolkits.axes_grid1 import make_axes_locatable

## Load Data

In [None]:
X_test = np.load('private_test.npy')

channels = X_test.shape[0]
height = X_test.shape[1]
width = X_test.shape[2]

print(f"Loaded image with channels: {channels}, height: {height}, width: {width}")

# Apply Model

In [None]:
cnn = load_model('../models/cnn')

In [4]:
def get_neighborhood(image, x, y):
    neighborhood = image[x:x+7, y:y+7,:]
    return neighborhood

Apply the CNN row-wise to the input. And store the predictions row-wise.

It's a pain to run this cell and takes ages... (about 16 hours). And maybe your RAM capacity is exceeded...

In [None]:
X_batch = np.zeros((width, 7, 7, 10))

X_test = np.moveaxis(X_test, 0, -1)
X_test_padded = np.pad(X_test, ((3, 3), (3, 3), (0, 0)), mode='edge')
del X_test

for y in range(height):
    for x in range(width):
        X_new = get_neighborhood(X_test_padded, x, y)
        mean = np.mean(X_new)
        std = np.std(X_new)
        X_new = (X_new - mean) / std
        X_batch[x] = X_new
    pred = cnn.predict(X_batch).flatten()
    np.save(f'cnn_pred/private_test_cnn_{y}.npy', pred)

In [None]:
num_files = len(os.listdir('cnn_pred/'))
num_files

Concatenate all rows of the numpy array to obtain a single array containing all predictions.

In [None]:
rows = np.stack(list(map(lambda index: np.load(f'./private_test_cnn_{index}.npy'), range(16384))))
predictions = rows.reshape((1, 16384, 16384)).swapaxes(1, 2)
np.save('../private_test_cnn.npy', predictions)

# Visualize Predictions

In [None]:
fig, ax = plt.subplots(figsize=(20, 20))
colors = ["white", "#942738"]
cmap = mcolors.LinearSegmentedColormap.from_list("mycmap", colors)
im = ax.imshow(np.load('../private_test_cnn.npy')[0], cmap=cmap, vmin=0, vmax=40)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.2)
plt.colorbar(im, cax=cax)
ax.axis('off')
plt.savefig('../Private_Test_CNN.png', bbox_inches='tight')
plt.show()