In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Linear:
    def __init__(self, in_features, out_features):
        self.w = None
        self.b = None
        self.dw = None
        self.b = None

        self.in_features = in_features
        self.out_features = out_features
        self._init_params()
    
    def _init_params(self):
        mu = 0
        sigma = np.sqrt(2 / (self.in_features + self.out_features))
        self.w = np.random.normal(mu, sigma, (self.in_features, self.out_features))
        self.b = np.zeros((1, self.out_features))
        
        
    def forward(self, input):
        self.input = input
        return input @ self.w + self.b

    def backward(self, pre_grad):
        self.dw = self.input.T @ pre_grad
        self.db = np.mean(pre_grad, axis=0)
        return pre_grad @ self.W.T
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    @property
    def get_params(self):
        return [self.w, self.b]

    @property
    def get_grads(self):
        return [self.dw, self.db]       

In [6]:
# test
batch_size, input_size, output_size = 128, 5, 10

x = np.random.randn(batch_size, input_size)
my_Linear = Linear(input_size, output_size)
my_y = my_Linear(x)