In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd drive/MyDrive/Optim/modelfit/

/content/drive/MyDrive/Optim/modelfit


Load the receptive fileds shift vectors (vector plots)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
# load the shift vectors LIP sac. onset and  fef_sacon.npy
with open('data2/vecdata_sacon.npy', 'rb') as f:

       crf = np.load(f) # lost of crf center [x,y]
       shift_vec = np.load(f) # lisft of shift vector [x,y]
       sacamp = np.load(f) # list of saccade amplitude


#len(d1)

In [None]:
import time
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import pdb
import torch.nn.functional as F
import numpy as np
import torchvision.utils as vutils
from torch.utils.data import Dataset,DataLoader
import torch.optim as optim
from scipy.stats import gamma
import math
def gauss2d(rangexy,pos,sigma,InputAmp):
         """
         2d gaussian input function
         Args:
             range_xy : range of x y axis.
             pos      : input position , [x,y]
             sigma    : input sigma , [sigmax,sigmay]
             inputAmp : input amplitude
         """
         range_y = rangexy
         range_x = rangexy
         mu_x, mu_y = pos[0],pos[1]
         sigma_x,sigma_y = sigma[0],sigma[1]
         argx = (range_x-mu_x)/sigma_x
         argx = -0.5*torch.pow(argx,2)
         argy = (range_y-mu_y)/sigma_y
         argy = -0.5*torch.pow(argy,2)
         gx = argx.exp().unsqueeze(0)
         gy = argy.exp().unsqueeze(0)
         G = torch.matmul(gy.T,gx)
         #pdb.set_trace()
         return G
class GaussStim(Dataset):
    """ 2-D gaussian stimulus
    Args:
      N : size of the network (N*N)
      InputAmp : input amplitude
      InputSig : sigma of the gaussian input
    """
    def __init__(self,N,InputAmp,InputSig):
        self.N  = N # input size (N*N)
        self.InputAmp = InputAmp # input amplitude
        self.InputSig = InputSig # input sigma

        self.rangexy = torch.arange(N)
    def __getitem__(self,idx):
        """
        Get a sample from the dataset given the index idx

        """
        # Transfer the 1d index to 2d (x,y) postion of the gaussian input
        data_shape = [self.N,self.N]
        posy, posx = np.unravel_index(idx,data_shape)

        x    = self.rangexy[posx]
        y    = self.rangexy[posy]
        #pdb.set_trace()
        # gauss2d(range_xy,pos,sigma,InputAmp,device):
        G = gauss2d(self.rangexy,[x,y],[self.InputSig,self.InputSig],self.InputAmp)
        return G
    def __len__(self):
        """
         Return the size of the dataset.
        """
        return self.N*self.N

class CircuitModel(nn.Module):
     """
     sigE     :  sigma of excitation (trainable Parameters)
     sigI     :  sigma of inhibition (trainable Parameters)
     ampE     :  amplitude of excitation (trainable Parameters)
     ampI     :  amplitude of inhibition (trainable Parameters)
     cd       :  cd signal (trainable Parameters)
     attSig   :  width parameter of attention modulation (trainable Parameters)
     attAmp   :  amplitude of attention modulation (trainable Parameters)


     cutoffFactor : multiple of sigma at which the kernel is cut off
     maxT     : total simulation time steps
     beta     : steepness of relu function
     dt       : time step
     tau      : time constant
     device   : which divice (cpu or gpu)

     """

     def __init__(self,args):
        super(CircuitModel,self).__init__()
        self.N = args.N
        '''
        self.sigE =  20*nn.Parameter(torch.rand(1)) # Initialized by uniform distribution U(0,20)
        self.sigI = 20*nn.Parameter(torch.rand(1))
        self.ampE = 20*nn.Parameter(torch.rand(1))
        self.ampI = 20*nn.Parameter(torch.rand(1))

        self.cd = 20*nn.Parameter(torch.rand(1))
        self.attSig = 20*nn.Parameter(torch.rand(1))
        self.attAmp = 20*nn.Parameter(torch.rand(1))
        '''
        self.sigE  = nn.Parameter(args.sigE)
        self.sigI  = nn.Parameter(args.sigI)
        self.ampE  = nn.Parameter(args.ampE)
        self.ampI  = nn.Parameter(args.ampI)
        self.cd  = nn.Parameter(args.cd)
        self.attSig  = nn.Parameter(args.attSig)
        self.attAmp  = nn.Parameter(args.attAmp)

        self.xyrange = torch.arange(self.N)
        self.cutoffFactor = 5 #  sets the cutoff factor for the kernel (to make computation faster, the kernel function is cut off at a certain multiple of sigma
        self.maxT = args.maxT
        self.beta = args.beta
        self.dt = args.dt
        self.tau = args.tau
        self.alpha = self.dt/self.tau
        self.rangexy  = torch.arange(args.N)-args.N//2
        self.rangexy = self.rangexy.to(arg.device)
        self.Target = [self.N*3//4,self.N*3//4] # Fix the target position

        self.device = args.device
        self.inputGammaAmp = self.gammaInput().to(self.device)  # gamma function
     def _gauss1d(self,kernalRange,sigma):
        """
        1-d gaussian function for symmetry connection
        """
        x = torch.arange(-kernalRange,kernalRange+1,1).to(self.device)

        argx = x/(sigma*math.sqrt(2))
        log_argx = -0.5*torch.pow(argx,2)
        gauss1d = log_argx.exp()
        return gauss1d
     def _gauss1d_drv(self,kernalRange,sigma):
        """
        first-order derivative of 1-d gaussian function for asymmetry connection
        """
        x = -torch.arange(-kernalRange,kernalRange+1,1).to(self.device)
        gauss1d = self._gauss1d(kernalRange,sigma)
        #pdb.set_trace()
        dgauss1d = x*gauss1d/(torch.pow(sigma,2)) # ??
        return dgauss1d

     def _cutkernel(self,w,halflen):
         """
         helper function, cut the conv kernelï¼Œ make it not too large.
         """
         l = w.size(-1)
         mid = l//2
         #pdb.set_trace()
         low_ind = max(0,mid-halflen)
         high_ind = min(mid+halflen+1,l)
         return w[low_ind:high_ind]
     def _computeKernelRange(self,sigma,cutoffFactor,fieldSize):
         """
         helper function determine the range of an connection kernel
         """
         #  determine the range for an interaction kernel
         r1 = abs(sigma.item()*cutoffFactor)
         #pdb.set_trace()
         r1 =min(int(r1),fieldSize-1)
         return r1
     def init_hidden(self,batch_size):
         # initialize the hidden state
         return torch.zeros(batch_size, 1,self.N,self.N)

     def _center(self,rf):
        """
        Calculate 2d-rf's center of mass.
        """
        #rf = rf.squeeze()
        rf = torch.mean(rf,dim=0) # calculate the mean through the time dim(1st dim)
        x = torch.arange(0,self.N).to(self.device)
        y = torch.arange(0,self.N).to(self.device)
        grid_y, grid_x = torch.meshgrid(x, y)
        M = rf.view(-1).sum()
        x_mean_ = grid_x*rf
        y_mean_ = grid_y*rf
        x_mean = x_mean_.view(-1).sum()/M
        y_mean = y_mean_.view(-1).sum()/M
        #pdb.set_trace()


        return x_mean , y_mean

     def get_model_para(self):
        """
        Get the model parameters,return the list of the trainable parameters
        """
        para=[self.sigE.detach().to('cpu').numpy(),
           self.sigI.detach().to('cpu').numpy(),
           self.ampE.detach().to('cpu').numpy(),
           self.ampI.detach().to('cpu').numpy(),
           self.attSig.detach().to('cpu').numpy(),
           self.attAmp.detach().to('cpu').numpy(),
           self.cd.detach().to('cpu').numpy()]
        return para
     def gammaInput(self,shape=5,scale=10):
        t = np.linspace (0, self.maxT, self.maxT)
        amp = gamma.pdf(t, shape, 0,scale)
        amp = amp/np.max(amp)
        return torch.from_numpy(amp)
     def _Weight(self):
        """
        Create convolution kernels
        """
        # compute the kernel size given the cutoffFactor and sigma
        self.kernalRangeE = self._computeKernelRange(self.sigE,self.cutoffFactor,self.N)
        self.kernalRangeI = self._computeKernelRange(self.sigI,self.cutoffFactor,self.N)

        #kernelLenE = 2*self.kernalRangeE # kernel length (-self.kernalRangeE, self.kernalRangeE)
        # Excitatory convolution kernel along x axis
        W_excX  = self._gauss1d(self.kernalRangeE,self.sigE) * self.ampE # symmetry
        W_excX_ = self._gauss1d_drv(self.kernalRangeE,self.sigE )* self.ampE # asymmetry
        # convolution kernel along y axis
        W_excY  = self._gauss1d(self.kernalRangeE,self.sigE)
        # add the symmetry and asymmetry kernels together
        kernelExcX = W_excX + self.cd*W_excX_
        kernelExcY = W_excY
        #l = 2*kernalRangeI
        kernelLenI = 2*self.kernalRangeI
        # Inhibitory convolution kernel along x axis
        W_inhX  = self._gauss1d(self.kernalRangeI,self.sigI) * self.ampI
        W_inhX_ = self._gauss1d_drv(self.kernalRangeI,self.sigI )* self.ampI
        W_inhY  = self._gauss1d(self.kernalRangeI,self.sigI)
        kernelInhX = W_inhX + self.cd*W_inhX_
        kernelInhY = W_inhY

        kernelExcX = self._cutkernel(kernelExcX,40) # Make the conv kernel not too large
        kernelExcY = self._cutkernel(kernelExcY,40)
        kernelInhX = self._cutkernel(kernelInhX,40)
        kernelInhY = self._cutkernel(kernelInhY,40)

        return kernelExcX.unsqueeze(0),kernelExcY.unsqueeze(0),\
               kernelInhX.unsqueeze(0),kernelInhY.unsqueeze(0),\
               W_excX_.unsqueeze(0) # Need extra dimension when doing convolution
     def forward(self,x,crf,state=None):
        """
        Compute output  from input.
        x : input 2-d gaussian bump N*N
        crf : (x,y) position of cRF
        """
        x=x.unsqueeze(1) # add extra dimension for convolution
        crf_x,crf_y = crf[0],crf[1] # get the position of crf
        output = []
        # gaussian attention modulation
        #pdb.set_trace()
        Mod = 1+gauss2d(self.rangexy,self.Target,[self.attSig,self.attSig],self.attAmp)

        self.Mod = Mod.expand_as(x) # expand the self.Mod as the shape of x

        self.Mod = self.Mod.to(x.device)
        # get covolution kernels of x and y directions
        self.WEx,self.WEy ,self.WIx,self.WIy,_ = self._Weight()
        #padding parameters for the conv2D
        lenE = self.WEx.size(-1)
        lenI = self.WIx.size(-1)
        pdE = lenE//2
        pdI = lenI//2

        # add extra 2 dimensions for conv2D (input and ouput channel)
        # kernel shape will be (1,1,1,kernelsize)
        # excitatory
        WEx = self.WEx.unsqueeze(0).unsqueeze(0)
        WEy = self.WEy.T
        WEy = WEy.unsqueeze(0).unsqueeze(0)
        # inhibitory
        WIx = self.WIx.unsqueeze(0).unsqueeze(0)
        WIy = self.WIy.T
        WIy = WIy.unsqueeze(0).unsqueeze(0)

        # rf of self.T time steps
        #record = torch.zeros(self.T,self.N,self.N)
        rf = torch.zeros(self.maxT,x.size(0)) # initial an array (T,N) of the single neural response to probe
        if state is None:
            state  = self.init_hidden(x.size(0)).to(x.device)
        # run neural dynamics
        for i in range(self.maxT):
          inputE = self.Mod * state # attention modulation
          #
          recEx = F.conv2d(inputE,WEx,padding = (0,pdE))
          recE = F.conv2d(recEx,WEy,padding = (pdE,0))
          recIx = F.conv2d(state,WIx,padding = (0,pdI))
          recI = F.conv2d(recIx,WIy,padding = (pdI,0))
          #pdb.set_trace()
          state = state + self.alpha*(-state+recE-recI+x*self.inputGammaAmp[i])
          state = self.beta*F.relu6(state)
          # neural response at time step i
          #rf[i] = state[:,0,crf_x,crf_y]
          #pdb.set_trace()
          rf[i] = state[:,0,crf_y,crf_x] # get the neural response
          output.append(state)
        return output,state.squeeze(),rf
if __name__ == "__main__":
  # code for debug
  class args:
    '''
    sigE =8.0 #(deg)
    sigI = 12.0
    ampE = 10.0
    ampI = 5.0
    attSig = 20.0 #
    attAmp = 0.0 #
    cd=0.0 #
    '''
    sigE =20*torch.rand(1)
    sigI = 20*torch.rand(1)
    ampE = 20*torch.rand(1)
    ampI = 20*torch.rand(1)
    attSig = 20*torch.rand(1)
    attAmp = 20*torch.rand(1)
    cd=20*torch.rand(1)
    InputAmp = 7.0 #(deg)
    InputSig = 10.0
    tau = 30.0 # ms
    N = 50 # network size
    dt = 2.0 # time step
    beta = 0.4




    #time_step = 50
    #Target=[75,75]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batchsize = N
    epochs = 100
    maxT = 60 # total simulation time steps

  arg = args()
  # get dataset
  data = GaussStim(arg.N,arg.InputAmp,arg.InputSig)
  model = CircuitModel(arg)
  model = model.to(arg.device)
  img = data[2080].unsqueeze(0)
  img = img.to(arg.device)
  tot_activity_,s,rf_ = model(img,[3,3])
  xc,yc = model._center(img)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [None]:
## training code
class args:# model parameters
    sigE =20*torch.rand(1)
    sigI = 20*torch.rand(1)
    ampE = 20*torch.rand(1)
    ampI = 20*torch.rand(1)
    attSig = 20*torch.rand(1)
    attAmp = 20*torch.rand(1)
    cd=20*torch.rand(1)
    InputAmp = 7.0 #(deg)
    InputSig = 10.0
    tau = 30.0 # ms
    downsample = 1
    N = 50//downsample # network size
    dt = 2.0 # time step
    beta = 0.4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batchsize = N
    epochs = 12
    maxT = 60 # total simulation time steps
    tw = 50 # time window of the neural response (time steps)

para = args() # get the parameters

# Get dataset
data = GaussStim(para.N,para.InputAmp,para.InputSig)
dataloader = DataLoader(data,para.batchsize,shuffle=False)



sac_list = np.unique(sacamp)
PATH = 'data2/tot_weights.pt'
#Reload the model and resume the trainning process
if os.path.exists(PATH):
    checkpoint = torch.load(PATH)
    start_crfi = checkpoint['epoch']
    modelW = checkpoint['modelW']
    cd_dict = checkpoint['cd_dict']
else:# Initial a list to collect model parameters
       start_crfi = 0
       modelW = []
       cd_dict=dict([(k,[]) for k in sac_list]) # Initial a dict to save the cd parameters w.r.t saccade amplitudes



for crfi in range(start_crfi,len(crf)): # len(crf) is the cell number
  #try:
     tot_loss=[]
     sacamp_ = sacamp[crfi]
     # training target,prf_center = crf_center+shiftvec
     Targetx=torch.tensor([crf[crfi][0]]).float().to(para.device)+shift_vec[crfi][0]
     Targety=torch.tensor([crf[crfi][1]]).float().to(para.device)+shift_vec[crfi][1]
     # crf center
     record_cell = [crf[crfi][0],crf[crfi][1]]
     # define the model
     para = args()
     model = CircuitModel(para)
     # move the model to gpu/cpu
     model = model.to(para.device)
     optimizer = optim.SGD(model.parameters(), lr=0.0001,weight_decay=0.01)
     criterion = nn.MSELoss()
     for ii in range(para.epochs):
       # initial the time course of prf

       prf = torch.zeros(para.batchsize,para.maxT,para.batchsize).to(para.device) # (N,T,N)
       for i ,stim in enumerate(dataloader):
         # prf size is N*N, batch size is N, so each epoch will get one row of prf across time
         stim = stim.to(para.device)
         tot_activity_,s,prf_ = model(stim,record_cell)
         prf[i] = prf_ # i th row of prf of record_cell
       prf = prf.transpose(1,0).contiguous() # reshape prf to shape (T,N,N),
       #method .contiguous() will make the tensor stored in contiguous memory
       #center of prf
       xc,yc= model._center(prf[0:para.tw])

       shiftx = xc-record_cell[0]
       shifty = yc-record_cell[1]
       # projection
       proj = shiftx*shift_vec[crfi][0] + shifty*shift_vec[crfi][1]
       center_loss = criterion(xc,Targetx) + criterion(yc,Targety)
       loss = center_loss - 0.01*proj
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()
       tot_loss.append(loss.item())
       if ii % 10 == 0:
           print(
            (
                f'cell number:{crfi};epoch: {ii + 1};center_error: {center_loss.item():.7f};projloss:{proj.item():.5f};'


            )
           )
        # save the model paprameters
     w_ = model.get_model_para()
     cd_dict[sacamp_].append(w_[-1])
     modelW.append(w_)
     torch.save({
            'epoch': crfi,
            'modelW': modelW,
            'cd_dict':cd_dict
            }, PATH)
     """
     with open('data2/tot_weights.npy','wb') as f:
          np.save(f,modelW)
          np.save(f,cd_dict)
          np.save(f,crfi)
     """
  #except:
  #  print(f'cell number:{crfi}')





tensor(0.2000, device='cuda:0', grad_fn=<MinBackward1>)

Model sampling

In [None]:
# Define the functions to get the optimal bandwidth
from sklearn.neighbors import KernelDensity
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import LeaveOneOut
import random
from datetime import datetime
random.seed(datetime.now().timestamp())
def kde2(x,y):
  xy = np.vstack([x,y])
  d=xy.shape[0]
  n=xy.shape[1]
  kde = KernelDensity(bandwidth=1, metric='euclidean',
                      kernel='gaussian',algorithm='ball_tree')
  kde.fit(xy.T)
  return kde
def get_best_bandwidth(x_train):
         grid = GridSearchCV(
         estimator=KernelDensity(kernel='gaussian'),
         param_grid={'bandwidth': np.linspace(0, 5, 1000)},
         cv=LeaveOneOut(),
           )
         grid.fit(x_train[:, np.newaxis])

         return grid.best_params_["bandwidth"]


def get_all_bandwidth(data):
    # data is the fitted model parameters shape:(Nsamples,7)
    siz = data.shape[1]
    bd_list = []
    for i in range(siz):
        tmp_data = data[:,i]
        tmp_bw   = get_best_bandwidth(tmp_data)
        bd_list.append(tmp_bw)
    return bd_list
def model_resp(weight,best_bandwidth,para,crf,shift_vec,cd_dict,N=100):
  """
  Model resampling function
  """
  # resampling the saccade amplitudes
  saclist = [5,10,15,20,25,30,35]
  tot_sacamp = np.random.choice(saclist,size=(N,))
  weight_ = weight[:,:,np.newaxis] # add new dimension required by KernelDensity function
  kde_tot = []
  tot_shift_vec = []
  # Initial the matrix to save the resampled parameters
  w_tot=np.zeros((7,N),dtype='float32')
  ## fit a mix gaussian distribution of weight parameters
  for ii in range(6):
    kde_ = KernelDensity(kernel='gaussian',bandwidth=best_bandwidth[ii]).fit(weight_[:,ii])
    ww_ = kde_.sample(N)
    w_tot[ii] = ww_[0]
    #w_tot[ii] = ww_[:,0]
    ## sample without fitting
    #tmp_weight = weight_[:,ii,0]
    #ww_ = np.random.choice(tmp_weight,N)
    #ww_ = ww_.astype('float32')
    #w_tot[ii] = ww_
  tot_cd = []
  #fit a mix gaussian distribution of cd parameters
  for ii in range(len(saclist)):
    # loop the saccade amplitudes [5,10,15,20,25,30,35]
    tmp_cd = cd_dict[saclist[ii]] # get the cd parameters with corresponding saccade amplitudes saclist[ii]

    kde_ = KernelDensity(kernel='gaussian',bandwidth=best_bandwidth[6]).fit(tmp_cd)
    tmp_sacamp = [i for i in tot_sacamp if i == saclist[ii]] # Determine the sample size for each sac.amplitudes are saclist[ii]
    resample_cd = kde_.sample(len(tmp_sacamp))
    tot_cd.extend(resample_cd)
  w_tot[6] = tot_cd
  w_tot = w_tot.T  # (N,7)
  ## resample cd
  #for ii in range(N):
  #  w_tot[6,ii] = np.random.choice(cd_dict[tot_sacamp[ii]],size=1)
  #  #w_tot[6,ii] = cd_dict[tot_sacamp[ii]]


  np.random.shuffle(crf)

  for ii in range(len(w_tot)):
    # loop the sample size N
    w_ = w_tot[ii]

    record_cell = crf[ii] # crf center
    # load the model parameters
    para.sigE=torch.tensor(w_[0])
    para.sigI=torch.tensor(w_[1])
    para.ampE=torch.tensor(w_[2])
    para.ampI=torch.tensor(w_[3])
    para.attSig=torch.tensor(w_[4])
    para.attSmp=torch.tensor(w_[5])

    para.cd = torch.tensor(w_[6])
    model = CircuitModel(para)
    model = model.to(para.device)
    prf = torch.zeros(para.batchsize,model.maxT,para.batchsize).to(para.device)
    # re-run the model to get the resampled rf
    for i ,stim in enumerate(dataloader):
           #print(i)
           stim = stim.to(para.device)
           #
           tot_activity_,s,prf_ = model(stim,record_cell)
           prf[i] = prf_
    prf = prf.transpose(1,0).contiguous()
    xc,yc= model._center(prf[0:para.tw])
    if ii%10 == 0:
       print(ii)
    # resampled prf shift vectors

    prf_center = torch.tensor([xc,yc]).detach().cpu().numpy()
    shift_vec = prf_center - record_cell
    tot_shift_vec.append(shift_vec)
  return tot_shift_vec,tot_sacamp # return the resample vectors and cds




In [None]:
#from scipy.stats import norm
#from sklearn.neighbors import KernelDensity
#from sklearn.model_selection import GridSearchCV
#from sklearn.model_selection import LeaveOneOut
# Define a 2-d kernel density estimate function

PATH = 'data2/tot_weights.pt'
if os.path.exists(PATH):
    checkpoint = torch.load(PATH)
    start_crfi = checkpoint['epoch']
    modelW = checkpoint['modelW']
    cd_dict = checkpoint['cd_dict']

# Reload the trained model paprameters
'''
with open('data2/tot_weights.npy','rb') as f:
   tot_weight = np.load(f)
   tot_weight = tot_weight.squeeze()
   cd_dict = np.load(f,allow_pickle=True)
   cd_dict = cd_dict.item()
'''
class args:
    sigE =20*torch.rand(1)
    sigI = 20*torch.rand(1)
    ampE = 20*torch.rand(1)
    ampI = 20*torch.rand(1)
    attSig = 20*torch.rand(1)
    attAmp = 20*torch.rand(1)
    cd=20*torch.rand(1)
    InputAmp = 7.0 #(deg)
    InputSig = 10.0
    tau = 30.0 # ms
    downsample = 1
    N = 50//downsample # network size
    dt = 2.0 # time step
    beta = 0.4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batchsize = N
    epochs = 12
    maxT = 60 # total simulation time steps
    tw = 50 # time window of the neural response (time steps)

para = args() #
data = GaussStim(para.N,para.InputAmp,para.InputSig)
dataloader = DataLoader(data,para.batchsize,shuffle=False)

tot_weight = modelW

#tot_weight = np.vstack((pfef_weight,plip_weight)) #shape (Nsample,6)

# Merge the two cd dicts
#for k,v in pfef_cd.items():
#         plip_cd[k].extend(v)
#cd_dict = plip_cd
cd_values = [*cd_dict.values()] # get all cd parameters from the dictionary
cd_flatten =  [ t for x in cd_values for t in x] # flatten the list
cd_array = np.array(cd_flatten)
#tot_weight=np.column_stack((tot_weight,cd_array)) # all model parameters including cd #shape:(Nsample,7)
tot_weight = np.array(tot_weight)
tot_weight = tot_weight.squeeze()
optimal_bandwidth = get_all_bandwidth(tot_weight)
#optimal_bandwidth = [2,1,1,0.3,1,2,0.2]


kde_crf = kde2(crf[:,0],crf[:,1])


lip_vec,cd_amp_lip = [],[]

# resample new crf centers by 2-d Kernel Density Estimation
Nsample = 180
new_crf = kde_crf.sample(Nsample)
new_crf = new_crf.astype('int')




# call the resample function
resp_vec , cd_amp = model_resp(tot_weight,optimal_bandwidth,para,new_crf,shift_vec,cd_dict,Nsample)



In [None]:
new_crf.shape

In [None]:
with open('data/resample_vectors_01.npy','wb') as f:
    np.save(f,resp_vec)
    np.save(f,cd_amp)


In [None]:
"""
class args:
    sigE =20*torch.rand(1)
    sigI = 20*torch.rand(1)
    ampE = 20*torch.rand(1)
    ampI = 20*torch.rand(1)
    attSig = 20*torch.rand(1)
    attAmp = 20*torch.rand(1)
    cd=20*torch.rand(1)
    InputAmp = 7.0 #(deg)
    InputSig = 10.0
    tau = 30.0 # ms
    downsample = 1
    N = 50//downsample # network size
    dt = 2.0 # time step
    beta = 0.4
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batchsize = N
    epochs = 12
    maxT = 60 # total simulation time steps
    tw = 50 # time window of the neural response (time steps)

para = args() #
data = GaussStim(para.N,para.InputAmp,para.InputSig)
dataloader = DataLoader(data,para.batchsize,shuffle=False)
with open('data2/plip_para_sacon0.npy','rb') as f:
  tot_weight = np.load(f)
  tot_weight = tot_weight.squeeze()
  cd_dict = np.load(f,allow_pickle=True)
  cd_dict = cd_dict.item()
cd_values = [*cd_dict.values()] # get all cd parameters from the dictionary
cd_flatten =  [ t for x in cd_values for t in x] # flatten the list
cd_array = np.array(cd_flatten)
#tot_weight=np.column_stack((tot_weight,cd_array))
#best_bandwidth = get_all_bandwidth(tot_weight)
best_bandwidth=[1,1,1,1,1,1,1,1]
kde_crf = kde2(crf[:,0],crf[:,1])
weight = tot_weight
N=100


saclist = [5,10,15,20,25,30,35]
tot_sacamp = np.random.choice(saclist,size=(N,))
weight_ = weight[:,:,np.newaxis] # add new dimension required by KernelDensity function
kde_tot = []
tot_shift_vec = []
  # Initial the matrix to save the resampled parameters
w_tot=np.zeros((7,N),dtype='float32')

for ii in range(6):
    kde_ = KernelDensity(kernel='gaussian',bandwidth=best_bandwidth[ii]).fit(weight_[:,ii])
    ww_ = kde_.sample(N)
    w_tot[ii] = ww_[:,0]
    #w_tot[ii] = ww_[:,0]
    ## sample without fitting
    #tmp_weight = weight_[:,ii,0]
    #ww_ = np.random.choice(tmp_weight,N)
    #ww_ = ww_.astype('float32')
    #w_tot[ii] = ww_
tot_cd = []
for ii in range(len(saclist)):
    # loop the saccade amplitudes [5,10,15,20,25,30,35]
    tmp_cd = cd_dict[saclist[ii]] # get the cd parameters with corresponding saccade amplitudes saclist[ii]
    kde_ = KernelDensity(kernel='gaussian',bandwidth=best_bandwidth[6]).fit(tmp_cd)
    tmp_sacamp = [i for i in tot_sacamp if i == saclist[ii]] # Determine the sample size for each sac.amplitudes are saclist[ii]
    resample_cd = kde_.sample(len(tmp_sacamp))
    tot_cd.extend(resample_cd)
w_tot[6] = tot_cd
w_tot = w_tot.T  # (N,7)
  ## resample cd
  #for ii in range(N):
  #  w_tot[6,ii] = np.random.choice(cd_dict[tot_sacamp[ii]],size=1)
  #  #w_tot[6,ii] = cd_dict[tot_sacamp[ii]]


np.random.shuffle(crf)

for ii in range(len(w_tot)):
    # loop the sample size N
    w_ = w_tot[ii]

    record_cell = crf[ii] # crf center
    # load the model parameters
    para.sigE=torch.tensor(w_[0])
    para.sigI=torch.tensor(w_[1])
    para.ampE=torch.tensor(w_[2])
    para.ampI=torch.tensor(w_[3])
    para.attSig=torch.tensor(w_[4])
    para.attSmp=torch.tensor(w_[5])

    para.cd = torch.tensor(w_[6])
    model = CircuitModel(para)
    model = model.to(para.device)
    prf = torch.zeros(para.batchsize,model.maxT,para.batchsize).to(para.device)
    # re-run the model to get the resampled rf
    for i ,stim in enumerate(dataloader):
           #print(i)
           stim = stim.to(para.device)
           #
           tot_activity_,s,prf_ = model(stim,record_cell)
           prf[i] = prf_
    prf = prf.transpose(1,0).contiguous()
    xc,yc= model._center(prf[0:para.tw])
    if ii%10 == 0:
       print(ii)
    # resampled prf shift vectors

    prf_center = torch.tensor([xc,yc]).detach().cpu().numpy()
    shift_vec = prf_center - record_cell
    tot_shift_vec.append(shift_vec)
    #return tot_shift_vec,tot_sacamp

"""