In [2]:
import numpy as np
np.random.seed(42)

In [12]:
class Linear():
    def __init__(self, in_dim: int, out_dim: int):
        """
        in_dim : explanatory_variable_dim_
        out_dim : target_variable_dim_
        dw : parameter_gradient_
        db : bias_gradient_
        """
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        self.weight = np.random.randn(out_dim, in_dim)
        self.bias = np.zeros(out_dim, dtype=float)

        self.dx , self.dw, self.db = None, None, None

    def __call__(self, x: np.ndarray) -> np.ndarray:
        """ forward_propagation
        x : input_data_ (batch_size, in_dim)
        output : output_data_ (batch_size, out_dim)
        """
        self.x = x
        # Affine
        output = np.dot(self.x, self.weight.T) + self.bias
        self.param = {'w' : self.weight, 'b' : self.bias}
        return  output

    def backward(self, grad: np.ndarray) -> np.ndarray:
        """ back_propagation
        grad : previous_gradient_ (batch_size, out_dim)
        dx : gradient_ (batch_size, in_dim)
        """
        # transpose x_shape
        if self.x.ndim == 2:
            x_T = self.x.T
        if self.x.ndim == 3:
            x_T = np.transpose(self.x, (0, 2, 1))
        if self.x.ndim == 4:
            x_T = np.transpose(self.x, (0, 1, 3, 2))
        # calculate gradient
        dx = np.dot(grad, self.weight)
        dw = np.dot(x_T, grad)
        db = np.sum(grad, axis=0)
        self.grad_param = {'w' : dw, 'b' : db}
        return dx

In [13]:
batch, channel, indim, outdim = np.arange(2,6)
# create_data
x = np.random.randn(batch, channel, indim)

# forward_propagation
affine = Linear(indim, outdim)
out = affine(x)
print('-------------  Affine_output  ------------\n',out)
print()
# parameters
print('----------------  params  ----------------\n', affine.param)
# demo_grad
grad = np.random.randn(batch, channel, outdim)
#back_propagation
dx = affine.backward(grad)
print('----------------  grad  -------------\n', dx)
print()
# grad_parameters
print('---------------params_grad-----------\n',affine.grad_param)

-------------  Affine_output  ------------
 [[[ 0.37592799  1.12673461  1.38469528  0.62574414 -0.47870263]
  [ 0.35669337  0.5067835   0.32429235  0.29012487  0.16871195]
  [ 1.19829281 -2.3424583   4.82897799 -1.04301817 -1.91963209]]

 [[ 1.80976351  0.40517077  3.79630358  0.30860244 -1.49672565]
  [ 1.18746135 -0.42686333 -0.7142571  -0.30638402 -0.07519201]
  [ 0.02754747  2.91399935 -0.9284874   1.41278849 -0.07131813]]]

----------------  params  ----------------
 {'w': array([[-0.65183611,  0.04739867, -0.86041337, -0.38455554],
       [ 1.00629281, -0.57689187,  0.83569211, -1.12970685],
       [ 0.52980418,  1.44156862, -2.4716445 , -0.79689526],
       [ 0.57707213, -0.20304539,  0.37114587, -0.60398519],
       [ 0.08658979, -0.15567724,  1.16778206,  0.25442084]]), 'b': array([0., 0., 0., 0., 0.])}
----------------  grad  -------------
 [[[-1.108328   -0.42288572  0.87060102  1.08565993]
  [ 2.14003873  2.67858639 -4.58267983 -2.42843937]
  [-3.0171748  -2.14981808  1.104

In [14]:
###### Check my module against pytorch module ######
import torch
import torch.nn as nn
# self_made module
x = np.random.randn(2,3,3,4)
affine = Linear(4,6)
# torch.nn.Linear
xt = torch.tensor(x).float()
linear = nn.Linear(4,6)
# overwrite nn.Linear params with self_made module params
weight = linear.weight.detach().numpy().copy()
affine.weight = weight
bias = linear.bias.detach().numpy().copy()
affine.bias = bias
# output with self_made module
y = affine(x).astype(dtype='float32')
# output with nn.Linear
yt = linear(xt).detach().numpy().copy()
print(y[0][0][0][0])
print(yt[0][0][0][0])
np.round(y, decimals=6) ==\
np.round(yt, decimals=6)

0.38795564
0.38795564


array([[[[ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True, False,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True, False]]],


       [[[ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True]]]])