In [1]:
%%capture
#!pip install transformers
#!pip install datasets
#!pip install conllu

In [2]:
import torch
import numpy as np
import seaborn as sns
import torch.nn as nn
from functools import partial
import matplotlib.pyplot as plt
from datasets import load_dataset
from collections import defaultdict
from sklearn.metrics import confusion_matrix, classification_report
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup

## Arc-standard

Recall that a **configuration** of the arc-standard parser is a triple of the form $( \sigma, \beta, A)$
where:

* $\sigma$ is the stack;
* $\beta$ is the input buffer;
* $A$ is a set of arcs constructed so far.

We write $\sigma_i$, $i \geq 1$, for the $i$-th token in the stack; we also write $\beta_i$, $i \geq 1$, for the $i$-th token in the buffer. 

The parser can perform three types of **actions** (transitions):

* **shift**, which removes $\beta_1$ from the buffer and pushes it into the stack;
* **left-arc**, which creates the arc $(\sigma_1 \rightarrow \sigma_2)$, and removes $\sigma_2$ from the stack;
* **right-arc**, which creates the arc $(\sigma_2 \rightarrow \sigma_1)$, and removes $\sigma_1$ from the stack.

Let $w = w_0 w_1 \cdots w_{n}$ be the input sentence, with $w_0$ the special symbol `<ROOT>`.
Stack and buffer are implemented as lists of integers, where `j` represents word $w_j$.  Top-most stack token is at the right-end of the list; first buffer token is at the left-end of the list. 
Set $A$ is implemented as an array `arcs` of size $n+1$ such that if arc $(w_i \rightarrow w_j)$ is in $A$ then `arcs[j]=i`, and if $w_j$ is still missing its head node in the tree under construction, then `arcs[j]=-1`. We always have `arcs[0]=-1`.  We use this representation also for complete dependency trees.


In [3]:
import orig #file orig.py in the same foldr contains original class and function definitions

In [4]:
class ArcStandard (orig.ArcStandard):
    def __init__(self, sentence, label_set):
        self.sentence = sentence
        self.buffer = [i for i in range(len(self.sentence))]
        self.stack = []
        self.arcs = [(-1, -1) for _ in range(len(self.sentence))]# head - relation converted to int

        self.label_set = {y:x for x,y in label_set.items()} #for future reference (printing)
        # three shift moves to initialize the stack
        self.shift()
        self.shift()
        if len(self.sentence) > 2:
            self.shift()

    def left_arc(self, deprel):
        o1 = self.stack.pop()
        o2 = self.stack.pop()
        self.arcs[o2] = (o1, deprel) #added deprel
        self.stack.append(o1)
        if len(self.stack) < 2 and len(self.buffer) > 0:
            self.shift()

    def right_arc(self, deprel):
        o1 = self.stack.pop()
        o2 = self.stack.pop()
        self.arcs[o1] = (o2, deprel) #added deprel
        self.stack.append(o2)
        if len(self.stack) < 2 and len(self.buffer) > 0:
            self.shift()

    def print_configuration(self):
        s = [self.sentence[i] for i in self.stack]
        b = [self.sentence[i] for i in self.buffer]
        print("STACK: ", s, "BUFFER: ", b) #added indication of stack and buffer
        print([(x[0], self.label_set[x[1]]) for x in self.arcs])

In [5]:
class Oracle (orig.Oracle):
    def __init__(self, parser, gold_tree, gold_labels):
        self.parser = parser
        self.gold = gold_tree
        self.labels = gold_labels #must be integers; see below

    def get_gold_label(self):
        if(self.is_left_arc_gold()):
            o2 = self.parser.stack[len(self.parser.stack)-2]
            return self.labels[o2]
        elif(self.is_right_arc_gold()):
            o1 = self.parser.stack[len(self.parser.stack)-1]
            return self.labels[o1]
        else:
            raise Exception("Action SHIFT does not produce a labeled dependency")
            

## Functions to simplify the processing

In [6]:
"""
dataset should be a Dataset object from HuggingFace
Return a dictionary where arc labels are associate to integer values.
Label '_' is removed (used for composite tokens)
Labels are sorted in alphabetical order because MSE in the neural network will penalize less classes that are "close".
<ROOT> label is given value -1 (label of the <ROOT> node).
"""
def create_deprel_dict(dataset):
    labels = set()
    for x in dataset['deprel']:
        for elem in x:
            if(elem != '_'):
                labels.add(elem)
    labels = sorted(list(labels))
    ret = {labels[i]:i for i in range(len(labels))}
    ret.update({"None":-1})
    return ret

"""
Return (sanitized_tokens, gold tree, gold labels) for a sentence in the dataset, assuming gold tree is in a column labeled 'head' and labels in a column labeled 'deprel'.
Both token indices and labels are converted to integer (in the latter case, according to deprel_dict).
Sanitized tokens is a tokenlist with removal of composite tokens.
"""
def create_gold(sentence, deprel_dict):
    sanitized = ['<ROOT>']
    gold_tree = [-1]
    gold_labels = [-1]
    for i in range(len(sentence['tokens'])):
        if(sentence['head'][i] != 'None'):
            sanitized.append(sentence['tokens'][i])
            gold_tree.append(int(sentence['head'][i]))
            gold_labels.append(deprel_dict[sentence['deprel'][i]])
    
    return sanitized, gold_tree, gold_labels 

## Testing parser and oracle...

In [7]:
%%capture
train_dataset = load_dataset('universal_dependencies', 'it_isdt', split="train")

In [8]:
deprels = create_deprel_dict(train_dataset)
sentence = train_dataset[3]

tokens, gold_tree, gold_labels = create_gold(sentence, deprels)

print(tokens)
print(gold_tree)
print(gold_labels)

['<ROOT>', 'Inconsueto', 'allarme', 'a', 'la', 'Tate', 'Gallery', ':']
[-1, 2, 0, 5, 5, 2, 5, 2]
[-1, 4, 41, 8, 17, 31, 28, 40]


In [9]:
parser = ArcStandard(tokens, deprels)
oracle = Oracle(parser, gold_tree, gold_labels)

parser.print_configuration()

STACK:  ['<ROOT>', 'Inconsueto', 'allarme'] BUFFER:  ['a', 'la', 'Tate', 'Gallery', ':']
[(-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]


In [10]:
while not parser.is_tree_final():  # transition precedence implemented here
    if oracle.is_shift_gold():  
        parser.shift()
    elif oracle.is_left_arc_gold():
        parser.left_arc(oracle.get_gold_label())
    elif oracle.is_right_arc_gold():
        parser.right_arc(oracle.get_gold_label())
    
    parser.print_configuration()

STACK:  ['<ROOT>', 'allarme'] BUFFER:  ['a', 'la', 'Tate', 'Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a'] BUFFER:  ['la', 'Tate', 'Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a', 'la'] BUFFER:  ['Tate', 'Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a', 'la', 'Tate'] BUFFER:  ['Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'a', 'Tate'] BUFFER:  ['Gallery', ':']
[(-1, 'None'), (2, 'amod'), (-1, 'None'), (-1, 'None'), (5, 'det'), (-1, 'None'), (-1, 'None'), (-1, 'None')]
STACK:  ['<ROOT>', 'allarme', 'Tate'] BUFFER:  ['Gallery', ':']
[(-1, 'No

## Testing the whole dataset

In [11]:
# remember to call orig.is_projective() to check for projectivity!!
sentence = train_dataset[10]
deprels = create_deprel_dict(train_dataset)

tokens, gold_tree, gold_labels = create_gold(sentence, deprels)
orig.is_projective(gold_tree)

False

In [12]:
### NOTE: THIS CELL WAS COPIED WITH FEW MODIFICATIONS FROM THE ORIGINAL FILE

# oracle test: run the parser guided by the oracle on the entire training set 

non_projective = 0
correct = 0
wrong = 0

deprels = create_deprel_dict(train_dataset)

for sample in train_dataset:
    tokens, gold_tree, gold_labels = create_gold(sample, deprels)

    if not orig.is_projective(gold_tree):
        non_projective += 1
        continue

    parser = ArcStandard(tokens, deprels)
    oracle = Oracle(parser, gold_tree, gold_labels)

    while not parser.is_tree_final():
        if oracle.is_left_arc_gold(): 
            parser.left_arc(oracle.get_gold_label())
        elif oracle.is_right_arc_gold():
            parser.right_arc(oracle.get_gold_label())
        elif oracle.is_shift_gold(): 
            parser.shift()

    for j in range(len(gold_tree)):  # comparing heads from parser and gold for actual sample
        if gold_tree[j] == parser.arcs[j][0] and gold_labels[j] == parser.arcs[j][1]: 
            correct += 1
        else:
            wrong += 1

print("non projective: ", non_projective)
print("correct: ", correct)
print("wrong: ", wrong)

non projective:  167
correct:  282533
wrong:  0


## Creating samples for the neural oracle

In [13]:
# Modified from the original
"""
This function processes a single sample, which is one sentence provided as an element of a Dataset object and returns 
    enc_sentence : a list of integers encoding the phrase #???? is this even correct for BERT???
    gold_path : a list of configurations
    gold_moves : a list of (move, label)
"""
def process_sample(sample, emb_dictionary, deprels, get_gold_path = False): #emb_dictionary and deprels are dictionaries of words and of dependency relations
    tokens, gold_tree, gold_labels = create_gold(sample, deprels)
    #print("DEBUG:", tokens)
    enc_sentence = [emb_dictionary[word] if word in emb_dictionary else emb_dictionary["<unk>"] for word in tokens]
    #print("DEBUG:", enc_sentence)
    
    # gold_path and gold_moves are parallel arrays whose elements refer to parsing steps
    gold_path = []   # record two topmost stack token and first buffer token for current step
    gold_moves = []  # oracle (canonical) move for current step: 100 is left, 0 right, -100 shift #motivations provided above TODO

    if get_gold_path:  # only for training
        parser = ArcStandard(enc_sentence, deprels)
        oracle = Oracle(parser, gold_tree, gold_labels)

        while not parser.is_tree_final():
            configuration = [parser.stack[len(parser.stack)-2], parser.stack[len(parser.stack)-1]]
            if len(parser.buffer) == 0:
                configuration.append(-1)
            else:
                configuration.append(parser.buffer[0])

            gold_path.append(configuration)
            
            if oracle.is_left_arc_gold():
                label_code = oracle.get_gold_label() #deprels[oracle.get_gold_label()]
                parser.left_arc(label_code)
                gold_moves.append((100, label_code)) #note: I switched the instructions here for symmetry. There was no comment or good reason not to.
            elif oracle.is_right_arc_gold():
                label_code = oracle.get_gold_label() #deprels[oracle.get_gold_label()]
                parser.right_arc(label_code)
                gold_moves.append((0, label_code))
            elif oracle.is_shift_gold():
                parser.shift()
                gold_moves.append((-100, -100))
            else:
                print("**** AN ERROR OCCURRED ****")
                #print(tokens)
                #print(enc_sentence)
                #print(gold_tree)
                parser.print_configuration()
                raise Exception("No action identified as gold!! Please make sure your dataset doesn't contain nonprojective trees.")
                
            #print("DEBUG: ", parser.stack)
    
    #print("DEBUG: length of path: ", len(gold_path), "length of moves: ", len(gold_moves))
    return enc_sentence, gold_path, gold_moves, gold_tree, gold_labels


In [14]:
### Code to test process_sample
"""
test_sentence = train_dataset[88]
emb_dictionary = orig.create_dict(train_dataset)
deprels = create_deprel_dict(train_dataset)

process_sample(test_sentence, emb_dictionary, deprels, get_gold_path = True)
"""

'\ntest_sentence = train_dataset[88]\nemb_dictionary = orig.create_dict(train_dataset)\ndeprels = create_deprel_dict(train_dataset)\n\nprocess_sample(test_sentence, emb_dictionary, deprels, get_gold_path = True)\n'

In [15]:
"""
Function that prepares the data for the model.
@return sentences : list of sentences in the dataset encoded according to emb_dictionary
@return paths : list of lists of configurations visited during the oracle-guided parsing (on training data) | empty list if get_gold_path is False
@return moves : list of lists of (MOVE, LABEL) performed during the oracle-guided parsing (on training data) | empty if get_gold_path is False
@return trees : ground truth tree
@return labels : ground truth lables

Note: the last two returned values may eventually be zipped together for faster loss calculation.
"""
#modified form orig
def prepare_batch(batch_data, emb_dictionary, deprels, get_gold_path=False):
    data = [process_sample(s, emb_dictionary, deprels, get_gold_path=get_gold_path) for s in batch_data]
    # sentences, paths, moves, trees are parallel arrays, each element refers to a sentence
    sentences = [s[0] for s in data]
    paths = [s[1] for s in data]
    moves = [s[2] for s in data]
    trees = [s[3] for s in data]
    labels = [s[4] for s in data]
    return sentences, paths, moves, trees, labels

In [16]:
### code to test prepare_batch
"""
emb_dictionary = orig.create_dict(train_dataset)
deprels = create_deprel_dict(train_dataset)

_ = prepare_batch(iter(train_dataset), emb_dictionary, deprels, True)
"""

'\nemb_dictionary = orig.create_dict(train_dataset)\ndeprels = create_deprel_dict(train_dataset)\n\n_ = prepare_batch(iter(train_dataset), emb_dictionary, deprels, True)\n'

In [17]:
"""
Return indices of projective sentences in the dataset. Must be used with select() to filter out nonprojective trees.
"""
def find_projective_idx(dataset):
    ret = []
    for i in range(len(dataset)):
        x = dataset[i]
        if orig.is_projective([-1] + [int(head) for head in x["head"] if head!='None']):
            ret.append(i)
    return ret

In [18]:
tmp_train_dataset = load_dataset('universal_dependencies', 'it_isdt', split="train")
dev_dataset = load_dataset('universal_dependencies', 'it_isdt', split="validation")
test_dataset = load_dataset('universal_dependencies', 'it_isdt', split="test")

train_dataset = tmp_train_dataset.select(find_projective_idx(tmp_train_dataset)) #to remove nonprojective trees

Reusing dataset universal_dependencies (/home/filippo/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)
Reusing dataset universal_dependencies (/home/filippo/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)
Reusing dataset universal_dependencies (/home/filippo/.cache/huggingface/datasets/universal_dependencies/it_isdt/2.7.0/065e728dfe9a8371434a6e87132c2386a6eacab1a076d3a12aa417b994e6ef7d)


In [19]:
emb_dictionary = orig.create_dict(train_dataset)
deprels = create_deprel_dict(train_dataset)

In [20]:
BATCH_SIZE = 32

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(prepare_batch, emb_dictionary = emb_dictionary, deprels = deprels, get_gold_path=True))
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch, emb_dictionary = emb_dictionary, deprels = deprels))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch, emb_dictionary = emb_dictionary, deprels = deprels))

In [21]:
sentences, paths, moves, trees, labels = next(iter(train_dataloader)) #<--- arresto anomalo per saturazione RAM

DEBUG: length of path:  36 length of moves:  36
DEBUG: length of path:  20 length of moves:  20
DEBUG: length of path:  46 length of moves:  46
DEBUG: length of path:  28 length of moves:  28
DEBUG: length of path:  28 length of moves:  28
DEBUG: length of path:  70 length of moves:  70
DEBUG: length of path:  48 length of moves:  48
DEBUG: length of path:  98 length of moves:  98
DEBUG: length of path:  22 length of moves:  22
DEBUG: length of path:  4 length of moves:  4
DEBUG: length of path:  16 length of moves:  16
DEBUG: length of path:  10 length of moves:  10
DEBUG: length of path:  20 length of moves:  20
DEBUG: length of path:  76 length of moves:  76
DEBUG: length of path:  34 length of moves:  34
DEBUG: length of path:  36 length of moves:  36
DEBUG: length of path:  12 length of moves:  12
DEBUG: length of path:  42 length of moves:  42
DEBUG: length of path:  30 length of moves:  30
DEBUG: length of path:  62 length of moves:  62
DEBUG: length of path:  18 length of moves

## Ideas for the future
* Use mean square error to predict pairs (MOVE, LABEL) in the training set
* We need to rescale the data in an acceptable range. Probably should make the MOVE have greater variation, e.g LeftArc=100 RightArc=0 Shift=-100
* BERT layer must process each sentence separately to produce word embeddings, but it should also be finetuned alongside the fully connected layer

# Neural Oracle

Neural oracle implemented using BERT fine-tunned with a simple classifiers.
The references are: [Towards Data Science](https://towardsdatascience.com/text-classification-with-bert-in-pytorch-887965e5820f) and [Text Classification | Sentiment Analysis with BERT using huggingface, PyTorch and Python Tutorial](https://www.youtube.com/watch?v=8N-nM3QW7O0).

The hyperparameters and the training tweaks are taken from the original BERT paper.

In [76]:
# Hyperparameters
bert_model = 'dbmdz/bert-base-italian-xxl-cased'
dropout = 0.3
n_classes = 2
n_classes_test = 1
learning_rate = 2e-5
epochs = 1
save_path = "path_were_to_save_model"
tokenizer_max_len = 100
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [28]:
# load the BERT tokenizer with pretrained weights for the Italian language
tokenizer = BertTokenizer.from_pretrained(bert_model)

Downloading:   0%|          | 0.00/230k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/59.0 [00:00<?, ?B/s]

In [29]:
example_text = '''Con il termine interfluvio in geologia si indica la porzione di superﬁcie più elevata
                  che separa due valli ﬂuviali adiacenti, che può essere una cresta oppure un' area ampia,
                  comunque non coinvolta dal movimento delle acque'''


bert_input = tokenizer(example_text, padding='max_length', max_length = 15, 
                       truncation=True, return_tensors="pt")

print(bert_input['input_ids'])      # words to integer
print(bert_input['token_type_ids']) # binary mask that identifies whether a token is a real word or just padding
print(bert_input['attention_mask']) # identifies whether a token is a real word or just padding

decoded_text = tokenizer.decode(bert_input.input_ids[0])

print(decoded_text)

tensor([[  102,   401,   162,  1909,   532, 27948,  2369,   139,  5101,  1783,
           223,  1731,   146, 16711,   103]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
[CLS] Con il termine interfluvio in geologia si indica la porzione [SEP]


## Prepare data to be fed into the BERT Oracle using the tokenizer

  ## Oracle model using BERT
  The oracle, as we intended, is just a classifier that, for each possibile configuration provided, outputs the right pair of (MOVE, LABEL) as if they were simple classes.

In [80]:
class BertOracle(nn.Module):
  def __init__(self, dropout_rate, n_classes, bert_model):
        super(BertOracle, self).__init__()

        self.bert = BertModel.from_pretrained(bert_model, return_dict=False)
        self.dropout = nn.Dropout(dropout_rate)
        self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
  
  def forward(self, input_tokens, attention_mask):
    _, pooled_output = self.bert(
        input_ids = input_tokens,
        attention_mask = attention_mask
    )
    output = self.dropout(pooled_output)
    output = self.out(output)
    return output


In [None]:
model = BertOracle(dropout, n_classes, bert_model)
model.to(device)

## Training and evaluation functions
Linear scheduler: used to decay the learning rates to improve training performacnes.

In [32]:
optimizer = AdamW(model.parameters(), lr = learning_rate, correct_bias = False)

total_steps = len(train_dataloader) * epochs
scheduler  = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps = 0,
    num_training_steps = total_steps
)

loss_fn = nn.MSELoss().to(device) 



In [101]:
def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler, n_examples):

  model = model.train()
  losses = []

  for batch in data_loader:
    optimizer.zero_grad()                                     # initialize gradient to zeros
    sentences, paths, moves, trees, labels = batch            # unpack the batch (created above w/ prepare_batch)
    moves = torch.FloatTensor(moves)

    preprocessed_text = tokenizer(str(sentences), padding='max_length', max_length = tokenizer_max_len, 
                       truncation=True, return_tensors="pt")  # tokenize input using BERT tokenizer
    input_tokens = preprocessed_text["input_ids"]
    attention_mask = preprocessed_text["attention_mask"]

    outputs = model(
        input_tokens = input_tokens,
        attention_mask = attention_mask
    )

    loss = loss_fn(outputs, moves)
    
    losses.append(loss.item())

    #backpropagation routine
    loss.backward()
    nn.utils.clip_grad_norm(model.parameters(), max_norm = 1.0) # gradient clipping to avoid exploding gradients if they become too large
    optimizer.step()
    scheduler.step()

  mean_loss = np.mean(losses)

  return mean_loss

In [68]:
def eval_model(model, data_loader, loss_fn, optimzier, device, scheduler, n_examples):
  model = model.eval()                                         # dropout and batch_norm are not enabled

  losses = []

  with torch.no_grad():
    for batch in data_loader:    
      optimizer.zero_grad()                                     # initialize gradient to zeros
      sentences, paths, moves, trees, labels = batch            # unpack the batch (created above w/ prepare_batch)
      moves = torch.FloatTensor(moves)

      preprocessed_text = tokenizer(str(sentences), padding='max_length', max_length = tokenizer_max_len, 
                        truncation=True, return_tensors="pt")  # tokenize input using BERT tokenizer
      input_tokens = preprocessed_text["input_ids"]
      attention_mask = preprocessed_text["attention_mask"]

      outputs = model(
          input_tokens = input_tokens,
          attention_mask = attention_mask
      )

      loss = loss_fn(outputs, moves)

      losses.append(loss.item())

  mean_loss = np.mean(losses)

  return mean_loss

## Training loop

In [102]:
%time                      

history = defaultdict(list) # store training and validation losses and accuracy in a dictionary 
best_accuracy = 0           # save model with best accuracy only (can be used for early stopping)

for epoch in range(epochs):
  print(f'Epoch {epoch + 1} / {epochs}')
  print('-'*10)

  #Training step
  train_acc, train_loss = train_epoch(model, dev_dataloader, loss_fn, optimizer, device, scheduler, n_classes)

  print(f"Train loss {train_loss} Train accuracy {train_acc}")

  #Evaluation step
  val_acc, val_loss = eval_model(model, dev_dataloader, loss_fn, optimizer, device, scheduler, n_classes)

  print(f"Val loss {val_loss} Val accuracy {val_acc}")

  print()

  history['train_acc'].append(train_acc)
  history['train_loss'].append(train_loss)

  history['val_acc'].append(val_acc)
  history['val_loss'].append(val_loss)

  if val_acc > best_accuracy:
    torch.save(model, save_path)
    best_accuracy = val_acc

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 5.48 µs
Epoch 1 / 1
----------
<class 'torch.Tensor'>
torch.Size([1, 2])
tensor([], size=(32, 0))


  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: ignored

## Evaluation

In [None]:
def get_label_action_pairs(model, data_loader):
  model = model.eval()

  label_action_pairs = []
  predictions = []


  with torch.no_grad():
    for batch in data_loader:    
      optimizer.zero_grad()                                     # initialize gradient to zeros
      sentences, paths, moves, trees, labels = batch            # unpack the batch (created above w/ prepare_batch)
      moves = torch.FloatTensor(moves)

      preprocessed_text = tokenizer(str(sentences), padding='max_length', max_length = tokenizer_max_len, 
                        truncation=True, return_tensors="pt")  # tokenize input using BERT tokenizer
      input_tokens = preprocessed_text["input_ids"]
      attention_mask = preprocessed_text["attention_mask"]

      outputs = model(
          input_tokens = input_tokens,
          attention_mask = attention_mask
      )

      label_action_pairs.extends(moves)
      predictions.extend(outputs)

  predictions = torch.stack(predictions).cpu()
  return label_action_pairs, predictions, 

In [None]:
test_acc, test_loss = eval_model(model, test_dataloader, loss_fn, device, n_classes_test)

In [None]:
y_real_moves_pairs, y_pred_moves_pairs = get_label_action_pairs(model, test_dataloader)