# Importing the necessary libraries and dependencies

In [1]:
# Download Spacy Model for Tokenization
!python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
     ---------------------------------------- 0.0/12.8 MB ? eta -:--:--
     - -------------------------------------- 0.5/12.8 MB 5.6 MB/s eta 0:00:03
     ------- -------------------------------- 2.4/12.8 MB 7.5 MB/s eta 0:00:02
     ----------------- ---------------------- 5.5/12.8 MB 10.5 MB/s eta 0:00:01
     ------------------------- -------------- 8.1/12.8 MB 11.2 MB/s eta 0:00:01
     --------------------------------- ----- 11.0/12.8 MB 11.9 MB/s eta 0:00:01
     --------------------------------------- 12.8/12.8 MB 11.6 MB/s eta 0:00:00
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


In [2]:
import json
from sklearn.metrics import classification_report
from utils.preprocessing import DOCRED_Processing # imported from preprocessing.py
import utils.train as train # imported from train.py

# Loading the training file

In [3]:
with open("./dataset_DocRED/train_annotated.json", "r") as f:
    train_Dataset = json.load(f)

train_Dataset = DOCRED_Processing(train_Dataset)
train_Dataset.df.head()

Unnamed: 0,entity1,entity2,original_doc,sent_tokens,sentence,relation,entity1_span,entity2_span
0,"Zest Airways, Inc.",Pasay City,AirAsia Zest,"[Zest, Airways, ,, Inc., operated, as, AirAsia...","Zest Airways , Inc. operated as AirAsia Zest (...",headquarters location,"[0, 1, 3]","[31, 32]"
2,Zest Air,Philippines,AirAsia Zest,"[Less, than, a, year, after, AirAsia, and, Zes...",Less than a year after AirAsia and Zest Air '...,country,"[7, 8]",[28]
3,Pasay City,Philippines,AirAsia Zest,"[Zest, Airways, ,, Inc., operated, as, AirAsia...","Zest Airways , Inc. operated as AirAsia Zest (...",country,"[31, 32]",[38]
4,Pasay City,Metro Manila,AirAsia Zest,"[Zest, Airways, ,, Inc., operated, as, AirAsia...","Zest Airways , Inc. operated as AirAsia Zest (...",located in the administrative territorial entity,"[31, 32]","[34, 35]"
5,Philippines,Metro Manila,AirAsia Zest,"[Zest, Airways, ,, Inc., operated, as, AirAsia...","Zest Airways , Inc. operated as AirAsia Zest (...",contains administrative territorial entity,[38],"[34, 35]"


In [4]:
# X and Y preparation for training data
X_train = train_Dataset.df.drop(['sent_tokens', 'relation'], axis=1)
Y_train= train_Dataset.df['relation']

# Loading the testing file

In [5]:
with open("./dataset_DocRED/dev.json", "r") as f:
    test_Dataset = json.load(f)

test_Dataset = DOCRED_Processing(test_Dataset)
test_Dataset.df.head()

Unnamed: 0,entity1,entity2,original_doc,sent_tokens,sentence,relation,entity1_span,entity2_span
0,Piraeus,Greece,Skai TV,"[Skai, TV, is, a, Greek, free, -, to, -, air, ...",Skai TV is a Greek free - to - air television...,country,[14],[36]
1,Skai Group,Greece,Skai TV,"[Skai, TV, is, a, Greek, free, -, to, -, air, ...",Skai TV is a Greek free - to - air television...,country,"[0, 22]",[54]
2,Athens,Greece,Skai TV,"[Skai, TV, is, a, Greek, free, -, to, -, air, ...",Skai TV is a Greek free - to - air television...,country,[30],[61]
3,Skai TV,Piraeus,Skai TV,"[Skai, TV, is, a, Greek, free, -, to, -, air, ...",Skai TV is a Greek free - to - air television ...,headquarters location,"[0, 1]",[14]
4,Skai TV,Skai Group,Skai TV,"[Skai, TV, is, a, Greek, free, -, to, -, air, ...",Skai TV is a Greek free - to - air television...,owned by,"[0, 1]","[0, 22]"


In [6]:
# X and Y preparation for testing data
X_test = test_Dataset.df.drop(['sent_tokens', 'relation'], axis=1)
Y_test = test_Dataset.df['relation']

# Train model

In [7]:
model = train.train_function(X_train, Y_train)

# Printing classification report

1. Train Dataset

In [8]:
Y_pred_train = model.predict(X_train)
print(classification_report(Y_train, Y_pred_train))

                                                  precision    recall  f1-score   support

                         applies to jurisdiction       0.88      0.17      0.28       271
                                          author       0.92      0.65      0.77       300
                                  award received       0.99      0.67      0.80       161
                                   basin country       1.00      0.16      0.27        82
                                         capital       1.00      0.06      0.11        72
                                      capital of       1.00      0.02      0.03        62
                                     cast member       0.90      0.82      0.86       551
                                     chairperson       0.92      0.20      0.33        59
                                      characters       0.95      0.56      0.70       149
                                           child       0.64      0.34      0.45       355
         

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


2. Test Dataset

In [9]:
Y_pred_test = model.predict(X_test)
print(classification_report(Y_test, Y_pred_test))

                                                  precision    recall  f1-score   support

                         applies to jurisdiction       0.00      0.00      0.00        67
                                          author       0.36      0.06      0.10        88
                                  award received       0.50      0.02      0.03        61
                                   basin country       0.00      0.00      0.00        30
                                         capital       0.00      0.00      0.00        25
                                      capital of       0.00      0.00      0.00        19
                                     cast member       0.69      0.21      0.32       168
                                     chairperson       0.00      0.00      0.00        20
                                      characters       0.00      0.00      0.00        71
                                           child       0.00      0.00      0.00        77
         

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
