# Thinking in tensors in PyTorch

Hands-on training  by [Piotr Migdał](https://p.migdal.pl) (2019). 

Version for [AI & NLP Workshop Day](https://nlpday.pl/), 31 May 2019, Warsaw, Poland: **Understanding LSTM and GRU networks in PyTorch**.



## NLP & AI: 4. LSTM GRU anatomy


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/stared/thinking-in-tensors-writing-in-pytorch/blob/master/extra/4%20LSTM%20GRU%20anatomy.ipynb)
 

In [None]:
import torch
from torch import nn

## LSTM

More in https://pytorch.org/docs/stable/nn.html#lstm

In [None]:
lstm = nn.LSTM(5, 3)

In [None]:
# L = 8 (length)
# B = 1 (batch size)
# C = 5 (channels)
x = torch.randn(8, 1, 5)

In [None]:
output, (hidden, cell) = lstm(x)

In [None]:
output

In [None]:
hidden

In [None]:
cell

In [None]:
output[-1] == hidden

## Step by step

In [None]:
output1, (hidden1, cell1) = lstm(x[:4])
output2, (hidden2, cell2) = lstm(x[4:], (hidden1, cell1))

In [None]:
output2

## Iteration

In [None]:
lstm

In [None]:
hidden = torch.tensor([[[ 0., 0., 0.]]])
cell = torch.tensor([[[ 0., 0., 0.]]])
for i, token in enumerate(x):
    output, (hidden, cell) = lstm(x[i:i+1], (hidden, cell))
    print(output)

In [None]:
for e

## GRU

More in https://pytorch.org/docs/stable/nn.html#gru

In [None]:
gru = nn.GRU(5, 3)

In [None]:
# note that instead of (hidden, cell) there is only hidden
output, hidden = gru(x)

In [None]:
output

In [None]:
hidden

## Bidirectional LSTM

See also: [Understanding Bidirectional RNN in PyTorch](https://towardsdatascience.com/understanding-bidirectional-rnn-in-pytorch-5bd25a5dd66) by Cechine Lee

In [None]:
bilstm = nn.LSTM(5, 3, bidirectional=True)

In [None]:
output, (hidden, cell) = bilstm(x)

In [None]:
output.size()

In [None]:
hidden.size()

In [None]:
cell.size()

In [None]:
output

In [None]:
hidden

## Many-layered LSTM

In [None]:
multilstm = nn.LSTM(5, 3, num_layers=2)

In [None]:
output, (hidden, cell) = multilstm(x)

In [None]:
output.size()

In [None]:
hidden.size()

In [None]:
cell.size()

In [None]:
output

In [None]:
hidden