In [1]:
# Preliminaries 
from mdtb_neocortical import *
from sklearn.decomposition import DictionaryLearning
import ipywidgets as widgets       # interactive display
%config InlineBackend.figure_format = 'svg' # other available formats are: 'retina', 'png', 'jpeg', 'pdf'


## Quick toy example to test the algorithm
Toy example with U ~ Gamma(1,1)

In [3]:
# Dictonary learning: Toy example with V~normal, U ~ Gamma(1,x), Y = UV + eps 
K = 5 
N = 20 
P = 100 
eps = 10
beta = 1 

V = np.random.normal(0,1,(N,K))
V = V - V.mean(axis=0)
V = V / np.sqrt(np.sum(V**2,axis=0))


U = np.random.gamma(1,beta,(P,K))*0.1
Y = np.random.normal(0,eps/np.sqrt(N),(P,N)) #

In [None]:
num=10
Uhat = np.empty((num,P,K))
Vhat = np.empty((num,N,K))
for i in range(num):
    snn = DictionaryLearning(n_components=5, transform_algorithm='lasso_cd',random_state=None,positive_code=True,fit_algorithm='cd')
    snn.fit(Y)
    Uhat[i,:,:] = snn.transform(Y)
    Vhat[i,:,:] = snn.components_.T

In [None]:
fig = plt.figure(figsize=(12,5))
plt.subplot(2,2,1)
plt.imshow(Uhat[1,:,:].T,aspect='auto')
plt.subplot(2,2,2)
plt.imshow(Vhat[1,:,:].T,aspect='auto')
plt.subplot(2,2,3)
plt.imshow(Uhat[0,:,:].T,aspect='auto')
plt.subplot(2,2,4)
plt.imshow(Vhat[0,:,:].T,aspect='auto')


In [None]:
vmatch=np.empty((num,num))
for i in range(num):
    for j in range(num):
        vmatch[i,j]=(Vhat[i,:,:].T@ Vhat[j,:,:]).max(axis=0).mean()
vmatch

## Now try the same on a hemisphere 

In [None]:
# Load all surfaces 
flatsurf,inflsurf = load_surf()

In [None]:
# Plot s02 task maps 
taskmap,colname,colmap = load_wcon('s02')

In [None]:
# Load the task maps and center 
YL = np.vstack(taskmap[0].agg_data())
YR = np.vstack(taskmap[1].agg_data())
YL = YL-YL.mean(axis=0)
YR = YR-YR.mean(axis=0)
YL[np.isnan(YL)]=0
YR[np.isnan(YR)]=0

In [None]:
snn1 = DictionaryLearning(n_components=10, transform_algorithm='lasso_cd',random_state=33,positive_code=True,fit_algorithm='cd',transform_max_iter=2000)
snn1.fit(YL.T)

In [None]:
U = snn.transform(YL.T)
V = snn.components_

In [None]:
a=plt.hist(np.sum(U,axis=1),bins=50)

In [None]:
label = np.argmax(U,axis=1)
length = np.sum(U,axis=1)
fig = plt.figure(figsize=(12,5))
ax1 = plt.subplot(1,2,1)
ax2 = surf.plot.plotmap(label+1,flatsurf[0],overlay_type='label',cmap='Paired')
ax3 = plt.subplot(1,2,2)
ax4 = surf.plot.plotmap(np.sqrt(length),flatsurf[0],overlay_type='func')
plt.show()

In [None]:
U1 = snn1.transform(YL.T)
V1 = snn1.components_
label = np.argmax(U1,axis=1)
length = np.sum(U1,axis=1)
fig = plt.figure(figsize=(12,5))
ax1 = plt.subplot(1,2,1)
ax2 = surf.plot.plotmap(label+1,flatsurf[0],overlay_type='label',cmap='Paired')
ax3 = plt.subplot(1,2,2)
ax4 = surf.plot.plotmap(np.sqrt(length),flatsurf[0],overlay_type='func')
plt.show()

In [None]:
V- V1