In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import datajoint as dj
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
sns.set_style('ticks', rc={'image.cmap': 'bwr'})

import os
import sys
import inspect

p = !pwd
p = os.path.dirname(os.path.dirname(p[0]))
if p not in sys.path:
    sys.path.append(p)

In [None]:
from cnn_sys_ident.mesonet import MODELS
from cnn_sys_ident.mesonet.controls import FitTrialSubset

In [None]:
plt.figure(figsize=(2.5, 2.5))

frac_trials = np.array([0.125, 0.25, 0.5, 1])
for model in [MODELS['CNNSparse'], MODELS['HermiteSparse']]:
    test_corr = []
    for t in frac_trials + 0.01:
        rel = (FitTrialSubset() * model & 'frac_trials < {}'.format(t))
        test_corr.append(rel.fetch('test_corr', order_by='val_loss', limit=1)[0])

    plt.semilogx(100 * frac_trials, test_corr, 'o-')
    
plt.legend(['CNN', 'RotEqui'])
plt.xlabel('% of trials')
plt.ylabel('Correlation')
plt.tight_layout()
plt.xlim([10, 110])
plt.ylim([0, 0.5])
sns.despine(trim=True, offset=5)
plt.xticks(frac_trials * 100)
plt.gca().set_xticklabels(
    [t if t < 20 else np.round(t).astype(np.int32) for t in frac_trials*100])
plt.minorticks_off()
plt.savefig('figures/frac_trials.eps', format='eps')