In [None]:
import torch
from torch import nn
import tools as tl

In [None]:
batch_size = 64
train_iter, test_iter, vocab = tl.load_data_imdb(batch_size)

In [None]:
class BiRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, **kwargs):
        super(BiRNN, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.encoder = nn.LSTM(input_size=embed_size, hidden_size=num_hiddens, num_layers=num_layers,
                               bidirectional=True)
        self.decoder = nn.Linear(4 * num_hiddens, 2)

    def forward(self, inputs):
        embeddings = self.embedding(inputs.T)
        self.encoder.flatten_parameters()
        outputs, _ = self.encoder(embeddings)
        encoding = torch.cat((outputs[0], outputs[-1]), dim=-1)
        outs = self.decoder(encoding)
        return outs

In [None]:
embed_size, num_hiddens, num_layers = 100, 100, 2
devices = tl.try_all_gpus()
net = BiRNN(len(vocab), embed_size, num_hiddens, num_layers)

In [None]:
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.xavier_uniform_(m.weight)
    if type(m) == nn.LSTM:
        for param in m._flat_weights_names:
            if "weight" in param:
                nn.init.xavier_uniform_(m._parameters[param])


net.apply(init_weights)

In [None]:
glove_embedding = tl.TokenEmbedding('glove.6B.100d')
embeds = glove_embedding[vocab.idx_to_token]
embeds.shape

In [None]:
net.embedding.weight.data.copy_(embeds)
net.embedding.weight.requires_grad = False

In [None]:
lr, num_epochs = 0.01, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
tl.train_ch13(train_iter, test_iter, net, loss, trainer, num_epochs, devices)

In [None]:
tl.predict_sentiment(net, vocab, 'this movie is so great')

In [None]:
tl.predict_sentiment(net, vocab, 'this movie is so bad')