In [1]:
import numpy as np
import matplotlib.pyplot as plt
import queue
import timeit
import time
import pandas as pd
import os
import os.path

In [2]:
#This code is a duplicate of Py_3DModel_Optimize, with the added features of tracking
#how many ions pass through the cell-cell gap junction and when.

#June 2: In fact, let's include counts of the number of ions that enter from the MARG and 
#from the GAP as well

#(***) adding a new array that keeps track of the cells: jumpJunct
#(***) at each timestep, count the number of ions in the cell that moved 
#across the gap junction to the cell to the right
#those ions will be 
#          (**) the ones going to the right neighbor (which[0] == -3, and
#          (**) the ones that are currently in an INTRA block (0<tempIon[0]<N[0]+1 and 0<tempIon[2]<N[2]+1)
#(***) added a line that will count the number that move across
#the gap junction at each timestep, and append that number to jumpJunct


class Cell:
 
    def __init__(self,C, G, N, ionThresh, openThreshes, cellType, ionChans, gapJunct = 1, NiGAP = 0, NiMARG = 0,\
                moveFilter = 0.2, pokeHoleMARG = False):

        self.C = C
        self.G = G
        self.N = N

        #warning to make sure parameters satisfy certain equalities. See NOTE above the instantiation of
        #self.DimGAP et al. below.
        if C[0]/N[0] != C[1]/N[1] or C[0]/N[0] != C[2]/N[2] or C[1]/N[1]!=C[2]/N[2] or G[0] != G[2]:
            print('Warning: The model assumes equal ratios of C[i]/N[i], i = 1,2,3, and that G[0] = G[2].'+\
                  'If these values do not match the model may not give out the correct result.')
        
        #
        self.moveFilter = moveFilter

        IONS_MARGIN = 42000
        IONS_GAP = 800
        
        self.DimINTRA = [C[0]/N[0],C[1]/N[1],C[2]/N[2]]
        #dimensions of a GAP box
        self.DimGAP = [C[0]/N[0],G[1],C[2]/N[2]]
        self.DimMARG = [G[0],C[1]/N[1],C[2]/N[2]]

        #surface areas of the boxes 
        self.INTRA_SA = self.SurfArea(self.DimINTRA)
        self.GAP_SA = self.SurfArea(self.DimGAP)
        self.MARG_SA = self.SurfArea(self.DimMARG)

        #volumes of squares
        self.INTRA_VOL = self.Volu(self.DimINTRA)
        self.GAP_VOL = self.Volu(self.DimGAP)
        self.MARG_VOL =  self.Volu(self.DimMARG)

        #use INTRA because in both cases the interface is a face of an INTRA block
        #self.GAPchanCoeff = ionChans[0]/(self.INTRA_SA/6)
        #self.MARGINchanCoeff = ionChans[1]/(self.INTRA_SA/6)

        #actually, make it so that the channel coefficients start at 0 and approach
        #1 as the number of ion channels per region increases
        #Note: don't have to use ion channel density on a face since all faces involving
        #ion channels are faces of INTRA blocks. In other words, the number of ion channels
        #completely determines the ion channel density.
        self.GAPchanCoeff = 1-np.exp(-ionChans[0])
        self.MARGINchanCoeff = 1-np.exp(-ionChans[1])
        self.CellCellchanCoeff = gapJunct

        #the order of the destination square in each list is: STAY, INTRA, MARG, GAP
        self.INTRA = np.array([self.computeStay(self.INTRA_VOL,self.INTRA_SA),self.DimINTRA[0]*self.DimINTRA[1],0,0])
        self.GAP = np.array([self.computeStay(self.GAP_VOL,self.GAP_SA),self.DimGAP[0]*self.DimGAP[2],\
                            self.DimGAP[0]*self.DimGAP[1], self.DimGAP[0]*self.DimGAP[1]])
        self.MARG = np.array([self.computeStay(self.MARG_VOL,self.MARG_SA),self.DimMARG[1]*self.DimMARG[2],\
                            self.DimMARG[0]*self.DimMARG[1],self.DimGAP[1]*self.DimGAP[2]])
        
        if NiGAP != 0:
            self.NiGAP = NiGAP
        else:
            self.NiGAP = IONS_GAP
        if NiMARG != 0:
            self.NiMARG = NiMARG
        else:
            self.NiMARG = IONS_MARGIN

        #CellType; whether the cell is 'leftend', 'middle', or 'rightend'
        self.cellType = cellType

        #this is the minimum number of ions in a grid square to block an ion
        #from moving in there
        self.ionThresh = ionThresh


        self.GAP_Thresh = openThreshes[0]
        self.MARG_Thresh = openThreshes[1]
        
        #initialize ions to the left of the cell (this is GAP)
        ionDownUp = np.random.uniform(self.G[0],self.G[0]+self.C[0],self.NiGAP)
        ionLeftRight = np.array([0.5*self.G[1] for i in range(self.NiGAP)])
        ionOutIn = np.random.uniform(self.G[2],self.G[2]+self.C[2],self.NiGAP)

        #initialize the ions above the cell (this is MARG)
        ionDownUp = np.concatenate((ionDownUp,np.array([self.C[0]+1.5*self.G[0] for i in range(self.NiMARG)])),axis = None)
        ionLeftRight = np.concatenate((ionLeftRight, np.random.uniform(self.G[1],self.G[1]+self.C[1],self.NiMARG)),axis = None)
        ionOutIn = np.concatenate((ionOutIn,np.random.uniform(self.G[2],self.G[2]+self.C[2],self.NiMARG)),axis = None)
        
        #initalize the ions below the cell (this is MARG)
        ionDownUp = np.concatenate((ionDownUp,np.array([0.5*self.G[0] for i in range(self.NiMARG)])),axis = None)
        ionLeftRight = np.concatenate((ionLeftRight, np.random.uniform(self.G[1],self.G[1]+self.C[1],self.NiMARG)),axis = None)
        ionOutIn = np.concatenate((ionOutIn,np.random.uniform(self.G[2],self.G[2]+self.C[2],self.NiMARG)),axis = None)

        #initialize the ions out in front of the cell (this is MARG)
        ionDownUp = np.concatenate((ionDownUp,np.random.uniform(self.G[0],self.G[0]+self.C[0],self.NiMARG)),axis = None)
        ionLeftRight = np.concatenate((ionLeftRight, np.random.uniform(self.G[1],self.G[1]+self.C[1],self.NiMARG)),axis = None)
        ionOutIn = np.concatenate((ionOutIn,np.array([self.C[2]+1.5*self.G[2] for i in range(self.NiMARG)])),axis = None)

        #initialize the ions behind the cell (this is MARG)
        ionDownUp = np.concatenate((ionDownUp,np.random.uniform(self.G[0],self.G[0]+self.C[0],self.NiMARG)),axis = None)
        ionLeftRight = np.concatenate((ionLeftRight, np.random.uniform(self.G[1],self.G[1]+self.C[1],self.NiMARG)),axis = None)
        ionOutIn = np.concatenate((ionOutIn,np.array([0.5*self.G[2] for i in range(self.NiMARG)])),axis = None)


        #October 2, 2024: Keep the bndOpen matrix as Ny+2 x Nx+2; just remember that when checking
        #bndOpen of the final column we will use ionMat[:,0] of the next cell
        self.bndOpen = np.zeros((self.N[0]+2,self.N[1]+2,self.N[2]+2))
        #September 12, 2024: poke a hole in the boundary if it's 'leftend'
        if self.cellType == 'leftend':
            if pokeHoleMARG:
                #poke a hole on the margin instead
                self.bndOpen[0,1,np.ceil(N[2]/2).astype(int)]=1
            else:
                self.bndOpen[np.ceil(N[0]/2).astype(int),0,np.ceil(N[2]/2).astype(int)]=1
            
            
        self.bndOpen = self.bndOpen.astype(int)
    
        if cellType == 'rightend':
            #initialize the ions to the right of the cell
            ionDownUp = np.concatenate((ionDownUp,np.random.uniform(self.G[0],self.G[0]+self.C[0],self.NiGAP)),axis = None)
            ionLeftRight = np.concatenate((ionLeftRight,np.array([self.C[1]+1.5*self.G[1] for i in range(self.NiGAP)])),axis = None)
            ionOutIn = np.concatenate((ionOutIn,np.random.uniform(self.G[2],self.G[2]+self.C[2],self.NiGAP)),axis = None)
        else:
            pass
        
        #counts how many times we've moved the ions
        self.counter = 0
        #set equal to self.counter once we first meet the depolarization criterion
        self.depolar = 0
        #an array that will count how many ions traveled using the gap junction to the righthand cell
        self.jumpJunct = []
        #array that will count how many ions entered through the GAP
        self.enterGAPself = []
        #array that will count how many ions entered through the GAP into the cell on the left
        self.enterGAPleft = []
        #array that will count how many ions entered through MARG
        self.enterMARG = []

        #note: self.DimINTRA = [C[0]/N[0], C[1]/N[1], C[2]/N[2]]
        self.ionDU = (np.maximum(np.minimum(np.ceil((ionDownUp-self.G[0])/self.DimINTRA[0]),self.N[0]),1)).astype(int)
        self.ionLR = (np.maximum(np.minimum(np.ceil((ionLeftRight-self.G[1])/self.DimINTRA[1]),self.N[1]),1)).astype(int)
        self.ionOI = (np.maximum(np.minimum(np.ceil((ionOutIn-self.G[2])/self.DimINTRA[2]),self.N[2]),1)).astype(int)


        self.ionDU[ionDownUp < self.G[0]] = 0
        self.ionDU[ionDownUp > self.C[0]+self.G[0]] = self.N[0]+1
        self.ionLR[ionLeftRight < self.G[1]] = 0
        if self.cellType == 'rightend':
            self.ionLR[ionLeftRight > self.C[1]+self.G[1]] = self.N[1]+1
        self.ionOI[ionOutIn < self.G[2]] = 0
        self.ionOI[ionOutIn > self.C[2]+self.G[2]] = self.N[2]+1
    
        #initialize ion-grid matrix like in getionCount()
        #October 2, 2024: we get rid of the last column 
        if self.cellType == 'rightend':
            self.ionMat = np.zeros((self.N[0]+2,self.N[1]+2,self.N[2]+2))
        else:
            self.ionMat = np.zeros((self.N[0]+2,self.N[1]+1,self.N[2]+2))
        for k in range(len(self.ionDU)):
            self.ionMat[self.ionDU[k],self.ionLR[k],self.ionOI[k]] += 1
        self.ionMat = self.ionMat.astype(int)

    #compute the surface area of a box from its dimensions; helper function
    #for __init__ above
    def SurfArea(self,dims):
        return 2*(dims[0]*dims[1]+dims[0]*dims[2]+dims[1]*dims[2])

    #computes the volume of a box from its dimensions; helper function 
    #for __init__ above
    def Volu(self,dims):
        return dims[0]*dims[1]*dims[2]

    #this is for determining the STAY version of the cross-section coefficient;
    #in case we want to change it up from Volume/SurfaceArea (which is 3-dim divided by 2-dim)
    def computeStay(self,Vol,SurfArea):
        #original
        #return Vol/SurfArea
        return np.cbrt(Vol)/np.sqrt(SurfArea)
    
    #determine the square the ion is in
    def boxType(self,currDU,currLR,currOI):
        if (currDU==0 or currDU==self.N[0]+1) and  currLR>=1 and currLR<=self.N[1] and currOI>=1 and currOI<=self.N[2]:
            return -2
        elif (currLR==0 or currLR==self.N[1]+1) and  currDU>=1 and currDU<=self.N[0] and currOI>=1 and currOI<=self.N[2]:
            return -1
        elif (currOI==0 or currOI==self.N[2]+1) and  currDU>=1 and currDU<=self.N[0] and currLR>=1 and currLR<=self.N[1]:
            return -2
        elif currDU>=1 and currLR>=1 and currOI>=1 and currDU<=self.N[0] and currLR<=self.N[1] and currOI<=self.N[2]:
            return -3
        else:
            print('Invalid square for ion: ('+ str(currI)+', '+str(currJ)+')')
            return 0
    
    #for each grid square, count the number of ions in that square
    #used for newPos function, to determine if a grid square has too many ions
    #to accept new members
    def getionCount(self):
        if self.cellType == 'rightend':
            self.ionMat = np.zeros((self.N[0]+2,self.N[1]+2,self.N[2]+2))
        else:
            self.ionMat = np.zeros((self.N[0]+2,self.N[1]+1,self.N[2]+2))
        for k in range(len(self.ionDU)):
            self.ionMat[self.ionDU[k],self.ionLR[k],self.ionOI[k]] += 1
        self.ionMat = self.ionMat.astype(int)
  
    def getAvail(self, tempIon, leftNbhr=0, rightNbhr = 0):

        colZeros = np.zeros(tempIon.shape[1])
        colOnes = np.ones(tempIon.shape[1])

        colN0s = np.array(tempIon.shape[1]*[self.N[0]])
        colN1s = np.array(tempIon.shape[1]*[self.N[1]])
        colN2s = np.array(tempIon.shape[1]*[self.N[2]])
        codeI = np.array(tempIon.shape[1]*[-4])
        codeS = np.array(tempIon.shape[1]*[-1])
        codeR = np.array(tempIon.shape[1]*[-3])
        codeL = np.array(tempIon.shape[1]*[-2])
        
        DOWN =  [codeI,tempIon[0],tempIon[1],tempIon[2]]*(tempIon[0] == 0)+\
                (tempIon[0] == 1)*(\
                    [codeI,tempIon[0],tempIon[1],tempIon[2]]*((tempIon[1]>= 1) & (tempIon[1] <= self.N[1]) & (tempIon[2]>=1) & (tempIon[2]<=self.N[2]))+\
                    ((tempIon[1] ==0) | (tempIon[1] == self.N[1]+1))*(\
                        [codeS,colZeros,colOnes,tempIon[2]]*((self.cellType == 'leftend') & (tempIon[1] ==0))+\
                        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                        #DOUBLE TROUBLE
                        [codeS,colZeros,colOnes,tempIon[2]]*((self.cellType != 'leftend') & (tempIon[1] == 0))+\
                        #move this one below to the EXTRAs
                        #[codeL,colZeros,colN1s,tempIon[2]]*(self.cellType != 'leftend' and tempIon[1] == 0)+\
                        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                        [codeS,colZeros,colN1s,tempIon[2]]*((self.cellType == 'rightend') & (tempIon[1] == self.N[1]+1)))+\
                    [codeS,colZeros,tempIon[1],colOnes]*(tempIon[2]==0)+\
                    [codeS,colZeros,tempIon[1],colN2s]*(tempIon[2] == self.N[2]+1))+\
                (tempIon[0] == self.N[0]+1)*(\
                    [codeS,tempIon[0]-1,tempIon[1],tempIon[2]]*(self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1)+\
                    [codeI,tempIon[0],tempIon[1],tempIon[2]]*(~((self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1))))+\
                [codeS,tempIon[0]-1,tempIon[1],tempIon[2]]*(~(((tempIon[0] == 0) | (tempIon[0] == 1) | (tempIon[0] == self.N[0]+1))))


        UP =    [codeI,tempIon[0],tempIon[1],tempIon[2]]*(tempIon[0] == self.N[0]+1)+\
                (tempIon[0] == self.N[0])*(\
                    [codeI,tempIon[0],tempIon[1],tempIon[2]]*((tempIon[1]>= 1) & (tempIon[1] <= self.N[1]) & (tempIon[2]>=1) & (tempIon[2]<=self.N[2]))+\
                    ((tempIon[1] ==0) | (tempIon[1] == self.N[1]+1))*(\
                        [codeS,colN0s+1,colOnes,tempIon[2]]*((self.cellType== 'leftend') & (tempIon[1] ==0))+\
                        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                        #DOUBLE TROUBLE
                        [codeS,colN0s+1,colOnes,tempIon[2]]*((self.cellType != 'leftend') & (tempIon[1] == 0))+\
                        #move this one below to the EXTRAs
                        #[codeL,colN0s+1,colN1s,tempIon[2]]*(self.cellType != 'leftend' and tempIon[1] == 0)+\
                        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                        [codeS,colN0s+1,colN1s,tempIon[2]]*((self.cellType == 'rightend') & (tempIon[1] == self.N[1]+1)))+\
                    [codeS,colN0s+1,tempIon[1],colOnes]*(tempIon[2] ==0)+\
                    [codeS,colN0s+1,tempIon[1],colN2s]*(tempIon[2] == self.N[2]+1))+\
                (tempIon[0] == 0)*(\
                    [codeS,tempIon[0]+1,tempIon[1],tempIon[2]]*(self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1)+\
                    [codeI,tempIon[0],tempIon[1],tempIon[2]]*(~(self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1)))+\
                [codeS,tempIon[0]+1,tempIon[1],tempIon[2]]*(~((tempIon[0] == self.N[0]+1) | (tempIon[0] == self.N[0]) | (tempIon[0] == 0))) 


        IN =    [codeI,tempIon[0],tempIon[1],tempIon[2]]*(tempIon[2] == 0)+\
                (tempIon[2] == 1)*(\
                    [codeI,tempIon[0],tempIon[1],tempIon[2]]*((tempIon[1]>= 1) & (tempIon[1] <= self.N[1]) & (tempIon[0]>=1) & (tempIon[0]<=self.N[0]))+\
                    ((tempIon[1] ==0) | (tempIon[1] == self.N[1]+1))*(\
                        [codeS,tempIon[0],colOnes,colZeros]*((self.cellType== 'leftend') & (tempIon[1] ==0))+\
                        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                        #DOUBLE TROUBLE
                        [codeS,tempIon[0],colOnes,colZeros]*((self.cellType != 'leftend') & (tempIon[1] == 0))+\
                        #move this one below to the EXTRAs
                        #[codeL,tempIon[0],colN1s,colZeros]*(self.cellType != 'leftend' and tempIon[1] == 0)+\
                        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                        [codeS,tempIon[0],colN1s,colZeros]*((self.cellType == 'rightend') & (tempIon[1] == self.N[1]+1)))+\
                    [codeS,colOnes,tempIon[1],colZeros]*(tempIon[0] ==0)+\
                    [codeS,colN0s,tempIon[1],colZeros]*(tempIon[0] == self.N[0]+1))+\
                (tempIon[2] == self.N[2]+1)*(\
                    [codeS,tempIon[0],tempIon[1],tempIon[2]-1]*(self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1)+\
                    [codeI,tempIon[0],tempIon[1],tempIon[2]]*(~(self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1)))+\
                [codeS,tempIon[0],tempIon[1],tempIon[2]-1]*(~((tempIon[2] == 0) | (tempIon[2] == 1) | (tempIon[2] == self.N[2]+1)))
          
        OUT =   [codeI,tempIon[0],tempIon[1],tempIon[2]]*(tempIon[2] == self.N[2]+1)+\
                (tempIon[2] == self.N[2])*(\
                    [codeI,tempIon[0],tempIon[1],tempIon[2]]*((tempIon[1]>= 1) & (tempIon[1] <= self.N[1]) & (tempIon[0]>=1) & (tempIon[0]<=self.N[0]))+\
                    ((tempIon[1] ==0) | (tempIon[1] == self.N[1]+1))*(\
                        [codeS,tempIon[0],colOnes,colN2s+1]*((self.cellType== 'leftend') & (tempIon[1] ==0))+\
                        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                        #DOUBLE TROUBLE
                        [codeS,tempIon[0],colOnes,colN2s+1]*((self.cellType != 'leftend') & (tempIon[1] == 0))+\
                        #move this one below to the EXTRAs
                        #[codeL,tempIon[0],colN1s,colN2s+1]*(self.cellType != 'leftend' and tempIon[1] == 0)+\
                        #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                        [codeS,tempIon[0],colN1s,colN2s+1]*((self.cellType == 'rightend') & (tempIon[1] == self.N[1]+1)))+\
                    [codeS,colOnes,tempIon[1],colN2s+1]*(tempIon[0] ==0)+\
                    [codeS,colN0s,tempIon[1],colN2s+1]*(tempIon[0] == self.N[0]+1))+\
                (tempIon[2] == 0)*(\
                    [codeS,tempIon[0],tempIon[1],tempIon[2]+1]*(self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1)+\
                    [codeI,tempIon[0],tempIon[1],tempIon[2]]*(~(self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1)))+\
                [codeS,tempIon[0],tempIon[1],tempIon[2]+1]*(~((tempIon[2] == self.N[2]+1) | (tempIon[2] == self.N[2]) | (tempIon[2] == 0)))


        RIGHT = [codeI,tempIon[0],tempIon[1],tempIon[2]]*(tempIon[1] == self.N[1]+1)+\
                (tempIon[1] == self.N[1])*(\
                    (self.cellType == 'rightend')*(\
                        [codeI,tempIon[0],tempIon[1],tempIon[2]]*((tempIon[0]>=1) & (tempIon[0]<=self.N[0]) & (tempIon[2]>=1) & (tempIon[2]<=self.N[2]))+\
                        [codeS,colOnes,colN1s+1,tempIon[2]]*(tempIon[0] ==0)+\
                        [codeS,colN0s,colN1s+1,tempIon[2]]*(tempIon[0] == self.N[0]+1)+\
                        [codeS,tempIon[0],colN1s+1,colOnes]*(tempIon[2] ==0)+\
                        [codeS,tempIon[0],colN1s+1,colN2s]*(tempIon[2] == self.N[2]+1))+\
                    (self.cellType!='rightend')*(\
                        [codeR,tempIon[0],colOnes,tempIon[2]]*((tempIon[0]>=1) & (tempIon[0]<=self.N[0]) & (tempIon[2]>=1) & (tempIon[2]<=self.N[2]))+\
                        [codeR,colOnes,colZeros,tempIon[2]]*(tempIon[0] ==0)+\
                        [codeR,colN0s,colZeros,tempIon[2]]*(tempIon[0] == self.N[0]+1)+\
                        [codeR,tempIon[0],colZeros,colOnes]*(tempIon[2] ==0)+\
                        [codeR,tempIon[0],colZeros,colN2s]*(tempIon[2] == self.N[2]+1)))+\
                (tempIon[1] == 0)*(\
                    [codeS,tempIon[0],tempIon[1]+1,tempIon[2]]*(self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1)+\
                    [codeI,tempIon[0],tempIon[1],tempIon[2]]*(~(self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1)))+\
                [codeS,tempIon[0],tempIon[1]+1,tempIon[2]]*(~((tempIon[1] == self.N[1]+1) | (tempIon[1] == self.N[1]) | (tempIon[1] == 0)))     

        if  self.cellType == 'leftend':
            LEFT = (tempIon[1] == 0)*(\
                    #include tempIon[1]==tempIon[1] tautology to create an array
                    [codeI,tempIon[0],tempIon[1],tempIon[2]]*(tempIon[1]==tempIon[1]))+\
                (tempIon[1] == 1)*(\
                    [codeI,tempIon[0],tempIon[1],tempIon[2]]*((tempIon[0]>=1) & (tempIon[0]<=self.N[0]) & (tempIon[2]>=1) & (tempIon[2]<=self.N[2]))+\
                    [codeS,colOnes,colZeros,tempIon[2]]*(tempIon[0] ==0)+\
                    [codeS,colN0s,colZeros,tempIon[2]]*(tempIon[0] == self.N[0]+1)+\
                    [codeS,tempIon[0],colZeros,colOnes]*(tempIon[2] ==0)+\
                    [codeS,tempIon[0],colZeros,colN2s]*(tempIon[2] == self.N[2]+1))+\
                (tempIon[1] == self.N[1]+1)*(\
                    [codeS,tempIon[0],tempIon[1]-1,tempIon[2]]*(self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1)+\
                    [codeI,tempIon[0],tempIon[1],tempIon[2]]*(~(self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1)))+\
                [codeS,tempIon[0],tempIon[1]-1,tempIon[2]]*(~((tempIon[1] == 0) | (tempIon[1] == 1) | (tempIon[1] == self.N[1]+1)))  

        #have to separate the cases for when the cell is or isn't 'leftend', since it can't call leftNbhr when it is 'leftend'
        else:
            LEFT =  (tempIon[1] == 0)*(\
                        [codeL,tempIon[0],colN1s,tempIon[2]]*(leftNbhr.bndOpen[tempIon[0],colN1s+1,tempIon[2]] == 1)+\
                        [codeI,tempIon[0],tempIon[1],tempIon[2]]*(leftNbhr.bndOpen[tempIon[0],colN1s+1,tempIon[2]] != 1))+\
                    (tempIon[1] == 1)*(\
                        [codeL,tempIon[0],colN1s,tempIon[2]]*((tempIon[0]>=1) & (tempIon[0]<=self.N[0]) & (tempIon[2]>=1) & (tempIon[2]<=self.N[2]))+\
                        [codeS,colOnes,colZeros,tempIon[2]]*(tempIon[0] ==0)+\
                        [codeS,colN0s,colZeros,tempIon[2]]*(tempIon[0] == self.N[0]+1)+\
                        [codeS,tempIon[0],colZeros,colOnes]*(tempIon[2] ==0)+\
                        [codeS,tempIon[0],colZeros,colN2s]*(tempIon[2] == self.N[2]+1))+\
                    (tempIon[1] == self.N[1]+1)*(\
                        [codeS,tempIon[0],tempIon[1]-1,tempIon[2]]*(self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1)+\
                        [codeI,tempIon[0],tempIon[1],tempIon[2]]*(~(self.bndOpen[tempIon[0],tempIon[1],tempIon[2]] ==1)))+\
                    [codeS,tempIon[0],tempIon[1]-1,tempIon[2]]*(~((tempIon[1] == 0) | (tempIon[1] == 1) | (tempIon[1] == self.N[1]+1)))  

        EXTRA = [codeL,colZeros,colN1s,tempIon[2]]*((tempIon[0] == 1) & (self.cellType != 'leftend') & (tempIon[1] == 0))+\
                [codeL,colN0s+1,colN1s,tempIon[2]]*((tempIon[0] == self.N[0]) & (self.cellType != 'leftend') & (tempIon[1] == 0))+\
                (tempIon[1] == self.N[1])*(\
                    (self.cellType!='rightend')*(\
                        [codeR,colZeros,colOnes,tempIon[2]]*(tempIon[0] ==0)+\
                        [codeR,colN0s+1,colOnes,tempIon[2]]*(tempIon[0] == self.N[0]+1)+\
                        [codeR,tempIon[0],colOnes,colZeros]*(tempIon[2] ==0)+\
                        [codeR,tempIon[0],colOnes,colN2s+1]*(tempIon[2] == self.N[2]+1)))+\
                (tempIon[1] == 1)*(\
                    (self.cellType != 'leftend')*(\
                        [codeL,colZeros,colN1s,tempIon[2]]*(tempIon[0] ==0)+\
                        [codeL,colN0s+1,colN1s,tempIon[2]]*(tempIon[0] == self.N[0]+1)+\
                        [codeL,tempIon[0],colN1s,colZeros]*(tempIon[2] ==0)+\
                        [codeL,tempIon[0],colN1s,colN2s+1]*(tempIon[2] == self.N[2]+1)))+\
                [codeI, tempIon[0],tempIon[1],tempIon[2]]*(~(((tempIon[0] == 1) & (self.cellType != 'leftend') & (tempIon[1] == 0)) | \
                                                              ((tempIon[0] == self.N[0]) & (self.cellType != 'leftend') & (tempIon[1] == 0)) |\
                                                              ((tempIon[1] == self.N[1]) & (self.cellType!='rightend') & \
                                                              ((tempIon[0]==0) | (tempIon[0]==self.N[0]+1) | (tempIon[2]==0) | (tempIon[2] == self.N[2]+1))) |\
                                                              ((tempIon[1] == 1) & (self.cellType != 'leftend') & \
                                                              ((tempIon[0]==0) | (tempIon[0]==self.N[0]+1) | (tempIon[2]==0) | (tempIon[2] == self.N[2]+1))))) 
                        
        EXTRA2 = [codeL,tempIon[0],colN1s,colZeros]*((tempIon[2] == 1) & (self.cellType != 'leftend') & (tempIon[1] == 0))+\
                [codeL,tempIon[0],colN1s,colN2s+1]*((tempIon[2] == self.N[2]) & (self.cellType != 'leftend') & (tempIon[1] == 0))+\
                [codeI, tempIon[0],tempIon[1],tempIon[2]]*(~(((tempIon[2] == 1) & (self.cellType != 'leftend') & (tempIon[1] == 0))\
                                                              | ((tempIon[2] == self.N[2]) & (self.cellType != 'leftend') & (tempIon[1] == 0))))
        
        return np.array(UP), np.array(DOWN), np.array(IN), np.array(OUT), np.array(RIGHT), np.array(LEFT),\
            np.array(EXTRA), np.array(EXTRA2)
                    
    #March 5, 2025: optimize; make it so that the inputs are as follows:
    def computeMoveScore(self, current, destination,lN,rN,colCats,stay):
        
        if self.cellType == 'leftend':
            ionCount = (destination[0,:] == -3)*(rN.ionMat[destination[1,:].astype(int),destination[2,:].astype(int),destination[3,:].astype(int)])+\
                   ((destination[0,:] != -3) & (destination[0,:]!=-2))*(self.ionMat[destination[1,:].astype(int),destination[2,:].astype(int),destination[3,:].astype(int)])
        elif self.cellType == 'rightend':
            #include the np.min for the case where the block is in RL coordinate 14, which in lN doesn't exist since it has only thirteen blocks RL-wise, not 14
             ionCount = (destination[0,:] == -2)*(lN.ionMat[destination[1,:].astype(int),np.fmin(destination[2,:].astype(int),12),destination[3,:].astype(int)])+\
                   ((destination[0,:] != -3) & (destination[0,:]!=-2))*(self.ionMat[destination[1,:].astype(int),destination[2,:].astype(int),destination[3,:].astype(int)])
        else:
            ionCount = (destination[0,:] == -2)*(lN.ionMat[destination[1,:].astype(int),destination[2,:].astype(int),destination[3,:].astype(int)])+\
                   (destination[0,:] == -3)*(rN.ionMat[destination[1,:].astype(int),destination[2,:].astype(int),destination[3,:].astype(int)])+\
                   ((destination[0,:] != -3) & (destination[0,:]!=-2))*(self.ionMat[destination[1,:].astype(int),destination[2,:].astype(int),destination[3,:].astype(int)])
        
        currSquare = colCats[0]*((current[1] == 0) | ((current[1] == self.N[1]+1) & (self.cellType == 'rightend')))+\
                    colCats[1]*((current[0] == 0) | (current [0] == self.N[0]+1) | (current[2] == 0) | (current[2] == self.N[2]+1))+\
                    colCats[2]*(~((current[0] == 0) | (current [0] == self.N[0]+1) | (current[1] == 0)\
                                | ((current[1] == self.N[1]+1) & (self.cellType=='rightend'))\
                                | (current[2] == 0) | (current[2] == self.N[2]+1)))

        destSquare = colCats[3]*((destination[1] == current[0]) & (destination[2]==current[1]) & (destination[3] == current[2]))+\
                    (~((destination[1] == current[0]) & (destination[2]==current[1]) & (destination[3] == current[2])))*(\
                        colCats[0]*((destination[2] == 0)  | ((destination[2] == self.N[1]+1) & (self.cellType == 'rightend')))+\
                        colCats[1]*((destination[1] == 0) | (destination [1] == self.N[0]+1) | (destination[3] == 0) | (destination[3] == self.N[2]+1))+\
                        colCats[2]*(~((destination[1] == 0) | (destination[1] == self.N[0]+1) | (destination[2] == 0) |\
                                  ((destination[2] == self.N[1]+1) & (self.cellType=='rightend'))\
                                | (destination[3] == 0) | (destination[3] == self.N[2]+1)\
                                | ((destination[1] == current[0]) & (destination[2]==current[1]) & (destination[3] == current[2])))))

        #May 12, 2025: implement checking to see if this is a ion channel-moderated
        #jump, and include multiplication by appropriate channel coefficient
        chanCoef = self.GAPchanCoeff*((currSquare == -1) & (destSquare == -3))+\
            self.MARGINchanCoeff*((currSquare == -2) & (destSquare == -3))+\
            self.CellCellchanCoeff*((currSquare == -3) & (destSquare == -3) & (destination[0] !=-1))+\
            +1*(~(((currSquare == -1) & (destSquare == -3)) |\
                    ((currSquare == -2) & (destSquare == -3)) |\
                    ((currSquare == -3) & (destSquare == -3) & (destination[0]!=-1))))
        
        #rewrite the volume of the destination square as a formula
        destVol = (destSquare == -2)*self.MARG_VOL+(destSquare == -1)*self.GAP_VOL+(destSquare == -2)*self.MARG_VOL+\
            (destSquare == -3)*self.INTRA_VOL+\
            (destSquare == -4)*((currSquare == -2)*self.MARG_VOL+(currSquare == -1)*self.GAP_VOL+(currSquare == -2)*self.MARG_VOL+\
            (currSquare == -3)*self.INTRA_VOL)

        csCoef = (currSquare == -2)*self.MARG[destSquare]+\
            (currSquare == -1)*self.GAP[destSquare]+\
            (currSquare == -3)*self.INTRA[destSquare]
        
        return (destination[0]!=-4)*chanCoef*csCoef/(1+np.exp((ionCount/np.sqrt(destVol))-self.ionThresh))
        #return csCoef/(1+np.exp(ionCount-self.ionThresh))

    def computeJump(self,tempIon, lN =0, rN = 0,show = False,timeRun = False):
        #get it so that adjs is a list of dictionaries of available blocks for each ion
        if timeRun:
            tgetavail0 = time.time()
        UP, DOWN, IN, OUT, RIGHT, LEFT, EXTRA, EXTRA2 = self.getAvail(tempIon,leftNbhr = lN,rightNbhr = rN)           
        if timeRun:
            tgetavail1 = time.time()
            print('Time to run getAvail is')
            print(tgetavail1-tgetavail0)
        
        colGAPs = np.array(tempIon.shape[1]*[-1])
        colMARGs = np.array(tempIon.shape[1]*[-2])
        colINTRAs = np.array(tempIon.shape[1]*[-3])
        colSTAYs = np.array(tempIon.shape[1]*[-4])
        colCats = [colGAPs, colMARGs, colINTRAs, colSTAYs]

        moveScores = np.zeros((9,tempIon.shape[1]))
        tempIonDest = np.vstack((np.array(tempIon.shape[1]*[-1]),tempIon))

        if timeRun:
            tcompMove0 = time.time()
        moveScores[0,:] = self.computeMoveScore(tempIon,UP,lN,rN,colCats,stay = False)
        moveScores[1,:] = self.computeMoveScore(tempIon,DOWN,lN,rN,colCats,stay = False)
        moveScores[2,:] = self.computeMoveScore(tempIon,OUT,lN,rN,colCats, stay = False)
        moveScores[3,:] = self.computeMoveScore(tempIon,IN,lN,rN,colCats, stay = False)
        moveScores[4,:] = self.computeMoveScore(tempIon,RIGHT,lN,rN,colCats, stay = False)
        moveScores[5,:] = self.computeMoveScore(tempIon,LEFT,lN,rN,colCats, stay = False)
        moveScores[6,:] = self.computeMoveScore(tempIon,EXTRA,lN,rN,colCats, stay = False)
        moveScores[7,:] = self.computeMoveScore(tempIon,EXTRA2,lN,rN,colCats, stay = False)
        moveScores[8,:] = self.computeMoveScore(tempIon,tempIonDest,lN,rN,colCats,stay = True)
        if timeRun:
            tcompMove1 = time.time()
            print('Time to compute moveScores is')
            print(tcompMove1-tcompMove0)
        
        cumProb = np.cumsum(moveScores/np.sum(moveScores, axis = 0),axis = 0)
        
        x = np.random.uniform(0,1,tempIon.shape[1])
        
        return UP*(x<=cumProb[0])+DOWN*((cumProb[0]<x) & (x<=cumProb[1]))+OUT*((cumProb[1]<x) & (x<=cumProb[2]))+\
            IN*((cumProb[2]<x) & (x<=cumProb[3]))+\
            RIGHT*((cumProb[3]<x) & (x<=cumProb[4]))+LEFT*((cumProb[4]<x)&(x<=cumProb[5]))+EXTRA*((cumProb[5]<x)&(x<=cumProb[6]))+\
            EXTRA2*((cumProb[6]<x)& (x<=cumProb[7]))+tempIonDest*(cumProb[7]<x)

    def newIon(self,timeRun = False):
        if timeRun:
            tnewion0 = time.time()
        self.newionDU = np.array([])
        self.newionLR = np.array([])
        self.newionOI = np.array([])
        if timeRun:
            tnewion1 = time.time()
            print('Time to initalize newion arrays is')
            print(tnewion1-tnewion0)
    
    def moveIons(self,LeftN,RightN,show = False,timeIt = False):
        #create array of random numbers of length ionDU (or ionOI, or ionLR, should all be equal)

        if timeIt:
            tfilter0 = time.time()
        
        willMove = np.random.uniform(size = len(self.ionDU))
        #do moveFilter before computeJump, so we aren't unnecessarily computing jumps
        #for like ~90% or more of ions
        self.newionDU = np.concatenate((self.newionDU,self.ionDU[willMove>self.moveFilter]))
        self.newionLR = np.concatenate((self.newionLR,self.ionLR[willMove>self.moveFilter]))
        self.newionOI = np.concatenate((self.newionOI,self.ionOI[willMove>self.moveFilter]))
        #tempionDU = self.ionDU[willMove<=self.moveFilter]
        #tempionLR = self.ionLR[willMove<=self.moveFilter]
        #tempionOI = self.ionOI[willMove<=self.moveFilter]
        #or maybe bundle them together?
        tempIon = np.vstack((self.ionDU[willMove<=self.moveFilter],self.ionLR[willMove<=self.moveFilter],
                                 self.ionOI[willMove<=self.moveFilter]))

        if timeIt:
            tfilter1 = time.time()
            print('Time to filter out the ions that will move')
            print(tfilter1-tfilter0)
              
        if self.cellType == 'leftend':
            #which, forDU, forLR, forOI = self.computeJump(tempIon,rN = RightN,show = show)
            which = self.computeJump(tempIon,rN = RightN,show = show,timeRun = timeIt)
        elif self.cellType == 'rightend':
            #which, forDU, forLR, forOI = self.computeJump(tempIon,lN=LeftN,show = show)
            which = self.computeJump(tempIon,lN=LeftN,show = show,timeRun = timeIt)
        else:
            #which, forDU, forLR, forOI = self.computeJump(tempIon,lN = LeftN, rN = RightN,show = show)
            which = self.computeJump(tempIon,lN = LeftN, rN = RightN,show = show,timeRun=timeIt)

        #this is where we count the number that go right across the junction this timestep
        self.jumpJunct.append(((which[0,:]==-3) & (0<tempIon[0]) & (tempIon[0]<self.N[0]+1)\
                               & (0<tempIon[2]) & (tempIon[2]<self.N[2]+1) & (which[2]==1)).sum())
        self.enterMARG.append(((tempIon[0]==0) & (which[0]==-1) & (which[1] ==1) & (tempIon[1]!=0) & (tempIon[2]!=0) & (tempIon[2]!=self.N[2]+1)\
                              &(tempIon[1] == which[2]) & (tempIon[2] == which[3])).sum()+\
                              ((tempIon[0]==self.N[0]+1) & (which[0]==-1) & (which[1] ==self.N[0]) & (tempIon[1]!=0) & (tempIon[2]!=0) & (tempIon[2]!=self.N[2]+1)\
                              &(tempIon[1] == which[2]) & (tempIon[2] == which[3])).sum()+\
                              ((tempIon[2]==0) & (which[0]==-1) & (which[3] ==1) & (tempIon[1]!=0) & (tempIon[0]!=0) & (tempIon[0]!=self.N[0]+1)\
                              &(tempIon[1] == which[2]) & (tempIon[0] == which[1])).sum()+\
                              ((tempIon[2]==self.N[2]+1) & (which[0]==-1) & (which[3] ==self.N[2]) & (tempIon[1]!=0) & (tempIon[0]!=0) & (tempIon[0]!=self.N[0]+1)\
                              &(tempIon[1] == which[2]) & (tempIon[0] == which[1])).sum())

        self.enterGAPself.append(((tempIon[1] == 0) & (which[0] == -1) & (which[2] == 1) & (tempIon[0]!=0)\
                                  &(tempIon[0]!=self.N[0]+1) & (tempIon[2]!=0) & (tempIon[2]!=self.N[2]+1)\
                                  &(tempIon[0] == which[1]) & (tempIon[2] == which[3])).sum())

        if self.cellType != 'leftend':
            self.enterGAPleft.append(((tempIon[1] == 0) & (which[0] == -2) & (which[2] == LeftN.N[1]) & (tempIon[0]!=0)\
                                     &(tempIon[0]!=self.N[0]+1) & (tempIon[2]!=0) & (tempIon[2]!=self.N[2]+1)\
                                     &(tempIon[0] == which[1]) & (tempIon[2] == which[3])).sum())
        
        
        self.newionDU = np.concatenate((self.newionDU,which[1,which[0,:]==-1]))
        self.newionLR = np.concatenate((self.newionLR,which[2,which[0,:]==-1]))
        self.newionOI = np.concatenate((self.newionOI,which[3,which[0,:]==-1]))

        if self.cellType != 'leftend':
            LeftN.newionDU = np.concatenate((LeftN.newionDU,which[1,which[0,:]==-2]))
            LeftN.newionLR = np.concatenate((LeftN.newionLR,which[2,which[0,:]==-2]))
            LeftN.newionOI = np.concatenate((LeftN.newionOI,which[3,which[0,:]==-2]))

        if self.cellType !='rightend':
            RightN.newionDU = np.concatenate((RightN.newionDU,which[1,which[0,:]==-3]))
            RightN.newionLR = np.concatenate((RightN.newionLR,which[2,which[0,:]==-3]))
            RightN.newionOI = np.concatenate((RightN.newionOI,which[3,which[0,:]==-3]))

    #Do this once we have determined the move for each ion acrseloss all cells.
    #Update ionI and ionJ with newionI and newionJ
    def updateIon(self,timeRun = False):
        if timeRun:
            tupdateion0 = time.time()
        self.ionDU = np.array(self.newionDU).astype(int)
        self.ionLR = np.array(self.newionLR).astype(int)
        self.ionOI = np.array(self.newionOI).astype(int)
        if timeRun:
            tupdateion1 = time.time()
            print('Time to update ion arrays is')
            print(tupdateion1-tupdateion0)

    def checkBound(self,rightNbhr = 0):
        #assumed that ionMat has been updated
        #October 2,2024: Now have to assume that everyone's ionMat has been
        #updated
        
        #check the GAPs
        for du in range(1,self.N[0]+1):
            for oi in range(1,self.N[2]+1):
                if self.ionMat[du,0,oi]-self.ionMat[du,1,oi]<=self.GAP_Thresh:
                    self.bndOpen[du,0,oi] = 1
                if self.cellType == 'rightend':
                    if self.ionMat[du,self.N[1]+1,oi]-self.ionMat[du,self.N[1],oi]<=self.GAP_Thresh:
                        self.bndOpen[du,self.N[1]+1,oi] = 1
                else:
                    if rightNbhr.ionMat[du,0,oi]-self.ionMat[du,self.N[1],oi]<=self.GAP_Thresh:
                        self.bndOpen[du,self.N[1]+1,oi] = 1
        #check MARGINs (up-down)
        for lr in range(1,self.N[1]+1):
            for oi in range(1,self.N[2]+1):
                if self.ionMat[0,lr,oi]-self.ionMat[1,lr,oi]<=self.MARG_Thresh:
                    self.bndOpen[0,lr,oi] = 1
                if self.ionMat[self.N[0]+1,lr,oi]-self.ionMat[self.N[0],lr,oi]<=self.MARG_Thresh:
                    self.bndOpen[self.N[0]+1,lr,oi] = 1
        #check MARGINs (out-in)
        for du in range(1,self.N[0]+1):
            for lr in range(1,self.N[1]+1):
                if self.ionMat[du,lr,0]-self.ionMat[du,lr,1]<=self.MARG_Thresh:
                    self.bndOpen[du,lr,0] = 1
                if self.ionMat[du,lr,self.N[2]+1]-self.ionMat[du,lr,self.N[2]]<=self.MARG_Thresh:
                    self.bndOpen[du,lr,self.N[2]+1] = 1

    def checkDepol(self):
        self.counter += 1
        if np.sum(self.bndOpen)>0 and self.depolar == 0 and self.cellType != 'leftend':
            self.depolar = self.counter

def doUpdateStep(cellModel,show,timeRun = False):
    numMid = len(cellModel)-2
    #Step One: Initialize placeholder for new ion positions
    for i in cellModel:
        i.newIon(timeRun = timeRun)
    cellModel[0].moveIons(0,cellModel[1],show = show,timeIt = timeRun)
    #Step Two: For each cell, move ions and store temporarily in placeholders
    for middie in range(1,numMid+1):
        cellModel[middie].moveIons(cellModel[middie-1], cellModel[middie+1],show = show,timeIt = timeRun)
    cellModel[numMid+1].moveIons(cellModel[numMid],0,show = show,timeIt = timeRun)
    #Step Three: Update ion positions from placeholders
    for i in cellModel:
        i.updateIon(timeRun = timeRun)
    #Step Four: Get the ion count (don't have to change)
    for cellie in range(len(cellModel)):
        cellModel[cellie].getionCount()
    #Step Five: Check to see if boundary is opening; depolarization (don't have to change)
    #October 15, 2024: skip over this part for now; Joyce wants to check diffusion works
    for cellie in range(len(cellModel)):
        if cellModel[cellie].cellType == 'rightend':
            cellModel[cellie].checkBound()
        else:
            cellModel[cellie].checkBound(rightNbhr = cellModel[cellie+1]) 
        cellModel[cellie].checkDepol()
        
def showSpecs(cellNum,cellOI, bndTimeMat,ionCountMat):
    print('For Cell Number '+str(cellNum))
    print('Here are the boundary opening times:')
    for i in range(cellOI+2):
        print(np.transpose(bndTimeMat[:,:,i]))
    print('Here is the ion count:')
    for i in range(cellOI+2):
        print(np.transpose(ionCountMat[:,:,i]))

#for C: 150 approximates 158 micrometers for length, 25 approximates 26.34 micrometers for width and height
#for N: choose values of Nx, Ny so that C[i]/N[i] = C[j]/N[j] for any i,j
#for G: margin to the side of each cell and gap between cells
#numCells: number of cells (including end cells) in the model
#ionThresh: the higher this value is, the less ions "push" other ions away
#openThresh: [GAP_Thresh, MARG_Thresh] number needed to open boundary

def createModel(numCells, ionThresh, openThresh, chans = [0,100], gapJunct = 1, C = [25, 150,25], N = [2,12,2],G = [0.133, 0.015, 0.133],moveFilter = 0.2,\
               pokeHoleMARG = False):
    
    theModel = [Cell(C, G, N, ionThresh, openThresh, 'leftend', chans, gapJunct = gapJunct, moveFilter = moveFilter,pokeHoleMARG = pokeHoleMARG)]\
              +[Cell(C, G, N, ionThresh, openThresh, 'middle', chans, gapJunct = gapJunct, moveFilter = moveFilter,pokeHoleMARG = pokeHoleMARG) for yup in range(numCells-2)]\
              +[Cell(C, G, N, ionThresh, openThresh, 'rightend', chans, gapJunct = gapJunct, moveFilter = moveFilter,pokeHoleMARG = pokeHoleMARG)]

    return theModel 

#theModel: model of cells
#steps: number of times to run update step
#averWindow: over how many update steps we should include when averaging the number of ions per box, etc.
#showIons: 0 if you don't want function to show anything, 1 if you just want the step number, 2 if you 
#          want the cell maxes and cell mins, and 3 if you want to see the ion matrices and boundary matrices
#showWhen: At what step multiple to do showIons, as shown above
#extend: how many additional rounds to run the model after everything has depolarized
def runModel(theModel, averWindow, showIons = 1, showWhen = 50, N = [2, 12,2],extend = 0):
    #October 17, 2024: implementing the averaging analysis, where we averrage the ion count matrices over "averWindow"
    #many steps, to see if all the entries are averaging to the same amount
    pastIonMats = [0 for i in range(averWindow)]
    
    #October 17, 2024: For depolarization rate, let's have a copy of the cell boundary matrix except the entries
    bndTime = [np.matrix.copy(cellie.bndOpen) for cellie in theModel] 
    
    #December 4, 2024: Keep track of at what round precisely the cell began to fill with ions (BeginFill) and
    #at what point a cell had ions in all of its boxes (AllFill)
    BeginFill = np.zeros(len(theModel))
    AllFill = np.zeros(len(theModel))
    i = 0
    depolArray = [cellie.depolar for cellie in theModel]
    while(0 in depolArray[1:]):
        #update the model
        doUpdateStep(theModel,False)
        #compute the current ion status matrix and append to the list of the past few ion status matrices,
        #removing the oldest one that it will replace
        fullIonMat = np.concatenate([cellie.ionMat for cellie in theModel],axis = 1)
        pastIonMats[i%averWindow] = fullIonMat

    
        #December 4, 2024: Check BeginFill and AllFill
        for cellie in range(len(theModel)):
            if np.any(theModel[cellie].ionMat[1:theModel[cellie].N[0],1:theModel[cellie].N[1],1:theModel[cellie].N[2]] > 0) and BeginFill[cellie]==0.0:
                BeginFill[cellie] == i
            if ~(np.any(theModel[cellie].ionMat[1:theModel[cellie].N[0],1:theModel[cellie].N[1],1:theModel[cellie].N[2]] == 0)) and AllFill[cellie]==0.0:
                AllFill[cellie] == i

        #compute the running average of the most recent "averWindow" many timesteps
        #October 17, 2024: Metric (2) for diffusion analysis, let's see the max and min ion values
        #for each cell
        if i%showWhen == 0:
            #print(depolArray[1:])
            if showIons > 0:
                print('Round '+str(i))
                if showIons > 1:
                    runAvg = sum(pastIonMats)/averWindow
                    cellMaxes = [runAvg[1:N[0]+1,how*(N[1]+1)+1:how*(N[1]+1)+N[1]+1,1:N[2]+1].max() for how in range(numCells)]
                    cellMins = [runAvg[1:N[0]+1,how*(N[1]+1)+1:how*(N[1]+1)+N[1]+1,1:N[2]+1].min() for how in range(numCells)]
                    print('Max Number of Ions in a Square for each Cell: '+str(cellMaxes))
                    print('Min Number of Ions in a Square for each Cell: '+str(cellMins))
                if showIons >2:
                    showSpecs(0,N[2], bndTimeMat = bndTime[0],ionCountMat = theModel[0].ionMat)
                    showSpecs(2,N[2], bndTimeMat = bndTime[2],ionCountMat = theModel[2].ionMat)
                if showIons > 1:
                    print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        #compute the current cell boundary status matrix
        fullCellBnd = [cellie.bndOpen for cellie in theModel]
        #determine if any new parts of the boundary opened, and record them
        for cellCounter in range(len(theModel)):
            for duIt in range(N[0]+2):
                for oiIt in range(N[2]+2):
                    for lrIt in range(N[1]+2):
                        if fullCellBnd[cellCounter][duIt][lrIt][oiIt]!= 0 and bndTime[cellCounter][duIt][lrIt][oiIt]== 0:
                            bndTime[cellCounter][duIt][lrIt][oiIt] = i
        i = i+1                    
        depolArray = [cellie.depolar for cellie in theModel]
    for extras in range(extend):
        #update the model
        doUpdateStep(theModel,False)
        #compute the current ion status matrix and append to the list of the past few ion status matrices,
        #removing the oldest one that it will replace
        fullIonMat = np.concatenate([cellie.ionMat for cellie in theModel],axis = 1)
        pastIonMats[i%averWindow] = fullIonMat

    
        #December 4, 2024: Check BeginFill and AllFill
        for cellie in range(len(theModel)):
            if np.any(theModel[cellie].ionMat[1:theModel[cellie].N[0],1:theModel[cellie].N[1],1:theModel[cellie].N[2]] > 0) and BeginFill[cellie]==0.0:
                BeginFill[cellie] == i
            if ~(np.any(theModel[cellie].ionMat[1:theModel[cellie].N[0],1:theModel[cellie].N[1],1:theModel[cellie].N[2]] == 0)) and AllFill[cellie]==0.0:
                AllFill[cellie] == i

        #compute the running average of the most recent "averWindow" many timesteps
        #October 17, 2024: Metric (2) for diffusion analysis, let's see the max and min ion values
        #for each cell
        if i%showWhen == 0:
            #print(depolArray[1:])
            if showIons > 0:
                print('Round '+str(i))
                if showIons > 1:
                    runAvg = sum(pastIonMats)/averWindow
                    cellMaxes = [runAvg[1:N[0]+1,how*(N[1]+1)+1:how*(N[1]+1)+N[1]+1,1:N[2]+1].max() for how in range(numCells)]
                    cellMins = [runAvg[1:N[0]+1,how*(N[1]+1)+1:how*(N[1]+1)+N[1]+1,1:N[2]+1].min() for how in range(numCells)]
                    print('Max Number of Ions in a Square for each Cell: '+str(cellMaxes))
                    print('Min Number of Ions in a Square for each Cell: '+str(cellMins))
                if showIons >2:
                    showSpecs(0,N[2], bndTimeMat = bndTime[0],ionCountMat = theModel[0].ionMat)
                    showSpecs(2,N[2], bndTimeMat = bndTime[2],ionCountMat = theModel[2].ionMat)
                if showIons > 1:
                    print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
        #compute the current cell boundary status matrix
        fullCellBnd = [cellie.bndOpen for cellie in theModel]
        #determine if any new parts of the boundary opened, and record them
        for cellCounter in range(len(theModel)):
            for duIt in range(N[0]+2):
                for oiIt in range(N[2]+2):
                    for lrIt in range(N[1]+2):
                        if fullCellBnd[cellCounter][duIt][lrIt][oiIt]!= 0 and bndTime[cellCounter][duIt][lrIt][oiIt]== 0:
                            bndTime[cellCounter][duIt][lrIt][oiIt] = i
        i = i+1                    
        depolArray = [cellie.depolar for cellie in theModel]
    print('This is the total number of rounds' +str(i))
    return bndTime, BeginFill, AllFill

In [3]:
def runsForFixedIonChans(chans, runs = 5):

    ionThresh = 0
    openThresh = [15,2,15]
    numCells = 5
    averWindow = 5
    
    n = [2,12,2]
    MARGregions = 2*n[0]*n[1]+2*n[2]*n[1]
    GAPregions = 2*n[0]*n[2]
    
    GtoM = MARGregions/GAPregions
    theFilter = 0.01
    filtDir = 'IonsIn/Model10e6ActualIons/Filter0pt01/'
    
    cellColors = ['b','g','r','c','m']
    
    if chans[0] == 0:
        pHMARG = True
    else:
        pHMARG = False
    
    condVels = []
    
    gjs = [0.00001,0.0001,0.001,0.01,0.1]
    
    figDir = 'GAPIon_'+str(chans[0])+'MARGIon'+str(chans[1])
    if not (os.path.isdir(filtDir+figDir)):
        os.mkdir(filtDir+figDir)
    
    for gj in gjs:
    
        thruGJ = [np.zeros(10) for i in range(numCells)]
        thruGAP = [np.zeros(10) for i in range(numCells)]
        thruMARG = [np.zeros(10) for i in range(numCells)]
        depol = np.array([0 for i in range(numCells)])
        
        for run in range(runs):
        
            myModel = createModel(numCells, ionThresh, openThresh, gapJunct = gj, chans = chans,N = n,moveFilter = theFilter,pokeHoleMARG = pHMARG)
            bleep, bloop, blep = runModel(myModel, averWindow, N = n,showIons = 0,extend = 50)
            
            depol =  depol + np.array([cellie.depolar for cellie in myModel])
            
            for cellie in range(numCells):
                if cellie<len(myModel)-1:
                    currJJ = np.array(myModel[cellie].jumpJunct)
                    thruUpdate = np.zeros(np.max(np.array([thruGJ[cellie].shape[0],currJJ.shape[0]])))
                    thruUpdate[0:thruGJ[cellie].shape[0]] = thruUpdate[0:thruGJ[cellie].shape[0]]+thruGJ[cellie]
                    thruUpdate[0:currJJ.shape[0]] = thruUpdate[0:currJJ.shape[0]]+currJJ
                    thruGJ[cellie] = thruUpdate
    
                if cellie<len(myModel)-1:
                    curreG = np.array(myModel[cellie].enterGAPself)+np.array(myModel[cellie+1].enterGAPleft)
                else:
                    curreG = np.array(myModel[cellie].enterGAPself)
                thruUpdate = np.zeros(np.max(np.array([thruGAP[cellie].shape[0],curreG.shape[0]])))
                thruUpdate[0:thruGAP[cellie].shape[0]] = thruUpdate[0:thruGAP[cellie].shape[0]]+thruGAP[cellie]
                thruUpdate[0:curreG.shape[0]] = thruUpdate[0:curreG.shape[0]]+curreG
                thruGAP[cellie] = thruUpdate
    
                curreM = np.array(myModel[cellie].enterMARG)
                thruUpdate = np.zeros(np.max(np.array([thruMARG[cellie].shape[0],curreM.shape[0]])))
                thruUpdate[0:thruMARG[cellie].shape[0]] = thruUpdate[0:thruMARG[cellie].shape[0]]+thruMARG[cellie]
                thruUpdate[0:curreM.shape[0]] = thruUpdate[0:curreM.shape[0]]+curreM
                thruMARG[cellie] = thruUpdate
    
        depol = depol/runs
        fitcoeffs = np.round(np.polyfit(range(2,len(depol)), depol[1:-1], 1),3)
        p = np.poly1d(fitcoeffs)
        theRate =  1/fitcoeffs[0]
        condVels.append(theRate)
    
        thruGJ = [ionNums/runs for ionNums in thruGJ]
        thruGAP = [ionNums/runs for ionNums in thruGAP]
        thruMARG = [ionNums/runs for ionNums in thruMARG]
    
        figl = plt.figure()
        axl = figl.add_axes([0,0,1,1])
        for i in range(len(thruGJ)-1):
            #xl.scatter(plot(range(, rates, label = 'GapJunctCoeff '+str(gJ))
            axl.vlines(depol,0,50,color = 'k')
            axl.plot(range(len(thruGJ[i])), thruGJ[i],label = 'Crossed into Cell '+str(i+2), c = cellColors[i+1])
        axl.set_title('Ions Entering Through Gap Junctions\n When Gap Junction Coeff. is '+str(gj)+\
                      ';\n Ion Channels in GAP square is '+str(chans[0])+\
                      ';\n Ion Channels in MARG square is '+str(chans[1])+\
                      ';\n Count Entering from Left Side Only\n Average over '+str(runs)+' Runs')
        axl.set_xlabel('Timesteps')
        axl.set_ylabel('Ions')
        #axl.set_xlim([0,5500])
        #axl.set_ylim([0,25])
        axl.legend()
        figl.savefig(filtDir+figDir+'/GJ'+str(gj)+'_ThruGJ.png',bbox_inches='tight')
    
        figa = plt.figure()
        axa = figa.add_axes([0,0,1,1])
        for i in range(len(myModel)):
            #xl.scatter(plot(range(, rates, label = 'GapJunctCoeff '+str(gJ))
            axa.vlines(depol,0,50,color = 'k')
            axa.plot(range(len(thruMARG[i])), thruMARG[i],label = 'Entered Cell '+str(i+1), c = cellColors[i])
        axa.set_title('Total Ions Entering Through MARG\n When Gap Junction Coeff. is '+str(gj)+\
                      ';\n Ion Channels in GAP square is '+str(chans[0])+\
                      ';\n Ion Channels in MARG square is '+str(chans[1])+\
                      ';\n Average over '+str(runs)+' Runs')
        axa.set_xlabel('Timesteps')
        axa.set_ylabel('Total Ions')
        #axa.set_xlim([0,5500])
        #axa.set_ylim([0,75])
        axa.legend()
        figa.savefig(filtDir+figDir+'/GJ'+str(gj)+'_ThruMARGTot.png',bbox_inches='tight')
    
        figb = plt.figure()
        axb = figb.add_axes([0,0,1,1])
        for i in range(len(myModel)):
             axb.vlines(depol,0,50,color = 'k')
             axb.plot(range(len(thruGAP[i])), thruGAP[i], label = 'Entered Cell '+str(i+1), c = cellColors[i])
        axb.set_title('Total Ions Entering Through GAP\n When Gap Junction Coeff. is '+str(gj)+\
                      ';\n Ion Channels in GAP square is '+str(chans[0])+\
                      ';\n Ion Channels in MARG square is '+str(chans[1])+\
                      ';\n Average over '+str(runs)+' Runs')
        axb.set_xlabel('Timesteps')
        axb.set_ylabel('Total Ions')
        #axb.set_xlim([0,5500])
        #axb.set_ylim([0,20])
        axb.legend()
        figb.savefig(filtDir+figDir+'/GJ'+str(gj)+'_ThruGAPTot.png',bbox_inches='tight')
    
        figa = plt.figure()
        axa = figa.add_axes([0,0,1,1])
        for i in range(len(myModel)):
            #xl.scatter(plot(range(, rates, label = 'GapJunctCoeff '+str(gJ))
            axa.vlines(depol,0,1,color = 'k')
            axa.plot(range(len(thruMARG[i])), thruMARG[i]/MARGregions,label = 'Entered Cell '+str(i+1), c = cellColors[i])
        axa.set_title('Average Per Section Ions Entering Through MARG\n When Gap Junction Coeff. is '+str(gj)+\
                      ';\n Ion Channels in GAP square is '+str(chans[0])+\
                      ';\n Ion Channels in MARG square is '+str(chans[1])+\
                      ';\n Average over '+str(runs)+' Runs')
        axa.set_xlabel('Timesteps')
        axa.set_ylabel('Average Ions per MARG Section')
        #axa.set_xlim([0,5500])
        #axa.set_ylim([0,75])
        axa.legend()
        figa.savefig(filtDir+figDir+'/GJ'+str(gj)+'_ThruMARGPerSect.png',bbox_inches='tight')
    
        figb = plt.figure()
        axb = figb.add_axes([0,0,1,1])
        for i in range(len(myModel)):
             axb.vlines(depol,0,1,color = 'k')
             axb.plot(range(len(thruGAP[i])), thruGAP[i]/GAPregions, label = 'Entered Cell '+str(i+1), c = cellColors[i])
        axb.set_title('Average Per Section Ions Entering Through GAP\n When Gap Junction Coeff. is '+str(gj)+\
                      ';\n Ion Channels in GAP square is '+str(chans[0])+\
                      ';\n Ion Channels in MARG square is '+str(chans[1])+\
                      ';\n Average over '+str(runs)+' Runs')
        axb.set_xlabel('Timesteps')
        axb.set_ylabel('Average Ions per GAP Section')
        #axb.set_xlim([0,5500])
        #axb.set_ylim([0,20])
        axb.legend()
        figb.savefig(filtDir+figDir+'/GJ'+str(gj)+'_ThruGAPPersect.png',bbox_inches='tight')

    
    figgjlog = plt.figure()
    axgjlog = figgjlog.add_axes([0,0,1,1])
    axgjlog.set_xlabel('Log_10 of Gap Junction Coefficient')
    axgjlog.set_ylabel('Average Conduction Velocity (Cells per Timestep)')
    axgjlog.set_title('Conduction Velocity based on Gap Junction Coefficients (Log Scale)'+\
     ';\n Ion Channels in GAP square is '+str(chans[0])+\
     ';\n Ion Channels in MARG square is '+str(chans[1])+\
     ';\n Computed from Averaged Depolarization Times \n Averaged Over '+str(runs) +' Runs')
    axgjlog.scatter(np.log10(gjs), condVels)
    figgjlog.savefig(filtDir+figDir+'/GJConductVels_Log.png',bbox_inches='tight')

    figgj = plt.figure()
    axgj = figgj.add_axes([0,0,1,1])
    axgj.set_xlabel('Gap Junction Coefficient')
    axgj.set_ylabel('Average Conduction Velocity (Cells per Timestep)')
    axgj.set_title('Conduction Velocity based on Gap Junction Coefficients'+\
     ';\n Ion Channels in GAP square is '+str(chans[0])+\
     ';\n Ion Channels in MARG square is '+str(chans[1])+\
     ';\n Computed from Averaged Depolarization Times \n Averaged Over '+str(runs) +' Runs')
    axgj.scatter(gjs, condVels)
    figgj.savefig(filtDir+figDir+'/GJConductVels.png',bbox_inches='tight')
    #alternatively adapt this for log-based scale
    #plt.xticks([0.0,4.8,9.6,14.4,19.2,24.0,28.8],['0','4', '8','12','16','20', '24'])

In [4]:
t0 = time.time()
#runsForFixedIonChans([24,0])
#runsForFixedIonChans([0,2])
#runsForFixedIonChans([12,1])
#runsForFixedIonChans([18,0.5])
#runsForFixedIonChans([23.625,0.03125])
t1 = time.time()
print('Total time taken was '+ str(t1-t0)+' seconds')

Total time taken was 0.0 seconds
