In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
!pip install hazm

Collecting hazm
  Downloading hazm-0.7.0-py3-none-any.whl (316 kB)
[K     |████████████████████████████████| 316 kB 11.8 MB/s 
[?25hCollecting nltk==3.3
  Downloading nltk-3.3.0.zip (1.4 MB)
[K     |████████████████████████████████| 1.4 MB 35.4 MB/s 
[?25hCollecting libwapiti>=0.2.1
  Downloading libwapiti-0.2.1.tar.gz (233 kB)
[K     |████████████████████████████████| 233 kB 46.7 MB/s 
Building wheels for collected packages: nltk, libwapiti
  Building wheel for nltk (setup.py) ... [?25l[?25hdone
  Created wheel for nltk: filename=nltk-3.3-py3-none-any.whl size=1394488 sha256=1071f05970345c427ef48a38f689c715399b4aee8578948dd1088c7b0915c4d1
  Stored in directory: /root/.cache/pip/wheels/9b/fd/0c/d92302c876e5de87ebd7fc0979d82edb93e2d8d768bf71fac4
  Building wheel for libwapiti (setup.py) ... [?25l[?25hdone
  Created wheel for libwapiti: filename=libwapiti-0.2.1-cp37-cp37m-linux_x86_64.whl size=153804 sha256=d7529de832fd56960068b565b60d41e5c4e2d8dba4d391d12f2e74f9f6b36143
  Store

In [5]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.9.1-py3-none-any.whl (2.6 MB)
[K     |████████████████████████████████| 2.6 MB 11.8 MB/s 
Collecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 45.1 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-5.4.1-cp37-cp37m-manylinux1_x86_64.whl (636 kB)
[K     |████████████████████████████████| 636 kB 39.8 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.45-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 30.8 MB/s 
Collecting huggingface-hub==0.0.12
  Downloading huggingface_hub-0.0.12-py3-none-any.whl (37 kB)
Installing collected packages: tokenizers, sacremoses, pyyaml, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled 

In [6]:
import re
import pandas as pd
import requests
import spacy
from spacy import displacy
nlp = spacy.load('en_core_web_sm')

from spacy.matcher import Matcher 
from spacy.tokens import Span 

import networkx as nx

import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
from tqdm.notebook import tqdm
from torch.utils.data import TensorDataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification


from __future__ import unicode_literals
from hazm import *

In [7]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")  
device

device(type='cuda')

In [8]:
train_path = './drive/MyDrive/data_sets/PERLEX/train.txt'
test_path = './drive/MyDrive/data_sets/PERLEX/test.txt'

# train_path = './drive/MyDrive/dataset/perlex/train.txt'
# test_path = './drive/MyDrive/dataset/perlex/test.txt'

In [9]:
def get_e1(text):
  text = re.sub('<e2>', '', text)
  text = re.sub('</e2>', '', text)
  e = re.findall("<e1>(.*?)</e1>", text)
  pre_process = set()
  for ent in e:
    if ent != ' ' and len(ent.strip())>1:
      pre_process.add(ent.strip())
  return pre_process

In [10]:
def get_e2(text):
  text = re.sub('<e1>', '', text)
  text = re.sub('</e1>', '', text)
  e = re.findall("<e2>(.*?)</e2>", text)
  pre_process = set()
  for ent in e:
    if ent != ' ' and len(ent.strip())>1:
      pre_process.add(ent.strip())
  return pre_process

In [11]:
SEMEVAL_RELATION_LABELS = ['Other', 'Message-Topic(e1,e2)', 'Message-Topic(e2,e1)',
                   'Product-Producer(e1,e2)', 'Product-Producer(e2,e1)',
                   'Instrument-Agency(e1,e2)', 'Instrument-Agency(e2,e1)',
                   'Entity-Destination(e1,e2)', 'Entity-Destination(e2,e1)',
                   'Cause-Effect(e1,e2)', 'Cause-Effect(e2,e1)',
                   'Component-Whole(e1,e2)', 'Component-Whole(e2,e1)',
                   'Entity-Origin(e1,e2)', 'Entity-Origin(e2,e1)',
                   'Member-Collection(e1,e2)', 'Member-Collection(e2,e1)',
                   'Content-Container(e1,e2)', 'Content-Container(e2,e1)']

In [12]:
indx2label = dict(enumerate(SEMEVAL_RELATION_LABELS))
indx2label

{0: 'Other',
 1: 'Message-Topic(e1,e2)',
 2: 'Message-Topic(e2,e1)',
 3: 'Product-Producer(e1,e2)',
 4: 'Product-Producer(e2,e1)',
 5: 'Instrument-Agency(e1,e2)',
 6: 'Instrument-Agency(e2,e1)',
 7: 'Entity-Destination(e1,e2)',
 8: 'Entity-Destination(e2,e1)',
 9: 'Cause-Effect(e1,e2)',
 10: 'Cause-Effect(e2,e1)',
 11: 'Component-Whole(e1,e2)',
 12: 'Component-Whole(e2,e1)',
 13: 'Entity-Origin(e1,e2)',
 14: 'Entity-Origin(e2,e1)',
 15: 'Member-Collection(e1,e2)',
 16: 'Member-Collection(e2,e1)',
 17: 'Content-Container(e1,e2)',
 18: 'Content-Container(e2,e1)'}

In [13]:
indx2label

{0: 'Other',
 1: 'Message-Topic(e1,e2)',
 2: 'Message-Topic(e2,e1)',
 3: 'Product-Producer(e1,e2)',
 4: 'Product-Producer(e2,e1)',
 5: 'Instrument-Agency(e1,e2)',
 6: 'Instrument-Agency(e2,e1)',
 7: 'Entity-Destination(e1,e2)',
 8: 'Entity-Destination(e2,e1)',
 9: 'Cause-Effect(e1,e2)',
 10: 'Cause-Effect(e2,e1)',
 11: 'Component-Whole(e1,e2)',
 12: 'Component-Whole(e2,e1)',
 13: 'Entity-Origin(e1,e2)',
 14: 'Entity-Origin(e2,e1)',
 15: 'Member-Collection(e1,e2)',
 16: 'Member-Collection(e2,e1)',
 17: 'Content-Container(e1,e2)',
 18: 'Content-Container(e2,e1)'}

In [14]:
label2index = {v:k for k,v in indx2label.items()}
label2index

{'Cause-Effect(e1,e2)': 9,
 'Cause-Effect(e2,e1)': 10,
 'Component-Whole(e1,e2)': 11,
 'Component-Whole(e2,e1)': 12,
 'Content-Container(e1,e2)': 17,
 'Content-Container(e2,e1)': 18,
 'Entity-Destination(e1,e2)': 7,
 'Entity-Destination(e2,e1)': 8,
 'Entity-Origin(e1,e2)': 13,
 'Entity-Origin(e2,e1)': 14,
 'Instrument-Agency(e1,e2)': 5,
 'Instrument-Agency(e2,e1)': 6,
 'Member-Collection(e1,e2)': 15,
 'Member-Collection(e2,e1)': 16,
 'Message-Topic(e1,e2)': 1,
 'Message-Topic(e2,e1)': 2,
 'Other': 0,
 'Product-Producer(e1,e2)': 3,
 'Product-Producer(e2,e1)': 4}

In [24]:
normilizer = Normalizer(persian_numbers=False, token_based=True)

In [25]:
def process_sentence(sentence):
  sentence = sentence.replace('<e1>','[E1]').replace('</e1>','[E1]').replace('<e2>','[E2]').replace('</e2>','[E2]')
  sentence = normilizer.normalize(sentence)
  # stemmer = Stemmer()
  # stemmer = stemmer.stem(sentence)
  # words = word_tokenize(sentence)
  return sentence

In [26]:
def label_spliter(label):
  # component-whole(e2,e1)
  # return [whole , component ]
  order = re.findall("\((.*?)\)", label)[0].split(',')
  label_names = label.split("(")[0].split("-")
  if order[0] == "e1":
    return label_names
  else:
    return list(reversed(label_names))

def make_dataframe_row(sentence,label):
  
  if label == "Other":
    labels =[label, label]
  else:
    labels = label_spliter(label=label)

  result = []
  e1s = get_e1(sentence)  
  e2s = get_e2(sentence)
  clean_sentence = process_sentence(sentence)
  for e1 in e1s:
    for e2 in e2s:

      result.append( {
          "e1":e1,
          "e2":e2,
          "e1_label":labels[0],
          "e2_label":labels[1],
          "label":label,
          "nlabel": label2index[label],
          "sentence": clean_sentence
      })

  return result
    
def make_dataframe(path):
    f = open(path, 'r')
    data = [x.rstrip() for x in f] 
    data_set_rows = []
    for i in range(0, len(data)-4, 4):
      item = data[i].split('\t')
      sentence = re.sub('[!@#$،.]', '', item[1])

      label = data[i+1]
      rows = make_dataframe_row(sentence , label)
      data_set_rows += rows

    return pd.DataFrame(data_set_rows)

In [27]:
df = make_dataframe(train_path)
df_test = make_dataframe(test_path)
df_test

Unnamed: 0,e1,e2,e1_label,e2_label,label,nlabel,sentence
0,حسابرسی‌ها,ضایعات,Message,Topic,"Message-Topic(e1,e2)",1,«معمول‌ترین [E1] حسابرسی‌ها [E1] مربوط به [E2]...
1,شرکت,صندلی‌های,Producer,Product,"Product-Producer(e2,e1)",4,«این [E1] شرکت [E2] [E1] صندلی‌های [E2] پلاستی...
2,استاد,چوب,Agency,Instrument,"Instrument-Agency(e2,e1)",6,«[E1] استاد [E1] مدرسه با یک [E2] چوب [E2] درس...
3,بدن,آب‌انبار,Entity,Destination,"Entity-Destination(e1,e2)",7,«مظنون [E1] بدن [E1] مرده را به یک [E2] آب‌انب...
4,آنفولانزای,ویروس,Effect,Cause,"Cause-Effect(e2,e1)",10,«[E1] آنفولانزای [E1] مرغی یک بیماری عفونی پرن...
...,...,...,...,...,...,...,...
2745,بقایای,طوفان,Entity,Origin,"Entity-Origin(e1,e2)",13,«هوا دیروز بادی و سرد بود و هنوز [E1] بقایای [...
2746,پادشاه,جارو می کشد,Agency,Instrument,"Instrument-Agency(e2,e1)",6,«پس از جاگذاری تمام بتها که خود ساعت‌ها طول می...
2747,مصالح,صنایع,Product,Producer,"Product-Producer(e1,e2)",3,«وزیر تولید کند [E1] مصالح [E1] توسط [E2] صنای...
2748,چتر,قاب,Whole,Component,"Component-Whole(e2,e1)",12,«[E2] قاب [E1] [E2] چتر [E1] دارای یک گیره متح...


In [28]:
X_train = df.sentence.tolist()
X_val =  df_test.sentence.tolist()
y_train = df.nlabel.tolist()
y_val =  df_test.nlabel.tolist()

In [30]:
# model_name = 'bert-base-multilingual-cased'
# model_name = 'xlm-roberta-base'
model_name = 'HooshvareLab/bert-fa-zwnj-base'

tokenizer = AutoTokenizer.from_pretrained(model_name)
model  = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=len(label2index))

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=292.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=565.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=426422.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1108824.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=134.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=473451616.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at HooshvareLab/bert-fa-zwnj-base were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at HooshvareLab/b

In [31]:
special_tokens_dict = {'additional_special_tokens': ['[E1]','[E2]']}
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))

Embedding(42002, 768)

In [44]:
df['len_sentence'] = df['sentence'].apply(lambda t: len(tokenizer.tokenize(t)))
print(df['len_sentence'].max())
print(df['len_sentence'].min())

119
10


In [45]:
df_test['len_sentence'] = df_test['sentence'].apply(lambda t: len(tokenizer.tokenize(t)))
print(df_test['len_sentence'].max())
print(df_test['len_sentence'].min())

85
10


In [46]:
# Finding the share of comments with a specific length
def data_gl_than(data, less_than=100.0, greater_than=0.0, col='len_sentence'):
    data_length = data[col].values

    data_glt = sum([1 for length in data_length if greater_than < length <= less_than])

    data_glt_rate = (data_glt / len(data_length)) * 100

    print(f'Texts with word length of greater than {greater_than} and less than {less_than} includes {data_glt_rate:.2f}% of the whole!')

In [47]:
data_gl_than(df, 75, 10), data_gl_than(df_test, 75, 10)

Texts with word length of greater than 10 and less than 75 includes 99.63% of the whole!
Texts with word length of greater than 10 and less than 75 includes 99.71% of the whole!


(None, None)

In [48]:
df = df.drop(columns='len_sentence')
df_test = df_test.drop(columns='len_sentence')

In [None]:
model.to(device)

In [None]:
model

In [37]:
encoded_data_train = tokenizer.batch_encode_plus(
    X_train, 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=75, 
    return_tensors='pt'
)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [38]:
encoded_data_val = tokenizer.batch_encode_plus(
    X_val, 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=75, 
    return_tensors='pt'
)




In [39]:
input_ids_train = encoded_data_train['input_ids']
attention_masks_train = encoded_data_train['attention_mask']
labels_train = torch.tensor(y_train)

input_ids_val = encoded_data_val['input_ids']
attention_masks_val = encoded_data_val['attention_mask']
labels_val = torch.tensor(y_val)

dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)
dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)

In [41]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
batch_size = 16
dataloader_train = DataLoader(dataset_train, 
                              sampler=RandomSampler(dataset_train), 
                              batch_size=batch_size)
dataloader_validation = DataLoader(dataset_val, 
                                   sampler=SequentialSampler(dataset_val), 
                                   batch_size=batch_size)

In [49]:
from transformers import AdamW, get_linear_schedule_with_warmup
optimizer = AdamW(model.parameters(),
                  lr=3e-5, 
                  eps=1e-8)
epochs = 3
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps=0,
                                            num_training_steps=len(dataloader_train)*epochs)

In [53]:
from sklearn.metrics import f1_score
def f1_score_func(preds, labels):
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return f1_score(labels_flat, preds_flat, average='weighted')

def accuracy_per_class(preds, labels):    
    preds_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    tt = 0
    for label in np.unique(labels_flat):
        y_preds = preds_flat[labels_flat==label]
        y_true = labels_flat[labels_flat==label]
        print(f'Class: {indx2label[label]}')
        print("Acc with percent:", len(y_preds[y_preds==label])/len(y_true))
        print(f'Accuracy: {len(y_preds[y_preds==label])}/{len(y_true)}\n')
        tt += len(y_preds[y_preds==label])/len(y_true)
    return tt/ len(label2index)


In [51]:
def evaluate(dataloader_val):

    model.eval()
    
    loss_val_total = 0
    predictions, true_vals = [], []
    
    for batch in dataloader_val:
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2],
                 }

        with torch.no_grad():        
            outputs = model(**inputs)
            
        loss = outputs[0]
        logits = outputs[1]
        loss_val_total += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = inputs['labels'].cpu().numpy()
        predictions.append(logits)
        true_vals.append(label_ids)
    
    loss_val_avg = loss_val_total/len(dataloader_val) 
    
    predictions = np.concatenate(predictions, axis=0)
    true_vals = np.concatenate(true_vals, axis=0)
            
    return loss_val_avg, predictions, true_vals
    

In [52]:
import random
import numpy as np

seed_val = 17
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
training_stats = []


for epoch in tqdm(range(1, epochs+1)):
    
    model.train()
    
    loss_train_total = 0

    progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)
    for batch in progress_bar:

        model.zero_grad()
        
        batch = tuple(b.to(device) for b in batch)
        
        inputs = {'input_ids':      batch[0],
                  'attention_mask': batch[1],
                  'labels':         batch[2],
                 }  
        
        outputs = model(**inputs, output_hidden_states=True)
        
        loss = outputs[0]
        loss_train_total += loss.item()
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()
        
        progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item()/len(batch))})
         
        
    # torch.save(model.state_dict(), f'data_volume/finetuned_BERT_epoch_{epoch}.model')
        
    tqdm.write(f'\nEpoch {epoch}')
    

    loss_train_avg = loss_train_total/len(dataloader_train)            
    tqdm.write(f'Training loss: {loss_train_avg}')
    

    val_loss, predictions, true_vals = evaluate(dataloader_validation)
    val_f1 = f1_score_func(predictions, true_vals)
    tt_accuracy = accuracy_per_class(predictions, true_vals)
    
    tqdm.write(f'accuracy: {tt_accuracy}')
    tqdm.write(f'Validation loss: {val_loss}')
    tqdm.write(f'F1 Score (Weighted): {val_f1}')

    training_stats.append(
        {
            'epoch': epoch,
            'Training Loss': loss_train_avg,
            'Valid. Loss': val_loss,
            'Valid. Accur.': tt_accuracy
        }
    )


HBox(children=(FloatProgress(value=0.0, max=3.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Epoch 1', max=507.0, style=ProgressStyle(description_widt…


Epoch 1
Training loss: 1.5712746308870333
Class: Other
Acc with percent: 0.3123644251626898
Accuracy: 144/461

Class: Message-Topic(e1,e2)
Acc with percent: 0.9036697247706422
Accuracy: 197/218

Class: Message-Topic(e2,e1)
Acc with percent: 0.8627450980392157
Accuracy: 44/51

Class: Product-Producer(e1,e2)
Acc with percent: 0.8181818181818182
Accuracy: 90/110

Class: Product-Producer(e2,e1)
Acc with percent: 0.88
Accuracy: 110/125

Class: Instrument-Agency(e1,e2)
Acc with percent: 0.0
Accuracy: 0/22

Class: Instrument-Agency(e2,e1)
Acc with percent: 0.5333333333333333
Accuracy: 72/135

Class: Entity-Destination(e1,e2)
Acc with percent: 0.8698630136986302
Accuracy: 254/292

Class: Cause-Effect(e1,e2)
Acc with percent: 0.8731343283582089
Accuracy: 117/134

Class: Cause-Effect(e2,e1)
Acc with percent: 0.9025641025641026
Accuracy: 176/195

Class: Component-Whole(e1,e2)
Acc with percent: 0.7048192771084337
Accuracy: 117/166

Class: Component-Whole(e2,e1)
Acc with percent: 0.56493506493506

HBox(children=(FloatProgress(value=0.0, description='Epoch 2', max=507.0, style=ProgressStyle(description_widt…


Epoch 2
Training loss: 0.7327737126506763
Class: Other
Acc with percent: 0.5184381778741866
Accuracy: 239/461

Class: Message-Topic(e1,e2)
Acc with percent: 0.8394495412844036
Accuracy: 183/218

Class: Message-Topic(e2,e1)
Acc with percent: 0.7843137254901961
Accuracy: 40/51

Class: Product-Producer(e1,e2)
Acc with percent: 0.7818181818181819
Accuracy: 86/110

Class: Product-Producer(e2,e1)
Acc with percent: 0.776
Accuracy: 97/125

Class: Instrument-Agency(e1,e2)
Acc with percent: 0.4090909090909091
Accuracy: 9/22

Class: Instrument-Agency(e2,e1)
Acc with percent: 0.6518518518518519
Accuracy: 88/135

Class: Entity-Destination(e1,e2)
Acc with percent: 0.8664383561643836
Accuracy: 253/292

Class: Cause-Effect(e1,e2)
Acc with percent: 0.9104477611940298
Accuracy: 122/134

Class: Cause-Effect(e2,e1)
Acc with percent: 0.8102564102564103
Accuracy: 158/195

Class: Component-Whole(e1,e2)
Acc with percent: 0.7650602409638554
Accuracy: 127/166

Class: Component-Whole(e2,e1)
Acc with percent: 0

HBox(children=(FloatProgress(value=0.0, description='Epoch 3', max=507.0, style=ProgressStyle(description_widt…


Epoch 3
Training loss: 0.40734691546102014
Class: Other
Acc with percent: 0.46203904555314534
Accuracy: 213/461

Class: Message-Topic(e1,e2)
Acc with percent: 0.8853211009174312
Accuracy: 193/218

Class: Message-Topic(e2,e1)
Acc with percent: 0.8627450980392157
Accuracy: 44/51

Class: Product-Producer(e1,e2)
Acc with percent: 0.7818181818181819
Accuracy: 86/110

Class: Product-Producer(e2,e1)
Acc with percent: 0.776
Accuracy: 97/125

Class: Instrument-Agency(e1,e2)
Acc with percent: 0.45454545454545453
Accuracy: 10/22

Class: Instrument-Agency(e2,e1)
Acc with percent: 0.7925925925925926
Accuracy: 107/135

Class: Entity-Destination(e1,e2)
Acc with percent: 0.8458904109589042
Accuracy: 247/292

Class: Cause-Effect(e1,e2)
Acc with percent: 0.8880597014925373
Accuracy: 119/134

Class: Cause-Effect(e2,e1)
Acc with percent: 0.9025641025641026
Accuracy: 176/195

Class: Component-Whole(e1,e2)
Acc with percent: 0.7650602409638554
Accuracy: 127/166

Class: Component-Whole(e2,e1)
Acc with perce

In [54]:
_, predictions, true_vals = evaluate(dataloader_validation)

In [55]:
totatl_acc = accuracy_per_class(predictions, true_vals)

Class: Other
Acc with percent: 0.46203904555314534
Accuracy: 213/461

Class: Message-Topic(e1,e2)
Acc with percent: 0.8853211009174312
Accuracy: 193/218

Class: Message-Topic(e2,e1)
Acc with percent: 0.8627450980392157
Accuracy: 44/51

Class: Product-Producer(e1,e2)
Acc with percent: 0.7818181818181819
Accuracy: 86/110

Class: Product-Producer(e2,e1)
Acc with percent: 0.776
Accuracy: 97/125

Class: Instrument-Agency(e1,e2)
Acc with percent: 0.45454545454545453
Accuracy: 10/22

Class: Instrument-Agency(e2,e1)
Acc with percent: 0.7925925925925926
Accuracy: 107/135

Class: Entity-Destination(e1,e2)
Acc with percent: 0.8458904109589042
Accuracy: 247/292

Class: Cause-Effect(e1,e2)
Acc with percent: 0.8880597014925373
Accuracy: 119/134

Class: Cause-Effect(e2,e1)
Acc with percent: 0.9025641025641026
Accuracy: 176/195

Class: Component-Whole(e1,e2)
Acc with percent: 0.7650602409638554
Accuracy: 127/166

Class: Component-Whole(e2,e1)
Acc with percent: 0.6948051948051948
Accuracy: 107/154

Cla

In [56]:
totatl_acc

0.7135061159582567

In [58]:
outputs.keys()

odict_keys(['loss', 'logits', 'hidden_states'])