In [None]:
import pandas as pd
import os
from nltk.tokenize import word_tokenize
import nltk
nltk.download('punkt')
!pip install transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM #, BertTokenizerFast
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from tqdm import tqdm
import itertools
from sklearn.metrics import classification_report
import numpy as np
from sklearn.model_selection import train_test_split
import sys

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.2-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m69.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m99.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.13.3 tokenizers-0.13.2 transformers-4.27.2


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
LEARNING_RATE = 0.00005
EPOCHS = 10
BATCH_SIZE = 4
SEED = 1
SAVE_PATH = 'model.pth'
data = pd.read_csv('drive/MyDrive/rebel_format_v2.csv')
df_train, df_val = train_test_split(data, test_size=0.1, random_state=SEED)
del data

In [None]:
model_checkpoint = "Babelscape/rebel-large"
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.42k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/1.23k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/123 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/344 [00:00<?, ?B/s]

In [None]:
class DataSequence(torch.utils.data.Dataset):

    def __init__(self, df):
        txt = df['context'].tolist()
        self.texts = tokenizer(txt, padding='max_length', max_length=128, truncation=True, return_tensors="pt")

        labels = df['triplets'].to_list()
        self.labels = tokenizer(labels, padding='max_length', max_length=128, truncation=True, return_tensors="pt")


    def __len__(self):
        return len(self.labels['input_ids'])

    def get_batch_data(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.texts.items()}
        return item

    def get_batch_labels(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.labels.items()}
        return item

    def __getitem__(self, idx):
        batch_data = self.get_batch_data(idx)
        batch_labels = self.get_batch_labels(idx)

        return batch_data, batch_labels

In [None]:
def train_loop(model, df_train, df_val):
    train_dataset = DataSequence(df_train)
    val_dataset = DataSequence(df_val)

    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # optimizer = SGD(model.parameters(), lr=LEARNING_RATE)
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    # create a scheduler that reduces the learning rate by a factor of 0.1 every 10 epochs
    #scheduler = StepLR(optimizer, step_size=3, gamma=0.1)

    if use_cuda:
        model = model.cuda()


    criterion = torch.nn.CrossEntropyLoss()
    best_metric = 0

    for epoch_num in range(EPOCHS):


        model.train()

        total_loss_train = 0

        for train_data, train_label in tqdm(train_dataloader):


            train_label = train_label['input_ids'].to(device)

            mask = train_data['attention_mask'].to(device)
            input_id = train_data['input_ids'].to(device)

            optimizer.zero_grad()

            loss = model(input_id, mask, labels= train_label).loss
            
            total_loss_train += loss.item()
        
            loss.backward() # Update the weights
            optimizer.step() # Notify optimizer that a batch is done.
            optimizer.zero_grad() # Reset the optimer

        model.eval()

        total_loss_val = 0
        pred = []
        gt = []

        for val_data, val_label in val_dataloader:

            val_label = val_label['input_ids'].to(device)
            mask = val_data['attention_mask'].to(device)
            input_id = val_data['input_ids'].to(device)

            loss = model(input_id, mask, labels=val_label).loss
            total_loss_val += loss.item()
            
            outputs = model.generate(input_id)
            outputs=tokenizer.batch_decode(outputs, skip_special_tokens=False)

            labels = tokenizer.batch_decode(val_label, skip_special_tokens=False)

            
            gt = gt + extract_triplets(labels, gold_extraction=True)
            pred = pred + extract_triplets(outputs, gold_extraction=False)

            del outputs, labels
        combined_metric = 0

        scores, precision, recall, f1= re_score(pred, gt, 'relation')
        combined_metric += scores["ALL"]["Macro_f1"]

        scores, precision, recall, f1= re_score(pred, gt, 'subject')
        combined_metric += scores["ALL"]["Macro_f1"]

        scores, precision, recall, f1= re_score(pred, gt, 'object')
        combined_metric = (combined_metric + scores["ALL"]["Macro_f1"]) /3

        best_metric = check_best_performing(model, best_metric, combined_metric, SAVE_PATH)
        del scores, precision, recall, f1



        # adjust the learning rate using the scheduler
        #scheduler.step()

        print(
            f'Epochs: {epoch_num + 1} | Loss: {total_loss_train / len(df_train): .6f} | Val_Loss: {total_loss_val / len(df_val): .6f}')


In [None]:
def extract_triplets(texts, gold_extraction):
    triplets = []
    for text in texts:
        try:
            text = ''.join(text).replace('<s>', '').replace('</s>', '').replace('<pad>', '')
            relation = ''
            for token in text.split():
                if token == "<triplet>":
                    current = 't'
                    if relation != '':
                        triplets.append((subject, relation, object_))
                        relation = ''
                    subject = ''
                elif token == "<subj>":
                    current = 's'
                    if relation != '':
                        triplets.append((subject, relation, object_))
                    object_ = ''
                elif token == "<obj>":
                    current = 'o'
                    relation = ''
                else:
                    if current == 't':
                        subject += ' ' + token
                    elif current == 's':
                        object_ += ' ' + token
                    elif current == 'o':
                        relation += ' ' + token
            triplets.append((subject.strip(), relation.strip(), object_.strip()))
        except:
            if gold_extraction:
                print("Gold labels should always be extracted correctly. Exiting")
                sys.exit()
            triplets.append(("Invalid", "Invalid", "Invalid"))

    return triplets

In [None]:
def re_score(predictions, ground_truths, type):
    """Evaluate RE predictions
    Args:
        predictions (list) :  list of list of predicted relations (several relations in each sentence)
        ground_truths (list) :    list of list of ground truth relations
            rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
                    "tail": (start_idx (inclusive), end_idx (exclusive)),
                    "head_type": ent_type,
                    "tail_type": ent_type,
                    "type": rel_type}
        vocab (Vocab) :         dataset vocabulary
        mode (str) :            in 'strict' or 'boundaries' """
    if type == 'relation':
        vocab = ['cause', 'enable', 'prevent', 'intend']
        predictions = [pred[1] for pred in predictions]
        ground_truths = [gt[1] for gt in ground_truths]

    elif type == 'subject':
        predictions = [pred[0] for pred in predictions]
        ground_truths = [gt[0] for gt in ground_truths]
        #vocab = ['Invalid'] #Create the vocabulary of possible tags
        vocab = np.unique(ground_truths).tolist()

    elif type == 'object':
        predictions = [pred[2] for pred in predictions]
        ground_truths = [gt[2] for gt in ground_truths]
        #vocab = ['Invalid']
        vocab = np.unique(ground_truths).tolist()

    scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in vocab + ["ALL"]}

    # Count GT relations and Predicted relations
    n_sents = len(ground_truths)
    n_rels = n_sents #Since every 'sentence' has only 1 relation
    n_found = n_sents

    # Count TP, FP and FN per type
    for pred_sent, gt_sent in zip(predictions, ground_truths):
        for entity in vocab:

            if pred_sent == entity:
                pred_entities = {pred_sent}
            else:
                pred_entities = set()

            if gt_sent == entity:
                gt_entities = {gt_sent}

            else:
                gt_entities = set()

            scores[entity]["tp"] += len(pred_entities & gt_entities)
            scores[entity]["fp"] += len(pred_entities - gt_entities)
            scores[entity]["fn"] += len(gt_entities - pred_entities)

    # Compute per relation Precision / Recall / F1
    for entity in scores.keys():
        if scores[entity]["tp"]:
            scores[entity]["p"] = 100 * scores[entity]["tp"] / (scores[entity]["fp"] + scores[entity]["tp"])
            scores[entity]["r"] = 100 * scores[entity]["tp"] / (scores[entity]["fn"] + scores[entity]["tp"])
        else:
            scores[entity]["p"], scores[entity]["r"] = 0, 0

        if not scores[entity]["p"] + scores[entity]["r"] == 0:
            scores[entity]["f1"] = 2 * scores[entity]["p"] * scores[entity]["r"] / (
                    scores[entity]["p"] + scores[entity]["r"])
        else:
            scores[entity]["f1"] = 0

    # Compute micro F1 Scores
    tp = sum([scores[entity]["tp"] for entity in vocab])
    fp = sum([scores[entity]["fp"] for entity in vocab])
    fn = sum([scores[entity]["fn"] for entity in vocab])

    if tp:
        precision = 100 * tp / (tp + fp)
        recall = 100 * tp / (tp + fn)
        f1 = 2 * precision * recall / (precision + recall)

    else:
        precision, recall, f1 = 0, 0, 0

    scores["ALL"]["p"] = precision
    scores["ALL"]["r"] = recall
    scores["ALL"]["f1"] = f1
    scores["ALL"]["tp"] = tp
    scores["ALL"]["fp"] = fp
    scores["ALL"]["fn"] = fn

    # Compute Macro F1 Scores
    scores["ALL"]["Macro_f1"] = np.mean([scores[ent_type]["f1"] for ent_type in vocab])
    scores["ALL"]["Macro_p"] = np.mean([scores[ent_type]["p"] for ent_type in vocab])
    scores["ALL"]["Macro_r"] = np.mean([scores[ent_type]["r"] for ent_type in vocab])

    #print(f"RE Evaluation in *** {mode.upper()} *** mode")

    if type == 'relation':
        print(
            "processed {} sentences with {} entities; found: {} relations; correct: {}.".format(n_sents, n_rels, n_found,
                                                                                                 tp))
        print(
            "\tALL\t TP: {};\tFP: {};\tFN: {}".format(
                scores["ALL"]["tp"],
                scores["ALL"]["fp"],
                scores["ALL"]["fn"]))
        print(
            "\t\t(m avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (micro)".format(
                precision,
                recall,
                f1))
        print(
            "\t\t(M avg): precision: {:.2f};\trecall: {:.2f};\tf1: {:.2f} (Macro)\n".format(
                scores["ALL"]["Macro_p"],
                scores["ALL"]["Macro_r"],
                scores["ALL"]["Macro_f1"]))

        for entity in vocab:
            print("\t{}: \tTP: {};\tFP: {};\tFN: {};\tprecision: {:.2f};\trecall: {:.2f};\tf1: {:.2f};\t{}".format(
                entity,
                scores[entity]["tp"],
                scores[entity]["fp"],
                scores[entity]["fn"],
                scores[entity]["p"],
                scores[entity]["r"],
                scores[entity]["f1"],
                scores[entity]["tp"] +
                scores[entity][
                    "fp"]))

    else:
        print(f"Macro F1 for {type}: {scores['ALL']['Macro_f1']:.4f}")

    return scores, precision, recall, f1


def calc_acc(predictions, gold):
    num_ner = len(predictions) #The total number of entities
    acc_subj_correct = 0
    acc_obj_correct = 0

    for pred, gt in zip(predictions, gold):
        if pred[0] == gt[0]: #The subjects match
            acc_subj_correct += 1
        
        if pred[2] == gt[2]: #The objects match
            acc_obj_correct +=1
    
    acc_subj_correct = acc_subj_correct / num_ner
    acc_obj_correct = acc_obj_correct / num_ner

    print(f"acc subject: {acc_subj_correct} acc object: {acc_obj_correct}")

    return acc_subj_correct, acc_obj_correct


def check_best_performing(model, best_metric, new_metric, PATH):
    if new_metric > best_metric:
        torch.save(model, PATH)
        print("New best model found, saving...")
        best_metric = new_metric
    return best_metric


In [None]:
train_loop(model, df_train, df_val)

  item = {key: torch.tensor(val[idx]) for key, val in self.texts.items()}
  item = {key: torch.tensor(val[idx]) for key, val in self.labels.items()}
100%|██████████| 436/436 [04:12<00:00,  1.72it/s]


processed 194 sentences with 194 entities; found: 194 relations; correct: 0.
	ALL	 TP: 0;	FP: 0;	FN: 194
		(m avg): precision: 0.00;	recall: 0.00;	f1: 0.00 (micro)
		(M avg): precision: 0.00;	recall: 0.00;	f1: 0.00 (Macro)

	cause: 	TP: 0;	FP: 0;	FN: 29;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	enable: 	TP: 0;	FP: 0;	FN: 44;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	prevent: 	TP: 0;	FP: 0;	FN: 46;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
	intend: 	TP: 0;	FP: 0;	FN: 75;	precision: 0.00;	recall: 0.00;	f1: 0.00;	0
Macro F1 for subject: 0.0000
Macro F1 for object: 0.0000
Epochs: 1 | Loss:  0.209634 | Val_Loss:  0.014293


100%|██████████| 436/436 [04:08<00:00,  1.75it/s]


processed 194 sentences with 194 entities; found: 194 relations; correct: 182.
	ALL	 TP: 182;	FP: 12;	FN: 12
		(m avg): precision: 93.81;	recall: 93.81;	f1: 93.81 (micro)
		(M avg): precision: 94.21;	recall: 92.89;	f1: 93.38 (Macro)

	cause: 	TP: 24;	FP: 1;	FN: 5;	precision: 96.00;	recall: 82.76;	f1: 88.89;	25
	enable: 	TP: 42;	FP: 6;	FN: 2;	precision: 87.50;	recall: 95.45;	f1: 91.30;	48
	prevent: 	TP: 46;	FP: 0;	FN: 0;	precision: 100.00;	recall: 100.00;	f1: 100.00;	46
	intend: 	TP: 70;	FP: 5;	FN: 5;	precision: 93.33;	recall: 93.33;	f1: 93.33;	75
Macro F1 for subject: 50.2121
Macro F1 for object: 45.3106
New best model found, saving...
Epochs: 2 | Loss:  0.008307 | Val_Loss:  0.014807


100%|██████████| 436/436 [04:09<00:00,  1.75it/s]


processed 194 sentences with 194 entities; found: 194 relations; correct: 181.
	ALL	 TP: 181;	FP: 13;	FN: 13
		(m avg): precision: 93.30;	recall: 93.30;	f1: 93.30 (micro)
		(M avg): precision: 93.41;	recall: 93.43;	f1: 93.40 (Macro)

	cause: 	TP: 27;	FP: 3;	FN: 2;	precision: 90.00;	recall: 93.10;	f1: 91.53;	30
	enable: 	TP: 39;	FP: 3;	FN: 5;	precision: 92.86;	recall: 88.64;	f1: 90.70;	42
	prevent: 	TP: 46;	FP: 0;	FN: 0;	precision: 100.00;	recall: 100.00;	f1: 100.00;	46
	intend: 	TP: 69;	FP: 7;	FN: 6;	precision: 90.79;	recall: 92.00;	f1: 91.39;	76
Macro F1 for subject: 46.9136
Macro F1 for object: 47.3196
Epochs: 3 | Loss:  0.006097 | Val_Loss:  0.010562


100%|██████████| 436/436 [04:08<00:00,  1.75it/s]


processed 194 sentences with 194 entities; found: 194 relations; correct: 183.
	ALL	 TP: 183;	FP: 11;	FN: 11
		(m avg): precision: 94.33;	recall: 94.33;	f1: 94.33 (micro)
		(M avg): precision: 93.24;	recall: 95.86;	f1: 94.32 (Macro)

	cause: 	TP: 29;	FP: 5;	FN: 0;	precision: 85.29;	recall: 100.00;	f1: 92.06;	34
	enable: 	TP: 42;	FP: 3;	FN: 2;	precision: 93.33;	recall: 95.45;	f1: 94.38;	45
	prevent: 	TP: 46;	FP: 2;	FN: 0;	precision: 95.83;	recall: 100.00;	f1: 97.87;	48
	intend: 	TP: 66;	FP: 1;	FN: 9;	precision: 98.51;	recall: 88.00;	f1: 92.96;	67
Macro F1 for subject: 55.8553
Macro F1 for object: 46.8484
New best model found, saving...
Epochs: 4 | Loss:  0.005036 | Val_Loss:  0.010539


100%|██████████| 436/436 [04:08<00:00,  1.75it/s]


processed 194 sentences with 194 entities; found: 194 relations; correct: 183.
	ALL	 TP: 183;	FP: 11;	FN: 11
		(m avg): precision: 94.33;	recall: 94.33;	f1: 94.33 (micro)
		(M avg): precision: 93.35;	recall: 95.04;	f1: 94.04 (Macro)

	cause: 	TP: 27;	FP: 4;	FN: 2;	precision: 87.10;	recall: 93.10;	f1: 90.00;	31
	enable: 	TP: 43;	FP: 6;	FN: 1;	precision: 87.76;	recall: 97.73;	f1: 92.47;	49
	prevent: 	TP: 46;	FP: 0;	FN: 0;	precision: 100.00;	recall: 100.00;	f1: 100.00;	46
	intend: 	TP: 67;	FP: 1;	FN: 8;	precision: 98.53;	recall: 89.33;	f1: 93.71;	68
Macro F1 for subject: 53.8933
Macro F1 for object: 55.5762
New best model found, saving...
Epochs: 5 | Loss:  0.004483 | Val_Loss:  0.009315


100%|██████████| 436/436 [04:09<00:00,  1.75it/s]


processed 194 sentences with 194 entities; found: 194 relations; correct: 185.
	ALL	 TP: 185;	FP: 9;	FN: 9
		(m avg): precision: 95.36;	recall: 95.36;	f1: 95.36 (micro)
		(M avg): precision: 94.47;	recall: 96.77;	f1: 95.44 (Macro)

	cause: 	TP: 29;	FP: 4;	FN: 0;	precision: 87.88;	recall: 100.00;	f1: 93.55;	33
	enable: 	TP: 43;	FP: 4;	FN: 1;	precision: 91.49;	recall: 97.73;	f1: 94.51;	47
	prevent: 	TP: 46;	FP: 0;	FN: 0;	precision: 100.00;	recall: 100.00;	f1: 100.00;	46
	intend: 	TP: 67;	FP: 1;	FN: 8;	precision: 98.53;	recall: 89.33;	f1: 93.71;	68
Macro F1 for subject: 56.0814
Macro F1 for object: 53.3064
New best model found, saving...
Epochs: 6 | Loss:  0.004371 | Val_Loss:  0.013144


100%|██████████| 436/436 [04:09<00:00,  1.75it/s]


processed 194 sentences with 194 entities; found: 194 relations; correct: 185.
	ALL	 TP: 185;	FP: 9;	FN: 9
		(m avg): precision: 95.36;	recall: 95.36;	f1: 95.36 (micro)
		(M avg): precision: 94.42;	recall: 97.00;	f1: 95.45 (Macro)

	cause: 	TP: 29;	FP: 4;	FN: 0;	precision: 87.88;	recall: 100.00;	f1: 93.55;	33
	enable: 	TP: 44;	FP: 5;	FN: 0;	precision: 89.80;	recall: 100.00;	f1: 94.62;	49
	prevent: 	TP: 46;	FP: 0;	FN: 0;	precision: 100.00;	recall: 100.00;	f1: 100.00;	46
	intend: 	TP: 66;	FP: 0;	FN: 9;	precision: 100.00;	recall: 88.00;	f1: 93.62;	66
Macro F1 for subject: 55.4274
Macro F1 for object: 56.0730
New best model found, saving...
Epochs: 7 | Loss:  0.003959 | Val_Loss:  0.010302


100%|██████████| 436/436 [04:09<00:00,  1.75it/s]


processed 194 sentences with 194 entities; found: 194 relations; correct: 179.
	ALL	 TP: 179;	FP: 13;	FN: 15
		(m avg): precision: 93.23;	recall: 92.27;	f1: 92.75 (micro)
		(M avg): precision: 95.08;	recall: 91.35;	f1: 93.03 (Macro)

	cause: 	TP: 26;	FP: 0;	FN: 3;	precision: 100.00;	recall: 89.66;	f1: 94.55;	26
	enable: 	TP: 37;	FP: 3;	FN: 7;	precision: 92.50;	recall: 84.09;	f1: 88.10;	40
	prevent: 	TP: 44;	FP: 0;	FN: 2;	precision: 100.00;	recall: 95.65;	f1: 97.78;	44
	intend: 	TP: 72;	FP: 10;	FN: 3;	precision: 87.80;	recall: 96.00;	f1: 91.72;	82
Macro F1 for subject: 51.3052
Macro F1 for object: 48.0216
Epochs: 8 | Loss:  0.004028 | Val_Loss:  0.010076


100%|██████████| 436/436 [04:08<00:00,  1.75it/s]


processed 194 sentences with 194 entities; found: 194 relations; correct: 179.
	ALL	 TP: 179;	FP: 15;	FN: 15
		(m avg): precision: 92.27;	recall: 92.27;	f1: 92.27 (micro)
		(M avg): precision: 91.12;	recall: 92.65;	f1: 91.68 (Macro)

	cause: 	TP: 25;	FP: 4;	FN: 4;	precision: 86.21;	recall: 86.21;	f1: 86.21;	29
	enable: 	TP: 43;	FP: 5;	FN: 1;	precision: 89.58;	recall: 97.73;	f1: 93.48;	48
	prevent: 	TP: 46;	FP: 5;	FN: 0;	precision: 90.20;	recall: 100.00;	f1: 94.85;	51
	intend: 	TP: 65;	FP: 1;	FN: 10;	precision: 98.48;	recall: 86.67;	f1: 92.20;	66
Macro F1 for subject: 54.0725
Macro F1 for object: 51.5017
Epochs: 9 | Loss:  0.003957 | Val_Loss:  0.010751


100%|██████████| 436/436 [04:08<00:00,  1.75it/s]


processed 194 sentences with 194 entities; found: 194 relations; correct: 172.
	ALL	 TP: 172;	FP: 21;	FN: 22
		(m avg): precision: 89.12;	recall: 88.66;	f1: 88.89 (micro)
		(M avg): precision: 89.58;	recall: 90.13;	f1: 88.96 (Macro)

	cause: 	TP: 25;	FP: 4;	FN: 4;	precision: 86.21;	recall: 86.21;	f1: 86.21;	29
	enable: 	TP: 44;	FP: 17;	FN: 0;	precision: 72.13;	recall: 100.00;	f1: 83.81;	61
	prevent: 	TP: 44;	FP: 0;	FN: 2;	precision: 100.00;	recall: 95.65;	f1: 97.78;	44
	intend: 	TP: 59;	FP: 0;	FN: 16;	precision: 100.00;	recall: 78.67;	f1: 88.06;	59
Macro F1 for subject: 52.1675
Macro F1 for object: 54.1295
Epochs: 10 | Loss:  0.004413 | Val_Loss:  0.028370


In [None]:
model = torch.load('drive/MyDrive/model_best1.pth').to('cuda:0')
text = ["because the machine is old, it is unreliable", "many people have died in the storm", "now the preparation is complete, we can start again",
        "the restrictions made sure less people got infected", "I am running everyday because i want to run a marathon", "the elevator is fixed, so i can go up again",
        "There was a traffic jam, so i was late", "I broke my leg, so I can't run the marathon", "Since I failed the exam, I can't graduate",
        "I did some shopping, because i want to cook later", "I shouldn't have said that, I did not mean that", "I wanted to say that", "I am planning on doing that later",
        "I intend on doing that", "My car is fixed, so i can drive now", "I broke my leg, so I can't walk", "I broke my leg, so I can't join", "I am learning French, because I want to be fluent"]
encoding = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to("cuda:0")

# forward pass
outputs = model.generate(**encoding, do_sample=True)
decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)

for t, o in zip(text, decoded_output):
  print(f"{t} - {o}")



because the machine is old, it is unreliable -  old  unreliable  cause
many people have died in the storm -  storm  died  cause
now the preparation is complete, we can start again -  preparation  start  cause
the restrictions made sure less people got infected -  restrictions  infected  prevent
I am running everyday because i want to run a marathon -  want  running  cause
the elevator is fixed, so i can go up again -  fixed  go  cause
There was a traffic jam, so i was late -  jam  late  cause
I broke my leg, so I can't run the marathon -  broke  run  cause
Since I failed the exam, I can't graduate -  failed  graduate  cause
I did some shopping, because i want to cook later -  want  did  cause
I shouldn't have said that, I did not mean that -  said  said  cause
I wanted to say that -  wanted  said  cause
I am planning on doing that later -  planning  doing  cause
I intend on doing that -  doing  doing  intend
My car is fixed, so i can drive now -  fixed  drive  cause
I broke my leg, so 