In [16]:
import numpy as np
import pandas as pd
import gensim.downloader as api
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import re
import random
from string import punctuation
from tqdm import tqdm

In [18]:
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [2]:
train_df = pd.read_csv("train.csv")
test_df = pd.read_csv("test.csv")
train_df.head()

Unnamed: 0,topic,the concept of the topic,candidate,candidate masked,label,wikipedia article name,wikipedia url
0,We should limit executive compensation,executive compensation,A say on pay - a non-binding vote of the gener...,A say on pay - a non-binding vote of the gener...,0,Executive pay,https://en.wikipedia.org/wiki/Executive_pay
1,We should limit executive compensation,executive compensation,"A February 2009 report, published by the Insti...","A February 2009 report, published by the Insti...",1,Executive pay,https://en.wikipedia.org/wiki/Executive_pay
2,We should limit executive compensation,executive compensation,The Financial Crisis has had a relatively smal...,The Financial Crisis has had a relatively smal...,0,Executive pay,https://en.wikipedia.org/wiki/Executive_pay
3,We should limit executive compensation,executive compensation,"1990-1992 Lineberger Cancer Center, SPA person...","1990-1992 Lineberger Cancer Center, SPA person...",0,Edison Liu,https://en.wikipedia.org/wiki/Edison_Liu
4,We should limit executive compensation,executive compensation,Countering the public uproar over excessive ex...,Countering the public uproar over excessive TO...,0,Jack Welch,https://en.wikipedia.org/wiki/Jack_Welch


In [3]:
word_re = re.compile(r"\b[a-z]{2,}\b")

def tokenize(text):
    processed_text = "".join(ch for ch in text.lower() if ch not in punctuation)
    processed_text = processed_text.replace("\n", " ")
    return word_re.findall(processed_text)

In [4]:
all_train = train_df['candidate'].apply(tokenize)
test = test_df['candidate'].apply(tokenize)
all_train

0       [say, on, pay, nonbinding, vote, of, the, gene...
1       [february, report, published, by, the, institu...
2       [the, financial, crisis, has, had, relatively,...
3       [lineberger, cancer, center, spa, personnel, c...
4       [countering, the, public, uproar, over, excess...
                              ...                        
4060    [new, zealand, prime, minister, john, key, sai...
4061    [the, international, monarchist, league, found...
4062    [on, march, three, shiite, groups, formed, the...
4063    [the, gulf, states, were, especially, inclined...
4064    [in, may, poll, by, angus, reid, found, that, ...
Name: candidate, Length: 4065, dtype: object

In [5]:
embed_size = 300
wv = api.load(f'glove-wiki-gigaword-{embed_size}')

In [6]:
vocab = pd.Series(wv.key_to_index) + 1
encoded_train = all_train.apply(lambda x: [vocab.get(w, 0) for w in x])
encoded_test = test.apply(lambda x: [vocab.get(w, 0) for w in x])
encoded_train

0       [204, 14, 643, 27374, 539, 4, 1, 217, 287, 5, ...
1       [618, 256, 736, 22, 1, 1065, 11, 528, 1380, 21...
2       [1, 396, 722, 32, 41, 2224, 358, 1051, 1262, 1...
3              [249139, 1648, 314, 7213, 2076, 3532, 447]
4       [22804, 1, 199, 14954, 75, 6341, 617, 643, 145...
                              ...                        
4060    [51, 1273, 336, 142, 280, 639, 17, 19, 1055, 1...
4061    [1, 147, 47401, 293, 1298, 7, 43, 32, 52, 192,...
4062    [14, 305, 88, 2582, 504, 1348, 1, 1127, 11, 23...
4063    [1, 1666, 113, 36, 859, 11730, 5, 1656, 30, 18...
4064    [7, 108, 1888, 22, 16293, 5700, 239, 13, 57, 7...
Name: candidate, Length: 4065, dtype: object

In [7]:
encoded_labels = train_df['label'].to_numpy()
encoded_labels_test = test_df['label'].to_numpy()
encoded_labels

array([0, 1, 0, ..., 1, 1, 0])

In [8]:
def pad_data(encoded_data, seq_length):
    data = encoded_data[:seq_length]
    padded_data = np.zeros(seq_length, dtype=int)
    padded_data[-len(data):] = data
    return padded_data

padded_train = np.stack(encoded_train.apply(pad_data, seq_length=81))
padded_test = np.stack(encoded_test.apply(pad_data, seq_length=81))
padded_train

array([[    0,     0,     0, ...,   224,     4,   253],
       [    0,     0,     0, ...,     6,  3568, 19812],
       [    0,     0,     0, ...,    14,   617,   643],
       ...,
       [    0,     0,     0, ...,   591,   875, 29127],
       [    0,     0,     0, ...,     4,    79, 29127],
       [    0,     0,     0, ...,  4828,   915,   101]])

In [9]:
train_ratio = 0.85
total = padded_train.shape[0]
train_cutoff = int(total * train_ratio)

train_x, train_y = torch.from_numpy(padded_train[:train_cutoff]), torch.from_numpy(encoded_labels[:train_cutoff])
valid_x, valid_y = torch.from_numpy(padded_train[train_cutoff:]), torch.from_numpy(encoded_labels[train_cutoff:])

test_x, test_y = torch.from_numpy(padded_test), torch.from_numpy(encoded_labels_test)

from torch.utils.data import TensorDataset, DataLoader

train_data = TensorDataset(train_x, train_y)
valid_data = TensorDataset(valid_x, valid_y)
test_data = TensorDataset(test_x, test_y)

BATCH_SIZE = 32
train_loader = DataLoader(train_data, batch_size = BATCH_SIZE, shuffle = True)
valid_loader = DataLoader(valid_data, batch_size = BATCH_SIZE, shuffle = True)
test_loader = DataLoader(test_data, batch_size = BATCH_SIZE, shuffle = True)

In [20]:
class SentimentLSTM(nn.Module):

    def __init__(self, n_vocab, n_embed, n_hidden, n_output, n_layers, attention_size, drop_p = 0.85):
        super().__init__()
        # params: "n_" means dimension

        self.n_vocab = n_vocab
        self.n_layers = n_layers
        self.n_hidden = n_hidden

        self.embedding = nn.Embedding.from_pretrained(wv_torch_with_pad)
        self.lstm = nn.LSTM(n_embed, n_hidden, n_layers, bidirectional=True, batch_first = True, dropout = drop_p)
        self.fc = nn.Linear(n_hidden*2, n_output)

        self.attention = nn.Linear(n_hidden * 2, attention_size)
        self.attention_combine = nn.Linear(attention_size, 1, bias=False)

        self.dropout = nn.Dropout(drop_p)
        self.sigmoid = nn.Sigmoid()
        self.__init_linear()


    def forward (self, input_words, hidden_state=None):
        if hidden_state is None:
            h, c = self.init_hidden(input_words.shape[0])
        else:
            h, c = hidden_state
        x = self.embedding(input_words)
        lstm_out, (h, c) = self.lstm(x, (h, c))

        # Attention mechanism
        attention_weights = torch.tanh(self.attention(lstm_out))  # Get attention weights
        attention_weights = F.softmax(self.attention_combine(attention_weights), dim=1)
        context_vector = torch.sum(attention_weights * lstm_out, dim=1)

        # Classification
        fc_out = self.fc(context_vector)
        sigmoid_out = self.sigmoid(fc_out)

        return sigmoid_out, (h, c)


    def __init_linear(self):
        self.fc.weight.data.normal_(0.0, 1/np.sqrt(self.n_hidden * 2))
        self.fc.bias.data.fill_(0)


    def init_hidden(self, batch_size):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        weights = next(self.parameters()).data
        h = (weights.new(self.n_layers * 2, batch_size, self.n_hidden).zero_().to(device),
             weights.new(self.n_layers * 2, batch_size, self.n_hidden).zero_().to(device))

        return h

In [21]:
# torch.cat
wv_torch = torch.from_numpy(wv.vectors)
pad = torch.zeros((1, wv_torch.shape[1]))
wv_torch_with_pad = torch.cat((pad, wv_torch), dim=0)

In [31]:
n_vocab = wv_torch_with_pad.shape[0]
n_embed = wv_torch_with_pad.shape[1]
n_hidden = 128
n_output = 1
n_layers = 2
attention_size = 100

net = SentimentLSTM(n_vocab, n_embed, n_hidden, n_output, n_layers, attention_size)

criterion = nn.BCELoss()
optimizer = optim.Adam(net.parameters(), lr = 0.001, weight_decay=1e-5)

In [32]:
print_every = 50
step = 0
n_epochs = 12
clip = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device)
net.train()

val_losses = []
val_accuracy = []
for epoch in tqdm(range(n_epochs)):
    for inputs, labels in tqdm(train_loader):
        step += 1
        inputs, labels = inputs.to(device), labels.to(device)

        output, t = net(inputs)
        loss = criterion(output.squeeze(), labels.float())
        loss.backward()
        nn.utils.clip_grad_norm_(net.parameters(), clip)
        optimizer.step()
        optimizer.zero_grad()

        if (step % print_every) == 0:
            ######################
            ##### VALIDATION #####
            ######################
            net.eval()
            valid_losses = []
            val_accuracy = []
            v_h = net.init_hidden(BATCH_SIZE)

            for v_inputs, v_labels in valid_loader:
                v_inputs, v_labels = inputs.to(device), labels.to(device)

                v_output, _ = net(v_inputs)
                v_loss = criterion(v_output.squeeze(), v_labels.float())
                valid_losses.append(v_loss.item())

                predictions = torch.round(torch.sigmoid(v_output.squeeze()))

            print("Epoch: {}/{}".format((epoch+1), n_epochs),
                  "Step: {}".format(step),
                  "Training Loss: {:.4f}".format(loss.item()),
                  "Validation Loss: {:.4f}".format(np.mean(valid_losses)))
            net.train()

  0%|                                                    | 0/12 [00:00<?, ?it/s]
  0%|                                                   | 0/108 [00:00<?, ?it/s][A
  1%|▍                                          | 1/108 [00:00<00:15,  6.88it/s][A
  3%|█▏                                         | 3/108 [00:00<00:09, 10.97it/s][A
  5%|█▉                                         | 5/108 [00:00<00:08, 12.47it/s][A
  6%|██▊                                        | 7/108 [00:00<00:07, 13.40it/s][A
  8%|███▌                                       | 9/108 [00:00<00:07, 13.90it/s][A
 10%|████▎                                     | 11/108 [00:00<00:06, 14.32it/s][A
 12%|█████                                     | 13/108 [00:00<00:06, 14.50it/s][A
 14%|█████▊                                    | 15/108 [00:01<00:06, 14.61it/s][A
 16%|██████▌                                   | 17/108 [00:01<00:06, 14.58it/s][A
 18%|███████▍                                  | 19/108 [00:01<00:06, 14.50it/s

Epoch: 1/12 Step: 50 Training Loss: 0.5553 Validation Loss: 0.5419



 49%|████████████████████▌                     | 53/108 [00:04<00:07,  7.01it/s][A
 51%|█████████████████████▍                    | 55/108 [00:04<00:06,  8.23it/s][A
 53%|██████████████████████▏                   | 57/108 [00:04<00:05,  9.46it/s][A
 55%|██████████████████████▉                   | 59/108 [00:04<00:04, 10.59it/s][A
 56%|███████████████████████▋                  | 61/108 [00:04<00:04, 11.59it/s][A
 58%|████████████████████████▌                 | 63/108 [00:05<00:03, 12.32it/s][A
 60%|█████████████████████████▎                | 65/108 [00:05<00:03, 13.08it/s][A
 62%|██████████████████████████                | 67/108 [00:05<00:03, 13.55it/s][A
 64%|██████████████████████████▊               | 69/108 [00:05<00:02, 13.96it/s][A
 66%|███████████████████████████▌              | 71/108 [00:05<00:02, 14.15it/s][A
 68%|████████████████████████████▍             | 73/108 [00:05<00:02, 14.28it/s][A
 69%|█████████████████████████████▏            | 75/108 [00:05<00:02, 14.41

Epoch: 1/12 Step: 100 Training Loss: 0.5241 Validation Loss: 0.5124



 97%|███████████████████████████████████████▊ | 105/108 [00:08<00:00,  8.62it/s][A
100%|█████████████████████████████████████████| 108/108 [00:08<00:00, 12.21it/s]
  8%|███▋                                        | 1/12 [00:08<01:37,  8.85s/it]
  0%|                                                   | 0/108 [00:00<?, ?it/s][A
  2%|▊                                          | 2/108 [00:00<00:06, 15.30it/s][A
  4%|█▌                                         | 4/108 [00:00<00:07, 14.82it/s][A
  6%|██▍                                        | 6/108 [00:00<00:07, 13.90it/s][A
  7%|███▏                                       | 8/108 [00:00<00:07, 14.16it/s][A
  9%|███▉                                      | 10/108 [00:00<00:06, 14.16it/s][A
 11%|████▋                                     | 12/108 [00:00<00:06, 14.47it/s][A
 13%|█████▍                                    | 14/108 [00:00<00:06, 14.65it/s][A
 15%|██████▏                                   | 16/108 [00:01<00:06, 14.77it/s]

Epoch: 2/12 Step: 150 Training Loss: 0.4744 Validation Loss: 0.4428



 43%|█████████████████▉                        | 46/108 [00:03<00:07,  8.68it/s][A
 44%|██████████████████▋                       | 48/108 [00:03<00:06,  9.85it/s][A
 46%|███████████████████▍                      | 50/108 [00:04<00:05, 10.89it/s][A
 48%|████████████████████▏                     | 52/108 [00:04<00:04, 11.82it/s][A
 50%|█████████████████████                     | 54/108 [00:04<00:04, 12.59it/s][A
 52%|█████████████████████▊                    | 56/108 [00:04<00:04, 12.93it/s][A
 54%|██████████████████████▌                   | 58/108 [00:04<00:03, 13.49it/s][A
 56%|███████████████████████▎                  | 60/108 [00:04<00:03, 13.96it/s][A
 57%|████████████████████████                  | 62/108 [00:04<00:03, 14.32it/s][A
 59%|████████████████████████▉                 | 64/108 [00:04<00:03, 14.63it/s][A
 61%|█████████████████████████▋                | 66/108 [00:05<00:02, 14.44it/s][A
 63%|██████████████████████████▍               | 68/108 [00:05<00:02, 14.21

Epoch: 2/12 Step: 200 Training Loss: 0.4897 Validation Loss: 0.4861



 89%|█████████████████████████████████████▎    | 96/108 [00:07<00:01,  8.43it/s][A
 91%|██████████████████████████████████████    | 98/108 [00:08<00:01,  9.70it/s][A
 93%|█████████████████████████████████████▉   | 100/108 [00:08<00:00, 10.80it/s][A
 94%|██████████████████████████████████████▋  | 102/108 [00:08<00:00, 11.78it/s][A
 96%|███████████████████████████████████████▍ | 104/108 [00:08<00:00, 12.68it/s][A
 98%|████████████████████████████████████████▏| 106/108 [00:08<00:00, 13.36it/s][A
100%|█████████████████████████████████████████| 108/108 [00:08<00:00, 12.44it/s]
 17%|███████▎                                    | 2/12 [00:17<01:27,  8.75s/it]
  0%|                                                   | 0/108 [00:00<?, ?it/s][A
  2%|▊                                          | 2/108 [00:00<00:06, 15.37it/s][A
  4%|█▌                                         | 4/108 [00:00<00:06, 15.28it/s][A
  6%|██▍                                        | 6/108 [00:00<00:06, 15.08it/s]

Epoch: 3/12 Step: 250 Training Loss: 0.4404 Validation Loss: 0.3599



 35%|██████████████▊                           | 38/108 [00:03<00:08,  8.60it/s][A
 37%|███████████████▌                          | 40/108 [00:03<00:06,  9.84it/s][A
 39%|████████████████▎                         | 42/108 [00:03<00:06, 10.93it/s][A
 41%|█████████████████                         | 44/108 [00:03<00:05, 11.84it/s][A
 43%|█████████████████▉                        | 46/108 [00:03<00:04, 12.64it/s][A
 44%|██████████████████▋                       | 48/108 [00:03<00:04, 13.32it/s][A
 46%|███████████████████▍                      | 50/108 [00:04<00:04, 13.44it/s][A
 48%|████████████████████▏                     | 52/108 [00:04<00:04, 13.73it/s][A
 50%|█████████████████████                     | 54/108 [00:04<00:03, 14.07it/s][A
 52%|█████████████████████▊                    | 56/108 [00:04<00:03, 14.37it/s][A
 54%|██████████████████████▌                   | 58/108 [00:04<00:03, 14.53it/s][A
 56%|███████████████████████▎                  | 60/108 [00:04<00:03, 14.73

Epoch: 3/12 Step: 300 Training Loss: 0.4536 Validation Loss: 0.4699



 81%|██████████████████████████████████▏       | 88/108 [00:07<00:02,  8.65it/s][A
 83%|███████████████████████████████████       | 90/108 [00:07<00:01,  9.91it/s][A
 85%|███████████████████████████████████▊      | 92/108 [00:07<00:01, 10.96it/s][A
 87%|████████████████████████████████████▌     | 94/108 [00:07<00:01, 11.84it/s][A
 89%|█████████████████████████████████████▎    | 96/108 [00:07<00:00, 12.50it/s][A
 91%|██████████████████████████████████████    | 98/108 [00:08<00:00, 12.92it/s][A
 93%|█████████████████████████████████████▉   | 100/108 [00:08<00:00, 13.19it/s][A
 94%|██████████████████████████████████████▋  | 102/108 [00:08<00:00, 13.46it/s][A
 96%|███████████████████████████████████████▍ | 104/108 [00:08<00:00, 13.73it/s][A
 98%|████████████████████████████████████████▏| 106/108 [00:08<00:00, 13.77it/s][A
100%|█████████████████████████████████████████| 108/108 [00:08<00:00, 12.37it/s]
 25%|███████████                                 | 3/12 [00:26<01:18,  8.74s/i

Epoch: 4/12 Step: 350 Training Loss: 0.3411 Validation Loss: 0.3167



 28%|███████████▋                              | 30/108 [00:02<00:09,  8.61it/s][A
 30%|████████████▍                             | 32/108 [00:02<00:07,  9.84it/s][A
 31%|█████████████▏                            | 34/108 [00:02<00:07, 10.54it/s][A
 33%|██████████████                            | 36/108 [00:03<00:06, 11.33it/s][A
 35%|██████████████▊                           | 38/108 [00:03<00:05, 11.87it/s][A
 37%|███████████████▌                          | 40/108 [00:03<00:05, 12.63it/s][A
 39%|████████████████▎                         | 42/108 [00:03<00:05, 13.19it/s][A
 41%|█████████████████                         | 44/108 [00:03<00:04, 13.62it/s][A
 43%|█████████████████▉                        | 46/108 [00:03<00:04, 13.98it/s][A
 44%|██████████████████▋                       | 48/108 [00:03<00:04, 14.17it/s][A
 46%|███████████████████▍                      | 50/108 [00:04<00:04, 14.43it/s][A
 48%|████████████████████▏                     | 52/108 [00:04<00:03, 14.52

Epoch: 4/12 Step: 400 Training Loss: 0.6976 Validation Loss: 0.6656



 74%|███████████████████████████████           | 80/108 [00:06<00:03,  8.46it/s][A
 76%|███████████████████████████████▉          | 82/108 [00:06<00:02,  9.79it/s][A
 78%|████████████████████████████████▋         | 84/108 [00:07<00:02, 10.96it/s][A
 80%|█████████████████████████████████▍        | 86/108 [00:07<00:01, 11.84it/s][A
 81%|██████████████████████████████████▏       | 88/108 [00:07<00:01, 12.43it/s][A
 83%|███████████████████████████████████       | 90/108 [00:07<00:01, 12.95it/s][A
 85%|███████████████████████████████████▊      | 92/108 [00:07<00:01, 13.49it/s][A
 87%|████████████████████████████████████▌     | 94/108 [00:07<00:01, 13.99it/s][A
 89%|█████████████████████████████████████▎    | 96/108 [00:07<00:00, 14.35it/s][A
 91%|██████████████████████████████████████    | 98/108 [00:08<00:00, 14.37it/s][A
 93%|█████████████████████████████████████▉   | 100/108 [00:08<00:00, 14.34it/s][A
 94%|██████████████████████████████████████▋  | 102/108 [00:08<00:00, 14.44

Epoch: 5/12 Step: 450 Training Loss: 0.3672 Validation Loss: 0.3503



 20%|████████▌                                 | 22/108 [00:02<00:09,  8.62it/s][A
 22%|█████████▎                                | 24/108 [00:02<00:08,  9.89it/s][A
 24%|██████████                                | 26/108 [00:02<00:07, 11.07it/s][A
 26%|██████████▉                               | 28/108 [00:02<00:06, 12.09it/s][A
 28%|███████████▋                              | 30/108 [00:02<00:06, 12.92it/s][A
 30%|████████████▍                             | 32/108 [00:02<00:05, 13.08it/s][A
 31%|█████████████▏                            | 34/108 [00:02<00:05, 13.66it/s][A
 33%|██████████████                            | 36/108 [00:03<00:05, 14.13it/s][A
 35%|██████████████▊                           | 38/108 [00:03<00:04, 14.48it/s][A
 37%|███████████████▌                          | 40/108 [00:03<00:04, 14.75it/s][A
 39%|████████████████▎                         | 42/108 [00:03<00:04, 14.88it/s][A
 41%|█████████████████                         | 44/108 [00:03<00:04, 14.97

Epoch: 5/12 Step: 500 Training Loss: 0.5322 Validation Loss: 0.5205



 67%|████████████████████████████              | 72/108 [00:06<00:04,  8.81it/s][A
 69%|████████████████████████████▊             | 74/108 [00:06<00:03, 10.05it/s][A
 70%|█████████████████████████████▌            | 76/108 [00:06<00:02, 11.21it/s][A
 72%|██████████████████████████████▎           | 78/108 [00:06<00:02, 12.18it/s][A
 74%|███████████████████████████████           | 80/108 [00:06<00:02, 12.92it/s][A
 76%|███████████████████████████████▉          | 82/108 [00:06<00:01, 13.49it/s][A
 78%|████████████████████████████████▋         | 84/108 [00:06<00:01, 13.91it/s][A
 80%|█████████████████████████████████▍        | 86/108 [00:07<00:01, 14.20it/s][A
 81%|██████████████████████████████████▏       | 88/108 [00:07<00:01, 14.50it/s][A
 83%|███████████████████████████████████       | 90/108 [00:07<00:01, 14.70it/s][A
 85%|███████████████████████████████████▊      | 92/108 [00:07<00:01, 14.89it/s][A
 87%|████████████████████████████████████▌     | 94/108 [00:07<00:00, 14.94

Epoch: 6/12 Step: 550 Training Loss: 0.2583 Validation Loss: 0.1950



 13%|█████▍                                    | 14/108 [00:01<00:11,  8.21it/s][A
 15%|██████▏                                   | 16/108 [00:01<00:10,  9.07it/s][A
 17%|███████                                   | 18/108 [00:01<00:08, 10.30it/s][A
 19%|███████▊                                  | 20/108 [00:02<00:07, 11.39it/s][A
 20%|████████▌                                 | 22/108 [00:02<00:06, 12.31it/s][A
 22%|█████████▎                                | 24/108 [00:02<00:06, 13.13it/s][A
 24%|██████████                                | 26/108 [00:02<00:05, 13.76it/s][A
 26%|██████████▉                               | 28/108 [00:02<00:05, 14.19it/s][A
 28%|███████████▋                              | 30/108 [00:02<00:05, 14.52it/s][A
 30%|████████████▍                             | 32/108 [00:02<00:05, 14.58it/s][A
 31%|█████████████▏                            | 34/108 [00:03<00:05, 14.26it/s][A
 33%|██████████████                            | 36/108 [00:03<00:05, 14.09

Epoch: 6/12 Step: 600 Training Loss: 0.2022 Validation Loss: 0.1488



 59%|████████████████████████▉                 | 64/108 [00:05<00:04,  8.83it/s][A
 61%|█████████████████████████▋                | 66/108 [00:05<00:04, 10.08it/s][A
 63%|██████████████████████████▍               | 68/108 [00:05<00:03, 11.19it/s][A
 65%|███████████████████████████▏              | 70/108 [00:06<00:03, 12.15it/s][A
 67%|████████████████████████████              | 72/108 [00:06<00:02, 12.88it/s][A
 69%|████████████████████████████▊             | 74/108 [00:06<00:02, 13.45it/s][A
 70%|█████████████████████████████▌            | 76/108 [00:06<00:02, 13.86it/s][A
 72%|██████████████████████████████▎           | 78/108 [00:06<00:02, 14.23it/s][A
 74%|███████████████████████████████           | 80/108 [00:06<00:01, 14.52it/s][A
 76%|███████████████████████████████▉          | 82/108 [00:06<00:01, 14.68it/s][A
 78%|████████████████████████████████▋         | 84/108 [00:06<00:01, 14.83it/s][A
 80%|█████████████████████████████████▍        | 86/108 [00:07<00:01, 14.93

Epoch: 7/12 Step: 650 Training Loss: 0.2310 Validation Loss: 0.1815



  6%|██▍                                        | 6/108 [00:01<00:14,  7.03it/s][A
  7%|███▏                                       | 8/108 [00:01<00:11,  8.92it/s][A
  9%|███▉                                      | 10/108 [00:01<00:09, 10.51it/s][A
 11%|████▋                                     | 12/108 [00:01<00:09, 10.49it/s][A
 13%|█████▍                                    | 14/108 [00:01<00:08, 11.60it/s][A
 15%|██████▏                                   | 16/108 [00:01<00:07, 12.50it/s][A
 17%|███████                                   | 18/108 [00:01<00:06, 13.27it/s][A
 19%|███████▊                                  | 20/108 [00:02<00:06, 13.83it/s][A
 20%|████████▌                                 | 22/108 [00:02<00:06, 14.26it/s][A
 22%|█████████▎                                | 24/108 [00:02<00:05, 14.53it/s][A
 24%|██████████                                | 26/108 [00:02<00:05, 14.69it/s][A
 26%|██████████▉                               | 28/108 [00:02<00:05, 14.85

Epoch: 7/12 Step: 700 Training Loss: 0.2690 Validation Loss: 0.2160



 52%|█████████████████████▊                    | 56/108 [00:05<00:05,  8.83it/s][A
 54%|██████████████████████▌                   | 58/108 [00:05<00:04, 10.12it/s][A
 56%|███████████████████████▎                  | 60/108 [00:05<00:04, 11.28it/s][A
 57%|████████████████████████                  | 62/108 [00:05<00:03, 12.22it/s][A
 59%|████████████████████████▉                 | 64/108 [00:05<00:03, 12.94it/s][A
 61%|█████████████████████████▋                | 66/108 [00:05<00:03, 13.55it/s][A
 63%|██████████████████████████▍               | 68/108 [00:05<00:02, 14.04it/s][A
 65%|███████████████████████████▏              | 70/108 [00:05<00:02, 14.39it/s][A
 67%|████████████████████████████              | 72/108 [00:06<00:02, 14.35it/s][A
 69%|████████████████████████████▊             | 74/108 [00:06<00:02, 14.59it/s][A
 70%|█████████████████████████████▌            | 76/108 [00:06<00:02, 14.78it/s][A
 72%|██████████████████████████████▎           | 78/108 [00:06<00:02, 14.94

Epoch: 7/12 Step: 750 Training Loss: 0.3320 Validation Loss: 0.2906



 98%|████████████████████████████████████████▏| 106/108 [00:09<00:00,  8.85it/s][A
100%|█████████████████████████████████████████| 108/108 [00:09<00:00, 11.69it/s]
 58%|█████████████████████████▋                  | 7/12 [01:01<00:44,  8.85s/it]
  0%|                                                   | 0/108 [00:00<?, ?it/s][A
  2%|▊                                          | 2/108 [00:00<00:06, 15.58it/s][A
  4%|█▌                                         | 4/108 [00:00<00:06, 15.46it/s][A
  6%|██▍                                        | 6/108 [00:00<00:06, 15.57it/s][A
  7%|███▏                                       | 8/108 [00:00<00:06, 15.46it/s][A
  9%|███▉                                      | 10/108 [00:00<00:06, 15.40it/s][A
 11%|████▋                                     | 12/108 [00:00<00:06, 15.37it/s][A
 13%|█████▍                                    | 14/108 [00:00<00:06, 15.29it/s][A
 15%|██████▏                                   | 16/108 [00:01<00:06, 15.32it/s]

Epoch: 8/12 Step: 800 Training Loss: 0.2786 Validation Loss: 0.2119



 44%|██████████████████▋                       | 48/108 [00:03<00:06,  8.69it/s][A
 46%|███████████████████▍                      | 50/108 [00:03<00:05,  9.94it/s][A
 48%|████████████████████▏                     | 52/108 [00:04<00:05, 11.08it/s][A
 50%|█████████████████████                     | 54/108 [00:04<00:04, 12.04it/s][A
 52%|█████████████████████▊                    | 56/108 [00:04<00:04, 12.78it/s][A
 54%|██████████████████████▌                   | 58/108 [00:04<00:03, 13.39it/s][A
 56%|███████████████████████▎                  | 60/108 [00:04<00:03, 13.82it/s][A
 57%|████████████████████████                  | 62/108 [00:04<00:03, 14.20it/s][A
 59%|████████████████████████▉                 | 64/108 [00:04<00:03, 14.49it/s][A
 61%|█████████████████████████▋                | 66/108 [00:05<00:02, 14.63it/s][A
 63%|██████████████████████████▍               | 68/108 [00:05<00:02, 14.75it/s][A
 65%|███████████████████████████▏              | 70/108 [00:05<00:02, 14.79

Epoch: 8/12 Step: 850 Training Loss: 0.2148 Validation Loss: 0.1574



 91%|██████████████████████████████████████    | 98/108 [00:07<00:01,  8.55it/s][A
 93%|█████████████████████████████████████▉   | 100/108 [00:08<00:00,  9.61it/s][A
 94%|██████████████████████████████████████▋  | 102/108 [00:08<00:00, 10.77it/s][A
 96%|███████████████████████████████████████▍ | 104/108 [00:08<00:00, 11.80it/s][A
 98%|████████████████████████████████████████▏| 106/108 [00:08<00:00, 12.64it/s][A
100%|█████████████████████████████████████████| 108/108 [00:08<00:00, 12.66it/s]
 67%|█████████████████████████████▎              | 8/12 [01:10<00:35,  8.75s/it]
  0%|                                                   | 0/108 [00:00<?, ?it/s][A
  2%|▊                                          | 2/108 [00:00<00:07, 14.31it/s][A
  4%|█▌                                         | 4/108 [00:00<00:07, 13.99it/s][A
  6%|██▍                                        | 6/108 [00:00<00:07, 14.22it/s][A
  7%|███▏                                       | 8/108 [00:00<00:07, 14.14it/s]

Epoch: 9/12 Step: 900 Training Loss: 0.2168 Validation Loss: 0.1615



 37%|███████████████▌                          | 40/108 [00:03<00:08,  7.98it/s][A
 39%|████████████████▎                         | 42/108 [00:03<00:07,  9.28it/s][A
 41%|█████████████████                         | 44/108 [00:03<00:06, 10.54it/s][A
 43%|█████████████████▉                        | 46/108 [00:04<00:05, 11.51it/s][A
 44%|██████████████████▋                       | 48/108 [00:04<00:04, 12.05it/s][A
 46%|███████████████████▍                      | 50/108 [00:04<00:04, 12.48it/s][A
 48%|████████████████████▏                     | 52/108 [00:04<00:04, 12.86it/s][A
 50%|█████████████████████                     | 54/108 [00:04<00:04, 13.37it/s][A
 52%|█████████████████████▊                    | 56/108 [00:04<00:03, 13.80it/s][A
 54%|██████████████████████▌                   | 58/108 [00:04<00:03, 14.12it/s][A
 56%|███████████████████████▎                  | 60/108 [00:05<00:03, 14.41it/s][A
 57%|████████████████████████                  | 62/108 [00:05<00:03, 14.65

Epoch: 9/12 Step: 950 Training Loss: 0.3120 Validation Loss: 0.2783



 83%|███████████████████████████████████       | 90/108 [00:07<00:02,  8.84it/s][A
 85%|███████████████████████████████████▊      | 92/108 [00:07<00:01, 10.09it/s][A
 87%|████████████████████████████████████▌     | 94/108 [00:07<00:01, 11.21it/s][A
 89%|█████████████████████████████████████▎    | 96/108 [00:08<00:00, 12.18it/s][A
 91%|██████████████████████████████████████    | 98/108 [00:08<00:00, 12.96it/s][A
 93%|█████████████████████████████████████▉   | 100/108 [00:08<00:00, 13.56it/s][A
 94%|██████████████████████████████████████▋  | 102/108 [00:08<00:00, 14.03it/s][A
 96%|███████████████████████████████████████▍ | 104/108 [00:08<00:00, 14.36it/s][A
 98%|████████████████████████████████████████▏| 106/108 [00:08<00:00, 14.60it/s][A
100%|█████████████████████████████████████████| 108/108 [00:08<00:00, 12.17it/s]
 75%|█████████████████████████████████           | 9/12 [01:18<00:26,  8.79s/it]
  0%|                                                   | 0/108 [00:00<?, ?it/s]

Epoch: 10/12 Step: 1000 Training Loss: 0.1380 Validation Loss: 0.0908



 30%|████████████▍                             | 32/108 [00:02<00:08,  8.78it/s][A
 31%|█████████████▏                            | 34/108 [00:02<00:07, 10.04it/s][A
 33%|██████████████                            | 36/108 [00:03<00:06, 11.16it/s][A
 35%|██████████████▊                           | 38/108 [00:03<00:05, 12.09it/s][A
 37%|███████████████▌                          | 40/108 [00:03<00:05, 12.80it/s][A
 39%|████████████████▎                         | 42/108 [00:03<00:04, 13.35it/s][A
 41%|█████████████████                         | 44/108 [00:03<00:04, 13.84it/s][A
 43%|█████████████████▉                        | 46/108 [00:03<00:04, 14.10it/s][A
 44%|██████████████████▋                       | 48/108 [00:03<00:04, 14.36it/s][A
 46%|███████████████████▍                      | 50/108 [00:03<00:04, 14.45it/s][A
 48%|████████████████████▏                     | 52/108 [00:04<00:04, 12.90it/s][A
 50%|█████████████████████                     | 54/108 [00:04<00:04, 13.34

Epoch: 10/12 Step: 1050 Training Loss: 0.2950 Validation Loss: 0.2820



 76%|███████████████████████████████▉          | 82/108 [00:06<00:02,  8.78it/s][A
 78%|████████████████████████████████▋         | 84/108 [00:06<00:02, 10.05it/s][A
 80%|█████████████████████████████████▍        | 86/108 [00:07<00:01, 11.14it/s][A
 81%|██████████████████████████████████▏       | 88/108 [00:07<00:01, 12.10it/s][A
 83%|███████████████████████████████████       | 90/108 [00:07<00:01, 12.87it/s][A
 85%|███████████████████████████████████▊      | 92/108 [00:07<00:01, 13.40it/s][A
 87%|████████████████████████████████████▌     | 94/108 [00:07<00:01, 13.99it/s][A
 89%|█████████████████████████████████████▎    | 96/108 [00:07<00:00, 14.38it/s][A
 91%|██████████████████████████████████████    | 98/108 [00:07<00:00, 14.62it/s][A
 93%|█████████████████████████████████████▉   | 100/108 [00:08<00:00, 14.86it/s][A
 94%|██████████████████████████████████████▋  | 102/108 [00:08<00:00, 14.97it/s][A
 96%|███████████████████████████████████████▍ | 104/108 [00:08<00:00, 15.02

Epoch: 11/12 Step: 1100 Training Loss: 0.0691 Validation Loss: 0.0202



 22%|█████████▎                                | 24/108 [00:02<00:09,  8.50it/s][A
 24%|██████████                                | 26/108 [00:02<00:08,  9.81it/s][A
 26%|██████████▉                               | 28/108 [00:02<00:07, 10.94it/s][A
 28%|███████████▋                              | 30/108 [00:02<00:06, 11.94it/s][A
 30%|████████████▍                             | 32/108 [00:02<00:05, 12.74it/s][A
 31%|█████████████▏                            | 34/108 [00:02<00:05, 13.31it/s][A
 33%|██████████████                            | 36/108 [00:03<00:05, 13.79it/s][A
 35%|██████████████▊                           | 38/108 [00:03<00:04, 14.16it/s][A
 37%|███████████████▌                          | 40/108 [00:03<00:04, 14.43it/s][A
 39%|████████████████▎                         | 42/108 [00:03<00:04, 14.57it/s][A
 41%|█████████████████                         | 44/108 [00:03<00:04, 14.70it/s][A
 43%|█████████████████▉                        | 46/108 [00:03<00:04, 14.87

Epoch: 11/12 Step: 1150 Training Loss: 0.1166 Validation Loss: 0.0767



 69%|████████████████████████████▊             | 74/108 [00:06<00:03,  8.79it/s][A
 70%|█████████████████████████████▌            | 76/108 [00:06<00:03, 10.07it/s][A
 72%|██████████████████████████████▎           | 78/108 [00:06<00:02, 11.18it/s][A
 74%|███████████████████████████████           | 80/108 [00:06<00:02, 12.06it/s][A
 76%|███████████████████████████████▉          | 82/108 [00:06<00:02, 12.82it/s][A
 78%|████████████████████████████████▋         | 84/108 [00:06<00:01, 13.40it/s][A
 80%|█████████████████████████████████▍        | 86/108 [00:07<00:01, 13.86it/s][A
 81%|██████████████████████████████████▏       | 88/108 [00:07<00:01, 14.17it/s][A
 83%|███████████████████████████████████       | 90/108 [00:07<00:01, 14.40it/s][A
 85%|███████████████████████████████████▊      | 92/108 [00:07<00:01, 14.53it/s][A
 87%|████████████████████████████████████▌     | 94/108 [00:07<00:00, 14.71it/s][A
 89%|█████████████████████████████████████▎    | 96/108 [00:07<00:00, 14.83

Epoch: 12/12 Step: 1200 Training Loss: 0.0153 Validation Loss: 0.0125



 15%|██████▏                                   | 16/108 [00:01<00:10,  8.61it/s][A
 17%|███████                                   | 18/108 [00:01<00:09,  9.90it/s][A
 19%|███████▊                                  | 20/108 [00:01<00:07, 11.07it/s][A
 20%|████████▌                                 | 22/108 [00:02<00:07, 11.94it/s][A
 22%|█████████▎                                | 24/108 [00:02<00:06, 12.74it/s][A
 24%|██████████                                | 26/108 [00:02<00:06, 13.32it/s][A
 26%|██████████▉                               | 28/108 [00:02<00:05, 13.83it/s][A
 28%|███████████▋                              | 30/108 [00:02<00:05, 14.16it/s][A
 30%|████████████▍                             | 32/108 [00:02<00:05, 14.39it/s][A
 31%|█████████████▏                            | 34/108 [00:02<00:05, 14.50it/s][A
 33%|██████████████                            | 36/108 [00:03<00:04, 14.72it/s][A
 35%|██████████████▊                           | 38/108 [00:03<00:04, 14.79

Epoch: 12/12 Step: 1250 Training Loss: 0.0047 Validation Loss: 0.0029



 61%|█████████████████████████▋                | 66/108 [00:05<00:04,  8.74it/s][A
 63%|██████████████████████████▍               | 68/108 [00:05<00:03, 10.02it/s][A
 65%|███████████████████████████▏              | 70/108 [00:05<00:03, 11.14it/s][A
 67%|████████████████████████████              | 72/108 [00:06<00:02, 12.12it/s][A
 69%|████████████████████████████▊             | 74/108 [00:06<00:02, 12.90it/s][A
 70%|█████████████████████████████▌            | 76/108 [00:06<00:02, 13.49it/s][A
 72%|██████████████████████████████▎           | 78/108 [00:06<00:02, 13.85it/s][A
 74%|███████████████████████████████           | 80/108 [00:06<00:01, 14.14it/s][A
 76%|███████████████████████████████▉          | 82/108 [00:06<00:01, 14.46it/s][A
 78%|████████████████████████████████▋         | 84/108 [00:06<00:01, 14.70it/s][A
 80%|█████████████████████████████████▍        | 86/108 [00:07<00:01, 14.89it/s][A
 81%|██████████████████████████████████▏       | 88/108 [00:07<00:01, 15.00

In [33]:
net.eval()
test_losses = []
num_correct = 0
test_h = net.init_hidden(BATCH_SIZE)
net.to(device)
for inputs, labels in test_loader:
    inputs, labels = inputs.to(device), labels.to(device)
    test_output, _ = net(inputs)
    loss = criterion(test_output, labels.float().unsqueeze(dim=1))
    test_losses.append(loss.item())

    preds = torch.round(test_output.squeeze())
    correct_tensor = preds.eq(labels.float().view_as(preds))
    correct = np.squeeze(correct_tensor.cpu().numpy())
    num_correct += np.sum(correct)

print("Test Loss: {:.4f}".format(np.mean(test_losses)))
print("Test Accuracy: {:.2f}".format(num_correct / len(test_loader.dataset)))

Test Loss: 1.2259
Test Accuracy: 0.70
