In [None]:
def _lstm_cell(x: torch.Tensor, u: nn.Linear, w: nn.Linear, ht: torch.Tensor, ct: torch.Tensor):
    """Returns the hidden states for each time step.
    Arguments
    ---------
    wx : torch.Tensor
        Linearly transformed input.
    """
    hiddens = []
    cell_state = []

    # Feed-forward affine transformations (all steps in parallel)
    wx = w(x)

    # Sampling dropout mask
    #drop_mask = self._sample_drop_mask(wx)

    # Loop over time axis
    for k in range(wx.shape[1]):
        gates = wx[:, k] + u(ht)
        it, ft, gt, ot = gates.chunk(4, dim=-1)
        it = torch.sigmoid(it)
        ft = torch.sigmoid(ft)
        gt = torch.tanh(gt)
        ot = torch.sigmoid(ot)

        ct = ft * ct + it * gt 
        ht = ot * torch.tanh(ct) #* drop_mask

        hiddens.append(ht)
        cell_state.append(ct)

    # Stacking states
    h = torch.stack(hiddens, dim=1)
    c = torch.stack(cell_state, dim=1)
    return h, c

In [2]:
import torch
import torch.nn as nn 
import torch.autograd as autograd 


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
hidden_size = 5
input_size = 5
batch_size = 1
ht = torch.randn(batch_size, hidden_size)
ct = torch.randn(batch_size, hidden_size)
x = torch.randn(batch_size, input_size)
w = nn.Linear(input_size, 4 * hidden_size, bias=True)
u = nn.Linear(hidden_size, 4 * hidden_size, bias=True)

In [553]:


class LSTM_Cell_Vanilla(autograd.Function):

    @staticmethod
    def forward(ctx, x, u, u_bias, w, w_bias, ht, ct):

        hiddens = []
        cell_state = []

        # Feed-forward affine transformations (all steps in parallel)
        wx = (x @ w.T) + w_bias

        # Sampling dropout mask
        #drop_mask = self._sample_drop_mask(wx)

        # Loop over time axis
        for k in range(wx.shape[1]):
            gates = wx[:, k] + (ht @ u.T) + u_bias 
            it, ft, gt, ot = gates.chunk(4, dim=-1)
            it = torch.sigmoid(it)
            ft = torch.sigmoid(ft)
            gt = torch.tanh(gt)
            ot = torch.sigmoid(ot)

            ct = ft * ct + it * gt 
            ht = ot * torch.tanh(ct) #* drop_mask

            hiddens.append(ht)
            cell_state.append(ct)

        # Stacking states
        h = torch.stack(hiddens, dim=1)
        c = torch.stack(cell_state, dim=1)
        return h, c

    @staticmethod
    def backward(ctx, grad_out_h, grad_out_c):
        return None, None, None, None, None 

class LSTM_Cell(autograd.Function):

    @staticmethod
    def forward(ctx, wx, u, u_bias, ht, ct):

        hiddens = []
        cell_state = []

        # Sampling dropout mask
        #drop_mask = self._sample_drop_mask(wx)

        # Loop over time axis
        for k in range(wx.shape[1]):
            ht, ct = _LSTM_Cell.apply(wx[:, k], u, u_bias, ht, ct)

            hiddens.append(ht)
            cell_state.append(ct)

        # Stacking states
        h = torch.stack(hiddens, dim=1)
        c = torch.stack(cell_state, dim=1)
        return h, c

    @staticmethod
    def backward(ctx, grad_out_h, grad_out_c):
        return None, None, None, None, None 

class _LSTM_Cell(autograd.Function):

    @staticmethod
    def forward(ctx, wx,  u, u_bias, ct, ht):
        # Loop over time axis
        hiddens = []
        cell_state = []
        save_it = []
        save_ft = []
        save_gt = []
        save_ot = []

        for k in range(wx.shape[1]):

            gates = wx[:, k] + (ht @ u.T) + u_bias 
            it, ft, gt, ot = gates.chunk(4, dim=1)

            it = torch.sigmoid(it)
            ft = torch.sigmoid(ft)
            gt = torch.tanh(gt)
            ot = torch.sigmoid(ot)

            save_it.append(it)
            save_ft.append(ft)
            save_gt.append(gt)
            save_ot.append(ot)

            ct = ft * ct + it * gt 
            ht =  ot * torch.tanh(ct)

            hiddens.append(ht)
            cell_state.append(ct)

        # Stacking states
        h = torch.stack(hiddens, dim=1)
        c = torch.stack(cell_state, dim=1)
        it = torch.stack(save_it, dim=1)
        ft = torch.stack(save_ft, dim=1)
        gt = torch.stack(save_gt, dim=1)
        ot = torch.stack(save_ot, dim=1)

        ctx.save_for_backward(it, ft, gt, ot, c, h, u, wx)

        return h, c

    @staticmethod
    def backward(ctx, grad_out_h, grad_out_c):
        it, ft, gt, ot, c, h, u, wx = ctx.saved_tensors

        
        dh_prev, dc_prev = 0, 0

        di = torch.zeros_like(it)
        df = torch.zeros_like(ft)
        dg = torch.zeros_like(gt)
        do = torch.zeros_like(ot)

        h_init = torch.zeros_like(h[:, 0])
        c_init = torch.zeros_like(c[:, 0])
        du = torch.zeros_like(u)
        for t in reversed(range(wx.shape[1])):

            dh = grad_out_h[:, t] + dh_prev
            dc = (1 - torch.tanh(c[:, t]) ** 2) * ot[:, t] * dh + dc_prev + grad_out_c[:, t]

            _di = dc  * gt[:, t] * ((1 - it[:, t]) * it[:, t])

            
            ct = c_init if t - 1 < 0 else c[:, t-1]

            _df = dc  * ct * ((1 - ft[:, t]) * ft[:, t])

            _dg = dc  *  it[:, t] * (1 - gt[:, t] ** 2)
            _do = dh * torch.tanh(c[:, t]) * ((1 - ot[:, t]) * ot[:, t])

            di[:, t] = _di
            df[:, t] = _df
            dg[:, t] = _dg
            do[:, t] = _do


            tmp = torch.cat((_di, _df, _dg, _do), 1)

            ht = h_init if t - 1 < 0 else h[:, t-1]
            
            du += tmp.T @ ht


            dh_prev = tmp @ u 
            dc_prev = dc * ft[:, t]
            
        dwx = torch.cat((di, df, dg, do), axis=2)

        return dwx, du, dwx, None, None




In [554]:
hidden_size = 2
input_size = 1
batch_size = 1
htt = torch.zeros(batch_size, hidden_size).double()
ctt = torch.zeros(batch_size, hidden_size).double()
x = torch.randn(batch_size, 2, input_size).double()
w = nn.Linear(input_size, 4 * hidden_size, bias=True).double()
u = nn.Linear(hidden_size, 4 * hidden_size, bias=True).double()

wx = w(x)
autograd.gradcheck(_LSTM_Cell.apply, [wx,  u.weight, u.bias,ctt, htt])

True

In [None]:
h, c = _lstm_cell(x, u, w, ht, ct)

NameError: name '_lstm_cell' is not defined

In [43]:
h_ad, c_ad = LSTM_Cell_Vanilla.apply(x, u.weight, u.bias, w.weight, w.bias, ht, ct)

In [44]:
def check(h, h_ad, c, c_ad):
    assert torch.allclose(h, h_ad), torch.allclose(c, c_ad)
    return True 

In [45]:
check(h, h_ad, c, c_ad)

True

In [46]:
h_ad, c_ad = LSTM_Cell.apply(x, u.weight, u.bias, w.weight, w.bias, ht, ct)

In [47]:
check(h, h_ad, c, c_ad)

True

In [462]:
class _LSTM_Cell(autograd.Function):

    @staticmethod
    def forward(ctx, wx,  u, u_bias, ht, ct):
        # Loop over time axis
        gates = wx + (ht @ u.T) + u_bias 
        it, ft, gt, ot = gates.chunk(4, dim=1)

        ctx.save_for_backward(it, ft, gt, ot, ct, ht, u)
        it = torch.sigmoid(it)
        ft = torch.sigmoid(ft)
        gt = torch.tanh(gt)
        ot = torch.sigmoid(ot)

        ct = ft * ct + it * gt 
        ht =  ot * torch.tanh(ct)
        ctx.ct = ct

        return ht, ct

    @staticmethod
    def backward(ctx, grad_out_h, grad_out_c):
        it, ft, gt, ot, ctt, htt, u = ctx.saved_tensors

        ui, uf, ug, uo = u.chunk(4, dim=0)
        ct = ctx.ct

        """
        do = dh * torch.tanh(ct) * ((1 - torch.sigmoid(ot)) * torch.sigmoid(ot))
        dg = dh * torch.sigmoid(ot) * (1 - torch.tanh(ct) ** 2) * torch.sigmoid(it) * (1 - torch.tanh(gt) ** 2)
        di = dh * torch.sigmoid(ot) * (1 - torch.tanh(ct) ** 2) * torch.tanh(gt) * ((1 - torch.sigmoid(it)) * torch.sigmoid(it))
        df = dh * torch.sigmoid(ot) * (1 - torch.tanh(ct) ** 2) * ctt * ((1 - torch.sigmoid(ft)) * torch.sigmoid(ft))
        """
        dh = grad_out_h
        dc = grad_out_c + dh * torch.sigmoid(ot) * (1 - torch.tanh(ct) ** 2)
        di = dc  * torch.tanh(gt) * ((1 - torch.sigmoid(it)) * torch.sigmoid(it))
        df = dc  * ctt * ((1 - torch.sigmoid(ft)) * torch.sigmoid(ft))
        dg = dc  *  torch.sigmoid(it) * (1 - torch.tanh(gt) ** 2)
        do = dh * torch.tanh(ct) * ((1 - torch.sigmoid(ot)) * torch.sigmoid(ot))

        
        dwx = torch.cat((di, df, dg, do), axis=1)
        du =  ( dwx.T @ htt ) 

        return dwx, du, dwx, dwx @ u, dc *  torch.sigmoid(ft)



In [463]:
hidden_size = 5
input_size = 5
batch_size = 8
htt = torch.randn(batch_size, hidden_size, requires_grad=True).double()
ctt = torch.randn(batch_size, hidden_size, requires_grad=True).double()
x = torch.randn(batch_size, 10, input_size).double()
w = nn.Linear(input_size, 4 * hidden_size, bias=True).double()
u = nn.Linear(hidden_size, 4 * hidden_size, bias=True).double()

wx = w(x)
autograd.gradcheck(_LSTM_Cell.apply, [wx[:, 0], u.weight, u.bias, htt, ctt])

True

In [130]:
class Linear(autograd.Function):

    @staticmethod
    def forward(ctx, x, w):
    
        ctx.save_for_backward(x, w)
        out = x @ w.T 

        return out

    @staticmethod
    def backward(ctx, dy):
        x, w = ctx.saved_tensors

    
        return  dy @ w, dy.T @ x 

In [58]:
hidden_size = 5
input_size = 5
batch_size = 1

x = torch.randn(batch_size,  input_size).double()
u = nn.Linear(hidden_size, hidden_size, bias=True).double()

autograd.gradcheck(Linear.apply, [x, u.weight])

True