In [2]:
import numpy as np
np.random.seed(42)

In [10]:
class Linear():
    def __init__(self, in_dim: int, out_dim: int):
        """
        in_dim : explanatory_variable_dim_
        out_dim : target_variable_dim_
        dw : parameter_gradient_
        db : bias_gradient_
        """
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        self.weight = np.random.randn(out_dim, in_dim)
        self.bias = np.zeros(out_dim, dtype=float)

        self.dx , self.dw, self.db = None, None, None

    def __call__(self, x: np.ndarray) -> np.ndarray:
        """ forward_propagation
        x : input_data_ (batch_size, in_dim)
        output : output_data_ (batch_size, out_dim)
        """
        self.x = x
        # Affine
        output = np.dot(self.x, self.weight.T) + self.bias
        self.param = {'w' : self.weight, 'b' : self.bias}
        return  output

    def backward(self, grad: np.ndarray) -> np.ndarray:
        """ back_propagation
        grad : previous_gradient_ (batch_size, out_dim)
        dx : gradient_ (batch_size, in_dim)
        """
        # transpose x_shape
        if self.x.ndim == 2:
            x_T = self.x.T
        if self.x.ndim == 3:
            x_T = np.transpose(self.x, (0, 2, 1))
        if self.x.ndim == 4:
            x_T = np.transpose(self.x, (0, 1, 3, 2))
        # calculate gradient
        dx = np.dot(grad, self.weight)
        dw = np.dot(x_T, grad)
        db = np.sum(self.bias)
        self.grad_param = {'w' : dw, 'b' : db}
        return dx

In [11]:
batch, channel, indim, outdim = np.arange(2,6)
# create_data
x = np.random.randn(batch, channel, indim)

# forward_propagation
affine = Linear(indim, outdim)
out = affine(x)
print('-------------  Affine_output  ------------\n',out)
print()
# parameters
print('----------------  params  ----------------\n', affine.param)
# demo_grad
grad = np.random.randn(batch, channel, outdim)
#back_propagation
dx = affine.backward(grad)
print('----------------  grad  -------------\n', dx)
print()
# grad_parameters
print('---------------params_grad-----------\n',affine.grad_param)

-------------  Affine_output  ------------
 [[[ 1.60390525 -1.35064836 -2.04226026 -0.35020045 -0.54220778]
  [ 0.6389963  -1.7398243  -0.95549327  0.29750484 -0.52366306]
  [ 1.768501    0.06397159 -1.39201434 -2.27508346 -0.03210308]]

 [[ 0.01949151 -3.12780704  1.65338964 -2.69720329 -0.30316664]
  [-1.22042728 -2.35410276  3.58903065  0.44679222 -0.02189353]
  [-2.5205208   2.92718227 -0.75422601  0.71272238  0.3994948 ]]]

----------------  params  ----------------
 {'w': array([[ 0.34175598,  1.87617084,  0.95042384, -0.57690366],
       [-0.89841467,  0.49191917, -1.32023321,  1.83145877],
       [ 1.17944012, -0.46917565, -1.71313453,  1.35387237],
       [-0.11453985,  1.23781631, -1.59442766, -0.59937502],
       [ 0.0052437 ,  0.04698059, -0.45006547,  0.62284993]]), 'b': array([0., 0., 0., 0., 0.])}
----------------  grad  -------------
 [[[-0.1502606  -1.45930443 -2.17330917  0.6529044 ]
  [ 2.45887279 -3.08795636 -1.42531317 -1.09642087]
  [ 1.82017447  2.59422179 -1.733

In [25]:
###### Check my module against pytorch module ######
import torch
import torch.nn as nn
# self_made module
x = np.random.randn(2,3,3,4)
affine = Linear(4,6)
# torch.nn.Linear
xt = torch.tensor(x).float()
linear = nn.Linear(4,6)
# overwrite nn.Linear params with self_made module params
weight = linear.weight.detach().numpy().copy()
affine.weight = weight
bias = linear.bias.detach().numpy().copy()
affine.bias = bias
# output with self_made module
y = affine(x).astype(dtype='float32')
# output with nn.Linear
yt = linear(xt).detach().numpy().copy()
print(y[0][0][0][0])
print(yt[0][0][0][0])
np.round(y, decimals=6) ==\
np.round(yt, decimals=6)

-0.60770345
-0.60770345


array([[[[ True,  True, False,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True, False]],

        [[ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True]]],


       [[[ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True, False,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True]]]])