In [1]:
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
from MiniTorch.core.baseclasses import ComputationNode
from MiniTorch.legacy_utils import _conv2d_forward_legacy_v1, _conv2d_forward_legacy_v2, _conv2d_backward_legacy_v1, _conv_initialize_legacy, get_kernel_size, get_stride
import time
from typing import Literal, List, Tuple, Dict, Any
import jax

In [2]:
class Conv2D(ComputationNode):

    def __init__(self, input_channels : int,kernel_size : int | tuple = 3, no_of_filters = 1, stride = 1, pad = None, accumulate_grad_norm = False, accumulate_params = False,seed_key = None, bias = True, 
                 initialization = "None", use_legacy_v1 : bool = False, use_legacy_v2:bool = False):
        super().__init__()
        if seed_key == None:
            self.seed_key = jrandom.PRNGKey(int(time.time()))
        self.kernel_size = get_kernel_size(kernel_size)
        self.input_channels = input_channels
        self.no_of_filters = no_of_filters
        self.stride = get_stride(stride)
        self.pad = pad
        self.accumulate_grad_norm = accumulate_grad_norm
        self.accumulate_params = accumulate_params
        self.initialization = initialization
        self.parameters = {'W': None, 'b': None}
        self.bias = bias

        self.use_legacy_v1 = use_legacy_v1
        self.use_legacy_v2 = use_legacy_v2
        if use_legacy_v1 or use_legacy_v2:
            self.parameters['W'], self.parameters['b'] = _conv_initialize_legacy(self.kernel_size,self.no_of_filters,self.initialization,self.bias)
        else:
            self.initialize(self.seed_key)
    def initialize(self, seed_key):
        if self.initialization == "he":
            self.parameters['W'] = jrandom.normal(seed_key, (self.no_of_filters, self.input_channels, self.kernel_size[0], self.kernel_size[1])) * jnp.sqrt(2/(self.no_of_filters * self.kernel_size[0] * self.kernel_size[1]))
        else:
            self.parameters['W'] = jrandom.normal(seed_key, (self.no_of_filters, self.input_channels, self.kernel_size[0], self.kernel_size[1]))
        if self.bias:
            self.parameters['b'] = jnp.zeros((1,))

    @staticmethod
    def _conv2d_forward(X : jax.Array, W : jax.Array,b : jax.Array, stride : tuple, padding: Literal['VALID','SAME'] = 'VALID'):

        def conv_over_one_batch(X_vec, W_vec, stride, padding):

            if X_vec.ndim == 3:
                X_vec = X_vec[None,...]
            cvout = jax.lax.conv_general_dilated(X_vec,W_vec[None,...],window_strides=stride,padding=padding,
                                                    dimension_numbers=('NCHW','OIHW','NCHW'))[0,0]
            return cvout
        convout = jax.vmap(jax.vmap(conv_over_one_batch,in_axes=(None,0,None,None)), in_axes=(0,None,None,None))(X,W,stride,padding)
        convout += b
        return convout

    def forward(self, x):
        self.input = x
        if self.use_legacy_v1:
            x = np.pad(x,((0,0),(0,0),(self.pad,self.pad),(self.pad,self.pad)))
            self.output = _conv2d_forward_legacy_v1(self.parameters['W'], x, self.stride, self.parameters['b'])
            return self.output
        if self.use_legacy_v2:
            x = np.pad(x,((0,0),(0,0),(self.pad,self.pad),(self.pad,self.pad)))
            self.output = _conv2d_forward_legacy_v2(self.parameters['W'], x, self.stride, self.parameters['b'])
            return self.output
        W, b, stride = self.parameters['W'], self.parameters['b'], self.stride
        with jax.checking_leaks():
            output = jax.jit(Conv2D._conv2d_forward, static_argnames=('stride','padding'))(x, W, b, stride)
        self.output = output
        return self.output
    def backward(self, out_grad):
        dL_dW,dL_db,dL_dinput = None,None,None
        if self.use_legacy_v1 or self.use_legacy_v2:
            dL_dW,dL_db,dL_dinput = _conv2d_backward_legacy_v1(out_grad,self.input,self.kernel_size,self.parameters['W'],self.parameters['b'],self.stride,self.pad)
        self.grad_cache['dL_dW'] = dL_dW
        self.grad_cache['dL_db'] = dL_db
        self.grad_cache['dL_dinput'] = dL_dinput
        return dL_dinput   

In [None]:
class Flatten(ComputationNode):
    def __init__(self):
        super().__init__()
        self.requires_grad = False
        self.shape = None

    def forward(self,x):
        self.shape = x.shape
        self.input = x
        self.output = np.reshape(x,(x.shape[0],-1))
        return self.output
    def backward(self, output_grad):
        dL_dinput= np.reshape(output_grad,(self.shape[0],self.shape[1],self.shape[2],self.shape[3]))
        self.grad_cache['dL_dinput']  = dL_dinput
        return dL_dinput
    
class MaxPool2d(ComputationNode):
    def __init__(self, pool_size, pool_stride, use_legacy_v1 = False):
        super().__init__()
        self.pool_size = get_kernel_size(pool_size)
        self.stride = get_stride(pool_stride)
        self.use_legacy_v1 = use_legacy_v1
    @staticmethod
    def _maxpool2d_forward_legacy_v1(pool_size, stride, input):
        batch_size, input_channels, H, W = input.shape[0],input.shape[1], input.shape[2], input.shape[3]
        output_h = (H - pool_size[0])//stride[0] + 1
        output_w = (W - pool_size[1])//stride[1] + 1
        output = np.zeros((batch_size,input_channels,output_h,output_w))
        for b in range(batch_size):
            for c in range(input_channels):
                for i in range(output_h):
                    for j in range(output_w):
                        h_s = i * stride[0]
                        h_e = h_s + pool_size[0]
                        w_s = j * stride[1]
                        w_e = w_s + pool_size[1]
                        output[b,c,i,j] = np.max(input[b,c,h_s:h_e,w_s:w_e])
        return output
    @staticmethod
    def _maxpool2d_backward_legacy_v1(pool_size, input, out_grad, stride):
        batch_size, input_channels = input.shape[0],input.shape[1]
        out_grad_h, out_grad_w = out_grad.shape[2], out_grad.shape[3]
        dL_dinput = np.zeros_like(input)
        for b in range(batch_size):
            for c in range(input_channels):
                for i in range(out_grad_h):
                    for j in range(out_grad_w):
                        h_s = i * stride[0]
                        h_e = h_s + pool_size[0]
                        w_s = j * stride[1]
                        w_e = w_s + pool_size[1]
                        window = input[b,c,h_s:h_e,w_s:w_e]
                        max_ids = np.unravel_index(np.argmax(window),window.shape)
                        dL_dinput[b,c,h_s + max_ids[0],w_s + max_ids[1]] = out_grad[b,c,i,j]
        return dL_dinput

        
    def forward(self,x):
        self.input = x
        output = None
        if self.use_legacy_v1:
            output = self._maxpool2d_forward_legacy_v1(self.pool_size,self.stride,x)
        self.output = output
        return output

    def backward(self, output_grad):
        dL_dinput = None
        if self.use_legacy_v1:
            dL_dinput = self._maxpool2d_backward_legacy_v1(self.pool_size,self.input,output_grad,self.stride)
        self.grad_cache['dL_input'] = dL_dinput
        return dL_dinput


In [4]:
from MiniTorch.nets.layers import Conv2D
import numpy as np
import jax.numpy as jnp
from torch.nn import Conv2d
import time
import torch

In [5]:
x = np.random.randn(2,3,30,30)
print(x[0,1,0:2,0:2])
np.unravel_index(np.argmax(x[0,1,0:2,0:2]), x[0,1,0:2,0:2].shape)

[[ 1.25112839  0.77884554]
 [-0.50159077  1.1798979 ]]


(0, 0)

In [22]:
np.reshape(x,(x.shape[0],-1))

array([[-0.14413925, -1.32305014,  0.47968129, ...,  1.4906905 ,
         1.70746929,  0.24435162],
       [-0.10128278, -0.84299713,  1.016312  , ...,  0.00423085,
         0.678212  ,  0.09096756]])

In [6]:

conv = Conv2D(input_channels=3,kernel_size = 3, no_of_filters = 50, stride = 1, pad = 0, accumulate_grad_norm = False, accumulate_params = False,seed_key = None, bias = True, initialization = "None",use_legacy_v1=True)
flatten = Flatten()

In [7]:
st = time.time()
out = conv.forward(x)
out = flatten.forward(out)
et = time.time()
et-st

0.33410120010375977

In [8]:
grad = np.random.randn(*list(out.shape))

In [9]:
st = time.time()
grad = flatten.backward(grad)
in_grad = conv.backward(grad)
et = time.time()
et-st

0.6525182723999023

In [10]:
in_grad.shape

(2, 3, 30, 30)

In [12]:
conv.grad_cache['dL_db'].shape

(50,)

In [50]:
st = time.time()
out2=conv1(torch.tensor(x, dtype=torch.float))
et = time.time()
et-st

0.11046099662780762

In [10]:
k = np.random.randn(3,3)
s = np.random.randn(3,3,3)
k + np.sum(s,axis = 0)

array([[-2.92018317,  1.32171577, -0.46376603],
       [-0.51912495,  1.42363607, -2.44862986],
       [ 0.9263156 , -0.90064167,  1.29908994]])

In [6]:
out.shape

(5, 50, 148, 148)

In [47]:
x.strides[-1] * 20 * 20 * 3

9600

In [68]:
stride = (1,1)

In [69]:
strides = (
    x.strides[0],  # Batch dimension stride (unchanged)
    x.strides[1],  # Channel dimension stride (unchanged)
    x.strides[2] * stride[0],  # Vertical stride (step size for sliding window)
    x.strides[3] * stride[1],  # Horizontal stride (step size for sliding window)
    x.strides[2],  # Kernel height stride (step size within the kernel)
    x.strides[3]   # Kernel width stride (step size within the kernel)
)

In [72]:
shape = (2, 1, (4-2)//stride[0] + 1, (4-2)//stride[1] + 1, 2, 2)

In [74]:
x = np.lib.stride_tricks.as_strided(x, shape=shape,strides=strides)

In [78]:
out = np.einsum('bchwkl,fkl->bfhw', x, w)

In [80]:
out.shape

(2, 10, 3, 3)

In [81]:
(4-2)//stride[0] + 1, (4-2)//stride[1] + 1

(3, 3)

In [103]:
import jax

In [104]:
x = jax.random.normal(jax.random.PRNGKey(1),(5,3,224,224))

In [136]:
w = jax.random.normal(jax.random.PRNGKey(1),(50,3,3,3))

In [137]:
def _conv_over_one_batch(x :jax.Array, w :jax.Array, stride : tuple):
    if x.ndim == 3:
        x = x[None,...]
    conv_out = jax.lax.conv_general_dilated(x, w[None, ...], padding='VALID',
                                window_strides=stride,dimension_numbers=('NCHW','OIHW','NCHW'))
    return conv_out[0,0]

In [138]:
out = jax.vmap(jax.vmap(_conv_over_one_batch, in_axes=(None, 0, None)), in_axes=(0,None,None))(jax.numpy.array(x),w, stride)

In [119]:
x = np.random.randn(5,3,224,224)

In [120]:
out == conv.forward(x)

False

In [7]:
type(conv)

MiniTorch.nets.layers.Conv2D

In [134]:
conv.forward(x).shape

(5, 50, 148, 148)

In [None]:
out

Array([[[[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]],

        [[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False]],

        [[False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         [False, False, False, ..., False, False, False],
         ...,
         [False, False, False, ..., False, False, False],
         [False, False, Fa