In [None]:
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import seaborn as sns

%matplotlib inline

In [None]:
plt.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.sans-serif": ["Computer Modern Roman"]})

In [None]:
base_path = '/storage/fits/neurocorr'

# Predicting Vowels (3 choices)

In [None]:
results_path = os.path.join(base_path, f"exp6_v_values_cv_25_1000_1000.h5")
results = h5py.File(results_path, 'r')

In [None]:
v_sdkl = results['v_sdkl'][:]
v_lfi = results['v_lfi'][:]
units = results['units'][:]
stims = results['stims'][:]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))

n_max_units = units.shape[2]
dims = 2 + np.arange(n_max_units - 1)
sdkl_mean = np.median(v_sdkl, axis=1)

ax.fill_between(
    x=dims,
    y1=np.percentile(v_sdkl, q=25, axis=1),
    y2=np.percentile(v_sdkl, q=75, axis=1),
    color='black',
    alpha=0.1)
ax.plot(
    dims,
    sdkl_mean,
    linewidth=4,
    color='black')

ax.set_xlim([2, 25])
ax.set_ylim([0, 3.5])
ax.tick_params(labelsize=15)
ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=15)
ax.set_ylabel(r'\textbf{sDKL}', fontsize=15)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))

n_max_units = units.shape[2]
dims = 2 + np.arange(n_max_units - 1)
sdkl_mean = np.median(v_lfi, axis=1)

ax.fill_between(
    x=dims,
    y1=np.percentile(v_lfi, q=25, axis=1),
    y2=np.percentile(v_lfi, q=75, axis=1),
    color='black',
    alpha=0.1)
ax.plot(
    dims,
    sdkl_mean,
    linewidth=4,
    color='black')

ax.set_xlim([2, 25])
ax.set_ylim([0, 0.5])
ax.tick_params(labelsize=15)
ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=15)
ax.set_ylabel(r'\textbf{Linear Fisher Information}', fontsize=15)
plt.show()

In [None]:
p_s_val_lfi = np.mean(v_lfi[..., np.newaxis] > results['v_s_lfi'][:], axis=-1)
p_s_val_sdkl = np.mean(v_sdkl[..., np.newaxis] > results['v_s_sdkl'][:], axis=-1)
p_r_val_lfi = np.mean(v_lfi[..., np.newaxis] > results['v_r_lfi'][:], axis=-1)
p_r_val_sdkl = np.mean(v_sdkl[..., np.newaxis] > results['v_r_sdkl'][:], axis=-1)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

n_max_units = units.shape[2]
dims = 2 + np.arange(n_max_units - 1)

# Shuffle
ax.plot(
    dims, np.median(p_s_val_sdkl, axis=1),
    linewidth=3,
    color='black',
    label='Shuffle')
ax.fill_between(
    x=dims,
    y1=np.percentile(p_s_val_sdkl, q=40, axis=1),
    y2=np.percentile(p_s_val_sdkl, q=60, axis=1),
    color='black',
    alpha=0.1)

# Rotation
ax.plot(
    dims, np.median(p_r_val_sdkl, axis=1),
    linewidth=3,
    color='red',
    label='Rotation')
ax.fill_between(
    x=dims,
    y1=np.percentile(p_r_val_sdkl, q=40, axis=1),
    y2=np.percentile(p_r_val_sdkl, q=60, axis=1),
    color=f'red',
    alpha=0.1)
    

ax.set_xlim([2, 25])
ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=15)
ax.set_ylim([-0.05, 1.05])
ax.set_yticks([0, 0.25, 0.50, 0.75, 1.0])
ax.tick_params(labelsize=15)

ax.legend(loc='best', prop={'size': 18})
ax.set_xlim([2, 25])
ax.set_ylim([-0.05, 1.05])

ax.set_ylabel(r'\textbf{sDKL Percentile}', fontsize=15)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

n_max_units = units.shape[2]
dims = 2 + np.arange(n_max_units - 1)

# Shuffle
ax.plot(
    dims, np.median(p_s_val_lfi, axis=1),
    linewidth=3,
    color='black',
    label='Shuffle')
ax.fill_between(
    x=dims,
    y1=np.percentile(p_s_val_lfi, q=40, axis=1),
    y2=np.percentile(p_s_val_lfi, q=60, axis=1),
    color='black',
    alpha=0.1)

# Rotation
ax.plot(
    dims, np.median(p_r_val_lfi, axis=1),
    linewidth=3,
    color='red',
    label='Rotation')
ax.fill_between(
    x=dims,
    y1=np.percentile(p_r_val_lfi, q=40, axis=1),
    y2=np.percentile(p_r_val_lfi, q=60, axis=1),
    color=f'red',
    alpha=0.1)
    

ax.set_xlim([2, 25])
ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=15)
ax.set_ylim([-0.05, 1.05])
ax.set_yticks([0, 0.25, 0.50, 0.75, 1.0])
ax.tick_params(labelsize=15)

ax.legend(loc='best', prop={'size': 18})
ax.set_xlim([2, 25])
ax.set_ylim([-0.05, 1.05])

ax.set_ylabel(r'\textbf{LFI Percentile}', fontsize=15)
plt.show()

# Predicting Consonants (16 choices)

In [None]:
results_path = os.path.join(base_path, f"exp6_c_values_cv_25_1000_1000.h5")
results = h5py.File(results_path, 'r')

In [None]:
v_sdkl = results['v_sdkl'][:]
v_lfi = results['v_lfi'][:]
units = results['units'][:]
stims = results['stims'][:]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))

n_max_units = units.shape[2]
dims = 2 + np.arange(n_max_units - 1)
sdkl_mean = np.median(v_sdkl, axis=1)

ax.fill_between(
    x=dims,
    y1=np.percentile(v_sdkl, q=25, axis=1),
    y2=np.percentile(v_sdkl, q=75, axis=1),
    color='black',
    alpha=0.1)
ax.plot(
    dims,
    sdkl_mean,
    linewidth=4,
    color='black')

ax.set_xlim([2, 25])
ax.set_ylim([0, 3.5])
ax.tick_params(labelsize=15)
ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=15)
ax.set_ylabel(r'\textbf{sDKL}', fontsize=15)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 5))

n_max_units = units.shape[2]
dims = 2 + np.arange(n_max_units - 1)
sdkl_mean = np.median(v_lfi, axis=1)

ax.fill_between(
    x=dims,
    y1=np.percentile(v_lfi, q=25, axis=1),
    y2=np.percentile(v_lfi, q=75, axis=1),
    color='black',
    alpha=0.1)
ax.plot(
    dims,
    sdkl_mean,
    linewidth=4,
    color='black')

ax.set_xlim([2, 25])
ax.set_ylim([0, 0.1])
ax.tick_params(labelsize=15)
ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=15)
ax.set_ylabel(r'\textbf{Linear Fisher Information}', fontsize=15)
plt.show()

In [None]:
p_s_val_lfi = np.mean(v_lfi[..., np.newaxis] > results['v_s_lfi'][:], axis=-1)
p_s_val_sdkl = np.mean(v_sdkl[..., np.newaxis] > results['v_s_sdkl'][:], axis=-1)
p_r_val_lfi = np.mean(v_lfi[..., np.newaxis] > results['v_r_lfi'][:], axis=-1)
p_r_val_sdkl = np.mean(v_sdkl[..., np.newaxis] > results['v_r_sdkl'][:], axis=-1)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

n_max_units = units.shape[2]
dims = 2 + np.arange(n_max_units - 1)

# Shuffle
ax.plot(
    dims, np.median(p_s_val_sdkl, axis=1),
    linewidth=3,
    color='black',
    label='Shuffle')
ax.fill_between(
    x=dims,
    y1=np.percentile(p_s_val_sdkl, q=40, axis=1),
    y2=np.percentile(p_s_val_sdkl, q=60, axis=1),
    color='black',
    alpha=0.1)

# Rotation
ax.plot(
    dims, np.median(p_r_val_sdkl, axis=1),
    linewidth=3,
    color='red',
    label='Rotation')
ax.fill_between(
    x=dims,
    y1=np.percentile(p_r_val_sdkl, q=40, axis=1),
    y2=np.percentile(p_r_val_sdkl, q=60, axis=1),
    color=f'red',
    alpha=0.1)
    

ax.set_xlim([2, 25])
ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=15)
ax.set_ylim([-0.05, 1.05])
ax.set_yticks([0, 0.25, 0.50, 0.75, 1.0])
ax.tick_params(labelsize=15)

ax.legend(loc='best', prop={'size': 18})
ax.set_xlim([2, 25])
ax.set_ylim([-0.05, 1.05])

ax.set_ylabel(r'\textbf{sDKL Percentile}', fontsize=15)
plt.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

n_max_units = units.shape[2]
dims = 2 + np.arange(n_max_units - 1)

# Shuffle
ax.plot(
    dims, np.median(p_s_val_lfi, axis=1),
    linewidth=3,
    color='black',
    label='Shuffle')
ax.fill_between(
    x=dims,
    y1=np.percentile(p_s_val_lfi, q=40, axis=1),
    y2=np.percentile(p_s_val_lfi, q=60, axis=1),
    color='black',
    alpha=0.1)

# Rotation
ax.plot(
    dims, np.median(p_r_val_lfi, axis=1),
    linewidth=3,
    color='red',
    label='Rotation')
ax.fill_between(
    x=dims,
    y1=np.percentile(p_r_val_lfi, q=40, axis=1),
    y2=np.percentile(p_r_val_lfi, q=60, axis=1),
    color=f'red',
    alpha=0.1)
    

ax.set_xlim([2, 25])
ax.set_xlabel(r'\textbf{Dimlet Dimension}', fontsize=15)
ax.set_ylim([-0.05, 1.05])
ax.set_yticks([0, 0.25, 0.50, 0.75, 1.0])
ax.tick_params(labelsize=15)

ax.legend(loc='best', prop={'size': 18})
ax.set_xlim([2, 25])
ax.set_ylim([-0.05, 1.05])

ax.set_ylabel(r'\textbf{LFI Percentile}', fontsize=15)
plt.show()