In [1]:
import torch
import torch.nn as nn
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from util import SST2Dataset, load_embedding_matrix

## A5.a

In [2]:
def collate_fn(batch):
    """
    Create a batch of data given a list of N sequences and labels. Sequences are stacked into a single tensor
    of shape (N, max_sequence_length), where max_sequence_length is the maximum length of any sequence in the
    batch. Sequences shorter than this length should be filled up with 0's. Also returns a tensor of shape (N, 1)
    containing the label of each sequence.

    :param batch: A list of size N, where each element is a tuple containing a sequence tensor and a single item
    tensor containing the true label of the sequence.

    :return: A tuple containing two tensors. The first tensor has shape (N, max_sequence_length) and contains all
    sequences. Sequences shorter than max_sequence_length are padded with 0s at the end. The second tensor
    has shape (N, 1) and contains all labels.
    """
    sentences, labels = zip(*batch)

    # torch.nn.utils.rnn.pad_sequence will find the max length in sentences and pad padding_value to it
    # set padding_value=0 as it's the token id for the padding token
    padded_sentences = torch.nn.utils.rnn.pad_sequence(sentences, batch_first=True, padding_value=0)
    
    # stack all labels into one tensor
    labels = torch.stack(labels)
    return padded_sentences, labels

## A5.b & A5.c

In [3]:
class RNNBinaryClassificationModel(nn.Module):
    def __init__(self, embedding_matrix, model_type):
        super().__init__()

        vocab_size = embedding_matrix.shape[0]
        embedding_dim = embedding_matrix.shape[1]

        # Construct embedding layer and initialize with given embedding matrix. Do not modify this code.
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
        self.embedding.weight.data = embedding_matrix

        hidden_size = 64
        if model_type == 'rnn':
            rnn = nn.RNN
        elif model_type == 'lstm':
            rnn = nn.LSTM
        elif model_type == 'gru':
            rnn = nn.GRU
        else:
            raise ValueError("model type must be in ('rnn', 'lstm', 'gru')")
            
        self.rnn = rnn(
            input_size=embedding_dim,
            hidden_size=hidden_size
        )
        self.fc = nn.Linear(hidden_size, 1) # output_dim = 1 because of binary classification
        self._loss = nn.CrossEntropyLoss()

    def forward(self, inputs):
        """
        Takes in a batch of data of shape (N, max_sequence_length). Returns a tensor of shape (N, 1), where each
        element corresponds to the prediction for the corresponding sequence.
        :param inputs: Tensor of shape (N, max_sequence_length) containing N sequences to make predictions for.
        :return: Tensor of predictions for each sequence of shape (N, 1).
        """
        
        embedding = self.embedding(inputs)
        rnn_output, hdn = self.rnn(embedding)
        logits = self.fc(rnn_output)
        return logits

    def loss(self, logits, targets):
        """
        Computes the binary cross-entropy loss.
        :param logits: Raw predictions from the model of shape (N, 1)
        :param targets: True labels of shape (N, 1)
        :return: Binary cross entropy loss between logits and targets as a scalar tensor.
        """
        return self._loss(logits, targets)

    def accuracy(self, logits, targets):
        """
        Computes the accuracy, i.e number of correct predictions / N.
        :param logits: Raw predictions from the model of shape (N, 1)
        :param targets: True labels of shape (N, 1)
        :return: Accuracy as a scalar tensor.
        """
        
        pred = torch.argmax(logits, axis=1)
        acc = (pred == targets).double().mean()
        return acc

## A5.d

In [4]:
# Training parameters
TRAINING_BATCH_SIZE = 32
NUM_EPOCHS = 16
LEARNING_RATE = 1e-3

# Batch size for validation, this only affects performance.
VAL_BATCH_SIZE = 128

In [5]:
def train(model_type):
    # Load datasets
    train_dataset = SST2Dataset("./SST-2/train.tsv")
    val_dataset = SST2Dataset("./SST-2/dev.tsv", train_dataset.vocab, train_dataset.reverse_vocab)

    # Create data loaders for creating and iterating over batches
    train_loader = DataLoader(train_dataset, batch_size=TRAINING_BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=VAL_BATCH_SIZE, collate_fn=collate_fn)

    # Print out some random examples from the data
    print("Data examples:")
    random_indices = torch.randperm(len(train_dataset))[:8].tolist()
    for index in random_indices:
        sequence_indices, label = train_dataset.sentences[index], train_dataset.labels[index]
        sentiment = "Positive" if label == 1 else "Negative"
        sequence = train_dataset.indices_to_tokens(sequence_indices)
        print(f"Sentiment: {sentiment}. Sentence: {sequence}")
    print()

    embedding_matrix = load_embedding_matrix(train_dataset.vocab)

    model = RNNBinaryClassificationModel(embedding_matrix, model_type)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    for epoch in range(NUM_EPOCHS):
        # Total loss across train data
        train_loss = 0.
        # Total number of correctly predicted training labels
        train_correct = 0
        # Total number of training sequences processed
        train_seqs = 0

        tqdm_train_loader = tqdm(train_loader)
        print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")

        model.train()
        for batch_idx, batch in enumerate(tqdm_train_loader):
            sentences_batch, labels_batch = batch

            # Make predictions
            logits = model(sentences_batch)

            # Compute loss and number of correct predictions
            loss = model.loss(logits, labels_batch)
            correct = model.accuracy(logits, labels_batch).item() * len(logits)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Accumulate metrics and update status
            train_loss += loss.item()
            train_correct += correct
            train_seqs += len(sentences_batch)
            tqdm_train_loader.set_description_str(
                f"[Loss]: {train_loss / (batch_idx + 1):.4f} [Acc]: {train_correct / train_seqs:.4f}")
        print()

        avg_train_loss = train_loss / len(tqdm_train_loader)
        train_accuracy = train_correct / train_seqs
        print(f"[Training Loss]: {avg_train_loss:.4f} [Training Accuracy]: {train_accuracy:.4f}")

        print("Validating")
        # Total loss across validation data
        val_loss = 0.
        # Total number of correctly predicted validation labels
        val_correct = 0
        # Total number of validation sequences processed
        val_seqs = 0

        tqdm_val_loader = tqdm(val_loader)

        model.eval()
        for batch_idx, batch in enumerate(tqdm_val_loader):
            sentences_batch, labels_batch = batch

            with torch.no_grad():
                # Make predictions
                logits = model(sentences_batch)

                # Compute loss and number of correct predictions and accumulate metrics and update status
                val_loss += model.loss(logits, labels_batch).item()
                val_correct += model.accuracy(logits, labels_batch).item() * len(logits)
                val_seqs += len(sentences_batch)
                tqdm_val_loader.set_description_str(
                    f"[Loss]: {val_loss / (batch_idx + 1):.4f} [Acc]: {val_correct / val_seqs:.4f}")
        print()

        avg_val_loss = val_loss / len(tqdm_val_loader)
        val_accuracy = val_correct / val_seqs
        print(f"[Validation Loss]: {avg_val_loss:.4f} [Validation Accuracy]: {val_accuracy:.4f}")

### A5.d RNN

In [6]:
train('rnn')

  indexed_sentences = [torch.tensor(self.tokens_to_indices(sentence)) for sentence in sentences]


Data examples:
Sentiment: Positive. Sentence: an accomplished actress
Sentiment: Positive. Sentence: a successful example of the lovable-loser protagonist
Sentiment: Positive. Sentence: a vivid , sometimes surreal , glimpse into the mysteries of human behavior .
Sentiment: Negative. Sentence: walked out of runteldat
Sentiment: Negative. Sentence: relentlessly saccharine
Sentiment: Positive. Sentence: this charming but slight tale has warmth , wit and interesting characters compassionately portrayed .
Sentiment: Positive. Sentence: last year 's kubrick-meets-spielberg exercise
Sentiment: Negative. Sentence: substitute plot for personality



  allow_unreachable=True)  # allow_unreachable flag
[Loss]: 2.7732 [Acc]: 0.3109:   0%|          | 10/2105 [00:00<00:22, 92.21it/s]

Epoch 1/16


[Loss]: 1.2783 [Acc]: 0.4337: 100%|██████████| 2105/2105 [00:18<00:00, 114.70it/s]
[Loss]: 1.6386 [Acc]: 0.3991: 100%|██████████| 7/7 [00:00<00:00, 157.73it/s]
[Loss]: 1.0320 [Acc]: 0.4833:   0%|          | 10/2105 [00:00<00:20, 99.99it/s]


[Training Loss]: 1.2783 [Training Accuracy]: 0.4337
Validating

[Validation Loss]: 1.6386 [Validation Accuracy]: 0.3991
Epoch 2/16


[Loss]: 1.1298 [Acc]: 0.4686: 100%|██████████| 2105/2105 [00:19<00:00, 110.27it/s]
[Loss]: 1.9774 [Acc]: 0.3200: 100%|██████████| 7/7 [00:00<00:00, 131.06it/s]
[Loss]: 1.0117 [Acc]: 0.5184:   1%|          | 12/2105 [00:00<00:18, 115.44it/s]


[Training Loss]: 1.1298 [Training Accuracy]: 0.4686
Validating

[Validation Loss]: 1.9774 [Validation Accuracy]: 0.3200
Epoch 3/16


[Loss]: 1.0583 [Acc]: 0.4924: 100%|██████████| 2105/2105 [00:18<00:00, 112.55it/s]
[Loss]: 2.0904 [Acc]: 0.2741: 100%|██████████| 7/7 [00:00<00:00, 131.60it/s]
[Loss]: 0.9900 [Acc]: 0.5078:   1%|          | 12/2105 [00:00<00:18, 113.06it/s]


[Training Loss]: 1.0583 [Training Accuracy]: 0.4924
Validating

[Validation Loss]: 2.0904 [Validation Accuracy]: 0.2741
Epoch 4/16


[Loss]: 1.0249 [Acc]: 0.5116: 100%|██████████| 2105/2105 [00:18<00:00, 110.98it/s]
[Loss]: 1.8061 [Acc]: 0.3383: 100%|██████████| 7/7 [00:00<00:00, 132.82it/s]
[Loss]: 1.0063 [Acc]: 0.5246:   1%|          | 11/2105 [00:00<00:20, 103.00it/s]


[Training Loss]: 1.0249 [Training Accuracy]: 0.5116
Validating

[Validation Loss]: 1.8061 [Validation Accuracy]: 0.3383
Epoch 5/16


[Loss]: 0.9830 [Acc]: 0.5312: 100%|██████████| 2105/2105 [00:18<00:00, 111.51it/s]
[Loss]: 2.2335 [Acc]: 0.2420: 100%|██████████| 7/7 [00:00<00:00, 157.05it/s]
[Loss]: 1.0523 [Acc]: 0.5139:   1%|          | 12/2105 [00:00<00:17, 117.65it/s]


[Training Loss]: 0.9830 [Training Accuracy]: 0.5312
Validating

[Validation Loss]: 2.2335 [Validation Accuracy]: 0.2420
Epoch 6/16


[Loss]: 0.9614 [Acc]: 0.5462: 100%|██████████| 2105/2105 [00:18<00:00, 112.09it/s]
[Loss]: 2.3162 [Acc]: 0.1606: 100%|██████████| 7/7 [00:00<00:00, 133.37it/s]
[Loss]: 0.9612 [Acc]: 0.5469:   0%|          | 10/2105 [00:00<00:21, 99.41it/s]


[Training Loss]: 0.9614 [Training Accuracy]: 0.5462
Validating

[Validation Loss]: 2.3162 [Validation Accuracy]: 0.1606
Epoch 7/16


[Loss]: 0.9289 [Acc]: 0.5552: 100%|██████████| 2105/2105 [00:18<00:00, 113.14it/s]
[Loss]: 1.8818 [Acc]: 0.2970: 100%|██████████| 7/7 [00:00<00:00, 133.68it/s]
[Loss]: 0.8537 [Acc]: 0.5848:   0%|          | 10/2105 [00:00<00:21, 96.86it/s]


[Training Loss]: 0.9289 [Training Accuracy]: 0.5552
Validating

[Validation Loss]: 1.8818 [Validation Accuracy]: 0.2970
Epoch 8/16


[Loss]: 0.9204 [Acc]: 0.5626: 100%|██████████| 2105/2105 [00:18<00:00, 111.47it/s]
[Loss]: 1.8665 [Acc]: 0.2890: 100%|██████████| 7/7 [00:00<00:00, 162.45it/s]
[Loss]: 0.8964 [Acc]: 0.5799:   1%|          | 12/2105 [00:00<00:17, 119.22it/s]


[Training Loss]: 0.9204 [Training Accuracy]: 0.5626
Validating

[Validation Loss]: 1.8665 [Validation Accuracy]: 0.2890
Epoch 9/16


[Loss]: 0.9053 [Acc]: 0.5724: 100%|██████████| 2105/2105 [00:18<00:00, 110.86it/s]
[Loss]: 1.9836 [Acc]: 0.1927: 100%|██████████| 7/7 [00:00<00:00, 158.76it/s]
[Loss]: 0.9247 [Acc]: 0.5386:   1%|          | 12/2105 [00:00<00:18, 112.73it/s]


[Training Loss]: 0.9053 [Training Accuracy]: 0.5724
Validating

[Validation Loss]: 1.9836 [Validation Accuracy]: 0.1927
Epoch 10/16


[Loss]: 0.8801 [Acc]: 0.5900: 100%|██████████| 2105/2105 [00:18<00:00, 111.24it/s]
[Loss]: 2.0062 [Acc]: 0.2351: 100%|██████████| 7/7 [00:00<00:00, 159.10it/s]
[Loss]: 0.7792 [Acc]: 0.6465:   1%|          | 11/2105 [00:00<00:19, 106.76it/s]


[Training Loss]: 0.8801 [Training Accuracy]: 0.5900
Validating

[Validation Loss]: 2.0062 [Validation Accuracy]: 0.2351
Epoch 11/16


[Loss]: 0.8761 [Acc]: 0.5908: 100%|██████████| 2105/2105 [00:18<00:00, 112.09it/s]
[Loss]: 1.9882 [Acc]: 0.2385: 100%|██████████| 7/7 [00:00<00:00, 160.05it/s]
[Loss]: 0.8780 [Acc]: 0.5799:   1%|          | 12/2105 [00:00<00:17, 118.80it/s]


[Training Loss]: 0.8761 [Training Accuracy]: 0.5908
Validating

[Validation Loss]: 1.9882 [Validation Accuracy]: 0.2385
Epoch 12/16


[Loss]: 0.8576 [Acc]: 0.6000: 100%|██████████| 2105/2105 [00:19<00:00, 110.68it/s]
[Loss]: 2.2026 [Acc]: 0.1915: 100%|██████████| 7/7 [00:00<00:00, 133.23it/s]
[Loss]: 0.8349 [Acc]: 0.6188:   1%|          | 11/2105 [00:00<00:19, 108.43it/s]


[Training Loss]: 0.8576 [Training Accuracy]: 0.6000
Validating

[Validation Loss]: 2.2026 [Validation Accuracy]: 0.1915
Epoch 13/16


[Loss]: 0.8294 [Acc]: 0.6235: 100%|██████████| 2105/2105 [00:19<00:00, 109.74it/s]
[Loss]: 2.5196 [Acc]: 0.1388: 100%|██████████| 7/7 [00:00<00:00, 146.70it/s]
[Loss]: 0.8298 [Acc]: 0.6445:   1%|          | 11/2105 [00:00<00:19, 109.65it/s]


[Training Loss]: 0.8294 [Training Accuracy]: 0.6235
Validating

[Validation Loss]: 2.5196 [Validation Accuracy]: 0.1388
Epoch 14/16


[Loss]: 0.8158 [Acc]: 0.6333: 100%|██████████| 2105/2105 [00:18<00:00, 111.49it/s]
[Loss]: 2.1827 [Acc]: 0.1984: 100%|██████████| 7/7 [00:00<00:00, 156.48it/s]
[Loss]: 0.7376 [Acc]: 0.6816:   1%|          | 12/2105 [00:00<00:18, 112.88it/s]


[Training Loss]: 0.8158 [Training Accuracy]: 0.6333
Validating

[Validation Loss]: 2.1827 [Validation Accuracy]: 0.1984
Epoch 15/16


[Loss]: 0.8013 [Acc]: 0.6452: 100%|██████████| 2105/2105 [00:18<00:00, 111.85it/s]
[Loss]: 2.0722 [Acc]: 0.2477: 100%|██████████| 7/7 [00:00<00:00, 131.72it/s]
[Loss]: 0.8778 [Acc]: 0.6317:   0%|          | 10/2105 [00:00<00:21, 98.68it/s]


[Training Loss]: 0.8013 [Training Accuracy]: 0.6452
Validating

[Validation Loss]: 2.0722 [Validation Accuracy]: 0.2477
Epoch 16/16


[Loss]: 0.7813 [Acc]: 0.6514: 100%|██████████| 2105/2105 [00:18<00:00, 112.36it/s]
[Loss]: 2.2148 [Acc]: 0.1984: 100%|██████████| 7/7 [00:00<00:00, 162.86it/s]


[Training Loss]: 0.7813 [Training Accuracy]: 0.6514
Validating

[Validation Loss]: 2.2148 [Validation Accuracy]: 0.1984





### A5.d LSTM

In [7]:
train('lstm')

Data examples:
Sentiment: Positive. Sentence: for people who make movies and watch them
Sentiment: Positive. Sentence: a crisp psychological drama ( and ) a fascinating little thriller that would have been perfect for an old `` twilight zone '' episode .
Sentiment: Positive. Sentence: captures that perverse element of the kafkaesque where identity , overnight , is robbed and replaced with a persecuted `` other
Sentiment: Positive. Sentence: 's a very very strong `` b + . ''
Sentiment: Positive. Sentence: did n't hate this one
Sentiment: Positive. Sentence: australia is a weirdly beautiful place
Sentiment: Negative. Sentence: hymn and a cruel story of youth culture
Sentiment: Positive. Sentence: somehow manages to bring together kevin pollak , former wrestler chyna and dolly parton



[Loss]: 3.1969 [Acc]: 0.3693:   0%|          | 6/2105 [00:00<00:37, 56.26it/s]

Epoch 1/16


[Loss]: 1.1920 [Acc]: 0.4450: 100%|██████████| 2105/2105 [00:37<00:00, 56.59it/s]
[Loss]: 1.8124 [Acc]: 0.3394: 100%|██████████| 7/7 [00:00<00:00, 66.44it/s]
[Loss]: 1.1424 [Acc]: 0.4750:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 1.1920 [Training Accuracy]: 0.4450
Validating

[Validation Loss]: 1.8124 [Validation Accuracy]: 0.3394
Epoch 2/16


[Loss]: 1.0433 [Acc]: 0.4730: 100%|██████████| 2105/2105 [00:37<00:00, 56.31it/s]
[Loss]: 1.5409 [Acc]: 0.3509: 100%|██████████| 7/7 [00:00<00:00, 57.82it/s]
[Loss]: 0.9436 [Acc]: 0.5469:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 1.0433 [Training Accuracy]: 0.4730
Validating

[Validation Loss]: 1.5409 [Validation Accuracy]: 0.3509
Epoch 3/16


[Loss]: 0.9828 [Acc]: 0.5197: 100%|██████████| 2105/2105 [00:37<00:00, 55.92it/s]
[Loss]: 1.5742 [Acc]: 0.3291: 100%|██████████| 7/7 [00:00<00:00, 67.71it/s]
[Loss]: 0.8485 [Acc]: 0.5312:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.9828 [Training Accuracy]: 0.5197
Validating

[Validation Loss]: 1.5742 [Validation Accuracy]: 0.3291
Epoch 4/16


[Loss]: 0.9014 [Acc]: 0.5800: 100%|██████████| 2105/2105 [00:37<00:00, 55.73it/s]
[Loss]: 1.4555 [Acc]: 0.3876: 100%|██████████| 7/7 [00:00<00:00, 66.63it/s]
[Loss]: 0.8222 [Acc]: 0.5687:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.9014 [Training Accuracy]: 0.5800
Validating

[Validation Loss]: 1.4555 [Validation Accuracy]: 0.3876
Epoch 5/16


[Loss]: 0.8349 [Acc]: 0.6174: 100%|██████████| 2105/2105 [00:37<00:00, 56.59it/s]
[Loss]: 1.5684 [Acc]: 0.3486: 100%|██████████| 7/7 [00:00<00:00, 57.31it/s]
[Loss]: 0.7414 [Acc]: 0.6771:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.8349 [Training Accuracy]: 0.6174
Validating

[Validation Loss]: 1.5684 [Validation Accuracy]: 0.3486
Epoch 6/16


[Loss]: 0.7896 [Acc]: 0.6469: 100%|██████████| 2105/2105 [00:37<00:00, 56.55it/s]
[Loss]: 1.4706 [Acc]: 0.3911: 100%|██████████| 7/7 [00:00<00:00, 73.87it/s]
[Loss]: 0.7029 [Acc]: 0.6927:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.7896 [Training Accuracy]: 0.6469
Validating

[Validation Loss]: 1.4706 [Validation Accuracy]: 0.3911
Epoch 7/16


[Loss]: 0.7498 [Acc]: 0.6655: 100%|██████████| 2105/2105 [00:37<00:00, 56.16it/s]
[Loss]: 1.7741 [Acc]: 0.3028: 100%|██████████| 7/7 [00:00<00:00, 58.21it/s]
[Loss]: 0.6810 [Acc]: 0.7500:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.7498 [Training Accuracy]: 0.6655
Validating

[Validation Loss]: 1.7741 [Validation Accuracy]: 0.3028
Epoch 8/16


[Loss]: 0.7227 [Acc]: 0.6766: 100%|██████████| 2105/2105 [00:37<00:00, 56.49it/s]
[Loss]: 1.8631 [Acc]: 0.3050: 100%|██████████| 7/7 [00:00<00:00, 73.39it/s]
[Loss]: 0.8091 [Acc]: 0.6813:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.7227 [Training Accuracy]: 0.6766
Validating

[Validation Loss]: 1.8631 [Validation Accuracy]: 0.3050
Epoch 9/16


[Loss]: 0.6997 [Acc]: 0.6899: 100%|██████████| 2105/2105 [00:37<00:00, 55.87it/s]
[Loss]: 1.9145 [Acc]: 0.2993: 100%|██████████| 7/7 [00:00<00:00, 67.85it/s]
[Loss]: 0.8334 [Acc]: 0.6750:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.6997 [Training Accuracy]: 0.6899
Validating

[Validation Loss]: 1.9145 [Validation Accuracy]: 0.2993
Epoch 10/16


[Loss]: 0.6779 [Acc]: 0.6999: 100%|██████████| 2105/2105 [00:37<00:00, 56.02it/s]
[Loss]: 1.7478 [Acc]: 0.3555: 100%|██████████| 7/7 [00:00<00:00, 66.19it/s]
[Loss]: 0.6170 [Acc]: 0.7109:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.6779 [Training Accuracy]: 0.6999
Validating

[Validation Loss]: 1.7478 [Validation Accuracy]: 0.3555
Epoch 11/16


[Loss]: 0.6619 [Acc]: 0.7038: 100%|██████████| 2105/2105 [00:37<00:00, 56.75it/s]
[Loss]: 1.7875 [Acc]: 0.3498: 100%|██████████| 7/7 [00:00<00:00, 72.78it/s]
[Loss]: 0.5151 [Acc]: 0.7552:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.6619 [Training Accuracy]: 0.7038
Validating

[Validation Loss]: 1.7875 [Validation Accuracy]: 0.3498
Epoch 12/16


[Loss]: 0.6551 [Acc]: 0.7083: 100%|██████████| 2105/2105 [00:37<00:00, 55.58it/s]
[Loss]: 1.7439 [Acc]: 0.3463: 100%|██████████| 7/7 [00:00<00:00, 73.79it/s]
[Loss]: 0.7172 [Acc]: 0.6875:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.6551 [Training Accuracy]: 0.7083
Validating

[Validation Loss]: 1.7439 [Validation Accuracy]: 0.3463
Epoch 13/16


[Loss]: 0.6494 [Acc]: 0.7117: 100%|██████████| 2105/2105 [00:37<00:00, 55.88it/s]
[Loss]: 2.0064 [Acc]: 0.2878: 100%|██████████| 7/7 [00:00<00:00, 73.08it/s]
[Loss]: 0.6622 [Acc]: 0.6687:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.6494 [Training Accuracy]: 0.7117
Validating

[Validation Loss]: 2.0064 [Validation Accuracy]: 0.2878
Epoch 14/16


[Loss]: 0.6390 [Acc]: 0.7145: 100%|██████████| 2105/2105 [00:37<00:00, 55.77it/s]
[Loss]: 1.6577 [Acc]: 0.3761: 100%|██████████| 7/7 [00:00<00:00, 64.67it/s]
[Loss]: 0.8722 [Acc]: 0.6375:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.6390 [Training Accuracy]: 0.7145
Validating

[Validation Loss]: 1.6577 [Validation Accuracy]: 0.3761
Epoch 15/16


[Loss]: 0.6318 [Acc]: 0.7178: 100%|██████████| 2105/2105 [00:37<00:00, 56.35it/s]
[Loss]: 1.8529 [Acc]: 0.3062: 100%|██████████| 7/7 [00:00<00:00, 57.67it/s]
[Loss]: 0.5068 [Acc]: 0.8125:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.6318 [Training Accuracy]: 0.7178
Validating

[Validation Loss]: 1.8529 [Validation Accuracy]: 0.3062
Epoch 16/16


[Loss]: 0.6203 [Acc]: 0.7231: 100%|██████████| 2105/2105 [00:37<00:00, 55.98it/s]
[Loss]: 1.6981 [Acc]: 0.3578: 100%|██████████| 7/7 [00:00<00:00, 73.61it/s]


[Training Loss]: 0.6203 [Training Accuracy]: 0.7231
Validating

[Validation Loss]: 1.6981 [Validation Accuracy]: 0.3578





### A5.d GRU

In [8]:
train('gru')

Data examples:
Sentiment: Negative. Sentence: the film 's ending has a `` what was it all for
Sentiment: Positive. Sentence: smooth , shrewd , powerful act
Sentiment: Negative. Sentence: ( toback 's ) fondness for fancy split-screen , stuttering editing and pompous references to wittgenstein and kirkegaard ...
Sentiment: Positive. Sentence: a testament to de niro and director michael caton-jones
Sentiment: Negative. Sentence: die
Sentiment: Positive. Sentence: make us care about zelda 's ultimate fate
Sentiment: Negative. Sentence: you have left the theater
Sentiment: Negative. Sentence: feels impersonal , almost generic



[Loss]: 2.9010 [Acc]: 0.3608:   0%|          | 6/2105 [00:00<00:35, 59.17it/s]

Epoch 1/16


[Loss]: 1.1632 [Acc]: 0.4417: 100%|██████████| 2105/2105 [00:39<00:00, 53.74it/s]
[Loss]: 2.0271 [Acc]: 0.3314: 100%|██████████| 7/7 [00:00<00:00, 66.45it/s]
[Loss]: 1.0974 [Acc]: 0.4313:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 1.1632 [Training Accuracy]: 0.4417
Validating

[Validation Loss]: 2.0271 [Validation Accuracy]: 0.3314
Epoch 2/16


[Loss]: 1.0428 [Acc]: 0.4678: 100%|██████████| 2105/2105 [00:39<00:00, 52.93it/s]
[Loss]: 1.8443 [Acc]: 0.2970: 100%|██████████| 7/7 [00:00<00:00, 83.93it/s]
[Loss]: 1.0719 [Acc]: 0.4821:   0%|          | 7/2105 [00:00<00:33, 61.93it/s]


[Training Loss]: 1.0428 [Training Accuracy]: 0.4678
Validating

[Validation Loss]: 1.8443 [Validation Accuracy]: 0.2970
Epoch 3/16


[Loss]: 0.9847 [Acc]: 0.4978: 100%|██████████| 2105/2105 [00:39<00:00, 53.33it/s]
[Loss]: 1.7081 [Acc]: 0.3830: 100%|██████████| 7/7 [00:00<00:00, 67.33it/s]
[Loss]: 0.9242 [Acc]: 0.5391:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.9847 [Training Accuracy]: 0.4978
Validating

[Validation Loss]: 1.7081 [Validation Accuracy]: 0.3830
Epoch 4/16


[Loss]: 0.9359 [Acc]: 0.5425: 100%|██████████| 2105/2105 [00:39<00:00, 53.17it/s]
[Loss]: 1.5964 [Acc]: 0.3819: 100%|██████████| 7/7 [00:00<00:00, 83.60it/s]
[Loss]: 0.8862 [Acc]: 0.6953:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.9359 [Training Accuracy]: 0.5425
Validating

[Validation Loss]: 1.5964 [Validation Accuracy]: 0.3819
Epoch 5/16


[Loss]: 0.8668 [Acc]: 0.6016: 100%|██████████| 2105/2105 [00:39<00:00, 53.63it/s]
[Loss]: 1.6050 [Acc]: 0.4128: 100%|██████████| 7/7 [00:00<00:00, 84.19it/s]
[Loss]: 0.7488 [Acc]: 0.6615:   0%|          | 6/2105 [00:00<00:35, 58.71it/s]


[Training Loss]: 0.8668 [Training Accuracy]: 0.6016
Validating

[Validation Loss]: 1.6050 [Validation Accuracy]: 0.4128
Epoch 6/16


[Loss]: 0.8227 [Acc]: 0.6307: 100%|██████████| 2105/2105 [00:39<00:00, 53.40it/s]
[Loss]: 1.5752 [Acc]: 0.4255: 100%|██████████| 7/7 [00:00<00:00, 73.71it/s]
[Loss]: 0.7130 [Acc]: 0.6875:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.8227 [Training Accuracy]: 0.6307
Validating

[Validation Loss]: 1.5752 [Validation Accuracy]: 0.4255
Epoch 7/16


[Loss]: 0.7709 [Acc]: 0.6553: 100%|██████████| 2105/2105 [00:38<00:00, 54.19it/s]
[Loss]: 1.5890 [Acc]: 0.3865: 100%|██████████| 7/7 [00:00<00:00, 84.18it/s]
[Loss]: 0.8010 [Acc]: 0.6607:   0%|          | 7/2105 [00:00<00:33, 61.79it/s]


[Training Loss]: 0.7709 [Training Accuracy]: 0.6553
Validating

[Validation Loss]: 1.5890 [Validation Accuracy]: 0.3865
Epoch 8/16


[Loss]: 0.7427 [Acc]: 0.6700: 100%|██████████| 2105/2105 [00:39<00:00, 53.75it/s]
[Loss]: 1.6043 [Acc]: 0.3842: 100%|██████████| 7/7 [00:00<00:00, 80.52it/s]
[Loss]: 0.6752 [Acc]: 0.6927:   0%|          | 6/2105 [00:00<00:35, 58.80it/s]


[Training Loss]: 0.7427 [Training Accuracy]: 0.6700
Validating

[Validation Loss]: 1.6043 [Validation Accuracy]: 0.3842
Epoch 9/16


[Loss]: 0.7162 [Acc]: 0.6831: 100%|██████████| 2105/2105 [00:39<00:00, 53.94it/s]
[Loss]: 1.6269 [Acc]: 0.3693: 100%|██████████| 7/7 [00:00<00:00, 82.41it/s]
[Loss]: 0.7326 [Acc]: 0.6510:   0%|          | 6/2105 [00:00<00:35, 58.88it/s]


[Training Loss]: 0.7162 [Training Accuracy]: 0.6831
Validating

[Validation Loss]: 1.6269 [Validation Accuracy]: 0.3693
Epoch 10/16


[Loss]: 0.6981 [Acc]: 0.6924: 100%|██████████| 2105/2105 [00:39<00:00, 52.95it/s]
[Loss]: 1.7012 [Acc]: 0.3704: 100%|██████████| 7/7 [00:00<00:00, 83.18it/s]
[Loss]: 0.5829 [Acc]: 0.7708:   0%|          | 6/2105 [00:00<00:38, 54.67it/s]


[Training Loss]: 0.6981 [Training Accuracy]: 0.6924
Validating

[Validation Loss]: 1.7012 [Validation Accuracy]: 0.3704
Epoch 11/16


[Loss]: 0.6817 [Acc]: 0.6984: 100%|██████████| 2105/2105 [00:39<00:00, 52.78it/s]
[Loss]: 1.6315 [Acc]: 0.4014: 100%|██████████| 7/7 [00:00<00:00, 76.99it/s]
[Loss]: 0.6762 [Acc]: 0.7063:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.6817 [Training Accuracy]: 0.6984
Validating

[Validation Loss]: 1.6315 [Validation Accuracy]: 0.4014
Epoch 12/16


[Loss]: 0.6708 [Acc]: 0.7018: 100%|██████████| 2105/2105 [00:39<00:00, 53.37it/s]
[Loss]: 1.6698 [Acc]: 0.3612: 100%|██████████| 7/7 [00:00<00:00, 65.14it/s]
[Loss]: 0.5053 [Acc]: 0.7500:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.6708 [Training Accuracy]: 0.7018
Validating

[Validation Loss]: 1.6698 [Validation Accuracy]: 0.3612
Epoch 13/16


[Loss]: 0.6532 [Acc]: 0.7069: 100%|██████████| 2105/2105 [00:39<00:00, 53.23it/s]
[Loss]: 1.7220 [Acc]: 0.3888: 100%|██████████| 7/7 [00:00<00:00, 65.18it/s]
[Loss]: 0.6374 [Acc]: 0.6953:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.6532 [Training Accuracy]: 0.7069
Validating

[Validation Loss]: 1.7220 [Validation Accuracy]: 0.3888
Epoch 14/16


[Loss]: 0.6518 [Acc]: 0.7114: 100%|██████████| 2105/2105 [00:39<00:00, 52.77it/s]
[Loss]: 1.8124 [Acc]: 0.3452: 100%|██████████| 7/7 [00:00<00:00, 84.18it/s]
[Loss]: 0.6787 [Acc]: 0.7135:   0%|          | 6/2105 [00:00<00:38, 54.81it/s]


[Training Loss]: 0.6518 [Training Accuracy]: 0.7114
Validating

[Validation Loss]: 1.8124 [Validation Accuracy]: 0.3452
Epoch 15/16


[Loss]: 0.6439 [Acc]: 0.7137: 100%|██████████| 2105/2105 [00:39<00:00, 52.75it/s]
[Loss]: 1.7607 [Acc]: 0.3739: 100%|██████████| 7/7 [00:00<00:00, 73.14it/s]
[Loss]: 0.6970 [Acc]: 0.7266:   0%|          | 0/2105 [00:00<?, ?it/s]


[Training Loss]: 0.6439 [Training Accuracy]: 0.7137
Validating

[Validation Loss]: 1.7607 [Validation Accuracy]: 0.3739
Epoch 16/16


[Loss]: 0.6345 [Acc]: 0.7183: 100%|██████████| 2105/2105 [00:39<00:00, 53.07it/s]
[Loss]: 1.7518 [Acc]: 0.3876: 100%|██████████| 7/7 [00:00<00:00, 85.53it/s]


[Training Loss]: 0.6345 [Training Accuracy]: 0.7183
Validating

[Validation Loss]: 1.7518 [Validation Accuracy]: 0.3876





## A5.e

An RNN or LSTM have the advantage of "remembering" the past inputs, to improve performance over prediction of a time-series data. If you use a neural network over like the past 500 characters, this may work but the network just treat the data as a bunch of data without any specific indication of time. The network can learn the time representation only through gradient descent. RNN or LSTM however have "time" as a mechanism built into the model. The model loops through the model sequentially and have a real "sense of time" even before the model is trained. The model also have "memory" of previous data points to help the prediction. The architecture is based on the progress of time and the gradient are propagated through time as well. This is a much more intuitive way to process time-series data.

However in this problem of text classification, we don't have to assume such time-series structure of input. And the downside of such assumption only allows us use past information at a given word while we can definitely leverage the information from the entire sentence to make classfication decision.

## A5.f

The reason why it makes sense to feed the final hidden state of an RNN into the classification layer is because we can view the final hidden state of an RNN as an "embedding" of the input sequence, and hence it contains information about this input sequence and we can use this embedding in our classification layer. 

However, like mentioned in the above question, RNN only tends to remember information that are closer to the current state, and hence the information provided by the final hidden state might, to some extent, lose information from early tokens in the sentence. This can be considered as a sort of information vanish.

## A5.g

While RNN goes through each token, we can save the hidden state, and feed it to the classification layer, and compute the loss using the logit of current token tag prediction and its actual tag. Then when the RNN goes to the next token, we repeat the same procedure, i.e., get the hidden state, feed it to the classification layer, compute loss. We repeat this process until the end of the sentence. This way we'll have tags predicted for all tokens.

## A5.h

Funniest review:
    
makes a joke out of car chases for an hour and then gives us half an hour of car chases 