In [81]:
import torch
import torchtext
from torch import nn
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import Vocab
from collections import Counter
from torch.utils.data import DataLoader
import torch.nn.functional as F


In [99]:
batch_size = 128
embedding_dim = 100
hidden_dim = 256
output_dim = 1
num_epochs = 3

In [67]:
tokenizer = get_tokenizer('basic_english')
TEXT = torchtext.data.Field(tokenize=tokenizer)
LABEL = torchtext.data.LabelField(dtype=torch.float)
train_data, test_data = IMDB.splits(TEXT, LABEL)

TEXT.build_vocab(train_data, min_freq=1)
LABEL.build_vocab(train_data)

train_loader, test_loader = torchtext.data.BucketIterator.splits(
    (train_data, test_data), batch_size=batch_size, shuffle=True)

vocab_size = len(TEXT.vocab)
output_dim = 1


In [85]:
class RNNModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, text):
        embedded = self.embedding(text)
        output, _ = self.rnn(embedded)
        predictions = self.fc(output[-1])
        return predictions

In [86]:
model = RNNModel(vocab_size, embedding_dim, hidden_dim, output_dim)
optimizer = torch.optim.Adam(model.parameters())

criterion = nn.BCEWithLogitsLoss()

In [87]:
def accuracy(predictions, labels):
    sigmoid = torch.nn.Sigmoid()
    rounded_preds = torch.round(sigmoid(predictions))
    correct = (rounded_preds == labels).float()
    acc = correct.sum() / len(correct)
    return acc


In [102]:
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    for batch in train_loader:
        optimizer.zero_grad()
        text = batch.text.long()
        predictions = model(text).squeeze(1)
        loss = criterion(predictions, batch.label)
        acc = accuracy(predictions, batch.label)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_acc += acc.item()
    print(f'Epoch {epoch + 1}:')
    print(f'\tLoss: {epoch_loss / len(train_loader):.4f}')
    print(f'\tAccuracy: {epoch_acc / len(train_loader):.4f}')


KeyboardInterrupt: 