<a href="https://colab.research.google.com/github/ShantanuKadam3115/MachineLearningBasics/blob/ML_implementations/CNN_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
import numpy as np

def conv_forward_naive(x, w, b, conv_param):
  N, C, H, W = x.shape
  F, C_filter, HH, WW = w.shape
  stride, pad = conv_param['stride'], conv_param['pad']

  assert C == C_filter

  H_out = 1 + (H + 2 * pad - HH) // stride
  W_out = 1 + (W + 2 * pad - WW) // stride

  out = np.zeros((N, F, H_out, W_out))

  x_pad = np.pad(x,((0,0),(0,0),(pad,pad),(pad,pad)), mode='constant')

  for n in range(N):
    for f in range(F):
      for i in range(H_out):
        for j in range(W_out):

          # vert_start = i * stride
          # vert_end = vert_start + HH
          # horiz_start = j * stride
          # horiz_end = horiz_start + WW

          # x_slice = x_pad[n, :, vert_start:vert_end, horiz_end:horiz_end]
          vert_start = i * stride
          vert_end = vert_start + HH
          horiz_start = j * stride
          horiz_end = horiz_start + WW

           # Slice the padded input
          # We need ALL channels (:), so shape is (C, HH, WW)
          x_slice = x_pad[n, :, vert_start:vert_end, horiz_start:horiz_end]

          out[n,f,i,j] = np.sum(x_slice * w[f]) + b[f]

  cache = (x, w, b, conv_param, x_pad)
  return out, cache


def conv_backward_naive(dout, cache):
  x,w,b,conv_param,x_pad = cache
  N, c, H, W = x.shape
  F, C_Fiter, HH, WW = w.shape
  stride, pad = conv_param['stride'], conv_param['pad']

  N, F, H_out, W_out = dout.shape

  dx_pad = np.zeros_like(x_pad)
  dw = np.zeros_like(w)
  db = np.zeros_like(b)


  for n in range(N):
    for f in range(F):

      db[f] += np.sum(dout[n,f])


      for i in range(H_out):
        for j in range(W_out):
          vert_start = i * stride
          vert_end = vert_start + HH
          horiz_start = j * stride
          horiz_end = horiz_start + WW

          x_slice = x_pad[n, : , vert_start:vert_end, horiz_start : horiz_end]
          dw[f] += x_slice * dout[n,f,i,j]


          dx_pad[n, :, vert_start:vert_end, horiz_start:horiz_end] += w[f] * dout[n,f,i,j]


  if pad > 0:
    dx = dx_pad[:,:, pad:-pad, pad:-pad]
  else:
    dx = dx_pad

  return dx , dw, db













# --- TEST DRIVE ---
print("--- Testing Conv Forward ---")
# 2 images, 3 channels, 4x4 size
x_test = np.random.randn(2, 3, 4, 4)
# 3 filters, 3 channels (must match input), 2x2 size
w_test = np.random.randn(3, 3, 2, 2)
b_test = np.random.randn(3)
params = {'stride': 2, 'pad': 1}

out_test, cached_out = conv_forward_naive(x_test, w_test, b_test, params)
print(f"Input Shape: {x_test.shape}")
print(f"Filter Shape: {w_test.shape}")
print(f"Output Shape: {out_test.shape}")
print("cache(x): ", cached_out[0].shape, cached_out[0])
print("cache(w): ", cached_out[1].shape, cached_out[1])
print("cache(biad): ", cached_out[2].shape, cached_out[2])
print("cache(conv_param): ", cached_out[3])
print("cache(x_pad)", cached_out[4].shape)
# Expected: (2, 3, 3, 3)
# Formula: (4 - 2 + 2)/2 + 1 = 3

print("--- Testing Conv Backward ---")

# Fake Upstream Gradient (dout) matching our previous output shape (2, 3, 3, 3)
dout = np.random.randn(2, 3, 3, 3)

# Run Backward
dx, dw, db = conv_backward_naive(dout, cached_out) # Note: Passing x_test as cache temporarily

print(f"dx Shape: {dx.shape} {dx}(Matches Input?)")
print(f"dw Shape: {dw.shape} {dw} (Matches Filter?)")
print(f"db Shape: {db.shape} {db}(Matches Bias?)")





--- Testing Conv Forward ---
Input Shape: (2, 3, 4, 4)
Filter Shape: (3, 3, 2, 2)
Output Shape: (2, 3, 3, 3)
cache(x):  (2, 3, 4, 4) [[[[ 1.16368884  0.41169372  0.94464824  0.80050399]
   [-0.38085631  0.72762413 -0.94177854  0.40369113]
   [-1.09754751  0.89304773  1.35124243 -0.33018604]
   [-1.37172582  0.11528923 -2.26490699  1.63316512]]

  [[ 0.00473061 -0.8601774   2.02261175  0.9765975 ]
   [ 0.64674807  3.2288364   0.84769162  1.12663373]
   [ 0.08089699  0.94709516  0.76408622  2.32093901]
   [-0.07884441  1.58841169 -0.09196849 -0.70451106]]

  [[-0.85730202 -0.86808899 -2.56244643  0.11088766]
   [-0.4925918   0.16334519 -0.09392074  1.19221864]
   [-0.44808313 -1.12201987 -0.25031506 -0.95588898]
   [ 0.75309053  0.60048932  0.17252865 -0.73126502]]]


 [[[-2.17611192 -0.85321164  0.63952048 -0.02598896]
   [-0.03145361  1.33167929  0.94831439 -0.30029912]
   [-0.46641187 -0.20575488  0.06315843  0.22207767]
   [-0.62997962  0.09570993  3.10421617  1.43003137]]

  [[-0.30