## Import Libraries 

In [1]:
import torch
from torch.nn.functional import pad
from PIL import Image, ImageOps # to read and manipulate photos
import numpy as np
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import math

## Define DoG and Gabor filters

### Gaussian kernel

### Gaussian Distribution in 2D
$$
G_{\sigma}(x,y) = e^{-(x^2 + y^2)/2\sigma^2}/\sigma\sqrt{2.\pi}
$$

### DoG kernel

$$
DoG_{\sigma_1 , \sigma_2} = G_{\sigma_1}(x,y) - G_{\sigma_2}(x,y) 
$$

In [18]:
class DoG:

    def __init__(self, kernel_size):
        self.kernel_size = kernel_size

    
    def gaussian_kernel_2D(self, sigma):
        ax = torch.linspace(-(self.kernel_size // 2) , (self.kernel_size // 2), self.kernel_size)
        xx, yy = torch.meshgrid(ax, ax)
        pi, sigma = torch.tensor(torch.pi), torch.tensor(sigma)
        kernel_2D = (1 / (torch.sqrt(2 * pi) * sigma)) * torch.exp(-1 * ((torch.square(xx) + torch.square(yy)) / torch.square(sigma)))
        return kernel_2D / torch.sum(kernel_2D)

    def get_kernel(self, sigma_1, sigma_2):
        return self.gaussian_kernel_2D(sigma_1) - self.gaussian_kernel_2D(sigma_2)


### Gabor kernel

$$
g(x,y,\lambda,\theta,\sigma,\gamma) = exp(-(X^2 + \gamma^2.Y^2)/2\sigma^2).cos(2.\pi.X/\lambda)
$$
$$
X = xcos(\theta) + ysin(\theta)
$$
$$
Y = -xsin(\theta) + ycos(\theta)
$$

In [19]:
class Gabor:

    def __init__(self, kernel_size):
        self.kernel_size = kernel_size

    def get_kernel(self, sigma, gamma, lamda, theta):
        scale = self.kernel_size // 2
        v_range = torch.linspace(-scale, scale, self.kernel_size)
        x, y = torch.meshgrid(v_range, v_range)
        x_rotated = x * torch.cos(theta) + y * torch.sin(theta)
        y_rotated = -1 * x * torch.sin(theta) + y * torch.cos(theta)
        gabor_kernel = torch.exp(-(x_rotated**2 + (gamma**2 * y_rotated**2)) / (2 * sigma**2)) * torch.cos(2 * torch.pi * x_rotated/lamda)
        return gabor_kernel / torch.sum(gabor_kernel)

## Convolution

In [20]:
class Conv2d:

    def __init__(self, kernel_size:int, padding:str, stride:int):
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride

    def set_kernel(self, kernel_obj, kernel_params):
        self.kernel = kernel_obj.get_kernel(**kernel_params)
    
    @staticmethod
    def add_padding(image, pad_dims, value):
        return pad(image, pad_dims, value=value)
    
    def handle_padding(self, image, padding_value):
        pad_rows , pad_columns = self.kernel_size // 2 , self.kernel_size // 2
        pad_dims = (pad_columns, pad_columns, pad_rows, pad_rows)
        return self.add_padding(image , pad_dims, padding_value)
    
    def compute_output_shape(self, image):
        one_dim = math.floor((image.shape[0] - self.kernel_size) / self.stride) + 1
        return one_dim 
    
    def convolve(self, image, row_index, column_index):
        row_start, row_finish = row_index * self.stride, row_index * self.stride + self.kernel_size 
        column_start, column_finish = column_index * self.stride, column_index * self.stride + self.kernel_size
        mat = image[row_start:row_finish, column_start:column_finish]
        return torch.sum(torch.multiply(mat, self.kernel))

    def conv2d(self, image, padding_value=0):
        if self.padding == "same":
            source_image = self.handle_padding(image, padding_value)
        else:
            source_image = image
        out_dim = self.compute_output_shape(source_image)
        result_img = torch.zeros((out_dim , out_dim))
        for row_idx in range(0 , out_dim):
            for col_idx in range(0 , out_dim):
                result_img[row_idx, col_idx] = self.convolve(source_image, row_idx, col_idx)
        return result_img

## Max Pool

In [21]:
class MaxPool2d:

    def __init__(self, pool_size, stride):
        self.pool_size = pool_size
        self.stride = stride

    def compute_output_shape(self, image_shape):
        rows, columns = image_shape
        out_dim_0 , out_dim_1 = math.floor((rows- self.pool_size) / self.stride) + 1 , math.floor((columns - self.pool_size) / self.stride) + 1
        return out_dim_0, out_dim_1
    
    def pool_field(self, image, row, column):
        row_start = row * self.stride
        row_end = row_start + self.pool_size
        column_start = column * self.stride
        column_end = column_start + self.pool_size
        return image[row_start:row_end, column_start:column_end]
    
    def max_pool(self, image):
        out_dim_0, out_dim_1 = self.compute_output_shape(image.shape)
        pools = torch.zeros((out_dim_0, out_dim_1))
        for row in range(out_dim_0):
            for column in range(out_dim_1):
                mat = self.pool_field(image, row, column)
                if mat.shape == (self.pool_size, self.pool_size):
                    pools[row, column] = torch.max(mat)
                
        return pools



## TTFS Encoding

In [22]:
class TTFS:

    @staticmethod
    def normalize(image):
        return image / 255.
    
    @staticmethod
    def process_spike_times(spike_times):
        inf_mask = torch.isinf(spike_times)
        spike_times[inf_mask] = -1
        spike_times = torch.round(spike_times).long()
        return spike_times
    
    def __init__(self, tau, theta=1):
        self.theta = theta
        self.tau = tau

    def get_spike_times(self, image):
        return -1 * self.tau * torch.log(image)
    
    def get_spike_train(self, time_window, spike_times):
        dim_0, dim_1, dim_2 = spike_times.shape[0], spike_times.shape[1], time_window
        spike_train = torch.zeros(size=(dim_0, dim_1, dim_2))
        for iteration in range(time_window):
            current_spikes = spike_times == iteration
            spike_train[:, :, iteration] = current_spikes
        return spike_train
    

    def encode(self, image, time_window, normalize=False):
        if normalize:
            image = self.normalize(image)
        spike_times = self.get_spike_times(image)
        spike_times = self.process_spike_times(spike_times)
        return self.get_spike_train(time_window, spike_times)
    

## Plotting Functions

In [23]:
def plot_kernel_image(img ,kernel, conv, max_pool):
    fig = make_subplots(
        rows=1, cols=4, subplot_titles=("Image",  "Kernel", "Convolved Image", "Max-Pooled Image"))
    fig.add_trace(px.imshow(img, binary_string=True).data[0], 1, 1)
    fig.add_trace(px.imshow(kernel, binary_string=True).data[0], 1, 2)
    fig.add_trace(px.imshow(conv, binary_string=True).data[0], 1, 3)
    fig.add_trace(px.imshow(max_pool, binary_string=True).data[0], 1, 4)
    fig.show()

### Read image and pre process

In [24]:
def get_image(size:tuple , path:str):
    img = Image.open(path)
    img = ImageOps.grayscale(img)
    img = img.resize(size=size)
    return torch.tensor(np.asarray(img))

## Put them all together

In [35]:
path = 'cat.jpg'
size = 224
k_type = 'DoG'
params = {
    'sigma_1': 7,
    'sigma_2': 5
}
padding ="same"
strides = 2
image = get_image((224, 224), "cat.jpg")
kernel = DoG(15)
conv = Conv2d(15, padding, strides)
conv.set_kernel(kernel, params)
conv_img = conv.conv2d(image)
max_pool_layer = MaxPool2d(pool_size=2, stride=2)
pooled_img = max_pool_layer.max_pool(conv_img)
plot_kernel_image(image, kernel.get_kernel(**params), conv_img, pooled_img)


In [36]:
ttfs = TTFS(tau=20.)
spike_train = ttfs.encode(pooled_img, 50)
x, y, z = torch.where(spike_train == 1)
fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z,
                                   mode='markers')])
fig.update_layout(
    width=1000,  # Set the width to 800 pixels
    height=800,  # Set the height to 600 pixels
    title_text="TTFS Encoding"
)
fig.show()