This notebook is a POC for the LSTM basic structure.

The goal is to define all the components of the LSTM and to define the forward pass.

## Structure

The classes below are defining the needed parameters for the forward pass.

In [85]:
from pydantic import BaseModel
import numpy as np

class ForwardParameters(BaseModel):
    Wf: np.ndarray # two-dimensional matrix containing weights for forget gate.
    bf: np.ndarray # one-dimensional vector containing biases for forget gate.
    Wi: np.ndarray # two-dimensional matrix containing weights for input gate.
    bi: np.ndarray # one-dimensional vector containing biases for input gate.
    Wc: np.ndarray # two-dimensional matrix containing weights for candidate cell state.
    bc: np.ndarray # one-dimensional vector containing biases for candidate cell state.
    Wo: np.ndarray # two-dimensional matrix containing weights for output gate.
    bo: np.ndarray # one-dimensional vector containing biases for output gate.
    Wy: np.ndarray # two-dimensional weight matrix relating the hidden-state to the output.
    by: np.ndarray # one-dimensional vector containing biases relating the hidden-state to the output.
    
    class Config:
        arbitrary_types_allowed = True
        
    


## Forward pass

In [86]:
def sigmoid(x: np.ndarray):
    """Compute the sigmoid activation function."""
    return 1 / (1 + np.exp(-x))

def softmax(x: np.ndarray):
    """Compute softmax values for each sets of values in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def lstm_forward(x: np.ndarray, ht_0: np.ndarray, parameters: ForwardParameters)-> tuple[
    list[np.ndarray], list[np.ndarray], list[np.ndarray]
    ]:
    """
    The forward pass of the LSTM-cell.
    
    Arguments:
    x: input data for all time steps, shape (n_x, m, T_x)
    ht_0: initial hidden state, shape (n_a, m)
    parameters: Parameters
    
    Returns:
    h: hidden states for all time steps, shape (n_a, m, T_x)
    c: cell states for all time steps, shape (n_a, m, T_x)
    y: predictions for all time steps, shape (n_y, m, T_x)
    """
    T_x = x.shape[0] # number of time steps
    m = x.shape[1] # number of examples
    
    # initialize hidden state and cell state
    ht = ht_0
    ct = np.zeros(m) # initialize cell state as zero vector
    
    ft, it, cct, ot = None, None, None, None
    h, c, y = [], [], []
    
    for t in range(T_x):
        xt = x[:,t]
        
        u_cct, v_cct= parameters.Wc
        cct = np.tanh(np.dot(u_cct, xt) + np.dot(v_cct, ht) + parameters.bc)
        print(f'cct: {cct}')
        
        u_ft, v_ft = parameters.Wf
        ft = sigmoid(np.dot(u_ft, xt) + np.dot(v_ft, ht) + parameters.bf)
        print(f'ft: {ft}')
        
        u_it, v_it = parameters.Wi
        it = sigmoid(np.dot(u_it, xt) + np.dot(v_it, ht) + parameters.bi)
        print(f'it: {it}')
        
        ct = np.dot(ft, ct) + np.dot(ft, cct)
        print(f'ct: {ct}')
        c.append(ct)
        
        u_ot, v_ot = parameters.Wo
        ot = sigmoid(np.dot(u_ot, xt) + np.dot(v_ot, ht) + parameters.bo)
        print(f'ot: {ot}')
        
        ht = np.dot(ot, np.tanh(ct))
        print(f'ht: {ht}\n')
        h.append(ht)
        
        yt_pred = softmax(np.dot(parameters.Wy, ht) + parameters.by)
        y.append(yt_pred)
    
    return h, c, y

### Forward pass example

In [87]:
# input = "I ate cake" vectorized as [1, 0, 0], [0, 1, 0], [0, 0, 1]
input = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
ht_0 = np.array([0,0,0])

# random parameters
parameters = ForwardParameters(
    Wf=np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
    bf=np.array([0.1, 0.2, 0.3]),
    Wi=np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
    bi=np.array([0.1, 0.2, 0.3]),
    Wc=np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
    bc=np.array([0.1, 0.2, 0.3]),
    Wo=np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]),
    bo=np.array([0.1, 0.2, 0.3]),
    Wy=np.array([0.1, 0.2, 0.3]),
    by=np.array([0.1, 0.2, 0.3])
)

forward_pass_result = lstm_forward(input, ht_0, parameters)

print(f'ht: {forward_pass_result[0]}')
print(f'ct: {forward_pass_result[1]}')
print(f'yt: {forward_pass_result[2]}')

cct: [0.19737532 0.29131261 0.37994896]
ft: [0.549834   0.57444252 0.59868766]
it: [0.549834   0.57444252 0.59868766]
ct: 0.5033367667404957
ot: [0.549834   0.57444252 0.59868766]
ht: [0.25552837 0.26696486 0.27823249]

cct: [0.60603656 0.66550634 0.71757764]
ft: [0.66877134 0.69053748 0.71149034]
it: [0.66877134 0.69053748 0.71149034]
ct: [1.71202372 1.72297942 1.73352576]
ot: [0.66877134 0.69053748 0.71149034]
ht: 1.9428700854726468

cct: [0.82655019 0.89985085 0.94313866]
ft: [0.76443463 0.81327541 0.85392486]
it: [0.76443463 0.81327541 0.85392486]
ct: 6.359327440546947
ot: [0.76443463 0.81327541 0.85392486]
ht: [0.76443005 0.81327054 0.85391974]

ht: [array([0.25552837, 0.26696486, 0.27823249]), 1.9428700854726468, array([0.76443005, 0.81327054, 0.85391974])]
ct: [0.5033367667404957, array([1.71202372, 1.72297942, 1.73352576]), 6.359327440546947]
yt: [array([0.30060961, 0.33222499, 0.3671654 ]), array([0.2413368 , 0.32391479, 0.43474841]), array([0.30060961, 0.33222499, 0.3671654 ]