In [1]:
# https://pytorch.org/tutorials/beginner/nlp/deep_learning_tutorial.html

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
data = [("me gusta comer en la cafeteria".split(), "SPANISH"),
        ("Give it to me".split(), "ENGLISH"),
        ("No creo que sea una buena idea".split(), "SPANISH"),
        ("No it is not a good idea to get lost at sea".split(), "ENGLISH")]

test_data = [("Yo creo que si".split(), "SPANISH"),
             ("it is lost on me".split(), "ENGLISH")]

In [3]:
word2index = {}

for seq, _ in data:
    for word in seq:
        if word not in word2index:
            word2index[word] = len(word2index)

In [9]:
class BoW(nn.Module):
    def __init__(self, num_labels, vocab_size):
        super(BoW, self).__init__()
        self.linear = nn.Linear(vocab_size, num_labels)
        
    def forward(self, bow_vec):
        return F.log_softmax(self.linear(bow_vec), dim = 1)

def make_bow_vec(seq, word2index):
    vec = torch.zeros(len(word2index))
    
    for word in seq:
        if word in word2index:
            vec[word2index[word]] += 1
        
    return vec.view(1, -1)

def make_target(label, label2index):
    return torch.LongTensor([label2index[label]])

In [13]:
clf = BoW(2, len(word2index))

with torch.no_grad():
    sample = data[0]
    bow_vec = make_bow_vec(sample[0], word2index)
    print(clf(bow_vec))

tensor([[-0.4786, -0.9666]])
tensor([0])


In [18]:
loss_f = nn.NLLLoss()
optimizer = optim.SGD(clf.parameters(), lr = 0.1)

for epoch in range(100):
    for sample, label in data:
        clf.zero_grad()
        
        bow_vec = make_bow_vec(sample, word2index)
        target = make_target(label, {'SPANISH': 0, 'ENGLISH': 1})
        
        log_probs = clf(bow_vec)
        loss = loss_f(log_probs, target)
        loss.backward()
        optimizer.step()
        
with torch.no_grad():
    for sample, label in test_data:
        bow_vec = make_bow_vec(sample, word2index)
        log_probs = clf(bow_vec)
        print(log_probs)

tensor([[-0.1664, -1.8752]])
tensor([[-2.6157, -0.0759]])
