In [None]:
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
torch.manual_seed(99)

cpu


<torch._C.Generator at 0x7f4ef0ea27f0>

# ArcEager Parser

> The arc-eager system defines an incremental left-to-right parsing order, where left dependents are added bottom–up and right dependents top–down, which is advanta-geous for postponing certain attachment decisions. However, a fundamental problem with this system is that it **does not guarantee that the output parse is a projective dependency tree**, only a projective dependency forest, that is, a sequence of adjacent, non-overlapping projective trees (Nivre 2008). This is different from the closely related arc-standard system (Nivre 2004), which constructs all dependencies bottom–up and can easily be constrained to only output trees. The failure to implement the tree constraint may lead to fragmented parses and lower parsing accuracy, especially with respect to the global structure of the sentence. Moreover, even if the loss in accuracy is not substantial, this may be problematic when using the parser in applications where downstream components may not function correctly if the parser output is not a well-formed tree.


In [29]:
# %% ArcStandardParser
class ArcEager:
    def __init__(self, sentence):
        self.sentence = sentence
        self.buffer = [i for i in range(len(self.sentence))]
        self.stack = []
        self.arcs = [-1 for _ in range(len(self.sentence))]

        # three shift moves to initialize the stack
        self.shift()
        self.shift()
        if len(self.sentence) > 2:
            self.shift()

    def shift(self):
        b1 = self.stack.pop(0)
        self.stack.append(b1)

    def left_arc(self):
        s1 = self.stack.pop()
        b1 = self.buffer[0]
        self.arcs[s1] = b1

    def right_arc(self):
        s1 = self.stack.pop()
        b1 = self.buffer.pop(0)
        self.arcs[b1] = s1
        self.stack.append(b1)

    def reduce(self):
        self.stack.pop()

    def is_tree_final(self):
        return len(self.stack) == 1 and len(self.buffer) == 0

    def print_configuration(self):
        s = [self.sentence[i] for i in self.stack]
        b = [self.sentence[i] for i in self.buffer]
        print(s, b)
        print(self.arcs)

1. Shift moves next to the stack.
2. Reduce pops the stack; allowed only if top has a head.
3. Right-Arc adds a dependency arc from top to next and moves next to the
   stack.
4. Left-Arc adds a dependency arc from next to top and pops the stack;
   allowed only if top has no head.

From [here](https://aclanthology.org/J14-2002.pdf)


In [30]:
class Oracle:
    def __init__(self, parser, gold_tree):
        self.parser = parser
        self.gold = gold_tree

    def is_left_arc_gold(self):
        # first element of the of the buffer is the gold head of the topmost element of the stack

        if len(self.parser.stack) == 0 or len(self.parser.buffer) == 0:
            return False

        o1 = self.parser.stack[-1]
        o2 = self.parser.buffer[0]  # [0]

        if self.gold[o2] == o1:
            return True

        return False

    def is_right_arc_gold(self):
        # if topmost stack element is gold head of the first element of the buffer
        if len(self.parser.stack) == 0 or len(self.parser.buffer) == 0:
            return False

        o1 = self.parser.stack[-1]
        o2 = self.parser.buffer[0]  # [0]

        if self.gold[o1] == o2:
            return True

        return False

    def is_reduce_gold(self):
        # if topmost stack element has got head
        if len(self.parser.stack) == 0:
            return False

        o1 = self.parser.stack[-1]

        if self.gold[o1] != -1:
            return True

        return False

    def is_shift_gold(self):
        if len(self.parser.buffer) == 0:
            return False

        # This dictates transition precedence of the parser
        if self.is_left_arc_gold() or self.is_right_arc_gold() or self.is_reduce_gold():
            return False

        return True

# Dataset preparation


In [9]:
# the function returns whether a tree is projective or not. It is currently
# implemented inefficiently by brute checking every pair of arcs.
def is_projective(tree):
    for i in range(len(tree)):
        if tree[i] == -1:
            continue
        left = min(i, tree[i])
        right = max(i, tree[i])

        for j in range(0, left):
            if tree[j] > left and tree[j] < right:
                return False
        for j in range(left + 1, right):
            if tree[j] < left or tree[j] > right:
                return False
        for j in range(right + 1, len(tree)):
            if tree[j] > left and tree[j] < right:
                return False

    return True


# the function creates a dictionary of word/index pairs: our embeddings vocabulary
# threshold is the minimum number of appearance for a token to be included in the embedding list
def create_dict(dataset, threshold=3):
    dic = {}  # dictionary of word counts
    for sample in dataset:
        for word in sample["tokens"]:
            if word in dic:
                dic[word] += 1
            else:
                dic[word] = 1

    map = {}  # dictionary of word/index pairs. This is our embedding list
    map["<pad>"] = 0
    map["<ROOT>"] = 1
    map["<unk>"] = 2  # used for words that do not appear in our list

    next_indx = 3
    for word in dic.keys():
        if dic[word] >= threshold:
            map[word] = next_indx
            next_indx += 1

    return map

In [20]:
from datasets import load_dataset

dataset = load_dataset("universal_dependencies", "en_lines", split="train")
print(len(dataset))
dataset = [
    sample
    for sample in dataset
    if is_projective([-1] + [int(head) for head in sample["head"]])
]
print(len(dataset), "are projective")
print(dataset[1].keys())

Found cached dataset universal_dependencies (/home/matteo/.cache/huggingface/datasets/universal_dependencies/en_lines/2.7.0/1ac001f0e8a0021f19388e810c94599f3ac13cc45d6b5b8c69f7847b2188bdf7)


3176
2922 are projective
dict_keys(['idx', 'text', 'tokens', 'lemmas', 'upos', 'xpos', 'feats', 'head', 'deprel', 'deps', 'misc'])


In [25]:
train_dataset, dev_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [0.6, 0.1, 0.3]
)
# train_dataset = load_dataset("universal_dependencies", "en_lines", split="train")
# dev_dataset = load_dataset("universal_dependencies", "en_lines", split="validation")
# test_dataset = load_dataset("universal_dependencies", "en_lines", split="test")
print(len(train_dataset))
print(len(dev_dataset))
print(len(test_dataset))

1754
292
876


In [26]:
# create the embedding dictionary
emb_dictionary = create_dict(train_dataset)

## Process sample


In [None]:
def process_sample(sample, get_gold_path=False):
    # put sentence and gold tree in our format
    sentence = ["<ROOT>"] + sample["tokens"]
    gold = [-1] + [
        int(i) for i in sample["head"]
    ]  # heads in the gold tree are strings, we convert them to int

    # embedding ids of sentence words
    enc_sentence = [
        emb_dictionary[word] if word in emb_dictionary else emb_dictionary["<unk>"]
        for word in sentence
    ]

    # gold_path and gold_moves are parallel arrays whose elements refer to parsing steps
    gold_path = (
        []
    )  # record two topmost stack tokens and first buffer token for current step
    gold_moves = (
        []
    )  # contains oracle (canonical) move for current step: 0 is left, 1 right, 2 shift, 3 reduce

    if get_gold_path:  # only for training
        parser = ArcEager(sentence)
        oracle = Oracle(parser, gold)

        while not parser.is_tree_final():
            # save configuration
            configuration = [
                parser.stack[len(parser.stack) - 2],
                parser.stack[len(parser.stack) - 1],
            ]
            if len(parser.buffer) == 0:
                configuration.append(-1)
            else:
                configuration.append(parser.buffer[0])
            gold_path.append(configuration)

            # save gold move
            if oracle.is_left_arc_gold():
                gold_moves.append(0)
                parser.left_arc()
            elif oracle.is_right_arc_gold():
                parser.right_arc()
                gold_moves.append(1)
            elif oracle.is_shift_gold():
                parser.shift()
                gold_moves.append(2)
            elif oracle.is_reduce_gold():
                gold_moves.append(3)
                parser.reduce()

    return enc_sentence, gold_path, gold_moves, gold

## DataLoaders


In [None]:
from functools import partial


def prepare_batch(batch_data, get_gold_path=False):
    data = [process_sample(s, get_gold_path=get_gold_path) for s in batch_data]
    # sentences, paths, moves, trees are parallel arrays, each element refers to a sentence
    sentences = [s[0] for s in data]
    paths = [s[1] for s in data]
    moves = [s[2] for s in data]
    trees = [s[3] for s in data]
    return sentences, paths, moves, trees


BATCH_SIZE = 32

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=partial(prepare_batch, get_gold_path=True),
)
dev_dataloader = torch.utils.data.DataLoader(
    dev_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch)
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=partial(prepare_batch),
)

# HighLevelStructure BiLSTM

1. WordEmbeddings - done before, DataLoaders done
2. BiLSTM Layer
3. Parser Configuration Representation
4. Action Classifier
5. Training
6. Prediction


## Embed sample

Generate all possible buffer-stack configurations with the gold output attached


In [None]:
    # put sentence and gold tree in our format
    sentence = ["<ROOT>"] + sample["tokens"]
    gold = [-1] + [
        int(i) for i in sample["head"]
    ]  # heads in the gold tree are strings, we convert them to int

    # embedding ids of sentence words
    enc_sentence = [
        emb_dictionary[word] if word in emb_dictionary else emb_dictionary["<unk>"]
        for word in sentence
    ]

    # gold_path and gold_moves are parallel arrays whose elements refer to parsing steps
    gold_path = (
        []
    )  # record two topmost stack tokens and first buffer token for current step
    gold_moves = (
        []
    )  # contains oracle (canonical) move for current step: 0 is left, 1 right, 2 shift

    if get_gold_path:  # only for training
        parser = ArcEager(sentence)
        oracle = Oracle(parser, gold)

        while not parser.is_tree_final():
            # save configuration
            configuration = [
                parser.stack[len(parser.stack) - 2],
                parser.stack[len(parser.stack) - 1],
            ]
            if len(parser.buffer) == 0:
                configuration.append(-1)
            else:
                configuration.append(parser.buffer[0])
            gold_path.append(configuration)

            # save gold move
            if oracle.is_reduce_gold():
                gold_moves.append(0)
                parser.reduce()
            if oracle.is_left_arc_gold():
                gold_moves.append(1)
                parser.left_arc()
            elif oracle.is_right_arc_gold():
                parser.right_arc()
                gold_moves.append(2)
            elif oracle.is_shift_gold():
                parser.shift()
                gold_moves.append(3)

    return enc_sentence, gold_path, gold_moves, gold

In [None]:
def prepare_batch(batch_data, get_gold_path=False):
    data = [process_sample(s, get_gold_path=get_gold_path) for s in batch_data]
    # sentences, paths, moves, trees are parallel arrays, each element refers to a sentence
    sentences = [s[0] for s in data]
    paths = [s[1] for s in data]
    moves = [s[2] for s in data]
    trees = [s[3] for s in data]
    return sentences, paths, moves, trees


BATCH_SIZE = 32

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=partial(prepare_batch, get_gold_path=True),
)
dev_dataloader = torch.utils.data.DataLoader(
    dev_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch)
)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=partial(prepare_batch),
)

In [None]:
# %% CREATE NN MODEL

EMBEDDING_SIZE = 200
LSTM_SIZE = 200
LSTM_LAYERS = 1
MLP_SIZE = 200
DROPOUT = 0.2
EPOCHS = 15
LR = 0.001  # learning rate


class Net(nn.Module):
    def __init__(self, device):
        super(Net, self).__init__()
        self.device = device
        self.embeddings = nn.Embedding(
            len(emb_dictionary), EMBEDDING_SIZE, padding_idx=emb_dictionary["<pad>"]
        )

        # initialize bi-LSTM
        self.lstm = nn.LSTM(
            EMBEDDING_SIZE,
            LSTM_SIZE,
            num_layers=LSTM_LAYERS,
            bidirectional=True,
            dropout=DROPOUT,
        )

        # initialize feedforward
        self.w1 = torch.nn.Linear(6 * LSTM_SIZE, MLP_SIZE, bias=True)
        self.activation = torch.nn.Tanh()
        self.w2 = torch.nn.Linear(MLP_SIZE, 3, bias=True)
        self.softmax = torch.nn.Softmax(dim=-1)

        self.dropout = torch.nn.Dropout(DROPOUT)

    def forward(self, x, paths):
        # get the embeddings
        x = [self.dropout(self.embeddings(torch.tensor(i).to(self.device))) for i in x]

        # run the bi-lstm
        h = self.lstm_pass(x)

        # for each parser configuration that we need to score we arrange from the
        # output of the bi-lstm the correct input for the feedforward
        mlp_input = self.get_mlp_input(paths, h)

        # run the feedforward and get the scores for each possible action
        out = self.mlp(mlp_input)

        return out

    def lstm_pass(self, x):
        x = torch.nn.utils.rnn.pack_sequence(x, enforce_sorted=False)
        h, (h_0, c_0) = self.lstm(x)
        h, h_sizes = torch.nn.utils.rnn.pad_packed_sequence(
            h
        )  # size h: (length_sentences, batch, output_hidden_units)
        return h

    def get_mlp_input(self, configurations, h):
        mlp_input = []
        zero_tensor = torch.zeros(2 * LSTM_SIZE, requires_grad=False).to(self.device)
        for i in range(len(configurations)):  # for every sentence in the batch
            for j in configurations[i]:  # for each configuration of a sentence
                mlp_input.append(
                    torch.cat(
                        [
                            zero_tensor if j[0] == -1 else h[j[0]][i],
                            zero_tensor if j[1] == -1 else h[j[1]][i],
                            zero_tensor if j[2] == -1 else h[j[2]][i],
                        ]
                    )
                )
        mlp_input = torch.stack(mlp_input).to(self.device)
        return mlp_input

    def mlp(self, x):
        return self.softmax(
            self.w2(self.dropout(self.activation(self.w1(self.dropout(x)))))
        )

    # we use this function at inference time. We run the parser and at each step
    # we pick as next move the one with the highest score assigned by the model
    def infere(self, x):
        parsers = [ArcStandard(i) for i in x]

        x = [self.embeddings(torch.tensor(i).to(self.device)) for i in x]

        h = self.lstm_pass(x)

        while not self.parsed_all(parsers):
            # get the current configuration and score next moves
            configurations = self.get_configurations(parsers)
            mlp_input = self.get_mlp_input(configurations, h)
            mlp_out = self.mlp(mlp_input)
            # take the next parsing step
            self.parse_step(parsers, mlp_out)

        # return the predicted dependency tree
        return [parser.arcs for parser in parsers]

    def get_configurations(self, parsers):
        configurations = []

        for parser in parsers:
            if parser.is_tree_final():
                conf = [-1, -1, -1]
            else:
                conf = [
                    parser.stack[len(parser.stack) - 2],
                    parser.stack[len(parser.stack) - 1],
                ]
                if len(parser.buffer) == 0:
                    conf.append(-1)
                else:
                    conf.append(parser.buffer[0])
            configurations.append([conf])

        return configurations

    def parsed_all(self, parsers):
        for parser in parsers:
            if not parser.is_tree_final():
                return False
        return True

    # In this function we select and perform the next move according to the scores obtained.
    # We need to be careful and select correct moves, e.g. don't do a shift if the buffer
    # is empty or a left arc if σ2 is the ROOT. For clarity sake we didn't implement
    # these checks in the parser so we must do them here. This renders the function quite ugly
    def parse_step(self, parsers, moves):
        moves_argm = moves.argmax(-1)
        for i in range(len(parsers)):
            if parsers[i].is_tree_final():
                continue
            else:
                if moves_argm[i] == 0:
                    if parsers[i].stack[len(parsers[i].stack) - 2] != 0:
                        parsers[i].left_arc()
                    else:
                        if len(parsers[i].buffer) > 0:
                            parsers[i].shift()
                        else:
                            parsers[i].right_arc()
                elif moves_argm[i] == 1:
                    if (
                        parsers[i].stack[len(parsers[i].stack) - 2] == 0
                        and len(parsers[i].buffer) > 0
                    ):
                        parsers[i].shift()
                    else:
                        parsers[i].right_arc()
                elif moves_argm[i] == 2:
                    if len(parsers[i].buffer) > 0:
                        parsers[i].shift()
                    else:
                        if moves[i][0] > moves[i][1]:
                            if parsers[i].stack[len(parsers[i].stack) - 2] != 0:
                                parsers[i].left_arc()
                            else:
                                parsers[i].right_arc()
                        else:
                            parsers[i].right_arc()

## BiLSTM


In [5]:
# from https://pytorch.org/docs/stable/generated/torch.nn.LSTM
rnn = nn.LSTM(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
c0 = torch.randn(2, 3, 20)
output, (hn, cn) = rnn(input, (h0, c0))
print(output.shape)

torch.Size([5, 3, 20])


In [3]:
class BiLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(BiLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(
            input_size, hidden_size, num_layers, batch_first=True, bidirectional=True
        )
        self.fc = nn.Linear(
            hidden_size * 2, output_size
        )  # Multiply hidden_size by 2 for bidirectional LSTM

    def forward(self, x):
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(
            x.device
        )  # Initialize hidden state for forward LSTM
        c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).to(
            x.device
        )  # Initialize cell state for forward LSTM

        out, _ = self.lstm(x, (h0, c0))  # Apply LSTM

        # Extract the last hidden state from the bidirectional LSTM output
        out = self.fc(out[:, -1, :])  # Shape: (batch_size, output_size)
        return out

In [5]:
# Create an instance of the BiLSTM model
input_size = 10
hidden_size = 20
num_layers = 2
output_size = 5

model = BiLSTM(input_size, hidden_size, num_layers, output_size)

# Generate some sample input
batch_size = 3
sequence_length = 4
input_data = torch.randn(batch_size, sequence_length, input_size)

# Pass the input through the model
output = model(input_data)

print(output)
print(output.shape)  # Print the shape of the output

tensor([[ 0.0845,  0.1244,  0.0795, -0.0587, -0.0564],
        [ 0.0853,  0.0852,  0.0466, -0.0511, -0.0886],
        [ 0.0826,  0.0876,  0.0281, -0.0715, -0.0818]],
       grad_fn=<AddmmBackward0>)
torch.Size([3, 5])


# OLD


In [41]:
%%capture
from datasets import load_dataset

mrpc_dataset = load_dataset("glue", "mrpc")


In [42]:
print(mrpc_dataset["train"].features.keys(), end="\n\n")

for k, v in mrpc_dataset["train"].features.items():
    print(k, v)

dict_keys(['sentence1', 'sentence2', 'label', 'idx'])

sentence1 Value(dtype='string', id=None)
sentence2 Value(dtype='string', id=None)
label ClassLabel(names=['not_equivalent', 'equivalent'], id=None)
idx Value(dtype='int32', id=None)


In [43]:
print(mrpc_dataset["train"][0], end="\n\n")

{'sentence1': 'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .', 'sentence2': 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .', 'label': 1, 'idx': 0}



In [44]:
for k, v in mrpc_dataset.items():
    print(k, len(v))

train 3668
validation 408
test 1725


Tokenize the entire dataset


In [64]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")


def tokenize(example):
    return tokenizer(example["sentence1"], example["sentence2"], truncation=True)

In [65]:
# train_dataset = mprpc_dataset["train"].map(tokenize, batched=True)

mrpc_dataset = load_dataset("glue", "mrpc")
mrpc_dataset = mrpc_dataset.map(
    tokenize, batched=True
)  # adds input_ids, attention_mask, token_type_ids

Found cached dataset glue (/home/matteo/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/matteo/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-7bc4744dbc6861ed.arrow
Loading cached processed dataset at /home/matteo/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-54ec9e7ee41eb863.arrow


Map:   0%|          | 0/1725 [00:00<?, ? examples/s]

Dynamic Padding


In [66]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")


samples = mrpc_dataset["train"][:8]
samples = {
    "input_ids": samples["input_ids"],  # requires tokenized words
    "attention_mask": samples["attention_mask"],
    "token_type_ids": samples["token_type_ids"],
    "label": samples["label"],
}
mrpc_dataset = data_collator(samples)  # Padding
mrpc_dataset.keys()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


dict_keys(['input_ids', 'attention_mask', 'token_type_ids', 'labels'])

In [67]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased", num_labels=2
)
print(model)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

Traning


In [70]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="bert_mrpc",  # where to save checkpoints and model predictions
    per_device_train_batch_size=8,  # training batch size
    per_device_eval_batch_size=16,  # validation/test batch sizie
    num_train_epochs=3,  # number of epochs
    save_strategy="epoch",  # checkpoint saving frequency
    evaluation_strategy="epoch",  # how frequently to run validation, can be epoch or steps (in this case you need to specify eval_steps)
    metric_for_best_model="f1",  # metric used to pick checkpoints
    greater_is_better=True,  # whether the metric for checkpoint needs to be maximized or minimized
    learning_rate=3e-5,  # learning rate or peak learning rate if scheduler is used
    optim="adamw_torch",  # which optimizer to use
    lr_scheduler_type="linear",  # which scheduler to use
    warmup_ratio=0.1,  # % of steps for which to do warmup
    seed=33,  # setting seed for reproducibility
    load_best_model_at_end=True,
)  # after training, load the best checkpoint according to metric_for_best_model

In [71]:
import evaluate
import numpy as np


def compute_metrics_mrpc(eval_preds):
    metric = evaluate.load("glue", "mrpc")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


# %%
from transformers import Trainer

trainer = Trainer(
    model,
    training_args,
    train_dataset=mrpc_dataset["train"],
    eval_dataset=mrpc_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics_mrpc,
)

# %%
trainer.train()

# %% [markdown]
# ### Computing performance on the test set

# %%
test_predictions = trainer.predict(mrpc_dataset["test"])
print(test_predictions.metrics)

# %%
trainer.state.best_model_checkpoint  # folder where best model is save

KeyError: 'train'