In [1]:
import numpy as np 
import cv2 as cv  

import torch                        
import torch.nn.functional as F
import torch.nn as nn

## Create Gabor To Hypercolumns ONNX Model

In [7]:
def get_gabor_weights(kernel_size, K_dim):
    """
    return CNN 2D Weights that operate a grayscale image to K-dim hypercolumns
    weights dimension : [K (hypercolumns), 1 (grayscale), kernel_size, kernel_size]
                         Final dimension , Previous dim
    
    """
    sigma = np.sqrt(kernel_size) # Standard deviation of the Gaussian envelope
    thetas = [(np.pi*k/K_dim) - (np.pi/2) for k in range(K_dim)]# The orientation of the normal to the parallel stripes of Gabor function
    lambd = kernel_size/2# Wavelength of the sinusoidal component
    gamma = 0.7# The spatial aspect ratio and specifies the ellipticity of the support of Gabor
    gabor_kernel_K = []

    for theta in thetas:
        gabor_kernel_K.append(cv.getGaborKernel((kernel_size, kernel_size), sigma, theta, lambd, gamma))

    return torch.from_numpy(np.asarray(gabor_kernel_K).astype('float32')).unsqueeze(1)

In [8]:
def get_phi_weights(K=12):
    """
    return phi() function as vectorized CNN 2D Weights
    weights dimension : [Hypercolumn dim , Hypercolumn dim, 1, 1]
    """
    # Range check func
    def range_check(k):
        if k < -np.pi/2:
            return k + np.pi
        if k > np.pi/2:
            return k - np.pi
        return k

    # set const
    orientations_K = [(np.pi*k/K) - (np.pi/2) for k in range(K)]
    pi_8 = np.pi/8

    delta_w = []
    for k in range(K):          # k from theta
        delta_w1 = []
        for k_r in range(K):    # k from raw/ beta
            delta = range_check(orientations_K[torch.tensor(k)] - orientations_K[torch.tensor(k_r)])
            delta_w1.append(np.exp(-np.abs(delta)/pi_8))
        delta_w.append(delta_w1)

    return torch.from_numpy(np.asarray(delta_w).astype('float32')).unsqueeze(2).unsqueeze(3)

In [18]:
class GaborToHypercolumns(nn.Module):
    def __init__(self, kernel_size=21, K=12, c_mul = 3.01, device='cpu'):
        super().__init__()
        self.kernel_size = kernel_size
        self.gabor_weights = get_gabor_weights(kernel_size, K)
        self.phi_weights = get_phi_weights(K)
        self.c_mul = c_mul


    def forward(self, input_grayscale):
        """
        input_grayscale : [300, 500]
        output : [56, 96]
        """
        input_grayscale = input_grayscale.unsqueeze(0).unsqueeze(0) #[H, W] -> [1,1,H,W]
        out = F.conv2d(input_grayscale, self.gabor_weights, stride=5)

        # Normalize gabor result by removing negative values and normalizing [0, out.max()] -> [0, 1]
        out = torch.nn.functional.relu(out)
        out = out / out.max()

        # Phi
        out = F.conv2d(out, self.phi_weights, stride=1)
        out = out * self.c_mul

        return out.squeeze(0)

In [12]:
GTHModel = GaborToHypercolumns()

In [13]:
dummy_input = torch.zeros([300, 500])

torch.onnx.export(GTHModel, 
                  dummy_input, 
                  "GTHModel.onnx", 
                  verbose=True, 
                  export_params=True,
                  input_names=['grayscale'], 
                  output_names=['hypercolumns'],
                  opset_version=11)

graph(%grayscale : Float(300, 500, strides=[500, 1], requires_grad=0, device=cpu)):
  %1 : Float(1, 300, 500, strides=[150000, 500, 1], requires_grad=0, device=cpu) = onnx::Unsqueeze[axes=[0]](%grayscale) # <ipython-input-11-15bb3ef5337f>:15:0
  %2 : Float(1, 1, 300, 500, strides=[150000, 150000, 500, 1], requires_grad=0, device=cpu) = onnx::Unsqueeze[axes=[0]](%1) # <ipython-input-11-15bb3ef5337f>:15:0
  %3 : Float(12, 1, 21, 21, strides=[441, 441, 21, 1], requires_grad=0, device=cpu) = onnx::Constant[value=<Tensor>]() # <ipython-input-11-15bb3ef5337f>:16:0
  %4 : Float(1, 12, 56, 96, strides=[64512, 5376, 96, 1], requires_grad=0, device=cpu) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[21, 21], pads=[0, 0, 0, 0], strides=[5, 5]](%2, %3) # <ipython-input-11-15bb3ef5337f>:16:0
  %5 : Float(1, 12, 56, 96, strides=[64512, 5376, 96, 1], requires_grad=0, device=cpu) = onnx::Relu(%4) # c:\Users\chrys\AppData\Local\Programs\Python\Python37\lib\site-packages\torch\nn\functional.py:12

## Create V1SH init

In [19]:
def pGainFunction(input_raw, T_x = 1):
    input = input_raw.clone()
    
    case1 = input < T_x
    input[case1] = 0
    
    case2 = torch.logical_and((T_x <= input), (input <= T_x + 1))
    input[case2] = input[case2] - T_x
    
    case3 = input > T_x + 1
    input[case3] = 1

    return input

def mGainFunction(input_raw, L_y=1.2, g_1=0.21, g_2=2.5):
    input = input_raw.clone()

    case1 = input < 0
    input[case1] = 0

    case2 = torch.logical_and((input >= 0), (input <= L_y))
    input[case2] =  input[case2] * g_1

    case3 = (input >= L_y)
    input[case3] =  g_1*L_y + g_2*(input[case3]-L_y)

    return input

In [20]:
class V1SHNetInit(nn.Module):
    def __init__(self, h, w, K=12, device='cpu'):
        super().__init__()
        # Params
        self.device = device
        self.K = K
        self.h = h
        self.w = w

        self.selfmp = -1.0

        self.fw1 = 0.8
        self.fw2 = 0.7
        
        self.backIm = 1.0
        self.backIp = 0.85
        self.decay = 1.0


        self.init_x = torch.tensor(self.backIm/ self.decay, device=device)
        self.init_x = self.selfmp*self.init_x*(1.0+self.fw1*2.0 + self.fw2*2.0)
        self.init_x = (self.backIp + self.init_x)/self.decay

        

    def forward(self, blank=0):
        Sp = torch.full([self.K, self.h, self.w], self.init_x, device=self.device)
        Sm = torch.full([self.K, self.h, self.w], self.backIm/ self.decay, device=self.device)
        return Sp, Sm

        

In [21]:
V1SHNetInitModel = V1SHNetInit(56, 96)

In [23]:
torch.onnx.export(V1SHNetInitModel, 
                  0,
                  f="V1SHNetInitModel.onnx", 
                  verbose=True, 
                  export_params=True,
                  input_names=['dummy'], 
                  output_names=['ExcitatoryCell', 'InhibitoryCell'],
                  opset_version=11)

graph():
  %ExcitatoryCell : Float(12, 56, 96, strides=[5376, 96, 1], requires_grad=0, device=cpu) = onnx::Constant[value=<Tensor>]()
  %InhibitoryCell : Float(12, 56, 96, strides=[5376, 96, 1], requires_grad=0, device=cpu) = onnx::Constant[value=<Tensor>]()
  return (%ExcitatoryCell, %InhibitoryCell)



## Create V1 Saliency Hypothesis ONNX Model

In [26]:
def synaptic_connection_J(d, theta1, theta2, angle, curvature):
    def cppdistance(distance):
        return np.exp(-np.power(distance, 2.0)/90.0)/2.7

    if (angle < np.pi/2.69 or 
    (angle < np.pi/1.10000 and np.abs(theta1) < np.pi/5.9 and np.abs(theta2) < np.pi/5.9)):
        return 0.126*np.exp(-(curvature)**2-2*(curvature)**7-cppdistance(d))
    else:
        return 0

def inhibitory_connection_W(d, theta1, theta2, angle, curvature, delta_K, selfpm=1.0):
    def cpmdistance(distance):
        #if(distance<9.999): 
        #    g = np.exp(-distance/30.0)/1.25
        if(distance<9.999): 
            return 1.0/1.75
        return 0.0

    if (angle >= np.pi/1.1) and (np.abs(theta1) > np.pi/11.999 and np.abs(theta2) > np.pi/11.999):
        x = delta_K
        if(x > np.pi/2.0): 
            x = np.pi-x

        if(x < np.pi/3.0):
            res = selfpm * np.exp(-np.power(x/(np.pi/4.0), 1.5)) * (1.0-1.00*np.exp(-0.4*np.power(curvature, 1.5)))
            res = res * cpmdistance(d/np.cos(angle/4.0))
            return res
    else:
        return 0
        
def get_exc_inh_weights(d=10, K=12, selfpm=1.0):
    """
    return the weight of both J, W
    """
    def get_connect_angle(P1, P2):
        x1, y1 = P1
        x2, y2 = P2

        x = (x2-x1)
        y = -(y2-y1) # y-axis flipped on computer array

        if(np.abs(x)>0):
            return np.arctan(y/x)
        elif(x==0):
            return np.pi/2
        return 0.0

    def range_check(k):
        if k < -np.pi/2:
            return k + np.pi
        if k > np.pi/2:
            return k - np.pi
        return k

    diameter = d*2 + 1
    kernel_J = np.zeros([K,K,diameter,diameter])
    kernel_W = np.zeros([K,K,diameter,diameter])

    orientations_K = [(np.pi*k/K) - (np.pi/2) for k in range(K)]

    mid_x = (diameter-1)/2
    mid_y = (diameter-1)/2

    for i in range(diameter):
        for j in range(diameter):

            # Get the euclidian distance
            distance = np.sqrt( (i-mid_x)**2 + (j-mid_y)**2 )
            if(0<distance<=d):
                for k1 in range(K):
                    for k2 in range(K):
                        connect_angle = get_connect_angle((mid_x, mid_y), (i, j))

                        theta1 = range_check(orientations_K[k1]-connect_angle)
                        theta2 = range_check(orientations_K[k2]-connect_angle)

                        if np.abs(theta1) > np.abs(theta2):
                            theta1, theta2 = theta2, theta1

                        angle = 2*np.abs(theta1) + 2*np.sin(np.abs(theta1+theta2))
                        curvature = angle/distance

                        delta_K = np.abs(orientations_K[k1] - orientations_K[k2])

                        kernel_J[k1][k2][j][i] = synaptic_connection_J(distance, theta1, theta2, angle, curvature)
                        kernel_W[k1][k2][j][i] = inhibitory_connection_W(distance, theta1, theta2, angle, curvature, delta_K, selfpm)

                        # (j,i) following image (h,w) convention
    return kernel_J, kernel_W # [K, K, diameter, diameter]

In [27]:
def I_norm_conv(cell_out, norm_width=2, norm_force=0.5, epsilon=0.01): 
    K, _, _ = cell_out.shape
    cell_out = cell_out.unsqueeze(0)  # [1, K1, h, w]

    weight = torch.from_numpy(np.ones([K, K, norm_width*2+1, norm_width*2+1]).astype('float32'))

    padded_input = F.pad(cell_out, 
                        (norm_width,norm_width,norm_width,norm_width), 
                        mode='circular') #[1,K,h+norm_width*2,w+norm_width*2]
    
    output = F.conv2d(padded_input, weight)   # [1, K1, h+pad, w+pad] -> [1, K2, h, w]

    # Do norm
    output = output / (4.0*norm_width*norm_width) # Element wise division
    output = 1.0 * output * output
    output = -1 * norm_force * epsilon * output

    # Apply norm
    out = F.relu(cell_out + output).squeeze(0)

    return out

In [28]:
class V1SHNetComputeMachine(nn.Module):
    def __init__(self, h, w, K=12, d=10, epsilon=0.05, device='cpu'):
        super().__init__()
        # Params
        self.device = device
        self.d = d
        self.K = K
        self.h = h
        self.w = w

        self.epsilon = epsilon
        self.selfpm = 1.0
        self.selfmp = -1.0
        self.selfpp = 0.8

        self.fw1 = 0.8
        self.fw2 = 0.7
        
        self.backIm = 1.0
        self.backIp = 0.85
        self.decay = 1.0

        self.angles = [(np.pi*k/K) - (np.pi/2) for k in range(K)]

        self.angle_weight = 0.0
        m = -K/12
        while(m<=K/12):
            self.angle_weight  += np.exp(-np.abs(m)/(K)*8.0)
            m+=1

        J_weights, W_weights = get_exc_inh_weights(self.d, self.K, self.selfpm)
        # Cpp, X, J, excitatory
        # Cpm, Y, W, inhibitory
        self.excitatory_weights = torch.from_numpy(J_weights.astype('float32')).to(device) / (self.angle_weight *1.45)
        self.inhibitory_weights = torch.from_numpy(W_weights.astype('float32')).to(device) / (self.angle_weight *2.00)
        

    def forward(self, I, Sp, Sm):
        # Get current output from gain function
        outSp = pGainFunction(Sp)
        outSm = mGainFunction(Sm)

        # Input Norm
        Sp = I_norm_conv(Sp, epsilon=self.epsilon)

        # p cell update
        pForce =  I - Sp*self.decay + self.selfpp*outSp
        pForce = pForce + self.backIp

        #   psi inhibition -> excitatory
        #     1
        pForce = pForce + self.selfmp*outSm
        #     0.8
        for k in range(self.K):
            pForce[k] = pForce[k] + self.selfmp*0.8*outSm[(k+1+self.K)%self.K]
            pForce[k] = pForce[k] + self.selfmp*0.8*outSm[(k-1+self.K)%self.K]
            pForce[k] = pForce[k] + self.selfmp*0.7*outSm[(k+2+self.K)%self.K]
            pForce[k] = pForce[k] + self.selfmp*0.7*outSm[(k-2+self.K)%self.K]

        #   excitatory connection
        outSpPad = F.pad(outSp.unsqueeze(0), #[1,12,H,W]
                        (self.d,self.d,self.d,self.d), 
                        mode='circular')
        pCon = F.conv2d(outSpPad, self.excitatory_weights).squeeze(0)
        
        pCh = pForce + pCon

        Sp = Sp + self.epsilon * pCh
        Sp = F.relu(Sp) #Remove negative value

        # m cell update
        mForce = self.selfpm*outSp - self.decay*Sm
        mForce = mForce + self.backIm

        outSpPad = F.pad(outSp.unsqueeze(0), #[1,12,H,W]
                        (self.d,self.d,self.d,self.d), 
                        mode='circular')

        mCon = F.conv2d(outSpPad, self.inhibitory_weights).squeeze(0)

        mCh = mForce + mCon

        Sm = Sm + self.epsilon * mCh
        Sm = F.relu(Sm) #Remove negative value

        outSp = pGainFunction(Sp)
        outSm = mGainFunction(Sm)

        return Sp, Sm, outSp, outSm, Sp.mean(), Sm.mean(), outSp.mean(), outSm.mean()

In [31]:
V1SHNetComputeMachineModel = V1SHNetComputeMachine(56, 96)

dummy_input1 = torch.zeros([12, 56, 96])
dummy_input2 = torch.zeros([12, 56, 96])
dummy_input3 = torch.zeros([12, 56, 96])

In [32]:
torch.onnx.export(V1SHNetComputeMachineModel, 
                  args=(dummy_input1, dummy_input2, dummy_input3),
                  f="V1SHComputeMachineModel.onnx", 
                  verbose=True, 
                  export_params=True,
                  do_constant_folding=True,
                  input_names=['ImageHypercolumnsInput', 'ExcitatoryCell', 'InhibitoryCell'], 
                  output_names=['ExcitatoryCell', 'InhibitoryCell', 'gExcitatoryCell', 'gInhibitoryCell', 'avgSp', 'avgSm', 'avgOutSp', 'avgOutSm'],
                  opset_version=11)

  """


graph(%ImageHypercolumnsInput : Float(12, 56, 96, strides=[5376, 96, 1], requires_grad=0, device=cpu),
      %ExcitatoryCell.1 : Float(12, 56, 96, strides=[5376, 96, 1], requires_grad=0, device=cpu),
      %InhibitoryCell.1 : Float(12, 56, 96, strides=[5376, 96, 1], requires_grad=0, device=cpu),
      %2915 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %2916 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %2917 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %2918 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %2919 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %2920 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %2921 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %2922 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %2923 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %2924 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %2925 : Long(1, strides=[1], requires_grad=0, d