# Relation Classification
This notebook focuses on implementing a Relation Classifier model inspired by the architecture proposed in the paper [Enhancing evidence-based medicine with natural language argumentative analysis of clinical trials](https://pubmed.ncbi.nlm.nih.gov/34412851/). The task involves classifying the relation between two argument components within an abstract of a clinical trial and has been treated as a sequence classification problem, where the sequence consists of a pair of two components. In the original work, the model predicted relations as Support, Attack, or Partial-Attack.

In this experiment, the classification task has been simplified to a binary classification problem, where the model predicts whether there is a Relation or NoRelation between the two argument components. The fine-tuned model employed here is SciBERT, which was the best-performing model in the original paper. To address the issue of class imbalance in the dataset, normalized class weights were applied during training, as suggested by the authors.

In [None]:
train_data_dir =  "data/train/neoplasm_train"
val_data_dir =  "data/dev/neoplasm_dev"
test_data_dir =  "data/test/neoplasm_test"
neoplasm_test_data_dir = "data/test/neoplasm_test"
glaucoma_test_data_dir = "data/test/glaucoma_test"
mixed_test_data_dir = "data/test/mixed_test"
custom_train_data_dir = "data/custom_datasets/train"
custom_val_data_dir = "data/custom_datasets/val"
custom_test_data_dir = "data/custom_datasets/test"

In [None]:
from transformers import AutoTokenizer
import torch 
from torch.utils.data import DataLoader
import torch.nn as nn
from utils.utils_notebooks import display_special_tokens,display_dataset_item, create_dataframe_from_directory, setup_mappings, compute_norm_class_weights, tokenize_and_encode
from models.relation_classifier import RelationDataset, BERTSentClf
from utils.train import train, evaluate, predict


### 1. Corpus

The AbstRCT dataset has been preprocessed so that we can model the relation classification task by considering a jointly encoded pair of argument components as done in the original work. Thus, using this architecture one component can have relations with multiple other components, since each component combination is classified independently.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

MODEL_CARD_SciBERT = "allenai/scibert_scivocab_uncased"

# Load BERT, BioBERT, and SciBERT tokenizers
scibert_tokenizer = AutoTokenizer.from_pretrained(MODEL_CARD_SciBERT)


In [6]:
LABELS = ['NoRelation', 'Relation']

label_to_idx, idx_to_label = setup_mappings(LABELS)
NUM_LABELS = len(LABELS)

print(f"Labels: {LABELS}")
print(f"Number of labels: {NUM_LABELS}")
print(f"Label to index: {label_to_idx}")
print(f"Index to Label: {idx_to_label}")

Labels: ['NoRelation', 'Relation']
Number of labels: 2
Label to index: {'NoRelation': 0, 'Relation': 1}
Index to Label: {0: 'NoRelation', 1: 'Relation'}


In [None]:
_ ,special_tokens_scibert = display_special_tokens(MODEL_CARD_SciBERT)


Special Tokens and their IDs:
unk_token: [UNK] (ID: 101)
sep_token: [SEP] (ID: 103)
pad_token: [PAD] (ID: 0)
cls_token: [CLS] (ID: 102)
mask_token: [MASK] (ID: 104)

Special Tokens and their IDs:
bos_token: <s> (ID: 0)
eos_token: </s> (ID: 2)
unk_token: <unk> (ID: 3)
sep_token: </s> (ID: 2)
pad_token: <pad> (ID: 1)
cls_token: <s> (ID: 0)
mask_token: <mask> (ID: 50264)


### 1.2 SciBERT - Corpus

In [30]:
MAX_LENGTH = 512
special_tokens_scibert = [special_tokens_scibert['cls_token'], special_tokens_scibert['sep_token'], special_tokens_scibert['pad_token']] #[CLS], [SEP], [PAD]

#original corpus
train_df_scibert = create_dataframe_from_directory(train_data_dir,label_to_idx,scibert_tokenizer,'relation_classification', MAX_LENGTH)
val_df_scibert = create_dataframe_from_directory(val_data_dir,label_to_idx, scibert_tokenizer,'relation_classification', MAX_LENGTH )
neoplasm_test_df_scibert = create_dataframe_from_directory(neoplasm_test_data_dir,label_to_idx,scibert_tokenizer,'relation_classification',MAX_LENGTH )
glaucoma_test_df_scibert = create_dataframe_from_directory(glaucoma_test_data_dir,label_to_idx,scibert_tokenizer,'relation_classification',MAX_LENGTH )
mixed_test_df_scibert = create_dataframe_from_directory(mixed_test_data_dir,label_to_idx,scibert_tokenizer,'relation_classification',MAX_LENGTH )

CLASS_WEIGHTS_NORMALIZED_SCIBERT = compute_norm_class_weights(train_df_scibert, label_column='Label', task_name='rel_class',special_tokens=special_tokens_scibert)

print(f"\nShape Training dataframe: {train_df_scibert.shape}")
print(f"Shape Validation dataframe : {val_df_scibert.shape}")
print(f"Shape Test dataframe: {neoplasm_test_df_scibert.shape}")
print(f"Shape Test dataframe: {glaucoma_test_df_scibert.shape}")
print(f"Shape Test dataframe: {mixed_test_df_scibert.shape}")
print(f"\n{CLASS_WEIGHTS_NORMALIZED_SCIBERT}")


train_df_scibert.head()


Shape Training dataframe: (14286, 8)
Shape Validation dataframe : (2030, 8)
Shape Test dataframe: (4380, 8)
Shape Test dataframe: (3332, 8)
Shape Test dataframe: (3332, 8)

:{'NoRelation': 0.09918801623967521, 'Relation': 0.9008119837603249}


Unnamed: 0,Arg1_Text,Arg2_Text,Arg1_ID,Arg2_ID,Label,Arg1_Type,Arg2_Type,File
0,Overall objective response (OR) rates were hig...,a similar trend was noted in patients with vis...,T1,T2,NoRelation,Premise,Premise,10735887.ann
1,Overall objective response (OR) rates were hig...,Median survival time was significantly longer ...,T1,T3,NoRelation,Premise,Premise,10735887.ann
2,Overall objective response (OR) rates were hig...,"Compared with MA, there were similar or greate...",T1,T4,NoRelation,Premise,Premise,10735887.ann
3,Overall objective response (OR) rates were hig...,Both drugs were well tolerated.,T1,T5,NoRelation,Premise,Claim,10735887.ann
4,Overall objective response (OR) rates were hig...,Grade 3 or 4 weight changes were more common w...,T1,T6,NoRelation,Premise,Premise,10735887.ann


In [33]:
#custom corpus
custom_train_df_scibert = create_dataframe_from_directory(train_data_dir,label_to_idx, scibert_tokenizer,'relation_classification', MAX_LENGTH)
custom_val_df_scibert = create_dataframe_from_directory(val_data_dir,label_to_idx, scibert_tokenizer,'relation_classification', MAX_LENGTH )
custom_test_df_scibert = create_dataframe_from_directory(custom_test_data_dir,label_to_idx,scibert_tokenizer,'relation_classification',MAX_LENGTH )

CLASS_WEIGHTS_NORMALIZED_CUSTOM_SCIBERT = compute_norm_class_weights(custom_train_df_scibert, label_column='Label', task_name='rel_class', special_tokens=special_tokens_scibert)

print(f"\nShape Training dataframe: {custom_train_df_scibert.shape}")
print(f"Shape Validation dataframe : {custom_val_df_scibert.shape}")
print(f"Shape Test dataframe: {custom_test_df_scibert.shape}")
print(f"\n{CLASS_WEIGHTS_NORMALIZED_CUSTOM_SCIBERT}")

custom_train_df_scibert.head()



Shape Training dataframe: (14286, 8)
Shape Validation dataframe : (2030, 8)
Shape Test dataframe: (4856, 8)

:{'NoRelation': 0.09918801623967521, 'Relation': 0.9008119837603249}


Unnamed: 0,Arg1_Text,Arg2_Text,Arg1_ID,Arg2_ID,Label,Arg1_Type,Arg2_Type,File
0,Overall objective response (OR) rates were hig...,a similar trend was noted in patients with vis...,T1,T2,NoRelation,Premise,Premise,10735887.ann
1,Overall objective response (OR) rates were hig...,Median survival time was significantly longer ...,T1,T3,NoRelation,Premise,Premise,10735887.ann
2,Overall objective response (OR) rates were hig...,"Compared with MA, there were similar or greate...",T1,T4,NoRelation,Premise,Premise,10735887.ann
3,Overall objective response (OR) rates were hig...,Both drugs were well tolerated.,T1,T5,NoRelation,Premise,Claim,10735887.ann
4,Overall objective response (OR) rates were hig...,Grade 3 or 4 weight changes were more common w...,T1,T6,NoRelation,Premise,Premise,10735887.ann


### 2. Dataset and Dataloader Creation

#### 2.1 SciBERT - Dataset Creation

In [40]:
BATCH_SIZE = 8

# original dataset SCIBERT
train_dataset_sci = RelationDataset(train_df_scibert,scibert_tokenizer, label_to_idx, MAX_LENGTH)
val_dataset_sci = RelationDataset(val_df_scibert,scibert_tokenizer, label_to_idx, MAX_LENGTH)
neoplasm_test_dataset_sci = RelationDataset(neoplasm_test_df_scibert,scibert_tokenizer, label_to_idx, MAX_LENGTH)
glaucoma_test_dataset_sci = RelationDataset(glaucoma_test_df_scibert,scibert_tokenizer, label_to_idx, MAX_LENGTH)
mixed_test_dataset_sci = RelationDataset(mixed_test_df_scibert,scibert_tokenizer, label_to_idx, MAX_LENGTH)

# original dataloader SCIBERT
train_dataloader_sci = DataLoader(train_dataset_sci, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader_sci = DataLoader(val_dataset_sci, batch_size=BATCH_SIZE, shuffle=False)
neoplasm_test_dataloader_sci = DataLoader(neoplasm_test_dataset_sci, batch_size=BATCH_SIZE, shuffle=False)
glaucoma_test_dataloader_sci = DataLoader(glaucoma_test_dataset_sci, batch_size=BATCH_SIZE, shuffle=False)
mixed_test_dataloader_sci = DataLoader(mixed_test_dataset_sci, batch_size=BATCH_SIZE, shuffle=False)

display_dataset_item(train_dataset_sci[0], idx_to_label=idx_to_label)


INPUT_IDS:
[102, 2103, 3201, 1278, 145, 234, 546, 1975, 267, 1001, 121, 568, 2338, 190,
199, 30107, 506, 121, 1052, 2338, 190, 1048, 145, 884, 205, 244, 1863, 171, 760,
205, 286, 1863, 546, 1814, 103, 106, 868, 5144, 241, 3742, 121, 568, 190, 17664,
10878, 145, 1041, 205, 305, 1863, '...']

Length: 512

ATTENTION_MASK:
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, '...']

Length: 512

TOKEN_TYPE_IDS:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, '...']

Length: 512

LABEL:
NoRelation


In [41]:
# custom dataset SCIBERT
custom_train_dataset_sci = RelationDataset(custom_train_df_scibert,scibert_tokenizer, label_to_idx, MAX_LENGTH)
custom_val_dataset_sci = RelationDataset(custom_val_df_scibert,scibert_tokenizer, label_to_idx, MAX_LENGTH)
custom_test_dataset_sci = RelationDataset(custom_test_df_scibert,scibert_tokenizer, label_to_idx, MAX_LENGTH)

# custom dataloader SCIBERT
custom_train_dataloader_sci = DataLoader(custom_train_dataset_sci, batch_size=BATCH_SIZE, shuffle=True)
custom_val_dataloader_sci = DataLoader(custom_val_dataset_sci, batch_size=BATCH_SIZE, shuffle=False)
custom_test_dataloader_sci = DataLoader(custom_test_dataset_sci, batch_size=BATCH_SIZE, shuffle=False)

display_dataset_item(custom_train_dataset_sci[0], idx_to_label=idx_to_label)


INPUT_IDS:
[102, 2103, 3201, 1278, 145, 234, 546, 1975, 267, 1001, 121, 568, 2338, 190,
199, 30107, 506, 121, 1052, 2338, 190, 1048, 145, 884, 205, 244, 1863, 171, 760,
205, 286, 1863, 546, 1814, 103, 106, 868, 5144, 241, 3742, 121, 568, 190, 17664,
10878, 145, 1041, 205, 305, 1863, '...']

Length: 512

ATTENTION_MASK:
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, '...']

Length: 512

TOKEN_TYPE_IDS:
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, '...']

Length: 512

LABEL:
NoRelation


### 3. Model Definition

In [None]:
seeds = [42]
dropout=0.1
LR = 2e-5
EPOCHS = 5

# BERTSentClf
scibert_model_original = BERTSentClf(MODEL_CARD_SciBERT, num_labels=NUM_LABELS,dropout=dropout)
scibert_model_custom = BERTSentClf(MODEL_CARD_SciBERT, num_labels=NUM_LABELS,dropout=dropout)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


#### 3.2 SciBERT - Training

In [None]:
# SCIBERT training on original dataset
output_folder = "/content/drive/MyDrive/data/scibert_rc_models"

scibert_model_original, results_sci_original = train(scibert_model_original,
                                                    "scibert_original",
                                                    train_dataloader_sci,
                                                    val_dataloader_sci,
                                                    learning_rate=LR,
                                                    num_epochs=EPOCHS,
                                                    task_name='rel_class',
                                                    seeds=seeds,
                                                    class_weights=CLASS_WEIGHTS_NORMALIZED_SCIBERT,
                                                    save_model=True,
                                                    models_folder=output_folder)


Training with seed 42...

Epoch 1/10




Training
	train_loss: 0.3916
	F1-Score Micro: 0.8622 | F1-Score Macro: 0.7162 | Weighted F1-Score: 0.879366 | Precision: 0.6811 | Recall: 0.8039 | Accuracy: 0.8622




Validating
	valid_loss: 0.3764
	F1-Score Micro: 0.8897 | F1-Score Macro: 0.7608 | Weighted F1-Score: 0.8808 | Precision: 0.8157 | Recall: 0.7283 | Accuracy: 0.8897
****************************************************************************************************

Epoch 2/10




Training
	train_loss: 0.2646
	F1-Score Micro: 0.9082 | F1-Score Macro: 0.7992 | Weighted F1-Score: 0.917812 | Precision: 0.7532 | Recall: 0.8885 | Accuracy: 0.9082




Validating
	valid_loss: 0.5403
	F1-Score Micro: 0.8897 | F1-Score Macro: 0.7463 | Weighted F1-Score: 0.8835 | Precision: 0.7796 | Recall: 0.7230 | Accuracy: 0.8897
****************************************************************************************************

Epoch 3/10




Training
	train_loss: 0.1760
	F1-Score Micro: 0.9374 | F1-Score Macro: 0.8551 | Weighted F1-Score: 0.942616 | Precision: 0.8077 | Recall: 0.9329 | Accuracy: 0.9374




Validating
	valid_loss: 0.5310
	F1-Score Micro: 0.8975 | F1-Score Macro: 0.7644 | Weighted F1-Score: 0.8918 | Precision: 0.8001 | Recall: 0.7393 | Accuracy: 0.8975
****************************************************************************************************

Epoch 4/10




Training
	train_loss: 0.1328
	F1-Score Micro: 0.9513 | F1-Score Macro: 0.8834 | Weighted F1-Score: 0.954718 | Precision: 0.8386 | Recall: 0.9503 | Accuracy: 0.9513




Validating
	valid_loss: 0.4692
	F1-Score Micro: 0.8645 | F1-Score Macro: 0.7446 | Weighted F1-Score: 0.8472 | Precision: 0.8498 | Recall: 0.7052 | Accuracy: 0.8645
****************************************************************************************************

Epoch 5/10




Training
	train_loss: 0.1217
	F1-Score Micro: 0.9547 | F1-Score Macro: 0.8907 | Weighted F1-Score: 0.957750 | Precision: 0.8468 | Recall: 0.9548 | Accuracy: 0.9547




Validating
	valid_loss: 0.7849
	F1-Score Micro: 0.8793 | F1-Score Macro: 0.7363 | Weighted F1-Score: 0.8700 | Precision: 0.7838 | Recall: 0.7078 | Accuracy: 0.8793
****************************************************************************************************

Epoch 6/10




Training
	train_loss: 0.0888
	F1-Score Micro: 0.9684 | F1-Score Macro: 0.9207 | Weighted F1-Score: 0.970020 | Precision: 0.8836 | Recall: 0.9696 | Accuracy: 0.9684




Validating
	valid_loss: 0.7815
	F1-Score Micro: 0.8823 | F1-Score Macro: 0.7377 | Weighted F1-Score: 0.8741 | Precision: 0.7795 | Recall: 0.7112 | Accuracy: 0.8823
****************************************************************************************************

Epoch 7/10




Training
	train_loss: 0.0755
	F1-Score Micro: 0.9770 | F1-Score Macro: 0.9402 | Weighted F1-Score: 0.977788 | Precision: 0.9117 | Recall: 0.9747 | Accuracy: 0.9770




Validating
	valid_loss: 0.7128
	F1-Score Micro: 0.8837 | F1-Score Macro: 0.7472 | Weighted F1-Score: 0.8746 | Precision: 0.7984 | Recall: 0.7168 | Accuracy: 0.8837
****************************************************************************************************

Epoch 8/10




Training
	train_loss: 0.0646
	F1-Score Micro: 0.9777 | F1-Score Macro: 0.9420 | Weighted F1-Score: 0.978499 | Precision: 0.9144 | Recall: 0.9751 | Accuracy: 0.9777




Validating
	valid_loss: 0.7382
	F1-Score Micro: 0.8788 | F1-Score Macro: 0.7512 | Weighted F1-Score: 0.8667 | Precision: 0.8237 | Recall: 0.7150 | Accuracy: 0.8788
****************************************************************************************************

Epoch 9/10




Training
	train_loss: 0.0620
	F1-Score Micro: 0.9805 | F1-Score Macro: 0.9488 | Weighted F1-Score: 0.981116 | Precision: 0.9242 | Recall: 0.9776 | Accuracy: 0.9805




Validating
	valid_loss: 1.0170
	F1-Score Micro: 0.8941 | F1-Score Macro: 0.7509 | Weighted F1-Score: 0.8892 | Precision: 0.7781 | Recall: 0.7306 | Accuracy: 0.8941
****************************************************************************************************

Epoch 10/10




Training
	train_loss: 0.0625
	F1-Score Micro: 0.9821 | F1-Score Macro: 0.9525 | Weighted F1-Score: 0.982543 | Precision: 0.9306 | Recall: 0.9775 | Accuracy: 0.9821




Validating
	valid_loss: 0.9064
	F1-Score Micro: 0.8941 | F1-Score Macro: 0.7474 | Weighted F1-Score: 0.8898 | Precision: 0.7700 | Recall: 0.7297 | Accuracy: 0.8941
****************************************************************************************************

Best Val F1-Score: 0.7644 at epoch 3
Saving model...
Model saved!

Training completed.


In [None]:
# Evaluate on original test sets

evaluate(scibert_model_original, neoplasm_test_dataloader_sci, verbose=True, name='Neoplasm');
evaluate(scibert_model_original, glaucoma_test_dataloader_sci, verbose=True, name='Glaucoma');
evaluate(scibert_model_original, mixed_test_dataloader_sci, verbose=True, name='Mixed');




Neoplasm Test Results
	F1-Score Micro: 0.8909
	F1-Score Macro: 0.7492
	Weighted F1-Score: 0.8805
	Precision: 0.8143
	Recall: 0.7137
	Accuracy: 0.8909
	Classification Report:
              precision    recall  f1-score   support

           0       0.91      0.97      0.94      3716
           1       0.72      0.46      0.56       664

    accuracy                           0.89      4380
   macro avg       0.81      0.71      0.75      4380
weighted avg       0.88      0.89      0.88      4380






Glaucoma Test Results
	F1-Score Micro: 0.8932
	F1-Score Macro: 0.7841
	Weighted F1-Score: 0.8826
	Precision: 0.8648
	Recall: 0.7431
	Accuracy: 0.8932
	Classification Report:
              precision    recall  f1-score   support

           0       0.90      0.98      0.94      2735
           1       0.83      0.51      0.63       597

    accuracy                           0.89      3332
   macro avg       0.86      0.74      0.78      3332
weighted avg       0.89      0.89      0.88      3332



                                                             


Mixed Test Results
	F1-Score Micro: 0.8869
	F1-Score Macro: 0.7544
	Weighted F1-Score: 0.8746
	Precision: 0.8357
	Recall: 0.7149
	Accuracy: 0.8869
	Classification Report:
              precision    recall  f1-score   support

           0       0.90      0.97      0.93      2776
           1       0.77      0.46      0.57       556

    accuracy                           0.89      3332
   macro avg       0.84      0.71      0.75      3332
weighted avg       0.88      0.89      0.87      3332





In [None]:
# SCIBERT training on custom dataset

scibert_model_custom, results_sci_custom = train(scibert_model_custom,
                                                "scibert_custom",
                                                custom_train_dataloader_sci,
                                                custom_val_dataloader_sci,
                                                learning_rate=LR,
                                                num_epochs=EPOCHS,
                                                task_name='rel_class',
                                                seeds=seeds,
                                                class_weights = CLASS_WEIGHTS_NORMALIZED_CUSTOM_SCIBERT,
                                                save_model=True,
                                                models_folder=output_folder)


Training with seed 42...

Epoch 1/5




Training
	train_loss: 0.3768
	F1-Score Micro: 0.8625 | F1-Score Macro: 0.7208 | Weighted F1-Score: 0.880226 | Precision: 0.6843 | Recall: 0.8156 | Accuracy: 0.8625




Validating
	valid_loss: 0.4153
	F1-Score Micro: 0.8966 | F1-Score Macro: 0.7677 | Weighted F1-Score: 0.8897 | Precision: 0.8116 | Recall: 0.7387 | Accuracy: 0.8966
****************************************************************************************************

Epoch 2/5




Training
	train_loss: 0.2615
	F1-Score Micro: 0.9080 | F1-Score Macro: 0.8001 | Weighted F1-Score: 0.917848 | Precision: 0.7533 | Recall: 0.8927 | Accuracy: 0.9080




Validating
	valid_loss: 0.4026
	F1-Score Micro: 0.8704 | F1-Score Macro: 0.7366 | Weighted F1-Score: 0.8570 | Precision: 0.8090 | Recall: 0.7021 | Accuracy: 0.8704
****************************************************************************************************

Epoch 3/5




Training
	train_loss: 0.1862
	F1-Score Micro: 0.9367 | F1-Score Macro: 0.8524 | Weighted F1-Score: 0.941829 | Precision: 0.8068 | Recall: 0.9256 | Accuracy: 0.9367




Validating
	valid_loss: 0.5257
	F1-Score Micro: 0.8926 | F1-Score Macro: 0.7604 | Weighted F1-Score: 0.8852 | Precision: 0.8053 | Recall: 0.7314 | Accuracy: 0.8926
****************************************************************************************************

Epoch 4/5




Training
	train_loss: 0.1357
	F1-Score Micro: 0.9519 | F1-Score Macro: 0.8840 | Weighted F1-Score: 0.955148 | Precision: 0.8408 | Recall: 0.9472 | Accuracy: 0.9519




Validating
	valid_loss: 0.5472
	F1-Score Micro: 0.9005 | F1-Score Macro: 0.7765 | Weighted F1-Score: 0.8939 | Precision: 0.8218 | Recall: 0.7466 | Accuracy: 0.9005
****************************************************************************************************

Epoch 5/5




Training
	train_loss: 0.1082
	F1-Score Micro: 0.9651 | F1-Score Macro: 0.9125 | Weighted F1-Score: 0.966907 | Precision: 0.8759 | Recall: 0.9609 | Accuracy: 0.9651




Validating
	valid_loss: 0.5465
	F1-Score Micro: 0.8833 | F1-Score Macro: 0.7519 | Weighted F1-Score: 0.8730 | Precision: 0.8121 | Recall: 0.7185 | Accuracy: 0.8833
****************************************************************************************************

Best Val F1-Score: 0.7765 at epoch 4
Saving model...
Model saved!

Training completed.


In [None]:
# Evaluate on custom test set

evaluate(scibert_model_custom, custom_test_dataloader_sci, verbose=True);




Test Results
	F1-Score Micro: 0.9425
	F1-Score Macro: 0.8678
	Weighted F1-Score: 0.9387
	Precision: 0.9280
	Recall: 0.8271
	Accuracy: 0.9425
	Classification Report:
              precision    recall  f1-score   support

           0       0.95      0.99      0.97      4160
           1       0.91      0.67      0.77       696

    accuracy                           0.94      4856
   macro avg       0.93      0.83      0.87      4856
weighted avg       0.94      0.94      0.94      4856



Looking at the results, the model seems to be performing well on the custom test set with an accuracy of 0.94 and a macro f1_score of 0.86. However, across all datasets, the performance disparity between the majority (class 0) and minority (class 1) classes underscores the difficulty of handling imbalanced data. Low recall for class 1 suggests that the model struggles to detect true relationships.

The domain-specific datasets (Neoplasm, Glaucoma, Mixed) exhibit consistent trends, with high precision for class 1 but low recall. This indicates the model’s cautious approach, favoring precision over recall.
Moreover, The custom dataset results show significant improvement in macro F1-Score and class 1 performance. This may be attributed to better data preprocessing, more representative examples, or improved class balancing techniques.

### Error Analysis

Let's have a look to some misclassified examples on the neoplasm test set which acieved the worst performance and try to understand why the model is failing to classify them correctly.

In [None]:
from utils.utils_relation_classifier import print_misclassified_examples

In [70]:
print_misclassified_examples(scibert_model_original, neoplasm_test_dataloader_sci, label_to_idx, idx_to_label, scibert_tokenizer, max_display=5);


Total misclassified examples: 478 out of 4380

Misclassified Examples:

Arg1_Text: octreotide, the long acting somatostatin analogue, improves survival of animals with pancreatic cancer. 	(Type: Claim)
Arg2_Text: octreotide is endowed with antiproliferative activity. 	(Type: Claim)
True Label: NoRelation
Predicted Label: Relation


Arg1_Text: octreotide, the long acting somatostatin analogue, improves survival of animals with pancreatic cancer. 	(Type: Claim)
Arg2_Text: octreotide therapy seems to confer a survival benefit and a better quality of life in advanced pancreatic tumour. 	(Type: Claim)
True Label: NoRelation
Predicted Label: Relation


Arg1_Text: The patients treated with octreotide showed a significant advantage in quality of life (restored appetite, improvement of digestion and intestine function, remission of abdominal pain and preservation of baseline body weight) with a mean > 80 of karnofsky performance score. 	(Type: Premise)
Arg2_Text: octreotide, the long acting so

Many of the misclassified examples involve semantically similar claims or premises, which might have led the model to infer a relationship where none exists. For instance, both Arg1 and Arg2 in the first example refer to the potential effects of octreotide but do not explicitly establish a causal or supportive relationship. This suggests the model may rely heavily on semantic similarity or overlapping vocabulary rather than identifying explicit relational cues.

Moreover since a relations is for the majority of the time between Premise and Claim (92% of the cases) a further impovement that can be done, is to use some additional information like the type of the relation between the two arguments to help the model to classify the relation between them. 