In [1]:
from torch import FloatTensor
from torch import LongTensor 

In [2]:
class Module(object):
    def forward(self, *input): 
        raise NotImplementedError
    def backward(self, *gradwrtoutput): 
        raise NotImplementedError
    def param(self): 
        return []

In [3]:
epsilon = 1e-6

In [45]:
class Linear (Module):
    def __init__(self, in_features, out_features):
        super(Module, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.w = FloatTensor(in_features, out_features).normal_(0, epsilon)
        self.b = FloatTensor(in_features, 1).normal_(0, epsilon) #TODO maybe put some specific parameters for init, 
                                                            #try xavier
        self.dl_dw = FloatTensor(self.w.size())
        self.dl_db = FloatTensor(self.b.size())  
    def forward(self, x_in):
        x_out = x_in.mm(self.w.t())+self.b.t()
        self.x_in = x_in
        return x_out
    def zero_grad(self): 
        self.dl_dw.zero_()
        self.dl_db.zero_()
    def backward(self, dl_dx_out): 
        dl_dw = 1/self.dl_dw.shape[0]*dl_dx_out.t().mm(self.x_in)
        dl_db = 1/self.dl_db.shape[0]*dl_dx_out.sum(0).view(-1,1)
        dl_dx_in = dl_dx_out.mm(self.w)
        self.dl_dw = dl_dw
        self.dl_db= dl_db
        return dl_dx_in
    def param(self): 
        return [(self.w, self.dl_dw), (self.b, self.dl_db)]

In [46]:
model = Linear(3, 5)

In [47]:
x_out = model.forward(FloatTensor(10, 5).normal_())
model.backward(FloatTensor(10, 3).normal_())
model.param()

[(
  1.00000e-06 *
    0.9213 -0.9813  0.5165 -0.8032  1.2092
    0.5633  0.8725 -0.0012 -0.6860 -0.5120
   -0.9405 -0.1476  0.1578 -0.4797 -0.8747
  [torch.FloatTensor of size 3x5], 
   0.0618  0.5465  1.0315 -1.2814  1.6803
  -0.4567 -0.0242  1.3338 -0.4402 -1.6675
  -0.2164 -1.5352 -2.6399 -1.5467  1.0349
  [torch.FloatTensor of size 3x5]), (
  1.00000e-06 *
   -1.8200
   -0.0613
   -0.6420
  [torch.FloatTensor of size 3x1], 
   1.0259
   0.2699
  -0.4375
  [torch.FloatTensor of size 3x1])]