# Predicting parts of speech with an LSTM

Let's preview the end result. We want to take a sentence and output the part of speech for each word in that sentence. Something like this:

**Code**

```python
new_sentence = "I is a teeth"

...

predictions = model(processed_sentence)

...
```

**Output**

```text
I     => Noun
is    => Verb
a     => Determiner
teeth => Noun
```

In [82]:
def ps(s):
    """Process String: convert a string into a list of lowercased words."""
    line = s.strip().replace(' ', '')
    return [c for c in line]

In [83]:
import os

# read quesitons and answers from file
FILEPATH = os.path.expanduser('~') + "/tonyaradzwa.github.io/train_data/arithmetic__mixed.txt"

questions = []
answers = []
f = open(FILEPATH, "r")

# questions = [ ["1", "+" , "3"], ... ] 
# answers = [ [] ]

for i in range(5000):
    line_q = f.readline().strip().replace(' ', '')
    line_a = f.readline().strip().replace(' ', '')
    
    line_q_list = [c for c in line_q]
    line_a_list = [c for c in line_a]
    
    for i in range(len(line_q_list) - len(line_a_list)):
        line_a_list.insert(0, "0")
        
    questions.append(line_q_list)
    answers.append(line_a_list)
    
f.close()
    
# use zip to create dataset object
dataset = [(q,a) for q,a in zip(questions,answers)]

In [84]:
import torch

from fastprogress.fastprogress import progress_bar, master_bar

from random import shuffle

## Preparing data for use as NN input

We can't pass a list of plain text words and tags to a NN. We need to convert them to a more appropriate format.

We'll start by creating a unique index for each word and tag.

In [85]:
word_to_index = {}
tag_to_index = {}

total_words = 0
total_tags = 0

tag_list = []

for words, tags in dataset:

    total_words += len(words)

    for word in words:
        if word not in word_to_index:
            word_to_index[word] = len(word_to_index)

    total_tags += len(tags)

    for tag in tags:
        if tag not in tag_to_index:
            tag_to_index[tag] = len(tag_to_index)
            tag_list.append(tag)

In [86]:
print("       Vocabulary Indices")
print("-------------------------------")

for word in sorted(word_to_index):
    print(f"{word:>14} => {word_to_index[word]:>2}")

print("\nTotal number of words:", total_words)
print("Number of unique words:", len(word_to_index))

       Vocabulary Indices
-------------------------------
             ( =>  3
             ) =>  6
             * => 25
             + =>  2
             - =>  5
             . => 21
             / =>  7
             0 =>  8
             1 =>  0
             2 => 11
             3 => 24
             4 => 10
             5 =>  1
             6 =>  9
             7 =>  4
             8 => 20
             9 => 19
             ? => 30
             C => 12
             E => 22
             W => 26
             a => 13
             c => 15
             e => 18
             f => 32
             h => 27
             i => 28
             l => 14
             o => 31
             s => 29
             t => 17
             u => 16
             v => 23

Total number of words: 134274
Number of unique words: 33


In [87]:
print("Tag Indices")
print("-----------")

for tag, index in tag_to_index.items():
    print(f"  {tag} => {index}")

print("\nTotal number of tags:", total_tags)
print("Number of unique tags:", len(tag_to_index))

Tag Indices
-----------
  0 => 0
  5 => 1
  - => 2
  3 => 3
  / => 4
  4 => 5
  6 => 6
  7 => 7
  2 => 8
  1 => 9
  9 => 10
  8 => 11

Total number of tags: 134274
Number of unique tags: 12


## Letting the NN parameterize words

Once we have a unique identifier for each word, it is useful to start our NN with an [embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html#torch.nn.Embedding) layer. This layer converts an index into a vector of values.

You can think of each value as indicating something about the word. For example, maybe the first value indicates how much a word conveys happiness vs sadness. Of course, the NN can learn any attributes and it is not limited to thinks like happy/sad, masculine/feminine, etc.

**Creating an embedding layer**. An embedding layer is created by telling it the size of the vocabulary (the number of words) and an embedding dimension (how many values to use to represent a word).

**Embedding layer input and output**. An embedding layer takes an index and return a matrix.

In [88]:
def convert_to_index_tensor(words, mapping):
    indices = [mapping[w] for w in words]
    return torch.tensor(indices, dtype=torch.long)

In [89]:
vocab_size = len(word_to_index)
embed_dim = 6  # Hyperparameter
embed_layer = torch.nn.Embedding(vocab_size, embed_dim)

In [90]:
# i = torch.tensor([word_to_index["the"], word_to_index["dog"]])
indices = convert_to_index_tensor(ps("15 + (7 + -17)/12"), word_to_index)
embed_output = embed_layer(indices)
indices.shape, embed_output.shape, embed_output

(torch.Size([13]),
 torch.Size([13, 6]),
 tensor([[ 0.1435, -0.6875, -2.2596, -1.1459, -0.5756, -0.9597],
         [-0.4545,  0.3066,  0.1789,  0.7608,  0.5407, -0.4892],
         [-0.0479,  0.9947, -0.1374, -0.8856,  0.4321, -1.7607],
         [ 0.8785, -3.2259,  1.8508, -1.1539,  0.3714, -0.5989],
         [ 0.8060, -0.6189,  0.7249, -2.1139, -0.3225, -1.1629],
         [-0.0479,  0.9947, -0.1374, -0.8856,  0.4321, -1.7607],
         [ 0.8696, -0.1932, -2.3105, -0.8037,  0.0034,  1.2192],
         [ 0.1435, -0.6875, -2.2596, -1.1459, -0.5756, -0.9597],
         [ 0.8060, -0.6189,  0.7249, -2.1139, -0.3225, -1.1629],
         [-0.6419,  0.0917,  0.2189,  0.3942, -0.2931,  0.2224],
         [-2.2640,  0.1160, -0.6837, -0.4518, -0.1985,  1.4327],
         [ 0.1435, -0.6875, -2.2596, -1.1459, -0.5756, -0.9597],
         [ 0.7777, -0.1471, -1.3939,  1.8194,  2.3570, -1.1997]],
        grad_fn=<EmbeddingBackward0>))

## Adding an LSTM layer

The [LSTM](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM) layer is in charge of processing embeddings such that the network can output the correct classification. Since this is a recurrent layer, it will take into account past words when it creates an output for the current word.

**Creating an LSTM layer**. To create an LSTM you need to tell it the size of its input (the size of an embedding) and the size of its internal cell state.

**LSTM layer input and output**. An LSTM takes an embedding (and optionally an initial hidden and cell state) and outputs a value for each word as well as the current hidden and cell state).

If you read the linked LSTM documentation you will see that it requires input in this format: (seq_len, batch, input_size)

As you can see above, our embedding layer outputs something that is (seq_len, input_size). So, we need to add a dimension in the middle.

In [91]:
hidden_dim = 10  # Hyperparameter
num_layers = 5  # Hyperparameter
lstm_layer = torch.nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers)

In [92]:
# The LSTM layer expects the input to be in the shape (L, N, E)
#   L is the length of the sequence
#   N is the batch size (we'll stick with 1 here)
#   E is the size of the embedding
lstm_output, _ = lstm_layer(embed_output.unsqueeze(1))
lstm_output.shape

torch.Size([13, 1, 10])

## Classifiying the LSTM output

We can now add a fully connected, [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear) layer to our NN to learn the correct part of speech (classification).

**Creating a linear layer**. We create a linear layer by specifying the shape of the input into the layer and the number of neurons in the linear layer.

**Linear layer input and output**. The input is expected to be (input_size, output_size) and the output will be the output of each neuron.

In [93]:
tag_size = len(tag_to_index)
linear_layer = torch.nn.Linear(hidden_dim, tag_size)

In [94]:
linear_output = linear_layer(lstm_output)
linear_output.shape, linear_output

(torch.Size([13, 1, 12]),
 tensor([[[ 0.2848, -0.0681,  0.1973, -0.0034,  0.0800, -0.2546,  0.1426,
            0.1761, -0.2061, -0.0183,  0.1820,  0.0969]],
 
         [[ 0.2995, -0.0689,  0.1961, -0.0204,  0.0791, -0.2436,  0.1308,
            0.1837, -0.2046, -0.0298,  0.1894,  0.1030]],
 
         [[ 0.3081, -0.0702,  0.1933, -0.0280,  0.0784, -0.2374,  0.1235,
            0.1865, -0.2048, -0.0343,  0.1917,  0.1060]],
 
         [[ 0.3131, -0.0716,  0.1905, -0.0318,  0.0777, -0.2338,  0.1193,
            0.1884, -0.2052, -0.0365,  0.1923,  0.1078]],
 
         [[ 0.3159, -0.0729,  0.1881, -0.0339,  0.0771, -0.2314,  0.1169,
            0.1900, -0.2055, -0.0379,  0.1922,  0.1091]],
 
         [[ 0.3174, -0.0741,  0.1862, -0.0351,  0.0768, -0.2299,  0.1157,
            0.1913, -0.2058, -0.0389,  0.1920,  0.1099]],
 
         [[ 0.3181, -0.0751,  0.1847, -0.0359,  0.0765, -0.2289,  0.1151,
            0.1924, -0.2060, -0.0397,  0.1917,  0.1106]],
 
         [[ 0.3185, -0.0758,  0.1837

# Training an LSTM model

In [95]:
# Hyperparameters
valid_percent = 0.2  # Training/validation split

embed_dim = 7  # Size of word embedding
hidden_dim = 8  # Size of LSTM internal state
num_layers = 5  # Number of LSTM layers

learning_rate = 0.1
num_epochs = 2

## Creating training and validation datasets

In [96]:
N = len(dataset)
vocab_size = len(word_to_index)  # Number of unique input words
tag_size = len(tag_to_index)  # Number of unique output targets

# Shuffle the data so that we can split the dataset randomly
shuffle(dataset)

split_point = int(N * valid_percent)
valid_dataset = dataset[:split_point]
train_dataset = dataset[split_point:]

len(valid_dataset), len(train_dataset)

(1000, 4000)

## Creating the Parts of Speech LSTM model

In [97]:
class POS_LSTM(torch.nn.Module):
    """Part of Speach LSTM model."""

    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers, tag_size):
        super().__init__()
        self.embed = torch.nn.Embedding(vocab_size, embed_dim)
        self.lstm = torch.nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers)
        self.linear = torch.nn.Linear(hidden_dim, tag_size)

    def forward(self, X):
        X = self.embed(X)
        X, _ = self.lstm(X.unsqueeze(1))
        return self.linear(X)

## Training

In [98]:
def compute_accuracy(dataset):
    """A helper function for computing accuracy on the given dataset."""
    total_words = 0
    total_correct = 0

    model.eval()

    with torch.no_grad():
        for sentence, tags in dataset:
            sentence_indices = convert_to_index_tensor(sentence, word_to_index)
            tag_scores = model(sentence_indices).squeeze()
            predictions = tag_scores.argmax(dim=1)
            total_words += len(sentence)
            total_correct += sum(t == tag_list[p] for t, p in zip(tags, predictions))

    return total_correct / total_words

In [99]:
model = POS_LSTM(vocab_size, embed_dim, hidden_dim, num_layers, tag_size)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

mb = master_bar(range(num_epochs))

accuracy = compute_accuracy(valid_dataset)
print(f"Validation accuracy before training : {accuracy * 100:.2f}%")

for epoch in mb:

    # Shuffle the data for each epoch (stochastic gradient descent)
    shuffle(train_dataset)

    model.train()

    for sentence, tags in progress_bar(train_dataset, parent=mb):
        model.zero_grad()
        
        sentence = convert_to_index_tensor(sentence, word_to_index)
        tags = convert_to_index_tensor(tags, tag_to_index)

        tag_scores = model(sentence)

        loss = criterion(tag_scores.squeeze(), tags)

        loss.backward()
        optimizer.step()

accuracy = compute_accuracy(valid_dataset)
print(f"Validation accuracy after training : {accuracy * 100:.2f}%")

Validation accuracy before training : 1.49%


Validation accuracy after training : 90.55%


## Examining results

Here we look at all words that are misclassified by the model

In [100]:
print("\nMis-predictions after training on entire dataset")
header = "Word".center(14) + " | True Tag | Prediction"
print(header)
print("-" * len(header))

with torch.no_grad():
    for sentence, tags in dataset:
        sentence_indices = convert_to_index_tensor(sentence, word_to_index)
        tag_scores = model(sentence_indices)
        predictions = tag_scores.squeeze().argmax(dim=1)
        for word, tag, pred in zip(sentence, tags, predictions):
            if tag != tag_list[pred]:
                print(f"{word:>14} |     {tag}    |    {tag_list[pred]}")


Mis-predictions after training on entire dataset
     Word      | True Tag | Prediction
--------------------------------------
             5 |     1    |    0
             6 |     /    |    0
             . |     4    |    0
             ) |     -    |    0
             ? |     3    |    0
             . |     4    |    0
             1 |     -    |    0
             5 |     4    |    0
             ) |     /    |    0
             ? |     5    |    0
             2 |     -    |    0
             . |     3    |    0
             2 |     4    |    0
             ) |     -    |    0
             . |     3    |    0
             ) |     -    |    0
             ? |     3    |    0
             2 |     1    |    0
             ) |     /    |    0
             ) |     5    |    0
             / |     -    |    0
             2 |     1    |    0
             4 |     /    |    0
             . |     2    |    0
             . |     1    |    0
             . |     3    |    0
             7

             * |     -    |    0
             - |     1    |    0
             3 |     /    |    0
             ? |     3    |    0
             5 |     2    |    0
             ) |     /    |    0
             ) |     1    |    0
             ? |     9    |    0
             1 |     -    |    0
             8 |     2    |    0
             ) |     /    |    0
             . |     3    |    0
             ) |     -    |    0
             . |     4    |    0
             3 |     -    |    0
             ) |     3    |    0
             5 |     3    |    0
             ) |     /    |    0
             ? |     7    |    0
             4 |     1    |    0
             5 |     /    |    0
             ? |     8    |    0
             2 |     -    |    0
             ? |     5    |    0
             . |     3    |    0
             ) |     -    |    0
             / |     2    |    0
             4 |     /    |    0
             0 |     1    |    0
             ? |     5    |    0
          

             / |     1    |    0
             8 |     /    |    0
             . |     5    |    0
             4 |     4    |    0
             ) |     /    |    0
             ? |     9    |    0
             - |     -    |    0
             1 |     2    |    0
             8 |     /    |    0
             ) |     9    |    0
             / |     -    |    0
             5 |     5    |    0
             ) |     /    |    0
             ? |     3    |    0
             1 |     2    |    0
             ) |     /    |    0
             . |     7    |    0
             . |     2    |    0
             / |     1    |    0
             1 |     /    |    0
             0 |     5    |    0
             6 |     -    |    0
             0 |     7    |    0
             ) |     -    |    0
             ? |     6    |    0
             ) |     1    |    0
             3 |     1    |    0
             9 |     -    |    0
             . |     6    |    0
             . |     4    |    0
          

             . |     5    |    0
             6 |     3    |    0
             ) |     /    |    0
             ? |     2    |    0
             / |     -    |    0
             3 |     1    |    0
             ) |     /    |    0
             . |     2    |    0
             ? |     7    |    0
             . |     3    |    0
             3 |     -    |    0
             7 |     4    |    0
             5 |     /    |    0
             . |     9    |    0
             6 |     1    |    0
             6 |     /    |    0
             ) |     3    |    0
             ? |     4    |    0
             3 |     -    |    0
             0 |     1    |    0
             4 |     /    |    0
             ) |     1    |    0
             / |     3    |    0
             7 |     /    |    0
             ) |     4    |    0
             3 |     2    |    0
             6 |     /    |    0
             3 |     1    |    0
             . |     1    |    0
             5 |     2    |    0
          

             ) |     -    |    0
             . |     3    |    0
             1 |     -    |    0
             . |     6    |    0
             ) |     2    |    0
             ) |     /    |    0
             . |     3    |    0
             - |     -    |    0
             8 |     3    |    0
             ) |     /    |    0
             ? |     8    |    0
             - |     -    |    0
             7 |     2    |    0
             0 |     /    |    0
             ) |     1    |    0
             . |     5    |    0
             ) |     -    |    0
             ? |     4    |    0
             8 |     1    |    0
             ) |     /    |    0
             . |     7    |    0
             / |     -    |    0
             1 |     2    |    0
             8 |     /    |    0
             4 |     2    |    0
             ? |     3    |    0
             / |     4    |    0
             1 |     /    |    0
             . |     9    |    0
             8 |     -    |    0
          

             . |     1    |    0
             ) |     -    |    0
             . |     5    |    0
             / |     -    |    0
             2 |     1    |    0
             ) |     -    |    0
             / |     2    |    0
             3 |     /    |    0
             6 |     9    |    0
             ) |     -    |    0
             ) |     3    |    0
             . |     2    |    0
             1 |     1    |    0
             2 |     /    |    0
             ) |     3    |    0
             0 |     1    |    0
             ) |     /    |    0
             . |     5    |    0
             6 |     -    |    0
             ? |     2    |    0
             . |     2    |    0
             6 |     1    |    0
             ) |     /    |    0
             ? |     3    |    0
             ) |     -    |    0
             . |     7    |    0
             / |     1    |    0
             5 |     /    |    0
             ) |     3    |    0
             3 |     -    |    0
          

             - |     -    |    0
             2 |     2    |    0
             ) |     /    |    0
             ? |     9    |    0
             + |     2    |    0
             5 |     /    |    0
             ) |     2    |    0
             ? |     5    |    0
             / |     -    |    0
             3 |     1    |    0
             6 |     /    |    0
             . |     3    |    0
             ) |     1    |    0
             - |     2    |    0
             3 |     /    |    0
             ) |     5    |    0
             ? |     7    |    0
             / |     -    |    0
             5 |     2    |    0
             ) |     /    |    0
             ? |     7    |    0
             / |     -    |    0
             3 |     1    |    0
             0 |     /    |    0
             . |     5    |    0
             1 |     -    |    0
             2 |     1    |    0
             ) |     /    |    0
             ? |     3    |    0
             - |     -    |    0
          

             8 |     /    |    0
             ? |     9    |    0
             0 |     4    |    0
             ) |     /    |    0
             ? |     5    |    0
             2 |     -    |    0
             ) |     6    |    0
             4 |     5    |    0
             ) |     /    |    0
             . |     2    |    0
             - |     -    |    0
             8 |     2    |    0
             5 |     /    |    0
             ) |     1    |    0
             . |     7    |    0
             4 |     -    |    0
             0 |     2    |    0
             ) |     /    |    0
             . |     5    |    0
             2 |     2    |    0
             0 |     /    |    0
             ? |     9    |    0
             ? |     1    |    0
             ? |     5    |    0
             ) |     -    |    0
             - |     2    |    0
             8 |     /    |    0
             ) |     1    |    0
             . |     5    |    0
             ) |     -    |    0
          

             + |     -    |    0
             1 |     6    |    0
             / |     /    |    0
             3 |     1    |    0
             ? |     1    |    0
             ? |     6    |    0
             ? |     3    |    0
             ) |     -    |    0
             / |     2    |    0
             6 |     /    |    0
             0 |     1    |    0
             ? |     5    |    0
             - |     -    |    0
             2 |     1    |    0
             8 |     /    |    0
             ) |     4    |    0
             * |     -    |    0
             - |     1    |    0
             2 |     /    |    0
             . |     7    |    0
             4 |     -    |    0
             0 |     7    |    0
             ) |     /    |    0
             . |     5    |    0
             . |     4    |    0
             0 |     -    |    0
             ? |     3    |    0
             0 |     -    |    0
             ) |     1    |    0
             / |     /    |    0
          

             ? |     1    |    0
             5 |     1    |    0
             9 |     /    |    0
             0 |     6    |    0
             ? |     5    |    0
             * |     -    |    0
             1 |     1    |    0
             1 |     1    |    0
             1 |     3    |    0
             ) |     /    |    0
             . |     2    |    0
             8 |     -    |    0
             / |     2    |    0
             5 |     /    |    0
             ? |     5    |    0
             8 |     -    |    0
             ) |     2    |    0
             * |     /    |    0
             - |     1    |    0
             6 |     3    |    0
             ) |     -    |    0
             ) |     8    |    0
             5 |     1    |    0
             1 |     /    |    0
             ? |     4    |    0
             4 |     -    |    0
             ) |     1    |    0
             / |     /    |    0
             2 |     5    |    0
             0 |     -    |    0
          

             8 |     -    |    0
             0 |     6    |    0
             ) |     /    |    0
             . |     5    |    0
             ) |     -    |    0
             ? |     5    |    0
             8 |     -    |    0
             . |     7    |    0
             ? |     3    |    0
             2 |     1    |    0
             ) |     /    |    0
             ? |     2    |    0
             - |     -    |    0
             1 |     7    |    0
             9 |     -    |    0
             0 |     1    |    0
             . |     1    |    0
             ) |     3    |    0
             ) |     /    |    0
             . |     2    |    0
             ) |     -    |    0
             . |     6    |    0
             6 |     1    |    0
             ) |     /    |    0
             . |     4    |    0
             ? |     5    |    0
             2 |     -    |    0
             ? |     4    |    0
             5 |     1    |    0
             ) |     /    |    0
          

             * |     1    |    0
             - |     /    |    0
             3 |     6    |    0
             1 |     -    |    0
             7 |     1    |    0
             5 |     /    |    0
             ) |     9    |    0
             ) |     -    |    0
             . |     3    |    0
             . |     3    |    0
             8 |     1    |    0
             ) |     /    |    0
             . |     4    |    0
             ) |     6    |    0
             ? |     5    |    0
             8 |     -    |    0
             ) |     3    |    0
             ? |     2    |    0
             ) |     -    |    0
             ? |     2    |    0
             0 |     2    |    0
             ) |     /    |    0
             ? |     5    |    0
             8 |     1    |    0
             ) |     /    |    0
             . |     3    |    0
             1 |     -    |    0
             1 |     1    |    0
             0 |     /    |    0
             ? |     5    |    0
          

             ) |     3    |    0
             ) |     /    |    0
             . |     2    |    0
             5 |     -    |    0
             + |     1    |    0
             6 |     /    |    0
             ? |     8    |    0
             ? |     4    |    0
             . |     8    |    0
             . |     5    |    0
             ? |     3    |    0
             ) |     -    |    0
             ? |     1    |    0
             - |     3    |    0
             2 |     /    |    0
             ? |     7    |    0
             6 |     1    |    0
             ) |     /    |    0
             . |     2    |    0
             ) |     -    |    0
             ? |     5    |    0
             ) |     -    |    0
             . |     4    |    0
             1 |     -    |    0
             ? |     2    |    0
             ) |     -    |    0
             ? |     1    |    0
             ? |     2    |    0
             2 |     1    |    0
             ) |     /    |    0
          

             ) |     -    |    0
             ? |     4    |    0
             4 |     -    |    0
             ) |     2    |    0
             . |     4    |    0
             ) |     -    |    0
             . |     5    |    0
             0 |     -    |    0
             ? |     5    |    0
             ) |     -    |    0
             . |     2    |    0
             1 |     -    |    0
             3 |     1    |    0
             ) |     -    |    0
             ? |     1    |    0
             ? |     1    |    0
             - |     -    |    0
             - |     2    |    0
             8 |     /    |    0
             . |     5    |    0
             1 |     -    |    0
             8 |     2    |    0
             6 |     /    |    0
             ) |     2    |    0
             ) |     5    |    0
             ) |     -    |    0
             ) |     5    |    0
             4 |     7    |    0
             6 |     /    |    0
             ? |     2    |    0
          

             / |     2    |    0
             4 |     /    |    0
             ) |     1    |    0
             . |     5    |    0
             1 |     8    |    0
             ) |     -    |    0
             ) |     2    |    0
             / |     /    |    0
             6 |     1    |    0
             ? |     1    |    0
             ) |     1    |    0
             ) |     /    |    0
             . |     5    |    0
             . |     2    |    0
             ) |     2    |    0
             ) |     /    |    0
             ? |     9    |    0
             . |     8    |    0
             * |     -    |    0
             - |     1    |    0
             3 |     /    |    0
             . |     2    |    0
             - |     -    |    0
             1 |     2    |    0
             ) |     /    |    0
             ? |     3    |    0
             1 |     1    |    0
             7 |     /    |    0
             ? |     9    |    0
             . |     6    |    0
          

             2 |     -    |    0
             2 |     5    |    0
             ) |     /    |    0
             ? |     2    |    0
             . |     1    |    0
             5 |     -    |    0
             / |     2    |    0
             1 |     /    |    0
             . |     3    |    0
             . |     5    |    0
             + |     -    |    0
             1 |     5    |    0
             / |     /    |    0
             2 |     4    |    0
             ? |     3    |    0
             5 |     -    |    0
             5 |     2    |    0
             ) |     /    |    0
             ? |     7    |    0
             ? |     4    |    0
             ) |     -    |    0
             . |     4    |    0
             + |     2    |    0
             4 |     /    |    0
             ) |     1    |    0
             ? |     1    |    0
             ) |     2    |    0
             ) |     /    |    0
             . |     3    |    0
             ) |     4    |    0
          

             . |     5    |    0
             - |     -    |    0
             4 |     4    |    0
             ) |     /    |    0
             . |     7    |    0
             . |     3    |    0
             . |     6    |    0
             2 |     -    |    0
             . |     3    |    0
             ) |     -    |    0
             . |     1    |    0
             ) |     -    |    0
             . |     3    |    0
             ? |     1    |    0
             . |     1    |    0
             ? |     1    |    0
             6 |     -    |    0
             5 |     2    |    0
             ) |     /    |    0
             . |     5    |    0
             2 |     -    |    0
             * |     2    |    0
             - |     /    |    0
             1 |     1    |    0
             . |     7    |    0
             1 |     -    |    0
             0 |     1    |    0
             ) |     /    |    0
             . |     2    |    0
             - |     3    |    0
          

             ? |     3    |    0
             4 |     -    |    0
             . |     9    |    0
             ) |     -    |    0
             / |     1    |    0
             6 |     /    |    0
             0 |     4    |    0
             0 |     -    |    0
             ) |     6    |    0
             ) |     /    |    0
             ? |     5    |    0
             1 |     5    |    0
             ) |     /    |    0
             ? |     4    |    0
             ? |     4    |    0
             . |     8    |    0
             - |     -    |    0
             3 |     1    |    0
             3 |     /    |    0
             ) |     3    |    0
             2 |     -    |    0
             / |     1    |    0
             8 |     /    |    0
             ? |     3    |    0
             6 |     -    |    0
             5 |     2    |    0
             ) |     -    |    0
             . |     2    |    0
             ? |     5    |    0
             ? |     7    |    0
          

             ? |     5    |    0
             5 |     -    |    0
             ) |     3    |    0
             / |     -    |    0
             7 |     1    |    0
             5 |     /    |    0
             . |     5    |    0
             2 |     -    |    0
             0 |     8    |    0
             ) |     /    |    0
             ? |     5    |    0
             ) |     1    |    0
             ) |     /    |    0
             ? |     9    |    0
             . |     4    |    0
             1 |     2    |    0
             8 |     /    |    0
             ? |     3    |    0
             0 |     -    |    0
             ) |     2    |    0
             ) |     -    |    0
             . |     5    |    0
             ) |     -    |    0
             . |     5    |    0
             ? |     2    |    0
             / |     -    |    0
             3 |     6    |    0
             2 |     6    |    0
             1 |     1    |    0
             ) |     /    |    0
          

             2 |     1    |    0
             0 |     /    |    0
             ) |     2    |    0
             3 |     -    |    0
             / |     1    |    0
             6 |     /    |    0
             ) |     1    |    0
             ? |     6    |    0
             ? |     5    |    0
             0 |     5    |    0
             ) |     /    |    0
             ? |     3    |    0
             1 |     1    |    0
             3 |     /    |    0
             ? |     3    |    0
             ? |     7    |    0
             9 |     -    |    0
             5 |     3    |    0
             ) |     /    |    0
             ? |     2    |    0
             5 |     -    |    0
             ? |     1    |    0
             5 |     -    |    0
             ) |     2    |    0
             ) |     /    |    0
             . |     5    |    0
             3 |     -    |    0
             5 |     2    |    0
             ) |     /    |    0
             ) |     9    |    0
          

             2 |     1    |    0
             2 |     /    |    0
             . |     4    |    0
             5 |     -    |    0
             . |     4    |    0
             ) |     -    |    0
             ) |     4    |    0
             3 |     -    |    0
             ? |     3    |    0
             ) |     -    |    0
             ? |     5    |    0
             - |     -    |    0
             3 |     2    |    0
             ) |     /    |    0
             ) |     1    |    0
             . |     1    |    0
             ) |     1    |    0
             ) |     /    |    0
             ? |     2    |    0
             - |     -    |    0
             2 |     2    |    0
             4 |     -    |    0
             ) |     2    |    0
             * |     /    |    0
             4 |     1    |    0
             ? |     9    |    0
             ? |     1    |    0
             4 |     -    |    0
             6 |     1    |    0
             8 |     /    |    0
          

             1 |     -    |    0
             ) |     1    |    0
             ) |     /    |    0
             ? |     3    |    0
             4 |     -    |    0
             9 |     1    |    0
             ) |     /    |    0
             . |     2    |    0
             . |     5    |    0
             * |     2    |    0
             2 |     /    |    0
             ) |     1    |    0
             . |     5    |    0
             2 |     -    |    0
             1 |     2    |    0
             6 |     /    |    0
             0 |     1    |    0
             ? |     5    |    0
             7 |     3    |    0
             ) |     /    |    0
             . |     5    |    0
             1 |     1    |    0
             ) |     -    |    0
             . |     3    |    0
             + |     6    |    0
             8 |     /    |    0
             ? |     5    |    0
             4 |     -    |    0
             9 |     1    |    0
             5 |     /    |    0
          

             5 |     2    |    0
             ) |     /    |    0
             . |     3    |    0
             ? |     9    |    0
             4 |     -    |    0
             ) |     2    |    0
             ) |     /    |    0
             . |     7    |    0
             - |     -    |    0
             1 |     1    |    0
             ) |     /    |    0
             . |     5    |    0
             4 |     6    |    0
             ? |     3    |    0
             / |     -    |    0
             4 |     1    |    0
             ) |     /    |    0
             . |     3    |    0
             . |     4    |    0
             1 |     -    |    0
             1 |     9    |    0
             + |     -    |    0
             - |     1    |    0
             9 |     /    |    0
             . |     5    |    0
             . |     7    |    0
             2 |     1    |    0
             ) |     /    |    0
             ? |     3    |    0
             ) |     -    |    0
          

             ) |     -    |    0
             / |     2    |    0
             2 |     /    |    0
             6 |     1    |    0
             ? |     3    |    0
             6 |     -    |    0
             ? |     6    |    0
             ) |     -    |    0
             ? |     3    |    0
             ? |     2    |    0
             . |     1    |    0
             8 |     -    |    0
             6 |     7    |    0
             0 |     4    |    0
             ) |     -    |    0
             ? |     2    |    0
             ) |     -    |    0
             ? |     2    |    0
             . |     4    |    0
             ? |     2    |    0
             3 |     -    |    0
             ) |     4    |    0
             ) |     -    |    0
             ? |     2    |    0
             . |     3    |    0
             8 |     -    |    0
             / |     1    |    0
             2 |     /    |    0
             ) |     2    |    0
             . |     8    |    0
          

             - |     2    |    0
             4 |     /    |    0
             ) |     7    |    0
             . |     5    |    0
             / |     2    |    0
             5 |     /    |    0
             4 |     2    |    0
             6 |     9    |    0
             ) |     -    |    0
             . |     6    |    0
             6 |     -    |    0
             ? |     1    |    0
             ) |     1    |    0
             3 |     5    |    0
             ) |     -    |    0
             / |     1    |    0
             8 |     /    |    0
             . |     2    |    0
             ? |     3    |    0
             1 |     2    |    0
             0 |     /    |    0
             ? |     5    |    0
             2 |     2    |    0
             ) |     /    |    0
             ? |     3    |    0
             0 |     7    |    0
             ) |     /    |    0
             ? |     2    |    0
             ? |     4    |    0
             / |     4    |    0
          

             5 |     -    |    0
             6 |     1    |    0
             0 |     /    |    0
             ? |     4    |    0
             1 |     -    |    0
             ) |     8    |    0
             * |     -    |    0
             2 |     2    |    0
             / |     /    |    0
             3 |     9    |    0
             3 |     2    |    0
             0 |     /    |    0
             ) |     1    |    0
             . |     1    |    0
             ) |     -    |    0
             ? |     3    |    0
             - |     -    |    0
             4 |     1    |    0
             ) |     /    |    0
             ? |     2    |    0
             ) |     1    |    0
             ) |     /    |    0
             ? |     2    |    0
             . |     8    |    0
             1 |     -    |    0
             2 |     1    |    0
             ? |     1    |    0
             . |     2    |    0
             2 |     -    |    0
             ) |     1    |    0
          

             / |     -    |    0
             1 |     1    |    0
             0 |     /    |    0
             . |     2    |    0
             ) |     -    |    0
             + |     2    |    0
             0 |     /    |    0
             ) |     1    |    0
             ? |     1    |    0
             ) |     4    |    0
             ) |     1    |    0
             ) |     /    |    0
             ? |     5    |    0
             ) |     3    |    0
             ) |     /    |    0
             ? |     8    |    0
             7 |     2    |    0
             ) |     /    |    0
             ? |     7    |    0
             / |     -    |    0
             3 |     1    |    0
             * |     -    |    0
             - |     1    |    0
             3 |     /    |    0
             ) |     4    |    0
             / |     1    |    0
             6 |     /    |    0
             ) |     3    |    0
             0 |     3    |    0
             ) |     /    |    0
          

             5 |     /    |    0
             ) |     7    |    0
             4 |     -    |    0
             ) |     6    |    0
             ? |     2    |    0
             - |     -    |    0
             5 |     1    |    0
             ) |     /    |    0
             ? |     2    |    0
             . |     4    |    0
             ) |     -    |    0
             + |     2    |    0
             1 |     /    |    0
             ) |     1    |    0
             . |     1    |    0
             2 |     -    |    0
             0 |     1    |    0
             ) |     /    |    0
             . |     3    |    0
             ) |     -    |    0
             ? |     7    |    0
             . |     9    |    0
             / |     2    |    0
             1 |     /    |    0
             0 |     5    |    0
             ) |     3    |    0
             / |     /    |    0
             4 |     5    |    0
             / |     -    |    0
             8 |     2    |    0
          

             6 |     4    |    0
             ) |     /    |    0
             . |     5    |    0
             6 |     1    |    0
             ) |     /    |    0
             . |     8    |    0
             ) |     -    |    0
             . |     6    |    0
             2 |     4    |    0
             8 |     /    |    0
             . |     3    |    0
             1 |     1    |    0
             ) |     /    |    0
             ? |     2    |    0
             ) |     2    |    0
             ) |     /    |    0
             ? |     5    |    0
             7 |     -    |    0
             / |     2    |    0
             2 |     /    |    0
             1 |     1    |    0
             ? |     3    |    0
             / |     -    |    0
             2 |     1    |    0
             8 |     /    |    0
             . |     2    |    0
             ) |     -    |    0
             ? |     3    |    0
             0 |     -    |    0
             ? |     5    |    0
          

             4 |     1    |    0
             ) |     /    |    0
             ? |     3    |    0
             - |     -    |    0
             1 |     1    |    0
             2 |     /    |    0
             ) |     6    |    0
             ) |     -    |    0
             - |     1    |    0
             5 |     /    |    0
             . |     2    |    0
             7 |     -    |    0
             . |     7    |    0
             . |     5    |    0
             0 |     4    |    0
             / |     -    |    0
             5 |     2    |    0
             4 |     /    |    0
             ) |     1    |    0
             . |     1    |    0
             ) |     -    |    0
             ? |     5    |    0
             6 |     3    |    0
             ) |     /    |    0
             . |     4    |    0
             6 |     -    |    0
             ) |     5    |    0
             5 |     1    |    0
             + |     /    |    0
             7 |     7    |    0
          

             ) |     6    |    0
             3 |     1    |    0
             ) |     /    |    0
             . |     2    |    0
             . |     4    |    0
             ? |     5    |    0
             - |     -    |    0
             3 |     2    |    0
             ) |     /    |    0
             ? |     9    |    0
             ) |     1    |    0
             4 |     -    |    0
             ) |     2    |    0
             . |     5    |    0
             1 |     3    |    0
             ) |     /    |    0
             ? |     2    |    0
             ? |     2    |    0
             - |     5    |    0
             6 |     /    |    0
             . |     6    |    0
             . |     4    |    0
             - |     1    |    0
             4 |     /    |    0
             ) |     2    |    0
             1 |     8    |    0
             ) |     /    |    0
             ? |     3    |    0
             / |     -    |    0
             1 |     1    |    0
          

             - |     -    |    0
             1 |     2    |    0
             ) |     /    |    0
             ? |     7    |    0
             5 |     2    |    0
             0 |     /    |    0
             ) |     3    |    0
             ) |     8    |    0
             0 |     -    |    0
             0 |     2    |    0
             - |     -    |    0
             1 |     1    |    0
             ) |     /    |    0
             . |     4    |    0
             7 |     5    |    0
             ) |     /    |    0
             ? |     3    |    0
             ? |     2    |    0
             + |     -    |    0
             0 |     2    |    0
             + |     /    |    0
             0 |     1    |    0
             . |     9    |    0
             1 |     -    |    0
             2 |     3    |    0
             ) |     /    |    0
             . |     2    |    0
             3 |     4    |    0
             4 |     -    |    0
             ) |     6    |    0
          

             . |     8    |    0
             ? |     7    |    0
             ) |     -    |    0
             ? |     5    |    0
             ? |     1    |    0
             ? |     3    |    0
             0 |     -    |    0
             ) |     2    |    0
             ) |     /    |    0
             . |     9    |    0
             7 |     1    |    0
             ) |     /    |    0
             ? |     7    |    0
             0 |     -    |    0
             ? |     9    |    0
             . |     3    |    0
             ) |     -    |    0
             . |     3    |    0
             . |     8    |    0
             2 |     3    |    0
             ) |     /    |    0
             . |     2    |    0
             4 |     -    |    0
             ) |     2    |    0
             / |     /    |    0
             2 |     3    |    0
             . |     1    |    0
             . |     1    |    0
             . |     2    |    0
             / |     -    |    0
          

             ) |     /    |    0
             . |     7    |    0
             1 |     -    |    0
             2 |     1    |    0
             0 |     /    |    0
             . |     4    |    0
             ? |     5    |    0
             6 |     -    |    0
             6 |     1    |    0
             ) |     /    |    0
             . |     3    |    0
             + |     1    |    0
             0 |     /    |    0
             ) |     3    |    0
             ) |     1    |    0
             ) |     /    |    0
             . |     7    |    0
             ) |     -    |    0
             . |     7    |    0
             ) |     3    |    0
             ) |     -    |    0
             . |     1    |    0
             2 |     -    |    0
             . |     9    |    0
             - |     -    |    0
             2 |     1    |    0
             0 |     /    |    0
             ) |     6    |    0
             / |     -    |    0
             1 |     2    |    0
          

             ) |     -    |    0
             ? |     6    |    0
             ? |     5    |    0
             0 |     -    |    0
             ) |     1    |    0
             ) |     /    |    0
             . |     6    |    0
             / |     -    |    0
             2 |     1    |    0
             8 |     /    |    0
             . |     2    |    0
             ) |     -    |    0
             . |     5    |    0
             ) |     2    |    0
             ) |     /    |    0
             . |     5    |    0
             ) |     -    |    0
             . |     1    |    0
             / |     -    |    0
             4 |     2    |    0
             5 |     /    |    0
             ? |     9    |    0
             - |     -    |    0
             4 |     4    |    0
             5 |     /    |    0
             ) |     9    |    0
             1 |     4    |    0
             ) |     /    |    0
             ? |     7    |    0
             ) |     -    |    0
          

             1 |     3    |    0
             ) |     /    |    0
             ? |     4    |    0
             4 |     -    |    0
             7 |     4    |    0
             ) |     /    |    0
             . |     5    |    0
             2 |     -    |    0
             0 |     2    |    0
             ) |     /    |    0
             ) |     1    |    0
             ? |     3    |    0
             ? |     5    |    0
             ) |     -    |    0
             ? |     6    |    0
             1 |     -    |    0
             6 |     2    |    0
             ) |     /    |    0
             ) |     2    |    0
             ? |     9    |    0
             6 |     1    |    0
             ) |     /    |    0
             . |     7    |    0
             4 |     -    |    0
             ) |     1    |    0
             ) |     /    |    0
             ? |     4    |    0
             / |     -    |    0
             1 |     3    |    0
             0 |     /    |    0
          

             - |     2    |    0
             1 |     /    |    0
             2 |     1    |    0
             ) |     1    |    0
             ) |     -    |    0
             ? |     7    |    0
             * |     -    |    0
             2 |     3    |    0
             ) |     /    |    0
             ? |     5    |    0
             6 |     2    |    0
             8 |     /    |    0
             0 |     1    |    0
             . |     7    |    0
             4 |     2    |    0
             ) |     /    |    0
             ? |     3    |    0
             2 |     -    |    0
             ? |     1    |    0
             . |     4    |    0
             . |     6    |    0
             ) |     -    |    0
             ? |     6    |    0
             ) |     3    |    0
             - |     1    |    0
             1 |     /    |    0
             . |     2    |    0
             3 |     2    |    0
             0 |     /    |    0
             8 |     1    |    0
          

             5 |     -    |    0
             . |     1    |    0
             5 |     -    |    0
             . |     3    |    0
             9 |     -    |    0
             . |     6    |    0
             ) |     2    |    0
             ) |     /    |    0
             ? |     7    |    0
             2 |     2    |    0
             7 |     /    |    0
             ) |     7    |    0
             ) |     -    |    0
             ? |     6    |    0
             . |     1    |    0
             0 |     -    |    0
             0 |     1    |    0
             ) |     /    |    0
             ? |     9    |    0
             ) |     2    |    0
             / |     /    |    0
             6 |     1    |    0
             ? |     5    |    0
             ) |     -    |    0
             . |     2    |    0
             ) |     -    |    0
             * |     2    |    0
             - |     /    |    0
             4 |     2    |    0
             . |     1    |    0
          

             ? |     3    |    0
             ( |     3    |    0
             - |     /    |    0
             6 |     1    |    0
             ( |     -    |    0
             - |     2    |    0
             2 |     /    |    0
             ) |     7    |    0
             / |     -    |    0
             6 |     3    |    0
             8 |     /    |    0
             . |     5    |    0
             / |     2    |    0
             2 |     /    |    0
             ? |     5    |    0
             ) |     2    |    0
             ) |     /    |    0
             . |     3    |    0
             ) |     2    |    0
             ) |     /    |    0
             ? |     3    |    0
             2 |     -    |    0
             ? |     5    |    0
             . |     1    |    0
             4 |     1    |    0
             ) |     /    |    0
             ? |     4    |    0
             3 |     -    |    0
             / |     1    |    0
             8 |     /    |    0
          

             ? |     7    |    0
             / |     -    |    0
             1 |     1    |    0
             6 |     /    |    0
             . |     2    |    0
             ) |     2    |    0
             ) |     /    |    0
             ? |     3    |    0
             - |     -    |    0
             2 |     1    |    0
             0 |     /    |    0
             ) |     1    |    0
             1 |     2    |    0
             ) |     /    |    0
             ? |     7    |    0
             ) |     -    |    0
             . |     7    |    0
             . |     1    |    0
             9 |     -    |    0
             . |     2    |    0
             ) |     -    |    0
             . |     3    |    0
             7 |     1    |    0
             ) |     /    |    0
             ? |     5    |    0
             4 |     -    |    0
             8 |     2    |    0
             4 |     /    |    0
             0 |     1    |    0
             . |     1    |    0
          

             - |     -    |    0
             2 |     7    |    0
             ) |     /    |    0
             ? |     6    |    0
             7 |     2    |    0
             ) |     /    |    0
             . |     9    |    0
             2 |     1    |    0
             ) |     /    |    0
             ? |     8    |    0
             ) |     -    |    0
             / |     2    |    0
             5 |     /    |    0
             ? |     9    |    0
             ) |     -    |    0
             ? |     6    |    0
             3 |     -    |    0
             ) |     3    |    0
             ) |     /    |    0
             . |     7    |    0
             8 |     3    |    0
             ) |     /    |    0
             . |     2    |    0
             1 |     2    |    0
             ) |     /    |    0
             . |     9    |    0
             ? |     4    |    0
             2 |     -    |    0
             ) |     6    |    0
             5 |     5    |    0
          

             8 |     -    |    0
             . |     6    |    0
             ) |     1    |    0
             ) |     /    |    0
             ? |     4    |    0
             ) |     1    |    0
             ) |     /    |    0
             ? |     4    |    0
             ) |     -    |    0
             ? |     3    |    0
             8 |     -    |    0
             ) |     1    |    0
             ) |     /    |    0
             ? |     6    |    0
             3 |     3    |    0
             ) |     /    |    0
             . |     2    |    0
             ) |     -    |    0
             ? |     8    |    0
             ) |     -    |    0
             ? |     5    |    0
             1 |     -    |    0
             3 |     2    |    0
             2 |     /    |    0
             . |     3    |    0
             ) |     -    |    0
             . |     5    |    0
             ) |     3    |    0
             ) |     /    |    0
             ? |     7    |    0
          

             / |     -    |    0
             6 |     5    |    0
             0 |     /    |    0
             . |     3    |    0
             + |     -    |    0
             1 |     2    |    0
             ) |     /    |    0
             ) |     3    |    0
             ? |     2    |    0
             3 |     -    |    0
             ) |     3    |    0
             ) |     /    |    0
             ? |     5    |    0
             ) |     -    |    0
             ? |     2    |    0
             . |     4    |    0
             / |     -    |    0
             2 |     2    |    0
             ) |     /    |    0
             ? |     7    |    0
             ? |     4    |    0
             ? |     1    |    0
             / |     1    |    0
             4 |     /    |    0
             2 |     6    |    0
             ? |     1    |    0
             + |     -    |    0
             6 |     2    |    0
             1 |     -    |    0
             ? |     5    |    0
          

             6 |     -    |    0
             ) |     2    |    0
             ) |     /    |    0
             ? |     7    |    0
             5 |     -    |    0
             ) |     2    |    0
             3 |     -    |    0
             ? |     3    |    0
             0 |     -    |    0
             ) |     5    |    0
             4 |     1    |    0
             ) |     /    |    0
             . |     6    |    0
             - |     -    |    0
             - |     1    |    0
             2 |     /    |    0
             . |     4    |    0
             7 |     1    |    0
             3 |     /    |    0
             . |     9    |    0
             1 |     2    |    0
             ) |     /    |    0
             . |     5    |    0


## Using the model for inference

In [105]:
new_sentence = "3 + 3"

# Convert sentence to lowercase words
sentence = ps(new_sentence)

# Check that each word is in our vocabulary
for word in sentence:
    assert word in word_to_index

# Convert input to a tensor
sentence = convert_to_index_tensor(sentence, word_to_index)

# Compute prediction
predictions = model(sentence)
predictions = predictions.squeeze().argmax(dim=1)

# Print results
for word, tag in zip(ps(new_sentence), predictions):
    print(word, "=>", tag_list[tag.item()])

3 => 0
+ => 0
3 => 0


Things to try:

- compare with fully connected network
- compare with CNN
- compare with transformer