In [1]:
import os
import sys
import json
import torch

sys.path.append("../")
from lib.utils import get_device, get_current_date
from lib.utils.constants import Subtask, Track, PreprocessTextLevel, DatasetType
from lib.utils.models import sequential_fully_connected
from lib.data.loading import load_train_dev_test_df
from lib.data.tokenizer import get_tokenizer
from lib.training.optimizer import get_optimizer, get_scheduler
from lib.training.loss import get_loss_fn
from lib.training.metric import get_metric

  from .autonotebook import tqdm as notebook_tqdm


In [31]:
CONFIG_FILE_PATH = os.path.relpath("../config.json")

config = {}
with open(CONFIG_FILE_PATH, "r") as config_file:
    config = json.load(config_file)

DEVICE = get_device()
print(f"Using device: {DEVICE}")

Using device: mps


In [3]:
# config

In [32]:
task = None
if "task" in config:
    task = Subtask(config["task"])
else:
    raise ValueError("Task not specified in config")

track = None
if "track" in config:
    track = Track(config["track"])
else:
    print(f"Warning: Track not specified in config for subtask: {task}")

dataset_type = DatasetType.TransformerTruncationDataset
if "dataset_type" in config["data"]:
    dataset_type = DatasetType(config["data"]["dataset_type"])

dataset_type_settings = None
if "dataset_type_settings" in config["data"]:
    dataset_type_settings = config["data"]["dataset_type_settings"]

df_train, df_dev, df_test = load_train_dev_test_df(
    task=task,
    track=track,
    data_dir=f"../{config['data']['data_dir']}",
    label_column=config["data"]["label_column"],
    test_size=config["data"]["test_size"],
    preprocess_text_level=PreprocessTextLevel(
        config["data"]["preprocess_text_level"]
    ),
)

print(f"df_train.shape: {df_train.shape}")
print(f"df_dev.shape: {df_dev.shape}")
print(f"df_test.shape: {df_test.shape}")

Loading train data...
Train/dev split... (df_train.shape: (3649, 3))
Loading test data... ---> .././data/original_data/SubtaskC/SubtaskC_dev.jsonl
df_train.shape: (2919, 3)
df_dev.shape: (730, 3)
df_test.shape: (505, 3)


In [33]:
DEBUG = False
if DEBUG:
    results_dir = os.path.relpath("../runs/SubtaskC/")
else:
    results_dir = os.path.relpath(
        f"../runs/{get_current_date()}-{task.value}-{config['model']}"
    )
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

print(f"Will save results to: {results_dir}")

with open(results_dir + "/config.json", "w") as f:
    json.dump(config, f, indent=4)

Will save results to: ../runs/30-12-2023_20:25:57-SubtaskC-bilstm_for_token_classification


# Load Word2Vec embeddings

In [8]:
df_train.head()

Unnamed: 0,id,text,label
2784,4a33ddc4-b755-42f3-8526-f0b4f1b3626c,This paper is the first (I believe) to establi...,89
2988,f232cb20-24d5-439c-9634-791d2cf51873,This manuscript tries to tackle neural network...,55
2718,c001a50d-de63-49e8-acb1-4665391f9381,This paper proposes an extension of the MAC me...,150
2223,09299ba2-0870-4fdc-91f8-ceb29dddb958,The paper proposes two methods for what is cal...,133
1284,fa7bc655-b107-4258-9c26-42b9bdb89e09,SUMMARY \r\nThis paper discusses how data from...,43


In [41]:
# WORD2VEC_MODEL_NAME = "word2vec-google-news-300"
# WORD2VEC_MODEL_NAME = "glove-wiki-gigaword-300"
WORD2VEC_MODEL_NAME = "fasttext-wiki-news-subwords-300"  # ~ 50% (train), 60% (dev), 64% (test) word coverage

In [42]:
import gensim.downloader as gensim_api

word2vec_model = gensim_api.load(WORD2VEC_MODEL_NAME)

In [43]:
# Count how many words from the dataset are in the word2vec model
from collections import Counter


def count_words_in_word2vec_model(df, word2vec_model):
    word_counts = Counter()
    for text in df["text"]:
        text = text.replace("\n", " ")
        text = text.replace("\t", " ")
        text = text.replace("\r", " ")

        words = [w for w in text.split(" ") if w != ""]

        word_counts.update(words)

    word2vec_words = set(word2vec_model.index_to_key)
    dataset_words = set(word_counts.keys())

    print(f"Number of words in the word2vec model: {len(word2vec_words)}")
    print(f"Number of words in the dataset: {len(dataset_words)}")

    common_words = word2vec_words.intersection(dataset_words)
    percentage = len(common_words) / len(dataset_words) * 100

    print(
        f"Number of words in the dataset that are in the word2vec model: {len(common_words)} ({percentage:.2f}%)"
    )

    return common_words

In [44]:
_ = count_words_in_word2vec_model(df_train, word2vec_model)
_ = count_words_in_word2vec_model(df_dev, word2vec_model)
_ = count_words_in_word2vec_model(df_test, word2vec_model)

Number of words in the word2vec model: 999999
Number of words in the dataset: 22625
Number of words in the dataset that are in the word2vec model: 11185 (49.44%)
Number of words in the word2vec model: 999999
Number of words in the dataset: 12553
Number of words in the dataset that are in the word2vec model: 7447 (59.32%)
Number of words in the word2vec model: 999999
Number of words in the dataset: 7873
Number of words in the dataset that are in the word2vec model: 5038 (63.99%)


In [45]:
print(type(word2vec_model))

<class 'gensim.models.keyedvectors.KeyedVectors'>


# Build vocabulary

In [34]:
import pandas as pd
from tqdm import tqdm


def split_text_into_words(text):
    text = text.replace("\n", " ")
    text = text.replace("\t", " ")
    text = text.replace("\r", " ")

    words = [w for w in text.split(" ") if w != ""]

    return words


class Vocabulary:
    unknown_token = "<UNK>"
    padding_token = "<PAD>"

    unknown_token_idx = 0
    padding_token_idx = 1

    def __init__(self):
        self.start_idx = 2
        self.word2idx = {
            Vocabulary.unknown_token: Vocabulary.unknown_token_idx,
            Vocabulary.padding_token: Vocabulary.padding_token_idx,
        }
        self.idx2word = {
            Vocabulary.unknown_token_idx: Vocabulary.unknown_token,
            Vocabulary.padding_token_idx: Vocabulary.padding_token,
        }

    def build_vocabulary(self, df: pd.DataFrame):
        idx = self.start_idx
        for text in tqdm(df["text"], desc="Building vocabulary"):
            words = split_text_into_words(text)

            for word in words:
                if word in self.word2idx:
                    continue

                # word is not in word2idx
                self.word2idx[word] = idx
                self.idx2word[idx] = word

                idx += 1

In [35]:
vocabulary = Vocabulary()
vocabulary.build_vocabulary(df_train)

Building vocabulary: 100%|██████████| 2919/2919 [00:00<00:00, 34085.70it/s]


In [7]:
print(f"Vocabulary size: {len(vocabulary.word2idx)}")

Vocabulary size: 22627


In [17]:
def get_unknown_words(vocab, df):
    all_words = set()
    unknown_words = set()
    for text in tqdm(df["text"], desc="Counting unknown words"):
        words = split_text_into_words(text)

        for word in words:
            all_words.add(word)
            if word not in vocab.word2idx:
                unknown_words.add(word)

    return all_words, unknown_words

In [18]:
dev_all_words, dev_unknown_words = get_unknown_words(vocabulary, df_dev)
print(f"Number of unknown words in dev: {len(dev_unknown_words)}/{len(dev_all_words)} ({len(dev_unknown_words) / len(dev_all_words) * 100:.2f}%)")

test_all_words, test_unknown_words = get_unknown_words(vocabulary, df_test)
print(f"Number of unknown words in test: {len(test_unknown_words)}/{len(test_all_words)} ({len(test_unknown_words) / len(test_all_words) * 100:.2f}%)")

Counting unknown words: 100%|██████████| 730/730 [00:00<00:00, 23214.41it/s]


Number of unknown words in dev: 2206/12553 (17.57%)


Counting unknown words: 100%|██████████| 505/505 [00:00<00:00, 33063.13it/s]

Number of unknown words in test: 2229/7873 (28.31%)





# Build the dataset

In [36]:
import numpy as np
from torch.utils.data import Dataset


class TokenClassificationDataset(Dataset):
    def __init__(
        self,
        ids: np.ndarray,
        texts: np.ndarray,
        targets: np.ndarray | None,
        vocabulary: Vocabulary,
        max_len: int,
        debug: bool = False,
    ):
        super().__init__()

        self.ids = ids
        self.texts = texts
        self.targets = targets
        self.vocab = vocabulary
        self.max_len = max_len
        self.debug = debug

    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, index):
        item_id = self.ids[index]
        text = self.texts[index]
        target = -1 if self.targets is None else self.targets[index]
        targets_available = False if target == -1 else True

        words = split_text_into_words(text)

        if self.debug:
            print(f"Text: {text}")
            print(f"Words: {words}")
            print(f"Machine text start position: {target}")
            print()

        targets = []
        corresponding_word = []
        tokens = []
        input_ids = []
        attention_mask = []

        for idx, word in enumerate(words):
            # word_encoded = self.tokenizer.tokenize(word)  # No [CLS] or [SEP]
            # sub_words = len(word_encoded)

            if targets_available:
                is_machine_text = 1 if idx >= target else 0
                # targets.extend([is_machine_text] * sub_words)
                targets.append(is_machine_text)

            corresponding_word.append(idx)
            tokens.append(word)
            input_ids.append(
                self.vocab.word2idx.get(word, self.vocab.unknown_token_idx)
            )
            attention_mask.append(1)

            if self.debug:
                print(
                    f"word[{idx}]:\n"
                    f"{'':-<5}> tokens: {[word]}\n"
                    f"{'':-<5}> corresponding_word: {corresponding_word[-1]}\n"
                    f"{'':-<5}> input_ids: {input_ids[-1]}\n"
                    f"{'':-<5}> is_machine_text: {is_machine_text}"
                )

        if self.debug:
            print()

            print(f"corresponding_word: {corresponding_word}")
            print(f"tokens: {tokens}")
            print(f"input_ids: {input_ids}")
            print(f"attention_mask: {attention_mask}")

            print()

            print(f"Machine text start word: {words[targets.index(1)]}")
            print(f"True machine text start word: {words[target]}")

            print()

        if len(input_ids) < self.max_len:
            if targets_available:
                targets = (
                    targets
                    + [-100] * (self.max_len - len(input_ids))
                )

            corresponding_word = (
                # [-100]
                corresponding_word
                + [-100] * (self.max_len - len(input_ids))
            )
            tokens = (
                # [self.tokenizer.bos_token]
                tokens
                # + [self.tokenizer.eos_token]
                + [self.vocab.padding_token] * (self.max_len - len(tokens))
            )
            input_ids = (
                # [self.tokenizer.bos_token_id]
                input_ids
                # + [self.tokenizer.eos_token_id]
                + [self.vocab.padding_token_idx] * (self.max_len - len(input_ids))
            )
            attention_mask = (
                # [1]
                attention_mask
                + [0] * (self.max_len - len(attention_mask))
            )
        else:
            if targets_available:
                targets = targets[: self.max_len]

            corresponding_word = corresponding_word[: self.max_len]
            # corresponding_word = (
            #     [-100]
            #     + corresponding_word[: self.max_len - 2]
            #     + [-100]
            # )
            tokens = tokens[: self.max_len]
            # tokens = (
            #     [self.tokenizer.bos_token]
            #     + tokens[: self.max_len - 2]
            #     + [self.tokenizer.eos_token]
            # )
            input_ids = input_ids[: self.max_len]
            # input_ids = (
            #     [self.tokenizer.bos_token_id]
            #     + input_ids[: self.max_len - 2]
            #     + [self.tokenizer.eos_token_id]
            # )
            attention_mask = attention_mask[: self.max_len]
            # attention_mask = (
            #     [1]
            #     + attention_mask[: self.max_len - 2]
            #     + [1]
            # )

        encoded = {}
        encoded["id"] = item_id
        encoded["text"] = text
        encoded["true_target"] = torch.tensor(target)
        encoded["corresponding_word"] = torch.tensor(corresponding_word)
        encoded["input_ids"] = torch.tensor(input_ids)
        encoded["attention_mask"] = torch.tensor(attention_mask)
        if targets_available:
            encoded["target"] = torch.tensor(targets)

        if self.debug:
            print(f"Tokenized human position: {targets.index(1)}")
            print(f"Original human position: {target}")
            print(f"Full human text: {text}\n\n")
            print(f"Human truncated text: {[w for w in text.split(' ')[:target] if w != '']}\n\n")

            encoded["partial_human_review"] = " ".join(
                [w for w in text.split(' ')[:target] if w != '']
            )

        return encoded

In [37]:
from torch.utils.data import DataLoader

# tokenizer = get_tokenizer(**config["tokenizer"])

train_dataset = TokenClassificationDataset(
    ids=df_train["id"].values,
    texts=df_train["text"].values,
    targets=df_train["label"].values,
    vocabulary=vocabulary,
    max_len=config["data"]["max_len"],
    debug=False,
)
dev_dataset = TokenClassificationDataset(
    ids=df_dev["id"].values,
    texts=df_dev["text"].values,
    targets=df_dev["label"].values,
    vocabulary=vocabulary,
    max_len=config["data"]["max_len"],
    debug=False,
)
test_dataset = TokenClassificationDataset(
    ids=df_test["id"].values,
    texts=df_test["text"].values,
    targets=df_test["label"].values,
    vocabulary=vocabulary,
    max_len=config["data"]["max_len"],
    debug=False,
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=config["data"]["batch_size"],
    shuffle=True,
)
dev_dataloader = DataLoader(
    dev_dataset,
    batch_size=config["data"]["batch_size"],
    shuffle=False,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=config["data"]["batch_size"],
    shuffle=False,
)

In [21]:
vocabulary.idx2word[3]

'paper'

In [10]:
# for i, batch in enumerate(train_dataloader):
#     print(f"Batch=[{i + 1}/{len(train_dataloader)}]")
#     print(f"batch['input_ids'].shape: {batch['input_ids'].shape}")
#     print(f"batch['attention_mask'].shape: {batch['attention_mask'].shape}")
#     print(f"batch['target'].shape: {batch['target'].shape}")
#     print(f"batch['target']: {batch['target']}")
#     print(f"batch['corresponding_word']: {batch['corresponding_word']}")
#     break

# for i, batch in enumerate(dev_dataloader):
#     print(f"Batch=[{i + 1}/{len(dev_dataloader)}]")
# #     # break

In [11]:
# vocabulary

# Create BiLSTM model for token classification

In [38]:
import torch.nn as nn
from transformers import LongformerModel


class BiLSTMForTokenClassification(nn.Module):
    def __init__(
        self,
        vocab_size,
        embedding_dim,
        out_size,
        device,
        dropout_p=0.3,
        n_layers=1,
        hidden_dim=32,
        fc=[],
    ):
        super().__init__()

        self.out_size = out_size
        self.device = device
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim,
        )

        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            n_layers,
            bidirectional=True,
            batch_first=True,
        )

        self.dropout = nn.Dropout(p=dropout_p)
        self.classifier = nn.Linear(2 * hidden_dim, out_size)
        # self.classifier = sequential_fully_connected(
        #     2 * hidden_dim, out_size, fc, dropout_p
        # )

    def forward(self, input_ids, attention_mask, labels=None):
        embeddings = self.embedding(input_ids)
        # print(f"embeddings.shape: {embeddings.shape}")

        lengths = attention_mask.sum(dim=1)
        # print(f"lengths.shape: {lengths.shape}")

        packed_embeddings = nn.utils.rnn.pack_padded_sequence(
            embeddings, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        # print(f"packed_embeddings.data.shape: {packed_embeddings.data.shape}")

        packed_output, (_, _) = self.lstm(packed_embeddings)
        # print(f"packed_output.data.shape: {packed_output.data.shape}")

        output, _ = nn.utils.rnn.pad_packed_sequence(
            packed_output, batch_first=True, total_length=embeddings.shape[1],
        )
        # print(f"output.shape: {output.shape}")

        output = self.dropout(output)
        logits = self.classifier(output)

        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss().to(self.device)
            loss = loss_fn(logits.view(-1, self.out_size), labels.view(-1))

        return loss, logits

    def freeze_transformer_layer(self):
        pass

    def unfreeze_transformer_layer(self):
        pass

    def get_predictions_from_logits(self, logits, labels=None, corresponding_word=None):
        # batch_size = logits.shape[0]

        # logits: (batch_size, max_seq_len, out_size)
        # labels: (batch_size, max_seq_len)
        # corresponding_word: (batch_size, max_seq_len)

        # print(f"logits.shape: {logits.shape}")
        # print(f"logits: {logits}")

        # preds: (batch_size, max_seq_len)
        preds = torch.argmax(logits, dim=-1)

        # print(f"preds.shape: {preds.shape}")
        # print(f"preds: {preds}")

        if labels is not None:
            # print(f"labels.shape: {labels.shape}")
            # print(f"labels: {labels}")

            # Keep only predictions where labels are not -100
            # clean_preds = preds[labels != -100].reshape(batch_size, -1)
            # clean_labels = labels[labels != -100].reshape(batch_size, -1)

            # print(f"clean_preds.shape: {clean_preds.shape}")
            # print(f"clean_preds: {clean_preds}")

            # print(f"clean_labels.shape: {clean_labels.shape}")
            # print(f"clean_labels: {clean_labels}")

            # Get the index of the first machine text word
            # predicted_positions = clean_preds.argmax(dim=-1)
            # true_positions = clean_labels.argmax(dim=-1)

            predicted_positions = []
            true_positions = []
            for p, l in zip(preds, labels):
                mask = l != -100

                clean_pred = p[mask]
                clean_label = l[mask]

                # print(f"clean_pred.shape: {clean_pred.shape}")
                # print(f"clean_pred: {clean_pred}")
                # print(f"clean_label.shape: {clean_label.shape}")
                # print(f"clean_label: {clean_label}")

                predicted_position = clean_pred.argmax(dim=-1)
                true_position = clean_label.argmax(dim=-1)

                # print(f"predicted_position: {predicted_position}")
                # print(f"true_position: {true_position}")

                predicted_positions.append(predicted_position.item())
                true_positions.append(true_position.item())

            # print(f"predicted_positions.shape: {predicted_positions.shape}")
            # print(f"predicted_positions: {predicted_positions}")

            # print(f"true_positions.shape: {true_positions.shape}")
            # print(f"true_positions: {true_positions}")

            # print(f"predicted_positions type: {type(predicted_positions)}")
            # print(f"true_positions type: {type(true_positions)}")

            return torch.Tensor(predicted_positions), torch.Tensor(true_positions)
        elif corresponding_word is not None:
            # print(f"corresponding_word.shape: {corresponding_word.shape}")
            # print(f"corresponding_word: {corresponding_word}")

            # Keep only predictions where corresponding_word are not -100
            # clean_preds = preds[corresponding_word != -100].reshape(
            #     batch_size, -1
            # ).detach().cpu().numpy()
            # clean_corresponding_word = corresponding_word[corresponding_word != -100].reshape(
            #     batch_size, -1
            # ).detach().cpu().numpy()

            # print(f"clean_preds.shape: {clean_preds.shape}")
            # print(f"clean_preds: {clean_preds}")

            # print(f"clean_corresponding_word.shape: {clean_corresponding_word.shape}")
            # print(f"clean_corresponding_word: {clean_corresponding_word}")

            predicted_positions = []
            for p, w in zip(preds, corresponding_word):
                mask = w != -100

                clean_pred = p[mask]
                clean_corresponding_word = w[mask]

                # print(f"clean_pred.shape: {clean_pred.shape}")
                # print(f"clean_pred: {clean_pred}")
                # print(f"clean_corresponding_word.shape: {clean_corresponding_word.shape}")
                # print(f"clean_corresponding_word: {clean_corresponding_word}")

                # Get the index of the first machine text word
                index = torch.where(clean_pred == 1)[0]
                value = index[0] if index.size else len(clean_pred) - 1
                position = clean_corresponding_word[value]

                # print(f"index: {index}")
                # print(f"value: {value}")
                # print(f"position: {position}")

                predicted_positions.append(position.item())
            #     # pred = pred.detach().cpu().numpy()

            #     index = np.where(pred == 1)[0]
            #     value = index[0] if index.size else len(pred) - 1
            #     position = clean_corresponding_word[idx][value]

            #     predicted_positions.append(position.item())

            # print(f"predicted_positions: {predicted_positions}")

            return predicted_positions, None
        else:
            raise ValueError("Either labels or corresponding_word must be provided")

# Train model

In [39]:
import pandas as pd
from tqdm import tqdm
# from time import time
from collections import defaultdict


def train_epoch(
    model,
    dataloader,
    loss_fn,
    optimizer,
    device,
    scheduler,
    metric_fn,
    print_freq=10,
):
    model.train()

    losses = []

    all_predictions = []
    all_true = []
    all_ids = []

    for i, batch in enumerate(dataloader):
        ids = batch["id"]
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        targets = batch["target"].to(device)
        corresponding_word = batch["corresponding_word"].to(device)

        loss, logits = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=targets,
        )

        predictions, true_predictions = model.get_predictions_from_logits(
            logits=logits,
            labels=targets,
            corresponding_word=corresponding_word
        )

        # print(f"predictions: {predictions}")
        # print(f"true_predictions: {true_predictions}")

        losses.append(loss.item())

        all_predictions.extend(predictions.tolist())
        all_true.extend(true_predictions.tolist())
        all_ids.extend(ids)

        if i % print_freq == 0:
            print(
                f"Batch [{i + 1}/{len(dataloader)}]; "
                f"Loss: {loss.item():.5f}; "
                f"Mean absolute error: {metric_fn(true_predictions, predictions):.5f}"
            )

        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    return np.mean(losses), (all_ids, all_true, all_predictions)


def validation_epoch(
    model,
    dataloader,
    loss_fn,
    device,
    metric_fn,
):
    model.eval()

    losses = []
    all_predictions = []
    all_true = []
    all_ids = []

    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader)):
            ids = batch["id"]
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            targets = batch["target"].to(device)
            corresponding_word = batch["corresponding_word"].to(device)

            loss, logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=targets,
            )

            predictions, true_predictions = model.get_predictions_from_logits(
                logits=logits,
                labels=targets,
                corresponding_word=corresponding_word
            )

            losses.append(loss.item())

            all_predictions.extend(predictions.tolist())
            all_true.extend(true_predictions.tolist())
            all_ids.extend(ids)

    return np.mean(losses), (all_ids, all_true, all_predictions)


def training_loop(
    model,
    num_epochs,
    train_dataloader,
    dev_dataloader,
    loss_fn,
    optimizer_config,
    scheduler_config,
    device,
    metric_fn,
    is_better_metric_fn,
    num_epochs_before_finetune,
    results_dir,
):
    history = defaultdict(list)
    best_metric = None
    best_model_state = None

    optimizer = get_optimizer(model, optimizer_config, finetune=False)
    scheduler = None

    for epoch in range(1, num_epochs + 1):
        print(f"Epoch {epoch}/{num_epochs}")
        if epoch <= num_epochs_before_finetune:
            print("Freeze transformer")
        else:
            print("Finetune transformer")
        print("-" * 10)

        if epoch == num_epochs_before_finetune + 1:
            model.unfreeze_transformer_layer()
            optimizer = get_optimizer(model, optimizer_config, finetune=True)
            scheduler = get_scheduler(
                optimizer,
                num_training_steps=len(train_dataloader) * num_epochs,
                **scheduler_config,
            )

        train_loss, (train_ids, train_true, train_predict) = train_epoch(
            model,
            train_dataloader,
            loss_fn,
            optimizer,
            device,
            scheduler,
            metric_fn,
        )

        train_metric = metric_fn(train_true, train_predict)

        print(f"Train Loss: {train_loss:.5f}; Train Metric: {train_metric:.5f}")

        dev_loss, (dev_ids, dev_true, dev_predict) = validation_epoch(
            model,
            dev_dataloader,
            loss_fn,
            device,
            metric_fn,
        )

        dev_metric = metric_fn(dev_true, dev_predict)

        print(
            f"Validation Loss: {dev_loss:.5f}; "
            f"Validation Metric: {dev_metric:.5f}"
        )

        history["train_metric"].append(train_metric)
        history["train_loss"].append(train_loss)
        history["dev_metric"].append(dev_metric)
        history["dev_loss"].append(dev_loss)

        if best_metric is None or is_better_metric_fn(train_metric, best_metric):
            best_metric = train_metric
            best_model_state = model.state_dict()
            
            if results_dir is not None:
                torch.save(
                    best_model_state,
                    os.path.join(results_dir, "best_model.bin"),
                )

                df_train_predictions = pd.DataFrame(
                    {
                        "id": train_ids,
                        "true": train_true,
                        "predict": train_predict,
                    }
                )
                df_train_predictions.to_csv(
                    os.path.join(results_dir, "best_model_train_predict.csv"),
                    index=False
                )

                df_dev_predictions = pd.DataFrame(
                    {
                        "id": dev_ids,
                        "true": dev_true,
                        "predict": dev_predict,
                    }
                )
                df_dev_predictions.to_csv(
                    os.path.join(results_dir, "best_model_dev_predict.csv"),
                    index=False
                )

    df_history = pd.DataFrame(history)
    if results_dir is not None:
        df_history.to_csv(os.path.join(results_dir, "history.csv"), index=False)

        model.load_state_dict(torch.load(os.path.join(results_dir, "best_model.bin")))
    else:
        model.load_state_dict(best_model_state)

    return model, df_history

In [40]:
import torch

torch.cuda.empty_cache()

In [41]:
config["model_config"]["vocab_size"] = len(vocabulary.word2idx)

In [42]:
with open(results_dir + "/config.json", "w") as f:
    json.dump(config, f, indent=4)

In [43]:
num_epochs = config["training"]["num_epochs"]
model = BiLSTMForTokenClassification(
    device=DEVICE, **config["model_config"]
).to(DEVICE)
loss_fn = get_loss_fn(config["training"]["loss"], DEVICE)
optimizer_config = config["training"]["optimizer"]
scheduler_config = config["training"]["scheduler"]
metric_fn, is_better_metric_fn = get_metric(config["training"]["metric"])
num_epochs_before_finetune = config["training"]["num_epochs_before_finetune"]

best_model, df_history = training_loop(
    model,
    num_epochs,
    train_dataloader,
    dev_dataloader,
    loss_fn,
    optimizer_config,
    scheduler_config,
    DEVICE,
    metric_fn,
    is_better_metric_fn,
    num_epochs_before_finetune,
    results_dir,
)

Epoch 1/10
Freeze transformer
----------
Batch [1/183]; Loss: 0.15373; Mean absolute error: 59.81250
Batch [11/183]; Loss: 0.16196; Mean absolute error: 71.25000
Batch [21/183]; Loss: 0.15641; Mean absolute error: 88.68750
Batch [31/183]; Loss: 0.14351; Mean absolute error: 77.31250
Batch [41/183]; Loss: 0.13299; Mean absolute error: 76.37500
Batch [51/183]; Loss: 0.10646; Mean absolute error: 55.68750
Batch [61/183]; Loss: 0.10958; Mean absolute error: 58.25000
Batch [71/183]; Loss: 0.14069; Mean absolute error: 88.68750
Batch [81/183]; Loss: 0.09015; Mean absolute error: 52.50000
Batch [91/183]; Loss: 0.21523; Mean absolute error: 125.93750
Batch [101/183]; Loss: 0.07662; Mean absolute error: 55.50000
Batch [111/183]; Loss: 0.07462; Mean absolute error: 44.56250
Batch [121/183]; Loss: 0.06212; Mean absolute error: 40.12500
Batch [131/183]; Loss: 0.07933; Mean absolute error: 63.81250
Batch [141/183]; Loss: 0.09035; Mean absolute error: 58.25000
Batch [151/183]; Loss: 0.07449; Mean ab

100%|██████████| 46/46 [00:47<00:00,  1.02s/it]


Validation Loss: 0.06681; Validation Metric: 44.52055
Epoch 2/10
Freeze transformer
----------
Batch [1/183]; Loss: 0.04602; Mean absolute error: 36.81250
Batch [11/183]; Loss: 0.04383; Mean absolute error: 51.00000
Batch [21/183]; Loss: 0.06834; Mean absolute error: 27.81250
Batch [31/183]; Loss: 0.04662; Mean absolute error: 28.25000
Batch [41/183]; Loss: 0.04322; Mean absolute error: 34.06250
Batch [51/183]; Loss: 0.07083; Mean absolute error: 45.18750
Batch [61/183]; Loss: 0.05911; Mean absolute error: 53.31250
Batch [71/183]; Loss: 0.07893; Mean absolute error: 59.93750
Batch [81/183]; Loss: 0.03519; Mean absolute error: 38.62500
Batch [91/183]; Loss: 0.04235; Mean absolute error: 47.50000
Batch [101/183]; Loss: 0.04889; Mean absolute error: 22.12500
Batch [111/183]; Loss: 0.05399; Mean absolute error: 24.18750
Batch [121/183]; Loss: 0.02676; Mean absolute error: 40.00000
Batch [131/183]; Loss: 0.02838; Mean absolute error: 23.43750
Batch [141/183]; Loss: 0.04928; Mean absolute er

100%|██████████| 46/46 [00:43<00:00,  1.05it/s]


Validation Loss: 0.05163; Validation Metric: 20.59452
Epoch 3/10
Freeze transformer
----------
Batch [1/183]; Loss: 0.04366; Mean absolute error: 19.12500
Batch [11/183]; Loss: 0.05490; Mean absolute error: 18.37500
Batch [21/183]; Loss: 0.03935; Mean absolute error: 20.68750
Batch [31/183]; Loss: 0.02577; Mean absolute error: 12.87500
Batch [41/183]; Loss: 0.05576; Mean absolute error: 20.62500
Batch [51/183]; Loss: 0.02477; Mean absolute error: 10.62500
Batch [61/183]; Loss: 0.03898; Mean absolute error: 25.37500
Batch [71/183]; Loss: 0.03936; Mean absolute error: 25.62500
Batch [81/183]; Loss: 0.04097; Mean absolute error: 14.75000
Batch [91/183]; Loss: 0.03692; Mean absolute error: 20.50000
Batch [101/183]; Loss: 0.02481; Mean absolute error: 6.68750
Batch [111/183]; Loss: 0.04031; Mean absolute error: 35.50000
Batch [121/183]; Loss: 0.02166; Mean absolute error: 12.37500
Batch [131/183]; Loss: 0.04032; Mean absolute error: 15.31250
Batch [141/183]; Loss: 0.03923; Mean absolute err

100%|██████████| 46/46 [00:44<00:00,  1.04it/s]


Validation Loss: 0.04963; Validation Metric: 18.84384
Epoch 4/10
Freeze transformer
----------
Batch [1/183]; Loss: 0.01945; Mean absolute error: 6.31250
Batch [11/183]; Loss: 0.06987; Mean absolute error: 24.50000
Batch [21/183]; Loss: 0.02924; Mean absolute error: 11.56250
Batch [31/183]; Loss: 0.02093; Mean absolute error: 18.25000
Batch [41/183]; Loss: 0.05690; Mean absolute error: 19.31250
Batch [51/183]; Loss: 0.01946; Mean absolute error: 12.50000
Batch [61/183]; Loss: 0.01902; Mean absolute error: 7.81250
Batch [71/183]; Loss: 0.03602; Mean absolute error: 23.00000
Batch [81/183]; Loss: 0.02870; Mean absolute error: 20.50000
Batch [91/183]; Loss: 0.03629; Mean absolute error: 12.25000
Batch [101/183]; Loss: 0.03286; Mean absolute error: 14.56250
Batch [111/183]; Loss: 0.02758; Mean absolute error: 9.31250
Batch [121/183]; Loss: 0.03881; Mean absolute error: 35.25000
Batch [131/183]; Loss: 0.02054; Mean absolute error: 9.75000
Batch [141/183]; Loss: 0.01888; Mean absolute error:

100%|██████████| 46/46 [00:45<00:00,  1.02it/s]


Validation Loss: 0.04624; Validation Metric: 17.76712
Epoch 5/10
Freeze transformer
----------
Batch [1/183]; Loss: 0.02916; Mean absolute error: 13.75000
Batch [11/183]; Loss: 0.03356; Mean absolute error: 31.43750
Batch [21/183]; Loss: 0.01953; Mean absolute error: 15.43750
Batch [31/183]; Loss: 0.01619; Mean absolute error: 6.87500
Batch [41/183]; Loss: 0.02369; Mean absolute error: 8.62500
Batch [51/183]; Loss: 0.02932; Mean absolute error: 19.93750
Batch [61/183]; Loss: 0.04251; Mean absolute error: 27.31250
Batch [71/183]; Loss: 0.01511; Mean absolute error: 7.62500
Batch [81/183]; Loss: 0.02230; Mean absolute error: 13.37500
Batch [91/183]; Loss: 0.01751; Mean absolute error: 11.37500
Batch [101/183]; Loss: 0.01205; Mean absolute error: 8.87500
Batch [111/183]; Loss: 0.01676; Mean absolute error: 5.50000
Batch [121/183]; Loss: 0.02062; Mean absolute error: 10.62500
Batch [131/183]; Loss: 0.02553; Mean absolute error: 17.68750
Batch [141/183]; Loss: 0.01728; Mean absolute error: 

100%|██████████| 46/46 [00:45<00:00,  1.00it/s]


Validation Loss: 0.04960; Validation Metric: 16.84658
Epoch 6/10
Finetune transformer
----------
Batch [1/183]; Loss: 0.02572; Mean absolute error: 8.37500
Batch [11/183]; Loss: 0.02046; Mean absolute error: 7.75000
Batch [21/183]; Loss: 0.01542; Mean absolute error: 5.00000
Batch [31/183]; Loss: 0.02257; Mean absolute error: 16.37500
Batch [41/183]; Loss: 0.02323; Mean absolute error: 8.87500
Batch [51/183]; Loss: 0.02098; Mean absolute error: 9.43750
Batch [61/183]; Loss: 0.02762; Mean absolute error: 13.87500
Batch [71/183]; Loss: 0.01383; Mean absolute error: 23.87500
Batch [81/183]; Loss: 0.01536; Mean absolute error: 8.50000
Batch [91/183]; Loss: 0.01899; Mean absolute error: 7.68750
Batch [101/183]; Loss: 0.02547; Mean absolute error: 11.50000
Batch [111/183]; Loss: 0.01708; Mean absolute error: 8.00000
Batch [121/183]; Loss: 0.01922; Mean absolute error: 7.56250
Batch [131/183]; Loss: 0.02196; Mean absolute error: 8.06250
Batch [141/183]; Loss: 0.03595; Mean absolute error: 9.1

100%|██████████| 46/46 [00:46<00:00,  1.01s/it]


Validation Loss: 0.04662; Validation Metric: 17.16986
Epoch 7/10
Finetune transformer
----------
Batch [1/183]; Loss: 0.01447; Mean absolute error: 7.50000
Batch [11/183]; Loss: 0.01448; Mean absolute error: 6.37500
Batch [21/183]; Loss: 0.02111; Mean absolute error: 8.18750
Batch [31/183]; Loss: 0.01382; Mean absolute error: 5.56250
Batch [41/183]; Loss: 0.01419; Mean absolute error: 5.37500
Batch [51/183]; Loss: 0.01474; Mean absolute error: 8.06250
Batch [61/183]; Loss: 0.01334; Mean absolute error: 5.25000
Batch [71/183]; Loss: 0.01809; Mean absolute error: 6.12500
Batch [81/183]; Loss: 0.02191; Mean absolute error: 12.06250
Batch [91/183]; Loss: 0.01524; Mean absolute error: 7.62500
Batch [101/183]; Loss: 0.01152; Mean absolute error: 5.00000
Batch [111/183]; Loss: 0.08918; Mean absolute error: 11.06250
Batch [121/183]; Loss: 0.01108; Mean absolute error: 5.50000
Batch [131/183]; Loss: 0.01181; Mean absolute error: 5.12500
Batch [141/183]; Loss: 0.01287; Mean absolute error: 5.875

100%|██████████| 46/46 [00:47<00:00,  1.03s/it]


Validation Loss: 0.04602; Validation Metric: 17.62877
Epoch 8/10
Finetune transformer
----------
Batch [1/183]; Loss: 0.01330; Mean absolute error: 7.37500
Batch [11/183]; Loss: 0.01692; Mean absolute error: 11.81250
Batch [21/183]; Loss: 0.01514; Mean absolute error: 8.31250
Batch [31/183]; Loss: 0.02421; Mean absolute error: 11.18750
Batch [41/183]; Loss: 0.01799; Mean absolute error: 7.06250
Batch [51/183]; Loss: 0.02732; Mean absolute error: 8.75000
Batch [61/183]; Loss: 0.01556; Mean absolute error: 8.81250
Batch [71/183]; Loss: 0.01025; Mean absolute error: 3.93750
Batch [81/183]; Loss: 0.01797; Mean absolute error: 15.31250
Batch [91/183]; Loss: 0.01838; Mean absolute error: 8.68750
Batch [101/183]; Loss: 0.01185; Mean absolute error: 7.12500
Batch [111/183]; Loss: 0.01189; Mean absolute error: 3.87500
Batch [121/183]; Loss: 0.01679; Mean absolute error: 6.75000
Batch [131/183]; Loss: 0.01715; Mean absolute error: 7.43750
Batch [141/183]; Loss: 0.01366; Mean absolute error: 8.25

100%|██████████| 46/46 [00:50<00:00,  1.09s/it]


Validation Loss: 0.04674; Validation Metric: 16.91096
Epoch 9/10
Finetune transformer
----------
Batch [1/183]; Loss: 0.02742; Mean absolute error: 8.06250
Batch [11/183]; Loss: 0.01666; Mean absolute error: 6.12500
Batch [21/183]; Loss: 0.01392; Mean absolute error: 7.12500
Batch [31/183]; Loss: 0.01057; Mean absolute error: 4.12500
Batch [41/183]; Loss: 0.01642; Mean absolute error: 6.43750
Batch [51/183]; Loss: 0.01552; Mean absolute error: 9.43750
Batch [61/183]; Loss: 0.01986; Mean absolute error: 9.50000
Batch [71/183]; Loss: 0.01220; Mean absolute error: 5.18750
Batch [81/183]; Loss: 0.01526; Mean absolute error: 8.68750
Batch [91/183]; Loss: 0.01278; Mean absolute error: 14.62500
Batch [101/183]; Loss: 0.01462; Mean absolute error: 7.00000
Batch [111/183]; Loss: 0.01602; Mean absolute error: 9.87500
Batch [121/183]; Loss: 0.01397; Mean absolute error: 4.81250
Batch [131/183]; Loss: 0.01096; Mean absolute error: 10.93750
Batch [141/183]; Loss: 0.01392; Mean absolute error: 5.937

100%|██████████| 46/46 [00:49<00:00,  1.07s/it]


Validation Loss: 0.04578; Validation Metric: 16.78493
Epoch 10/10
Finetune transformer
----------
Batch [1/183]; Loss: 0.02704; Mean absolute error: 24.31250
Batch [11/183]; Loss: 0.01934; Mean absolute error: 8.31250
Batch [21/183]; Loss: 0.01703; Mean absolute error: 14.31250
Batch [31/183]; Loss: 0.01351; Mean absolute error: 6.06250
Batch [41/183]; Loss: 0.01481; Mean absolute error: 9.18750
Batch [51/183]; Loss: 0.01268; Mean absolute error: 6.31250
Batch [61/183]; Loss: 0.02178; Mean absolute error: 7.31250
Batch [71/183]; Loss: 0.01337; Mean absolute error: 6.43750
Batch [81/183]; Loss: 0.01670; Mean absolute error: 5.50000
Batch [91/183]; Loss: 0.01604; Mean absolute error: 12.25000
Batch [101/183]; Loss: 0.01810; Mean absolute error: 7.00000
Batch [111/183]; Loss: 0.01416; Mean absolute error: 18.93750
Batch [121/183]; Loss: 0.01426; Mean absolute error: 9.75000
Batch [131/183]; Loss: 0.01816; Mean absolute error: 5.62500
Batch [141/183]; Loss: 0.02182; Mean absolute error: 6.

100%|██████████| 46/46 [00:57<00:00,  1.26s/it]

Validation Loss: 0.04623; Validation Metric: 16.70274





# Make predictions

In [44]:
import pandas as pd


def make_predictions(
    model,
    dataloader,
    device,
    results_dir,
    label_column,
    file_format="csv",
):
    model.eval()

    all_predictions = []
    all_true = []
    all_ids = []

    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader)):
            ids = batch["id"]
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            targets = batch["target"].to(device)
            corresponding_word = batch["corresponding_word"].to(device)

            _, logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=targets,
            )

            predictions, true_predictions = model.get_predictions_from_logits(
                logits=logits,
                labels=targets,
                corresponding_word=corresponding_word
            )

            all_predictions.extend(predictions.tolist())
            all_true.extend(true_predictions.tolist())
            all_ids.extend(ids)

    df_predictions = pd.DataFrame(
        {
            "id": all_ids,
            "true": all_true,
            label_column: all_predictions,
        }
    )

    if results_dir is not None:
        if file_format == "csv":
            df_predictions.to_csv(
                os.path.join(results_dir, "submission.csv"),
                index=False,
            )
        elif file_format == "jsonl":
            df_predictions.to_json(
                os.path.join(results_dir, "submission.jsonl"),
                orient="records",
                lines=True,
            )
        else:
            raise ValueError(f"Unknown file format: {file_format}")
    else:
        print("Missing results_dir, not saving predictions to file!")

    return df_predictions

In [45]:
predictions = make_predictions(
    best_model,
    test_dataloader,
    DEVICE,
    results_dir,
    config["data"]["label_column"],
    file_format="csv",
)

100%|██████████| 32/32 [00:41<00:00,  1.30s/it]


In [2]:
!python ../scores_and_plots.py --results-dir "../runs/30-12-2023_20:25:57-SubtaskC-bilstm_for_token_classification"

Results on validation
MAE: 16.70274
--------------------
Results on test
MAE: 18.14851
--------------------
