In [1]:
# https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [11]:
training_data = [("The dog ate the apple".split(), ["DET", "NN", "V", "DET", "NN"]),
                 ("Everybody read that book".split(), ["NN", "V", "DET", "NN"])]
word2index = {}
ch2index = {}

for sent, tags in training_data:
    for word in sent:
        if word not in word2index:
            word2index[word] = len(word2index)
            
        for ch in word:
            if ch not in ch2index:
                ch2index[ch] = len(ch2index)

tag2index = {"DET": 0, "NN": 1, "V": 2}

In [5]:
class LSTMTagger(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, vocab_size, tag_size):
        super(LSTMTagger, self).__init__()
        
        self.word_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.hidden2tag = nn.Linear(hidden_dim, tag_size)
        
    def forward(self, seq):
        embeds = self.word_embedding(seq)
        out, _ = self.lstm(embeds.view(len(seq), 1, -1))
        out = self.hidden2tag(out.view(len(seq), -1))
        out = F.log_softmax(out, dim = 1)
        
        return out
    
model = LSTMTagger(6, 16, len(word2index), len(tag2index))
loss_f = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr = 0.1)

In [63]:
class AugmentedTagger(nn.Module):
    def __init__(self, w_embedding_dim, c_embedding_dim, hidden_dim1, hidden_dim2, vocab_size, character_size, tag_size):
        super(AugmentedTagger, self).__init__()
        
        self.word_embedding = nn.Embedding(vocab_size, w_embedding_dim)
        self.ch_embedding = nn.Embedding(character_size, c_embedding_dim)
        self.lstm1 = nn.LSTM(c_embedding_dim, hidden_dim1)
        self.lstm2 = nn.LSTM(w_embedding_dim + hidden_dim1, hidden_dim2)
        self.hidden2tag = nn.Linear(hidden_dim2, tag_size)
        
    def forward(self, seq, ch_seq):
        w_embeds = self.word_embedding(seq)
        w_embeds = w_embeds.view(len(seq), -1)
        embeds = []
        
        for i in range(len(w_embeds)):
            c_embeds = self.ch_embedding(torch.tensor(ch_seq[i], dtype = torch.long))
            _, out = self.lstm1(c_embeds.view(len(ch_seq[i]), 1, -1))
            embeds.append(torch.cat((w_embeds[i], out[1][0][0]), 0))
        
        embeds = torch.cat(tuple(embeds), 0)
        out, _ = self.lstm2(embeds.view(len(seq), 1, -1))
        out = self.hidden2tag(out.view(len(seq), -1))
        out = F.log_softmax(out, dim = 1)
        
        return out
    
model = AugmentedTagger(6, 5, 10, 16, len(word2index), len(ch2index), len(tag2index))
loss_f = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr = 0.1)

In [66]:
with torch.no_grad():
    seq = torch.tensor([word2index[w] for w in training_data[0][0]], dtype = torch.long)
    ch_seq = [tuple([ch2index[ch] for ch in w]) for w in training_data[0][0]]
    tag_scores = model(seq, ch_seq)

    print(tag_scores)
    
for epoch in range(300):
    for sentence, tags in training_data:
        model.zero_grad()
        seq = torch.tensor([word2index[w] for w in sentence], dtype = torch.long)
        ch_seq = [tuple([ch2index[ch] for ch in w]) for w in sentence]
        predicts = model(seq, ch_seq)
        loss = loss_f(predicts, torch.tensor([tag2index[t] for t in tags], dtype = torch.long))
        loss.backward()
        optimizer.step()
        
with torch.no_grad():
    seq = torch.tensor([word2index[w] for w in training_data[0][0]], dtype = torch.long)
    ch_seq = [tuple([ch2index[ch] for ch in w]) for w in training_data[0][0]]
    tag_scores = model(seq, ch_seq)

    print(tag_scores)

tensor([[-1.0206, -1.1109, -1.1701],
        [-1.0233, -1.1051, -1.1731],
        [-1.0073, -1.1193, -1.1767],
        [-1.0184, -1.1153, -1.1679],
        [-1.0024, -1.1194, -1.1825]])
tensor([[-0.0322, -5.6385, -3.5694],
        [-6.8223, -0.0142, -4.3403],
        [-3.6461, -3.6338, -0.0539],
        [-0.0208, -7.0147, -3.9264],
        [-5.1633, -0.0177, -4.4388]])
