###  MC959-MO810 -- Introduction to Self-Supervised (Representation) Learning (SSRL)

* Instructor: Marcelo S. Reis. <a href="mailto:msreis@unicamp.br">msreis@unicamp.br</a>

Campinas, October 9, 2024.


## Lecture 16: Jigsaw Hands On

*Based on ['Unsupervised Jigsaw Puzzle Solving on MNIST'](https://www.kaggle.com/code/kmader/unsupervised-jigsaw-puzzle-solving-on-mnist/notebook) by Scott Mader.*

In this lab, we will carry out a SSRL pipeline using the Jigsaw method.

<img src="http://www.ic.unicamp.br/~msreis/MC959/Jigsaw-01.png">


### Summary <a class="anchor" id="topo"></a>

* [Part 1: Solving main dependencies](#part_01).
* [Part 2: Loading the MNIST dataset](#part_02).
* [Part 3: Implementing the Jigsaw scramble and unscramble](#part_03).
* [Part 4: Generating permutations of the patches](#part_04).
* [Part 5: Implementing the Jigsaw model](#part_05).
* [Part 6: Self-supervised pretraining with the Jigsaw and MNIST](#part_06).
* [Part 7: Exploration of the learned representations](#part_07)
* [Part 8: Evaluating the downstream task (digit classification)](#part_08).
* [Part 9: Suggestion of exercises on this pipeline](#part_09).


### Part 1: Solving main dependencies <a class="anchor" id="part_01"></a>

Here we load the main libraries that will be used throughout this notebook.


In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import os

!pip install scikit-image
from skimage.io import imread as imread
from skimage.util import montage

from itertools import product  # Cartesian product between values
                               # in x and y (for efficient for loops).

from tqdm import tqdm_notebook
from IPython.display import clear_output

plt.rcParams["figure.figsize"] = (8, 8)
plt.rcParams["figure.dpi"] = 125
plt.rcParams["font.size"] = 14
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
plt.style.use('ggplot')
sns.set_style("whitegrid", {'axes.grid': False})
plt.rcParams['image.cmap'] = 'gray' # grayscale looks better

from itertools import cycle
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']


Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [2]:
# Path to the folder where the datasets are/should be downloaded (e.g. )
#
DATASET_PATH = "./data/"  # Using the same data downloaded for Lecture 12

# MNIST data in CSV can be retrieved from here:
#
# https://www.kaggle.com/datasets/oddrationale/mnist-in-csv?select=mnist_test.csv

MY_SEED = 42

np.random.seed(MY_SEED)

[Back to summary.](#topo)


### Part 2: Loading the MNIST dataset <a class="anchor" id="part_02"></a>

Loading the classic MNIST dataset for digit classification. Each pair $(x,y)$ is $28 \times 28$-pixel, grayscale handwritten digit and its respective digit value. The dataset has 60K and 10K observations for training and test, respectively.


In [3]:

train   = pd.read_csv(DATASET_PATH + "mnist_train.csv", index_col =  False)
y_train = train["label"]
train   = train.drop(labels = ["label"], axis = 1)
train   = train / 255.0   # Scaling in the range [0,1].

test   = pd.read_csv(DATASET_PATH + "mnist_test.csv", index_col = False)
y_test = test["label"]
test   = test.drop(labels = ["label"], axis = 1)
test   = test / 255.0

train

FileNotFoundError: [Errno 2] No such file or directory: './data/mnist_train.csv'

In [None]:
X_train = []
X_val   = []
X_down  = []

for i in range(0,len(train)):
    # Reshape image in 3 dimensions
    # (height = 28px, width = 28px , channel = 1)
    #
    image = np.array(train.iloc[[i]]).reshape(-1,28,28).squeeze()
    image = np.expand_dims(image, -1)
    X_train.append(image)

for i in range(0,len(test)):
    # Reshape image in 3 dimensions
    # (height = 28px, width = 28px , channel = 1)
    #
    image = np.array(test.iloc[[i]]).reshape(-1,28,28).squeeze()
    image = np.expand_dims(image, -1)
    X_down.append(image)

# del train  # They should be removed to save space.
# del test

X_val   = np.array(X_train[40000:60001])
X_train = np.array(X_train[0:40000])
y_train = np.array(y_train[0:40000])     # We'll keep for a visualization later.
X_down  = np.array(X_down)
y_down  = np.array(y_test)

print('Pretext task training data  (no label)  : ' , X_train.shape)
print('Pretext task validation data (no label) : ', X_val.shape)
print('Downstream task data (with labels)      : ', X_down.shape, y_down.shape)

In [None]:

img_idx = np.random.choice(range(X_train.shape[0]), size=1)
plt.matshow(X_train[img_idx].squeeze())


[Back to summary.](#topo)

### Part 3: Implementing the Jigsaw scramble and unscramble <a class="anchor" id="part_04"></a>

Here we implement a function to break up the image into (typically nine) patches (```cut_jigsaw```), and also another one to bring them back into a recoverd image (```jigsaw_to_image```).

*Disclaimer: This implementation is for didactic purposes, and it is not optimized to run efficiently on large datasets.*

In [None]:
import doctest
import copy
import functools

# Unit tests within the functions.
#
def autotest(func):
    globs = copy.copy(globals())
    globs.update({func.__name__: func})
    doctest.run_docstring_examples(func, globs, verbose=True, name=func.__name__)
    return func

@autotest
def cut_jigsaw(in_image, x_wid, y_wid, gap = False, jitter = False, jitter_dim = None):

    """Cuts the image into little pieces
    in_image   : the image to cut-apart
    x_wid      : the size of the piece in x
    y_wid      : the size of the piece in y
    gap        : if there is a gap between tiles
    jitter     : if the positions should be moved around
    jitter_dim : amount to jitter (default is x_wid or y_wid/2)
    return     : a 4D array with tiles x x_wid x y_wid * d

    Examples

    >>> test_image = np.arange(20).reshape((4, 5)).astype(int)
    >>> test_image
    array([[ 0,  1,  2,  3,  4],
           [ 5,  6,  7,  8,  9],
           [10, 11, 12, 13, 14],
           [15, 16, 17, 18, 19]])
    >>> cut_jigsaw(test_image, 2, 2, False, False)
    array([[[ 0,  1],
            [ 5,  6]],
    <BLANKLINE>
           [[ 2,  3],
            [ 7,  8]],
    <BLANKLINE>
           [[10, 11],
            [15, 16]],
    <BLANKLINE>
           [[12, 13],
            [17, 18]]])
    >>> cut_jigsaw(test_image, 2, 2, True, False)
    array([[[ 0,  1],
            [ 5,  6]],
    <BLANKLINE>
           [[ 3,  4],
            [ 8,  9]],
    <BLANKLINE>
           [[10, 11],
            [15, 16]],
    <BLANKLINE>
           [[13, 14],
            [18, 19]]])
    >>> np.random.seed(0)
    >>> cut_jigsaw(test_image, 2, 2, True, True, 1)
    array([[[ 1,  2],
            [ 6,  7]],
    <BLANKLINE>
           [[ 7,  8],
            [12, 13]],
    <BLANKLINE>
           [[ 5,  6],
            [10, 11]],
    <BLANKLINE>
           [[ 7,  8],
            [12, 13]]])
    """

    if len(in_image.shape) == 2:
        in_image = np.expand_dims(in_image, -1)  # If the image is w x h, expand the shape to turn it w x h x 1.
        expand = True
    else:
        expand = False                           # It already has at least one channel.

    x_size, y_size, d_size = in_image.shape
    out_tiles = []
    x_chunks = x_size // x_wid
    y_chunks = y_size // y_wid
    out_tiles = np.zeros((x_chunks * y_chunks, x_wid, y_wid, d_size), dtype = in_image.dtype)

    if gap:
        x_gap = x_size - x_chunks * x_wid    # Maximum gap size.
        y_gap = y_size - y_chunks * y_wid
    else:
        x_gap, y_gap = 0, 0

    x_jitter = x_wid // 2 if jitter_dim is None else jitter_dim
    y_jitter = y_wid // 2 if jitter_dim is None else jitter_dim

    for idx, (i, j) in enumerate(product(range(x_chunks), range(y_chunks))):

        x_start = i * x_wid + min(x_gap, i)
        y_start = j * y_wid + min(y_gap, j)

        if jitter:
            x_range = max(x_start - x_jitter, 0), min(x_start + x_jitter + 1, x_size - x_wid)
            y_range = max(y_start - y_jitter, 0), min(y_start + y_jitter + 1, y_size - y_wid)

            x_start = np.random.choice(range(*x_range)) if x_range[1] > x_range[0] else x_start
            y_start = np.random.choice(range(*y_range)) if y_range[1] > y_range[0] else y_start

        out_tiles[idx, :, :, :] = in_image[x_start:x_start+x_wid, y_start:y_start+y_wid,:]

    return out_tiles[:, :, :, 0] if expand else out_tiles       # type: (...) -> List[np.ndarray]


In [None]:
@autotest
def jigsaw_to_image(in_tiles, out_x, out_y, gap = False):

    """Reassembles little pieces into an image
    in_tiles : the tiles to reassemble
    out_x    : the size of the image in x (default is calculated automatically)
    out_y    : the size of the image in y
    gap      : if there is a gap between tiles
    return   : an image from the tiles

    Examples

    >>> test_image = np.arange(20).reshape((4, 5)).astype(int)
    >>> test_image
    array([[ 0,  1,  2,  3,  4],
           [ 5,  6,  7,  8,  9],
           [10, 11, 12, 13, 14],
           [15, 16, 17, 18, 19]])
    >>> js_pieces = cut_jigsaw(test_image, 2, 2, False, False)
    >>> jigsaw_to_image(js_pieces, 4, 5)
    array([[ 0,  1,  2,  3,  0],
           [ 5,  6,  7,  8,  0],
           [10, 11, 12, 13,  0],
           [15, 16, 17, 18,  0]])
    >>> js_gap_pieces = cut_jigsaw(test_image, 2, 2, True, False)
    >>> jigsaw_to_image(js_gap_pieces, 4, 5, True)
    array([[ 0,  1,  0,  3,  4],
           [ 5,  6,  0,  8,  9],
           [10, 11,  0, 13, 14],
           [15, 16,  0, 18, 19]])
    >>> js_gap_pieces = cut_jigsaw(test_image, 2, 2, False, True)
    >>> jigsaw_to_image(js_gap_pieces, 4, 5, False)
    array([[ 1,  2,  6,  7,  0],
           [ 6,  7, 11, 12,  0],
           [ 6,  7,  7,  8,  0],
           [11, 12, 12, 13,  0]])
    """
    if len(in_tiles.shape) == 3:
        in_tiles = np.expand_dims(in_tiles, -1)
        expand = True
    else:
        expand = False

    tile_count, x_wid, y_wid, d_size = in_tiles.shape

    x_chunks = out_x // x_wid
    y_chunks = out_y // y_wid
    out_image = np.zeros((out_x, out_y, d_size), dtype = in_tiles.dtype)

    if gap:
        x_gap = out_x - x_chunks * x_wid
        y_gap = out_y - y_chunks * y_wid
    else:
        x_gap, y_gap = 0, 0

    for idx, (i, j) in enumerate(product(range(x_chunks), range(y_chunks))):
        x_start = i * x_wid + min(x_gap, i)
        y_start = j * y_wid + min(y_gap, j)
        out_image[x_start:x_start+x_wid, y_start:y_start+y_wid] = in_tiles[idx, :, :]

    return out_image[:, :, 0] if expand else out_image     # type: (...) -> np.ndarray




Let's put those two functions into action on some MNIST images.

In [None]:
TILE_X = 9
TILE_Y = 9

out_tiles = cut_jigsaw(np.array(X_train[0]), TILE_X, TILE_Y, gap=False)
out_tiles.shape


In [None]:
fig, m_axs = plt.subplots(4, 11, figsize=(50, 10))

for img_idx, c_axs in enumerate(m_axs, 1):

    c_axs[0].imshow(X_train[img_idx])
    c_axs[0].set_title('Input')

    out_tiles = cut_jigsaw(X_train[img_idx], TILE_X, TILE_Y, gap=False)

    for k, c_ax in zip(range(out_tiles.shape[0]), c_axs[1:]):
        c_ax.matshow(out_tiles[k, :, :])

    recon_img = jigsaw_to_image(out_tiles, X_train.shape[1], X_train.shape[2])
    c_axs[-1].imshow(recon_img[:, :])
    c_axs[-1].set_title('Reconstruction')


[Back to summary.](#topo)


### Part 4: Generating permutations of the patches <a class="anchor" id="part_04"></a>

For a $3 \times 3$ grid of patches, we have $9!$ possible permutations. As we discussed in Lecture 15, this would be a problem for the pretext task, since there are far too many permutations to be predicted, which spoil from both computational and statistical points of view.

Therefore, we'll keep just a few of all possible permutations (e.g., 200 of them).


In [None]:

KEEP_RANDOM_PERM = 200

from itertools import permutations
all_perm = np.array(list(permutations(range(out_tiles.shape[0]), out_tiles.shape[0])))
print('Permutation count:' , len(all_perm))

# first one is always unmessed up
#
keep_perm = all_perm[0:1, :].tolist() \
+ all_perm[np.random.choice(range(1, len(all_perm)), KEEP_RANDOM_PERM-1), :].tolist()

print('Permutations kept: ', len(keep_perm))

keep_perm

Let's visualize some permutations in action on images, looking at the effect of applying jitter on them.

In [None]:
JITTER_SIZE = 3

fig, m_axs = plt.subplots(5, 5, figsize=(15, 25))

an_image = X_train[666]

for i, c_axs in enumerate(m_axs.T):

    out_tiles = cut_jigsaw(an_image, TILE_X, TILE_Y, gap=False, jitter=i>0, jitter_dim=JITTER_SIZE)

    for j, (c_ax, c_perm) in enumerate(zip(c_axs, keep_perm)):
        scrambled_tiles = out_tiles[c_perm]
        recon_img = jigsaw_to_image(scrambled_tiles, X_train.shape[1], X_train.shape[2])
        c_ax.imshow(recon_img)
        c_ax.set_title('Permutation:#{}\nJitter:{}'.format(j, i>0))
        c_ax.axis('off')


Finally, we'll use the scramble function and the permutation scheme to pre-process the training and validation data, making them ready for the self-supervised learning.

In [None]:
TRAIN_TILE_COUNT = 2 ** 14  # 16,384
VALID_TILE_COUNT = 2 ** 11  #  2,048

def _generate_batch(in_idx, is_validation = False):

    if is_validation:
        img_ds = X_val
    else:
        img_ds = X_train

    img_idx = np.random.choice(range(img_ds.shape[0]))  # Takes a single random image

    out_tiles = cut_jigsaw(img_ds[img_idx], TILE_X, TILE_Y, gap = True,
                           jitter = JITTER_SIZE > 0, jitter_dim = JITTER_SIZE)

    perm_idx = np.random.choice(range(len(keep_perm)))
    c_perm = keep_perm[perm_idx]

    return out_tiles[c_perm], perm_idx


def make_tile_group(out_tiles_shape, tile_count, is_validation = False):

    c_tiles = np.zeros((tile_count,) + out_tiles_shape, dtype='float32')

    c_perms = np.zeros((tile_count,), dtype='int')

    for i in tqdm_notebook(range(tile_count)):
        c_tiles[i], c_perms[i] = _generate_batch(i, is_validation = is_validation)     # should be parallelized

    return c_tiles, c_perms


# Preparating training and validation data.
#
out_tiles_shape = cut_jigsaw(X_train[8].squeeze(), TILE_X, TILE_Y, gap=False)
out_tiles_shape = out_tiles.shape

train_tiles, train_perms = make_tile_group(out_tiles_shape, TRAIN_TILE_COUNT)

valid_tiles, valid_perms = make_tile_group(out_tiles_shape, VALID_TILE_COUNT, is_validation = True)


[Back to summary.](#topo)


### Part 5: Implementing the Jigsaw model <a class="anchor" id="part_05"></a>


<img src="http://www.ic.unicamp.br/~msreis/MC959/Jigsaw-03.png">

We'll implement an architecture very similar to the one in the paper. First, we'll implement an encoder for the patches. This encoder makes use of batch normalization, convolutional layers, max pooling layers, and LeakyReLU activation:

<img src="https://pytorch.org/docs/stable/_images/LeakyReLU.png">
                                                               
In the sequence, we'll implement a "big model", which concatenates all the encoded patches and, after some hidden layers, ends in a softmax with gives the probability that a given scrambled set of patches belong to a given permutation.
                                              

In [None]:
LATENT_SIZE = 8

# Keras is for Tensorflow as PyTorch Lightining is for PyTorch.
#
from keras import models, layers

# Instantiating a Keras model.
#
tile_encoder = models.Sequential(name = 'TileEncoder')

# We use None to make the model more usuable later
# (i.e., to have input size flexibility).
#
# Exercise: Put input_shape=(100,100) and analyze
# the impact on the summary of the model.
#
tile_encoder.add(layers.BatchNormalization(input_shape=(None, None) + (train_tiles.shape[-1],)))

tile_encoder.add(layers.Conv2D(8, (3,3), padding='same', activation='linear'))

tile_encoder.add(layers.BatchNormalization())

tile_encoder.add(layers.MaxPool2D(2,2))

tile_encoder.add(layers.LeakyReLU(0.1))

tile_encoder.add(layers.Conv2D(16, (3,3), padding='same', activation='linear'))

tile_encoder.add(layers.BatchNormalization())

tile_encoder.add(layers.MaxPool2D(2,2))

tile_encoder.add(layers.LeakyReLU(0.1))

tile_encoder.add(layers.Conv2D(32, (2,2), padding='valid', activation='linear'))

tile_encoder.add(layers.BatchNormalization())

tile_encoder.add(layers.LeakyReLU(0.1))

tile_encoder.add(layers.Conv2D(LATENT_SIZE, (1,1), activation='linear'))

tile_encoder.add(layers.BatchNormalization())

tile_encoder.add(layers.LeakyReLU(0.1))

clear_output() # some annoying loading/warnings come up

tile_encoder.summary()


In [None]:
# Model for a single tile.
#
print('Model Input Shape:', train_tiles.shape[2:],
      '-> Model Output Shape:', tile_encoder.predict(np.zeros((1,) + train_tiles.shape[2:])).shape[1:])


Now, let's implement the model for concat of encoded patches and the assignment of the probability of their scramble being being each of the considered permutations.


In [None]:
!pip install pyparsing==3.0.9
!pip install pydot
from keras.utils.vis_utils import model_to_dot
from IPython.display import Image

BIG_LATENT_SIZE  = 64

# Input here includes all tiles.
#
big_in = layers.Input(train_tiles.shape[1:], name='All_Tile_Input')

feat_vec = []

for k in range(train_tiles.shape[1]):  # shape[1] is the number of tiles.

    lay_x     = layers.Lambda(lambda x: x[:, k], name = 'Select_{}_Tile'.format(k))(big_in)
    feat_x    = tile_encoder(lay_x)
    feat_vec += [layers.GlobalAvgPool2D()(feat_x)]

feat_cat = layers.concatenate(feat_vec)

feat_dr = layers.Dropout(0.5)(feat_cat)

feat_latent = layers.Dense(BIG_LATENT_SIZE)(feat_dr)

feat_latent_dr = layers.Dropout(0.5)(feat_latent)

out_pred = layers.Dense(KEEP_RANDOM_PERM,
                        activation = 'softmax')(feat_latent_dr)

big_model = models.Model(inputs = [big_in], outputs = [out_pred])

big_model.compile(optimizer = 'adam',
                  loss = 'sparse_categorical_crossentropy',
                  metrics = ['sparse_categorical_accuracy', 'sparse_top_k_categorical_accuracy'])

# Plotting and saving image of the model's architecture.
#
dot_model = model_to_dot(big_model, show_shapes=True)
dot_model.set_rankdir('LR')
dot_model.write_png('Jigsaw_model.png')
Image(dot_model.create_png())

Here, we compute a reverse mapping for the permutations (an "unscramble function"), so we can show the prediction results.

In [None]:
reversed_keep_perm = []

for c_dict in [{j: i for i, j in enumerate(c_perm)} for c_perm in keep_perm]:
    print(c_dict)
    reversed_keep_perm.append([c_dict[j] for j in range(out_tiles.shape[0])])


In [None]:
for i in range(3):
    print('forward', keep_perm[i], 'reversed', reversed_keep_perm[i])

In [None]:

def show_model_output(image_count=4, perm_count=3):

    fig, m_axs = plt.subplots(image_count, perm_count + 1,
                              figsize = (3 * (perm_count + 1), 3 * image_count))
    [c_ax.axis('off') for c_ax in m_axs.flatten()]

    for img_idx, c_axs in enumerate(m_axs):

        # Sample an image for input.
        #
        img_idx = np.random.choice(range(X_train.shape[0]))

        # Sample a permutation.
        #
        perm_idx = np.random.choice(range(len(keep_perm)))

        # Plot sampled image.
        #
        c_axs[0].imshow(X_train[img_idx, :, :, 0])
        c_axs[0].set_title('Input #{}'.format(perm_idx))

        # Generate tiles with sampled image.
        #
        out_tiles = cut_jigsaw(X_train[img_idx, :, :], TILE_X, TILE_Y, gap = True,
                               jitter = JITTER_SIZE>0, jitter_dim = JITTER_SIZE)

        # Scramble tiles with sampled permutation.
        #
        c_perm = keep_perm[perm_idx]
        scr_tiles = out_tiles[c_perm]

        # Get model prediction for the given scrambled tiles.
        #
        out_pred = big_model.predict(np.expand_dims(scr_tiles, 0))[0]

        # Plot the scramble image and the most probable permutation index
        # according to our model (softmax over KEEP_RANDOM_PERM; e.g., 200).
        #
        for c_ax, k_idx in zip(c_axs[1:], np.argsort(-1 * out_pred)):
            pred_rev_perm = reversed_keep_perm[k_idx]
            recon_img = jigsaw_to_image(scr_tiles[pred_rev_perm], X_train.shape[1], X_train.shape[2])
            c_ax.imshow(recon_img[:, :, 0])
            c_ax.set_title('Pred: #{} ({:2.2%})'.format(k_idx, out_pred[k_idx]))

show_model_output()


[Back to summary.](#topo)

### Part 6: Self-supervised pretraining with the Jigasaw and MNIST <a class="anchor" id="part_06"></a>

Finally, we can proceed and carry out the pretext task!

In [None]:
# Training the model.
#
fit_results = big_model.fit(train_tiles,
                            train_perms,
                            validation_data = (valid_tiles, valid_perms),
                            batch_size = 512,
                            epochs = 40)


In [None]:
%matplotlib inline

fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (20, 10))
ax1.semilogy(fit_results.history['loss'], label='Training')
ax1.semilogy(fit_results.history['val_loss'], label='Validation')
ax1.legend()
ax1.set_title('Loss')
ax2.plot(fit_results.history['sparse_categorical_accuracy'], label='Training')
ax2.plot(fit_results.history['val_sparse_categorical_accuracy'], label='Validation')
ax2.legend()
ax2.set_title('Accuracy')
ax2.set_ylim(0, 1)
fig.show()

In [None]:
show_model_output(image_count=10, perm_count=4)

[Back to summary.](#topo)

### Part 7: Exploration of the learned representations <a class="anchor" id="part_07"></a>

Now, we can evaluate the representations we have just learned from unlabeled data. To this end, can look at the filters, the activations and also perform a non-linear dimensionaly reduction using t-SNE.


#### Looking at the filters


In [None]:
# Assemble a dictionary to store the weights having
# the layer index and name as keys.
#
conv_weight_dict = {(idx, k.name): k.get_weights() for idx, k in enumerate(tile_encoder.layers)
                    if isinstance(k, layers.Conv2D)}
print(conv_weight_dict.keys())

fig, m_axs = plt.subplots(2, 2, figsize=(20, 20))

for c_ax, ((idx, lay_name), [W, b]) in zip(m_axs.flatten(), conv_weight_dict.items()):

    c_ax.set_title('{} #{}\n{}'.format(lay_name, idx, W.shape))

    flat_W = W.reshape((W.shape[0], W.shape[1], -1)).swapaxes(0, 2).swapaxes(1,2)

    if flat_W.shape[1] > 1 or flat_W.shape[2] > 1:
        pad_W = np.pad(flat_W, [(0, 0), (1, 1), (1, 1)], mode = 'constant', constant_values = np.NAN)
        pad_W = montage(pad_W, fill = np.NAN, grid_shape = (W.shape[2], W.shape[3]))

    else:
        pad_W = W[0, 0]

    c_ax.imshow(pad_W, vmin = -1, vmax = 1, cmap = 'RdBu')


#### Looking at the most activating imaging channels

In [None]:
gp_outputs = []

for k in tile_encoder.layers:

    # Look for activation layers only (in this case, the leaky ReLU),
    #
    if isinstance(k, layers.LeakyReLU):

        c_output = k.get_output_at(0)

        c_smooth = layers.AvgPool2D((2, 2))(c_output)

        c_gp = layers.GnvlobalMaxPool2D(name='GP_{}'.format(k.name))(c_smooth)

        gp_outputs += [c_gp]

activation_tile_encoder = models.Model(inputs = tile_encoder.inputs, outputs = gp_outputs)

activation_tile_encoder.summary()


In [None]:
activation_maps = dict(zip(activation_tile_encoder.output_names,
                           activation_tile_encoder.predict(X_train, batch_size = 128, verbose = True)))

for k, v in activation_maps.items():
    print(k, v.shape)

Here we show each intermediate layer (panel) with each neuron/depth-channel (row) and the top-n images for activating that pattern (columns). Each row should more or less represent the kinds of images that particular neuron is sensitive too.

In [None]:
keep_top_n = 5

fig, m_axs = plt.subplots(1, len(activation_maps), figsize=(20, 20))

for c_ax, (k, v) in zip(m_axs.T, activation_maps.items()):

    c_ax.set_title(k)
    active_rows = []

    for i in range(v.shape[1]):

        # Gets the indices of the top n images and stores the respective images.
        #
        top_idx = np.argsort(-np.abs(v[:, i]))[:keep_top_n]
        active_rows += [X_train[top_idx, :, :, 0]]

    # Plot all rows for a given activation.
    #
    c_ax.imshow(montage(np.concatenate(active_rows, 0),
                        grid_shape = (v.shape[1], keep_top_n),
                        padding_width = 1))
    c_ax.axis('off')


#### Calculating t-SNE

We can see whether t-SNE can separate the digits in the projection space. To this end, we throw in global average pooling to turn the output of the `tile_encoder` into a single feature-vector. We can also use this feature vector later for image classification (downstream task).

In [None]:

print(X_train[0].shape, '->', tile_encoder.predict(X_train[0:1]).shape)


In [None]:

# img_in is an input layer with shape (28, 28, 1).
#
img_in = layers.Input(X_train.shape[1:])

# Generating a representation (with shape 6,6,8) using the (single) tile encoder.
#
full_feat_mat = tile_encoder(img_in)

# Reducing the representation to a 8-dimensional vector.
#
gap_out = layers.GlobalAvgPool2D()(full_feat_mat)

# Defining our final image encoder.
#
image_encoder = models.Model(inputs = [img_in], outputs = [gap_out], name = 'EncodeImage')
image_encoder.summary()


In [None]:

X_features = image_encoder.predict(X_train, batch_size = 128)

X_features


In [None]:
from sklearn.manifold import TSNE

tsne = TSNE(n_components = 2, perplexity = 40, verbose = 2, n_iter = 250, early_exaggeration = 1)

X_tsne = tsne.fit_transform(X_features)


In [None]:
fig, ax1 = plt.subplots(1, 1, figsize=(20, 20))
for k in np.unique(y_train):
    ax1.plot(X_tsne[y_train==k, 0], X_tsne[y_train==k, 1], '.', label='{}'.format(k))
ax1.legend()

[Back to summary.](#topo)

### Part 8: Evaluating the downstream task (digit classification) <a class="anchor" id="part_08"></a>

We'll evaluate now how our learned representation perform on a downstream task. The baseline will be a classification based on pixels. We'll explore two different models: logistic regression and random forest.


In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import Normalizer

X_features = image_encoder.predict(X_down, batch_size = 128)

train_imgs, test_imgs, train_feat, test_feat, train_y, test_y = train_test_split(
        X_down,
        X_features,
        y_down,
        random_state = MY_SEED,
        test_size    = 0.1,
        stratify     = y_down
)

print(train_imgs.shape, train_feat.shape, train_y.shape)
print(test_imgs.shape, test_feat.shape, test_y.shape)


def train_and_show_model(in_model, train_size, random_state, show_figure = True):

    # Fit pixel model.
    #
    baseline_model = make_pipeline(Normalizer(), in_model)
    baseline_model.fit(train_imgs[0:train_size].reshape((-1, 784)), train_y[0:train_size])
    baseline_pred_y = baseline_model.predict(test_imgs.reshape((-1, 784)))

    # Fit feature model.
    #
    feat_model = make_pipeline(Normalizer(), in_model)
    feat_model.fit(train_feat[0:train_size], train_y[0:train_size])
    pred_y = feat_model.predict(test_feat)

    if show_figure:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))

        sns.heatmap(confusion_matrix(test_y, baseline_pred_y), annot=True, fmt='d', ax=ax1)
        ax1.set_title('Pixel Accuracy: {:2.1%}'.format(accuracy_score(test_y, baseline_pred_y)))

        sns.heatmap(confusion_matrix(test_y, pred_y), annot=True, fmt='d', ax=ax2)
        ax2.set_title('Feature Accuracy: {:2.1%}'.format(accuracy_score(test_y, pred_y)))

    return {
        'train_size': train_size,
        'random_state': random_state,
        'pixel_accuracy': accuracy_score(test_y, baseline_pred_y),
        'feature_accuracy': accuracy_score(test_y, pred_y)
    }

print('[Done]')


#### Logistic regression (linear)

In [None]:

train_and_show_model(LogisticRegression(solver='lbfgs', multi_class='auto'), 9000, MY_SEED)


#### Random forest (non-linear)

In [None]:

train_and_show_model(RandomForestClassifier(n_estimators=100), 9000, MY_SEED)


#### Evaluating the models with few data

Here we evaluate the impact of the downstream training set size on the performance of the random forest, for both pixel and vector classification.


In [None]:
import dask
from dask import bag
import dask.diagnostics as diag
from multiprocessing.pool import ThreadPool
import gc; gc.enable(); gc.collect() # memory gets very tight

model_fn = lambda : RandomForestClassifier(n_estimators = 25, random_state = MY_SEED, n_jobs = 1)

# We vary the number of samples between 1 and 1000, and the seed value between 0 and 9
# (for different initializations, in order to obtain a mean and a deviation).
#
parm_sweep = bag.\
    from_sequence(product(np.logspace(1, 3, 10).astype(int), # sample size
                          np.arange(10))).\
    map(lambda args: train_and_show_model(model_fn(),
                                     train_size=args[0],
                                     random_state=args[1],
                                    show_figure=False))
parm_sweep


In [None]:
with diag.ProgressBar(), dask.config.set(pool = ThreadPool(2)):
    parm_df = pd.DataFrame(parm_sweep.compute())

# clean-up the output
#
nice_parm_df = pd.melt(parm_df, id_vars = ['train_size', 'random_state']).\
    rename(columns={'value': 'Test Accuracy', 'variable': 'input'})

nice_parm_df['Test Accuracy'] *= 100

nice_parm_df['input'] = nice_parm_df['input'].map(lambda x: x.split('_accuracy')[0])

nice_parm_df['Samples Per Class'] = nice_parm_df['train_size'] / np.unique(y_train).shape[0]

nice_parm_df.head(3)


In [None]:
sns.catplot(x        = 'Samples Per Class',
            y        = 'Test Accuracy',
            hue      = 'input',
            kind     = 'point',
            errorbar = 'sd',
            data     = nice_parm_df)

In [None]:
sns.catplot(x='Samples Per Class',
            y='Test Accuracy',
            hue='input',
            kind='swarm',
            data=nice_parm_df)

[Back to summary.](#topo)


### Part 9: Sugestions of exercises on this pipeline <a class="anchor" id="part_09"></a>

- Try to explore hyperparameters of the pretext task;

- Run this pipeline on other dataset (e.g., STL-10);

- Try to implement the permutation algorithm of the Jigsaw paper:

<img src="http://www.ic.unicamp.br/~msreis/MC959/Jigsaw-04.png">

- Try to port this pipeline to PyTorch.


[Back to summary.](#topo)
