In [1]:
import torch as t
from torchtext.datasets import IMDB
from torchtext import data

In [2]:
# load train and test data

TEXT = data.Field(fix_length=500, lower=True)
LABEL = data.Field(sequential=False)


train, test = IMDB.splits(text_field=TEXT, label_field=LABEL)

TEXT.build_vocab(train)
LABEL.build_vocab(train)

In [3]:
train_iter, test_iter = data.BucketIterator.splits((train, test), batch_size=32)

In [4]:
import torch.nn.functional as F

In [5]:
class RNN(t.nn.Module):
    def __init__(self, vocab_size, hidden_num=32, num_classes=1):
        super(RNN, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_num = hidden_num
        self.num_classes = num_classes
        
        self.embedd = t.nn.Embedding(vocab_size, embedding_dim=100)
        self.lstm1 = t.nn.LSTM(100, 64)
        self.lstm2 = t.nn.LSTM(64, hidden_num)
        self.fc = t.nn.Linear(hidden_num, 1)
        
    def forward(self, inputs):
        x = self.embedd(inputs)
        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)
        x = x[:, -1, :]
        x = F.tanh(x)
        x = self.fc(x)
        return t.sigmoid(x)

In [6]:
VOCAB_SIZE = len(TEXT.vocab.stoi)
EPOCHES = 5

rnn = RNN(VOCAB_SIZE)
rnn

RNN(
  (embedd): Embedding(251639, 100)
  (lstm1): LSTM(100, 64)
  (lstm2): LSTM(64, 32)
  (fc): Linear(in_features=32, out_features=1, bias=True)
)

In [7]:
loss_func = t.nn.BCELoss()

optimizer = t.optim.Adam(rnn.parameters(), lr=1e-5)

In [8]:
# batch.text 32x500
# batch.label 32
running_loss = 0
for idx, batch in enumerate(train_iter):
    y_pred = rnn(batch.text.T).squeeze(1)
    loss = loss_func(y_pred, batch.label.float()-1)
    
    optimizer.zero_grad()
    loss.backward()
    
    optimizer.step()
    
    running_loss += loss.item()
    
    if idx % 10 == 9:
        print(f"idx: {idx} loss: {running_loss / 10}")
        running_loss = 0



idx: 9 loss: 0.6923101365566253
idx: 19 loss: 0.6934561550617218
idx: 29 loss: 0.6934484541416168
idx: 39 loss: 0.6935383796691894
idx: 49 loss: 0.6915086388587952
idx: 59 loss: 0.692417049407959
idx: 69 loss: 0.6941934585571289
idx: 79 loss: 0.6925857484340667
idx: 89 loss: 0.6926884472370147
idx: 99 loss: 0.6925378799438476
idx: 109 loss: 0.692298811674118
idx: 119 loss: 0.693765377998352
idx: 129 loss: 0.6935660123825074
idx: 139 loss: 0.6920570611953736
idx: 149 loss: 0.6934839069843293
idx: 159 loss: 0.6929186522960663
idx: 169 loss: 0.691523265838623
idx: 179 loss: 0.6932915925979615
idx: 189 loss: 0.6942057430744171
idx: 199 loss: 0.6948730409145355
idx: 209 loss: 0.6941698670387269


KeyboardInterrupt: 

In [None]:
test_data = iter(test_iter)

In [None]:
test_batch = next(test_data)

In [None]:
y_hat = rnn(test_batch.text.T)

In [None]:
y_hat, test_batch.label