In [1]:
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
from MiniTorch.core.baseclasses import ComputationNode
from MiniTorch.legacy_utils import _conv2d_forward_legacy_v1, _conv2d_forward_legacy_v2, _conv2d_backward_legacy_v1, _conv_initialize_legacy, get_kernel_size, get_stride, _conv2d_backward_legacy_v2
import time
from typing import Literal, List, Tuple, Dict, Any
import jax

In [2]:
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist['data'], mnist['target']

In [None]:
class Conv2D(ComputationNode):

    def __init__(self, input_channels : int,kernel_size : int | tuple = 3, no_of_filters = 1, stride = 1, pad = 0, accumulate_grad_norm = False, accumulate_params = False,seed_key = None, bias = True, 
                 initialization = "None", use_legacy_v1 : bool = False, use_legacy_v2:bool = False):
        super().__init__()
        if seed_key == None:
            self.seed_key = jrandom.PRNGKey(int(time.time()))
        self.kernel_size = get_kernel_size(kernel_size)
        self.input_channels = input_channels
        self.no_of_filters = no_of_filters
        self.stride = get_stride(stride)
        self.pad = pad
        self.accumulate_grad_norm = accumulate_grad_norm
        self.accumulate_params = accumulate_params
        self.initialization = initialization
        self.parameters = {'W': None, 'b': None}
        self.bias = bias

        self.use_legacy_v1 = use_legacy_v1
        self.use_legacy_v2 = use_legacy_v2
        if use_legacy_v1 or use_legacy_v2:
            self.parameters['W'], self.parameters['b'] = _conv_initialize_legacy(self.kernel_size,self.no_of_filters,self.input_channels,self.initialization,self.bias)
        else:
            self.initialize(self.seed_key)
    def initialize(self, seed_key):
        if self.initialization == "he":
            self.parameters['W'] = jrandom.normal(seed_key, (self.no_of_filters, self.input_channels, self.kernel_size[0], self.kernel_size[1])) * jnp.sqrt(2/(self.no_of_filters * self.kernel_size[0] * self.kernel_size[1]))
        else:
            self.parameters['W'] = jrandom.normal(seed_key, (self.no_of_filters, self.input_channels, self.kernel_size[0], self.kernel_size[1]))
        if self.bias:
            self.parameters['b'] = jnp.zeros((self.no_of_filters,))

    @staticmethod
    def _conv2d_forward(X : jax.Array, W : jax.Array,b : jax.Array, stride : tuple, padding: Literal['VALID','SAME'] = 'VALID'):

        def conv_over_one_batch(X_vec, W_vec, stride, padding):

            if X_vec.ndim == 3:
                X_vec = X_vec[None,...]
            cvout = jax.lax.conv_general_dilated(X_vec,W_vec[None,...],window_strides=stride,padding=padding,
                                                    dimension_numbers=('NCHW','OIHW','NCHW'))[0,0]
            return cvout
        convout = jax.vmap(jax.vmap(conv_over_one_batch,in_axes=(None,0,None,None)), in_axes=(0,None,None,None))(X,W,stride,padding)
        convout += b
        return convout

    def forward(self, x):
        self.input = x
        if self.use_legacy_v1:
            x = np.pad(x,((0,0),(0,0),(self.pad,self.pad),(self.pad,self.pad)))
            self.output = _conv2d_forward_legacy_v1(self.parameters['W'], x, self.stride, self.parameters['b'])
            return self.output
        if self.use_legacy_v2:
            x = np.pad(x,((0,0),(0,0),(self.pad,self.pad),(self.pad,self.pad)))
            self.output = _conv2d_forward_legacy_v2(self.parameters['W'], x, self.stride, self.parameters['b'])
            return self.output
        W, b, stride = self.parameters['W'], self.parameters['b'], self.stride
        with jax.checking_leaks():
            output = jax.jit(Conv2D._conv2d_forward, static_argnames=('stride','padding'))(x, W, b, stride)
        self.output = output
        return self.output
    def backward(self, out_grad):
        dL_dW,dL_db,dL_dinput = None,None,None
        if self.use_legacy_v1:
            dL_dW,dL_db,dL_dinput = _conv2d_backward_legacy_v1(out_grad,self.input,self.kernel_size,self.parameters['W'],self.parameters['b'],self.stride,self.pad)
        if self.use_legacy_v2:
            dL_dW,dL_db,dL_dinput = _conv2d_backward_legacy_v2(out_grad,self.input,self.kernel_size,self.parameters['W'],self.parameters['b'],self.stride,self.pad)
        self.grad_cache['dL_dW'] = dL_dW
        self.grad_cache['dL_db'] = dL_db
        self.grad_cache['dL_dinput'] = dL_dinput
        return dL_dinput
    
    def weights_var_mean(self):
        return self.parameters['W'].var(), self.parameters['W'].mean()

    def bias_var_mean(self):
        return self.parameters['b'].var(), self.parameters['b'].mean()

    def step(self, lr):
        if self.accumulate_grad_norm:
            self._accumulate_grad_norm('dL_dW')
            self._accumulate_grad_norm('dL_db')
        if self.accumulate_params:
            self._accumulate_parameters('W', self.weights_var_mean)
            self._accumulate_parameters('b', self.bias_var_mean)
        self.parameters['W'] -= lr * self.grad_cache['dL_dW']
        if self.bias:
            self.parameters['b'] -= lr * self.grad_cache['dL_db']

In [4]:
class Flatten(ComputationNode):
    def __init__(self):
        super().__init__()
        self.requires_grad = False
        self.shape = None

    def forward(self,x):
        self.shape = x.shape
        self.input = x
        self.output = np.reshape(x,(x.shape[0],-1))
        return self.output
    def backward(self, output_grad):
        dL_dinput= np.reshape(output_grad,(self.shape[0],self.shape[1],self.shape[2],self.shape[3]))
        self.grad_cache['dL_dinput']  = dL_dinput
        return dL_dinput
    
class MaxPool2d(ComputationNode):
    def __init__(self, pool_size, pool_stride, use_legacy_v1 = False):
        super().__init__()
        self.pool_size = get_kernel_size(pool_size)
        self.stride = get_stride(pool_stride)
        self.use_legacy_v1 = use_legacy_v1
    @staticmethod
    def _maxpool2d_forward_legacy_v1(pool_size, stride, input):
        batch_size, input_channels, H, W = input.shape[0],input.shape[1], input.shape[2], input.shape[3]
        output_h = (H - pool_size[0])//stride[0] + 1
        output_w = (W - pool_size[1])//stride[1] + 1
        output = np.zeros((batch_size,input_channels,output_h,output_w))
        for b in range(batch_size):
            for c in range(input_channels):
                for i in range(output_h):
                    for j in range(output_w):
                        h_s = i * stride[0]
                        h_e = h_s + pool_size[0]
                        w_s = j * stride[1]
                        w_e = w_s + pool_size[1]
                        output[b,c,i,j] = np.max(input[b,c,h_s:h_e,w_s:w_e])
        return output
    @staticmethod
    def _maxpool2d_backward_legacy_v1(pool_size, input, out_grad, stride):
        batch_size, input_channels = input.shape[0],input.shape[1]
        out_grad_h, out_grad_w = out_grad.shape[2], out_grad.shape[3]
        dL_dinput = np.zeros_like(input)
        for b in range(batch_size):
            for c in range(input_channels):
                for i in range(out_grad_h):
                    for j in range(out_grad_w):
                        h_s = i * stride[0]
                        h_e = h_s + pool_size[0]
                        w_s = j * stride[1]
                        w_e = w_s + pool_size[1]
                        window = input[b,c,h_s:h_e,w_s:w_e]
                        max_ids = np.unravel_index(np.argmax(window),window.shape)
                        dL_dinput[b,c,h_s + max_ids[0],w_s + max_ids[1]] = out_grad[b,c,i,j]
        return dL_dinput

        
    def forward(self,x):
        self.input = x
        output = None
        if self.use_legacy_v1:
            output = self._maxpool2d_forward_legacy_v1(self.pool_size,self.stride,x)
        self.output = output
        return output

    def backward(self, output_grad):
        dL_dinput = None
        if self.use_legacy_v1:
            dL_dinput = self._maxpool2d_backward_legacy_v1(self.pool_size,self.input,output_grad,self.stride)
        self.grad_cache['dL_input'] = dL_dinput
        return dL_dinput


In [5]:
x = np.random.randn(2,3,30,30)

In [6]:
conv = Conv2D(input_channels=3,kernel_size = 3, no_of_filters = 50, stride = 1, pad = 0, accumulate_grad_norm = False, accumulate_params = False,seed_key = None, bias = True, initialization = "None",use_legacy_v2=True)
max_pool = MaxPool2d(2,1,True)
flatten = Flatten()

In [7]:
st = time.time()
out = conv.forward(x)
out = max_pool.forward(out)
out = flatten.forward(out)
et = time.time()
et-st

0.36423730850219727

In [8]:
grad = np.random.randn(*list(out.shape))

In [13]:
conv.use_legacy_v1 = False
conv.use_legacy_v2 = True

In [14]:
st = time.time()
grad1 = flatten.backward(grad)
grad2 = max_pool.backward(grad1)
in_grad1= conv.backward(grad2)
et = time.time()
et-st

0.49361562728881836

In [15]:
conv.step(0.001)

In [None]:
s = np.random.randn(3,3,3)
y = np.random.randn(3,3)

()

In [59]:
st = time.time()
grad3 = flatten.backward(grad)
grad4 = max_pool.backward(grad3)
in_grad2= conv.backward(grad4)
et = time.time()
et-st

0.9833188056945801

In [None]:
def _conv2d_backward_legacy_v1(out_grad: np.ndarray, input: np.ndarray, 
                             kernel_size: Tuple[int], W: np.ndarray, 
                             b: np.ndarray, stride: Tuple[int], 
                             pad: int) -> np.ndarray:
    batch_size, out_channel, out_h, out_w = out_grad.shape
    in_channel = input.shape[1]
    
    # Initialize gradients
    dL_dinput = np.zeros_like(input)
    dL_dW = np.zeros_like(W)
    dL_db = np.zeros_like(b)
    
    # Pad input gradient if needed
    if pad > 0:
        dL_dinput_padded = np.zeros((batch_size, in_channel, 
                                   input.shape[2] + 2*pad, 
                                   input.shape[3] + 2*pad))
    else:
        dL_dinput_padded = dL_dinput

    # Compute bias gradient (vectorized over batch and spatial dimensions)
    dL_db = np.sum(out_grad, axis=(0, 2, 3))

    # Create views for vectorized operations
    kh, kw = kernel_size
    stride_h, stride_w = stride

    # Generate index arrays for all positions at once
    h_starts = np.arange(out_h) * stride_h
    w_starts = np.arange(out_w) * stride_w
    h_ends = h_starts + kh
    w_ends = w_starts + kw
    for c in range(out_channel):
        patches = np.lib.stride_tricks.as_strided(
            input,
            shape=(batch_size, out_h, out_w, in_channel, kh, kw),
            strides=(input.strides[0], 
                    stride_h * input.strides[2],
                    stride_w * input.strides[3],
                    input.strides[1], 
                    input.strides[2], 
                    input.strides[3])
        )
        dL_dW[c] = np.tensordot(out_grad[:, c], patches, axes=([0,1,2], [0,1,2]))

        grad_reshaped = out_grad[:, c, :, :, np.newaxis, np.newaxis]
        w_reshaped = W[c, :, :, :, np.newaxis, np.newaxis]
        temp = np.zeros_like(dL_dinput_padded)
        idx_h = np.arange(kh)[np.newaxis, np.newaxis, :, np.newaxis]
        idx_w = np.arange(kw)[np.newaxis, np.newaxis, np.newaxis, :]
        h_pos = h_starts[:, np.newaxis, np.newaxis] + idx_h
        w_pos = w_starts[np.newaxis, :, np.newaxis] + idx_w
        np.add.at(temp, 
                 (slice(None), slice(None), h_pos, w_pos),
                 grad_reshaped * w_reshaped)

        dL_dinput_padded += temp

    if pad > 0:
        dL_dinput = dL_dinput_padded[:, :, pad:-pad, pad:-pad]
    else:
        dL_dinput = dL_dinput_padded

    return dL_dW, dL_db, dL_dinput


In [30]:
out = conv.forward(x)

In [31]:
grad = np.random.randn(*list(out.shape))

In [32]:
batch_size, out_channel, out_h, out_w = grad.shape
in_channel = x.shape[1]
kernel_size = (3,3)
stride = (1,1)
W, b = conv.parameters['W'], conv.parameters['b']
# Initialize gradients
dL_dinput = np.zeros_like(x)
dL_dW = np.zeros_like(W)
dL_db = np.zeros_like(b)

In [33]:
dL_db = np.sum(grad, axis=(0, 2, 3))
kh, kw = kernel_size
stride_h, stride_w = stride
h_starts = np.arange(out_h) * stride_h
w_starts = np.arange(out_w) * stride_w
h_ends = h_starts + kh
w_ends = w_starts + kw

In [53]:
strides = (
x.strides[0],
x.strides[1],
x.strides[2] * stride_h,
x.strides[3] * stride_w,
x.strides[2],
x.strides[3]
)
shapes = (batch_size,in_channel,out_h,out_w,kh,kw)

In [55]:
input_strided = np.lib.stride_tricks.as_strided(
    x,
    shape = shapes,
    strides= strides
)

In [20]:
grad.shape

(2, 50, 28, 28)

In [22]:
np.einsum('bhwikl,bchw->ckl',pathces,grad).shape

(50, 3, 3)

In [None]:
grad[:,0] 

(2, 1, 28, 28)

In [94]:
pathces.shape

(2, 30, 30, 3, 3, 3)

In [23]:
w = conv.parameters['W']

In [None]:
grad.shape

(2, 50, 28, 28)

In [44]:
grad_reshaped = grad[:, 0, :, :, np.newaxis, np.newaxis][:,None,...]
w_reshaped = w[0, None,:,None,None, :, :]
temp = np.zeros_like(x)
idx_h = np.arange(kh)[np.newaxis, np.newaxis, :, np.newaxis]
idx_w = np.arange(kw)[np.newaxis, np.newaxis, np.newaxis, :]


In [50]:
idx_h = np.arange(kh)[np.newaxis, :, np.newaxis]  # (1, 3, 1)
idx_w = np.arange(kw)[np.newaxis, np.newaxis, :]  # (1, 1, 3)
h_pos = h_starts[:, np.newaxis, np.newaxis] + idx_h  # (28, 1, 3, 1)
w_pos = w_starts[np.newaxis, :, np.newaxis] + idx_w

In [45]:
h_pos = h_starts[:, np.newaxis, np.newaxis] + idx_h
w_pos = w_starts[np.newaxis, :, np.newaxis] + idx_w
print(h_pos[0][1],w_pos[0][0][1])

[[1]
 [2]
 [3]] [1 2 3]


In [49]:
w_reshaped.shape, grad_reshaped.shape

((1, 3, 1, 1, 3, 3), (2, 1, 28, 28, 1, 1))

In [51]:
np.add.at(temp, 
            (slice(None), slice(None), h_pos, w_pos),
            grad_reshaped * w_reshaped)

IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (28,3,1) (1,28,3) 

In [56]:
W_rotated = np.rot90(w,2,axes=(2,3))

In [59]:
np.einsum('bohw,oikl->bihw', grad, W_rotated).shape

(2, 3, 28, 28)

In [60]:
x.shape

(2, 3, 30, 30)

In [61]:
kernel = np.array([[1,2],[3,4]])

In [62]:
np.rot90(kernel,2)

array([[4, 3],
       [2, 1]])

In [65]:
grad = np.pad(grad,((0,0),(0,0),(1,1),(1,1)))

In [67]:
np.einsum('bohw,oikl->bihw',grad,w).shape

(2, 3, 30, 30)

In [68]:
x.shape

(2, 3, 30, 30)