Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions figures/kCSD_properties/different_error_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def grid(x, y, z):
x = x.flatten()
y = y.flatten()
z = z.flatten()
xi, yi = np.mgrid[min(x):max(x):np.complex(0, 100),
min(y):max(y):np.complex(0, 100)]
xi, yi = np.mgrid[min(x):max(x):complex(0, 100),
min(y):max(y):complex(0, 100)]
zi = griddata((x, y), z, (xi, yi), method='linear')
return xi, yi, zi

Expand Down Expand Up @@ -94,8 +94,8 @@ def do_kcsd(CSD_PROFILE, data, csd_seed, prefix, missing_ele):
rstate = np.random.RandomState(42) # just a random seed
rmv = rstate.choice(ele_pos.shape[0], remove_num, replace=False)
ele_pos = np.delete(ele_pos, rmv, 0)
# Potentials generated

# Potentials generated
pots = np.zeros(ele_pos.shape[0])
pots = data['pots']
h = 50.
Expand Down Expand Up @@ -132,7 +132,7 @@ def do_kcsd(CSD_PROFILE, data, csd_seed, prefix, missing_ele):
v_max = np.max(np.abs(pots))
levels_pot = np.linspace(-1 * v_max, v_max, 16)
im = ax.contourf(pot_X, pot_Y, pot_Z,
levels=levels_pot, cmap=cm.PRGn)
levels=levels_pot, cmap=cm.PRGn)
ax.scatter(ele_pos[:, 0], ele_pos[:, 1], 10, c='k')
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
Expand All @@ -148,7 +148,7 @@ def do_kcsd(CSD_PROFILE, data, csd_seed, prefix, missing_ele):
t_max = np.max(np.abs(est_csd_pre_cv[:, :, 0]))
levels_kcsd = np.linspace(-1 * t_max, t_max, 16, endpoint=True)
im = ax.contourf(k.estm_x, k.estm_y, est_csd_pre_cv[:, :, 0],
levels=levels_kcsd, cmap=cm.bwr)
levels=levels_kcsd, cmap=cm.bwr)
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.set_xticks([0, 0.5, 1])
Expand All @@ -163,7 +163,7 @@ def do_kcsd(CSD_PROFILE, data, csd_seed, prefix, missing_ele):
t_max = np.max(np.abs(est_csd_post_cv[:, :, 0]))
levels_kcsd = np.linspace(-1 * t_max, t_max, 16, endpoint=True)
im = ax.contourf(k.estm_x, k.estm_y, est_csd_post_cv[:, :, 0],
levels=levels_kcsd, cmap=cm.bwr)
levels=levels_kcsd, cmap=cm.bwr)
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.set_xticks([0, 0.5, 1])
Expand All @@ -180,7 +180,7 @@ def do_kcsd(CSD_PROFILE, data, csd_seed, prefix, missing_ele):
t_max = np.max(abs(error1))
levels_kcsd = np.linspace(0, t_max, 16, endpoint=True)
im = ax.contourf(k.estm_x, k.estm_y, error1,
levels=levels_kcsd, cmap=cm.Greys)
levels=levels_kcsd, cmap=cm.Greys)
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.set_xticks([0, 0.5, 1])
Expand All @@ -197,7 +197,7 @@ def do_kcsd(CSD_PROFILE, data, csd_seed, prefix, missing_ele):
t_max = np.max(abs(error2))
levels_kcsd = np.linspace(0, t_max, 16, endpoint=True)
im = ax.contourf(k.estm_x, k.estm_y, error2,
levels=levels_kcsd, cmap=cm.Greys)
levels=levels_kcsd, cmap=cm.Greys)
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.set_xticks([0, 0.5, 1])
Expand All @@ -206,15 +206,15 @@ def do_kcsd(CSD_PROFILE, data, csd_seed, prefix, missing_ele):
ax.set_title('Normalized difference')
ticks = np.linspace(0, t_max, 3, endpoint=True)
plt.colorbar(im, orientation='horizontal', format='%.2f', ticks=ticks, pad=0.25)

ax = plt.subplot(247)
error3 = calculate_rdm(true_csd, est_csd_post_cv[:, :, 0])
print(error3.shape)
ax.set_aspect('equal')
t_max = np.max(abs(error3))
levels_kcsd = np.linspace(0, t_max, 16, endpoint=True)
im = ax.contourf(k.estm_x, k.estm_y, error3,
levels=levels_kcsd, cmap=cm.Greys)
levels=levels_kcsd, cmap=cm.Greys)
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.set_xticks([0, 0.5, 1])
Expand All @@ -231,7 +231,7 @@ def do_kcsd(CSD_PROFILE, data, csd_seed, prefix, missing_ele):
t_max = np.max(abs(error4))
levels_kcsd = np.linspace(0, t_max, 16, endpoint=True)
im = ax.contourf(k.estm_x, k.estm_y, error4,
levels=levels_kcsd, cmap=cm.Greys)
levels=levels_kcsd, cmap=cm.Greys)
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
ax.set_xticks([0, 0.5, 1])
Expand All @@ -249,7 +249,7 @@ def do_kcsd(CSD_PROFILE, data, csd_seed, prefix, missing_ele):

if __name__ == '__main__':
CSD_PROFILE = CSD.gauss_2d_large #CSD.gauss_2d_small #

prefix = '/home/mkowalska/Marta/kCSD-python/figures/kCSD_properties/small_srcs_all_ele'
for csd_seed in range(100):
data = np.load(prefix + '/' + str(csd_seed) + '.npz')
Expand Down
212 changes: 135 additions & 77 deletions figures/kCSD_properties/kCSD_with_reliability_map_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,22 @@ def set_axis(ax, letter=None):
-------
ax: modyfied Axes object.
"""
ax.text(
-0.05,
1.05,
letter,
fontsize=20,
weight='bold',
transform=ax.transAxes)
ax.text(-0.05, 1.05, letter, fontsize=20, weight="bold", transform=ax.transAxes)
return ax


def make_reconstruction(KK, csd_profile, csd_seed, total_ele,
ele_lims=None, noise=0, nr_broken_ele=None,
Rs=None, lambdas=None, method='cross-validation'):
def make_reconstruction(
KK,
csd_profile,
csd_seed,
total_ele,
ele_lims=None,
noise=0,
nr_broken_ele=None,
Rs=None,
lambdas=None,
method="cross-validation",
):
"""
Main method, makes the whole kCSD reconstruction.

Expand Down Expand Up @@ -96,58 +99,66 @@ def make_reconstruction(KK, csd_profile, csd_seed, total_ele,
Potentials measured (calculated) on electrodes.
"""
csd_at, true_csd = KK.generate_csd(csd_profile, csd_seed)
ele_pos, pots = KK.electrode_config(csd_profile, csd_seed, total_ele,
ele_lims, KK.h, KK.sigma,
noise, nr_broken_ele)
k, est_csd = KK.do_kcsd(pots, ele_pos, method=method, Rs=Rs,
lambdas=lambdas)
ele_pos, pots = KK.electrode_config(
csd_profile, csd_seed, total_ele, ele_lims, KK.h, KK.sigma, noise, nr_broken_ele
)
k, est_csd = KK.do_kcsd(pots, ele_pos, method=method, Rs=Rs, lambdas=lambdas)
return k, csd_at, true_csd, ele_pos, pots


def make_subplot(ax, val_type, xs, ys, values, cax, title=None, ele_pos=None,
xlabel=False, ylabel=False, letter='', t_max=None,
mask=False, level=False):
if val_type == 'csd':
def make_subplot(
ax,
val_type,
xs,
ys,
values,
cax,
title=None,
ele_pos=None,
xlabel=False,
ylabel=False,
letter="",
t_max=None,
mask=False,
level=False,
):
if val_type == "csd":
cmap = cm.bwr
elif val_type == 'pot':
elif val_type == "pot":
cmap = cm.PRGn
else:
cmap = cm.Greys
ax.set_aspect('equal')
ax.set_aspect("equal")
if t_max is None:
t_max = np.max(np.abs(values))
if level is not False:
levels = level
else:
levels = np.linspace(-t_max, t_max, 32)
if val_type == 'pot':
if val_type == "pot":
X, Y, Z = grid(ele_pos[:, 0], ele_pos[:, 1], values)
im = ax.contourf(X, Y, Z, levels=levels, cmap=cmap, alpha=1)
else:
im = ax.contourf(xs, ys, values,
levels=levels, cmap=cmap, alpha=1,
extent=(0, 0.5, 0, 0.5))
im = ax.contourf(
xs, ys, values, levels=levels, cmap=cmap, alpha=1, extent=(0, 0.5, 0, 0.5)
)
if mask is not False:
CS = ax.contour(xs, ys, mask, cmap='Greys')
ax.clabel(CS, # label every second level
inline=1,
fmt='%1.2f',
fontsize=9)
if val_type == 'pot':
ax.scatter(ele_pos[:, 0], ele_pos[:, 1], 10, c='k')
CS = ax.contour(xs, ys, mask, cmap="Greys")
ax.clabel(CS, inline=1, fmt="%1.2f", fontsize=9) # label every second level
if val_type == "pot":
ax.scatter(ele_pos[:, 0], ele_pos[:, 1], 10, c="k")
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
if xlabel:
ax.set_xlabel('X (mm)')
ax.set_xlabel("X (mm)")
if ylabel:
ax.set_ylabel('Y (mm)')
ax.set_ylabel("Y (mm)")
if title is not None:
ax.set_title(title)
ax.set_xticks([0, 0.5, 1])
ax.set_yticks([0, 0.5, 1])
ticks = np.linspace(-t_max, t_max, 3, endpoint=True)
plt.colorbar(im, cax=cax, orientation='horizontal', format='%.2f',
ticks=ticks)
plt.colorbar(im, cax=cax, orientation="horizontal", format="%.2f", ticks=ticks)
set_axis(ax, letter=letter)
return ax, cax

Expand All @@ -171,44 +182,83 @@ def grid(x, y, z, resX=100, resY=100):
x = x.flatten()
y = y.flatten()
z = z.flatten()
xi, yi = np.mgrid[min(x):max(x):np.complex(0, resX),
min(y):max(y):np.complex(0, resY)]
zi = griddata((x, y), z, (xi, yi), method='linear')
xi, yi = np.mgrid[
min(x) : max(x) : complex(0, resX), min(y) : max(y) : complex(0, resY)
]
zi = griddata((x, y), z, (xi, yi), method="linear")
return xi, yi, zi


def generate_figure(k, true_csd, ele_pos, pots, mask=False):
csd_at = np.mgrid[0.:1.:100j,
0.:1.:100j]
csd_at = np.mgrid[0.0:1.0:100j, 0.0:1.0:100j]
csd_x, csd_y = csd_at
plt.figure(figsize=(17, 6))
gs = gridspec.GridSpec(2, 4, height_ratios=[1., 0.04], width_ratios=[1]*4)
# gs.update(top=.95, bottom=0.53)
gs = gridspec.GridSpec(2, 4, height_ratios=[1.0, 0.04], width_ratios=[1] * 4)
# gs.update(top=.95, bottom=0.53)
ax = plt.subplot(gs[0, 0])
cax = plt.subplot(gs[1, 0])
make_subplot(ax, 'csd', csd_x, csd_y, true_csd, cax=cax, ele_pos=ele_pos,
title='True CSD', xlabel=True, ylabel=True, letter='A',
t_max=np.max(abs(true_csd)))
make_subplot(
ax,
"csd",
csd_x,
csd_y,
true_csd,
cax=cax,
ele_pos=ele_pos,
title="True CSD",
xlabel=True,
ylabel=True,
letter="A",
t_max=np.max(abs(true_csd)),
)
ax = plt.subplot(gs[0, 1])
cax = plt.subplot(gs[1, 1])
make_subplot(ax, 'pot', ele_pos[:, 0], ele_pos[:, 1], pots, cax=cax,
ele_pos=ele_pos, title='Potentials', xlabel=True, letter='B',
t_max=np.max(abs(pots)))
make_subplot(
ax,
"pot",
ele_pos[:, 0],
ele_pos[:, 1],
pots,
cax=cax,
ele_pos=ele_pos,
title="Potentials",
xlabel=True,
letter="B",
t_max=np.max(abs(pots)),
)
ax = plt.subplot(gs[0, 2])
cax = plt.subplot(gs[1, 2])
make_subplot(ax, 'csd', k.estm_x, k.estm_y, k.values('CSD')[:, :, 0],
cax=cax, ele_pos=ele_pos, title='kCSD with Reliability Map',
xlabel=True, letter='C', t_max=np.max(abs(true_csd)),
mask=mask)
make_subplot(
ax,
"csd",
k.estm_x,
k.estm_y,
k.values("CSD")[:, :, 0],
cax=cax,
ele_pos=ele_pos,
title="kCSD with Reliability Map",
xlabel=True,
letter="C",
t_max=np.max(abs(true_csd)),
mask=mask,
)
ax = plt.subplot(gs[0, 3])
cax = plt.subplot(gs[1, 3])
make_subplot(ax, 'diff', csd_x, csd_y,
abs(true_csd-k.values('CSD')[:, :, 0]), cax=cax,
ele_pos=ele_pos, title='|True CSD - kCSD|', xlabel=True,
letter='D',
t_max=np.max(abs(true_csd-k.values('CSD')[:, :, 0])),
level=np.linspace(0, np.max(abs(true_csd-k.values('CSD')[:, :, 0])), 16))
plt.savefig('kCSD_with_reliability_map_2D.png', dpi=300)
make_subplot(
ax,
"diff",
csd_x,
csd_y,
abs(true_csd - k.values("CSD")[:, :, 0]),
cax=cax,
ele_pos=ele_pos,
title="|True CSD - kCSD|",
xlabel=True,
letter="D",
t_max=np.max(abs(true_csd - k.values("CSD")[:, :, 0])),
level=np.linspace(0, np.max(abs(true_csd - k.values("CSD")[:, :, 0])), 16),
)
plt.savefig("kCSD_with_reliability_map_2D.png", dpi=300)
plt.show()


Expand All @@ -222,32 +272,40 @@ def matrix_symmetrization(point_error):
r11 = np.rot90(arr_lr, k=1, axes=(1, 2))
r12 = np.rot90(arr_lr, k=2, axes=(1, 2))
r13 = np.rot90(arr_lr, k=3, axes=(1, 2))
symm_array = np.concatenate((point_error, r1, r2, r3, arr_lr, r11, r12,
r13))
symm_array = np.concatenate((point_error, r1, r2, r3, arr_lr, r11, r12, r13))
return symm_array


if __name__ == '__main__':
if __name__ == "__main__":
CSD_PROFILE = CSD.gauss_2d_large
CSD_SEED = 16
ELE_LIMS = [0.05, 0.95] # range of electrodes space
method = 'cross-validation'
method = "cross-validation"
Rs = np.arange(0.2, 0.5, 0.1)
lambdas = None
noise = 0

KK = ValidateKCSD2D(CSD_SEED, h=50., sigma=1., n_src_init=400,
est_xres=0.01, est_yres=0.01, ele_lims=ELE_LIMS)
k, csd_at, true_csd, ele_pos, pots = make_reconstruction(KK, CSD_PROFILE,
CSD_SEED,
total_ele=100,
noise=noise,
Rs=Rs,
lambdas=lambdas,
method=method)
error_l = np.load('error_maps_2D/point_error_large_100_all_ele.npy')
error_s = np.load('error_maps_2D/point_error_small_100_all_ele.npy')
KK = ValidateKCSD2D(
CSD_SEED,
h=50.0,
sigma=1.0,
n_src_init=400,
est_xres=0.01,
est_yres=0.01,
ele_lims=ELE_LIMS,
)
k, csd_at, true_csd, ele_pos, pots = make_reconstruction(
KK,
CSD_PROFILE,
CSD_SEED,
total_ele=100,
noise=noise,
Rs=Rs,
lambdas=lambdas,
method=method,
)
error_l = np.load("error_maps_2D/point_error_large_100_all_ele.npy")
error_s = np.load("error_maps_2D/point_error_small_100_all_ele.npy")
error_all = np.concatenate((error_l, error_s))
symm_array_all = matrix_symmetrization(error_all)
generate_figure(k, true_csd, ele_pos, pots,
mask=np.mean(symm_array_all, axis=0))
generate_figure(k, true_csd, ele_pos, pots, mask=np.mean(symm_array_all, axis=0))
Loading