# Models for ICLR 2019 paper

### A rotation-equivariant convolutional neural network model of primary visual cortex
*Alexander S. Ecker, Fabian H. Sinz, Emmanouil Froudarakis, Paul G. Fahey, Santiago A. Cadena, Edgar Y. Walker, Erick Cobos, Jacob Reimer, Andreas S. Tolias, Matthias Bethge*

https://openreview.net/forum?id=H1fU8iAqKX

This notebook contains the code to build and train all models described in the paper. We start by building and loading a pre-trained model. Then we provide the code to build and train all other models.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import tensorflow as tf, numpy as np, os, sys
p = !pwd
p = os.path.dirname(os.path.dirname(p[0]))
if p not in sys.path:
    sys.path.append(p)

In [None]:
from cnn_sys_ident.architectures.models import BaseModel, CorePlusReadoutModel
from cnn_sys_ident.architectures.cores import StackedRotEquiHermiteConv2dCore
from cnn_sys_ident.architectures.readouts import SpatialXFeatureJointL1Readout
from cnn_sys_ident.architectures.training import Trainer
from data import Dataset

## Parameters used throughout

In [None]:
# Core
NUM_ROTATIONS = 8
UPSAMPLING = 2
SHARED_BIASES = False
FILTER_SIZE = [13, 5, 5]
NUM_FILTERS = [16, 16, 16]
STRIDE = [1, 1, 1]
RATE = [1, 1, 1]
PADDING = ['SAME', 'SAME', 'SAME']
ACTIVATION_FN = ['soft', 'soft', 'none']
REL_SMOOTH_WEIGHT = [1, 0.5, 0.5]
REL_SPARSE_WEIGHT = [0, 1, 1]

# Readout
POSITIVE_FEATURE_WEIGHTS = False
INIT_MASKS = 'rand'

# Training
VAL_STEPS = 50
LEARNING_RATE = 0.002
BATCH_SIZE = 256
PATIENCE = 5
LR_DECAY_STEPS = 2
LOG_DIR = 'analysis/iclr2019/checkpoints-repro'
LOG_DIR_PRETRAINED = 'analysis/iclr2019/checkpoints'

# 1. Building and loading a pre-trained model

In this section we build a model and load the pre-trained weights from a checkpoint.

## Rotation-equivariant model with 16 features used for analyses

This is the model with 16 features that we analyze in the paper. The code below just builds and loads the pre-trained model for inference. For training the model from scratch, refer to the code further below.

In [None]:
base = BaseModel(
    Dataset.load(),
    log_dir=LOG_DIR_PRETRAINED,
    log_hash='647bb1d1bd02979996e492b5422eb95f'
)
core = StackedRotEquiHermiteConv2dCore(
    base,
    base.inputs,
    num_rotations=NUM_ROTATIONS,
    upsampling=UPSAMPLING,
    shared_biases=SHARED_BIASES,
    filter_size=FILTER_SIZE,
    num_filters=NUM_FILTERS,
    stride=STRIDE,
    rate=RATE,
    padding=PADDING,
    activation_fn=ACTIVATION_FN,
)
readout = SpatialXFeatureJointL1Readout(
    base,
    core.output,
    positive_feature_weights=POSITIVE_FEATURE_WEIGHTS,
)
model = CorePlusReadoutModel(base, core, readout)
model.load()
trainer = Trainer(base, model)   # just for computing the performance
trainer.compute_test_corr()      #   (for training, see below)

# 2. Building and training models

In this section we provide the code for building and training all models and controls shown in the paper. If you just want to load a pre-trained model, follow the pattern above and replace the training code by `model.load()`.

## Fig. 2: Model comparison: number of features in last conv layer

All models have the same basic architecture: three layers with 16–16–N features (N = 8 ... 48), each at 8 orientations. There are three hyperparameters that we optimize by random search (32 models each): smoothness of convolutional filters (`conv_smooth_weight` $\in$ [0.001, 0.03]), group sparsity of convolutional filters (`conv_sparse_weight` $\in$ [0.001, 0.1]) and sparsity of the readout (`readout_sparsity` $\in$ [0.005, 0.03]). Below we specify the hyperparameters for the best model for each N.

In [None]:
N = [8, 12, 16, 20, 24, 28, 32, 40, 48]
conv_smooth_weight = {
    8:  0.00781004, 12: 0.00184694, 16: 0.0249692,
    20: 0.0257738,  24: 0.00146371, 28: 0.0186784,
    32: 0.026082,   40: 0.00232312, 48: 0.00129107}
conv_sparse_weight = {
    8:  0.0168574,  12: 0.0610123,  16: 0.0152482,
    20: 0.0691215,  24: 0.00999698, 28: 0.0187448,
    32: 0.0118641,  40: 0.0868334,  48: 0.0644271}
readout_sparsity = {
    8:  0.0156452,  12: 0.0153464,  16: 0.0170696,
    20: 0.0141163,  24: 0.0131784,  28: 0.0124147,
    32: 0.0161513,  40: 0.0115895,  48: 0.0163213}
log_hash = {   # determines the seed of the random number generator
    8:  '8d2912ce0669f4dcc4efa78b970e453c',
    12: '4d2e43901a1be496a5e66dc9fec1ed14',
    16: '647bb1d1bd02979996e492b5422eb95f',
    20: '6babf3b3be2cbd8da50e091966f22e46',
    24: '1e34d6f792b506630897ce84fe93a58c',
    28: 'a653720bdd962f95b213156f25c80f31',
    32: 'd23dd9d3a7149ecc72627115bb940e1e',
    40: 'ba65e73469fe90109f22e8204557b646',
    48: '37e70606daaa0b2ca13698fee329eec4'}

In [None]:
for num_features in N:
    base = BaseModel(
        Dataset.load(),
        log_dir=LOG_DIR,
        log_hash=log_hash[num_features]
    )
    core = StackedRotEquiHermiteConv2dCore(
        base,
        base.inputs,
        num_rotations=NUM_ROTATIONS,
        upsampling=UPSAMPLING,
        shared_biases=SHARED_BIASES,
        filter_size=FILTER_SIZE,
        num_filters=[16, 16, num_features],
        stride=STRIDE,
        rate=RATE,
        padding=PADDING,
        activation_fn=ACTIVATION_FN,
        rel_smooth_weight=REL_SMOOTH_WEIGHT,
        rel_sparse_weight=REL_SPARSE_WEIGHT,
        conv_smooth_weight=conv_smooth_weight[num_features],
        conv_sparse_weight=conv_sparse_weight[num_features],
    )
    readout = SpatialXFeatureJointL1Readout(
        base,
        core.output,
        positive_feature_weights=POSITIVE_FEATURE_WEIGHTS,
        init_masks=INIT_MASKS,
        readout_sparsity=readout_sparsity[num_features],
    )
    model = CorePlusReadoutModel(base, core, readout)
    trainer = Trainer(base, model)
    iter_num, val_loss, test_corr = trainer.fit(
        val_steps=VAL_STEPS,
        learning_rate=LEARNING_RATE,
        batch_size=BATCH_SIZE,
        patience=PATIENCE,
        lr_decay_steps=LR_DECAY_STEPS)
    
    trainer.compute_test_corr()

## Table 1: Performance of our proposed model and various baselines

### Rotation-equivariant CNN 3x (16x8)

Same as for N=16 above. Repeated here for completeness.

In [None]:
num_features = 16
base = BaseModel(
    Dataset.load(),
    log_dir=LOG_DIR,
    log_hash=log_hash[num_features]
)
core = StackedRotEquiHermiteConv2dCore(
    base,
    base.inputs,
    num_rotations=NUM_ROTATIONS,
    upsampling=UPSAMPLING,
    shared_biases=SHARED_BIASES,
    filter_size=FILTER_SIZE,
    num_filters=NUM_FILTERS,
    stride=STRIDE,
    rate=RATE,
    padding=PADDING,
    activation_fn=ACTIVATION_FN,
    rel_smooth_weight=REL_SMOOTH_WEIGHT,
    rel_sparse_weight=REL_SPARSE_WEIGHT,
    conv_smooth_weight=conv_smooth_weight[num_features],
    conv_sparse_weight=conv_sparse_weight[num_features],
)
readout = SpatialXFeatureJointL1Readout(
    base,
    core.output,
    positive_feature_weights=POSITIVE_FEATURE_WEIGHTS,
    init_masks=INIT_MASKS,
    readout_sparsity=readout_sparsity[num_features],
)
model = CorePlusReadoutModel(base, core, readout)
trainer = Trainer(base, model)
iter_num, val_loss, test_corr = trainer.fit(
    val_steps=VAL_STEPS,
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    patience=PATIENCE,
    lr_decay_steps=LR_DECAY_STEPS)

trainer.compute_test_corr()

### Rotation-equivariant CNN, but with positive feature weights

In [None]:
base = BaseModel(
    Dataset.load(),
    log_dir=LOG_DIR,
    log_hash='a4de905100ac9b78c6a96e8d67f8adfe'
)
core = StackedRotEquiHermiteConv2dCore(
    base,
    base.inputs,
    num_rotations=NUM_ROTATIONS,
    upsampling=UPSAMPLING,
    shared_biases=SHARED_BIASES,
    filter_size=FILTER_SIZE,
    num_filters=NUM_FILTERS,
    stride=STRIDE,
    rate=RATE,
    padding=PADDING,
    activation_fn=ACTIVATION_FN,
    rel_smooth_weight=REL_SMOOTH_WEIGHT,
    rel_sparse_weight=REL_SPARSE_WEIGHT,
    conv_smooth_weight=0.00553383,
    conv_sparse_weight=0.0715125,
)
readout = SpatialXFeatureJointL1Readout(
    base,
    core.output,
    positive_feature_weights=True,
    init_masks=INIT_MASKS,
    readout_sparsity=0.0244531,
)
model = CorePlusReadoutModel(base, core, readout)
trainer = Trainer(base, model)
iter_num, val_loss, test_corr = trainer.fit(
    val_steps=VAL_STEPS,
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    patience=PATIENCE,
    lr_decay_steps=LR_DECAY_STEPS)

trainer.compute_test_corr()

### Rotation-equivariant CNN, but with non-sparse, L2-regularized feature weights

In [None]:
from cnn_sys_ident.architectures.readouts import SpatialSparseXFeatureDenseSeparateReadout

In [None]:
base = BaseModel(
    Dataset.load(),
    log_dir=LOG_DIR,
    log_hash='9ef7308edab3233c4d02d280ea37bc93'
)
core = StackedRotEquiHermiteConv2dCore(
    base,
    base.inputs,
    num_rotations=NUM_ROTATIONS,
    upsampling=UPSAMPLING,
    shared_biases=SHARED_BIASES,
    filter_size=FILTER_SIZE,
    num_filters=NUM_FILTERS,
    stride=STRIDE,
    rate=RATE,
    padding=PADDING,
    activation_fn=ACTIVATION_FN,
    rel_smooth_weight=REL_SMOOTH_WEIGHT,
    rel_sparse_weight=REL_SPARSE_WEIGHT,
    conv_smooth_weight=0.0141237,
    conv_sparse_weight=0.00280391,
)
readout = SpatialSparseXFeatureDenseSeparateReadout(
    base,
    core.output,
    positive_feature_weights=POSITIVE_FEATURE_WEIGHTS,
    init_masks=INIT_MASKS,
    mask_sparsity=0.0324413,
    feature_l2=0.315181,
)
model = CorePlusReadoutModel(base, core, readout)
trainer = Trainer(base, model)
iter_num, val_loss, test_corr = trainer.fit(
    val_steps=VAL_STEPS,
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    patience=PATIENCE,
    lr_decay_steps=LR_DECAY_STEPS)

trainer.compute_test_corr()

### Regular CNNs with cores of different sizes

_Note: When preparing the code for publication, we realized that there is some residual stochasticity in the model fitting procedure (despite fixing all random number generator seeds) that appears to affect these models more than others. You may therefore have to run the model fitting multiple times to reproduce the same performance as reported in the paper. In our experiments, we always ran 32 different initializations, which is why we are confident that the numbers reported are reasonably robust. In the code below, I increased the patience of the early stopping algorithm (from 5 to 10), which leads to more reliable results._

In [None]:
from cnn_sys_ident.architectures.cores import StackedConv2dCore

In [None]:
cnn_filter_nums = [
    [32, 32, 32],
    [64, 64, 64],
    [128, 128, 128],
    [128, 128, 256],
]
conv_smooth_weights = [0.0151716, 0.00218237, 0.0277236, 0.0015324]
conv_sparse_weights = [0.0219826, 0.0323365, 0.0650177, 0.007974]
readout_sparsities = [0.0193531, 0.0261594, 0.0151648, 0.0179]
log_hashes = [
    '96c4d0cc8869d2b5a4297f13f2cdd422',
    'b8c433730fc6d4753f6f910f697b7f4b',
    '3bedbbd474249974eb309aeda76ca426',
    'f4c477e777c48dac89e61feff11f4327',
]
for num_filters, conv_smooth_weight, conv_sparse_weight, readout_sparsity, log_hash in zip(
        cnn_filter_nums, conv_smooth_weights, conv_sparse_weights, readout_sparsities, log_hashes):
    base = BaseModel(
        Dataset.load(),
        log_dir=LOG_DIR,
#         log_dir='checkpoints/aecker_mesonet_data/',
        log_hash=log_hash
    )
    core = StackedConv2dCore(
        base,
        base.inputs,
        filter_size=FILTER_SIZE,
        num_filters=num_filters,
        stride=STRIDE,
        rate=RATE,
        padding=PADDING,
        activation_fn=ACTIVATION_FN,
        rel_smooth_weight=REL_SMOOTH_WEIGHT,
        rel_sparse_weight=REL_SPARSE_WEIGHT,
        conv_smooth_weight=conv_smooth_weight,
        conv_sparse_weight=conv_sparse_weight,
    )
    readout = SpatialXFeatureJointL1Readout(
        base,
        core.output,
        positive_feature_weights=POSITIVE_FEATURE_WEIGHTS,
        init_masks=INIT_MASKS,
        readout_sparsity=readout_sparsity,
    )
    model = CorePlusReadoutModel(base, core, readout)
    trainer = Trainer(base, model)
#     model.base.tf_session.load()
    iter_num, val_loss, test_corr = trainer.fit(
        val_steps=VAL_STEPS,
        learning_rate=LEARNING_RATE,
        batch_size=BATCH_SIZE,
        patience=10,   # in the paper we used 5; more reliable results with 10
        lr_decay_steps=LR_DECAY_STEPS)

    print(num_filters)
    print(trainer.compute_test_corr())

## Control: Feature space generalizes to unseen neurons

To show that our network learns common features of V1 neurons, we excluded half of the neurons when fitting the network. We then fixed the rotation-equivariant convolutional core and trained only the readout (spatial mask and feature weights) for the other half of the neurons. 

In terms of implementation, we insert a stop_gradient between the convolutional core and the readout for half of the neurons, which is done in the class for the readout (`SpatialXFeatureJointL1TransferReadout`).

In [None]:
from cnn_sys_ident.architectures.cores import StackedRotEquiHermiteConv2dCore
from cnn_sys_ident.architectures.readouts import SpatialXFeatureJointL1TransferReadout

In [None]:
base = BaseModel(
    Dataset.load(),
    log_dir=LOG_DIR,
    log_hash='b8f78ead705cb02d09c01f9701067ba2'
)
core = StackedRotEquiHermiteConv2dCore(
    base,
    base.inputs,
    num_rotations=NUM_ROTATIONS,
    upsampling=UPSAMPLING,
    shared_biases=SHARED_BIASES,
    filter_size=FILTER_SIZE,
    num_filters=NUM_FILTERS,
    stride=STRIDE,
    rate=RATE,
    padding=PADDING,
    activation_fn=ACTIVATION_FN,
    rel_smooth_weight=REL_SMOOTH_WEIGHT,
    rel_sparse_weight=REL_SPARSE_WEIGHT,
    conv_smooth_weight=0.0112711,
    conv_sparse_weight=0.0492937,
)
readout = SpatialXFeatureJointL1TransferReadout(
    base,
    core.output,
    k_transfer=2,
    positive_feature_weights=POSITIVE_FEATURE_WEIGHTS,
    init_masks=INIT_MASKS,
    readout_sparsity=0.020616,
)
model = CorePlusReadoutModel(base, core, readout)
trainer = Trainer(base, model)
iter_num, val_loss, test_corr = trainer.fit(
    val_steps=VAL_STEPS,
    learning_rate=LEARNING_RATE,
    batch_size=BATCH_SIZE,
    patience=PATIENCE,
    lr_decay_steps=LR_DECAY_STEPS)

trainer.compute_test_corr()