In [4]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')
from sklearn.cross_validation import StratifiedShuffleSplit
from sklearn.grid_search import GridSearchCV
import numpy as np
from scipy import io
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
import seaborn as sns
from utils import model_selection
from collections import defaultdict
from sklearn.metrics import mutual_info_score, accuracy_score
from tqdm import tqdm, trange
from itertools import product
import pandas as pd

sns.set_style('ticks')
sns.set_context('notebook',font_scale=1.3)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
dat = io.loadmat('/data/neurons08.mat')
ori, contrasts, n_bins, n_trials, times, n_neurons = \
    dat['ori'].squeeze(),  dat['contrast'].squeeze(), \
        int(dat['nBins']), int(dat['nTrials']), dat['times'].squeeze(), int(dat['nNeurons'])
responses = np.stack([np.stack(e[:-1], axis=0)  for e in dat['feat'].squeeze()], axis=0) # :-1 to exclude last incomplete trials


print('contrast X orientations X neurons X time X trials', responses.shape)


contrast X orientations X neurons X time X trials (2, 7, 20, 90, 85)


### a) compute mutual information for each single neuron

In [34]:
orientation_classes = [0,2]
time_idx = (times >= 50) & (times <= 250)
results = defaultdict(list)
params = dict(
            C=10.**np.arange(-5.,1., .5), 
            penalty=['l1','l2']
        )
i_contrast = 1
for neuron in trange(1, n_neurons):
    X = np.vstack([responses[i_contrast, _, neuron, time_idx, :].T for _ in orientation_classes])
    y = np.ones(2*n_trials)
    y[n_trials:] = 2
    for resample, (train_idx, test_idx) in enumerate(StratifiedShuffleSplit(y, n_iter=5, test_size=.3)):
        model = LogisticRegression(C=1.,penalty='l2')
       
        best = model_selection(model, X[train_idx], y[train_idx], params, cv=5,  scoring='accuracy', n_jobs=5)
        yhat = best.predict(X[test_idx])
        results['neuron'].append(neuron)
        results['accuracy'].append(accuracy_score(y[test_idx], yhat))
        results['resample'].append(resample)
        results['mode'].append('single')
        results['mutual information [bits]'].append(mutual_info_score(y[test_idx], yhat)/np.log(2.))
        
        
        

100%|██████████| 19/19 [00:27<00:00,  1.39s/it]


### b) compute mutual information for the entire population

In [35]:
X = np.vstack([responses[i_contrast, _, 1:, time_idx, :].reshape((-1, n_trials)).T for _ in orientation_classes])
y = np.ones(2*n_trials)
y[n_trials:] = 2
for resample, (train_idx, test_idx) in enumerate(StratifiedShuffleSplit(y, n_iter=5, test_size=.3)):
    model = LogisticRegression(C=1.,penalty='l2')

    best = model_selection(model, X[train_idx], y[train_idx], params, cv=5,  scoring='accuracy', n_jobs=5)
    yhat = best.predict(X[test_idx])
    results['neuron'].append('2-20')
    results['accuracy'].append(accuracy_score(y[test_idx], yhat))
    results['resample'].append(resample)
    results['mode'].append('population')
    results['mutual information [bits]'].append(mutual_info_score(y[test_idx], yhat)/np.log(2.))
     

### c) compare single mutual information, population, and sum over single mutual informations

In [41]:
df = pd.DataFrame(results)
avg = df.groupby(['neuron','mode']).mean()
avg

Unnamed: 0_level_0,Unnamed: 1_level_0,accuracy,mutual information [bits],resample
neuron,mode,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,single,0.64,0.082934,2
2,single,0.5,0.00454,2
3,single,0.916,0.630726,2
4,single,0.508,0.001733,2
5,single,0.624,0.04965,2
6,single,0.48,0.018047,2
7,single,0.52,0.004185,2
8,single,0.924,0.666139,2
9,single,0.56,0.022233,2
10,single,0.588,0.081051,2


In [43]:
avg.reset_index().groupby('mode').mean()

Unnamed: 0_level_0,accuracy,mutual information [bits],resample
mode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
population,0.992,0.95108,2
single,0.611789,0.127591,2


In [44]:
avg.reset_index().groupby('mode').sum()

Unnamed: 0_level_0,accuracy,mutual information [bits],resample
mode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
population,0.992,0.95108,2
single,11.624,2.424234,38
