In [9]:
import torch
import torch.nn as nn

In [4]:
class Softmax(object):
  """
  
  """
  @staticmethod
  def forward(logits):
      """
      Computes the forward pass for a Softmax layer.
      Input:
      - x: Input; a tensor of any shape
      Returns a tuple of:
      - out: Output, a tensor of the same shape as x
      - cache: x
      """

      # out = x * torch.gt(x, 0)
      # cache = x
      x_max=torch.max(logits)
      exp=torch.exp(logits-x_max)
      return exp/torch.sum(exp , axis=1)
  @staticmethod
  def backward(dout, cache ,y):
    """
    Computes the backward pass for a Softmax layer.
    Input:
    - dout: Upstream derivatives, of any shape
    - cache: Input x, of same shape as dout
    Returns:
    - dx: Gradient with respect to x
    """
    s = cache
    # dlocal = s * (1-s) * dout  
    dlocal =torch.min(dout) * s * (1-s) * y - torch.min(dout) * torch.max(s*y) * s * (1-y)
  
    return dlocal

In [5]:
torch.manual_seed(0)
input  = torch.rand((10,1) ) # the input to the network
w = torch.rand((3,10) , requires_grad=True) # the weights tensor of shape (3x10)
y=torch.tensor([0,1,0]) # The ground truth output

linear = torch.mm(w,input) # (3x1) logits before the softmax

In [6]:
s=Softmax.forward(linear.reshape(1,-1)) # (3x1) the probabilities after the softmax
dl_ds = -y / s   #upstream gradient 
dl_dlinear = Softmax.backward(dl_ds ,  s , y) # local gradient
dl_dw = dl_dlinear * input # the drivative of the q4 with recpect to the weights


In [7]:
dl_dw.t() # the gradient of the loss with recpect to the weights

tensor([[ 0.1269,  0.1964,  0.0226,  0.0337,  0.0786,  0.1621,  0.1253,  0.2291,
          0.1165,  0.1616],
        [-0.3325, -0.5148, -0.0593, -0.0885, -0.2060, -0.4249, -0.3284, -0.6007,
         -0.3053, -0.4237],
        [ 0.2057,  0.3184,  0.0367,  0.0547,  0.1274,  0.2628,  0.2031,  0.3715,
          0.1888,  0.2621]], grad_fn=<TBackward>)

lets compare the result to pytorch backward() function

In [10]:
torch.manual_seed(0)

input  = torch.rand((10,1) ) # the input to the network
w = torch.rand((3,10) , requires_grad=True) # the weights tensor of shape (3x10)
y=torch.tensor([0,1,0]) # The ground truth output

linear = torch.mm(w,input) # (3x1) logits before the softmax
softmax =nn.Softmax()
soft = softmax(linear.reshape(1,-1))

loss = torch.sum( -y * torch.log(soft) ) # (scalar) the cross entropy loss between s and y 


  if __name__ == '__main__':


In [11]:
loss.backward() # compute the gardient using autograd engine

In [12]:
w.grad

tensor([[ 0.1269,  0.1964,  0.0226,  0.0337,  0.0786,  0.1621,  0.1253,  0.2291,
          0.1165,  0.1616],
        [-0.3325, -0.5148, -0.0593, -0.0885, -0.2060, -0.4249, -0.3284, -0.6007,
         -0.3053, -0.4237],
        [ 0.2057,  0.3184,  0.0367,  0.0547,  0.1274,  0.2628,  0.2031,  0.3715,
          0.1888,  0.2621]])

In [13]:
if torch.sum(dl_dw.t() - w.grad)< 0.00001 :
  print("gooooooooood result")
else:
  print('bad result')

gooooooooood result
