In [None]:
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import sys
import os
from scipy import linalg
import visionloader as vl
import torch
import torch.nn as nn
sys.path.insert(0, '/home/agogliet/gogliettino/projects/natural-scenes-reco/repos/imagenet-rgc-reco/')
from src.Dataset import Dataset
import src.models as models
import torch.optim as optim
import time
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error as mse
from skimage.metrics import peak_signal_noise_ratio as psnr
import palettable
import matplotlib as mpl
from matplotlib.patches import Ellipse

## Load data

In [None]:
# cellids_dict = np.load('./tmp/2018-08-07-5/kilosort_data001-kilosort_data002/'\
#                        'cellids_dict.npy',allow_pickle=True).item()
# responses = np.load('./tmp/2018-08-07-5/kilosort_data001-kilosort_data002/test_X1.npy')
# ns_tensor = np.load('./tmp/2018-08-07-5/kilosort_data001-kilosort_data002/test_Y1.npy')
# wn_vcd = vl.load_vision_data('/Volumes/Analysis/2018-08-07-5/kilosort_data011/data011/',
#                              'data011',include_neurons=True,include_params=True,
#                               include_runtimemovie_params=True)
# wn_cellids = sorted(list(cellids_dict['celltypes'].keys()))
# ns_cellids = sorted(list(cellids_dict['ns1_to_ns2']))
cellids_dict = np.load('./tmp/2018-08-07-5/yass_data001-yass_data002/'\
                       'cellids_dict.npy',allow_pickle=True).item()
responses = np.load('./tmp/2018-08-07-5/yass_data001-yass_data002/test_X1.npy')
ns_tensor = np.load('./tmp/2018-08-07-5/yass_data001-yass_data002/test_Y1.npy')
wn_vcd = vl.load_vision_data('/Volumes/Analysis/2018-08-07-5/yass_data000/data000/',
                             'data000',include_neurons=True,include_params=True,
                              include_runtimemovie_params=True)
wn_cellids = sorted(list(cellids_dict['celltypes'].keys()))
ns_cellids = sorted(list(cellids_dict['ns1_to_ns2']))

## Make example plots of RGC activations.

In [None]:
NUM_SIGMAS_RF_FIT = 2
width = int(wn_vcd.runtimemovie_params.width)
height = int(wn_vcd.runtimemovie_params.height)
cmap = palettable.colorbrewer.diverging.Spectral_11_r.mpl_colormap
norm = mpl.colors.Normalize(vmin=0,vmax=15)

for i in range(ns_tensor.shape[0]):
    
    if i not in [0,45]:
         continue
    
    fig,ax = plt.subplots(1,5,figsize=(50,10))
    ax[0].imshow(ns_tensor[i,:,32:288,:],cmap='gray')
    
    # Plot the mosaics colored by the median number of spikes.
    for ww,wn_cell in enumerate(wn_cellids):
        celltype = cellids_dict['celltypes'][wn_cell]
        
        if "on" in celltype and "parasol" in celltype:
            ind = 1 
        elif "off" in celltype and "parasol" in celltype:
            ind = 2 
        elif "on" in celltype and 'midget' in celltype:
            ind = 3 
        elif "off" in celltype and 'midget' in celltype:
            ind = 4 
            
        # Get the STA fits. 
        sta_fit = wn_vcd.get_stafit_for_cell(wn_cell)
        mu_x = sta_fit.center_x
        mu_y = sta_fit.center_y
        sigma_y = sta_fit.std_x
        sigma_x = sta_fit.std_y
        degrees = sta_fit.rot * (180 / np.pi) * -1
        fit = Ellipse(xy = (mu_x,mu_y), width = NUM_SIGMAS_RF_FIT * sigma_y,
              height = NUM_SIGMAS_RF_FIT * sigma_x,
              angle = degrees)
#         ax[ind+1].set(xlim=[0,(width-1)],ylim=[0,(height-1)],aspect=1)
        ax[ind].set(xlim=[0,(width-1)],ylim=[(height-1),0],aspect=1)
        ax[ind].add_artist(fit)
        
        # Get the corresponding NS cell 
        ns_cell = cellids_dict['wn_to_ns1'][wn_cell]
        ns_cell_ind = ns_cellids.index(ns_cell)
        n_spikes = responses[i,ns_cell_ind]
        color = cmap(norm(n_spikes))
        fit.set_facecolor(color)
        fit.set_edgecolor('k')
    
    for j in range(5):
        ax[j].axis('off')
    
    plt.savefig('./tmp/2018-08-07-5/yass_data001-yass_data002/'\
                        'linear-recons/trained-model/figures/encoding_%s.pdf'%str(i)) 
    plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(3,.15))
plt.rc('font',size=18)

norm = mpl.colors.Normalize(vmin=0,
                            vmax=15)

cb1 = mpl.colorbar.ColorbarBase(ax, cmap=cmap,
                                norm=norm,
                                orientation='horizontal')
cb1.set_label('firing rate (spikes)')
plt.savefig('./tmp/2018-08-07-5/yass_data001-yass_data002/'\
                        'linear-recons/trained-model/figures/cbar.pdf')

## Load in training results

In [None]:
trained_model = np.load('./tmp/2018-08-07-5/yass_data001-yass_data002/'\
                        'linear-recons/trained-model/trained_model_10_epochs.npy',allow_pickle=True).item()
trained_model_end_to_end = np.load('./tmp/2018-08-07-5/yass_data001-yass_data002/'\
                        'end-to-end/trained-model/trained_model_2_epochs.npy',allow_pickle=True).item()
Y_hat_test = trained_model['Y_hat_test']
Y_hat_test_end_to_end = trained_model_end_to_end['Y_hat_test']
Y_test = trained_model['Y_test']
X_test = trained_model['X_test']

In [None]:
for i in range(Y_hat_test.shape[0]):
    
#     if i not in [3,26,45,92,124,138,143]:
#         continue
        
    fig,ax = plt.subplots(1,4,figsize=(20,5))
    fig.subplots_adjust(wspace=.05)
#     fig,ax = plt.subplots(1,3,figsize=(30,10))
    ax[0].imshow(Y_test[i,...].squeeze(),cmap='gray')
#     ax[0].set_title('ground truth')
    ax[1].imshow(X_test[i,...].squeeze(),cmap='gray')
#     ax[1].set_title('linear')
    ax[2].imshow(Y_hat_test[i,...].squeeze(),cmap='gray')
#     ax[2].set_title('linear + CNN')
    ax[3].imshow(Y_hat_test_end_to_end[i,...].squeeze(),cmap='gray')
#     ax[3].set_title('end-to-end CNN')
    
    for j in range(4):
        ax[j].axis('off')
    
    print(i)
    
    plt.savefig('./tmp/2018-08-07-5/yass_data001-yass_data002/'\
                        'linear-recons/trained-model/figures/recon_example_%s.pdf'%str(i))
    plt.show()

## Calculate MSE, SSIM, and PSNR

In [None]:
metric_dict = dict()

for metric in ['ssim','mse','psnr']:
    metric_dict[metric] = dict()
metric_dict['ssim']['linear'] = np.asarray([ssim(Y_test[i,...].squeeze(),
                                  X_test[i,...].squeeze()) for i in range(Y_test.shape[0])])
metric_dict['ssim']['linear_cnn'] = np.asarray([ssim(Y_test[i,...].squeeze(),
                                  Y_hat_test[i,...].squeeze()) for i in range(Y_test.shape[0])])
metric_dict['ssim']['cnn_end_to_end'] = np.asarray([ssim(Y_test[i,...].squeeze(),
                                  Y_hat_test_end_to_end[i,...].squeeze()) for i in range(Y_test.shape[0])])
metric_dict['mse']['linear'] = np.asarray([mse(Y_test[i,...].squeeze(),
                                  X_test[i,...].squeeze()) for i in range(Y_test.shape[0])])
metric_dict['mse']['linear_cnn'] = np.asarray([mse(Y_test[i,...].squeeze(),
                                  Y_hat_test[i,...].squeeze()) for i in range(Y_test.shape[0])])
metric_dict['mse']['cnn_end_to_end'] = np.asarray([mse(Y_test[i,...].squeeze(),
                                  Y_hat_test_end_to_end[i,...].squeeze()) for i in range(Y_test.shape[0])])
metric_dict['psnr']['linear'] = np.asarray([psnr(Y_test[i,...].squeeze(),
                                  X_test[i,...].squeeze()) for i in range(Y_test.shape[0])])
metric_dict['psnr']['linear_cnn'] = np.asarray([psnr(Y_test[i,...].squeeze(),
                                  Y_hat_test[i,...].squeeze()) for i in range(Y_test.shape[0])])
metric_dict['psnr']['cnn_end_to_end'] = np.asarray([psnr(Y_test[i,...].squeeze(),
                                  Y_hat_test_end_to_end[i,...].squeeze()) for i in range(Y_test.shape[0])])

In [None]:
np.median(metric_dict['mse']['linear_cnn'])

In [None]:
plt.plot(metric_dict['mse']['linear'])
plt.plot(metric_dict['mse']['linear_cnn'])

## Make scatter plot comparing linear and conv 

In [None]:
ncol = 3
nrow = 3
fig,ax = plt.subplots(ncol,nrow,figsize=(15,15))
fig.subplots_adjust(wspace=.3,hspace=.3)
plt.rc('font',size=18)
i = 0
j = 0
comps = [['linear','linear_cnn'],['linear_cnn','cnn_end_to_end'],['linear','cnn_end_to_end']]

for comp in comps:
    comp1 = comp[0]
    comp2 = comp[1]
    
    for metric in metric_dict:
        ax[i,j].scatter(metric_dict[metric][comp1],
                        metric_dict[metric][comp2],c='k')
        ax[i,j].set_xlabel('%s'%comp1)
        ax[i,j].set_ylabel('%s'%comp2)
        ylim = ax[i,j].get_ylim()
        xlim = ax[i,j].get_xlim()
        lim = np.min((ylim[0],xlim[0])),np.max((ylim[1],xlim[1]))
        ax[i,j].set_xlim(lim)
        ax[i,j].set_ylim(lim)
        
        ax[i,j].plot([0, 1], [0, 1], transform=ax[i,j].transAxes,c='r')
        
        ax[i,j].set_title("%s"%(metric))
        j +=1
        
        if j == ncol:
            j = 0
            i +=1
            
plt.savefig('model_analysis_scatter.pdf')
plt.show()

In [None]:
np.mean(metric_dict['psnr']['cnn_end_to_end'])

In [None]:
np.mean(metric_dict['psnr']['linear_cnn'])

In [None]:
np.mean(metric_dict['psnr']['linear'])

In [None]:
np.min()

In [None]:
ax[0,0].get_ylim()

In [None]:
comp2

In [None]:
metric_dict['psnr'].keys()

In [None]:
print(metric_dict['psnr'])
print(metric_dict['mse'])
print(metric_dict['ssim'])