In [4]:
from models import text_cnn

In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [6]:
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0]  # 1 is good, 0 is not good.

word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}

In [7]:
num_filters = 3 
filter_sizes = [2, 2, 2] 
vocab_size = len(word_dict)
embedding_size = 2 
sequence_length = 3 
num_classes = 2 

In [8]:
model = text_cnn.TextCNN(
    num_filters, filter_sizes, vocab_size,
    embedding_size, sequence_length, num_classes
)
model

TextCNN(
  (W): Embedding(16, 2)
  (Weight): Linear(in_features=9, out_features=2, bias=False)
  (filter_list): ModuleList(
    (0-2): 3 x Conv2d(1, 3, kernel_size=(2, 2), stride=(1, 1))
  )
)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

inputs = torch.LongTensor([np.asarray([word_dict[n] for n in sen.split()]) for sen in sentences])
targets = torch.LongTensor([out for out in labels]) # To using Torch Softmax Loss function

# Training
for epoch in range(5):
    optimizer.zero_grad()
    output = model(inputs)

    # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
    loss = criterion(output, targets)
    print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
 
    loss.backward()
    optimizer.step()

test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = torch.LongTensor(tests)

# Predict
predict = model(test_batch).data.max(1, keepdim=True)[1]
if predict[0][0] == 0:
    print(test_text,"is Bad Mean...")
else:
    print(test_text,"is Good Mean!!")

Epoch: 0001 cost = 0.739378
Epoch: 0002 cost = 0.737598
Epoch: 0003 cost = 0.735840
Epoch: 0004 cost = 0.734106
Epoch: 0005 cost = 0.732395
sorry hate you is Bad Mean...
