https://towardsdatascience.com/building-a-lstm-by-hand-on-pytorch-59c02a4ec091

Pytorch Code


In [5]:
import math
import torch
import  torch.nn as nn

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
print(torch.Tensor(2,3))
print(nn.Parameter(torch.Tensor(2,3)))

tensor([[1.1037e-05, 6.3016e-10, 6.6770e+22],
        [2.1006e+20, 5.1432e-11, 4.2330e+21]])
Parameter containing:
tensor([[1.1037e-05, 6.3016e-10, 6.6770e+22],
        [2.1006e+20, 5.1432e-11, 4.2330e+21]], requires_grad=True)


nn.Parameter :  when they're assigned as Module attributes they are automatically added to the list of its parameters, 
                and will appear e.g. in :meth:`~Module.parameters` iterator.
                States are saved using this subclass

In [11]:
class CustomLSTM(nn.Module):
    def __init__(self, input_sz:int, hidden_sz:int):
        super().__init__()
        self.input_size = input_sz
        self.hidden_size = hidden_sz
        
        #i_t
        self.u_i = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.v_i = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_i = nn.Parameter(torch.Tensor(hidden_sz))
        
        #f_t
        self.u_f = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.v_f = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_f = nn.Parameter(torch.Tensor(hidden_sz))
        
        #c_t
        self.u_c = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.v_c = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_c = nn.Parameter(torch.Tensor(hidden_sz))
        
        #o_t
        self.u_o = nn.Parameter(torch.Tensor(input_sz, hidden_sz))
        self.v_o = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz))
        self.b_o = nn.Parameter(torch.Tensor(hidden_sz))
        
        self.init_weights()
        
        
    #default implementation in nn.module
    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)
        
    
    def forward(self, x, init_states=None):
        batch_size, seq_size, feature_length = x.size()
        hidden_seq = []
        
        if init_states is None:
            h_t, c_t = (torch.zeros(batch_size, self.hidden_size).to(x.device),
                        torch.zeros(batch_size, self.hidden_size).to(x.device)
            )
        else:
            h_t, c_t = init_states
    
    
        for t in range(seq_size):
            x_t = x[:, t, :]
            i_t = torch.sigmoid(x_t * self.U_i + h_t * self.V_i + self.b_i)
            f_t = torch.sigmoid(x_t * self.U_f + h_t * self.V_f + self.b_f)
            g_t = torch.tanh(x_t * self.U_c + h_t * self.V_c + self.b_c)
            o_t = torch.sigmoid(x_t * self.U_o + h_t * self.V_o + self.b_o)
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            
            hidden_seq.append(h_t.unsqueeze(0))
        
        #reshape hidden_seq p/ retornar
        hidden_seq = torch.cat(hidden_seq, dim=0)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        return hidden_seq, (h_t, c_t)
    

In [12]:
lstm_obj = CustomLSTM(2,2)
lstm_obj.parameters(0)

<generator object Module.parameters at 0x000001FD4C4753C8>

In [None]:
input_shape => (batch_size, sequence_length, feature_length)

if sequence is a sentence : "I have a cat"
sequence length is 4. In practice we use a standard sequence length say 20, 
and pad the sequences or prune the sequences
each word is a feature - and has a feature length

e.g. embedding vector of length 300
    one-hot encoding where everty word is a 0 ot 1,
    in that case the feature length of a word is the size of the full vocabulary
    i.e. the total numbe of words
    
batch: how many sequences are there in a batch
    
    
    
The weight matrix multiplies each element of the sequence:
shape(feature_length, length_of_hidden_state)
The hidden_state for each element in the sequence has shape (batch_size, length_of_hidden_state)
Output shape (batch_size, sequence length,length_of_hidden_state)
the weight matrix that will multiply output must have shape: (length_of_hidden_state, length_of_hidden_state)

    
The feedforward operation receives ht, ct parameters.
Set to zero if nothing is carried forward.
Feed forward LSTM equations for each of the sequence elements (each word or feature vector in the sequence) preserving the ht, ct 
Introducing it as the states for the next element of the sequence.

Terminate:return the predictions and the last states tuple.



In [15]:
class CustomLSTM(nn.Module):
    def __init__(self, input_sz, hidden_sz):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
        self.init_weights()
                
    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)
         
    def forward(self, x, init_states=None):
        """Assumes x is of shape (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), 
                        torch.zeros(bs, self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states
         
        HS = self.hidden_size
        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            gates = x_t * self.W + h_t * self.U + self.bias
            i_t, f_t, g_t, o_t = (
                torch.sigmoid(gates[:, :HS]), # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.tanh(gates[:, HS*2:HS*3]),
                torch.sigmoid(gates[:, HS*3:]), # output
            )
            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)
            hidden_seq.append(h_t.unsqueeze(0))
        hidden_seq = torch.cat(hidden_seq, dim=0)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        return hidden_seq, (h_t, c_t)

Modified LSTM

In [16]:
class CustomLSTM(nn.Module):
    def __init__(self, input_sz, hidden_sz, peephole=False):
        super().__init__()
        self.input_sz = input_sz
        self.hidden_size = hidden_sz
        self.peephole = peephole
        self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
        self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
        self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
        self.init_weights()
                
    def init_weights(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)
         
    def forward(self, x, 
                init_states=None):
        """Assumes x is of shape (batch, sequence, feature)"""
        bs, seq_sz, _ = x.size()
        hidden_seq = []
        if init_states is None:
            h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), 
                        torch.zeros(bs, self.hidden_size).to(x.device))
        else:
            h_t, c_t = init_states
         
        HS = self.hidden_size
        for t in range(seq_sz):
            x_t = x[:, t, :]
            # batch the computations into a single matrix multiplication
            
            if self.peephole:
                gates = x_t * U + c_t * V + bias
            else:
                gates = x_t * U + h_t * V + bias
                g_t = torch.tanh(gates[:, HS*2:HS*3])
            
            i_t, f_t, o_t = (
                torch.sigmoid(gates[:, :HS]), # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.sigmoid(gates[:, HS*3:]), # output
            )
            
            if self.peephole:
                c_t = f_t * c_t + i_t * torch.sigmoid(x_t @ U + bias)[:, HS*2:HS*3]
                h_t = torch.tanh(o_t * c_t)
            else:
                c_t = f_t * c_t + i_t * g_t
                h_t = o_t * torch.tanh(c_t)
                
            hidden_seq.append(h_t.unsqueeze(0))
            
        hidden_seq = torch.cat(hidden_seq, dim=0)
        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()
        
        return hidden_seq, (h_t, c_t)