In [1]:
# add parent dir to sys path for import of modules
import os
import sys
parentdir = os.path.abspath(os.path.join(os.path.abspath(''), os.pardir))
sys.path.insert(0, parentdir) 

# TF Data Preparation

In [2]:
import argparse
from PetReader import pet_reader
import tensorflow as tf
import tensorflow_addons  as tfa
import transformers
from petreader.labels import *

INFO:PetReader:Load RelationsExtraction dataset ...




 _______ _     _ _______       _____  _______ _______      ______  _______ _______ _______ _______ _______ _______
    |    |_____| |______      |_____] |______    |         |     \ |_____|    |    |_____| |______ |______    |   
    |    |     | |______      |       |______    |         |_____/ |     |    |    |     | ______| |______    |   
                                                                                                                  
Discover more at: [https://pdi.fbk.eu/pet-dataset/]



  0%|          | 0/1 [00:00<?, ?it/s]

INFO:PetReader:Load TokenClassification dataset ...




 _______ _     _ _______       _____  _______ _______      ______  _______ _______ _______ _______ _______ _______
    |    |_____| |______      |_____] |______    |         |     \ |_____|    |    |_____| |______ |______    |   
    |    |     | |______      |       |______    |         |_____/ |     |    |    |     | ______| |______    |   
                                                                                                                  
Discover more at: [https://pdi.fbk.eu/pet-dataset/]



  0%|          | 0/1 [00:00<?, ?it/s]

In [3]:
import argparse

## Create Dataset

In [98]:
XOR_LABEL = 2
AND_LABEL = 3

In [4]:
tokenizer = transformers.DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

In [5]:
print(pet_reader.token_dataset.GetSampleDictWithNerLabels(316))

{'document name': 'doc-5.4', 'sentence-ID': 12, 'tokens': ['If', 'the', 'treasurer', 'accepts', 'the', 'expenses', 'for', 'processing', ',', 'the', 'report', 'moves', 'to', 'an', 'automatic', 'activity', 'that', 'links', 'to', 'a', 'payment', 'system', '.'], 'ner-tags': ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']}


In [6]:
WEIGHTS_BERT_TOKENS = 0
WEIGHTS_GATEWAY_LABELS = 1
WEIGHTS_OTHER = .1

sample_numbers = pet_reader.token_dataset.GetRandomizedSampleNumbers()
sample_dicts = [pet_reader.token_dataset.GetSampleDictWithNerLabels(sample_number) for sample_number in sample_numbers]
sample_sentences = [sample_dict['tokens'] for sample_dict in sample_dicts]

# 1) transform tokens tags into IDs classification
dataset_tokens = tokenizer(sample_sentences, is_split_into_words=True, padding=True, return_tensors='tf')
max_sentence_length = dataset_tokens['input_ids'].shape[1]

# 2) transform NER token tags into labels for classification
dataset_labels = []
dataset_sample_weights = []
dataset_word_ids = []
for i, sample_number in enumerate(sample_numbers):
    sample_dict = pet_reader.token_dataset.GetSampleDictWithNerLabels(sample_number)
    # transformer_tokens = tokenizer.convert_ids_to_tokens(tokens['input_ids'][i])
    # tokenize again every single sample to get access to .word_ids()
    tokenization = tokenizer(sample_dict['tokens'], is_split_into_words=True, 
                             padding='max_length', max_length=max_sentence_length, return_tensors='tf')
    sample_tokens = tokenizer.convert_ids_to_tokens(tokenization['input_ids'][0])
    
    sample_labels = []
    sample_sample_weights = []
    # word index necessary, because one token in PET could be splitted into multiple tokens with tokenizer
    # multiple tokens have all the same word_id -> allows retrieval of the same one NER label from PET tokens
    for token, word_index in zip(sample_tokens, tokenization.word_ids()):
        # set special class for special tokens
        if token in ['[CLS]', '[SEP]', '[PAD]']:
            sample_labels.append(0)
            sample_sample_weights.append(WEIGHTS_BERT_TOKENS)
        else:
            
            token_tag = sample_dict['ner-tags'][word_index]
            # XOR
            if token_tag.endswith(XOR_GATEWAY):
                sample_labels.append(XOR_LABEL)  # 2
                sample_sample_weights.append(1)
            # AND
            elif token_tag.endswith(AND_GATEWAY):
                sample_labels.append(AND_LABEL)  # 3
                sample_sample_weights.append(1)
            else:
                sample_labels.append(1)
                sample_sample_weights.append(WEIGHTS_OTHER)
#             # Other
#             elif token_tag.endswith("O"):
#                 sample_labels.append(1)
#                 sample_sample_weights.append(WEIGHTS_OTHER)
#             # Activity
#             elif token_tag.endswith(ACTIVITY):
#                 sample_labels.append(4)
#                 sample_sample_weights.append(WEIGHTS_OTHER)
#             # Activity Data
#             elif token_tag.endswith(ACTIVITY_DATA):
#                 sample_labels.append(5)
#                 sample_sample_weights.append(WEIGHTS_OTHER)
#             # Actor
#             elif token_tag.endswith(ACTOR):
#                 sample_labels.append(6)
#                 sample_sample_weights.append(WEIGHTS_OTHER)
#             # Further Specifications
#             elif token_tag.endswith(FURTHER_SPECIFICATION):
#                 sample_labels.append(7)
#                 sample_sample_weights.append(WEIGHTS_OTHER)
#             # Condition Specifications
#             elif token_tag.endswith(CONDITION_SPECIFICATION):
#                 sample_labels.append(8)
#                 sample_sample_weights.append(WEIGHTS_OTHER)
    dataset_sample_weights.append(sample_sample_weights)            
    dataset_labels.append(sample_labels)
    dataset_word_ids.append(tokenization.word_ids())

dataset_labels = tf.constant(dataset_labels)
dataset_sample_weights = tf.constant(dataset_sample_weights)
dataset_tokens = dataset_tokens
print(dataset_labels.shape)
print(dataset_sample_weights.shape)
print(dataset_tokens['input_ids'].shape)

(417, 64)
(417, 64)
(417, 64)


In [7]:
dataset_sample_weights

<tf.Tensor: shape=(417, 64), dtype=float32, numpy=
array([[0. , 1. , 0.1, ..., 0. , 0. , 0. ],
       [0. , 0.1, 0.1, ..., 0. , 0. , 0. ],
       [0. , 0.1, 0.1, ..., 0. , 0. , 0. ],
       ...,
       [0. , 1. , 0.1, ..., 0. , 0. , 0. ],
       [0. , 0.1, 0.1, ..., 0. , 0. , 0. ],
       [0. , 0.1, 0.1, ..., 0. , 0. , 0. ]], dtype=float32)>

In [8]:
# split up
val_share = 1-0.1
val_instances = round(dataset_tokens['input_ids'].shape[0] * val_share)

def create_splitted_datasets(tokens, labels, sample_weights, split_index, batch_size=None):
    print(labels)
    tokens1, tokens2 = {k: v[:split_index] for k, v in tokens.items()}, {k: v[split_index:] for k, v in tokens.items()}
    labels1, labels2 = labels[:split_index], labels[split_index:]
    sample_weights1, sample_weights2 = sample_weights[:split_index], sample_weights[split_index:]
    
    create_dataset = lambda t, l, w: tf.data.Dataset.from_tensor_slices(({'input_ids': t['input_ids'], 'attention_mask': t['attention_mask']}, 
                                                                         l, w))
    train_dataset = create_dataset(tokens1, labels1, sample_weights1)
    dev_dataset = create_dataset(tokens2, labels2, sample_weights2)
    
    if batch_size:
        train_dataset = train_dataset.batch(batch_size)
        dev_dataset = dev_dataset.batch(batch_size)
    
    return train_dataset, dev_dataset

train_dataset, dev_dataset = create_splitted_datasets(dataset_tokens, dataset_labels, dataset_sample_weights, 
                                                      val_instances, batch_size=8)
print(f"{len(train_dataset)} / {len(dev_dataset)}")

tf.Tensor(
[[0 2 1 ... 0 0 0]
 [0 1 1 ... 0 0 0]
 [0 1 1 ... 0 0 0]
 ...
 [0 2 1 ... 0 0 0]
 [0 1 1 ... 0 0 0]
 [0 1 1 ... 0 0 0]], shape=(417, 64), dtype=int32)
47 / 6


In [9]:
for i, s in enumerate(dev_dataset):
    if i == 1:
        print(len(sample_sentences[val_instances + i]), sample_sentences[val_instances + i])
        t, l, w = s
        print(t)
        print(l)
        print(w)
        print()

16 ['The', 'process', 'is', 'triggered', 'by', 'the', 'demand', 'of', 'a', 'functional', 'department', 'to', 'fill', 'a', 'post', '.']
{'input_ids': <tf.Tensor: shape=(8, 64), dtype=int32, numpy=
array([[  101,  1996,  2175,  4618,  2025,  8757,  1996,  5796,  2361,
         2055,  1996,  5080,  3431,  1010,  1996,  3040,  2951,  1010,
         1996,  8316,  4175,  2012,  4487, 25855, 16671,  2075,  1010,
         1998,  1996,  8316,  4175,  2012,  8272,  1012,   102,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0],
       [  101,  2016,  2059, 14148,  1996,  1999,  6767,  6610,  5167,
         1998,  3964,  1996,  7978,  3465,  2415,  2012,  1996,  9353,
         4168, 12943,  1998,  1996,  3141,  3465,  2415, 10489,  2005,
         2169,  2597,  2006,  1037,  3584,  2433,  1006,  1036,  1036,
       

In [10]:
print(dev_dataset)

<BatchDataset element_spec=({'input_ids': TensorSpec(shape=(None, 64), dtype=tf.int32, name=None), 'attention_mask': TensorSpec(shape=(None, 64), dtype=tf.int32, name=None)}, TensorSpec(shape=(None, 64), dtype=tf.int32, name=None), TensorSpec(shape=(None, 64), dtype=tf.float32, name=None))>


## Create Model

In [11]:
num_labels = 4

In [12]:
# user num_labels=4 when classifying only in XOR, AND and OTHER; or num_labels=9 if in all classes
token_cls_model = transformers.TFAutoModelForTokenClassification.from_pretrained("distilbert-base-uncased", num_labels=num_labels)

Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertForTokenClassification: ['activation_13', 'vocab_transform', 'vocab_layer_norm', 'vocab_projector']
- This IS expected if you are initializing TFDistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFDistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier', 'dropout_19']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inferenc

In [13]:
type(token_cls_model)
token_cls_model.summary()
token_cls_model.layers[2].activation

Model: "tf_distil_bert_for_token_classification"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 distilbert (TFDistilBertMai  multiple                 66362880  
 nLayer)                                                         
                                                                 
 dropout_19 (Dropout)        multiple                  0         
                                                                 
 classifier (Dense)          multiple                  3076      
                                                                 
Total params: 66,365,956
Trainable params: 66,365,956
Non-trainable params: 0
_________________________________________________________________


<function keras.activations.linear(x)>

### a) including Model in native Keras

In [114]:
import argparse
import os
import datetime
import re
from keras import backend as K

parser = argparse.ArgumentParser()
# Standard params
parser.add_argument("--batch_size", default=8, type=int, help="Batch size.")
parser.add_argument("--epochs", default=1, type=int, help="Number of epochs.")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
# Architecture params
parser.add_argument("--huggingface_model_name", default="distilbert-base-uncased", type=str, help="Model checkpoint")

args = parser.parse_args([] if "__file__" not in globals() else None)


class GatewayTokenClassifier(tf.keras.Model):

    def __init__(self, args: argparse.Namespace, model, train_dataset: tf.data.Dataset) -> None:

        # A) OPTIMIZER
        optimizer, lr_schedule = transformers.create_optimizer(
            init_lr=2e-5,
            num_train_steps=(len(train_dataset) // args.batch_size) * args.epochs,
            weight_decay_rate=0.01,
            num_warmup_steps=0,
        )

        # B) ARCHITECTURE
        inputs = {
            "input_ids": tf.keras.layers.Input(shape=[None], dtype=tf.int32),
            "attention_mask": tf.keras.layers.Input(shape=[None], dtype=tf.int32)
        }
        bert_output = model(inputs).logits  # includes one linear dense layer
        predictions = tf.keras.layers.Dense(num_labels, activation=tf.nn.softmax)(bert_output)
        super().__init__(inputs=inputs, outputs=predictions)

        # C) COMPILE
        
        # create custom metrics
        def filter_y_for_target_label(y_true, y_pred, target_label):
            y_true = tf.cast(y_true, tf.int32)
            y_pred = tf.cast(y_pred, tf.int32)
            y_true_filtered = tf.where(tf.equal(y_true, target_label), 1, tf.zeros_like(y_true))
            y_pred_filtered = tf.where(tf.equal(y_pred, target_label), 1, tf.zeros_like(y_pred))
            return y_true_filtered, y_pred_filtered
        
        def TPs(y_true, y_pred):
            """ assume y_true and y_pred with 0 and 1 values for binary classification """
            return K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        
        def f1(precision, recall):
            return 2 * ( tf.math.multiply(precision, recall ) / ( precision + recall + K.epsilon() ) )
        
        def xor_precision(y_true, y_pred):
            y_pred = tf.math.argmax(y_pred, axis=2)
            y_true_filtered, y_pred_filtered = filter_y_for_target_label(y_true, y_pred, XOR_LABEL)
            true_positives = TPs(y_true_filtered, y_pred_filtered)
            predicted_positives = K.sum(K.round(K.clip(y_pred_filtered, 0, 1)))
            if predicted_positives == 0:
                return tf.constant(0.0, dtype=tf.float32)
            else:
                return tf.cast(true_positives / (predicted_positives), tf.float32)

        def xor_recall(y_true, y_pred):
            y_pred = tf.math.argmax(y_pred, axis=2)
            y_true_filtered, y_pred_filtered = filter_y_for_target_label(y_true, y_pred, XOR_LABEL)
            true_positives = TPs(y_true_filtered, y_pred_filtered)
            real_positives = K.sum(K.round(K.clip(y_true_filtered, 0, 1)))
            if real_positives == 0:
                return tf.constant(0.0, dtype=tf.float32)
            else:
                return tf.cast(true_positives / (real_positives), tf.float32)

        def xor_f1(y_true, y_pred):
            precision = xor_precision(y_true, y_pred)
            recall = xor_recall(y_true, y_pred)
            return tf.cast(f1(precision, recall), tf.float32)
        
        def and_precision(y_true, y_pred):
            y_pred = tf.math.argmax(y_pred, axis=2)
            y_true_filtered, y_pred_filtered = filter_y_for_target_label(y_true, y_pred, AND_LABEL)
            true_positives = TPs(y_true_filtered, y_pred_filtered)
            predicted_positives = K.sum(K.round(K.clip(y_pred_filtered, 0, 1)))
            if predicted_positives == 0:
                return tf.constant(0.0, dtype=tf.float32)
            else:
                return tf.cast(true_positives / (predicted_positives), tf.float32)

        def and_recall(y_true, y_pred):
            y_pred = tf.math.argmax(y_pred, axis=2)
            y_true_filtered, y_pred_filtered = filter_y_for_target_label(y_true, y_pred, AND_LABEL)
            true_positives = TPs(y_true_filtered, y_pred_filtered)
            real_positives = K.sum(K.round(K.clip(y_true_filtered, 0, 1)))
            if real_positives == 0:
                return tf.constant(0.0, dtype=tf.float32)
            else:
                return tf.cast(true_positives / (real_positives), tf.float32)

        def and_f1(y_true, y_pred):
            precision = and_precision(y_true, y_pred)
            recall = and_recall(y_true, y_pred)
            return tf.cast(f1(precision, recall), tf.float32)
        
        # actual compile
        
        self.compile(optimizer=optimizer,
                     #loss=custom_loss,
                     loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                     # general accuracy of all labels (except 0 class for padding tokens)
                     weighted_metrics=[tf.metrics.SparseCategoricalAccuracy(name="Overall Accuracy")],
                     # metrics for classes of interest
                     metrics=[xor_precision, xor_recall, xor_f1, and_recall, and_precision, and_f1])
        #token_cls_model.summary()
        #self.summary()


# Create logdir name
args.logdir = os.path.join("data/logs", "{}-{}-{}".format(
    os.path.basename(globals().get("__file__", "notebook")),
    datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S"),
    ",".join(("{}={}".format(re.sub("(.)[^_]*_?", r"\1", k), v) for k, v in sorted(vars(args).items())))
))


# Create the model and train it
model = GatewayTokenClassifier(args, token_cls_model, train_dataset)
model.fit(
    train_dataset, batch_size=args.batch_size, epochs=3, validation_data=dev_dataset,
    callbacks=[tf.keras.callbacks.TensorBoard(args.logdir, histogram_freq=1, update_freq=100, profile_batch=0)]
)

Epoch 1/3

KeyboardInterrupt: 

In [21]:
# : Predict the tags on the test set.
print("create dev set predictions")
predictions = model.predict(dev_dataset)

create dev set predictions


In [26]:
import numpy as np
print(type(predictions))

if False:  # in case without additional Dense softmax head
    predictions = predictions.logits
    print("predictions shape:", tf.shape(logits))
elif True:  # in case WITH additional Dense softmax head
    print("Predictions shape:", tf.shape(predictions))
print()

for i, sample in enumerate(predictions):
    print(f"sample {i}".center(50, '-'))
    sample_word_ids = dataset_word_ids[val_instances +i]
    print(sample_sentences[val_instances+i])
    important_token_indices = [i for i, word_id in enumerate(sample_word_ids) if word_id != None]
    for j, token in enumerate(sample):
        if j in important_token_indices:
            print(j, token, end=' ')
            token_label = np.argmax(token)
            print(token_label, end=' ')
            if token_label in [2,3]:
                print("!", end=' ')
            print(sample_sentences[val_instances+i][sample_word_ids[j]], end=' ')
            print()
    print()
    

<class 'numpy.ndarray'>
Predictions shape: tf.Tensor([42 64  4], shape=(3,), dtype=int32)

---------------------sample 0---------------------
['If', 'the', 'treasurer', 'accepts', 'the', 'expenses', 'for', 'processing', ',', 'the', 'report', 'moves', 'to', 'an', 'automatic', 'activity', 'that', 'links', 'to', 'a', 'payment', 'system', '.']
1 [0.17884886 0.4014421  0.16047417 0.25923488] 1 If 
2 [0.16937277 0.42546183 0.16456768 0.24059774] 1 the 
3 [0.18682188 0.39350954 0.18596828 0.23370034] 1 treasurer 
4 [0.1485111  0.46505818 0.16611001 0.22032063] 1 accepts 
5 [0.14525631 0.4644485  0.17157966 0.21871552] 1 the 
6 [0.16051927 0.4250386  0.15868732 0.2557548 ] 1 expenses 
7 [0.16613026 0.41736063 0.17639527 0.24011391] 1 for 
8 [0.14881557 0.46527088 0.15136002 0.23455349] 1 processing 
9 [0.28957155 0.20086193 0.22887923 0.28068724] 0 , 
10 [0.15354498 0.45084158 0.15365063 0.24196286] 1 the 
11 [0.17816067 0.40415916 0.16935892 0.24832128] 1 report 
12 [0.1327977  0.5075081  0.1

#### (Debug Metrics)

In [101]:
y_true1 = tf.cast(tf.constant([[1, 1, 0],
                               [0, 2, 1]]), tf.int32)
y_pred1 = tf.cast(tf.constant([[1, 1, 0],
                               [2, 2, 2]]), tf.int32)

def own_precision(y_true, y_pred, target_label):
    from keras import backend as K
    
    # transform into 0 (rest) and 1 (target_label)
    y_true_filtered = tf.where(tf.equal(y_true, target_label), 1, tf.zeros_like(y_true))
    y_pred_filtered = tf.where(tf.equal(y_pred, target_label), 1, tf.zeros_like(y_pred))
    
    true_positives = K.sum(K.round(K.clip(y_true_filtered * y_pred_filtered, 0, 1))) 
    predicted_positives = K.sum(K.round(K.clip(y_pred_filtered, 0, 1)))

    if predicted_positives == 0:
        return tf.constant(0.0, dtype=tf.float32)
    else:
        precision = true_positives / (predicted_positives)  #  + K.epsilon()
        return tf.cast(precision, tf.float32)

print("precision", own_precision(y_true1, y_pred1, 2))
    
def own_recall(y_true, y_pred, target_label):
    from keras import backend as K
    
    # transform into 0 (rest) and 1 (target_label)
    y_true_filtered = tf.where(tf.equal(y_true, target_label), 1, tf.zeros_like(y_true))
    y_pred_filtered = tf.where(tf.equal(y_pred, target_label), 1, tf.zeros_like(y_pred))
    
    true_positives = K.sum(K.round(K.clip(y_true_filtered * y_pred_filtered, 0, 1))) 
    real_positives = K.sum(K.round(K.clip(y_true_filtered, 0, 1)))

    if real_positives == 0:
        return tf.constant(0.0, dtype=tf.float32)
    else:
        recall = true_positives / (real_positives)  #  + K.epsilon()
        return tf.cast(recall, tf.float32)

print("recall", own_recall(y_true1, y_pred1, 2))

def own_f1(y_true, y_pred, target_label):
    from keras import backend as K
    precision = own_precision(y_true, y_pred, target_label)
    recall = own_recall(y_true, y_pred, target_label)
    f1 = tf.constant(2*(tf.math.multiply(precision, recall)/(precision+recall+K.epsilon())))
    return tf.cast(f1, tf.float32)

print("f1", own_f1(y_true1, y_pred1, 2))

precision tf.Tensor(0.33333334, shape=(), dtype=float32)
recall tf.Tensor(1.0, shape=(), dtype=float32)
f1 tf.Tensor(0.49999997, shape=(), dtype=float32)


In [200]:
# results of 4 label cls

<class 'transformers.modeling_tf_outputs.TFTokenClassifierOutput'>
Logits shape: tf.Tensor([42 64  4], shape=(3,), dtype=int32)

---------------------sample 0---------------------
['If', 'the', 'treasurer', 'accepts', 'the', 'expenses', 'for', 'processing', ',', 'the', 'report', 'moves', 'to', 'an', 'automatic', 'activity', 'that', 'links', 'to', 'a', 'payment', 'system', '.']
1 [-0.68042624 -0.06963277 -0.05562851 -0.45490682] 2 ! If 
2 [-0.67450744 -0.07063795 -0.07084563 -0.4472993 ] 1 the 
3 [-0.68007725 -0.06729331 -0.0786977  -0.443125  ] 1 treasurer 
4 [-0.66857904 -0.07492305 -0.05826456 -0.44343913] 2 ! accepts 
5 [-0.67852056 -0.08177662 -0.08233006 -0.43491125] 1 the 
6 [-0.68491924 -0.07899152 -0.06896686 -0.45671514] 2 ! expenses 
7 [-0.6832283  -0.07021421 -0.06331892 -0.4526383 ] 2 ! for 
8 [-0.6835841  -0.06967415 -0.06367798 -0.44890505] 2 ! processing 
9 [-0.67150086 -0.06618639 -0.07021592 -0.44545662] 1 , 
10 [-0.6836521  -0.06786406 -0.05845393 -0.44633073] 2 ! the

### b) with transformers Trainer (NOT WORKING WITH TF)

In [13]:
batch_size = 8
args = transformers.TrainingArguments(
    output_dir='./results',
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01
)

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [p for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [l for (p, l) in zip(prediction, label) if l != -100] 
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

trainer = transformers.TFTrainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

trainer.train()



ImportError: Method `n_gpu` requires PyTorch.