In [1]:
import torch
import numpy as np
import scipy.io as sio
import os
import time
from scipy import stats

In [2]:
class EnsemblePursuitPyTorch():
    
    def calculate_cost_delta(self):
        cost_delta=torch.clamp(torch.matmul(self.current_v,self.X),min=0,max=None)**2/torch.matmul(self.current_v,self.current_v)-self.lambd
        return cost_delta
    
    def fit_one_assembly(self):
        '''
        Function for fitting one cell assembly and computing u and v of the currrent assembly (self.current_u,
        self.current_v).
        One neuron cell assemblies are excluded. 
        '''
        with torch.cuda.device(0) as device:
            #Fake i for initiating while loop. self.i stores the number of neurons in assemblies.
            self.i=1
            #If i is 1, e.g. only one neuron in fit cell assembly, will run fitting the assembly again. 
            #safety it to avoid infinite loops.
            safety_it=0
            n_of_neurons=100
            while self.i==1:
                top_neurons,_=self.corr_top_k(n_neurons=n_of_neurons)
                top_corr_neuron=self.select_top_k_corr_neuron(top_neurons,_,n_of_neurons)
                print('neuron',top_corr_neuron)
                #choose_neuron_idx=np.random.randint(0,self.sz[1],1)[0]
                choose_neuron_idx=top_corr_neuron
                #Array of keeping track of neurons in the cell assembly
                self.selected_neurons=torch.zeros([self.sz[1]]).cuda()
                self.selected_neurons[choose_neuron_idx]=1
                #Seed current_v
                self.current_v=self.X[:,choose_neuron_idx]
                #Fake cost to initiate while loop
                max_delta_cost=1000
                #reset i
                self.i=1
                while max_delta_cost>0:
                    cost_delta=self.calculate_cost_delta()
                    #invert the 0's and 1's in the array which stores which neurons have already 
                    #been selected into the assembly to use it as a mask
                    mask=self.selected_neurons.clone()
                    mask[self.selected_neurons==0]=1
                    mask[self.selected_neurons!=0]=0
                    masked_cost_delta=mask*cost_delta
                    values,sorted_neurons=masked_cost_delta.sort()
                    max_delta_neuron=sorted_neurons[-1]
                    max_delta_cost=values[-1]
                    if max_delta_cost>0:
                        self.current_v=(self.current_v+self.X[:,max_delta_neuron.item()])/2
                        self.selected_neurons[max_delta_neuron.item()]=1
                    self.i+=1
                safety_it+=1
                #Increase number of neurons to sample from if while loop hasn't been finding any assemblies.
                if safety_it>100:
                    n_of_neurons=500
                if safety_it>600:
                    n_of_neurons=1000
                if safety_it>1600:
                    raise ValueError('Assembly capacity too big, can\'t fit model')
            #Add final seed neuron to seed_neurons.        
            self.seed_neurons.append(top_corr_neuron)          
            #Calculate u based on final v fit for a cell assembly. 
            self.current_u=torch.clamp(torch.matmul(self.current_v,self.X),min=0,max=None)/torch.matmul(self.current_v,self.current_v)
            self.U=torch.cat((self.U,self.current_u.view(self.X.size(1),1)),1)
            self.V=torch.cat((self.V,self.current_v.view(1,self.X.size(0))),0)
    
    def corrcoef(self,x):
        '''
        Torch implementation of the full correlation matrix.
        '''
        # calculate covariance matrix of columns
        mean_x = torch.mean(x,0)
        xm = torch.sub(x,mean_x)
        c = x.mm(x.t())
        c = c / (x.size(1) - 1)

        # normalize covariance matrix
        d = torch.diag(c)
        stddev = torch.pow(d, 0.5)
        c = c.div(stddev.expand_as(c))
        c = c.div(stddev.expand_as(c).t())

        # clamp between -1 and 1
        c = torch.clamp(c, -1.0, 1.0)

        return c
    
    def corr_top_k(self,n_neurons=100):
        '''
        Finds n_neurons neurons that are on average most correlated to their 
        5 closest neighbors.
        '''
        #Compute full correlation matrix (works with one neuron per column,
        #so have to transpose.)
        corr=self.corrcoef(self.X.t())
        #Sorts each row of correlation matrix
        vals,ix=corr.sort(dim=1)
        #Discards the last entry corresponding to the diagonal 1 and then
        #selects 5 of the largest entries from sorted array.
        top_vals=vals[:,:-1][:,self.sz[1]-6:]
        #Averages the 5 top correlations.
        av=torch.mean(top_vals,dim=1)
        #Sorts the averages
        vals,top_neurons=torch.sort(av)
        #Selects top neurons
        top_neuron=top_neurons[self.sz[1]-(n_neurons+1):]
        top_val=vals[self.sz[1]-(n_neurons+1):]
        return top_neuron,top_val
          
    
    def select_top_k_corr_neuron(self,top_neuron,top_val,n_neurons=100):
        '''
        Randomly samples from k top correlated urons.
        '''
        #Randomly samples a neuron from the n_of_neurons top correlated.
        idx=torch.randint(0,n_neurons,size=(1,))
        print('top n', top_neuron[idx[0]].item(), top_val[idx[0]].item())
        return top_neuron[idx[0]].item()
    
    
    def fit_transform(self,X,lambd,n_ensembles=None):
        torch.manual_seed(7)
        with torch.cuda.device(0) as device:
            self.lambd=lambd
            #z-score data.
            self.X=stats.zscore(X+0.0000001,axis=0)
            #Creates cuda tensor from data
            self.X=torch.cuda.FloatTensor(self.X) 
            #Store dimensionalit of X for later use.
            self.sz=self.X.size()
            #Initializes U and V with zeros, later these will be discarded.
            self.U=torch.zeros((self.X.size(1),1)).cuda()
            self.V=torch.zeros([1,self.X.size(0)]).cuda()
            #List for storing the number of neurons in each fit assembly.
            self.nr_of_neurons=[]
            #List for storing the seed neurons for each assembly.
            self.seed_neurons=[]
            cost_lst=[]
            for iteration in range(0,n_ensembles):
                self.fit_one_assembly()
                self.nr_of_neurons.append(self.i)
                U_V=torch.mm(self.current_u.view(self.sz[1],1),self.current_v.view(1,self.sz[0]))
                U_V[U_V != U_V] = 0
                res=(self.X-U_V.t())
                self.X=res
                print('ensemble nr', iteration)
                #print('u',self.current_u)
                #print('v',self.current_v)
                #print('length v', torch.matmul(self.current_v,self.current_v))
                #print('norm',torch.norm(self.X))
                self.cost=torch.mean(torch.mul(res,res))
                print('cost',self.cost)
                cost_lst.append(self.cost.item())
            #After fitting arrays discard the zero initialization rows and columns from U and V.
            self.U=self.U[:,1:]
            self.V=self.V[1:,:]
            print(self.X.size())
            print(self.U.size())
            print(self.V.size())
            return torch.matmul(self.U,self.V).t().cpu(), self.nr_of_neurons, self.U.cpu(), self.V.cpu(), cost_lst, self.seed_neurons

In [3]:
X=sio.loadmat('/home/maria/Documents/EnsemblePursuit/data/natimg2800_M170717_MP034_2017-09-11.mat')['stim']['resp'][0][0]
X[X<0]=0
print(X.shape)
for col in range(0,X.shape[1]):
    if np.sum(X[:,col]==0)==X.shape[0]:
        print(col)

(5880, 10103)
8574


In [4]:
np.random.seed(7)
ep=EnsemblePursuitPyTorch()
s=time.time()
U_V,nr_of_neurons,U,V, cost_lst,seed_neurons=ep.fit_transform(X,500,1000)
e=time.time()
print(e-s)
print(nr_of_neurons)
print(cost_lst)

top n 1672 0.41533946990966797
neuron 1672
ensemble nr 0
cost tensor(0.9983, device='cuda:0')
top n 5168 0.48429423570632935
neuron 5168
ensemble nr 1
cost tensor(0.9964, device='cuda:0')
top n 196 0.4173440933227539
neuron 196
ensemble nr 2
cost tensor(0.9941, device='cuda:0')
top n 2233 0.4697026312351227
neuron 2233
ensemble nr 3
cost tensor(0.9927, device='cuda:0')
top n 1489 0.45894595980644226
neuron 1489
ensemble nr 4
cost tensor(0.9911, device='cuda:0')
top n 9367 0.432136207818985
neuron 9367
ensemble nr 5
cost tensor(0.9895, device='cuda:0')
top n 4115 0.47026801109313965
neuron 4115
ensemble nr 6
cost tensor(0.9878, device='cuda:0')
top n 2295 0.4510462284088135
neuron 2295
ensemble nr 7
cost tensor(0.9860, device='cuda:0')
top n 1408 0.4705081880092621
neuron 1408
ensemble nr 8
cost tensor(0.9850, device='cuda:0')
top n 223 0.43413883447647095
neuron 223
ensemble nr 9
cost tensor(0.9830, device='cuda:0')
top n 105 0.4322568476200104
neuron 105
ensemble nr 10
cost tensor(0.9

top n 6918 0.32041797041893005
neuron 6918
ensemble nr 88
cost tensor(0.9098, device='cuda:0')
top n 5262 0.32127898931503296
neuron 5262
ensemble nr 89
cost tensor(0.9090, device='cuda:0')
top n 6806 0.3410807251930237
neuron 6806
ensemble nr 90
cost tensor(0.9083, device='cuda:0')
top n 689 0.31337255239486694
neuron 689
ensemble nr 91
cost tensor(0.9079, device='cuda:0')
top n 3505 0.3953709304332733
neuron 3505
ensemble nr 92
cost tensor(0.9073, device='cuda:0')
top n 6025 0.3288411796092987
neuron 6025
ensemble nr 93
cost tensor(0.9067, device='cuda:0')
top n 6686 0.3506530821323395
neuron 6686
ensemble nr 94
cost tensor(0.9060, device='cuda:0')
top n 2205 0.3807002007961273
neuron 2205
ensemble nr 95
cost tensor(0.9054, device='cuda:0')
top n 3500 0.3139466643333435
neuron 3500
ensemble nr 96
cost tensor(0.9049, device='cuda:0')
top n 6545 0.339323490858078
neuron 6545
ensemble nr 97
cost tensor(0.9040, device='cuda:0')
top n 2628 0.32167965173721313
neuron 2628
ensemble nr 98
co

top n 8723 0.6201118230819702
neuron 8723
ensemble nr 175
cost tensor(0.8623, device='cuda:0')
top n 4403 0.5965811014175415
neuron 4403
ensemble nr 176
cost tensor(0.8619, device='cuda:0')
top n 2854 0.6504265666007996
neuron 2854
ensemble nr 177
cost tensor(0.8614, device='cuda:0')
top n 702 0.5932124257087708
neuron 702
ensemble nr 178
cost tensor(0.8610, device='cuda:0')
top n 2746 0.5982522368431091
neuron 2746
ensemble nr 179
cost tensor(0.8606, device='cuda:0')
top n 3429 0.593113899230957
neuron 3429
ensemble nr 180
cost tensor(0.8602, device='cuda:0')
top n 3958 0.6005157232284546
neuron 3958
ensemble nr 181
cost tensor(0.8598, device='cuda:0')
top n 3511 0.6038844585418701
neuron 3511
ensemble nr 182
cost tensor(0.8594, device='cuda:0')
top n 2328 0.6222836971282959
neuron 2328
ensemble nr 183
cost tensor(0.8590, device='cuda:0')
top n 695 0.6069912314414978
neuron 695
ensemble nr 184
cost tensor(0.8586, device='cuda:0')
top n 4022 0.6110048294067383
neuron 4022
ensemble nr 1

top n 1466 0.7325335144996643
neuron 1466
ensemble nr 262
cost tensor(0.8276, device='cuda:0')
top n 1842 0.7328417301177979
neuron 1842
ensemble nr 263
cost tensor(0.8274, device='cuda:0')
top n 9064 0.7407363653182983
neuron 9064
ensemble nr 264
cost tensor(0.8270, device='cuda:0')
top n 3597 0.7379263639450073
neuron 3597
ensemble nr 265
cost tensor(0.8267, device='cuda:0')
top n 1403 0.7943517565727234
neuron 1403
ensemble nr 266
cost tensor(0.8264, device='cuda:0')
top n 1761 0.7353013157844543
neuron 1761
ensemble nr 267
cost tensor(0.8260, device='cuda:0')
top n 3971 0.8350479006767273
neuron 3971
ensemble nr 268
cost tensor(0.8257, device='cuda:0')
top n 561 0.733891487121582
neuron 561
ensemble nr 269
cost tensor(0.8254, device='cuda:0')
top n 1355 0.7338634729385376
neuron 1355
ensemble nr 270
cost tensor(0.8251, device='cuda:0')
top n 8386 0.7317537665367126
neuron 8386
ensemble nr 271
cost tensor(0.8247, device='cuda:0')
top n 574 0.7586433291435242
neuron 574
ensemble nr 2

top n 182 0.8848922848701477
neuron 182
ensemble nr 350
cost tensor(0.7981, device='cuda:0')
top n 5215 1.0
neuron 5215
ensemble nr 351
cost tensor(0.7981, device='cuda:0')
top n 3434 0.8863757252693176
neuron 3434
ensemble nr 352
cost tensor(0.7977, device='cuda:0')
top n 6444 0.8981008529663086
neuron 6444
ensemble nr 353
cost tensor(0.7973, device='cuda:0')
top n 125 0.8969321250915527
neuron 125
ensemble nr 354
cost tensor(0.7970, device='cuda:0')
top n 304 0.9027355313301086
neuron 304
ensemble nr 355
cost tensor(0.7966, device='cuda:0')
top n 5215 1.0
neuron 5215
ensemble nr 356
cost tensor(0.7966, device='cuda:0')
top n 282 0.8965610861778259
neuron 282
ensemble nr 357
cost tensor(0.7964, device='cuda:0')
top n 499 0.8906934857368469
neuron 499
ensemble nr 358
cost tensor(0.7961, device='cuda:0')
top n 1739 0.8822957277297974
neuron 1739
ensemble nr 359
cost tensor(0.7957, device='cuda:0')
top n 1719 0.8826934695243835
neuron 1719
ensemble nr 360
cost tensor(0.7955, device='cuda

top n 10026 1.0
neuron 10026
ensemble nr 441
cost tensor(0.7732, device='cuda:0')
top n 10045 1.0
neuron 10045
ensemble nr 442
cost tensor(0.7729, device='cuda:0')
top n 10071 1.0
neuron 10071
ensemble nr 443
cost tensor(0.7727, device='cuda:0')
top n 10064 1.0
neuron 10064
ensemble nr 444
cost tensor(0.7724, device='cuda:0')
top n 10010 1.0
neuron 10010
ensemble nr 445
cost tensor(0.7724, device='cuda:0')
top n 10022 1.0
neuron 10022
ensemble nr 446
cost tensor(0.7724, device='cuda:0')
top n 10066 1.0
neuron 10066
ensemble nr 447
cost tensor(0.7721, device='cuda:0')
top n 10087 1.0
neuron 10087
ensemble nr 448
cost tensor(0.7719, device='cuda:0')
top n 10095 1.0
neuron 10095
ensemble nr 449
cost tensor(0.7717, device='cuda:0')
top n 10036 1.0
neuron 10036
ensemble nr 450
cost tensor(0.7717, device='cuda:0')
top n 10037 1.0
neuron 10037
ensemble nr 451
cost tensor(0.7715, device='cuda:0')
top n 10077 1.0
neuron 10077
ensemble nr 452
cost tensor(0.7713, device='cuda:0')
top n 10045 1.0


top n 10083 1.0
neuron 10083
ensemble nr 541
cost tensor(0.7610, device='cuda:0')
top n 10013 1.0
neuron 10013
ensemble nr 542
cost tensor(0.7608, device='cuda:0')
top n 10065 1.0
neuron 10065
ensemble nr 543
cost tensor(0.7605, device='cuda:0')
top n 10045 1.0
neuron 10045
ensemble nr 544
cost tensor(0.7605, device='cuda:0')
top n 10067 1.0
neuron 10067
ensemble nr 545
cost tensor(0.7603, device='cuda:0')
top n 10028 1.0
neuron 10028
ensemble nr 546
cost tensor(0.7600, device='cuda:0')
top n 10080 1.0
neuron 10080
ensemble nr 547
cost tensor(0.7600, device='cuda:0')
top n 10048 1.0
neuron 10048
ensemble nr 548
cost tensor(0.7600, device='cuda:0')
top n 10100 1.0
neuron 10100
ensemble nr 549
cost tensor(0.7600, device='cuda:0')
top n 10070 1.0
neuron 10070
ensemble nr 550
cost tensor(0.7600, device='cuda:0')
top n 10021 1.0
neuron 10021
ensemble nr 551
cost tensor(0.7600, device='cuda:0')
top n 10072 1.0
neuron 10072
ensemble nr 552
cost tensor(0.7598, device='cuda:0')
top n 10006 1.0


top n 10078 1.0
neuron 10078
ensemble nr 641
cost tensor(0.7553, device='cuda:0')
top n 10060 1.0
neuron 10060
ensemble nr 642
cost tensor(0.7553, device='cuda:0')
top n 10072 1.0
neuron 10072
ensemble nr 643
cost tensor(0.7553, device='cuda:0')
top n 10039 1.0
neuron 10039
ensemble nr 644
cost tensor(0.7553, device='cuda:0')
top n 10022 1.0
neuron 10022
ensemble nr 645
cost tensor(0.7553, device='cuda:0')
top n 10101 1.0
neuron 10101
ensemble nr 646
cost tensor(0.7553, device='cuda:0')
top n 10064 1.0
neuron 10064
ensemble nr 647
cost tensor(0.7553, device='cuda:0')
top n 10046 1.0
neuron 10046
ensemble nr 648
cost tensor(0.7553, device='cuda:0')
top n 10023 1.0
neuron 10023
ensemble nr 649
cost tensor(0.7553, device='cuda:0')
top n 10030 1.0
neuron 10030
ensemble nr 650
cost tensor(0.7553, device='cuda:0')
top n 10016 1.0
neuron 10016
ensemble nr 651
cost tensor(0.7550, device='cuda:0')
top n 10020 1.0
neuron 10020
ensemble nr 652
cost tensor(0.7550, device='cuda:0')
top n 10025 1.0


top n 10064 1.0
neuron 10064
ensemble nr 741
cost tensor(0.7534, device='cuda:0')
top n 10097 1.0
neuron 10097
ensemble nr 742
cost tensor(0.7534, device='cuda:0')
top n 10091 1.0
neuron 10091
ensemble nr 743
cost tensor(0.7534, device='cuda:0')
top n 10010 1.0
neuron 10010
ensemble nr 744
cost tensor(0.7534, device='cuda:0')
top n 10058 1.0
neuron 10058
ensemble nr 745
cost tensor(0.7534, device='cuda:0')
top n 10101 1.0
neuron 10101
ensemble nr 746
cost tensor(0.7534, device='cuda:0')
top n 10008 1.0
neuron 10008
ensemble nr 747
cost tensor(0.7534, device='cuda:0')
top n 10081 1.0
neuron 10081
ensemble nr 748
cost tensor(0.7534, device='cuda:0')
top n 10079 1.0
neuron 10079
ensemble nr 749
cost tensor(0.7534, device='cuda:0')
top n 10073 1.0
neuron 10073
ensemble nr 750
cost tensor(0.7534, device='cuda:0')
top n 10089 1.0
neuron 10089
ensemble nr 751
cost tensor(0.7534, device='cuda:0')
top n 10004 1.0
neuron 10004
ensemble nr 752
cost tensor(0.7534, device='cuda:0')
top n 10088 1.0


top n 10014 1.0
neuron 10014
ensemble nr 842
cost tensor(inf, device='cuda:0')
top n 10024 1.0
neuron 10024
ensemble nr 843
cost tensor(inf, device='cuda:0')
top n 10043 1.0
neuron 10043
ensemble nr 844
cost tensor(inf, device='cuda:0')
top n 10083 1.0
neuron 10083
ensemble nr 845
cost tensor(inf, device='cuda:0')
top n 10039 1.0
neuron 10039
ensemble nr 846
cost tensor(inf, device='cuda:0')
top n 10046 1.0
neuron 10046
ensemble nr 847
cost tensor(inf, device='cuda:0')
top n 10019 1.0
neuron 10019
ensemble nr 848
cost tensor(inf, device='cuda:0')
top n 10044 1.0
neuron 10044
ensemble nr 849
cost tensor(inf, device='cuda:0')
top n 10059 1.0
neuron 10059
ensemble nr 850
cost tensor(inf, device='cuda:0')
top n 10087 1.0
neuron 10087
ensemble nr 851
cost tensor(inf, device='cuda:0')
top n 10057 1.0
neuron 10057
ensemble nr 852
cost tensor(inf, device='cuda:0')
top n 10036 1.0
neuron 10036
ensemble nr 853
cost tensor(inf, device='cuda:0')
top n 10079 1.0
neuron 10079
ensemble nr 854
cost te

top n 10004 1.0
neuron 10004
ensemble nr 946
cost tensor(inf, device='cuda:0')
top n 10101 1.0
neuron 10101
ensemble nr 947
cost tensor(inf, device='cuda:0')
top n 10068 1.0
neuron 10068
ensemble nr 948
cost tensor(inf, device='cuda:0')
top n 10086 1.0
neuron 10086
ensemble nr 949
cost tensor(inf, device='cuda:0')
top n 10078 1.0
neuron 10078
ensemble nr 950
cost tensor(inf, device='cuda:0')
top n 10086 1.0
neuron 10086
ensemble nr 951
cost tensor(inf, device='cuda:0')
top n 10100 1.0
neuron 10100
ensemble nr 952
cost tensor(inf, device='cuda:0')
top n 10101 1.0
neuron 10101
ensemble nr 953
cost tensor(inf, device='cuda:0')
top n 10098 1.0
neuron 10098
ensemble nr 954
cost tensor(inf, device='cuda:0')
top n 10003 1.0
neuron 10003
ensemble nr 955
cost tensor(inf, device='cuda:0')
top n 10005 1.0
neuron 10005
ensemble nr 956
cost tensor(inf, device='cuda:0')
top n 10013 1.0
neuron 10013
ensemble nr 957
cost tensor(inf, device='cuda:0')
top n 10044 1.0
neuron 10044
ensemble nr 958
cost te

In [5]:
def test_train_split(data,stim):
    unique, counts = np.unique(stim.flatten(), return_counts=True)
    count_dict=dict(zip(unique, counts))

    keys_with_enough_data=[]
    for key in count_dict.keys():
        if count_dict[key]==2:
            keys_with_enough_data.append(key)

    filtered_stims=np.isin(stim.flatten(),keys_with_enough_data)

    #Arrange data so that responses with the same stimulus are adjacent
    z=stim.flatten()[np.where(filtered_stims)[0]]
    sortd=np.argsort(z)
    istim=np.sort(z)
    X=data[filtered_stims,:]
    out=X[sortd,:].copy()

    x_train=out[::2,:]
    y_train=istim[::2]
    x_test=out[1::2,:]
    y_test=istim[1::2]
    
    return x_train, x_test, y_train, y_test

def evaluate_model(x_train,x_test):
    corr_mat=np.zeros((x_train.shape[0],x_train.shape[0]))
    for j in range(0,x_train.shape[0]):
        for i in range(0,x_test.shape[0]):
            corr_mat[j,i]=np.corrcoef(x_train[j,:],x_test[i,:])[0,1]
    print(np.mean(np.argmax(corr_mat, axis=0) == np.arange(0,x_train.shape[0],1,int)))
    
stim=sio.loadmat('/home/maria/Documents/EnsemblePursuit/data/natimg2800_M170717_MP034_2017-09-11.mat')['stim']['istim'][0][0]
x_train, x_test, y_train, y_test=test_train_split(np.array(V.t()),stim)
evaluate_model(x_train,x_test)

  X -= avg[:, None]


KeyboardInterrupt: 

In [None]:
x_train, x_test, y_train, y_test=test_train_split(X[:,seed_neurons],stim)
evaluate_model(x_train,x_test)

In [None]:
print(len(seed_neurons))

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.hist(nr_of_neurons)
plt.title('Histogram of assembly sizes, lambda 500, 1000 assemblies')