In [1]:
!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 80.2M  100 80.2M    0     0  4398k      0  0:00:18  0:00:17  0:00:01 5510k2M   69 55.4M    0     0  4094k      0  0:00:20  0:00:13  0:00:07 5411k    0  4444k      0  0:00:18  0:00:18 --:--:-- 5495k


In [9]:
!tar -xf aclImdb_v1.tar.gz

In [None]:
from utils import create_dataloaders
from torch import optim
import torch
import torch.nn as nn

In [3]:
train_dir = "aclImdb/train"
train_dataloader, vocab = create_dataloaders(train_dir)

In [4]:
test_dir = "aclImdb/test"
test_dataloader, vocab = create_dataloaders(test_dir)

In [None]:
from transformer_encoder import TransformerEncoder
from positional_embedding import PositionalEmbedding

class Transformer(nn.Module):
    def __init__(self, embed_dim, dense_dim, num_heads, vocab_size, **kwargs):
        super(Transformer, self).__init__()

        self.embedding = PositionalEmbedding(vocab_size, embed_dim)

        self.encoder = TransformerEncoder(embed_dim, dense_dim, num_heads)

        self.global_max_pool = nn.AdaptiveMaxPool1d(1)
        self.out = nn.Linear(embed_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, text, mask=None):
        embedded = self.embedding(text)
        encoder_output = self.encoder(embedded, mask)
        output = encoder_output.max(dim=1)[0]
        output = output.squeeze(-1)
        output = self.out(output)
        output = self.sigmoid(output)
        return output

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

embed_dim = 128
num_heads = 2
dense_dim = 32

transformer = Transformer(embed_dim, dense_dim, num_heads, 20000).to(device)

rmsprop = optim.RMSprop(params=transformer.parameters(), lr=0.0001)
criterion = nn.BCELoss()

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
for epoch in range(10):
    transformer.train()
    correct_predictions = 0
    total_predictions = 0


    for batch in train_dataloader:
        text, label = batch

        rmsprop.zero_grad()

        output = transformer(text.to(device))

        loss = criterion(output[:, 0], label.to(device).float())

        correct_predictions += (output[:, 0] > 0.5).eq(label.to(device)).sum().item()
        total_predictions += len(label)


        loss.backward()
        rmsprop.step()


    print(f"Epoch: {epoch+1}, Loss: {loss.item()}, Accuracy: {correct_predictions / total_predictions * 100}")


Epoch: 1, Loss: 0.6537891030311584, Accuracy: 53.92
Epoch: 2, Loss: 0.5448966026306152, Accuracy: 63.68000000000001
Epoch: 3, Loss: 0.4857151508331299, Accuracy: 69.82000000000001
Epoch: 4, Loss: 0.455940842628479, Accuracy: 72.404
Epoch: 5, Loss: 0.45676377415657043, Accuracy: 75.02799999999999
Epoch: 6, Loss: 0.355640172958374, Accuracy: 77.02799999999999
Epoch: 7, Loss: 0.3498089909553528, Accuracy: 78.8
Epoch: 8, Loss: 0.5490145087242126, Accuracy: 80.67999999999999
Epoch: 9, Loss: 0.32198524475097656, Accuracy: 82.384
Epoch: 10, Loss: 0.07523089647293091, Accuracy: 83.828


In [9]:
correct_predictions = 0
total_predictions = 0
for batch in test_dataloader:
    text, label = batch

    output = transformer(text.to(device))

    loss = criterion(output[:, 0], label.to(device).float())

    correct_predictions += (output[:, 0] > 0.5).eq(label.to(device)).sum().item()
    total_predictions += len(label)

print(f"Loss: {loss.item()}, Accuracy: {correct_predictions / total_predictions * 100}")

Loss: 1.254289150238037, Accuracy: 54.928


In [None]:
for batch in train_dataloader:
    text, label = batch
    output = transformer(text.to(device))
    for i in range(0, 100):
        print(f"{label[i]} = {output[i,0]}")


1 = 0.3088705241680145
0 = 0.15901578962802887
1 = 0.5373284816741943
1 = 0.9435697793960571
0 = 0.705100953578949
0 = 0.1325273960828781
0 = 0.029428264126181602
1 = 0.9192430973052979
1 = 0.9317619204521179
0 = 0.43019816279411316
1 = 0.9812169075012207
0 = 0.06102919206023216
0 = 0.6255525350570679
0 = 0.34059834480285645
1 = 0.6655703186988831
1 = 0.9709249138832092
0 = 0.004734253976494074
1 = 0.616407573223114
1 = 0.3340715765953064
0 = 0.23141178488731384
0 = 0.4274246096611023
0 = 0.604925274848938
1 = 0.9750867486000061
1 = 0.852275013923645
0 = 0.6010562777519226
1 = 0.9529685974121094
1 = 0.7750461101531982
0 = 0.019227078184485435
1 = 0.9951541423797607
1 = 0.9381968975067139
1 = 0.934755802154541
0 = 0.09102805703878403
1 = 0.7634303569793701
0 = 0.49535441398620605
1 = 0.639639675617218
0 = 0.021112579852342606
0 = 0.050686560571193695
1 = 0.9517002701759338
1 = 0.5856476426124573
1 = 0.4772869348526001
1 = 0.9438878893852234
0 = 0.21540483832359314
1 = 0.9444018602371216