Skip to content

Commit

Permalink
Merge 976ecbf into a92889e
Browse files Browse the repository at this point in the history
  • Loading branch information
m-kowalska committed Jul 8, 2019
2 parents a92889e + 976ecbf commit d62a8a0
Show file tree
Hide file tree
Showing 13 changed files with 981 additions and 290 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
import os
from os.path import expanduser
import numpy as np
from numpy.linalg import LinAlgError
import matplotlib.pyplot as plt
from figure_properties import *
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from matplotlib.ticker import FuncFormatter
import datetime
import time

from kcsd import SpectralStructure, KCSD1D
from kcsd import KCSD1D
import targeted_basis as tb

__abs_file__ = os.path.abspath(__file__)
Expand All @@ -24,8 +22,7 @@ def _html(r, g, b):


def stability_M(csd_profile, n_src, ele_lims, true_csd_xlims,
total_ele, ele_pos, pots,
method='cross-validation', Rs=None, lambdas=None):
total_ele, ele_pos, pots, R_init=0.23):
"""
Investigates stability of reconstruction for different number of basis
sources
Expand All @@ -46,15 +43,9 @@ def stability_M(csd_profile, n_src, ele_lims, true_csd_xlims,
Electrodes positions.
pots: numpy array
Values of potentials at ele_pos.
method: string
Determines the method of regularization.
Default: cross-validation.
Rs: numpy 1D array
Basis source parameter for crossvalidation.
Default: None.
lambdas: numpy 1D array
Regularization parameter for crossvalidation.
Default: None.
R_init: float
Initial value of R parameter - width of basis source
Default: 0.23.
Returns
-------
Expand All @@ -70,14 +61,19 @@ def stability_M(csd_profile, n_src, ele_lims, true_csd_xlims,
for i, value in enumerate(n_src):
pots = pots.reshape((len(ele_pos), 1))
obj = KCSD1D(ele_pos, pots, src_type='gauss', sigma=0.3, h=0.25,
gdx=0.01, n_src_init=n_src[i], ext_x=0, xmin=0, xmax=1)
if method == 'cross-validation':
obj.cross_validate(Rs=Rs, lambdas=lambdas)
elif method == 'L-curve':
obj.L_curve(Rs=Rs, lambdas=lambdas)
ss = SpectralStructure(obj)
eigenvectors[i], eigenvalues[i] = ss.evd()

gdx=0.01, n_src_init=n_src[i], ext_x=0, xmin=0, xmax=1,
R_init=R_init)
try:
eigenvalue, eigenvector = np.linalg.eigh(obj.k_pot +
obj.lambd *
np.identity
(obj.k_pot.shape[0]))
except LinAlgError:
raise LinAlgError('EVD is failing - try moving the electrodes'
'slightly')
idx = eigenvalue.argsort()[::-1]
eigenvalues[i] = eigenvalue[idx]
eigenvectors[i] = eigenvector[:, idx]
obj_all.append(obj)
return obj_all, eigenvalues, eigenvectors

Expand Down Expand Up @@ -112,8 +108,7 @@ def set_axis(ax, x, y, letter=None):


def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
save_path, method='cross-validation', Rs=None,
lambdas=None, noise=0):
save_path, noise=0, R_init=0.23):
"""
Generates figure for spectral structure decomposition.
Expand All @@ -135,18 +130,12 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
Electrodes limits.
save_path: string
Directory.
method: string
Determines the method of regularization.
Default: cross-validation.
Rs: numpy 1D array
Basis source parameter for crossvalidation.
Default: None.
lambdas: numpy 1D array
Regularization parameter for crossvalidation.
Default: None.
noise: float
Determines the level of noise in the data.
Default: 0.
R_init: float
Initial value of R parameter - width of basis source
Default: 0.23.
Returns
-------
Expand All @@ -162,16 +151,14 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
OBJ_M, eigenval_M, eigenvec_M = stability_M(csd_profile, n_src_M,
ele_lims, true_csd_xlims,
total_ele, ele_pos, pots,
method=method, Rs=Rs,
lambdas=lambdas)
R_init=R_init)

plt_cord = [(2, 0), (2, 2), (2, 4),
(3, 0), (3, 2), (3, 4),
(4, 0), (4, 2), (4, 4),
(5, 0), (5, 2), (5, 4)]


letters = ['C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'O']
letters = ['C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N']

BLACK = _html(0, 0, 0)
ORANGE = _html(230, 159, 0)
Expand Down Expand Up @@ -199,10 +186,8 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
linestyle=linestyles[indx], color=colors[indx],
marker=markers[indx], label='M='+str(n_src_M[i]),
markersize=10)
# ax.set_title(' ', fontsize=12)
ht, lh = ax.get_legend_handles_labels()
set_axis(ax, -0.05, 1.05, letter='A')
# ax.legend(loc='lower left')
ax.set_xlabel('Number of components')
ax.set_ylabel('Eigenvalues')
ax.set_yscale('log')
Expand All @@ -213,7 +198,6 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
ax = fig.add_subplot(gs[0, 3:])
ax.plot(n_src_M, eigenval_M[:, 0], marker='s', color='k', markersize=5,
linestyle=' ')
#ax.set_title(' ', fontsize=12)
set_axis(ax, -0.05, 1.05, letter='B')
ax.set_xlabel('Number of basis sources')
ax.set_xscale('log')
Expand All @@ -229,13 +213,9 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
eigenvec_M[j, :, i]),
linestyle=linestyles[idx], color=colors[idx],
label='M='+str(n_src_M[j]), lw=2)
#ax.set_title(r"$\tilde{K}*v_{{%(i)d}}$" % {'i': i+1})
ax.text(0.5, 1., r"$\tilde{K}*v_{{%(i)d}}$" % {'i': i+1},
horizontalalignment='center', transform=ax.transAxes, fontsize=20)
# ax.locator_params(axis='y', nbins=3)

# ax.set_xlabel('Depth (mm)', fontsize=12)
# ax.set_ylabel('CSD (mA/mm)', fontsize=12)
ax.text(0.5, 1., r"$\tilde{K}\cdot{v_{{%(i)d}}}$" % {'i': i+1},
horizontalalignment='center', transform=ax.transAxes,
fontsize=20)
set_axis(ax, -0.10, 1.1, letter=letters[i])
if i < 9:
ax.get_xaxis().set_visible(False)
Expand All @@ -245,23 +225,12 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
if i % 3 == 0:
ax.set_ylabel('CSD ($mA/mm$)')
ax.yaxis.set_label_coords(-0.18, 0.5)
# ax.yaxis.get_major_formatter().set_powerlimits((0, 1))
# ax.tick_params(direction='out', pad=10)
# ax.yaxis.get_major_formatter(FormatStrFormatter('%.2f'))
ax.ticklabel_format(style='sci', axis='y', scilimits=((0.0, 0.0)))
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
# ht, lh = ax.get_legend_handles_labels()

# ax = fig.add_subplot(gs[3, :])
# ax.legend(ht, lh, fancybox=False, shadow=False, ncol=len(src_idx),
# loc='upper center', frameon=False, bbox_to_anchor=(0.5, 0.0))
# ax.axis('off')

# plt.tight_layout()
fig.legend(ht, lh, loc='lower center', ncol=5, frameon=False)
fig.savefig(os.path.join(save_path, 'vectors_' + method +
'_noise_' + str(noise) + '.png'), dpi=300)
fig.savefig(os.path.join(save_path, 'vectors_' + '_noise_' +
str(noise) + 'R0_2' + '.png'), dpi=300)

plt.show()

Expand All @@ -281,6 +250,6 @@ def generate_figure(csd_profile, R, MU, true_csd_xlims, total_ele, ele_lims,
CSD_PROFILE = tb.csd_profile
R = 0.2
MU = 0.25
R_init = 0.2
generate_figure(CSD_PROFILE, R, MU, TRUE_CSD_XLIMS, TOTAL_ELE, ELE_LIMS,
SAVE_PATH, method='cross-validation',
Rs=np.arange(0.1, 0.5, 0.05), noise=None)
SAVE_PATH, noise=None, R_init=R_init)

0 comments on commit d62a8a0

Please sign in to comment.