# Import dependencies and set current folder

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import h5py
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
import matplotlib.backends as back
import LocaNMF
import torch
import time
import hdf5storage
from matplotlib.backends.backend_pdf import PdfPages
cd /PATH/TO/DATA

# Load & Format Data

In [None]:
sessList = {'A15307': '2_12',
            'A15098': '1_28',
            'A15309': '1_22',
            'A15100': '1_20',
            'A15312': '1_34',
            'A15301': '2_28',
            'A15352': '3_32'}
tagList = {'A15307': 'SVD_data', 
           'A15098': 'SVD', 
           'A15309': 'SVD_data', 
           'A15100': 'SVD_data', 
           'A15312': 'SVD_data', 
           'A15301': 'SVD_data', 
           'A15352': 'SVD_data'}
curAnimal = 'A15098'
datafolder = 'data/'
addTag = tagList[curAnimal]
addTagFinal = 'locaNMF'

matlabImport = h5py.File(datafolder + curAnimal + '_' + addTag + '.mat','r')
allenImport = h5py.File(datafolder + 'allenROI.mat','r')
dorsalMapScaledOrig = allenImport['ROI']

Uc = matlabImport['coeff'][:]
Vc = matlabImport['score'][:]
valid = np.array(matlabImport['maskValid'][:],'int')
brainmaskOrig = np.array(matlabImport['mask'], 'bool') # THIS for window boundaries

brainmaskOrig = np.array(brainmaskOrig)
dorsalMapScaledOrig = np.array(dorsalMapScaledOrig)
valid = valid.squeeze()-1
U = Uc.transpose((1,0))
V = Vc.transpose((1,0))

globalMask = brainmaskOrig.transpose((1,0))
ROI = dorsalMapScaledOrig.transpose(1,0)
tmp = np.zeros((ROI.shape[0]*ROI.shape[1], U.shape[1]))
tmp[valid[:], :] = U
U = np.reshape(tmp, (ROI.shape[1], ROI.shape[0], U.shape[1]))
U = U.transpose((1,0,2))
V = V.transpose((1,0))
print(U.shape)
print(V.shape)
print(globalMask.shape)
print(ROI.shape)
q, r = np.linalg.qr(V.T)
video_mats = (np.copy(U[globalMask]), r.T)#V)



# Generate the region masks

In [None]:
pixelCount = np.histogram(ROI[:],range(1,1+int(np.max(ROI[:]))))
print(pixelCount[0])
#rank_range = (1, min(U.shape[2], 100), 1) #20
rank_range = (1, min(min(pixelCount[0]), 100), 1) #20

print(rank_range)

#rank_range = (1, min(20, 100), 1) #20
device='cuda'
region_mats = LocaNMF.extract_region_metadata(globalMask,
                                            ROI,
                                            min_size=1)
sigma=.002 # Previous
region_mats[1][:] = 1 - np.exp(-1*np.power(region_mats[1],2) * sigma)
curArea = 0
A = np.zeros(globalMask.shape, dtype=np.float32)
A[globalMask] = region_mats[1][curArea]


plt.imshow(A)
plt.show()

# Save a plot of the masks

In [None]:
out_pdf = curAnimal+'_components.pdf'
pdf = PdfPages(out_pdf)
#figsize=(11.69,8.27)
    #fig, axs = plt.subplots(1 + int(region_ranks[1+rdx] / 4), 4,
fig, axs = plt.subplots(1+int(len(region_mats[1]) / 4), 4,
                        figsize=(16,(int(len(region_mats[1]) / 4)) * 4))
axs = axs.reshape((int(np.prod(axs.shape)),))

#for i in range(len(region_mats[1])):
for i, ax in enumerate(axs):
    A = np.zeros(globalMask.shape, dtype=np.float32)    
    if i < len(region_mats[1]):
        A[globalMask] = region_mats[1][i]
        ax.imshow(A)
    else:
        ax.set_axis_off()
    #plt.imshow(A)
plt.show()
pdf.savefig(fig)
pdf.close()

# Put data in the right format / shape

In [None]:
region_metadata = LocaNMF.RegionMetadata(region_mats[0].shape[0],
                                       region_mats[0].shape[1:],
                                       device=device)
region_metadata.set(torch.from_numpy(region_mats[0].astype(np.uint8)),
                    torch.from_numpy(region_mats[1]),
                    torch.from_numpy(region_mats[2].astype(np.int64)))
torch.cuda.synchronize()
print('v SVD Initialization')
t0 = time.time()
region_videos = LocaNMF.factor_region_videos(video_mats,
                                           region_mats[0],
                                           rank_range[1],
                                           device=device)
torch.cuda.synchronize()
print("\'-total : %f" % (time.time() - t0))
low_rank_video = LocaNMF.LowRankVideo(
    (int(np.sum(globalMask)),) + video_mats[1].shape,
    device=device
)
low_rank_video = LocaNMF.LowRankVideo(
    (int(np.sum(globalMask)),) + video_mats[1].shape,
    device=device
)
low_rank_video.set(torch.from_numpy(video_mats[0].T),
                   torch.from_numpy(video_mats[1]))

# Run locaNMF

In [None]:
torch.cuda.synchronize()
print('v Rank Line Search')
t0 = time.time()
# 50 loc .96 good for A15301 
res = LocaNMF.rank_linesearch(low_rank_video,
                              region_metadata,
                              region_videos,
                              maxiter_rank=150,
                              maxiter_lambda=200,
                              maxiter_hals=50,
                              lambda_step=1.5,
                              lambda_init=1e-3,
                              loc_thresh=75,
                              r2_thresh=.985,
                              rank_range=rank_range,
                              verbose=[True, False, False],
                              sample_prop=(1,1),
                              device=device)


torch.cuda.synchronize()
print("\'-total : %f" % (time.time() - t0))
region_ranks = [0]
region_idx = []
for rdx in torch.unique(res.regions.data, sorted=True):
    region_ranks.append(torch.sum(rdx == res.regions.data).item())
    region_idx.append(rdx.item())
    print((np.min(region_ranks),
       np.mean(region_ranks),
       np.max(region_ranks)))
print((np.sum(np.array(region_ranks) == np.max(region_ranks)), len(region_metadata)))
print((1 + int(np.amax(region_ranks) / 4)) * 4)

# Save some locaNMF output plots

In [None]:
dorsalMapScaledB = ROI
brainmaskOrig = globalMask
A = np.zeros(globalMask.shape, dtype=np.float32)
#out_pdf = r+curAnimal+'Components.pdf'
out_pdf = curAnimal+'_components5.pdf'
pdf = PdfPages(out_pdf)
#figsize=(11.69,8.27)
for rdx, i in zip(region_idx, np.cumsum(region_ranks[:-1])):
    #fig, axs = plt.subplots(1 + int(region_ranks[1+rdx] / 4), 4,
    fig, axs = plt.subplots(1 + int(np.amax(region_ranks) / 4), 4,
                            #figsize=(16,(1 + int(region_ranks[1+rdx] / 4)) * 4))
                            figsize=(16,(1 + int(np.amax(region_ranks) / 4)) * 4))
                            #
                            #figsize=(16,16))
    axs = axs.reshape((int(np.prod(axs.shape)),))
    A[globalMask] = 2*(res.distance.data[i].cpu()==0) + (res.distance.data[i].cpu()>0)
    B = dorsalMapScaledB
    B = B+(np.amax(dorsalMapScaledB)+1)*A
    axs[0].imshow(B)
    axs[0].set_title("Region: {}".format(rdx+1))
    for j, ax in enumerate(axs[1:]):
        if i + j < len(res) and res.regions.data[i+j].item() == rdx:
            A[globalMask] = res.spatial.data[i+j].cpu()
            ax.set_title("Component {}".format(i+j))
            ax.imshow(A)
        else:
            A[globalMask] = 0
            ax.set_axis_off()
        #ax.imshow(A)
        
    plt.show()
    pdf.savefig(fig)
pdf.close()

# Prepare data for saving as a new .mat file

In [None]:
dorsalMapScaledB = ROI
A = np.zeros(globalMask.shape, dtype=np.float32)
temporal = res.temporal
curRange = temporal.shape[0]
compData = np.zeros((V.T.shape[0], curRange))
C = np.matmul(q,res.temporal.data.cpu().numpy().T)
print(C.shape)
for i in range(curRange):
    A[globalMask] = 2*(res.distance.data[i].cpu()==0) + (res.distance.data[i].cpu()>0)
    B = dorsalMapScaledB
    B = B+(np.amax(dorsalMapScaledB)+1)*A
    y = temporal.data[i,:].cpu().numpy()*V.T
    y = np.sum(y,axis=1)
    compData[:,i] = y
print(compData.shape)


# Save

In [None]:
import scipy.io as sio
a_dict = {'compData': compData, 'compXY': res.spatial.data.cpu().numpy(), 'compD': res.distance.data.cpu().numpy(), 'compT': res.temporal.data.cpu().numpy(), 'C': C, 'region_ranks': region_ranks, 'region_idx': region_idx}
fName = datafolder + curAnimal + '_' + addTag + '_' + addTagFinal + '_output5.mat'
sio.savemat(fName, {'a_dict': a_dict})