In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import datajoint as dj
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

import seaborn as sns
sns.set_style('ticks')

import os
import sys
import inspect

p = !pwd
p = os.path.dirname(os.path.dirname(p[0]))
if p not in sys.path:
    sys.path.append(p)

In [None]:
from cnn_sys_ident.mesonet.parameters import Fit, Model, Core, Readout, RegPath
from cnn_sys_ident.mesonet.data import MultiDataset
from cnn_sys_ident.mesonet import MODELS

# Main model

In [None]:
num_filters = 16
data_key = dict(data_hash='cfcd208495d565ef66e7dff9f98764da')

In [None]:
model_rel = MODELS['HermiteSparse'] * MultiDataset() & data_key \
    & 'positive_feature_weights=False AND shared_biases=False' \
    & {'num_filters_2': num_filters}
print(len(Fit() * model_rel))
val_loss, test_corr = (Fit() * model_rel).fetch(
    'val_loss', 'test_corr', order_by='val_loss', limit=5)
print('Loss: {:.1f}, avg corr: {:.3f} +/- {:.4f}'.format(val_loss[0], test_corr[0], test_corr.std()))

# Control: dense, L2-regularized feature weights

In [None]:
model_rel = MODELS['HermiteDenseSeparate'] * MultiDataset() & data_key \
    & 'positive_feature_weights=False AND shared_biases=False' \
    & {'num_filters_2': num_filters}
print(len(Fit() * model_rel))
val_loss, test_corr = (Fit() * model_rel).fetch(
    'val_loss', 'test_corr', order_by='val_loss', limit=5)
print('Loss: {:.1f}, avg corr: {:.3f} +/- {:.4f}'.format(val_loss[0], test_corr[0], test_corr.std()))

# Control: positive feature weights

In [None]:
model_rel = MODELS['HermiteSparse'] * MultiDataset() & data_key \
    & 'positive_feature_weights=True AND shared_biases=False' \
    & {'num_filters_2': num_filters}
print(len(Fit() * model_rel))
val_loss, test_corr = (Fit() * model_rel).fetch(
    'val_loss', 'test_corr', order_by='val_loss', limit=5)
print('Loss: {:.1f}, avg corr: {:.3f} +/- {:.4f}'.format(val_loss[0], test_corr[0], test_corr.std()))

# Baseline: regular CNNs

In [None]:
cnn_filter_nums = [
    [32, 32, 32],
    [64, 64, 64],
    [128, 128, 128],
    [128, 128, 256],
]
for n in cnn_filter_nums:
    key = {'num_filters_{:d}'.format(i): n[i] for i in range(len(n))}
    model_rel = MODELS['CNNSparse'] * MultiDataset() & data_key \
        & 'positive_feature_weights=False'
    print(len(Fit() * model_rel & key))
    val_loss, test_corr = (Fit() * model_rel & key).fetch(
        'val_loss', 'test_corr', order_by='val_loss', limit=5)
    print('Features: {}, Loss: {:.1f}, Avg. corr: {:.3f} +/- {:.4f}'.format(
        n, val_loss[0], test_corr[0], test_corr.std()))