In [1]:
## Import Simulation Function Library
import sys
import os
path = 'C:/Users/fritz/Google Drive/Graduate School/Research/Data Analysis/Code'
sys.path.append(path) 
import matplotlib
from RetinaOptLib import *


## TODO:
   - Implement params data structure for retinaoptlib
   - normalize linear decoder (cell stas) in either inf_norm or l2_norm
   - Run simulation for 1 dataset
   - frequency analysis:
       - what is the frequency weighing mask we're using
       - what do the frequency reconstructions look like relative to each other? original? on average?
       - how does this compare with frequency weighting mask?
       - look at the whole visual scene & calculate its error as a function of bandwidth. should get flat--> csf for higher freqs
       
       
   

In [None]:
## Set Figure Display Parameters
%matplotlib notebook
matplotlib.rcParams['figure.figsize'] = (16, 9)

In [None]:
## load data
#img_data = load_data('img_data_19-08-31_14-02.dat')
#sim_data = load_data('sim_data_optimal_19-08-31_14-02.dat')

#img_data = load_data('waterfall_img_data_19-09-01_19-18.dat')
#sim_data = load_data('waterfall_sim_data_optimal_19-09-01_19-18.dat')

data_set = load_data('pic_sweep_data_set_19-09-01_23-33.dat')

# stim_sweep_data = load_data('stim_sweep_data_19-08-31_17-05.dat')

In [None]:
# dict.mat is piece 2015-11-09-03,  Stixel 8, Eccentricity 20 deg

In [None]:
## save data
# save_data(img_data,'waterfall_img_data')
# save_data(sim_data,'waterfall_sim_data_optimal')
#save_data(stim_sweep_data,'waterfall_stim_sweep_data')
#save_data(data_set,'pic_sweep_data_set')

In [None]:
### Analysis Question Outline
## 1) How does the choice of error metric change optimal cellular activity during linear reconstruction?
    
#     a) Single Image Analysis:
#         - Reconstructed Image Comparison: Is there a perceivable difference in reconstructed Image?
#         - Activity: 
#             i) Is there a difference in which cells are selected in terms of angle?
#             ii) What are the total number of spikes required by each algo


#     b) Multi-Image Analysis:
#         - Activit|y:
#             i) Display the optimal activity angle relative to MSE for wMSE and SSIM reconstructions
#             ii) Histogram of total number of spikes required over all images for each metric


# 2) Distortion: How do the activities change when total spikes are limited? 

#         a) Error Metric Analysis: For each metric, plot the error of each reconstruction versus num of allowed spikes averaged over all images
#             i) Where do the metric diverge? What do the reconstructions look like? 


#         b) Activity Analysis:  How different is distored activity from optimal in terms of angle betweee, n vectors? 
#             i) Over each image reconstruction, calculate the anlge between distorted image activity and optimal activity
#             ii) average over all images
#             iii) plot numStims vs angle for each metric. Does one metric converge to optimal activity faster? 


# 3) Separability: Given a set of cell activities can we separate them into clusters that predict the error metric used?
       
#        a) For each metric, stack the cell activity vectors for thatac metrics into the same matrix and perform pca
       
#        b) Plot each data point/image with color specifying metric (maybe compare to random). See if linearly separable. 



In [2]:
## Simulation to Generate Data

# load & Preprocess Image
img = load_raw_img(os.getcwd()+"\pics\\nature_pic.jpg")
ret_data = load_raw_data('dict2.mat')

## Run Simulation
run_sim = True
run_conv_sim    = False
run_photo_sweep = False
### Simulation Parameters
stixel_dims = (20,20)  # Pixel Dimensions of Retinal STA 
num_stims = -1    # Number of Allowable stimulations
electrode = False  # Solve for Electrode-Dictionary Reconstruction or Optimal Cellular Activity
max_act  = 100     # Maximum cell activity for time window

### Psychophysical Parameters
smps = 5.5              # Stimulus Monitor Pixel Size in micrometers
stixel_size = 8          # Stixel Size (number of monitor pixels per stimulus pixel)
eccentricity = 20       # Eccentricity from Fovea of Tissue Center
eye_diam = 24            # Diameter of Eye in milimeters
obj_vis_ang = (120,100)   # Size of VIsualField of View (degrees)
L  = 350                # L Luminance (cd/m2) of object
k = 3                   # Psychometric constant (minimum detection signal to noise ratio)
T = 100                 # Integration time of the eye (in msec)
Ng0 = 36000             # RGC Density at fovea (cells/deg^2)
eg  = 3                 # Subject-dependent  Cell Density Constant (deg)
elec_vis_ang = get_elec_angs(stixel_dims, smps, stixel_size, eye_diam)
# pupil diameter (mm)
# for now calculate as function of luminance, in future can be measured directly via eye tracking
pupil_diam  = 5 - 3 * np.tanh( .4*np.log(L*obj_vis_ang[0]*obj_vis_ang[1]/1600))

params = Parameters.empty()
# psychophysical parameters
params.L = L
params.vis_ang_horz = obj_vis_ang[0]
params.vis_ang_vert = obj_vis_ang[1]
params.pupil_diam = pupil_diam
params.e = eccentricity
# compute the  visual angle spanned by the electrode array 
params.elec_ang_horz = elec_vis_ang[0]
params.elec_ang_vert = elec_vis_ang[1]
params.binocular = True
params.k = k
params.T = T/1000
params.Ng0 = Ng0    
params.eg = eg
# simulation parameters
params.A = ret_data.A
params.P = ret_data.P
params.e_map = ret_data.e_map
params.e_locs = ret_data.e_locs
params.num_cells = ret_data.A.shape[1]
params.num_stixels = ret_data.num_stixels
params.stixel_dims = stixel_dims
params.num_stims = num_stims
params.max_act = max_act
params.electrode = electrode

selec_dims = get_selection_dims(params, img.shape)
print('%i x %i degrees of visual angle is %i x %i pixels of the original image.'
      %(params.elec_ang_horz,params.elec_ang_vert, selec_dims[0],selec_dims[1]))


# Run the Simulation 
if run_sim:
    img_data = preprocess_img(img, params)
    sim_data    =  metric_compar(img_data,params)
    
if run_conv_sim:
    img_data = preprocess_img(img,params)
    stim_sweep_data = num_stim_sweep(img_data,params)
        
if run_photo_sweep:
    data_set = {}
    img_list = os.listdir('./pics')
    num_pics = len(img_list)
    data_set["num_pics"] = num_pics
    for i in np.arange(num_pics):
        img = load_raw_img('./pics/'+img_list[i])
        print('Running Simulation %i/%i'%(i+1, num_pics))
        data_set["img_data_"+str(i)] = preprocess_img(img,params)
        data_set["sim_data_"+str(i)] = metric_compar(data_set["img_data_"+str(i)],params)
        


2 x 2 degrees of visual angle is 34 x 34 pixels of the original image.
Tiling Image ...
Tiled Image
Solving for Cellular Activities...
MSE Activity Reconsruction:


  1%|▋                                                                               | 14/1737 [00:02<05:13,  5.49it/s]

TypeError: act_solver() missing 1 required positional argument: 'electrode'

In [None]:
## single image comparison
%matplotlib notebook
vmax = np.max(img)
vmin = np.min(img)
plt.rcParams["figure.figsize"] = [16,9]
Tidx = 15
imgNum = 980
cmode = 'RMS' #use RMS (standard deviation) contrast for comparison

origIm = np.reshape(img_data.img_set[:,imgNum],pixel_dims,order='F')
mseRec = np.reshape(sim_data.imgs_mse[:,imgNum],pixel_dims,order='F')
wmsRec = np.reshape(sim_data.imgs_wms[:,imgNum],pixel_dims,order='F')
ssmRec = np.reshape(sim_data.imgs_ssm[:,imgNum],pixel_dims,order='F')

plt.figure()
plt.imshow(origIm,cmap='bone',vmax=vmax,vmin=vmin)
plt.axis('off')
plt.title("Original Image, Contrast = %f"%contrast(origIm,cmode))
plt.savefig('Original.jpg',bbox_inches='tight')
plt.show()

plt.figure()
plt.imshow(mseRec,cmap='bone',vmax=vmax,vmin=vmin)
plt.axis('off')
plt.title("MSE Optimal Reconstruction, Contrast = %f"%contrast(mseRec,cmode))
plt.savefig('MSERecons.jpg',bbox_inches='tight')


plt.figure()
plt.imshow(wmsRec,cmap='bone',vmax=vmax,vmin=vmin)
plt.axis('off')
plt.title("wMSE Optimal Reconstruction, Contrast = %f"%contrast(wmsRec,cmode))
plt.savefig('wMSRecons.jpg',bbox_inches='tight')


plt.figure()
plt.imshow(ssmRec,cmap='bone',vmax=vmax,vmin=vmin)
plt.axis('off')
plt.title("SSIM Optimal Reconstruction, Contrast = %f"%contrast(ssmRec,cmode))
plt.savefig('SSMRecons.jpg',bbox_inches='tight')



In [None]:
## display entire visual scene tiled by reconstructions for each error metric
cmode='rms'

num_pics = 14
for i in np.arange(num_pics):
    img_data = data_set["img_data_"+str(i)]
    sim_data = data_set["sim_data_"+str(i)]
    vmax = np.max(img_data.resampled_img)
    vmin = np.min(img_data.resampled_img)

    plt.figure()
    plt.imshow(img_data.resampled_img,cmap='bone',vmax=vmax,vmin=vmin)
    plt.axis('off')
    plt.title('Original Visual Scene, Contrast = %f'%contrast(img_data.orig_img,cmode))
    plt.savefig('./reconstructions/vis_scene_orig_%i.jpg'%i,bbox_inches='tight')
    plt.show()

    plt.figure()
    plt.imshow(sim_data.recons_mse,cmap='bone',vmax=vmax,vmin=vmin)
    plt.axis('off')
    plt.title('MSE-Optimal Reconstructed Visual Scene, Contrast = %f'%contrast(sim_data.recons_mse,cmode))
    plt.savefig('./reconstructions/vis_scene_rec_mse_%i.jpg'%i,bbox_inches='tight')

    plt.figure()
    plt.imshow(sim_data.recons_wms,cmap='bone',vmax=vmax,vmin=vmin)
    plt.axis('off')
    plt.title('wMSE-Optimal Reconstructed Visual Scene, Contrast = %f'%contrast(sim_data.recons_wms,cmode))
    plt.savefig('./reconstructions/vis_scene_rec_wms_%i.jpg'%i,bbox_inches='tight')

    plt.figure()
    plt.imshow(sim_data.recons_ssm,cmap='bone',vmax=vmax,vmin=vmin)
    plt.axis('off')
    plt.title('SSIM-Optimal Reconstructed Visual Scene, Contrast = %f'%contrast(sim_data.recons_ssm,cmode))
    plt.savefig('./reconstructions/vis_scene_rec_ssm_%i.jpg'%i,bbox_inches='tight')

In [None]:
## scatter plot of contrasts

def get_conts_vec(imgs, cont_mode ='rms'):
    # given a set (num_pixel x num_imgs)  of images, return a num_imgs vector of the contrast
    # values for all those images specified by mode
    num_imgs = imgs.shape[1]
    conts = np.zeros((num_imgs,))
    for i in np.arange(num_imgs):
        conts[i] = contrast(imgs[:,i],cont_mode)
    return conts

img_list = os.listdir('./pics')
num_pics = len(img_list)

                            
plt.figure(figsize=(16,9))
a_point = 1
s = 1
img_nums = np.arange(data_set["img_data_0"].num_imgs)+1

%matplotlib notebook
#num_pics = data_set['num_pics']
sim_data = data_set["sim_data_0"]
conts_mat_mse = get_conts_vec(sim_data.imgs_mse)
conts_mat_wms = get_conts_vec(sim_data.imgs_wms)
conts_mat_ssm = get_conts_vec(sim_data.imgs_ssm)

for i in np.arange(num_pics-1)+1:
    sim_data = data_set["sim_data_"+str(i)]
    conts_mse = get_conts_vec(sim_data.imgs_mse)
    conts_wms = get_conts_vec(sim_data.imgs_wms)
    conts_ssm = get_conts_vec(sim_data.imgs_ssm) 
    
    conts_mat_mse = np.hstack((conts_mat_mse,conts_mse))
    conts_mat_wms = np.hstack((conts_mat_wms,conts_wms))
    conts_mat_ssm = np.hstack((conts_mat_ssm,conts_ssm))

plt.scatter(conts_mat_ssm,conts_mat_mse,label='MSE Reconstructions',s=s,alpha=a_point,marker='.')
plt.scatter(conts_mat_ssm,conts_mat_wms,label='wMSE Reconstructions',s=s,alpha=a_point,marker='.')
plt.scatter(conts_mat_ssm,conts_mat_ssm,label='SSIM Reconstructions',s=s,alpha=a_point,marker='.')
plt.title("MSE and wMSE Contrasts vs. SSIM Contrasts \n Data from 14 Natural Scenes")
plt.xlabel('RMS Contrast of SSIM Image Reconstruction')
plt.ylabel('RMS Contrast of Image Reconstruction')
plt.legend()
plt.savefig('conts_vs_ssm.jpg',bbox_inches='tight')


#report correlation coefficients of each variable:
mse_v_wms_v_ssm_conts = np.vstack((conts_mse,conts_wms,conts_ssm))
corr_coeffs = np.corrcoef(mse_v_wms_v_ssm_conts)
print((corr_coeffs))

print(np.sqrt((conts_mat_wms-conts_mat_ssm).T@(conts_mat_wms-conts_mat_ssm)/conts_mat_mse.size))
print(np.sqrt((conts_mat_ssm-conts_mat_mse).T@(conts_mat_ssm-conts_mat_mse)/conts_mat_mse.size))



In [None]:
## single image cellular activity comparison
plt.figure(figsize= (16,3))
cellNums = np.arange(sim_data.acts_mse[:,imgNum].size) + 1
plt.plot(cellNums,sim_data.acts_mse[:,imgNum],'-',c='r',alpha=.3)
plt.plot(cellNums,sim_data.acts_wms[:,imgNum],'-',c='g',alpha=.3)
plt.plot(cellNums,sim_data.acts_ssm[:,imgNum],'-',c='b',alpha=.3)
plt.plot(cellNums,sim_data.acts_mse[:,imgNum],'.',c='r',label='MSE-Optimal Activity, Total Spikes = %i'%np.sum(sim_data.acts_mse[:,imgNum]))
plt.plot(cellNums,sim_data.acts_wms[:,imgNum],'.',c='g',label='wMSE-Optimal Activity, Total Spikes = %i'%np.sum(sim_data.acts_wms[:,imgNum]))
plt.plot(cellNums,sim_data.acts_ssm[:,imgNum],'.',c='b',label='SSIM-Optimal Activity, Total Spikes = %i'%np.sum(sim_data.acts_ssm[:,imgNum]))
plt.xlabel('Cell Number')
plt.xticks(np.arange(min(cellNums), max(cellNums), 5)+1)
plt.ylabel('Number of Spikes in Optimal Reconstruction')
plt.title('Optimal Cellular Activity for Single Image')
plt.legend()
plt.savefig('SingleImageCellActComp.jpg',bbox_inches='tight')




In [None]:
#radial plot of image reconstructions 
a = .05
r_lim = 1000
plt.figure(figsize=(10,10))
act_angle_plot(sim_data.acts_mse,sim_data.acts_wms,'wMSE',r_lim, a=a)
plt.savefig('mseVSwmsAng.jpg',bbox_inches='tight')


plt.figure(figsize=(10,10))
act_angle_plot(sim_data.acts_mse,sim_data.acts_ssm,'SSIM',r_lim, a=a)
plt.savefig('mseVSssmAng.jpg',bbox_inches='tight')







In [None]:
## count number of on parasol cells and off parasol cells spikes for each metric
# do the ratio of cell type selections change according to error metric? 

# #first identify which cells are on and which are off
# # On = cells 0 - 39, off = cells 40::
# off_cell_start = 40
# # then go through each image activity and calculate the number of on spikes and number of off spikes
# spikes_on_mse = np.zeros((imgData.numImgs,))
# spikes_off_mse = np.zeros((imgData.numImgs,))

# spikes_on_ssm = np.zeros((imgData.numImgs,))
# spikes_off_ssm = np.zeros((imgData.numImgs,))

# spikes_on_wms = np.zeros((imgData.numImgs,))
# spikes_off_wms = np.zeros((imgData.numImgs,))
# for i in np.arange(imgData.numImgs):
#     spikes_on_mse[i] = np.sum(simData.mseActs[0:off_cell_start-1,i])
#     spikes_off_mse[i]  =  np.sum(simData.mseActs[off_cell_start:,i])

#     spikes_on_wms[i] = np.sum(simData.wmsActs[0:off_cell_start-1,i])
#     spikes_off_wms[i]  =  np.sum(simData.wmsActs[off_cell_start:,i])
      
#     spikes_on_ssm[i] = np.sum(simData.ssmActs[0:off_cell_start-1,i])
#     spikes_off_ssm[i]  =  np.sum(simData.ssmActs[off_cell_start:,i])
    
# #  # for an image, calculate the ratio of on to off, do so for all images, for all error metric sets,
# plt.figure()
# plt.scatter(spikes_on_mse,spikes_off_mse,label="MSE",alpha=.5)
# plt.scatter(spikes_on_wms,spikes_off_wms,label="wMSE",alpha=.5)
# plt.scatter(spikes_on_ssm,spikes_off_ssm,label="SSIM",alpha=.5)
# plt.legend()

# ratio_mse = np.divide(spikes_on_mse, spikes_off_mse)
# ratio_wms = np.divide(spikes_on_wms, spikes_off_wms)
# ratio_ssm = np.divide(spikes_on_ssm, spikes_off_ssm)

# plt.figure()
# plt.scatter(ratio_mse,ratio_wms,label='MSE vs WMS Ratio',alpha=.1)
# plt.scatter(ratio_mse,ratio_ssm,label='MSE vs SSIM Ratio',alpha=.1)

# or, for each image, calculate ratio ssim/wmse on to mse on, and ratio ssim/wmse off to mse off and plot on scatter axis.
# on_ratio_mse_ssm = np.divide(spikes_on_ssm,spikes_on_mse)

# off_ratio_mse_ssm = np.divide(spikes_off_ssm,spikes_off_mse)
# plt.figure()
# plt.scatter(on_ratio_mse_ssm)
# #plt.scatter(on_ratio_mse_ssm,off_ratio_mse_ssm,alpha=.2,label='SSIM / ')

# # on_ratio_mse_wms = np.divide(spikes_on_wms,spikes_on_mse)
# # off_ratio_mse_wms = np.divide(spikes_off_wms,spikes_off_mse)
# # plt.scatter(on_ratio_mse_wms,off_ratio_mse_wms,alpha=.2)
# plt.show()



In [None]:
## spike divergence histograms


plt.figure(figsize=(16,9))

sums_mse = np.sum(sim_data.acts_mse,0)
sums_ssm = np.sum(sim_data.acts_ssm,0)
percs_ssm = 100*np.divide(sums_ssm-sums_mse,sums_mse)
bins = np.linspace(0, 250, 100)
plt.hist(percs_ssm,bins=bins,label="Average = %f"%np.mean(percs_ssm),alpha=1,color="black")
plt.axvline(x=np.mean(percs_ssm),color='black')
plt.title('Histogram of SSIM-MSE Spike Divergence Across Image Set',fontSize=18)
plt.xlabel('Percent Difference of Total Spikes',fontsize=18)
plt.ylabel('Number of Images',fontsize=18)
plt.legend()
plt.savefig('divergence_hist_mse_ssm.jpg',bbox_inches='tight')
plt.show()
    

plt.figure()
sums_mse = np.sum(sim_data.acts_mse,0)
sums_wms = np.sum(sim_data.acts_wms,0)
percs_wms = 100*np.divide(sums_wms-sums_mse,sums_mse)
bins = np.linspace(0, 250, 100)
plt.hist(percs_wms,bins=bins,label="Average = %f"%np.mean(percs_wms),alpha=1,color="black")
plt.axvline(x=np.mean(percs_wms),color='black')
plt.title('Histogram of wMSE-MSE Spike Divergence Across Image Set',fontSize=18)
plt.xlabel('Percent Difference of Total Spikes',fontsize=18)
plt.ylabel('Number of Images',fontsize=18)
plt.legend()
plt.savefig('divergence_hist_mse_wms.jpg',bbox_inches='tight')
plt.show()




In [None]:
## total number of spikes for each image for each metric,
img_list = os.listdir('./pics')
num_pics = len(img_list)


%matplotlib notebook
#num_pics = data_set['num_pics']
sim_data = data_set["sim_data_0"]
spikes_mat_mse = np.sum(sim_data.acts_mse,0)
spikes_mat_wms = np.sum(sim_data.acts_wms,0)
spikes_mat_ssm = np.sum(sim_data.acts_ssm,0)
for i in np.arange(num_pics-1)+1:
    sim_data = data_set["sim_data_"+str(i)]
    spikes_mse = np.sum(sim_data.acts_mse,0)
    spikes_wms = np.sum(sim_data.acts_wms,0)
    spikes_ssm = np.sum(sim_data.acts_ssm,0)
                            
    spikes_mat_mse = np.hstack((spikes_mat_mse,spikes_mse))
    spikes_mat_wms = np.hstack((spikes_mat_wms,spikes_wms))
    spikes_mat_ssm = np.hstack((spikes_mat_ssm,spikes_ssm))
    
                            

a_point = 1
s = 1



plt.figure(figsize=(16,9))
plt.title("MSE and wMSE Spike Totals vs. SSIM Spike Totals \n Data from 14 Natural Scenes")
plt.xlabel('Number of Spikes for SSIM Image Reconstruction')
plt.ylabel('Number of Spikes Required for Image Reconstruction')
plt.scatter(spikes_mat_ssm,spikes_mat_mse,s=s,alpha=a_point,marker='.',label='MSE Reconstructions')
plt.scatter(spikes_mat_ssm,spikes_mat_wms,label='wMSE Reconstructions',s=s,alpha=a_point,marker='.')
plt.scatter(spikes_mat_ssm,spikes_mat_ssm,label='SSIM Reconstructions',s=s,alpha=a_point,marker='.')
plt.legend()
plt.savefig('num_spikes_sorted_ssm.jpg',bbox_inches='tight')


# report correlation coefficients of each variable:
mse_v_wms_v_ssm_acts = np.vstack((spikes_mat_mse,spikes_mat_wms,spikes_mat_ssm))
corr_coeffs = np.corrcoef(mse_v_wms_v_ssm_acts)
print(np.round(corr_coeffs,2))

print(np.sqrt((spikes_mat_wms-spikes_mat_ssm).T@(spikes_mat_wms-spikes_mat_ssm)/spikes_mat_ssm.size))
print(np.sqrt((spikes_mat_ssm-spikes_mat_mse).T@(spikes_mat_ssm-spikes_mat_mse)/spikes_mat_mse.size))


In [None]:
## plot contrast/ numspikes for each metric
from scipy import stats
xs = np.linspace(np.min(spikes_mat_ssm),np.max(spikes_mat_ssm),1000)

a = 1
s = .05
plt.figure(figsize=(16,9))
plt.scatter(spikes_mat_mse,conts_mat_mse,label='MSE',alpha=a,s=s,c='#1f77b4')


plt.scatter(spikes_mat_wms,conts_mat_wms,label='wMSE',alpha=a,s=s,c='#ff7f0e')

plt.scatter(spikes_mat_ssm,conts_mat_ssm,label='SSIM',alpha=a,s=s,c='#2ca02c')
plt.show()


plt.title('Image RMS Contrast vs. Number of Spikes used in Reconstruction')
plt.xlabel('Total Number of Spikes in Optimal Reconstruction \n Data from 14 Natural Scenes')
plt.ylabel('RMS Image Contrast')
plt.savefig('Contrast_vs_num_spikes.jpg',bbox_inches='tight')
plt.show()

slope, intercept, r_value, p_value, std_err = stats.linregress(spikes_mat_mse,conts_mat_mse)



mse_corr = np.corrcoef(np.vstack((spikes_mat_mse,conts_mat_mse)))
wms_corr = np.corrcoef(np.vstack((spikes_mat_wms,conts_mat_wms)))
ssm_corr = np.corrcoef(np.vstack((spikes_mat_ssm,conts_mat_ssm)))
print(mse_corr)
print(wms_corr)
print(ssm_corr)

In [None]:
# ## remove top percentile spike diverging images from scene for wMSE and SSIM and compare
diffs_mse_ssm = sums_ssm-sums_mse
diffs_mse_wms = sums_wms-sums_mse
threshold = 400 

diff_ssm = []
diff_wms = []
for i in np.arange(img_data.num_imgs):
    if diffs_mse_ssm[i] < threshold:
        diff_ssm.append(i)
    if diffs_mse_wms[i] < threshold:
        diff_wms.append(i)
        
        
ssm_copy = copy.deepcopy(img_data.orig_img)
wms_copy = copy.deepcopy(img_data.orig_img)

# remove diff_ssm image sfrom ssm_copy
dx = 34
dy = dx
for i in np.arange(len(diff_ssm)):
    x = int(img_data.xs[diff_ssm[i]])
    y = int(img_data.ys[diff_ssm[i]])
    ssm_copy[x:x+dx,y:y+dy] = 1
    
for i in np.arange(len(diff_wms)):
    x = int(img_data.xs[diff_wms[i]])
    y = int(img_data.ys[diff_ssm[i]])
    wms_copy[x:x+dx,y:y+dy] = 1

plt.figure()
plt.imshow(ssm_copy,cmap='bone',vmax=vmax,vmin=vmin)
plt.show()

plt.figure()
plt.imshow(wms_copy,cmap='bone',vmax=vmax,vmin=vmin)
plt.show()


In [None]:
# spike divergence versus reference image contrast scatter plot
diffs_mse_ssm = sums_ssm-sums_mse
diffs_mse_wms = sums_wms-sums_mse

# for each image, 
conts_ref = np.zeros((img_data.num_imgs,))
for i in np.arange(img_data.num_imgs):
    # calculate the rms contrast of the reference image
    conts_ref[i] = contrast(img_data.img_set[:,i],'rms')

    
# scatter plot 
plt.figure()
plt.scatter(conts_ref,diffs_mse_ssm,label='SSIM Reconstructions')
plt.scatter(conts_ref,diffs_mse_wms,label='wMSE Reconstruction')
plt.show()

In [None]:
## relative error vs number of allowed stimulations 
plt.figure()
plot_stim_comparison(img_data, stim_sweep_data, "MSE", psych_params, (20,20))
plt.savefig("RelativeMSEConv.jpg",bbox_inches='tight')

plt.figure()
plot_stim_comparison(img_data, stim_sweep_data, "wMS", psych_params, (20,20))
plt.savefig("RelativewMSConv.jpg",bbox_inches='tight')

plt.figure()
plot_stim_comparison(img_data, stim_sweep_data, "SSIM", psych_params, (20,20))
plt.savefig("RelativeSSMConv.jpg",bbox_inches='tight')


In [None]:
## single image comparison for numStims ~ 500

Tidx = 6
img_num = 980

origIm = np.reshape(img_data.img_set[:,img_num],pixel_dims,order='F')
mseRec = np.reshape(stim_sweep_data.img_set_mse[Tidx,:,img_num],pixel_dims,order='F')
wmsRec = np.reshape(stim_sweep_data.img_set_wms[Tidx,:,img_num],pixel_dims,order='F')
ssmRec = np.reshape(stim_sweep_data.img_set_ssm[Tidx,:,img_num],pixel_dims,order='F')

plt.figure()
plt.imshow(origIm,cmap='bone',vmax=.5,vmin=-.5)
plt.axis('off')
plt.title("Original Image")
plt.savefig('OriginalDistorted.jpg',bbox_inches='tight')
plt.show()


plt.figure()
plt.imshow(mseRec,cmap='bone',vmax=.5,vmin=-.5)
plt.axis('off')
plt.title("MSE Distorted Reconstruction")
plt.savefig('MSEReconsDistorted.jpg',bbox_inches='tight')


plt.figure()
plt.imshow(wmsRec,cmap='bone',vmax=.5,vmin=-.5)
plt.axis('off')
plt.title("wMSE Distorted Reconstruction")
plt.savefig('wMSReconsDistorted.jpg',bbox_inches='tight')


plt.figure()
plt.imshow(ssmRec,cmap='bone',vmax=.5,vmin=-.5)
plt.axis('off')
plt.title("SSIM Distorted Reconstruction")
plt.savefig('SSMReconsDistorted.jpg',bbox_inches='tight')



In [None]:
## display entire distorted visual scene tiled by reconstructions for each error metric
plt.figure()
plt.imshow(img_data.resampled_img,cmap='bone',vmax=vmax,vmin=vmin)
plt.axis('off')
plt.title('Original Visual Scene')
plt.savefig('OrigVisSceneDist.jpg',bbox_inches='tight')
plt.show()

plt.figure()
plt.imshow(stim_sweep_data.rec_set_mse[Tidx],cmap='bone',vmax=vmax,vmin=vmin)
plt.axis('off')
plt.title('MSE-Distorted Reconstructed Visual Scene')
plt.savefig('MSEVisSceneDist.jpg',bbox_inches='tight')

plt.figure()
plt.imshow(stim_sweep_data.rec_set_wms[Tidx],cmap='bone',vmax=vmax,vmin=vmin)
plt.axis('off')
plt.title('wMSE-Distorted Reconstructed Visual Scene')
plt.savefig('wMSVisSceneDist.jpg',bbox_inches='tight')

plt.figure()
plt.imshow(stim_sweep_data.rec_set_ssm[Tidx],cmap='bone',vmax=vmax,vmin=vmin)
plt.axis('off')
plt.title('SSIM-Dist Reconstructed Visual Scene')
plt.savefig('SSMVisSceneDist.jpg',bbox_inches='tight')

In [None]:
## plot the frequency weighting mask
XO = psychParams["XO"]

psychParams["e"] = 20
N  =  int(imgData.origImg.shape[0]/imgData.sDims[0]) # number of selection blocks (number of samples of DC terms of each subImage)
offset = (1/2) * (N / XO)
Wp = csf(psychParams,pixelDims,offset=offset) #offset frequency b


sfRes = 1/20
ppd = 20/psychParams["elecXO"]
fs = (sfRes * ppd *(np.arange(20) )) + offset
fs = np.round(fs,1)
ticks = np.arange(0,20,4)
labels = fs[np.arange(0,20,4)]
      
plt.figure(figsize=(10,10))
plt.imshow(Wp,cmap='bone',vmax=10,vmin=0)
plt.xticks(ticks,labels)
plt.yticks(ticks,labels)
plt.xlabel('Horizontal Spatial Frequency (cpd)')
plt.ylabel('Vertical Spatial Frequency (cpd)')
plt.colorbar()
plt.show()

In [None]:
## compare roughness of reconstructions to original

def roughness(img):
    # given an image (either a  vector or 2d matrix), measure the roughness
    # defined as Frobenius norm of the hessian (2nd order derivative) matrix
    del2 = np.gradient(img)

    return np.linalg.norm(del2)

def get_rougs_vec(imgs):
    num_imgs = imgs.shape[1]
    return [roughness(imgs[:,i]) for i in np.arange(num_imgs)]
    


img_list = os.listdir('./pics')
num_pics = len(img_list)

                            
plt.figure(figsize=(16,9))
a_point = 1
s = 1

%matplotlib notebook
sim_data = data_set["sim_data_0"]
roughs_mat_mse = get_rougs_vec(sim_data.imgs_mse)
roughs_mat_wms = get_rougs_vec(sim_data.imgs_wms)
roughs_mat_ssm = get_rougs_vec(sim_data.imgs_ssm)

for i in np.arange(num_pics-1)+1:
    sim_data = data_set["sim_data_"+str(i)]
    roughs_mse = get_rougs_vec(sim_data.imgs_mse)
    roughs_wms = get_rougs_vec(sim_data.imgs_wms)
    roughs_ssm = get_rougs_vec(sim_data.imgs_ssm) 
    
    roughs_mat_mse = np.hstack((roughs_mat_mse,roughs_mse))
    roughs_mat_wms = np.hstack((roughs_mat_wms,roughs_wms))
    roughs_mat_ssm = np.hstack((roughs_mat_ssm,roughs_ssm))

plt.scatter(roughs_mat_ssm,roughs_mat_mse,label='MSE Reconstructions',s=s,alpha=a_point,marker='.')
plt.scatter(roughs_mat_ssm,roughs_mat_wms,label='wMSE Reconstructions',s=s,alpha=a_point,marker='.')
plt.scatter(roughs_mat_ssm,roughs_mat_ssm,label='SSIM Reconstructions',s=s,alpha=a_point,marker='.')
plt.title("MSE and wMSE Gradient Energy vs. SSIM Gradient Energy \n Data from 14 Natural Scenes")
plt.xlabel('Gradient Energy of SSIM Image Reconstruction')
plt.ylabel('Gradient Energy of Image Reconstruction')
plt.legend()
plt.savefig('roughs_vs_ssm.jpg',bbox_inches='tight')


#report correlation coefficients of each variable:
mse_v_wms_v_ssm_roughs = np.vstack((roughs_mse,roughs_wms,roughs_ssm))
corr_coeffs = np.corrcoef(mse_v_wms_v_ssm_roughs)
print((corr_coeffs))

print(np.sqrt((roughs_mat_wms-roughs_mat_ssm).T@(roughs_mat_wms-roughs_mat_ssm)/roughs_mat_mse.size))
print(np.sqrt((roughs_mat_ssm-roughs_mat_mse).T@(roughs_mat_ssm-roughs_mat_mse)/roughs_mat_mse.size))






In [None]:
## plot contrast/ numspikes for each metric
from scipy import stats
xs = np.linspace(np.min(spikes_mat_ssm),np.max(spikes_mat_ssm),1000)

a = 1
s = .05
plt.figure(figsize=(16,9))
plt.scatter(spikes_mat_mse,roughs_mat_mse,label='MSE',alpha=a,s=s,c='#1f77b4')

plt.scatter(spikes_mat_wms,roughs_mat_wms,label='wMSE',alpha=a,s=s,c='#ff7f0e')

plt.scatter(spikes_mat_ssm,roughs_mat_ssm,label='SSIM',alpha=a,s=s,c='#2ca02c')
plt.show()


plt.title('Image Gradient Energy vs. Number of Spikes used in Reconstruction')
plt.xlabel('Total Number of Spikes in Optimal Reconstruction \n Data from 14 Natural Scenes')
plt.ylabel('Gradient Energy')
plt.savefig('roughs_vs_num_spikes.jpg',bbox_inches='tight')
plt.show()





mse_corr = np.corrcoef(np.vstack((spikes_mat_mse,roughs_mat_mse)))
wms_corr = np.corrcoef(np.vstack((spikes_mat_wms,roughs_mat_wms)))
ssm_corr = np.corrcoef(np.vstack((spikes_mat_ssm,roughs_mat_ssm)))
print(mse_corr)
print(wms_corr)
print(ssm_corr)