Assignment 3 : Sequence labelling with RNNs
In this assignement we will ask you to perform POS tagging.

You are asked to follow these steps:

- Download the corpora and split it in training and test sets, structuring a dataframe.
- Embed the words using GloVe embeddings
- Create a baseline model, using a simple neural architecture
- Experiment doing small modifications to the model
- Evaluate your best model
- Analyze the errors of your model
- Corpora: Ignore the numeric value in the third column, use only the words/symbols and its label. https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/dependency_treebank.zip

Splits: documents 1-100 are the train set, 101-150 validation set, 151-199 test set.

Baseline: two layers architecture: a Bidirectional LSTM and a Dense/Fully-Connected layer on top.

Modifications: experiment using a GRU instead of the LSTM, adding an additional LSTM layer, and using a CRF in addition to the LSTM. Each of this change must be done by itself (don't mix these modifications).

Training and Experiments: all the experiments must involve only the training and validation sets.

Evaluation: in the end, only the best model of your choice must be evaluated on the test set. The main metric must be F1-Macro computed between the various part of speech (without considering punctuation classes).

Error Analysis (optional) : analyze the errors done by your model, try to understand which may be the causes and think about how to improve it.

Report: You are asked to deliver a small report of about 4-5 lines in the .txt file that sums up your findings.

https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/dependency_treebank.zip

# Sequence labeling

In [128]:
import re
import random
import time
import string
from functools import partial

import pandas as pd
import numpy as np
import sklearn
import nltk
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from tqdm import tqdm
from nltk.corpus import dependency_treebank

import utils

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
nltk.download("dependency_treebank")

[nltk_data] Downloading package dependency_treebank to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package dependency_treebank is already up-to-date!


True

In [134]:
RANDOM_SEED = 42


def fix_random(seed):
    """
    Fix all the possible sources of randomness
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


fix_random(RANDOM_SEED)

In [135]:
file_prefix = "wsj_"
file_ext = ".dp"
train_files = [f"{file_prefix}{i:04d}{file_ext}" for i in range(1, 101)]
val_files = [f"{file_prefix}{i:04d}{file_ext}" for i in range(101, 151)]
test_files = [f"{file_prefix}{i:04d}{file_ext}" for i in range(151, 200)]
splits = (
    ["train"] * len(train_files) + ["val"] * len(val_files) + ["test"] * len(test_files)
)
whole_files = train_files + val_files + test_files

In [136]:
whole_files[0:-1:20]

['wsj_0001.dp',
 'wsj_0021.dp',
 'wsj_0041.dp',
 'wsj_0061.dp',
 'wsj_0081.dp',
 'wsj_0101.dp',
 'wsj_0121.dp',
 'wsj_0141.dp',
 'wsj_0161.dp',
 'wsj_0181.dp']

In [137]:
dependency_treebank.raw(fileids=whole_files[0])

'Pierre\tNNP\t2\nVinken\tNNP\t8\n,\t,\t2\n61\tCD\t5\nyears\tNNS\t6\nold\tJJ\t2\n,\t,\t2\nwill\tMD\t0\njoin\tVB\t8\nthe\tDT\t11\nboard\tNN\t9\nas\tIN\t9\na\tDT\t15\nnonexecutive\tJJ\t15\ndirector\tNN\t12\nNov.\tNNP\t9\n29\tCD\t16\n.\t.\t8\n\nMr.\tNNP\t2\nVinken\tNNP\t3\nis\tVBZ\t0\nchairman\tNN\t3\nof\tIN\t4\nElsevier\tNNP\t7\nN.V.\tNNP\t12\n,\t,\t12\nthe\tDT\t12\nDutch\tNNP\t12\npublishing\tVBG\t12\ngroup\tNN\t5\n.\t.\t3\n'

In [138]:
def parse_file(fileid, preprocessor=None):
    """
    Parse the given file identifier from the dependency treebank corpus
    and return a tuple (`tokens`, `tags`), where `tokens` is the list
    of tokens retrieved in the document and `tags` is the associated
    list of tags (`tokens` and `tags` have the same lenght)
    
    If you wish to preprocess tokens, you can pass a function to the
    `preprocessor` argument, which should take as input only the token
    to transform
    """
    file_str = dependency_treebank.raw(fileids=fileid)
    splitted_file_str = [
        x for x in re.split("\t|\n", file_str.strip()) if x.strip() != ""
    ]
    if preprocessor is None:
        preprocessor = lambda t: t
    tokens, tags = [], []
    for i in range(0, len(splitted_file_str), 3):
        token = preprocessor(splitted_file_str[i])
        tag = splitted_file_str[i + 1]
        tokens.append(token)
        tags.append(tag)
    return tokens, tags


def parse_files(fileids, preprocessor=None):
    """
    Parse a set of file identifiers from the dependency treebank corpus
    and return two lists, one which contains lists of tokens and one 
    containing lists of corresponding tags for each file identifier
    """
    tokens_list, tags_list = [], []
    for fileid in fileids:
        tokens, tags = parse_file(fileid, preprocessor=preprocessor)
        tokens_list.append(tokens)
        tags_list.append(tags)
    return tokens_list, tags_list

In [139]:
def preprocess_token(token):
    """
    Peform small modifications to the given token:
        - Transform to lowercase
        - Encode numbers as "<num>"
    """
    token = token.lower()
    token = "<num>" if re.match(utils.FLOAT_RE, token) else token
    return token

In [140]:
whole_tokens, whole_tags = parse_files(whole_files, preprocessor=preprocess_token)

In [141]:
def flatten(a):
    """
    Given a 2D list, returns its flattened version
    """
    return [i for s in a for i in s]


flattened_tags = flatten(whole_tags)
flattened_tokens = flatten(whole_tokens)

In [142]:
unique_tags = np.unique(flattened_tags)
len(unique_tags)

45

In [143]:
tags_fd = nltk.probability.FreqDist(flattened_tags)
tags_fd.most_common(10)

[('NN', 13166),
 ('IN', 9857),
 ('NNP', 9410),
 ('DT', 8165),
 ('NNS', 6047),
 ('JJ', 5834),
 (',', 4886),
 ('.', 3874),
 ('CD', 3546),
 ('VBD', 3043)]

In [144]:
unique_tokens = np.unique(flattened_tokens)
len(unique_tokens)

9964

In [145]:
tokens_fd = nltk.probability.FreqDist(flattened_tokens)
tokens_fd.most_common(10)

[(',', 4885),
 ('the', 4764),
 ('.', 3828),
 ('<num>', 2471),
 ('of', 2325),
 ('to', 2182),
 ('a', 1988),
 ('in', 1769),
 ('and', 1556),
 ("'s", 865)]

In [146]:
def build_vocabulary(tokens, padding_token="0"):
    """
    Given a list of tokens, builds the corresponding word vocabulary
    """
    words = sorted(set(tokens))
    vocabulary, inverse_vocabulary = dict(), dict()
    vocabulary[0] = str(padding_token)
    inverse_vocabulary[str(padding_token)] = 0
    for i, w in tqdm(enumerate(words)):
        vocabulary[i + 1] = w
        inverse_vocabulary[w] = i + 1
    return vocabulary, inverse_vocabulary, words

In [147]:
PADDING_TOKEN = "0"
PADDING_TOKEN in (flattened_tokens, flattened_tags)

False

In [148]:
index_to_word, word_to_index, word_listing = build_vocabulary(
    flattened_tokens, padding_token=PADDING_TOKEN
)

9964it [00:00, 635079.55it/s]


In [149]:
list(index_to_word.items())[:10]

[(0, '0'),
 (1, '!'),
 (2, '#'),
 (3, '$'),
 (4, '%'),
 (5, '&'),
 (6, "'"),
 (7, "''"),
 (8, "'30s"),
 (9, "'40s")]

In [150]:
index_to_tag, tag_to_index, tag_listing = build_vocabulary(
    flattened_tags, padding_token=PADDING_TOKEN
)

45it [00:00, 166001.48it/s]


In [151]:
list(index_to_tag.items())[:10]

[(0, '0'),
 (1, '#'),
 (2, '$'),
 (3, "''"),
 (4, ','),
 (5, '-LRB-'),
 (6, '-RRB-'),
 (7, '.'),
 (8, ':'),
 (9, 'CC')]

In [280]:
no_punct_tags = list(set(tag_listing) - set(string.punctuation))
no_punct_tags_indexes = [tag_to_index[t] for t in no_punct_tags]
print(no_punct_tags)

['DT', 'NNPS', 'WDT', 'IN', 'JJR', 'NN', 'PRP', 'VBD', 'WP$', 'LS', 'MD', 'NNS', 'VB', '``', 'RB', 'RP', 'RBS', 'CC', 'WP', 'VBG', 'CD', 'POS', 'FW', 'VBP', 'NNP', 'UH', 'VBZ', 'JJS', 'PRP$', '-LRB-', '-RRB-', 'JJ', 'SYM', 'EX', "''", 'RBR', 'WRB', 'PDT', 'VBN', 'TO']


In [152]:
def to_indexes(values, to_index):
    """
    Given a list of keys and a dictionary indexed by those keys,
    return the corresponding values in the dictionary
    """
    return [to_index[v] for v in values]


whole_indexed_tokens = map(
    lambda tokens: to_indexes(tokens, word_to_index), whole_tokens
)
whole_indexed_tags = map(lambda tags: to_indexes(tags, tag_to_index), whole_tags)

In [153]:
df = pd.DataFrame(
    {
        "tokens": whole_tokens,
        "indexed_tokens": whole_indexed_tokens,
        "tags": whole_tags,
        "indexed_tags": whole_indexed_tags,
        "split": splits,
        "fileid": whole_files,
    }
)
df.head()

Unnamed: 0,tokens,indexed_tokens,tags,indexed_tags,split,fileid
0,"[pierre, vinken, ,, <num>, years, old, ,, will...","[6562, 9557, 20, 40, 9919, 6130, 20, 9793, 474...","[NNP, NNP, ,, CD, NNS, JJ, ,, MD, VB, DT, NN, ...","[21, 21, 4, 10, 23, 15, 4, 19, 35, 11, 20, 14,...",train,wsj_0001.dp
1,"[rudolph, agnew, ,, <num>, years, old, and, fo...","[7695, 256, 20, 40, 9919, 6130, 394, 3607, 146...","[NNP, NNP, ,, CD, NNS, JJ, CC, JJ, NN, IN, NNP...","[21, 21, 4, 10, 23, 15, 9, 15, 20, 14, 21, 21,...",train,wsj_0002.dp
2,"[a, form, of, asbestos, once, used, to, make, ...","[45, 3601, 6095, 570, 6146, 9453, 9063, 5302, ...","[DT, NN, IN, NN, RB, VBN, TO, VB, NNP, NN, NNS...","[11, 20, 14, 20, 28, 38, 33, 35, 21, 20, 23, 4...",train,wsj_0003.dp
3,"[yields, on, money-market, mutual, funds, cont...","[9931, 6144, 5688, 5792, 3725, 1968, 9063, 821...","[NNS, IN, JJ, JJ, NNS, VBD, TO, VB, ,, IN, NNS...","[23, 14, 15, 15, 23, 36, 33, 35, 4, 14, 23, 14...",train,wsj_0004.dp
4,"[j.p., bolduc, ,, vice, chairman, of, w.r., gr...","[4701, 1024, 20, 9535, 1462, 6095, 9609, 3894,...","[NNP, NNP, ,, NN, NN, IN, NNP, NNP, CC, NNP, ,...","[21, 21, 4, 20, 20, 14, 21, 21, 9, 21, 4, 41, ...",train,wsj_0005.dp


In [154]:
class DependencyTreebankDataset(Dataset):
    """
    Dependency treebank dataset for POS tagging
    """

    def __init__(self, df):
        self.df = df.copy()
        self.df = self.df.reset_index(drop=True)

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, index):
        assert isinstance(index, int)
        tokens = self.df.loc[index, "indexed_tokens"]
        tags = self.df.loc[index, "indexed_tags"]
        return tokens, tags

In [155]:
train_dataset = DependencyTreebankDataset(df[df["split"] == "train"])
val_dataset = DependencyTreebankDataset(df[df["split"] == "val"])
test_dataset = DependencyTreebankDataset(df[df["split"] == "test"])

In [156]:
print(train_dataset[0])

([6562, 9557, 20, 40, 9919, 6130, 20, 9793, 4749, 8966, 1013, 568, 45, 5988, 2554, 6028, 40, 27, 5755, 9557, 4676, 1462, 6095, 2953, 5808, 20, 8966, 2826, 7018, 3943, 27], [21, 21, 4, 10, 23, 15, 4, 19, 35, 11, 20, 14, 11, 15, 20, 21, 10, 7, 21, 21, 40, 20, 14, 21, 21, 4, 11, 21, 37, 20, 7])


In [157]:
def pad_batch(batch):
    """
    This function expects to receive a list of tuples (i.e. a batch),
    s.t. each tuple contains tokens and tags for one sentence in the batch
    and returns the same sequences padded with the padding token
    """
    (tokens, tags) = zip(*batch)
    tokens_lenghts = [len(x) for x in tokens]
    tags_lenghts = [len(y) for y in tags]
    padded_tokens = pad_sequence(
        [torch.tensor(t) for t in tokens],
        batch_first=True,
        padding_value=int(PADDING_TOKEN),
    )
    padded_tags = pad_sequence(
        [torch.tensor(t) for t in tags],
        batch_first=True,
        padding_value=int(PADDING_TOKEN),
    )
    return padded_tokens, padded_tags, tokens_lenghts, tags_lenghts


batch_size = 5
default_dataloader = partial(
    DataLoader,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=pad_batch,
    pin_memory=True,
)
train_dataloader = default_dataloader(train_dataset)
val_dataloader = default_dataloader(val_dataset)
test_dataloader = default_dataloader(test_dataset)

In [158]:
embedding_dimension = 50
embedding_model = utils.load_embedding_model("glove", embedding_dimension=embedding_dimension)

In [159]:
def check_oov_terms(embedding_model, word_listing):
    """
    Checks differences between pre-trained embedding model vocabulary
    and dataset specific vocabulary in order to highlight out-of-vocabulary terms
    """
    oov_terms = []
    for word in word_listing:
        if word not in embedding_model.vocab:
            oov_terms.append(word)
    return oov_terms

In [160]:
oov_terms = check_oov_terms(embedding_model, word_listing)
print(
    f"Total OOV terms: {len(oov_terms)} ({round(len(oov_terms) / len(word_listing), 2)}%)"
)

Total OOV terms: 508 (0.05%)


In [164]:
def build_embedding_matrix(
    embedding_model,
    embedding_dimension,
    word_to_index,
    oov_terms,
    method="normal",
    padding_token="0"
):
    """
    Builds the embedding matrix of a specific dataset given a pre-trained Gensim word embedding model
    """

    def uniform_embedding(embedding_dimension, interval=(-1, 1)):
        return interval[0] + np.random.sample(embedding_dimension) + interval[1]

    def normal_embedding(embedding_dimension):
        return np.random.normal(embedding_dimension)

    embedding_matrix = np.zeros((len(word_to_index), embedding_dimension))
    for word, index in word_to_index.items():
        if word == padding_token:
            word_vector = np.zeros((1, embedding_dimension))
        # Words that are no OOV are taken from the Gensim model
        elif word not in oov_terms:
            word_vector = embedding_model[word]
        # OOV words computed as random normal vectors
        elif method == "normal":
            word_vector = normal_embedding(embedding_dimension)
        # OOV words computed as uniform vectors in range [-1, 1]
        elif method == "uniform":
            word_vector = uniform_embedding(embedding_dimension)
        embedding_matrix[index, :] = word_vector
    return embedding_matrix

In [165]:
embedding_matrix = build_embedding_matrix(
    embedding_model,
    embedding_dimension,
    word_to_index,
    oov_terms,
    method="normal",
    padding_token=PADDING_TOKEN,
)
embedding_matrix[word_to_index["the"]]

array([ 4.18000013e-01,  2.49679998e-01, -4.12420005e-01,  1.21699996e-01,
        3.45270008e-01, -4.44569997e-02, -4.96879995e-01, -1.78619996e-01,
       -6.60229998e-04, -6.56599998e-01,  2.78430015e-01, -1.47670001e-01,
       -5.56770027e-01,  1.46579996e-01, -9.50950012e-03,  1.16579998e-02,
        1.02040000e-01, -1.27920002e-01, -8.44299972e-01, -1.21809997e-01,
       -1.68009996e-02, -3.32789987e-01, -1.55200005e-01, -2.31309995e-01,
       -1.91809997e-01, -1.88230002e+00, -7.67459989e-01,  9.90509987e-02,
       -4.21249986e-01, -1.95260003e-01,  4.00710011e+00, -1.85939997e-01,
       -5.22870004e-01, -3.16810012e-01,  5.92130003e-04,  7.44489999e-03,
        1.77780002e-01, -1.58969998e-01,  1.20409997e-02, -5.42230010e-02,
       -2.98709989e-01, -1.57490000e-01, -3.47579986e-01, -4.56370004e-02,
       -4.42510009e-01,  1.87849998e-01,  2.78489990e-03, -1.84110001e-01,
       -1.15139998e-01, -7.85809994e-01])

In [166]:
embedding_matrix[0]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [167]:
class POSTaggingModel(nn.Module):
    def __init__(
        self,
        input_dimension,
        embedding_dimension,
        hidden_dimension,
        output_dimension,
        embedding_matrix=None,
        retrain_embeddings=True,
        gru=True,
        num_layers=1,
        bidirectional=True,
        dropout_rate=0.0,
    ):
        """
        Build a generic POS tagging model, with recurrent modules
        """
        super().__init__()

        # Embedding module
        self.embedding = nn.Embedding(
            input_dimension, embedding_dimension, padding_idx=0
        )
        if embedding_matrix is not None:
            assert (
                embedding_matrix.shape[0] == input_dimension
                and embedding_matrix.shape[1] == embedding_dimension
            )
            self.embedding.weight = nn.Parameter(torch.FloatTensor(embedding_matrix))
        self.embedding.weight.requires_grad = retrain_embeddings

        # Recurrent module
        recurrent_module = nn.GRU if gru else nn.LSTM
        self.recurrent_module = recurrent_module(
            embedding_dimension,
            hidden_dimension,
            batch_first=True,
            num_layers=num_layers,
            bidirectional=bidirectional,
            dropout=dropout_rate if num_layers > 1 else 0,
        )

        # Dense and dropout
        self.dense = nn.Linear(
            hidden_dimension * 2 if bidirectional else hidden_dimension,
            output_dimension,
        )
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, tokens, tokens_lenghts):
        embedded = self.dropout(self.embedding(tokens))
        packed = pack_padded_sequence(
            embedded, tokens_lenghts, batch_first=True, enforce_sorted=False
        )
        packed_outputs, (hidden, cell) = self.recurrent_module(packed)
        padded_outputs, outputs_lengths = pad_packed_sequence(
            packed_outputs, batch_first=True
        )
        predictions = self.dense(self.dropout(padded_outputs))
        return predictions

In [168]:
hidden_dimension = 128
num_layers = 1
bidirectional = True
dropout_rate = 0.0

baseline_model = POSTaggingModel(
    len(word_to_index),
    embedding_dimension,
    hidden_dimension,
    len(tag_to_index),
    embedding_matrix=embedding_matrix,
    retrain_embeddings=True,
    gru=False,
    num_layers=num_layers,
    bidirectional=bidirectional,
    dropout_rate=dropout_rate,
)

In [169]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f"The model has {count_parameters(baseline_model):,} trainable parameters")

The model has 694,392 trainable parameters


In [170]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.normal_(param.data, mean=0, std=0.1)


baseline_model.apply(init_weights)

POSTaggingModel(
  (embedding): Embedding(9965, 50, padding_idx=0)
  (recurrent_module): LSTM(50, 128, batch_first=True, bidirectional=True)
  (dense): Linear(in_features=256, out_features=46, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
)

In [171]:
optimizer = optim.Adam(baseline_model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index=0)

In [172]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
baseline_model = baseline_model.to(device)
criterion = criterion.to(device)
print(device)

cpu


In [205]:
def categorical_accuracy(predictions, ground_truth):
    """
    Returns accuracy per batch
    """
    max_predictions = predictions.argmax(dim=1, keepdim=True)
    non_pad_elements = torch.where(ground_truth != 0)[0]
    correct = (
        max_predictions[non_pad_elements].squeeze(1).eq(ground_truth[non_pad_elements])
    )
    return correct.sum() / torch.FloatTensor([ground_truth[non_pad_elements].shape[0]])

In [281]:
def f1_score(predictions, ground_truth, labels):
    """
    Returns F1-macro per batch
    """
    max_predictions = predictions.argmax(dim=1)
    non_pad_elements = torch.where(ground_truth != 0)[0]
    return sklearn.metrics.f1_score(
        ground_truth[non_pad_elements].cpu().detach().tolist(),
        max_predictions[non_pad_elements].cpu().detach().tolist(),
        labels=labels,
        average="macro",
    )

In [282]:
def train(model, dataloader, optimizer, criterion):
    """
    Train the given model with the given dataloader, optimizer and criterion
    """
    epoch_loss, epoch_acc, epoch_f1 = 0, 0, 0
    model.train()
    for tokens, tags, tokens_lenghts, tags_lenghts in tqdm(dataloader):
        optimizer.zero_grad()
        predictions = model(tokens, tokens_lenghts)
        predictions = predictions.view(-1, predictions.shape[-1])
        tags = tags.view(-1)
        loss = criterion(predictions, tags)
        acc = categorical_accuracy(predictions, tags)
        f1 = f1_score(predictions, tags, no_punct_tags_indexes)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        epoch_f1 += f1.item()

    return (
        epoch_loss / len(dataloader),
        epoch_acc / len(dataloader),
        epoch_f1 / len(dataloader),
    )

In [283]:
def evaluate(model, dataloader, criterion):
    """
    Evaluate the given model with the given dataloader, optimizer and criterion
    """
    epoch_loss, epoch_acc, epoch_f1 = 0, 0, 0
    model.eval()
    with torch.no_grad():
        for tokens, tags, tokens_lenghts, tags_lenghts in tqdm(dataloader):
            predictions = model(tokens, tokens_lenghts)
            predictions = predictions.view(-1, predictions.shape[-1])
            tags = tags.view(-1)
            loss = criterion(predictions, tags)
            acc = categorical_accuracy(predictions, tags)
            f1 = f1_score(predictions, tags, no_punct_tags_indexes)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
            epoch_f1 += f1.item()

    return (
        epoch_loss / len(dataloader),
        epoch_acc / len(dataloader),
        epoch_f1 / len(dataloader),
    )

In [284]:
def train_val_test(model, model_name, epochs=10):
    """
    Perform training, validation and testing on the given model,
    for the specified number of epochs
    """
    best_val_loss = float("inf")
    for epoch in range(epochs):
        # Perform train and validation
        print(f"Epoch: {epoch + 1:02}")
        start_time = time.time()
        print(f"Training...")
        train_loss, train_acc, train_f1 = train(
            model, train_dataloader, optimizer, criterion
        )
        print(f"Evaluating...")
        val_loss, val_acc, val_f1 = evaluate(model, val_dataloader, criterion)
        end_time = time.time()

        # Print epoch stats
        print(f"Epoch Time: {end_time - start_time}s")
        print(
            f"Train loss: {train_loss:.3f} | Train accuracy: {train_acc * 100:.2f}% | Train F1: {train_f1 * 100:.2f}%"
        )
        print(
            f"Validation loss: {val_loss:.3f} | Validation accuracy: {val_acc * 100:.2f}% | Validation F1: {val_f1 * 100:.2f}%"
        )

        # Save the best model so far
        if val_loss < best_val_loss:
            print("Saving new checkpoint...")
            best_val_loss = val_loss
            torch.save(model.state_dict(), f"models/{model_name}.pt")

        print()

    # Test the model
    model.load_state_dict(torch.load(f"models/{model_name}.pt"))
    print("Testing...")
    test_loss, test_acc, test_f1 = evaluate(model, test_dataloader, criterion)
    print(
        f"Test loss: {test_loss:.3f} | Test accuracy: {test_acc * 100:.2f}% | Test F1: {test_f1 * 100:.2f}%"
    )

In [285]:
train_val_test(baseline_model, "baseline_model", epochs=10)

  0%|          | 0/20 [00:00<?, ?it/s]

Epoch: 01
Training...


100%|██████████| 20/20 [01:26<00:00,  4.33s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Evaluating...


100%|██████████| 10/10 [00:05<00:00,  1.69it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

Epoch Time: 92.6566071510315s
Train loss: 1.261 | Train accuracy: 68.97% | Train F1: 26.91%
Validation Loss: 1.184 | Validation accuracy: 71.30% | Validation F1: 32.50%
Saving new checkpoint...

Epoch: 02
Training...


100%|██████████| 20/20 [01:25<00:00,  4.28s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Evaluating...


100%|██████████| 10/10 [00:06<00:00,  1.64it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

Epoch Time: 91.85422563552856s
Train loss: 0.949 | Train accuracy: 79.24% | Train F1: 39.56%
Validation Loss: 0.959 | Validation accuracy: 78.06% | Validation F1: 41.83%
Saving new checkpoint...

Epoch: 03
Training...


100%|██████████| 20/20 [01:29<00:00,  4.49s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Evaluating...


100%|██████████| 10/10 [00:05<00:00,  1.67it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

Epoch Time: 95.89124608039856s
Train loss: 0.717 | Train accuracy: 85.20% | Train F1: 45.80%
Validation Loss: 0.789 | Validation accuracy: 81.43% | Validation F1: 45.53%
Saving new checkpoint...

Epoch: 04
Training...


100%|██████████| 20/20 [01:30<00:00,  4.53s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Evaluating...


100%|██████████| 10/10 [00:05<00:00,  1.69it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

Epoch Time: 96.48748970031738s
Train loss: 0.549 | Train accuracy: 88.88% | Train F1: 49.24%
Validation Loss: 0.671 | Validation accuracy: 83.50% | Validation F1: 47.76%
Saving new checkpoint...

Epoch: 05
Training...


100%|██████████| 20/20 [01:27<00:00,  4.36s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Evaluating...


100%|██████████| 10/10 [00:05<00:00,  1.69it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

Epoch Time: 93.24930310249329s
Train loss: 0.427 | Train accuracy: 91.57% | Train F1: 52.85%
Validation Loss: 0.592 | Validation accuracy: 85.18% | Validation F1: 51.88%
Saving new checkpoint...

Epoch: 06
Training...


100%|██████████| 20/20 [01:29<00:00,  4.46s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Evaluating...


100%|██████████| 10/10 [00:06<00:00,  1.61it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

Epoch Time: 95.58493089675903s
Train loss: 0.336 | Train accuracy: 93.45% | Train F1: 58.30%
Validation Loss: 0.536 | Validation accuracy: 86.19% | Validation F1: 55.01%
Saving new checkpoint...

Epoch: 07
Training...


100%|██████████| 20/20 [01:24<00:00,  4.21s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Evaluating...


100%|██████████| 10/10 [00:05<00:00,  1.67it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

Epoch Time: 90.28789329528809s
Train loss: 0.269 | Train accuracy: 94.92% | Train F1: 63.45%
Validation Loss: 0.496 | Validation accuracy: 87.14% | Validation F1: 59.28%
Saving new checkpoint...

Epoch: 08
Training...


100%|██████████| 20/20 [01:22<00:00,  4.14s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Evaluating...


100%|██████████| 10/10 [00:06<00:00,  1.64it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

Epoch Time: 88.86557984352112s
Train loss: 0.220 | Train accuracy: 95.84% | Train F1: 65.78%
Validation Loss: 0.466 | Validation accuracy: 87.67% | Validation F1: 61.24%
Saving new checkpoint...

Epoch: 09
Training...


100%|██████████| 20/20 [01:22<00:00,  4.14s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Evaluating...


100%|██████████| 10/10 [00:05<00:00,  1.68it/s]
  0%|          | 0/20 [00:00<?, ?it/s]

Epoch Time: 88.8346483707428s
Train loss: 0.183 | Train accuracy: 96.62% | Train F1: 68.58%
Validation Loss: 0.446 | Validation accuracy: 88.09% | Validation F1: 65.14%
Saving new checkpoint...

Epoch: 10
Training...


100%|██████████| 20/20 [01:27<00:00,  4.39s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Evaluating...


100%|██████████| 10/10 [00:05<00:00,  1.68it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch Time: 93.85730004310608s
Train loss: 0.155 | Train accuracy: 97.15% | Train F1: 70.27%
Validation Loss: 0.435 | Validation accuracy: 88.31% | Validation F1: 66.61%
Saving new checkpoint...

Testing...


100%|██████████| 10/10 [00:03<00:00,  3.09it/s]

Test loss: 0.377 | Test accuracy: 89.51% | Test F1: 63.90%



