In [2]:
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import json
import pandas as pd
import numpy as np
import sys

pd.set_option('display.max_colwidth', None)
sys.path.append('./src-py')

In [4]:
%autoreload
import sbert_training
from sklearn.metrics import precision_recall_fscore_support

In [5]:
from sentence_transformers import SentenceTransformer, InputExample, LoggingHandler, losses, models, util
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, BinaryClassificationEvaluator
from torch.utils.data import DataLoader
from sentence_transformers.evaluation import TripletEvaluator
from datetime import datetime
from sentence_transformers import util
from zipfile import ZipFile
from sentence_transformers.datasets import SentenceLabelDataset
from sentence_transformers.datasets import NoDuplicatesDataLoader

import logging

logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

In [6]:
taska_training_df = pd.read_csv('../data/TaskA_train.csv')
taskb_training_df = pd.read_csv('../data/TaskB_train.csv')

taska_valid_df = pd.read_csv('../data/TaskA_dev.csv')
taskb_valid_df = pd.read_csv('../data/TaskB_dev.csv')

In [7]:
output_path = "../../data-ceph/arguana/argmining22-sharedtask/models/"

In [8]:
#Mapping labels
taska_training_df = taska_training_df[taska_training_df.Validity != 0]
taska_valid_df    = taska_valid_df[taska_valid_df.Validity != 0]

taska_training_df['label'] = taska_training_df.Validity.apply(lambda x : 1 if x == 1 else 0)
taska_valid_df['label'] = taska_valid_df.Validity.apply(lambda x : 1 if x == 1 else 0)

taska_training_df['Premise'] = taska_training_df.apply(lambda x: x['topic'] + ' : ' +  x['Premise'], axis=1)
taska_valid_df['Premise'] = taska_valid_df.apply(lambda x: x['topic'] + ' : ' +  x['Premise'], axis=1)

In [9]:
taska_training_df['input_txt'] = taska_training_df.apply(lambda x: '[CLS] {} [SEP] {} [SEP] {} [SEP]'.format(x['topic'], x['Premise'], x['Conclusion']), axis=1)

In [10]:
taska_training_df.head()

Unnamed: 0,topic,Premise,Conclusion,Validity,Validity-Confidence,Novelty,Novelty-Confidence,label,input_txt
0,TV viewing is harmful to children,"TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions.",Depression is a well-known psychological problem of modern society that is partly caused by TV watching:,1,confident,1,confident,1,"[CLS] TV viewing is harmful to children [SEP] TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions. [SEP] Depression is a well-known psychological problem of modern society that is partly caused by TV watching: [SEP]"
1,TV viewing is harmful to children,"TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions.",Children's TV viewing fosters negative emotions,1,very confident,-1,majority,1,"[CLS] TV viewing is harmful to children [SEP] TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions. [SEP] Children's TV viewing fosters negative emotions [SEP]"
2,TV viewing is harmful to children,"TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions.",Popularity of TV is harmful to children,1,very confident,1,majority,1,"[CLS] TV viewing is harmful to children [SEP] TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions. [SEP] Popularity of TV is harmful to children [SEP]"
3,TV viewing is harmful to children,"TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions.",Violence on TV and in movies encourages psychological stress,1,very confident,1,majority,1,"[CLS] TV viewing is harmful to children [SEP] TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions. [SEP] Violence on TV and in movies encourages psychological stress [SEP]"
4,TV viewing is harmful to children,"TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions.",US-India deal does not cap or limit Indian fissile material production.,-1,very confident,-1,very confident,0,"[CLS] TV viewing is harmful to children [SEP] TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions. [SEP] US-India deal does not cap or limit Indian fissile material production. [SEP]"


In [11]:
taska_training_df.label.value_counts()

1    401
0    320
Name: label, dtype: int64

In [12]:
taska_valid_df.label.value_counts()

1    125
0     74
Name: label, dtype: int64

In [14]:
#train_model(taska_training_df, taska_valid_df, '../data/output/', 'sentence-transformers/nli-roberta-large', num_epochs=20, train_batch_size=16, model_suffix='', data_file_suffix='', max_seq_length=256, special_tokens=[], loss='ContrastiveLoss', sentence_transformer=False)

In [13]:
_, _ = sbert_training.train_model(taska_training_df, taska_valid_df, 
                                    output_path + '/task-A/validity/sbert/', 
                                    'sentence-transformers/nli-roberta-large', 
                                    num_epochs=15, train_batch_size=32,
                                    model_suffix='', max_seq_length=512, special_tokens=[], 
                                    loss='ContrastiveLoss', sentence_transformer=False, evaluation_steps=10)

2022-07-02 19:06:24 - Use pytorch device: cuda
2022-07-02 19:06:24 - Read Triplet train dataset
Len of training: 721
Len of Dev: 199
Evaluating before start learning.....
2022-07-02 19:06:24 - Binary Accuracy Evaluation of the model on sts-dev dataset:
2022-07-02 19:06:27 - Accuracy with Cosine-Similarity:           72.36	(Threshold: 0.6503)
2022-07-02 19:06:27 - F1 with Cosine-Similarity:                 80.69	(Threshold: 0.5985)
2022-07-02 19:06:27 - Precision with Cosine-Similarity:          70.91
2022-07-02 19:06:27 - Recall with Cosine-Similarity:             93.60
2022-07-02 19:06:27 - Average Precision with Cosine-Similarity:  82.54

2022-07-02 19:06:27 - Accuracy with Manhatten-Distance:           72.36	(Threshold: 646.3732)
2022-07-02 19:06:27 - F1 with Manhatten-Distance:                 80.56	(Threshold: 679.7071)
2022-07-02 19:06:27 - Precision with Manhatten-Distance:          71.17
2022-07-02 19:06:27 - Recall with Manhatten-Distance:             92.80
2022-07-02 19:06:27



Epoch:   0%|          | 0/15 [00:00<?, ?it/s]

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:06:30 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 0 after 10 steps:
2022-07-02 19:06:30 - Accuracy with Cosine-Similarity:           72.36	(Threshold: 0.6616)
2022-07-02 19:06:30 - F1 with Cosine-Similarity:                 80.41	(Threshold: 0.6062)
2022-07-02 19:06:30 - Precision with Cosine-Similarity:          70.48
2022-07-02 19:06:30 - Recall with Cosine-Similarity:             93.60
2022-07-02 19:06:30 - Average Precision with Cosine-Similarity:  82.40

2022-07-02 19:06:30 - Accuracy with Manhatten-Distance:           72.36	(Threshold: 631.9717)
2022-07-02 19:06:30 - F1 with Manhatten-Distance:                 80.56	(Threshold: 672.8953)
2022-07-02 19:06:30 - Precision with Manhatten-Distance:          71.17
2022-07-02 19:06:30 - Recall with Manhatten-Distance:             92.80
2022-07-02 19:06:30 - Average Precision with Manhatten-Distance:  82.55

2022-07-02 19:06:30 - Accuracy with Euclidean-Distance:           72.36	(Threshold: 25.3380

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:07:05 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 1 after 10 steps:
2022-07-02 19:07:05 - Accuracy with Cosine-Similarity:           73.37	(Threshold: 0.7339)
2022-07-02 19:07:05 - F1 with Cosine-Similarity:                 80.59	(Threshold: 0.7339)
2022-07-02 19:07:05 - Precision with Cosine-Similarity:          74.32
2022-07-02 19:07:05 - Recall with Cosine-Similarity:             88.00
2022-07-02 19:07:05 - Average Precision with Cosine-Similarity:  81.96

2022-07-02 19:07:05 - Accuracy with Manhatten-Distance:           72.36	(Threshold: 544.4036)
2022-07-02 19:07:05 - F1 with Manhatten-Distance:                 80.28	(Threshold: 571.1272)
2022-07-02 19:07:05 - Precision with Manhatten-Distance:          71.70
2022-07-02 19:07:05 - Recall with Manhatten-Distance:             91.20
2022-07-02 19:07:05 - Average Precision with Manhatten-Distance:  81.70

2022-07-02 19:07:05 - Accuracy with Euclidean-Distance:           73.37	(Threshold: 22.1106

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:07:10 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 2 after 10 steps:
2022-07-02 19:07:10 - Accuracy with Cosine-Similarity:           73.37	(Threshold: 0.7817)
2022-07-02 19:07:10 - F1 with Cosine-Similarity:                 80.00	(Threshold: 0.6802)
2022-07-02 19:07:10 - Precision with Cosine-Similarity:          67.03
2022-07-02 19:07:10 - Recall with Cosine-Similarity:             99.20
2022-07-02 19:07:10 - Average Precision with Cosine-Similarity:  81.69

2022-07-02 19:07:10 - Accuracy with Manhatten-Distance:           74.37	(Threshold: 520.3919)
2022-07-02 19:07:10 - F1 with Manhatten-Distance:                 81.04	(Threshold: 520.3919)
2022-07-02 19:07:10 - Precision with Manhatten-Distance:          75.69
2022-07-02 19:07:10 - Recall with Manhatten-Distance:             87.20
2022-07-02 19:07:10 - Average Precision with Manhatten-Distance:  81.62

2022-07-02 19:07:10 - Accuracy with Euclidean-Distance:           73.37	(Threshold: 19.8955

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:07:15 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 3 after 10 steps:
2022-07-02 19:07:15 - Accuracy with Cosine-Similarity:           73.87	(Threshold: 0.7634)
2022-07-02 19:07:15 - F1 with Cosine-Similarity:                 80.39	(Threshold: 0.6529)
2022-07-02 19:07:15 - Precision with Cosine-Similarity:          67.20
2022-07-02 19:07:15 - Recall with Cosine-Similarity:             100.00
2022-07-02 19:07:15 - Average Precision with Cosine-Similarity:  81.96

2022-07-02 19:07:15 - Accuracy with Manhatten-Distance:           74.87	(Threshold: 528.1063)
2022-07-02 19:07:15 - F1 with Manhatten-Distance:                 80.71	(Threshold: 555.5344)
2022-07-02 19:07:15 - Precision with Manhatten-Distance:          72.90
2022-07-02 19:07:15 - Recall with Manhatten-Distance:             90.40
2022-07-02 19:07:15 - Average Precision with Manhatten-Distance:  82.12

2022-07-02 19:07:15 - Accuracy with Euclidean-Distance:           73.87	(Threshold: 21.384

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:07:20 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 4 after 10 steps:
2022-07-02 19:07:21 - Accuracy with Cosine-Similarity:           72.86	(Threshold: 0.7683)
2022-07-02 19:07:21 - F1 with Cosine-Similarity:                 80.39	(Threshold: 0.6650)
2022-07-02 19:07:21 - Precision with Cosine-Similarity:          67.20
2022-07-02 19:07:21 - Recall with Cosine-Similarity:             100.00
2022-07-02 19:07:21 - Average Precision with Cosine-Similarity:  81.89

2022-07-02 19:07:21 - Accuracy with Manhatten-Distance:           73.37	(Threshold: 522.3141)
2022-07-02 19:07:21 - F1 with Manhatten-Distance:                 80.28	(Threshold: 553.6518)
2022-07-02 19:07:21 - Precision with Manhatten-Distance:          71.70
2022-07-02 19:07:21 - Recall with Manhatten-Distance:             91.20
2022-07-02 19:07:21 - Average Precision with Manhatten-Distance:  81.76

2022-07-02 19:07:21 - Accuracy with Euclidean-Distance:           72.86	(Threshold: 21.021

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:07:25 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 5 after 10 steps:
2022-07-02 19:07:25 - Accuracy with Cosine-Similarity:           71.86	(Threshold: 0.7578)
2022-07-02 19:07:25 - F1 with Cosine-Similarity:                 81.21	(Threshold: 0.7024)
2022-07-02 19:07:25 - Precision with Cosine-Similarity:          69.94
2022-07-02 19:07:25 - Recall with Cosine-Similarity:             96.80
2022-07-02 19:07:25 - Average Precision with Cosine-Similarity:  81.99

2022-07-02 19:07:25 - Accuracy with Manhatten-Distance:           73.37	(Threshold: 544.8882)
2022-07-02 19:07:25 - F1 with Manhatten-Distance:                 81.33	(Threshold: 598.2496)
2022-07-02 19:07:25 - Precision with Manhatten-Distance:          69.71
2022-07-02 19:07:25 - Recall with Manhatten-Distance:             97.60
2022-07-02 19:07:25 - Average Precision with Manhatten-Distance:  82.05

2022-07-02 19:07:25 - Accuracy with Euclidean-Distance:           71.86	(Threshold: 21.0494

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:07:30 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 6 after 10 steps:
2022-07-02 19:07:31 - Accuracy with Cosine-Similarity:           71.86	(Threshold: 0.7076)
2022-07-02 19:07:31 - F1 with Cosine-Similarity:                 81.19	(Threshold: 0.6934)
2022-07-02 19:07:31 - Precision with Cosine-Similarity:          69.10
2022-07-02 19:07:31 - Recall with Cosine-Similarity:             98.40
2022-07-02 19:07:31 - Average Precision with Cosine-Similarity:  81.68

2022-07-02 19:07:31 - Accuracy with Manhatten-Distance:           72.36	(Threshold: 555.4594)
2022-07-02 19:07:31 - F1 with Manhatten-Distance:                 80.95	(Threshold: 586.5836)
2022-07-02 19:07:31 - Precision with Manhatten-Distance:          70.41
2022-07-02 19:07:31 - Recall with Manhatten-Distance:             95.20
2022-07-02 19:07:31 - Average Precision with Manhatten-Distance:  81.76

2022-07-02 19:07:31 - Accuracy with Euclidean-Distance:           71.36	(Threshold: 23.6716

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:07:36 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 7 after 10 steps:
2022-07-02 19:07:36 - Accuracy with Cosine-Similarity:           72.36	(Threshold: 0.6921)
2022-07-02 19:07:36 - F1 with Cosine-Similarity:                 81.73	(Threshold: 0.6921)
2022-07-02 19:07:36 - Precision with Cosine-Similarity:          69.89
2022-07-02 19:07:36 - Recall with Cosine-Similarity:             98.40
2022-07-02 19:07:36 - Average Precision with Cosine-Similarity:  81.49

2022-07-02 19:07:36 - Accuracy with Manhatten-Distance:           72.36	(Threshold: 572.4706)
2022-07-02 19:07:36 - F1 with Manhatten-Distance:                 81.10	(Threshold: 581.8716)
2022-07-02 19:07:36 - Precision with Manhatten-Distance:          71.08
2022-07-02 19:07:36 - Recall with Manhatten-Distance:             94.40
2022-07-02 19:07:36 - Average Precision with Manhatten-Distance:  81.66

2022-07-02 19:07:36 - Accuracy with Euclidean-Distance:           72.36	(Threshold: 23.1658

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:07:41 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 8 after 10 steps:
2022-07-02 19:07:41 - Accuracy with Cosine-Similarity:           72.86	(Threshold: 0.7224)
2022-07-02 19:07:41 - F1 with Cosine-Similarity:                 81.43	(Threshold: 0.6792)
2022-07-02 19:07:41 - Precision with Cosine-Similarity:          68.68
2022-07-02 19:07:41 - Recall with Cosine-Similarity:             100.00
2022-07-02 19:07:41 - Average Precision with Cosine-Similarity:  81.22

2022-07-02 19:07:41 - Accuracy with Manhatten-Distance:           72.86	(Threshold: 565.0312)
2022-07-02 19:07:41 - F1 with Manhatten-Distance:                 81.25	(Threshold: 566.7938)
2022-07-02 19:07:41 - Precision with Manhatten-Distance:          71.78
2022-07-02 19:07:41 - Recall with Manhatten-Distance:             93.60
2022-07-02 19:07:41 - Average Precision with Manhatten-Distance:  81.37

2022-07-02 19:07:41 - Accuracy with Euclidean-Distance:           73.37	(Threshold: 23.057

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:07:46 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 9 after 10 steps:
2022-07-02 19:07:46 - Accuracy with Cosine-Similarity:           73.37	(Threshold: 0.7172)
2022-07-02 19:07:46 - F1 with Cosine-Similarity:                 81.79	(Threshold: 0.7141)
2022-07-02 19:07:46 - Precision with Cosine-Similarity:          71.69
2022-07-02 19:07:46 - Recall with Cosine-Similarity:             95.20
2022-07-02 19:07:46 - Average Precision with Cosine-Similarity:  81.07

2022-07-02 19:07:46 - Accuracy with Manhatten-Distance:           73.37	(Threshold: 574.0347)
2022-07-02 19:07:46 - F1 with Manhatten-Distance:                 81.66	(Threshold: 574.0347)
2022-07-02 19:07:46 - Precision with Manhatten-Distance:          71.95
2022-07-02 19:07:46 - Recall with Manhatten-Distance:             94.40
2022-07-02 19:07:46 - Average Precision with Manhatten-Distance:  81.18

2022-07-02 19:07:46 - Accuracy with Euclidean-Distance:           73.37	(Threshold: 23.2309

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:07:51 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 10 after 10 steps:
2022-07-02 19:07:51 - Accuracy with Cosine-Similarity:           73.37	(Threshold: 0.7169)
2022-07-02 19:07:51 - F1 with Cosine-Similarity:                 81.79	(Threshold: 0.7097)
2022-07-02 19:07:51 - Precision with Cosine-Similarity:          71.69
2022-07-02 19:07:51 - Recall with Cosine-Similarity:             95.20
2022-07-02 19:07:51 - Average Precision with Cosine-Similarity:  81.16

2022-07-02 19:07:51 - Accuracy with Manhatten-Distance:           73.87	(Threshold: 577.9833)
2022-07-02 19:07:51 - F1 with Manhatten-Distance:                 82.07	(Threshold: 577.9833)
2022-07-02 19:07:51 - Precision with Manhatten-Distance:          72.12
2022-07-02 19:07:51 - Recall with Manhatten-Distance:             95.20
2022-07-02 19:07:51 - Average Precision with Manhatten-Distance:  81.43

2022-07-02 19:07:51 - Accuracy with Euclidean-Distance:           73.87	(Threshold: 23.361

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:07:56 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 11 after 10 steps:
2022-07-02 19:07:56 - Accuracy with Cosine-Similarity:           73.87	(Threshold: 0.7192)
2022-07-02 19:07:56 - F1 with Cosine-Similarity:                 81.94	(Threshold: 0.7192)
2022-07-02 19:07:56 - Precision with Cosine-Similarity:          72.39
2022-07-02 19:07:56 - Recall with Cosine-Similarity:             94.40
2022-07-02 19:07:56 - Average Precision with Cosine-Similarity:  81.25

2022-07-02 19:07:56 - Accuracy with Manhatten-Distance:           73.87	(Threshold: 578.4261)
2022-07-02 19:07:56 - F1 with Manhatten-Distance:                 82.07	(Threshold: 578.4261)
2022-07-02 19:07:56 - Precision with Manhatten-Distance:          72.12
2022-07-02 19:07:56 - Recall with Manhatten-Distance:             95.20
2022-07-02 19:07:56 - Average Precision with Manhatten-Distance:  81.36

2022-07-02 19:07:56 - Accuracy with Euclidean-Distance:           73.87	(Threshold: 23.102

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:08:01 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 12 after 10 steps:
2022-07-02 19:08:01 - Accuracy with Cosine-Similarity:           73.87	(Threshold: 0.7212)
2022-07-02 19:08:01 - F1 with Cosine-Similarity:                 81.94	(Threshold: 0.7212)
2022-07-02 19:08:01 - Precision with Cosine-Similarity:          72.39
2022-07-02 19:08:01 - Recall with Cosine-Similarity:             94.40
2022-07-02 19:08:01 - Average Precision with Cosine-Similarity:  81.21

2022-07-02 19:08:01 - Accuracy with Manhatten-Distance:           73.87	(Threshold: 575.0896)
2022-07-02 19:08:01 - F1 with Manhatten-Distance:                 82.07	(Threshold: 575.0896)
2022-07-02 19:08:01 - Precision with Manhatten-Distance:          72.12
2022-07-02 19:08:01 - Recall with Manhatten-Distance:             95.20
2022-07-02 19:08:01 - Average Precision with Manhatten-Distance:  81.42

2022-07-02 19:08:01 - Accuracy with Euclidean-Distance:           73.87	(Threshold: 23.073

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:08:06 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 13 after 10 steps:
2022-07-02 19:08:06 - Accuracy with Cosine-Similarity:           73.87	(Threshold: 0.7220)
2022-07-02 19:08:06 - F1 with Cosine-Similarity:                 81.94	(Threshold: 0.7220)
2022-07-02 19:08:06 - Precision with Cosine-Similarity:          72.39
2022-07-02 19:08:06 - Recall with Cosine-Similarity:             94.40
2022-07-02 19:08:06 - Average Precision with Cosine-Similarity:  81.10

2022-07-02 19:08:06 - Accuracy with Manhatten-Distance:           73.87	(Threshold: 574.9849)
2022-07-02 19:08:06 - F1 with Manhatten-Distance:                 82.07	(Threshold: 574.9849)
2022-07-02 19:08:06 - Precision with Manhatten-Distance:          72.12
2022-07-02 19:08:06 - Recall with Manhatten-Distance:             95.20
2022-07-02 19:08:06 - Average Precision with Manhatten-Distance:  81.27

2022-07-02 19:08:06 - Accuracy with Euclidean-Distance:           73.87	(Threshold: 23.082

Iteration:   0%|          | 0/23 [00:00<?, ?it/s]

2022-07-02 19:08:11 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 14 after 10 steps:
2022-07-02 19:08:12 - Accuracy with Cosine-Similarity:           73.87	(Threshold: 0.7225)
2022-07-02 19:08:12 - F1 with Cosine-Similarity:                 81.94	(Threshold: 0.7225)
2022-07-02 19:08:12 - Precision with Cosine-Similarity:          72.39
2022-07-02 19:08:12 - Recall with Cosine-Similarity:             94.40
2022-07-02 19:08:12 - Average Precision with Cosine-Similarity:  81.06

2022-07-02 19:08:12 - Accuracy with Manhatten-Distance:           73.87	(Threshold: 575.3561)
2022-07-02 19:08:12 - F1 with Manhatten-Distance:                 82.07	(Threshold: 575.3561)
2022-07-02 19:08:12 - Precision with Manhatten-Distance:          72.12
2022-07-02 19:08:12 - Recall with Manhatten-Distance:             95.20
2022-07-02 19:08:12 - Average Precision with Manhatten-Distance:  81.29

2022-07-02 19:08:12 - Accuracy with Euclidean-Distance:           73.87	(Threshold: 23.044

### Using the auto-generated conclusions:

In [14]:
taska_training_df = pd.read_pickle('../data/TaskA_train_with_extra_conclusions.pkl')
taska_valid_df = pd.read_csv('../data/TaskA_dev.csv')

In [15]:
#Mapping labels
taska_training_df = taska_training_df[taska_training_df.Validity != 0]
taska_valid_df    = taska_valid_df[taska_valid_df.Validity != 0]

taska_training_df['label'] = taska_training_df.Validity.apply(lambda x : 1 if x == 1 else 0)
taska_valid_df['label'] = taska_valid_df.Validity.apply(lambda x : 1 if x == 1 else 0)

taska_training_df['Premise'] = taska_training_df.apply(lambda x: x['topic'] + ' : ' +  x['Premise'], axis=1)
taska_valid_df['Premise'] = taska_valid_df.apply(lambda x: x['topic'] + ' : ' +  x['Premise'], axis=1)

In [16]:
_, _ = sbert_training.train_model(taska_training_df, taska_valid_df, 
                                output_path + '/task-A/validity/sbert/', 
                                'sentence-transformers/nli-roberta-large', 
                                num_epochs=15, train_batch_size=32,
                                model_suffix='extra-conclusions', max_seq_length=512, special_tokens=[], 
                                loss='ContrastiveLoss', sentence_transformer=False, evaluation_steps=10)

2022-07-02 19:09:16 - Use pytorch device: cuda
2022-07-02 19:09:16 - Read Triplet train dataset
Len of training: 4471
Len of Dev: 199
Evaluating before start learning.....
2022-07-02 19:09:16 - Binary Accuracy Evaluation of the model on sts-dev dataset:
2022-07-02 19:09:17 - Accuracy with Cosine-Similarity:           72.36	(Threshold: 0.6503)
2022-07-02 19:09:17 - F1 with Cosine-Similarity:                 80.69	(Threshold: 0.5985)
2022-07-02 19:09:17 - Precision with Cosine-Similarity:          70.91
2022-07-02 19:09:17 - Recall with Cosine-Similarity:             93.60
2022-07-02 19:09:17 - Average Precision with Cosine-Similarity:  82.54

2022-07-02 19:09:17 - Accuracy with Manhatten-Distance:           72.36	(Threshold: 646.3732)
2022-07-02 19:09:17 - F1 with Manhatten-Distance:                 80.56	(Threshold: 679.7071)
2022-07-02 19:09:17 - Precision with Manhatten-Distance:          71.17
2022-07-02 19:09:17 - Recall with Manhatten-Distance:             92.80
2022-07-02 19:09:1



Epoch:   0%|          | 0/15 [00:00<?, ?it/s]

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:09:19 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 0 after 10 steps:
2022-07-02 19:09:19 - Accuracy with Cosine-Similarity:           72.36	(Threshold: 0.6541)
2022-07-02 19:09:19 - F1 with Cosine-Similarity:                 80.41	(Threshold: 0.5983)
2022-07-02 19:09:19 - Precision with Cosine-Similarity:          70.48
2022-07-02 19:09:19 - Recall with Cosine-Similarity:             93.60
2022-07-02 19:09:19 - Average Precision with Cosine-Similarity:  82.52

2022-07-02 19:09:19 - Accuracy with Manhatten-Distance:           72.36	(Threshold: 644.4963)
2022-07-02 19:09:19 - F1 with Manhatten-Distance:                 80.56	(Threshold: 678.1983)
2022-07-02 19:09:19 - Precision with Manhatten-Distance:          71.17
2022-07-02 19:09:19 - Recall with Manhatten-Distance:             92.80
2022-07-02 19:09:19 - Average Precision with Manhatten-Distance:  82.46

2022-07-02 19:09:19 - Accuracy with Euclidean-Distance:           72.86	(Threshold: 26.2884

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:10:17 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 1 after 10 steps:
2022-07-02 19:10:17 - Accuracy with Cosine-Similarity:           65.33	(Threshold: 0.9242)
2022-07-02 19:10:17 - F1 with Cosine-Similarity:                 77.96	(Threshold: 0.9007)
2022-07-02 19:10:17 - Precision with Cosine-Similarity:          64.89
2022-07-02 19:10:17 - Recall with Cosine-Similarity:             97.60
2022-07-02 19:10:17 - Average Precision with Cosine-Similarity:  75.03

2022-07-02 19:10:17 - Accuracy with Manhatten-Distance:           65.83	(Threshold: 307.0382)
2022-07-02 19:10:17 - F1 with Manhatten-Distance:                 77.96	(Threshold: 350.4466)
2022-07-02 19:10:17 - Precision with Manhatten-Distance:          64.89
2022-07-02 19:10:17 - Recall with Manhatten-Distance:             97.60
2022-07-02 19:10:17 - Average Precision with Manhatten-Distance:  75.00

2022-07-02 19:10:17 - Accuracy with Euclidean-Distance:           66.33	(Threshold: 11.9797

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:10:48 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 2 after 10 steps:
2022-07-02 19:10:48 - Accuracy with Cosine-Similarity:           64.32	(Threshold: 0.8870)
2022-07-02 19:10:48 - F1 with Cosine-Similarity:                 77.74	(Threshold: 0.8453)
2022-07-02 19:10:48 - Precision with Cosine-Similarity:          63.92
2022-07-02 19:10:48 - Recall with Cosine-Similarity:             99.20
2022-07-02 19:10:48 - Average Precision with Cosine-Similarity:  71.16

2022-07-02 19:10:48 - Accuracy with Manhatten-Distance:           64.32	(Threshold: 365.0349)
2022-07-02 19:10:48 - F1 with Manhatten-Distance:                 77.88	(Threshold: 466.7104)
2022-07-02 19:10:48 - Precision with Manhatten-Distance:          63.78
2022-07-02 19:10:48 - Recall with Manhatten-Distance:             100.00
2022-07-02 19:10:48 - Average Precision with Manhatten-Distance:  71.31

2022-07-02 19:10:48 - Accuracy with Euclidean-Distance:           64.32	(Threshold: 14.777

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:11:19 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 3 after 10 steps:
2022-07-02 19:11:19 - Accuracy with Cosine-Similarity:           63.82	(Threshold: 0.8115)
2022-07-02 19:11:19 - F1 with Cosine-Similarity:                 77.64	(Threshold: 0.7103)
2022-07-02 19:11:19 - Precision with Cosine-Similarity:          63.45
2022-07-02 19:11:19 - Recall with Cosine-Similarity:             100.00
2022-07-02 19:11:19 - Average Precision with Cosine-Similarity:  65.88

2022-07-02 19:11:19 - Accuracy with Manhatten-Distance:           64.32	(Threshold: 435.3516)
2022-07-02 19:11:19 - F1 with Manhatten-Distance:                 77.64	(Threshold: 574.8058)
2022-07-02 19:11:19 - Precision with Manhatten-Distance:          63.45
2022-07-02 19:11:19 - Recall with Manhatten-Distance:             100.00
2022-07-02 19:11:19 - Average Precision with Manhatten-Distance:  66.77

2022-07-02 19:11:19 - Accuracy with Euclidean-Distance:           63.82	(Threshold: 18.90

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:11:50 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 4 after 10 steps:
2022-07-02 19:11:50 - Accuracy with Cosine-Similarity:           63.82	(Threshold: 0.7044)
2022-07-02 19:11:50 - F1 with Cosine-Similarity:                 77.50	(Threshold: 0.7044)
2022-07-02 19:11:50 - Precision with Cosine-Similarity:          63.59
2022-07-02 19:11:50 - Recall with Cosine-Similarity:             99.20
2022-07-02 19:11:50 - Average Precision with Cosine-Similarity:  63.30

2022-07-02 19:11:50 - Accuracy with Manhatten-Distance:           64.32	(Threshold: 532.2913)
2022-07-02 19:11:50 - F1 with Manhatten-Distance:                 77.74	(Threshold: 532.2913)
2022-07-02 19:11:50 - Precision with Manhatten-Distance:          63.92
2022-07-02 19:11:50 - Recall with Manhatten-Distance:             99.20
2022-07-02 19:11:50 - Average Precision with Manhatten-Distance:  64.61

2022-07-02 19:11:50 - Accuracy with Euclidean-Distance:           63.82	(Threshold: 20.3350

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:12:20 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 5 after 10 steps:
2022-07-02 19:12:20 - Accuracy with Cosine-Similarity:           63.32	(Threshold: 0.6828)
2022-07-02 19:12:20 - F1 with Cosine-Similarity:                 77.40	(Threshold: 0.3714)
2022-07-02 19:12:20 - Precision with Cosine-Similarity:          63.13
2022-07-02 19:12:20 - Recall with Cosine-Similarity:             100.00
2022-07-02 19:12:20 - Average Precision with Cosine-Similarity:  62.08

2022-07-02 19:12:20 - Accuracy with Manhatten-Distance:           63.82	(Threshold: 614.5557)
2022-07-02 19:12:20 - F1 with Manhatten-Distance:                 77.64	(Threshold: 696.5840)
2022-07-02 19:12:20 - Precision with Manhatten-Distance:          63.45
2022-07-02 19:12:20 - Recall with Manhatten-Distance:             100.00
2022-07-02 19:12:20 - Average Precision with Manhatten-Distance:  63.80

2022-07-02 19:12:20 - Accuracy with Euclidean-Distance:           63.32	(Threshold: 24.55

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:12:51 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 6 after 10 steps:
2022-07-02 19:12:51 - Accuracy with Cosine-Similarity:           62.81	(Threshold: 0.3332)
2022-07-02 19:12:51 - F1 with Cosine-Similarity:                 77.02	(Threshold: 0.3332)
2022-07-02 19:12:51 - Precision with Cosine-Similarity:          62.94
2022-07-02 19:12:51 - Recall with Cosine-Similarity:             99.20
2022-07-02 19:12:51 - Average Precision with Cosine-Similarity:  59.72

2022-07-02 19:12:51 - Accuracy with Manhatten-Distance:           63.82	(Threshold: 684.8413)
2022-07-02 19:12:51 - F1 with Manhatten-Distance:                 77.64	(Threshold: 684.8413)
2022-07-02 19:12:51 - Precision with Manhatten-Distance:          63.45
2022-07-02 19:12:51 - Recall with Manhatten-Distance:             100.00
2022-07-02 19:12:51 - Average Precision with Manhatten-Distance:  60.47

2022-07-02 19:12:51 - Accuracy with Euclidean-Distance:           62.81	(Threshold: 36.019

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:13:21 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 7 after 10 steps:
2022-07-02 19:13:21 - Accuracy with Cosine-Similarity:           62.31	(Threshold: 0.1900)
2022-07-02 19:13:21 - F1 with Cosine-Similarity:                 76.78	(Threshold: 0.1468)
2022-07-02 19:13:21 - Precision with Cosine-Similarity:          62.63
2022-07-02 19:13:21 - Recall with Cosine-Similarity:             99.20
2022-07-02 19:13:21 - Average Precision with Cosine-Similarity:  59.94

2022-07-02 19:13:21 - Accuracy with Manhatten-Distance:           64.32	(Threshold: 710.5441)
2022-07-02 19:13:21 - F1 with Manhatten-Distance:                 77.88	(Threshold: 710.5441)
2022-07-02 19:13:21 - Precision with Manhatten-Distance:          63.78
2022-07-02 19:13:21 - Recall with Manhatten-Distance:             100.00
2022-07-02 19:13:21 - Average Precision with Manhatten-Distance:  60.81

2022-07-02 19:13:21 - Accuracy with Euclidean-Distance:           62.31	(Threshold: 39.689

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:13:52 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 8 after 10 steps:
2022-07-02 19:13:52 - Accuracy with Cosine-Similarity:           62.81	(Threshold: 0.1412)
2022-07-02 19:13:52 - F1 with Cosine-Similarity:                 77.02	(Threshold: 0.1412)
2022-07-02 19:13:52 - Precision with Cosine-Similarity:          62.94
2022-07-02 19:13:52 - Recall with Cosine-Similarity:             99.20
2022-07-02 19:13:52 - Average Precision with Cosine-Similarity:  59.21

2022-07-02 19:13:52 - Accuracy with Manhatten-Distance:           63.32	(Threshold: 692.1079)
2022-07-02 19:13:52 - F1 with Manhatten-Distance:                 77.40	(Threshold: 715.6714)
2022-07-02 19:13:52 - Precision with Manhatten-Distance:          63.13
2022-07-02 19:13:52 - Recall with Manhatten-Distance:             100.00
2022-07-02 19:13:52 - Average Precision with Manhatten-Distance:  59.97

2022-07-02 19:13:52 - Accuracy with Euclidean-Distance:           62.81	(Threshold: 41.037

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:14:22 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 9 after 10 steps:
2022-07-02 19:14:22 - Accuracy with Cosine-Similarity:           62.81	(Threshold: 0.0651)
2022-07-02 19:14:22 - F1 with Cosine-Similarity:                 77.02	(Threshold: 0.0651)
2022-07-02 19:14:22 - Precision with Cosine-Similarity:          62.94
2022-07-02 19:14:22 - Recall with Cosine-Similarity:             99.20
2022-07-02 19:14:22 - Average Precision with Cosine-Similarity:  58.16

2022-07-02 19:14:22 - Accuracy with Manhatten-Distance:           63.82	(Threshold: 715.0748)
2022-07-02 19:14:22 - F1 with Manhatten-Distance:                 77.64	(Threshold: 715.0748)
2022-07-02 19:14:22 - Precision with Manhatten-Distance:          63.45
2022-07-02 19:14:22 - Recall with Manhatten-Distance:             100.00
2022-07-02 19:14:22 - Average Precision with Manhatten-Distance:  59.38

2022-07-02 19:14:22 - Accuracy with Euclidean-Distance:           62.31	(Threshold: 36.539

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:14:52 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 10 after 10 steps:
2022-07-02 19:14:53 - Accuracy with Cosine-Similarity:           62.81	(Threshold: 0.0221)
2022-07-02 19:14:53 - F1 with Cosine-Similarity:                 77.02	(Threshold: 0.0221)
2022-07-02 19:14:53 - Precision with Cosine-Similarity:          62.94
2022-07-02 19:14:53 - Recall with Cosine-Similarity:             99.20
2022-07-02 19:14:53 - Average Precision with Cosine-Similarity:  57.55

2022-07-02 19:14:53 - Accuracy with Manhatten-Distance:           63.82	(Threshold: 703.7850)
2022-07-02 19:14:53 - F1 with Manhatten-Distance:                 77.64	(Threshold: 703.7850)
2022-07-02 19:14:53 - Precision with Manhatten-Distance:          63.45
2022-07-02 19:14:53 - Recall with Manhatten-Distance:             100.00
2022-07-02 19:14:53 - Average Precision with Manhatten-Distance:  58.71

2022-07-02 19:14:53 - Accuracy with Euclidean-Distance:           62.81	(Threshold: 43.88

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:15:23 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 11 after 10 steps:
2022-07-02 19:15:23 - Accuracy with Cosine-Similarity:           62.81	(Threshold: -0.0215)
2022-07-02 19:15:23 - F1 with Cosine-Similarity:                 77.02	(Threshold: -0.0215)
2022-07-02 19:15:23 - Precision with Cosine-Similarity:          62.94
2022-07-02 19:15:23 - Recall with Cosine-Similarity:             99.20
2022-07-02 19:15:23 - Average Precision with Cosine-Similarity:  57.48

2022-07-02 19:15:23 - Accuracy with Manhatten-Distance:           63.82	(Threshold: 702.4962)
2022-07-02 19:15:23 - F1 with Manhatten-Distance:                 77.64	(Threshold: 702.4962)
2022-07-02 19:15:23 - Precision with Manhatten-Distance:          63.45
2022-07-02 19:15:23 - Recall with Manhatten-Distance:             100.00
2022-07-02 19:15:23 - Average Precision with Manhatten-Distance:  58.50

2022-07-02 19:15:23 - Accuracy with Euclidean-Distance:           62.81	(Threshold: 44.

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:15:53 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 12 after 10 steps:
2022-07-02 19:15:54 - Accuracy with Cosine-Similarity:           62.31	(Threshold: -0.0416)
2022-07-02 19:15:54 - F1 with Cosine-Similarity:                 76.78	(Threshold: -0.1222)
2022-07-02 19:15:54 - Precision with Cosine-Similarity:          62.63
2022-07-02 19:15:54 - Recall with Cosine-Similarity:             99.20
2022-07-02 19:15:54 - Average Precision with Cosine-Similarity:  56.34

2022-07-02 19:15:54 - Accuracy with Manhatten-Distance:           63.82	(Threshold: 685.2755)
2022-07-02 19:15:54 - F1 with Manhatten-Distance:                 77.64	(Threshold: 685.2755)
2022-07-02 19:15:54 - Precision with Manhatten-Distance:          63.45
2022-07-02 19:15:54 - Recall with Manhatten-Distance:             100.00
2022-07-02 19:15:54 - Average Precision with Manhatten-Distance:  57.16

2022-07-02 19:15:54 - Accuracy with Euclidean-Distance:           62.31	(Threshold: 45.

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:16:24 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 13 after 10 steps:
2022-07-02 19:16:24 - Accuracy with Cosine-Similarity:           62.81	(Threshold: -0.0296)
2022-07-02 19:16:24 - F1 with Cosine-Similarity:                 77.02	(Threshold: -0.0296)
2022-07-02 19:16:24 - Precision with Cosine-Similarity:          62.94
2022-07-02 19:16:24 - Recall with Cosine-Similarity:             99.20
2022-07-02 19:16:24 - Average Precision with Cosine-Similarity:  56.93

2022-07-02 19:16:24 - Accuracy with Manhatten-Distance:           63.82	(Threshold: 693.0916)
2022-07-02 19:16:24 - F1 with Manhatten-Distance:                 77.64	(Threshold: 693.0916)
2022-07-02 19:16:24 - Precision with Manhatten-Distance:          63.45
2022-07-02 19:16:24 - Recall with Manhatten-Distance:             100.00
2022-07-02 19:16:24 - Average Precision with Manhatten-Distance:  58.26

2022-07-02 19:16:24 - Accuracy with Euclidean-Distance:           62.81	(Threshold: 45.

Iteration:   0%|          | 0/140 [00:00<?, ?it/s]

2022-07-02 19:16:54 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 14 after 10 steps:
2022-07-02 19:16:54 - Accuracy with Cosine-Similarity:           62.81	(Threshold: -0.0502)
2022-07-02 19:16:54 - F1 with Cosine-Similarity:                 77.02	(Threshold: -0.0502)
2022-07-02 19:16:54 - Precision with Cosine-Similarity:          62.94
2022-07-02 19:16:54 - Recall with Cosine-Similarity:             99.20
2022-07-02 19:16:54 - Average Precision with Cosine-Similarity:  56.47

2022-07-02 19:16:54 - Accuracy with Manhatten-Distance:           63.82	(Threshold: 688.8479)
2022-07-02 19:16:54 - F1 with Manhatten-Distance:                 77.64	(Threshold: 688.8479)
2022-07-02 19:16:54 - Precision with Manhatten-Distance:          63.45
2022-07-02 19:16:54 - Recall with Manhatten-Distance:             100.00
2022-07-02 19:16:54 - Average Precision with Manhatten-Distance:  57.63

2022-07-02 19:16:54 - Accuracy with Euclidean-Distance:           62.81	(Threshold: 45.

#### Generating predictions for argument validity:

- Load best model and use it to predict the labels for the eval_df.
- Best model is: nli-roberta-large-ranking-loss-extra-data-2022-07-01_14-27-18

In [31]:
model_path = output_path + 'task-A/validity/sbert/sentence-transformers/nli-roberta-large-extra-conclusions-2022-07-02_19-09-07'
#model_path = output_path + 'task-A/validity/sbert/sentence-transformers/nli-roberta-large--2022-07-02_19-06-14'

In [32]:
taska_valid_df = pd.read_csv('../data/TaskA_dev.csv')
taska_valid_df = taska_valid_df[taska_valid_df.Validity != 0]
taska_valid_df['Premise'] = taska_valid_df.apply(lambda x: x['topic'] + ' : ' +  x['Premise'], axis=1)

In [33]:
def predict_labels(df, model_path, clm1, clm2, threshold):
    best_model = SentenceTransformer(model_path)

    encoded_premises = best_model.encode(taska_valid_df[clm1].tolist())
    encoded_conclusions = best_model.encode(taska_valid_df[clm2].tolist())
    
    scores = [util.pytorch_cos_sim(x[0], x[1]).item() for x in zip(encoded_conclusions, encoded_premises)]
    labels = [1 if x > threshold else -1 for x in scores]
    
    df['pred_labels'] = labels
    
    return df

In [34]:
eval_df = predict_labels(taska_valid_df, model_path, 'Premise', 'Conclusion', 0.6146)
#eval_df = predict_labels(taska_valid_df, '../data/output/sentence-transformers/nli-roberta-large-ranking-loss-2022-07-01_14-24-04', 'Premise', 'Conclusion', 0.6146)

2022-07-02 19:21:34 - Load pretrained SentenceTransformer: ../../data-ceph/arguana/argmining22-sharedtask/models/task-A/validity/sbert/sentence-transformers/nli-roberta-large-extra-conclusions-2022-07-02_19-09-07
2022-07-02 19:21:36 - Use pytorch device: cuda


Batches:   0%|          | 0/7 [00:00<?, ?it/s]

Batches:   0%|          | 0/7 [00:00<?, ?it/s]

In [35]:
precision, recall, f1, _ = precision_recall_fscore_support(eval_df.Validity.tolist(), eval_df.pred_labels.tolist(), average='binary')

print('Precision: {}, Recall {}, F1: {}'.format(precision, recall, f1))

Precision: 0.70625, Recall 0.904, F1: 0.7929824561403509
