In [None]:
def _lstm_cell(x: torch.Tensor, u: nn.Linear, w: nn.Linear, ht: torch.Tensor, ct: torch.Tensor):
    """Returns the hidden states for each time step.
    Arguments
    ---------
    wx : torch.Tensor
        Linearly transformed input.
    """
    hiddens = []
    cell_state = []

    # Feed-forward affine transformations (all steps in parallel)
    wx = w(x)

    # Sampling dropout mask
    #drop_mask = self._sample_drop_mask(wx)

    # Loop over time axis
    for k in range(wx.shape[1]):
        gates = wx[:, k] + u(ht)
        it, ft, gt, ot = gates.chunk(4, dim=-1)
        it = torch.sigmoid(it)
        ft = torch.sigmoid(ft)
        gt = torch.tanh(gt)
        ot = torch.sigmoid(ot)

        ct = ft * ct + it * gt 
        ht = ot * torch.tanh(ct) #* drop_mask

        hiddens.append(ht)
        cell_state.append(ct)

    # Stacking states
    h = torch.stack(hiddens, dim=1)
    c = torch.stack(cell_state, dim=1)
    return h, c

In [1]:
import torch
import torch.nn as nn 
import torch.autograd as autograd 


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
hidden_size = 5
input_size = 5
batch_size = 1
ht = torch.randn(batch_size, hidden_size)
ct = torch.randn(batch_size, hidden_size)
x = torch.randn(batch_size, input_size)
w = nn.Linear(input_size, 4 * hidden_size, bias=True)
u = nn.Linear(hidden_size, 4 * hidden_size, bias=True)

In [18]:


class LSTM_Cell_Vanilla(autograd.Function):

    @staticmethod
    def forward(ctx, x, u, u_bias, w, w_bias, ht, ct):

        hiddens = []
        cell_state = []

        # Feed-forward affine transformations (all steps in parallel)
        wx = (x @ w.T) + w_bias

        # Sampling dropout mask
        #drop_mask = self._sample_drop_mask(wx)

        # Loop over time axis
        for k in range(wx.shape[1]):
            gates = wx[:, k] + (ht @ u.T) + u_bias 
            it, ft, gt, ot = gates.chunk(4, dim=-1)
            it = torch.sigmoid(it)
            ft = torch.sigmoid(ft)
            gt = torch.tanh(gt)
            ot = torch.sigmoid(ot)

            ct = ft * ct + it * gt 
            ht = ot * torch.tanh(ct) #* drop_mask

            hiddens.append(ht)
            cell_state.append(ct)

        # Stacking states
        h = torch.stack(hiddens, dim=1)
        c = torch.stack(cell_state, dim=1)
        return h, c

    @staticmethod
    def backward(ctx, grad_out_h, grad_out_c):
        return None, None, None, None, None 

class LSTM_Cell(autograd.Function):

    @staticmethod
    def forward(ctx, x, u, u_bias, w, w_bias, ht, ct):

        hiddens = []
        cell_state = []

        # Feed-forward affine transformations (all steps in parallel)
        wx = (x @ w.T) + w_bias

        # Sampling dropout mask
        #drop_mask = self._sample_drop_mask(wx)

        # Loop over time axis
        for k in range(wx.shape[1]):
            ht, ct = _LSTM_Cell.apply(wx[:, k], u, u_bias, w, w_bias, ht, ct)

            hiddens.append(ht)
            cell_state.append(ct)

        # Stacking states
        h = torch.stack(hiddens, dim=1)
        c = torch.stack(cell_state, dim=1)
        return h, c

    @staticmethod
    def backward(ctx, grad_out_h, grad_out_c):
        return None, None, None, None, None 

class _LSTM_Cell(autograd.Function):

    @staticmethod
    def forward(ctx, ht, wx, u, u_bias, ct):
        # Loop over time axis
        gates = wx + (ht @ u.T) + u_bias 
        it, ft, gt, ot = gates.chunk(4, dim=-1)

        ctx.save_for_backward(it, ft, gt, ot, ct, ht, u)
        it = torch.sigmoid(it)
        ft = torch.sigmoid(ft)
        gt = torch.tanh(gt)
        ot = torch.sigmoid(ot)

        ct = ft * ct + it * gt 
        ctx.ct = ct
        ht = ot * torch.tanh(ct) #* drop_mask

        return ht, ct

    @staticmethod
    def backward(ctx, grad_out_h, grad_out_c):
        it, ft, gt, ot, ctt, htt, u = ctx.saved_tensors

        ui, uf, ug, uo = u.chunk(4, dim=0)
        ct = ctx.ct

        dh = grad_out_h 
        """
        do = dh * torch.tanh(ct) * ((1 - torch.sigmoid(ot)) * torch.sigmoid(ot))
        dg = dh * torch.sigmoid(ot) * (1 - torch.tanh(ct) ** 2) * torch.sigmoid(it) * (1 - torch.tanh(gt) ** 2)
        di = dh * torch.sigmoid(ot) * (1 - torch.tanh(ct) ** 2) * torch.tanh(gt) * ((1 - torch.sigmoid(it)) * torch.sigmoid(it))
        df = dh * torch.sigmoid(ot) * (1 - torch.tanh(ct) ** 2) * ctt * ((1 - torch.sigmoid(ft)) * torch.sigmoid(ft))
        """
        di = grad_out_c * torch.tanh(gt) * ((1 - torch.sigmoid(it)) * torch.sigmoid(it))
        df = ctt *  torch.sigmoid(ft) 
        dg = torch.sigmoid(it) * (1 - torch.tanh(gt) ** 2)
        do = torch.tanh(ct) * ((1 - torch.sigmoid(ot)) * torch.sigmoid(ot))
        dh_prev = do @ uo + dg @ ug + di @ ui + df @ uf

        return dh_prev, None, None, None, None 



In [19]:
hidden_size = 5
input_size = 5
batch_size = 1
htt = torch.randn(batch_size, hidden_size).double()
ctt = torch.randn(batch_size, hidden_size).double()
x = torch.randn(batch_size, 10, input_size).double()
w = nn.Linear(input_size, 4 * hidden_size, bias=True).double()
u = nn.Linear(hidden_size, 4 * hidden_size, bias=True).double()

wx = w(x)
autograd.gradcheck(_LSTM_Cell.apply, [htt, wx[:, 0], u.weight, u.bias, ctt])

GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[-0.0050,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.0514,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0039,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -0.0826,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0454],
        [ 0.0428,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0095,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0259,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -0.0126,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0610],
        [ 0.1484,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0035,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0425,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.1071,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0538],
        [ 0.0090,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.0285,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000, -0.0478,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000, -0.1224,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0339]], dtype=torch.float64)
analytical:tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], dtype=torch.float64)


In [20]:
h, c = _lstm_cell(x, u, w, ht, ct)

NameError: name '_lstm_cell' is not defined

In [43]:
h_ad, c_ad = LSTM_Cell_Vanilla.apply(x, u.weight, u.bias, w.weight, w.bias, ht, ct)

In [44]:
def check(h, h_ad, c, c_ad):
    assert torch.allclose(h, h_ad), torch.allclose(c, c_ad)
    return True 

In [45]:
check(h, h_ad, c, c_ad)

True

In [46]:
h_ad, c_ad = LSTM_Cell.apply(x, u.weight, u.bias, w.weight, w.bias, ht, ct)

In [47]:
check(h, h_ad, c, c_ad)

True