In [1]:
# https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [4]:
training_data = [("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),
                 ("Everybody read that book".split(), ["NN", "V", "DET", "NN"])]
word2index = {}

for sent, tags in training_data:
    for word in sent:
        if word not in word2index:
            word2index[word] = len(word2index)

tag2index = {"DET": 0, "NN": 1, "V": 2}

In [5]:
class LSTMTagger(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size, tag_size):
        super(LSTMTagger, self).__init__()
        
        self.word_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.hidden2tag = nn.Linear(hidden_dim, tag_size)
        
    def forward(self, seq):
        embeds = self.word_embedding(seq)
        out, _ = self.lstm(embeds.view(len(seq), 1, -1))
        out = self.hidden2tag(out.view(len(seq), -1))
        out = F.log_softmax(out, dim = 1)
        
        return out
    
model = LSTMTagger(6, 16, len(word2index), len(tag2index))
loss_f = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr = 0.1)

In [10]:
with torch.no_grad():
    inputs = torch.tensor([word2index[w] for w in training_data[0][0]], dtype = torch.long)
    tag_scores = model(inputs)

    print(tag_scores)
    
for epoch in range(300):
    for sentence, tags in training_data:
        model.zero_grad()
        seq = [word2index[w] for w in sentence]
        targets = [tag2index[t] for t in tags]
        predicts = model(torch.tensor(seq, dtype = torch.long))
        loss = loss_f(predicts, torch.tensor(targets, dtype = torch.long))
        loss.backward()
        optimizer.step()
        
with torch.no_grad():
    inputs = torch.tensor([word2index[w] for w in training_data[0][0]], dtype = torch.long)
    tag_scores = model(inputs)

    print(tag_scores)

tensor([[-1.0938, -1.0020, -1.2110],
        [-1.1015, -1.0036, -1.2004],
        [-1.0983, -0.9876, -1.2238],
        [-1.1521, -0.9553, -1.2062],
        [-1.1115, -1.0230, -1.1666]])
tensor([[-2.9054e-02, -4.0877e+00, -4.4347e+00],
        [-4.2514e+00, -2.0291e-02, -5.1426e+00],
        [-3.3132e+00, -6.4645e+00, -3.8697e-02],
        [-2.5116e-02, -4.6546e+00, -4.1809e+00],
        [-5.5187e+00, -6.1245e-03, -6.1684e+00]])
