In [1]:
import os
import sys
import json
import torch

sys.path.append("../")
from lib.utils import get_device, get_current_date
from lib.utils.constants import Subtask, Track, PreprocessTextLevel, DatasetType
from lib.utils.models import sequential_fully_connected
from lib.data.loading import load_train_dev_test_df
from lib.data.tokenizer import get_tokenizer
from lib.training.optimizer import get_optimizer, get_scheduler
from lib.training.loss import get_loss_fn
from lib.training.metric import get_metric

  warn(


In [2]:
CONFIG_FILE_PATH = os.path.relpath("../config.json")

config = {}
with open(CONFIG_FILE_PATH, "r") as config_file:
    config = json.load(config_file)

DEVICE = get_device()
print(f"Using device: {DEVICE}")

Using device: cuda


In [3]:
# config

In [4]:
task = None
if "task" in config:
    task = Subtask(config["task"])
else:
    raise ValueError("Task not specified in config")

track = None
if "track" in config:
    track = Track(config["track"])
else:
    print(f"Warning: Track not specified in config for subtask: {task}")

dataset_type = DatasetType.TransformerTruncationDataset
if "dataset_type" in config["data"]:
    dataset_type = DatasetType(config["data"]["dataset_type"])

dataset_type_settings = None
if "dataset_type_settings" in config["data"]:
    dataset_type_settings = config["data"]["dataset_type_settings"]

df_train, df_dev, df_test = load_train_dev_test_df(
    task=task,
    track=track,
    data_dir=f"../{config['data']['data_dir']}",
    label_column=config["data"]["label_column"],
    test_size=config["data"]["test_size"],
    preprocess_text_level=PreprocessTextLevel(
        config["data"]["preprocess_text_level"]
    ),
)

print(f"df_train.shape: {df_train.shape}")
print(f"df_dev.shape: {df_dev.shape}")
print(f"df_test.shape: {df_test.shape}")

Loading train data...
Train/dev split... (df_train.shape: (3649, 3))
Loading test data....././data/original_data/SubtaskC/SubtaskC_dev.jsonl
df_train.shape: (2919, 3)
df_dev.shape: (730, 3)
df_test.shape: (505, 3)


In [5]:
results_dir = os.path.relpath(
    f"../runs/{get_current_date()}-{task.value}-{config['model']}"
)
if not os.path.exists(results_dir):
    os.makedirs(results_dir)

print(f"Will save results to: {results_dir}")

with open(results_dir + "/config.json", "w") as f:
    json.dump(config, f, indent=4)

Will save results to: ../runs/29-12-2023_13:41:34-SubtaskC-longformer_bilstm


# Build the dataset

In [6]:
import numpy as np
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer


class TokenClassificationDataset(Dataset):
    def __init__(
        self,
        ids: np.ndarray,
        texts: np.ndarray,
        targets: np.ndarray | None,
        tokenizer: PreTrainedTokenizer | None,
        max_len: int,
        debug: bool = False,
    ):
        super().__init__()

        if tokenizer is None:
            raise ValueError("Tokenizer cannot be None")

        self.ids = ids
        self.texts = texts
        self.targets = targets
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.debug = debug

    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, index):
        item_id = self.ids[index]
        text = self.texts[index]
        target = -1 if self.targets is None else self.targets[index]
        targets_available = False if target == -1 else True

        text = text.replace("\n", " ")
        text = text.replace("\t", " ")
        text = text.replace("\r", " ")

        words = [w for w in text.split(" ") if w != ""]

        if self.debug:
            print(f"Text: {text}")
            print(f"Words: {words}")
            print(f"Machine text start position: {target}")
            print()

        targets = []
        corresponding_word = []
        tokens = []
        input_ids = []
        attention_mask = []

        for idx, word in enumerate(words):
            word_encoded = self.tokenizer.tokenize(word)  # No [CLS] or [SEP]
            sub_words = len(word_encoded)

            if targets_available:
                is_machine_text = 1 if idx >= target else 0
                targets.extend([is_machine_text] * sub_words)

            corresponding_word.extend([idx] * sub_words)
            tokens.extend(word_encoded)
            input_ids.extend(self.tokenizer.convert_tokens_to_ids(word_encoded))
            attention_mask.extend([1] * sub_words)

            if self.debug:
                print(
                    f"word[{idx}]:\n"
                    f"{'':-<5}> tokens: {word_encoded} (no. of subwords: {sub_words})\n"
                    f"{'':-<5}> corresponding_word: {corresponding_word[-sub_words:]}\n"
                    f"{'':-<5}> input_ids: {input_ids[-sub_words:]}\n"
                    f"{'':-<5}> is_machine_text: {is_machine_text}"
                )

        if self.debug:
            print()

            print(f"corresponding_word: {corresponding_word}")
            print(f"tokens: {tokens}")
            print(f"input_ids: {input_ids}")
            print(f"attention_mask: {attention_mask}")

            print()

            print(f"Machine text start word: {words[corresponding_word[targets.index(1)]]}")
            print(f"True machine text start word: {words[target]}")

            print()

        if len(input_ids) < self.max_len - 2:
            if targets_available:
                targets = (
                    [-100]
                    + targets
                    + [-100] * (self.max_len - len(input_ids) - 1)
                )

            corresponding_word = (
                [-100]
                + corresponding_word
                + [-100] * (self.max_len - len(input_ids) - 1)
            )
            tokens = (
                [self.tokenizer.bos_token]
                + tokens
                + [self.tokenizer.eos_token]
                + [self.tokenizer.pad_token] * (self.max_len - len(tokens) - 2)
            )
            input_ids = (
                [self.tokenizer.bos_token_id]
                + input_ids
                + [self.tokenizer.eos_token_id]
                + [self.tokenizer.pad_token_id] * (self.max_len - len(input_ids) - 2)
            )
            attention_mask = (
                [1]
                + attention_mask
                + [1]
                + [0] * (self.max_len - len(attention_mask) - 2)
            )
        else:
            if targets_available:
                targets = [-100] + targets[: self.max_len - 2] + [-100]

            corresponding_word = (
                [-100]
                + corresponding_word[: self.max_len - 2]
                + [-100]
            )
            tokens = (
                [self.tokenizer.bos_token]
                + tokens[: self.max_len - 2]
                + [self.tokenizer.eos_token]
            )
            input_ids = (
                [self.tokenizer.bos_token_id]
                + input_ids[: self.max_len - 2]
                + [self.tokenizer.eos_token_id]
            )
            attention_mask = (
                [1]
                + attention_mask[: self.max_len - 2]
                + [1]
            )

        encoded = {}
        encoded["id"] = item_id
        encoded["text"] = text
        encoded["true_target"] = torch.tensor(target)
        encoded["corresponding_word"] = torch.tensor(corresponding_word)
        encoded["input_ids"] = torch.tensor(input_ids)
        encoded["attention_mask"] = torch.tensor(attention_mask)
        if targets_available:
            encoded["target"] = torch.tensor(targets)

        if self.debug:
            print(f"Tokenized human position: {targets.index(1)}")
            print(f"Original human position: {target}")
            print(f"Full human text: {text}\n\n")
            print(f"Human truncated text: {[w for w in text.split(' ')[:target] if w != '']}\n\n")

            encoded["partial_human_review"] = " ".join(
                [w for w in text.split(' ')[:target] if w != '']
            )

        return encoded

In [7]:
from torch.utils.data import DataLoader

tokenizer = get_tokenizer(**config["tokenizer"])

train_dataset = TokenClassificationDataset(
    ids=df_train["id"].values,
    texts=df_train["text"].values,
    targets=df_train["label"].values,
    tokenizer=tokenizer,
    max_len=config["data"]["max_len"],
    debug=False,
)
dev_dataset = TokenClassificationDataset(
    ids=df_dev["id"].values,
    texts=df_dev["text"].values,
    targets=df_dev["label"].values,
    tokenizer=tokenizer,
    max_len=config["data"]["max_len"],
    debug=False,
)
test_dataset = TokenClassificationDataset(
    ids=df_test["id"].values,
    texts=df_test["text"].values,
    targets=df_test["label"].values,
    tokenizer=tokenizer,
    max_len=config["data"]["max_len"],
    debug=False,
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=config["data"]["batch_size"],
    shuffle=True,
)
dev_dataloader = DataLoader(
    dev_dataset,
    batch_size=config["data"]["batch_size"],
    shuffle=False,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=config["data"]["batch_size"],
    shuffle=False,
)

In [7]:
# for i, batch in enumerate(train_dataloader):
#     print(f"Batch=[{i + 1}/{len(train_dataloader)}]")
#     # break

# for i, batch in enumerate(dev_dataloader):
#     print(f"Batch=[{i + 1}/{len(dev_dataloader)}]")
#     # break

# Create LongformerBiLSTM model for token classification

In [8]:
import torch.nn as nn
from transformers import LongformerModel


class LongformerBiLSTMForTokenClassification(nn.Module):
    def __init__(
        self,
        pretrained_model_name,
        out_size,
        device,
        dropout_p=0.3,
        last_layers_emb=4,
        hidden_dim=200,
        fc=[],
        finetune_last_layers_emb=False,
    ):
        super().__init__()

        self.out_size = out_size
        self.device = device
        self.last_layers_emb = last_layers_emb
        self.hidden_dim = hidden_dim
        self.finetune_last_layers_emb = finetune_last_layers_emb

        self.longformer = LongformerModel.from_pretrained(
            pretrained_model_name, return_dict=False, output_hidden_states=True,
        )

        embedding_dim = last_layers_emb * self.longformer.config.hidden_size
        self.lstm = nn.LSTM(
            embedding_dim,
            hidden_dim,
            bidirectional=True,
            batch_first=True,
        )

        self.dropout = nn.Dropout(p=dropout_p)
#         self.classifier = nn.Linear(2 * hidden_dim, out_size)
        self.classifier = sequential_fully_connected(
            2 * hidden_dim, out_size, fc, dropout_p
        )

        self.freeze_transformer_layer()

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.longformer(
            input_ids=input_ids, attention_mask=attention_mask
        )
        hidden_states = outputs[2]
        # print(f"hidden_states: {hidden_states}")

        embeddings = hidden_states[-self.last_layers_emb :]
        # print(f"embeddings.shape: {embeddings.shape}")

        # embeddings: (batch_size, max_seq_len, last_layers_emb * hidden_size)
        embeddings = torch.cat(embeddings, dim=2)
        # print(f"embeddings.shape: {embeddings.shape}")

        lengths = attention_mask.sum(dim=1)
        # print(f"lengths.shape: {lengths.shape}")

        packed_embeddings = nn.utils.rnn.pack_padded_sequence(
            embeddings, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        # print(f"packed_embeddings.data.shape: {packed_embeddings.data.shape}")

        packed_output, (_, _) = self.lstm(packed_embeddings)
        # print(f"packed_output.data.shape: {packed_output.data.shape}")

        output, _ = nn.utils.rnn.pad_packed_sequence(
            packed_output, batch_first=True, total_length=embeddings.shape[1],
        )
        # print(f"output.shape: {output.shape}")

        output = self.dropout(output)
        logits = self.classifier(output)

        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss().to(self.device)
            loss = loss_fn(logits.view(-1, self.out_size), labels.view(-1))

        return loss, logits

    def freeze_transformer_layer(self):
        for param in self.longformer.parameters():
            param.requires_grad = False

    def unfreeze_transformer_layer(self):
        if self.finetune_last_layers_emb:
            # print(f"Will fine-tune last {self.last_layers_emb} layers")
            # Fine-tune only the last emb layers
#             for layer in self.longformer.encoder.layer[-self.last_layers_emb :]:
            for layer in self.longformer.encoder.layer[-1:]:
                for param in layer.parameters():
                    param.requires_grad = True
        else:
            # Do nothing
            # print(f"Transformer used as feature extractor only => no fine-tuning")
            pass

    def get_predictions_from_logits(self, logits, labels=None, corresponding_word=None):
        batch_size = logits.shape[0]

        # logits: (batch_size, max_seq_len, out_size)
        # labels: (batch_size, max_seq_len)
        # corresponding_word: (batch_size, max_seq_len)

        # print(f"logits.shape: {logits.shape}")
        # print(f"logits: {logits}")

        # preds: (batch_size, max_seq_len)
        preds = torch.argmax(logits, dim=-1)

        # print(f"preds.shape: {preds.shape}")
        # print(f"preds: {preds}")

        if labels is not None:
            # print(f"labels.shape: {labels.shape}")
            # print(f"labels: {labels}")

            # Keep only predictions where labels are not -100
            # clean_preds = preds[labels != -100].reshape(batch_size, -1)
            # clean_labels = labels[labels != -100].reshape(batch_size, -1)

            # print(f"clean_preds.shape: {clean_preds.shape}")
            # print(f"clean_preds: {clean_preds}")

            # print(f"clean_labels.shape: {clean_labels.shape}")
            # print(f"clean_labels: {clean_labels}")

            # Get the index of the first machine text word
            # predicted_positions = clean_preds.argmax(dim=-1)
            # true_positions = clean_labels.argmax(dim=-1)

            predicted_positions = []
            true_positions = []
            for p, l in zip(preds, labels):
                mask = l != -100

                clean_pred = p[mask]
                clean_label = l[mask]

                # print(f"clean_pred.shape: {clean_pred.shape}")
                # print(f"clean_pred: {clean_pred}")
                # print(f"clean_label.shape: {clean_label.shape}")
                # print(f"clean_label: {clean_label}")

                predicted_position = clean_pred.argmax(dim=-1)
                true_position = clean_label.argmax(dim=-1)

                # print(f"predicted_position: {predicted_position}")
                # print(f"true_position: {true_position}")

                predicted_positions.append(predicted_position.item())
                true_positions.append(true_position.item())

            # print(f"predicted_positions.shape: {predicted_positions.shape}")
            # print(f"predicted_positions: {predicted_positions}")

            # print(f"true_positions.shape: {true_positions.shape}")
            # print(f"true_positions: {true_positions}")

            # print(f"predicted_positions type: {type(predicted_positions)}")
            # print(f"true_positions type: {type(true_positions)}")

            return torch.Tensor(predicted_positions), torch.Tensor(true_positions)
        elif corresponding_word is not None:
            # print(f"corresponding_word.shape: {corresponding_word.shape}")
            # print(f"corresponding_word: {corresponding_word}")

            # Keep only predictions where corresponding_word are not -100
            # clean_preds = preds[corresponding_word != -100].reshape(
            #     batch_size, -1
            # ).detach().cpu().numpy()
            # clean_corresponding_word = corresponding_word[corresponding_word != -100].reshape(
            #     batch_size, -1
            # ).detach().cpu().numpy()

            # print(f"clean_preds.shape: {clean_preds.shape}")
            # print(f"clean_preds: {clean_preds}")

            # print(f"clean_corresponding_word.shape: {clean_corresponding_word.shape}")
            # print(f"clean_corresponding_word: {clean_corresponding_word}")

            predicted_positions = []
            for p, w in zip(preds, corresponding_word):
                mask = w != -100

                clean_pred = p[mask]
                clean_corresponding_word = w[mask]

                # print(f"clean_pred.shape: {clean_pred.shape}")
                # print(f"clean_pred: {clean_pred}")
                # print(f"clean_corresponding_word.shape: {clean_corresponding_word.shape}")
                # print(f"clean_corresponding_word: {clean_corresponding_word}")

                # Get the index of the first machine text word
                index = torch.where(clean_pred == 1)[0]
                value = index[0] if index.size else len(clean_pred) - 1
                position = clean_corresponding_word[value]

                # print(f"index: {index}")
                # print(f"value: {value}")
                # print(f"position: {position}")

                predicted_positions.append(position.item())
            #     # pred = pred.detach().cpu().numpy()

            #     index = np.where(pred == 1)[0]
            #     value = index[0] if index.size else len(pred) - 1
            #     position = clean_corresponding_word[idx][value]

            #     predicted_positions.append(position.item())

            print(f"predicted_positions: {predicted_positions}")

            return predicted_positions, None
        else:
            raise ValueError("Either labels or corresponding_word must be provided")

# Train model

In [9]:
import pandas as pd
from tqdm import tqdm
# from time import time
from collections import defaultdict


def train_epoch(
    model,
    dataloader,
    loss_fn,
    optimizer,
    device,
    scheduler,
    metric_fn,
    print_freq=10,
):
    model.train()

    losses = []

    all_predictions = []
    all_true = []
    all_ids = []

    for i, batch in enumerate(dataloader):
        ids = batch["id"]
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        targets = batch["target"].to(device)
        corresponding_word = batch["corresponding_word"].to(device)

        loss, logits = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=targets,
        )

        predictions, true_predictions = model.get_predictions_from_logits(
            logits=logits,
            labels=targets,
            corresponding_word=corresponding_word
        )

        # print(f"predictions: {predictions}")
        # print(f"true_predictions: {true_predictions}")

        losses.append(loss.item())

        all_predictions.extend(predictions.tolist())
        all_true.extend(true_predictions.tolist())
        all_ids.extend(ids)

        if i % print_freq == 0:
            print(
                f"Batch [{i + 1}/{len(dataloader)}]; "
                f"Loss: {loss.item():.5f}; "
                f"Mean absolute error: {metric_fn(true_predictions, predictions):.5f}"
            )

        loss.backward()

        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        optimizer.zero_grad()

    return np.mean(losses), (all_ids, all_true, all_predictions)


def validation_epoch(
    model,
    dataloader,
    loss_fn,
    device,
    metric_fn,
):
    model.eval()

    losses = []
    all_predictions = []
    all_true = []
    all_ids = []

    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader)):
            ids = batch["id"]
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            targets = batch["target"].to(device)
            corresponding_word = batch["corresponding_word"].to(device)

            loss, logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=targets,
            )

            predictions, true_predictions = model.get_predictions_from_logits(
                logits=logits,
                labels=targets,
                corresponding_word=corresponding_word
            )

            losses.append(loss.item())

            all_predictions.extend(predictions.tolist())
            all_true.extend(true_predictions.tolist())
            all_ids.extend(ids)

    return np.mean(losses), (all_ids, all_true, all_predictions)


def training_loop(
    model,
    num_epochs,
    train_dataloader,
    dev_dataloader,
    loss_fn,
    optimizer_config,
    scheduler_config,
    device,
    metric_fn,
    is_better_metric_fn,
    num_epochs_before_finetune,
    results_dir,
):
    history = defaultdict(list)
    best_metric = None
    best_model_state = None

    optimizer = get_optimizer(model, optimizer_config, finetune=False)
    scheduler = None

    for epoch in range(1, num_epochs + 1):
        print(f"Epoch {epoch}/{num_epochs}")
        if epoch <= num_epochs_before_finetune:
            print("Freeze transformer")
        else:
            print("Finetune transformer")
        print("-" * 10)

        if epoch == num_epochs_before_finetune + 1:
            model.unfreeze_transformer_layer()
            optimizer = get_optimizer(model, optimizer_config, finetune=True)
            scheduler = get_scheduler(
                optimizer,
                num_training_steps=len(train_dataloader) * num_epochs,
                **scheduler_config,
            )

        train_loss, (train_ids, train_true, train_predict) = train_epoch(
            model,
            train_dataloader,
            loss_fn,
            optimizer,
            device,
            scheduler,
            metric_fn,
        )

        train_metric = metric_fn(train_true, train_predict)

        print(f"Train Loss: {train_loss:.5f}; Train Metric: {train_metric:.5f}")

        dev_loss, (dev_ids, dev_true, dev_predict) = validation_epoch(
            model,
            dev_dataloader,
            loss_fn,
            device,
            metric_fn,
        )

        dev_metric = metric_fn(dev_true, dev_predict)

        print(
            f"Validation Loss: {dev_loss:.5f}; "
            f"Validation Metric: {dev_metric:.5f}"
        )

        history["train_metric"].append(train_metric)
        history["train_loss"].append(train_loss)
        history["dev_metric"].append(dev_metric)
        history["dev_loss"].append(dev_loss)

        if best_metric is None or is_better_metric_fn(train_metric, best_metric):
            best_metric = train_metric
            best_model_state = model.state_dict()
            
            if results_dir is not None:
                torch.save(
                    best_model_state,
                    os.path.join(results_dir, "best_model.bin"),
                )

                df_train_predictions = pd.DataFrame(
                    {
                        "id": train_ids,
                        "true": train_true,
                        "predict": train_predict,
                    }
                )
                df_train_predictions.to_csv(
                    os.path.join(results_dir, "best_model_train_predict.csv"),
                    index=False
                )

                df_dev_predictions = pd.DataFrame(
                    {
                        "id": dev_ids,
                        "true": dev_true,
                        "predict": dev_predict,
                    }
                )
                df_dev_predictions.to_csv(
                    os.path.join(results_dir, "best_model_dev_predict.csv"),
                    index=False
                )

    df_history = pd.DataFrame(history)
    if results_dir is not None:
        df_history.to_csv(os.path.join(results_dir, "history.csv"), index=False)

        model.load_state_dict(torch.load(os.path.join(results_dir, "best_model.bin")))
    else:
        model.load_state_dict(best_model_state)

    return model, df_history

In [10]:
import torch

torch.cuda.empty_cache()

In [11]:
num_epochs = config["training"]["num_epochs"]
model = LongformerBiLSTMForTokenClassification(
    device=DEVICE, **config["model_config"]
).to(DEVICE)
loss_fn = get_loss_fn(config["training"]["loss"], DEVICE)
optimizer_config = config["training"]["optimizer"]
scheduler_config = config["training"]["scheduler"]
metric_fn, is_better_metric_fn = get_metric(config["training"]["metric"])
num_epochs_before_finetune = config["training"]["num_epochs_before_finetune"]

best_model, df_history = training_loop(
    model,
    num_epochs,
    train_dataloader,
    dev_dataloader,
    loss_fn,
    optimizer_config,
    scheduler_config,
    DEVICE,
    metric_fn,
    is_better_metric_fn,
    num_epochs_before_finetune,
    results_dir,
)

Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerModel: ['lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.bias']
- This IS expected if you are initializing LongformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Epoch 1/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.84567; Mean absolute error: 72.03125
Batch [11/92]; Loss: 0.46727; Mean absolute error: 103.31250
Batch [21/92]; Loss: 0.29498; Mean absolute error: 41.18750
Batch [31/92]; Loss: 0.19255; Mean absolute error: 44.28125
Batch [41/92]; Loss: 0.20362; Mean absolute error: 40.81250
Batch [51/92]; Loss: 0.14458; Mean absolute error: 27.03125
Batch [61/92]; Loss: 0.18996; Mean absolute error: 20.28125
Batch [71/92]; Loss: 0.12622; Mean absolute error: 14.37500
Batch [81/92]; Loss: 0.14130; Mean absolute error: 11.87500
Batch [91/92]; Loss: 0.12274; Mean absolute error: 18.12500
Train Loss: 0.26239; Train Metric: 44.49469


100%|██████████| 23/23 [00:34<00:00,  1.51s/it]


Validation Loss: 0.13114; Validation Metric: 19.40137
Epoch 2/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.09140; Mean absolute error: 12.15625
Batch [11/92]; Loss: 0.08971; Mean absolute error: 10.75000
Batch [21/92]; Loss: 0.11847; Mean absolute error: 18.93750
Batch [31/92]; Loss: 0.09492; Mean absolute error: 16.78125
Batch [41/92]; Loss: 0.33100; Mean absolute error: 22.28125
Batch [51/92]; Loss: 0.08187; Mean absolute error: 10.62500
Batch [61/92]; Loss: 0.09596; Mean absolute error: 17.21875
Batch [71/92]; Loss: 0.08397; Mean absolute error: 14.43750
Batch [81/92]; Loss: 0.06899; Mean absolute error: 13.96875
Batch [91/92]; Loss: 0.17210; Mean absolute error: 35.15625
Train Loss: 0.10168; Train Metric: 16.34943


100%|██████████| 23/23 [00:34<00:00,  1.51s/it]


Validation Loss: 0.14549; Validation Metric: 22.44932
Epoch 3/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.08393; Mean absolute error: 13.09375
Batch [11/92]; Loss: 0.14445; Mean absolute error: 16.62500
Batch [21/92]; Loss: 0.06803; Mean absolute error: 10.31250
Batch [31/92]; Loss: 0.06224; Mean absolute error: 13.46875
Batch [41/92]; Loss: 0.05307; Mean absolute error: 8.46875
Batch [51/92]; Loss: 0.16532; Mean absolute error: 20.12500
Batch [61/92]; Loss: 0.05636; Mean absolute error: 7.78125
Batch [71/92]; Loss: 0.23875; Mean absolute error: 29.78125
Batch [81/92]; Loss: 0.06318; Mean absolute error: 10.62500
Batch [91/92]; Loss: 0.11329; Mean absolute error: 30.12500
Train Loss: 0.08104; Train Metric: 13.32066


100%|██████████| 23/23 [00:34<00:00,  1.50s/it]


Validation Loss: 0.10241; Validation Metric: 14.45342
Epoch 4/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.06812; Mean absolute error: 7.71875
Batch [11/92]; Loss: 0.06843; Mean absolute error: 8.62500
Batch [21/92]; Loss: 0.06917; Mean absolute error: 15.65625
Batch [31/92]; Loss: 0.08665; Mean absolute error: 10.09375
Batch [41/92]; Loss: 0.09537; Mean absolute error: 9.68750
Batch [51/92]; Loss: 0.07508; Mean absolute error: 10.75000
Batch [61/92]; Loss: 0.07319; Mean absolute error: 12.03125
Batch [71/92]; Loss: 0.03224; Mean absolute error: 4.50000
Batch [81/92]; Loss: 0.05710; Mean absolute error: 8.09375
Batch [91/92]; Loss: 0.05125; Mean absolute error: 8.78125
Train Loss: 0.07390; Train Metric: 11.86776


100%|██████████| 23/23 [00:34<00:00,  1.50s/it]


Validation Loss: 0.12724; Validation Metric: 16.44658
Epoch 5/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.14245; Mean absolute error: 13.21875
Batch [11/92]; Loss: 0.04852; Mean absolute error: 8.18750
Batch [21/92]; Loss: 0.07465; Mean absolute error: 19.96875
Batch [31/92]; Loss: 0.04294; Mean absolute error: 5.90625
Batch [41/92]; Loss: 0.09929; Mean absolute error: 12.53125
Batch [51/92]; Loss: 0.02988; Mean absolute error: 3.56250
Batch [61/92]; Loss: 0.03409; Mean absolute error: 11.00000
Batch [71/92]; Loss: 0.07028; Mean absolute error: 10.00000
Batch [81/92]; Loss: 0.07778; Mean absolute error: 10.25000
Batch [91/92]; Loss: 0.05436; Mean absolute error: 8.03125
Train Loss: 0.06393; Train Metric: 10.11168


100%|██████████| 23/23 [00:34<00:00,  1.50s/it]


Validation Loss: 0.23503; Validation Metric: 32.68356
Epoch 6/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.07740; Mean absolute error: 19.71875
Batch [11/92]; Loss: 0.08504; Mean absolute error: 24.87500
Batch [21/92]; Loss: 0.05042; Mean absolute error: 18.90625
Batch [31/92]; Loss: 0.05729; Mean absolute error: 14.25000
Batch [41/92]; Loss: 0.03995; Mean absolute error: 7.75000
Batch [51/92]; Loss: 0.04020; Mean absolute error: 5.31250
Batch [61/92]; Loss: 0.04621; Mean absolute error: 7.31250
Batch [71/92]; Loss: 0.04782; Mean absolute error: 10.90625
Batch [81/92]; Loss: 0.03971; Mean absolute error: 7.87500
Batch [91/92]; Loss: 0.04272; Mean absolute error: 4.65625
Train Loss: 0.05536; Train Metric: 9.36245


100%|██████████| 23/23 [00:34<00:00,  1.51s/it]


Validation Loss: 0.12892; Validation Metric: 15.88219
Epoch 7/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.03400; Mean absolute error: 5.40625
Batch [11/92]; Loss: 0.05228; Mean absolute error: 10.18750
Batch [21/92]; Loss: 0.07395; Mean absolute error: 9.53125
Batch [31/92]; Loss: 0.03430; Mean absolute error: 3.65625
Batch [41/92]; Loss: 0.03255; Mean absolute error: 4.12500
Batch [51/92]; Loss: 0.04520; Mean absolute error: 5.28125
Batch [61/92]; Loss: 0.05377; Mean absolute error: 6.25000
Batch [71/92]; Loss: 0.03979; Mean absolute error: 6.12500
Batch [81/92]; Loss: 0.06767; Mean absolute error: 9.59375
Batch [91/92]; Loss: 0.06425; Mean absolute error: 14.28125
Train Loss: 0.04636; Train Metric: 7.74854


100%|██████████| 23/23 [00:34<00:00,  1.50s/it]


Validation Loss: 0.16404; Validation Metric: 23.12329
Epoch 8/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.04803; Mean absolute error: 8.81250
Batch [11/92]; Loss: 0.03732; Mean absolute error: 7.15625
Batch [21/92]; Loss: 0.06730; Mean absolute error: 11.68750
Batch [31/92]; Loss: 0.03315; Mean absolute error: 5.53125
Batch [41/92]; Loss: 0.03275; Mean absolute error: 4.87500
Batch [51/92]; Loss: 0.02942; Mean absolute error: 4.59375
Batch [61/92]; Loss: 0.03921; Mean absolute error: 6.18750
Batch [71/92]; Loss: 0.02641; Mean absolute error: 4.00000
Batch [81/92]; Loss: 0.03521; Mean absolute error: 3.53125
Batch [91/92]; Loss: 0.03618; Mean absolute error: 6.12500
Train Loss: 0.04123; Train Metric: 6.73929


100%|██████████| 23/23 [00:34<00:00,  1.50s/it]


Validation Loss: 0.11267; Validation Metric: 12.80959
Epoch 9/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.04946; Mean absolute error: 6.15625
Batch [11/92]; Loss: 0.02857; Mean absolute error: 4.34375
Batch [21/92]; Loss: 0.03764; Mean absolute error: 5.90625
Batch [31/92]; Loss: 0.04730; Mean absolute error: 6.50000
Batch [41/92]; Loss: 0.03463; Mean absolute error: 7.56250
Batch [51/92]; Loss: 0.03263; Mean absolute error: 4.53125
Batch [61/92]; Loss: 0.02952; Mean absolute error: 3.62500
Batch [71/92]; Loss: 0.03091; Mean absolute error: 4.65625
Batch [81/92]; Loss: 0.02987; Mean absolute error: 5.00000
Batch [91/92]; Loss: 0.02519; Mean absolute error: 5.25000
Train Loss: 0.03403; Train Metric: 5.64406


100%|██████████| 23/23 [00:34<00:00,  1.50s/it]


Validation Loss: 0.12558; Validation Metric: 14.64384
Epoch 10/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.02589; Mean absolute error: 4.00000
Batch [11/92]; Loss: 0.02143; Mean absolute error: 2.90625
Batch [21/92]; Loss: 0.03899; Mean absolute error: 6.87500
Batch [31/92]; Loss: 0.03685; Mean absolute error: 19.03125
Batch [41/92]; Loss: 0.03808; Mean absolute error: 6.65625
Batch [51/92]; Loss: 0.02449; Mean absolute error: 6.75000
Batch [61/92]; Loss: 0.02226; Mean absolute error: 5.00000
Batch [71/92]; Loss: 0.02616; Mean absolute error: 3.71875
Batch [81/92]; Loss: 0.03138; Mean absolute error: 6.68750
Batch [91/92]; Loss: 0.02850; Mean absolute error: 3.75000
Train Loss: 0.03032; Train Metric: 5.38335


100%|██████████| 23/23 [00:34<00:00,  1.50s/it]


Validation Loss: 0.12239; Validation Metric: 14.40411
Epoch 11/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.02805; Mean absolute error: 5.06250
Batch [11/92]; Loss: 0.02559; Mean absolute error: 2.59375
Batch [21/92]; Loss: 0.02320; Mean absolute error: 3.75000
Batch [31/92]; Loss: 0.02783; Mean absolute error: 4.78125
Batch [41/92]; Loss: 0.02631; Mean absolute error: 4.53125
Batch [51/92]; Loss: 0.03409; Mean absolute error: 4.28125
Batch [61/92]; Loss: 0.02341; Mean absolute error: 3.78125
Batch [71/92]; Loss: 0.02386; Mean absolute error: 4.03125
Batch [81/92]; Loss: 0.02560; Mean absolute error: 3.56250
Batch [91/92]; Loss: 0.03294; Mean absolute error: 5.96875
Train Loss: 0.02754; Train Metric: 4.96471


100%|██████████| 23/23 [00:34<00:00,  1.50s/it]


Validation Loss: 0.15649; Validation Metric: 15.35205
Epoch 12/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.01696; Mean absolute error: 2.65625
Batch [11/92]; Loss: 0.01751; Mean absolute error: 3.31250
Batch [21/92]; Loss: 0.03636; Mean absolute error: 17.12500
Batch [31/92]; Loss: 0.01944; Mean absolute error: 2.81250
Batch [41/92]; Loss: 0.02025; Mean absolute error: 2.75000
Batch [51/92]; Loss: 0.02124; Mean absolute error: 3.43750
Batch [61/92]; Loss: 0.02607; Mean absolute error: 2.90625
Batch [71/92]; Loss: 0.02710; Mean absolute error: 3.46875
Batch [81/92]; Loss: 0.02494; Mean absolute error: 4.00000
Batch [91/92]; Loss: 0.03011; Mean absolute error: 9.06250
Train Loss: 0.02376; Train Metric: 4.36519


100%|██████████| 23/23 [00:34<00:00,  1.50s/it]


Validation Loss: 0.17035; Validation Metric: 16.47260
Epoch 13/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.02246; Mean absolute error: 4.87500
Batch [11/92]; Loss: 0.02057; Mean absolute error: 3.46875
Batch [21/92]; Loss: 0.01682; Mean absolute error: 2.21875
Batch [31/92]; Loss: 0.02800; Mean absolute error: 2.81250
Batch [41/92]; Loss: 0.02249; Mean absolute error: 3.21875
Batch [51/92]; Loss: 0.01658; Mean absolute error: 2.43750
Batch [61/92]; Loss: 0.03899; Mean absolute error: 22.37500
Batch [71/92]; Loss: 0.02378; Mean absolute error: 3.56250
Batch [81/92]; Loss: 0.01800; Mean absolute error: 5.09375
Batch [91/92]; Loss: 0.02176; Mean absolute error: 3.00000
Train Loss: 0.02070; Train Metric: 3.81877


100%|██████████| 23/23 [00:34<00:00,  1.50s/it]


Validation Loss: 0.14790; Validation Metric: 13.18082
Epoch 14/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.01466; Mean absolute error: 2.62500
Batch [11/92]; Loss: 0.01633; Mean absolute error: 3.15625
Batch [21/92]; Loss: 0.01943; Mean absolute error: 2.71875
Batch [31/92]; Loss: 0.01420; Mean absolute error: 3.03125
Batch [41/92]; Loss: 0.01880; Mean absolute error: 3.21875
Batch [51/92]; Loss: 0.01921; Mean absolute error: 2.68750
Batch [61/92]; Loss: 0.01932; Mean absolute error: 2.93750
Batch [71/92]; Loss: 0.01919; Mean absolute error: 2.18750
Batch [81/92]; Loss: 0.01550; Mean absolute error: 3.65625
Batch [91/92]; Loss: 0.01291; Mean absolute error: 2.09375
Train Loss: 0.01853; Train Metric: 3.48338


100%|██████████| 23/23 [00:34<00:00,  1.51s/it]


Validation Loss: 0.15630; Validation Metric: 13.96438
Epoch 15/15
Finetune transformer
----------
Batch [1/92]; Loss: 0.02044; Mean absolute error: 3.25000
Batch [11/92]; Loss: 0.01839; Mean absolute error: 4.78125
Batch [21/92]; Loss: 0.01729; Mean absolute error: 2.21875
Batch [31/92]; Loss: 0.02009; Mean absolute error: 2.81250
Batch [41/92]; Loss: 0.01580; Mean absolute error: 2.62500
Batch [51/92]; Loss: 0.01715; Mean absolute error: 2.56250
Batch [61/92]; Loss: 0.01480; Mean absolute error: 2.21875
Batch [71/92]; Loss: 0.01571; Mean absolute error: 1.53125
Batch [81/92]; Loss: 0.02114; Mean absolute error: 2.68750
Batch [91/92]; Loss: 0.01828; Mean absolute error: 3.43750
Train Loss: 0.01703; Train Metric: 3.04967


100%|██████████| 23/23 [00:34<00:00,  1.50s/it]


Validation Loss: 0.15478; Validation Metric: 13.62055


# Make predictions

In [12]:
import pandas as pd


def make_predictions(
    model,
    dataloader,
    device,
    results_dir,
    label_column,
    file_format="csv",
):
    model.eval()

    all_predictions = []
    all_true = []
    all_ids = []

    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader)):
            ids = batch["id"]
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            targets = batch["target"].to(device)
            corresponding_word = batch["corresponding_word"].to(device)

            _, logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=targets,
            )

            predictions, true_predictions = model.get_predictions_from_logits(
                logits=logits,
                labels=targets,
                corresponding_word=corresponding_word
            )

            all_predictions.extend(predictions.tolist())
            all_true.extend(true_predictions.tolist())
            all_ids.extend(ids)

    df_predictions = pd.DataFrame(
        {
            "id": all_ids,
            "true": all_true,
            label_column: all_predictions,
        }
    )

    if results_dir is not None:
        if file_format == "csv":
            df_predictions.to_csv(
                os.path.join(results_dir, "submission.csv"),
                index=False,
            )
        elif file_format == "jsonl":
            df_predictions.to_json(
                os.path.join(results_dir, "submission.jsonl"),
                orient="records",
                lines=True,
            )
        else:
            raise ValueError(f"Unknown file format: {file_format}")
    else:
        print("Missing results_dir, not saving predictions to file!")

    return df_predictions

In [13]:
predictions = make_predictions(
    best_model,
    test_dataloader,
    DEVICE,
    results_dir,
    config["data"]["label_column"],
    file_format="csv",
)

100%|██████████| 16/16 [00:23<00:00,  1.49s/it]


In [15]:
!python ../scores_and_plots.py --results-dir "../runs/29-12-2023_13:41:34-SubtaskC-longformer_bilstm"

Results on validation
MAE: 13.62055
--------------------
Results on test
MAE: 11.25149
--------------------
