In [9]:
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
from MiniTorch.core.baseclasses import ComputationNode
import time
from typing import Literal, List, Tuple, Dict, Any
import jax

In [10]:
def get_kernel_size(kernel_size):
    if isinstance(kernel_size, int):
        return (kernel_size, kernel_size)
    else:
        return kernel_size
def get_stride(stride):
    if isinstance(stride, int):
        return (stride, stride)
    else:
        return stride

In [11]:
class Conv2D(ComputationNode):

    def __init__(self, input_channels : int,kernel_size : int | tuple = 3, no_of_filters = 1, stride = 1, pad = None, accumulate_grad_norm = False, accumulate_params = False,seed_key = None, bias = True, initialization = "None"):
        super().__init__()
        if seed_key == None:
            self.seed_key = jrandom.PRNGKey(int(time.time()))
        self.kernel_size = Conv2D.get_kernel_size(kernel_size)
        self.input_channels = input_channels
        self.no_of_filters = no_of_filters
        self.stride = Conv2D.get_stride(stride)
        self.pad = pad
        self.accumulate_grad_norm = accumulate_grad_norm
        self.accumulate_params = accumulate_params
        self.initialization = initialization
        self.parameters = {'W': None, 'b': None}
        self.bias = bias
        self.initialize(self.seed_key)
    @staticmethod
    def get_kernel_size(kernel_size):
        if isinstance(kernel_size, int):
            return (kernel_size, kernel_size)
        else:
            return kernel_size
    @staticmethod
    def get_stride(stride): 
        if isinstance(stride, int):
            return (stride, stride)
        else:
            return stride

    def initialize(self, seed_key):
        if self.initialization == "he":
            self.parameters['W'] = jrandom.normal(seed_key, (self.no_of_filters, self.input_channels, self.kernel_size[0], self.kernel_size[1])) * jnp.sqrt(2/(self.no_of_filters * self.kernel_size[0] * self.kernel_size[1]))
        else:
            self.parameters['W'] = jrandom.normal(seed_key, (self.no_of_filters, self.input_channels, self.kernel_size[0], self.kernel_size[1]))
        if self.bias:
            self.parameters['b'] = jnp.zeros((1,))

    @staticmethod
    def _conv2d_forward_legacy_v1(W, x, stride, b = None):
        no_of_filters, kernel_size_x, kernel_size_y = W.shape
        batch_size, input_channels, input_x, input_y = x.shape
        output_x = (input_x - kernel_size_x)//stride[0] + 1
        output_y = (input_y - kernel_size_y)//stride[1] + 1
        out = np.zeros((batch_size, no_of_filters, output_x, output_y))
        for batch in range(batch_size):
            for filter in range(no_of_filters):
                for i in range(output_x):
                    for j in range(output_y):
                        conv_out = np.sum(x[batch, :, i*stride[0]:i*stride[0]+kernel_size_x, j*stride[1]:j*stride[1]+kernel_size_y] * W[filter])
                        out[batch, filter, i, j] = conv_out + b
        return out
    
    @staticmethod
    def _conv2d_forward_legacy_v2(W, x, stride, b = None):
        no_of_filters, kernel_size_x, kernel_size_y = W.shape
        batch_size, input_channels, input_x, input_y = x.shape
        output_x = (input_x - kernel_size_x)//stride[0] + 1
        output_y = (input_y - kernel_size_y)//stride[1] + 1
        stride_x, stride_y = stride
        strides = (
            x.strides[0],
            x.strides[1],
            x.strides[2] * stride_x,
            x.strides[3] * stride_y,
            x.strides[2],
            x.strides[3]
        )
        shape = (
            batch_size,
            input_channels,
            output_x,
            output_y,
            kernel_size_x,
            kernel_size_y
        )
        x_strided_view = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
        conv_out = np.einsum('bchwkl,fkl->bfhw', x_strided_view, W, optimize=True)
        conv_out += b
        return conv_out
    @staticmethod
    def _conv2d_forward(X : jax.Array, W : jax.Array,b : jax.Array, stride : tuple, padding: Literal['VALID','SAME'] = 'VALID'):

        def conv_over_one_batch(X_vec, W_vec, stride, padding):

            if X_vec.ndim == 3:
                X_vec = X_vec[None,...]
            cvout = jax.lax.conv_general_dilated(X_vec,W_vec[None,...],window_strides=stride,padding=padding,
                                                    dimension_numbers=('NCHW','OIHW','NCHW'))[0,0]
            return cvout
        convout = jax.vmap(jax.vmap(conv_over_one_batch,in_axes=(None,0,None,None)), in_axes=(0,None,None,None))(X,W,stride,padding)
        convout += b
        return convout

    def forward(self, x, use_legacy_v1 = False, use_legacy_v2 = False):
        self.input = x
        if use_legacy_v1:
            self.output = Conv2D._conv2d_forward_legacy_v1(self.parameters['W'], x, self.stride, self.parameters['b'])
            return self.output
        if use_legacy_v2:
            self.output = Conv2D._conv2d_forward_legacy_v2(self.parameters['W'], x, self.stride, self.parameters['b'])
            return self.output
        W, b, stride = self.parameters['W'], self.parameters['b'], self.stride
        with jax.checking_leaks():
            output = jax.jit(Conv2D._conv2d_forward, static_argnames=('stride','padding'))(x, W, b, stride)
        self.output = output
        return self.output
    def backward(self, out_grad):
        pass

        

In [13]:
from MiniTorch.nets.layers import Conv2D
import numpy as np
import jax.numpy as jnp
from torch.nn import Conv2d
import time
import torch

In [24]:
x = np.random.randn(30,3,225,225)


In [3]:
conv = Conv2D(input_channels=3,kernel_size = 3, no_of_filters = 50, stride = 1, pad = None, accumulate_grad_norm = False, accumulate_params = False,seed_key = None, bias = True, initialization = "None")

In [4]:
conv1 = Conv2d(3,50,3)

In [49]:
st = time.time()
out = conv.forward(jnp.array(x))
et = time.time()
et-st

0.10851025581359863

In [50]:
st = time.time()
out2=conv1(torch.tensor(x, dtype=torch.float))
et = time.time()
et-st

0.11046099662780762

In [6]:
out.shape

(5, 50, 148, 148)

In [47]:
x.strides[-1] * 20 * 20 * 3

9600

In [68]:
stride = (1,1)

In [69]:
strides = (
    x.strides[0],  # Batch dimension stride (unchanged)
    x.strides[1],  # Channel dimension stride (unchanged)
    x.strides[2] * stride[0],  # Vertical stride (step size for sliding window)
    x.strides[3] * stride[1],  # Horizontal stride (step size for sliding window)
    x.strides[2],  # Kernel height stride (step size within the kernel)
    x.strides[3]   # Kernel width stride (step size within the kernel)
)

In [72]:
shape = (2, 1, (4-2)//stride[0] + 1, (4-2)//stride[1] + 1, 2, 2)

In [74]:
x = np.lib.stride_tricks.as_strided(x, shape=shape,strides=strides)

In [78]:
out = np.einsum('bchwkl,fkl->bfhw', x, w)

In [80]:
out.shape

(2, 10, 3, 3)

In [81]:
(4-2)//stride[0] + 1, (4-2)//stride[1] + 1

(3, 3)

In [103]:
import jax

In [104]:
x = jax.random.normal(jax.random.PRNGKey(1),(5,3,224,224))

In [136]:
w = jax.random.normal(jax.random.PRNGKey(1),(50,3,3,3))

In [137]:
def _conv_over_one_batch(x :jax.Array, w :jax.Array, stride : tuple):
    if x.ndim == 3:
        x = x[None,...]
    conv_out = jax.lax.conv_general_dilated(x, w[None, ...], padding='VALID',
                                window_strides=stride,dimension_numbers=('NCHW','OIHW','NCHW'))
    return conv_out[0,0]

In [138]:
out = jax.vmap(jax.vmap(_conv_over_one_batch, in_axes=(None, 0, None)), in_axes=(0,None,None))(jax.numpy.array(x),w, stride)

In [119]:
x = np.random.randn(5,3,224,224)

In [120]:
out == conv.forward(x)

False

In [7]:
type(conv)

MiniTorch.nets.layers.Conv2D

In [134]:
conv.forward(x).shape

(5, 50, 148, 148)

In [None]:
out

Array([[[[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]],

        [[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]],

        [[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, Fa