# LSTM

![Alt text](assets/image.png)

## Forget Gate

![Alt text](assets/image-4.png)

$ f_t = sigmoid(W_f * [h_{t-1},x_t] + b_f) $

Forget gate decides what information will be retained in the cell state.

Multiply the weight of the forget gate $ W_t $ with previous state $ h_{t-1} $ and current input $ x_t $. Add the bias of the forget gate $ b_f $.

The sigmoid function turns this into a value between 0 and 1 indicating how much percent of cell state to keep.

```python

```

## Input Gate

![Alt text](assets/image-5.png)

$ i_t = sigmoid(W_i * [h_{t-1},x_t] + b_i) $

The input gate decides which values to update in the cell state and by how much.

$ \~{C_t} = tanh(W_C * [h_{t-1},x_t] + b_C) $

$ \~{C_t} $ determines the new values to be added to the cell state.

## Cell State

![Alt text](assets/image-1.png)

$ C_t = f_t * C_{t-1} + i_t * \tilde{C} $

Cell state is the long term memory. It forgets some information from the previous step in the sequence, then adds new information calculated from the input gate, forming the updated cell state.

## Output Gate

![Alt text](./assets/image-6.png)

$ o_t = sigmoid(W_o [h_{t-1},x_t] + b_o) $

The output gate determines which values in the cell state will be forwarded.

$ h_t = o_t * tanh(C_t) $

The hidden state calculates the output by multiplying the output gate and a $ tanh $-filtered cell state. It doubly functions as the short term memory for the next input in the sequence.

The `hidden_size` parameter determines the size, or dimensions, of the hidden state.

## Source

https://github.com/pytorch/pytorch/blob/9347a79f1c35c76535310111d1c50bada4e975a8/aten/src/ATen/native/RNN.cpp#L716

Get result of $ W_h(h_{t-1}) + W_i(x_t) + b_h + b_i $
```cpp
const auto gates = params.linear_hh(hx).add_(
    pre_compute_input ? input : params.linear_ih(input));
```

Apply activation to each gate
```cpp
auto chunked_gates = gates.unsafe_chunk(4, 1);
auto ingate = chunked_gates[0].sigmoid_();
auto forgetgate = chunked_gates[1].sigmoid_();
auto cellgate = chunked_gates[2].tanh_();
auto outgate = chunked_gates[3].sigmoid_();
```

Update cell state (long term) and hidden state (short term)
```cpp
auto cy = (forgetgate * cx).add_(ingate * cellgate);
auto hy = outgate * cy.tanh();
return std::make_tuple(std::move(hy), std::move(cy));
```