In [8]:
import json
import pandas as pd
import numpy as np
import sys

pd.set_option('display.max_colwidth', None)
sys.path.append('./src-py')

In [9]:
import sbert_training
from sklearn.metrics import precision_recall_fscore_support

In [10]:
from sentence_transformers import SentenceTransformer, InputExample, LoggingHandler, losses, models, util
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, BinaryClassificationEvaluator
from torch.utils.data import DataLoader
from sentence_transformers.evaluation import TripletEvaluator
from datetime import datetime
from sentence_transformers import util
from zipfile import ZipFile
from sentence_transformers.datasets import SentenceLabelDataset
from sentence_transformers.datasets import NoDuplicatesDataLoader

import logging

logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

In [11]:
taska_training_df = pd.read_csv('../data/TaskA_train.csv')
taskb_training_df = pd.read_csv('../data/TaskB_train.csv')

taska_valid_df = pd.read_csv('../data/TaskA_dev.csv')
taskb_valid_df = pd.read_csv('../data/TaskB_dev.csv')

In [12]:
#Mapping labels
taska_training_df = taska_training_df[taska_training_df.Novelty != 0]
taska_valid_df    = taska_valid_df[taska_valid_df.Novelty != 0]

taska_training_df['label'] = taska_training_df.Novelty.apply(lambda x : 1 if x == 1 else 0)
taska_valid_df['label'] = taska_valid_df.Novelty.apply(lambda x : 1 if x == 1 else 0)

taska_training_df['Premise'] = taska_training_df.apply(lambda x: x['topic'] + ' : ' +  x['Premise'], axis=1)
taska_valid_df['Premise'] = taska_valid_df.apply(lambda x: x['topic'] + ' : ' +  x['Premise'], axis=1)

In [13]:
taska_training_df['input_txt'] = taska_training_df.apply(lambda x: '[CLS] {} [SEP] {} [SEP] {} [SEP]'.format(x['topic'], x['Premise'], x['Conclusion']), axis=1)

In [16]:
taska_training_df[['Premise', 'Conclusion', 'Novelty', 'label']].head()

Unnamed: 0,Premise,Conclusion,Novelty,label
0,"TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions.",Depression is a well-known psychological problem of modern society that is partly caused by TV watching:,1,1
1,"TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions.",Children's TV viewing fosters negative emotions,-1,0
2,"TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions.",Popularity of TV is harmful to children,1,1
3,"TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions.",Violence on TV and in movies encourages psychological stress,1,1
4,"TV viewing is harmful to children : The popularity of TV watching is among the reasons of this phenomenon. Violence, aggression, crimes and wars are broadcast through the daily news as well as in movies, showing dark pictures that encourage psychological tension, pessimism and negative emotions.",US-India deal does not cap or limit Indian fissile material production.,-1,0


In [8]:
taska_training_df.label.value_counts()

1    401
0    320
Name: label, dtype: int64

In [9]:
taska_valid_df.label.value_counts()

1    125
0     74
Name: label, dtype: int64

In [17]:
def get_training_examples(df, eval_df, loss):
    
    logger.info("Read Triplet train dataset")
    train_examples = []
    for idx, row in df.iterrows():
        if loss == 'ContrastiveLoss':
            train_examples.append(InputExample(texts=[row['Premise'], row['Conclusion']], label=row['label']))
        elif loss == 'MultipleNegativesRankingLoss':
            if row['label'] == 1:
                train_examples.append(InputExample(texts=[row['Premise'], row['Conclusion']], label=1))
        else:
            train_examples.append(InputExample(texts=[row['anchor'], row['pos'], row['neg']], label=0))
            
    
    dev_samples = []
    for idx, row in eval_df.iterrows():
        dev_samples.append(InputExample(texts=[row['Premise'], row['Conclusion']], label=row['label']))
    
    return train_examples, dev_samples

            
def train_model(df, eval_df, output_path, model_name, num_epochs=3, train_batch_size=16, model_suffix='', \
                data_file_suffix='', max_seq_length=256, 
                special_tokens=[], loss='Triplet', sentence_transformer=False, evaluation_steps=5):
    
    output_path = output_path + model_name+ "-" + model_suffix + "-"+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

    if sentence_transformer:
        word_embedding_model = SentenceTransformer(model_name)
        word_embedding_model.max_seq_length = max_seq_length
    else:
        word_embedding_model = models.Transformer(model_name)
        word_embedding_model.max_seq_length = max_seq_length
    
    
    if len(special_tokens) > 0:
        word_embedding_model.tokenizer.add_tokens(special_tokens, special_tokens=True)
        word_embedding_model.resize_token_embeddings(len(word_embedding_model.tokenizer))
        
    # Apply mean pooling to get one fixed sized sentence vector
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                                   pooling_mode_mean_tokens=True,
                                   pooling_mode_cls_token=False,
                                   pooling_mode_max_tokens=False)

    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])


    train_examples, dev_samples = get_training_examples(df, eval_df, loss)

    print('Len of training: {}'.format(len(train_examples)))
    print('Len of Dev: {}'.format(len(dev_samples)))
    
    if loss == 'MultipleNegativesRankingLoss':
        # Special data loader that avoid duplicates within a batch
        train_dataloader = NoDuplicatesDataLoader(train_examples, batch_size=train_batch_size)
        # Our training loss
        train_loss = losses.MultipleNegativesRankingLoss(model)
    elif loss == 'ContrastiveLoss':
        train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
        train_loss = losses.ContrastiveLoss(model)
    else:
        train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
        train_loss = losses.TripletLoss(model)
    

    evaluator = BinaryClassificationEvaluator.from_input_examples(dev_samples, batch_size=train_batch_size, name='sts-dev')

    warmup_steps = int(len(train_dataloader) * num_epochs * 0.1) #10% of train data

    print('Evaluating before start learning.....')
    model.evaluate(evaluator)
    print('====== Start Training =======')
    
    # Train the model
    model.fit(train_objectives=[(train_dataloader, train_loss)],
              evaluator=evaluator,
              epochs=num_epochs,
              save_best_model=True,
              checkpoint_save_steps=evaluation_steps,
              optimizer_params={'lr':5e-06},
              checkpoint_save_total_limit=3,
              evaluation_steps=evaluation_steps,
              warmup_steps=warmup_steps,
              output_path=output_path)
    
    return model, evaluator

In [14]:
#train_model(taska_training_df, taska_valid_df, '../data/output/', 'sentence-transformers/nli-roberta-large', num_epochs=20, train_batch_size=16, model_suffix='', data_file_suffix='', max_seq_length=256, special_tokens=[], loss='ContrastiveLoss', sentence_transformer=False)

In [18]:
trained_model, evaluator = train_model(taska_training_df, taska_valid_df, '../data/output/', 
            'sentence-transformers/nli-roberta-large', 
            num_epochs=15, train_batch_size=32,
            model_suffix='ranking-loss', max_seq_length=512, special_tokens=[], 
            loss='MultipleNegativesRankingLoss', sentence_transformer=False, evaluation_steps=10)

2022-07-01 14:24:14 - Use pytorch device: cuda
2022-07-01 14:24:14 - Read Triplet train dataset
Len of training: 401
Len of Dev: 199
Evaluating before start learning.....
2022-07-01 14:24:14 - Binary Accuracy Evaluation of the model on sts-dev dataset:
2022-07-01 14:24:14 - Accuracy with Cosine-Similarity:           72.36	(Threshold: 0.6503)
2022-07-01 14:24:14 - F1 with Cosine-Similarity:                 80.69	(Threshold: 0.5985)
2022-07-01 14:24:14 - Precision with Cosine-Similarity:          70.91
2022-07-01 14:24:14 - Recall with Cosine-Similarity:             93.60
2022-07-01 14:24:14 - Average Precision with Cosine-Similarity:  82.54

2022-07-01 14:24:14 - Accuracy with Manhatten-Distance:           72.36	(Threshold: 646.3732)
2022-07-01 14:24:14 - F1 with Manhatten-Distance:                 80.56	(Threshold: 679.7071)
2022-07-01 14:24:14 - Precision with Manhatten-Distance:          71.17
2022-07-01 14:24:14 - Recall with Manhatten-Distance:             92.80
2022-07-01 14:24:14

Epoch:   0%|          | 0/15 [00:00<?, ?it/s]

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:16 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 0 after 10 steps:
2022-07-01 14:24:16 - Accuracy with Cosine-Similarity:           72.36	(Threshold: 0.6591)
2022-07-01 14:24:16 - F1 with Cosine-Similarity:                 80.56	(Threshold: 0.6068)
2022-07-01 14:24:16 - Precision with Cosine-Similarity:          71.17
2022-07-01 14:24:16 - Recall with Cosine-Similarity:             92.80
2022-07-01 14:24:16 - Average Precision with Cosine-Similarity:  82.88

2022-07-01 14:24:16 - Accuracy with Manhatten-Distance:           72.86	(Threshold: 635.0920)
2022-07-01 14:24:16 - F1 with Manhatten-Distance:                 80.56	(Threshold: 672.8419)
2022-07-01 14:24:16 - Precision with Manhatten-Distance:          71.17
2022-07-01 14:24:16 - Recall with Manhatten-Distance:             92.80
2022-07-01 14:24:16 - Average Precision with Manhatten-Distance:  82.97

2022-07-01 14:24:16 - Accuracy with Euclidean-Distance:           72.86	(Threshold: 25.3105

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:20 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 1 after 10 steps:
2022-07-01 14:24:20 - Accuracy with Cosine-Similarity:           72.86	(Threshold: 0.6574)
2022-07-01 14:24:20 - F1 with Cosine-Similarity:                 80.58	(Threshold: 0.6406)
2022-07-01 14:24:20 - Precision with Cosine-Similarity:          73.20
2022-07-01 14:24:20 - Recall with Cosine-Similarity:             89.60
2022-07-01 14:24:20 - Average Precision with Cosine-Similarity:  83.32

2022-07-01 14:24:20 - Accuracy with Manhatten-Distance:           74.87	(Threshold: 614.6072)
2022-07-01 14:24:20 - F1 with Manhatten-Distance:                 81.06	(Threshold: 614.6072)
2022-07-01 14:24:20 - Precision with Manhatten-Distance:          76.98
2022-07-01 14:24:20 - Recall with Manhatten-Distance:             85.60
2022-07-01 14:24:20 - Average Precision with Manhatten-Distance:  83.63

2022-07-01 14:24:20 - Accuracy with Euclidean-Distance:           72.86	(Threshold: 25.3056

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:24 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 2 after 10 steps:
2022-07-01 14:24:25 - Accuracy with Cosine-Similarity:           72.36	(Threshold: 0.6960)
2022-07-01 14:24:25 - F1 with Cosine-Similarity:                 79.71	(Threshold: 0.6549)
2022-07-01 14:24:25 - Precision with Cosine-Similarity:          72.85
2022-07-01 14:24:25 - Recall with Cosine-Similarity:             88.00
2022-07-01 14:24:25 - Average Precision with Cosine-Similarity:  83.73

2022-07-01 14:24:25 - Accuracy with Manhatten-Distance:           75.38	(Threshold: 600.8503)
2022-07-01 14:24:25 - F1 with Manhatten-Distance:                 81.37	(Threshold: 600.8503)
2022-07-01 14:24:25 - Precision with Manhatten-Distance:          77.54
2022-07-01 14:24:25 - Recall with Manhatten-Distance:             85.60
2022-07-01 14:24:25 - Average Precision with Manhatten-Distance:  83.65

2022-07-01 14:24:25 - Accuracy with Euclidean-Distance:           72.86	(Threshold: 23.7834

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:31 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 3 after 10 steps:
2022-07-01 14:24:31 - Accuracy with Cosine-Similarity:           71.86	(Threshold: 0.6867)
2022-07-01 14:24:31 - F1 with Cosine-Similarity:                 79.58	(Threshold: 0.6387)
2022-07-01 14:24:31 - Precision with Cosine-Similarity:          71.07
2022-07-01 14:24:31 - Recall with Cosine-Similarity:             90.40
2022-07-01 14:24:31 - Average Precision with Cosine-Similarity:  83.67

2022-07-01 14:24:31 - Accuracy with Manhatten-Distance:           74.37	(Threshold: 603.8773)
2022-07-01 14:24:31 - F1 with Manhatten-Distance:                 80.90	(Threshold: 611.5989)
2022-07-01 14:24:31 - Precision with Manhatten-Distance:          76.06
2022-07-01 14:24:31 - Recall with Manhatten-Distance:             86.40
2022-07-01 14:24:31 - Average Precision with Manhatten-Distance:  83.51

2022-07-01 14:24:31 - Accuracy with Euclidean-Distance:           71.36	(Threshold: 23.7285

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:34 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 4 after 10 steps:
2022-07-01 14:24:34 - Accuracy with Cosine-Similarity:           70.85	(Threshold: 0.6375)
2022-07-01 14:24:34 - F1 with Cosine-Similarity:                 79.58	(Threshold: 0.6375)
2022-07-01 14:24:34 - Precision with Cosine-Similarity:          71.07
2022-07-01 14:24:34 - Recall with Cosine-Similarity:             90.40
2022-07-01 14:24:34 - Average Precision with Cosine-Similarity:  83.05

2022-07-01 14:24:34 - Accuracy with Manhatten-Distance:           74.37	(Threshold: 613.4158)
2022-07-01 14:24:34 - F1 with Manhatten-Distance:                 80.90	(Threshold: 613.4158)
2022-07-01 14:24:34 - Precision with Manhatten-Distance:          76.06
2022-07-01 14:24:34 - Recall with Manhatten-Distance:             86.40
2022-07-01 14:24:34 - Average Precision with Manhatten-Distance:  83.23

2022-07-01 14:24:34 - Accuracy with Euclidean-Distance:           70.35	(Threshold: 23.9833

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:37 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 5 after 10 steps:
2022-07-01 14:24:37 - Accuracy with Cosine-Similarity:           70.85	(Threshold: 0.6529)
2022-07-01 14:24:37 - F1 with Cosine-Similarity:                 79.11	(Threshold: 0.4599)
2022-07-01 14:24:37 - Precision with Cosine-Similarity:          65.45
2022-07-01 14:24:37 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:24:37 - Average Precision with Cosine-Similarity:  82.50

2022-07-01 14:24:37 - Accuracy with Manhatten-Distance:           73.37	(Threshold: 620.5010)
2022-07-01 14:24:37 - F1 with Manhatten-Distance:                 80.30	(Threshold: 622.2135)
2022-07-01 14:24:37 - Precision with Manhatten-Distance:          75.00
2022-07-01 14:24:37 - Recall with Manhatten-Distance:             86.40
2022-07-01 14:24:37 - Average Precision with Manhatten-Distance:  82.79

2022-07-01 14:24:37 - Accuracy with Euclidean-Distance:           70.85	(Threshold: 25.310

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:40 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 6 after 10 steps:
2022-07-01 14:24:40 - Accuracy with Cosine-Similarity:           71.36	(Threshold: 0.6476)
2022-07-01 14:24:40 - F1 with Cosine-Similarity:                 79.11	(Threshold: 0.4501)
2022-07-01 14:24:40 - Precision with Cosine-Similarity:          65.45
2022-07-01 14:24:40 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:24:40 - Average Precision with Cosine-Similarity:  82.18

2022-07-01 14:24:40 - Accuracy with Manhatten-Distance:           72.86	(Threshold: 629.0262)
2022-07-01 14:24:40 - F1 with Manhatten-Distance:                 80.29	(Threshold: 636.7933)
2022-07-01 14:24:40 - Precision with Manhatten-Distance:          73.83
2022-07-01 14:24:40 - Recall with Manhatten-Distance:             88.00
2022-07-01 14:24:40 - Average Precision with Manhatten-Distance:  82.48

2022-07-01 14:24:40 - Accuracy with Euclidean-Distance:           70.35	(Threshold: 25.313

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:42 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 7 after 10 steps:
2022-07-01 14:24:43 - Accuracy with Cosine-Similarity:           70.85	(Threshold: 0.6400)
2022-07-01 14:24:43 - F1 with Cosine-Similarity:                 79.11	(Threshold: 0.4376)
2022-07-01 14:24:43 - Precision with Cosine-Similarity:          65.45
2022-07-01 14:24:43 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:24:43 - Average Precision with Cosine-Similarity:  82.14

2022-07-01 14:24:43 - Accuracy with Manhatten-Distance:           72.36	(Threshold: 639.6288)
2022-07-01 14:24:43 - F1 with Manhatten-Distance:                 80.00	(Threshold: 639.6288)
2022-07-01 14:24:43 - Precision with Manhatten-Distance:          73.33
2022-07-01 14:24:43 - Recall with Manhatten-Distance:             88.00
2022-07-01 14:24:43 - Average Precision with Manhatten-Distance:  82.37

2022-07-01 14:24:43 - Accuracy with Euclidean-Distance:           69.85	(Threshold: 25.657

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:45 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 8 after 10 steps:
2022-07-01 14:24:45 - Accuracy with Cosine-Similarity:           70.85	(Threshold: 0.6408)
2022-07-01 14:24:45 - F1 with Cosine-Similarity:                 79.11	(Threshold: 0.4285)
2022-07-01 14:24:45 - Precision with Cosine-Similarity:          65.45
2022-07-01 14:24:45 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:24:45 - Average Precision with Cosine-Similarity:  82.04

2022-07-01 14:24:45 - Accuracy with Manhatten-Distance:           72.36	(Threshold: 643.1239)
2022-07-01 14:24:45 - F1 with Manhatten-Distance:                 80.00	(Threshold: 643.1239)
2022-07-01 14:24:45 - Precision with Manhatten-Distance:          73.33
2022-07-01 14:24:45 - Recall with Manhatten-Distance:             88.00
2022-07-01 14:24:45 - Average Precision with Manhatten-Distance:  82.21

2022-07-01 14:24:45 - Accuracy with Euclidean-Distance:           69.35	(Threshold: 25.782

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:48 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 9 after 10 steps:
2022-07-01 14:24:48 - Accuracy with Cosine-Similarity:           70.35	(Threshold: 0.6397)
2022-07-01 14:24:48 - F1 with Cosine-Similarity:                 79.11	(Threshold: 0.4230)
2022-07-01 14:24:48 - Precision with Cosine-Similarity:          65.45
2022-07-01 14:24:48 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:24:48 - Average Precision with Cosine-Similarity:  81.90

2022-07-01 14:24:48 - Accuracy with Manhatten-Distance:           72.36	(Threshold: 648.6514)
2022-07-01 14:24:48 - F1 with Manhatten-Distance:                 80.00	(Threshold: 648.6514)
2022-07-01 14:24:48 - Precision with Manhatten-Distance:          73.33
2022-07-01 14:24:48 - Recall with Manhatten-Distance:             88.00
2022-07-01 14:24:48 - Average Precision with Manhatten-Distance:  81.98

2022-07-01 14:24:48 - Accuracy with Euclidean-Distance:           69.35	(Threshold: 25.982

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:51 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 10 after 10 steps:
2022-07-01 14:24:51 - Accuracy with Cosine-Similarity:           70.35	(Threshold: 0.6304)
2022-07-01 14:24:51 - F1 with Cosine-Similarity:                 79.11	(Threshold: 0.4175)
2022-07-01 14:24:51 - Precision with Cosine-Similarity:          65.45
2022-07-01 14:24:51 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:24:51 - Average Precision with Cosine-Similarity:  81.80

2022-07-01 14:24:51 - Accuracy with Manhatten-Distance:           72.36	(Threshold: 650.5294)
2022-07-01 14:24:51 - F1 with Manhatten-Distance:                 80.00	(Threshold: 650.5294)
2022-07-01 14:24:51 - Precision with Manhatten-Distance:          73.33
2022-07-01 14:24:51 - Recall with Manhatten-Distance:             88.00
2022-07-01 14:24:51 - Average Precision with Manhatten-Distance:  82.01

2022-07-01 14:24:51 - Accuracy with Euclidean-Distance:           68.84	(Threshold: 26.01

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:54 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 11 after 10 steps:
2022-07-01 14:24:54 - Accuracy with Cosine-Similarity:           69.85	(Threshold: 0.6311)
2022-07-01 14:24:54 - F1 with Cosine-Similarity:                 79.11	(Threshold: 0.4146)
2022-07-01 14:24:54 - Precision with Cosine-Similarity:          65.45
2022-07-01 14:24:54 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:24:54 - Average Precision with Cosine-Similarity:  81.77

2022-07-01 14:24:54 - Accuracy with Manhatten-Distance:           71.36	(Threshold: 650.3481)
2022-07-01 14:24:54 - F1 with Manhatten-Distance:                 79.42	(Threshold: 652.4294)
2022-07-01 14:24:54 - Precision with Manhatten-Distance:          72.37
2022-07-01 14:24:54 - Recall with Manhatten-Distance:             88.00
2022-07-01 14:24:54 - Average Precision with Manhatten-Distance:  81.83

2022-07-01 14:24:54 - Accuracy with Euclidean-Distance:           68.34	(Threshold: 25.51

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:56 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 12 after 10 steps:
2022-07-01 14:24:57 - Accuracy with Cosine-Similarity:           69.85	(Threshold: 0.6294)
2022-07-01 14:24:57 - F1 with Cosine-Similarity:                 79.11	(Threshold: 0.4138)
2022-07-01 14:24:57 - Precision with Cosine-Similarity:          65.45
2022-07-01 14:24:57 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:24:57 - Average Precision with Cosine-Similarity:  81.79

2022-07-01 14:24:57 - Accuracy with Manhatten-Distance:           70.85	(Threshold: 652.7652)
2022-07-01 14:24:57 - F1 with Manhatten-Distance:                 79.22	(Threshold: 746.9213)
2022-07-01 14:24:57 - Precision with Manhatten-Distance:          66.67
2022-07-01 14:24:57 - Recall with Manhatten-Distance:             97.60
2022-07-01 14:24:57 - Average Precision with Manhatten-Distance:  81.87

2022-07-01 14:24:57 - Accuracy with Euclidean-Distance:           68.34	(Threshold: 25.55

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:24:59 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 13 after 10 steps:
2022-07-01 14:24:59 - Accuracy with Cosine-Similarity:           69.85	(Threshold: 0.6289)
2022-07-01 14:24:59 - F1 with Cosine-Similarity:                 79.11	(Threshold: 0.4121)
2022-07-01 14:24:59 - Precision with Cosine-Similarity:          65.45
2022-07-01 14:24:59 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:24:59 - Average Precision with Cosine-Similarity:  81.70

2022-07-01 14:24:59 - Accuracy with Manhatten-Distance:           70.85	(Threshold: 653.0719)
2022-07-01 14:24:59 - F1 with Manhatten-Distance:                 79.22	(Threshold: 747.4995)
2022-07-01 14:24:59 - Precision with Manhatten-Distance:          66.67
2022-07-01 14:24:59 - Recall with Manhatten-Distance:             97.60
2022-07-01 14:24:59 - Average Precision with Manhatten-Distance:  81.87

2022-07-01 14:24:59 - Accuracy with Euclidean-Distance:           68.34	(Threshold: 25.59

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:25:02 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 14 after 10 steps:
2022-07-01 14:25:02 - Accuracy with Cosine-Similarity:           69.35	(Threshold: 0.6377)
2022-07-01 14:25:02 - F1 with Cosine-Similarity:                 79.11	(Threshold: 0.4116)
2022-07-01 14:25:02 - Precision with Cosine-Similarity:          65.45
2022-07-01 14:25:02 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:25:02 - Average Precision with Cosine-Similarity:  81.70

2022-07-01 14:25:02 - Accuracy with Manhatten-Distance:           70.85	(Threshold: 653.2681)
2022-07-01 14:25:02 - F1 with Manhatten-Distance:                 79.22	(Threshold: 747.7288)
2022-07-01 14:25:02 - Precision with Manhatten-Distance:          66.67
2022-07-01 14:25:02 - Recall with Manhatten-Distance:             97.60
2022-07-01 14:25:02 - Average Precision with Manhatten-Distance:  81.85

2022-07-01 14:25:02 - Accuracy with Euclidean-Distance:           68.34	(Threshold: 25.60

### Using the auto-generated conclusions:

In [12]:
taska_training_df = pd.read_pickle('../data/TaskA_train_with_extra_conclusions.pkl')
taska_valid_df = pd.read_csv('../data/TaskA_dev.csv')

In [13]:
#Mapping labels
taska_training_df = taska_training_df[taska_training_df.Validity != 0]
taska_valid_df    = taska_valid_df[taska_valid_df.Validity != 0]

taska_training_df['label'] = taska_training_df.Validity.apply(lambda x : 1 if x == 1 else 0)
taska_valid_df['label'] = taska_valid_df.Validity.apply(lambda x : 1 if x == 1 else 0)

taska_training_df['Premise'] = taska_training_df.apply(lambda x: x['topic'] + ' : ' +  x['Premise'], axis=1)
taska_valid_df['Premise'] = taska_valid_df.apply(lambda x: x['topic'] + ' : ' +  x['Premise'], axis=1)

In [19]:
train_model(taska_training_df, taska_valid_df, '../data/output/', 
            'sentence-transformers/nli-roberta-large', 
            num_epochs=15, train_batch_size=32,
            model_suffix='ranking-loss-extra-data', max_seq_length=512, special_tokens=[], 
            loss='MultipleNegativesRankingLoss', sentence_transformer=False, evaluation_steps=10)

2022-07-01 14:27:28 - Use pytorch device: cuda
2022-07-01 14:27:28 - Read Triplet train dataset
Len of training: 401
Len of Dev: 199
Evaluating before start learning.....
2022-07-01 14:27:28 - Binary Accuracy Evaluation of the model on sts-dev dataset:
2022-07-01 14:27:28 - Accuracy with Cosine-Similarity:           72.36	(Threshold: 0.6503)
2022-07-01 14:27:28 - F1 with Cosine-Similarity:                 80.69	(Threshold: 0.5985)
2022-07-01 14:27:28 - Precision with Cosine-Similarity:          70.91
2022-07-01 14:27:28 - Recall with Cosine-Similarity:             93.60
2022-07-01 14:27:28 - Average Precision with Cosine-Similarity:  82.54

2022-07-01 14:27:28 - Accuracy with Manhatten-Distance:           72.36	(Threshold: 646.3732)
2022-07-01 14:27:28 - F1 with Manhatten-Distance:                 80.56	(Threshold: 679.7071)
2022-07-01 14:27:28 - Precision with Manhatten-Distance:          71.17
2022-07-01 14:27:28 - Recall with Manhatten-Distance:             92.80
2022-07-01 14:27:28

Epoch:   0%|          | 0/15 [00:00<?, ?it/s]

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:27:30 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 0 after 10 steps:
2022-07-01 14:27:30 - Accuracy with Cosine-Similarity:           73.37	(Threshold: 0.6399)
2022-07-01 14:27:30 - F1 with Cosine-Similarity:                 80.59	(Threshold: 0.6399)
2022-07-01 14:27:30 - Precision with Cosine-Similarity:          74.32
2022-07-01 14:27:30 - Recall with Cosine-Similarity:             88.00
2022-07-01 14:27:30 - Average Precision with Cosine-Similarity:  83.01

2022-07-01 14:27:30 - Accuracy with Manhatten-Distance:           72.86	(Threshold: 634.2710)
2022-07-01 14:27:30 - F1 with Manhatten-Distance:                 80.57	(Threshold: 662.9831)
2022-07-01 14:27:30 - Precision with Manhatten-Distance:          72.15
2022-07-01 14:27:30 - Recall with Manhatten-Distance:             91.20
2022-07-01 14:27:30 - Average Precision with Manhatten-Distance:  82.94

2022-07-01 14:27:30 - Accuracy with Euclidean-Distance:           72.86	(Threshold: 25.3020

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:27:34 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 1 after 10 steps:
2022-07-01 14:27:35 - Accuracy with Cosine-Similarity:           72.86	(Threshold: 0.6461)
2022-07-01 14:27:35 - F1 with Cosine-Similarity:                 80.29	(Threshold: 0.6383)
2022-07-01 14:27:35 - Precision with Cosine-Similarity:          73.83
2022-07-01 14:27:35 - Recall with Cosine-Similarity:             88.00
2022-07-01 14:27:35 - Average Precision with Cosine-Similarity:  83.42

2022-07-01 14:27:35 - Accuracy with Manhatten-Distance:           75.38	(Threshold: 621.4330)
2022-07-01 14:27:35 - F1 with Manhatten-Distance:                 81.51	(Threshold: 621.4330)
2022-07-01 14:27:35 - Precision with Manhatten-Distance:          77.14
2022-07-01 14:27:35 - Recall with Manhatten-Distance:             86.40
2022-07-01 14:27:35 - Average Precision with Manhatten-Distance:  83.58

2022-07-01 14:27:35 - Accuracy with Euclidean-Distance:           72.86	(Threshold: 24.2606

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:27:42 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 2 after 10 steps:
2022-07-01 14:27:42 - Accuracy with Cosine-Similarity:           73.37	(Threshold: 0.6747)
2022-07-01 14:27:42 - F1 with Cosine-Similarity:                 80.14	(Threshold: 0.6455)
2022-07-01 14:27:42 - Precision with Cosine-Similarity:          73.03
2022-07-01 14:27:42 - Recall with Cosine-Similarity:             88.80
2022-07-01 14:27:42 - Average Precision with Cosine-Similarity:  83.85

2022-07-01 14:27:42 - Accuracy with Manhatten-Distance:           75.38	(Threshold: 616.5479)
2022-07-01 14:27:42 - F1 with Manhatten-Distance:                 81.51	(Threshold: 616.5479)
2022-07-01 14:27:42 - Precision with Manhatten-Distance:          77.14
2022-07-01 14:27:42 - Recall with Manhatten-Distance:             86.40
2022-07-01 14:27:42 - Average Precision with Manhatten-Distance:  83.82

2022-07-01 14:27:42 - Accuracy with Euclidean-Distance:           73.37	(Threshold: 24.1235

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:27:48 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 3 after 10 steps:
2022-07-01 14:27:48 - Accuracy with Cosine-Similarity:           72.86	(Threshold: 0.6782)
2022-07-01 14:27:48 - F1 with Cosine-Similarity:                 79.43	(Threshold: 0.6391)
2022-07-01 14:27:48 - Precision with Cosine-Similarity:          71.34
2022-07-01 14:27:48 - Recall with Cosine-Similarity:             89.60
2022-07-01 14:27:48 - Average Precision with Cosine-Similarity:  83.58

2022-07-01 14:27:48 - Accuracy with Manhatten-Distance:           75.38	(Threshold: 613.3158)
2022-07-01 14:27:48 - F1 with Manhatten-Distance:                 81.51	(Threshold: 613.3158)
2022-07-01 14:27:48 - Precision with Manhatten-Distance:          77.14
2022-07-01 14:27:48 - Recall with Manhatten-Distance:             86.40
2022-07-01 14:27:48 - Average Precision with Manhatten-Distance:  83.66

2022-07-01 14:27:48 - Accuracy with Euclidean-Distance:           72.36	(Threshold: 23.9726

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:27:51 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 4 after 10 steps:
2022-07-01 14:27:51 - Accuracy with Cosine-Similarity:           71.36	(Threshold: 0.6786)
2022-07-01 14:27:51 - F1 with Cosine-Similarity:                 79.11	(Threshold: 0.4746)
2022-07-01 14:27:51 - Precision with Cosine-Similarity:          65.45
2022-07-01 14:27:51 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:27:51 - Average Precision with Cosine-Similarity:  83.13

2022-07-01 14:27:51 - Accuracy with Manhatten-Distance:           73.87	(Threshold: 617.1617)
2022-07-01 14:27:51 - F1 with Manhatten-Distance:                 80.60	(Threshold: 617.1617)
2022-07-01 14:27:51 - Precision with Manhatten-Distance:          75.52
2022-07-01 14:27:51 - Recall with Manhatten-Distance:             86.40
2022-07-01 14:27:51 - Average Precision with Manhatten-Distance:  83.24

2022-07-01 14:27:51 - Accuracy with Euclidean-Distance:           71.36	(Threshold: 24.543

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:27:53 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 5 after 10 steps:
2022-07-01 14:27:54 - Accuracy with Cosine-Similarity:           70.85	(Threshold: 0.6585)
2022-07-01 14:27:54 - F1 with Cosine-Similarity:                 79.37	(Threshold: 0.4830)
2022-07-01 14:27:54 - Precision with Cosine-Similarity:          65.79
2022-07-01 14:27:54 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:27:54 - Average Precision with Cosine-Similarity:  82.74

2022-07-01 14:27:54 - Accuracy with Manhatten-Distance:           72.86	(Threshold: 621.9987)
2022-07-01 14:27:54 - F1 with Manhatten-Distance:                 79.85	(Threshold: 631.2996)
2022-07-01 14:27:54 - Precision with Manhatten-Distance:          73.65
2022-07-01 14:27:54 - Recall with Manhatten-Distance:             87.20
2022-07-01 14:27:54 - Average Precision with Manhatten-Distance:  82.67

2022-07-01 14:27:54 - Accuracy with Euclidean-Distance:           70.85	(Threshold: 25.033

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:27:56 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 6 after 10 steps:
2022-07-01 14:27:56 - Accuracy with Cosine-Similarity:           71.36	(Threshold: 0.6509)
2022-07-01 14:27:56 - F1 with Cosine-Similarity:                 79.37	(Threshold: 0.4705)
2022-07-01 14:27:56 - Precision with Cosine-Similarity:          65.79
2022-07-01 14:27:56 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:27:56 - Average Precision with Cosine-Similarity:  82.26

2022-07-01 14:27:56 - Accuracy with Manhatten-Distance:           72.86	(Threshold: 635.5295)
2022-07-01 14:27:56 - F1 with Manhatten-Distance:                 80.15	(Threshold: 635.5295)
2022-07-01 14:27:56 - Precision with Manhatten-Distance:          74.15
2022-07-01 14:27:56 - Recall with Manhatten-Distance:             87.20
2022-07-01 14:27:56 - Average Precision with Manhatten-Distance:  82.45

2022-07-01 14:27:56 - Accuracy with Euclidean-Distance:           70.85	(Threshold: 25.255

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:27:59 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 7 after 10 steps:
2022-07-01 14:27:59 - Accuracy with Cosine-Similarity:           70.85	(Threshold: 0.6468)
2022-07-01 14:27:59 - F1 with Cosine-Similarity:                 79.37	(Threshold: 0.4598)
2022-07-01 14:27:59 - Precision with Cosine-Similarity:          65.79
2022-07-01 14:27:59 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:27:59 - Average Precision with Cosine-Similarity:  82.07

2022-07-01 14:27:59 - Accuracy with Manhatten-Distance:           71.36	(Threshold: 643.0668)
2022-07-01 14:27:59 - F1 with Manhatten-Distance:                 79.42	(Threshold: 643.0668)
2022-07-01 14:27:59 - Precision with Manhatten-Distance:          72.37
2022-07-01 14:27:59 - Recall with Manhatten-Distance:             88.00
2022-07-01 14:27:59 - Average Precision with Manhatten-Distance:  82.28

2022-07-01 14:27:59 - Accuracy with Euclidean-Distance:           69.35	(Threshold: 25.577

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:28:02 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 8 after 10 steps:
2022-07-01 14:28:02 - Accuracy with Cosine-Similarity:           70.35	(Threshold: 0.6418)
2022-07-01 14:28:02 - F1 with Cosine-Similarity:                 79.37	(Threshold: 0.4509)
2022-07-01 14:28:02 - Precision with Cosine-Similarity:          65.79
2022-07-01 14:28:02 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:28:02 - Average Precision with Cosine-Similarity:  81.91

2022-07-01 14:28:02 - Accuracy with Manhatten-Distance:           70.85	(Threshold: 649.6815)
2022-07-01 14:28:02 - F1 with Manhatten-Distance:                 79.14	(Threshold: 649.6815)
2022-07-01 14:28:02 - Precision with Manhatten-Distance:          71.90
2022-07-01 14:28:02 - Recall with Manhatten-Distance:             88.00
2022-07-01 14:28:02 - Average Precision with Manhatten-Distance:  82.18

2022-07-01 14:28:02 - Accuracy with Euclidean-Distance:           68.84	(Threshold: 25.519

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:28:05 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 9 after 10 steps:
2022-07-01 14:28:05 - Accuracy with Cosine-Similarity:           69.85	(Threshold: 0.6397)
2022-07-01 14:28:05 - F1 with Cosine-Similarity:                 79.37	(Threshold: 0.4480)
2022-07-01 14:28:05 - Precision with Cosine-Similarity:          65.79
2022-07-01 14:28:05 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:28:05 - Average Precision with Cosine-Similarity:  81.81

2022-07-01 14:28:05 - Accuracy with Manhatten-Distance:           71.36	(Threshold: 649.6365)
2022-07-01 14:28:05 - F1 with Manhatten-Distance:                 79.42	(Threshold: 649.6365)
2022-07-01 14:28:05 - Precision with Manhatten-Distance:          72.37
2022-07-01 14:28:05 - Recall with Manhatten-Distance:             88.00
2022-07-01 14:28:05 - Average Precision with Manhatten-Distance:  82.17

2022-07-01 14:28:05 - Accuracy with Euclidean-Distance:           68.34	(Threshold: 25.428

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:28:07 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 10 after 10 steps:
2022-07-01 14:28:08 - Accuracy with Cosine-Similarity:           69.35	(Threshold: 0.6399)
2022-07-01 14:28:08 - F1 with Cosine-Similarity:                 79.37	(Threshold: 0.4442)
2022-07-01 14:28:08 - Precision with Cosine-Similarity:          65.79
2022-07-01 14:28:08 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:28:08 - Average Precision with Cosine-Similarity:  81.66

2022-07-01 14:28:08 - Accuracy with Manhatten-Distance:           71.36	(Threshold: 654.9318)
2022-07-01 14:28:08 - F1 with Manhatten-Distance:                 79.42	(Threshold: 654.9318)
2022-07-01 14:28:08 - Precision with Manhatten-Distance:          72.37
2022-07-01 14:28:08 - Recall with Manhatten-Distance:             88.00
2022-07-01 14:28:08 - Average Precision with Manhatten-Distance:  82.20

2022-07-01 14:28:08 - Accuracy with Euclidean-Distance:           68.84	(Threshold: 25.45

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:28:10 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 11 after 10 steps:
2022-07-01 14:28:10 - Accuracy with Cosine-Similarity:           68.84	(Threshold: 0.6449)
2022-07-01 14:28:10 - F1 with Cosine-Similarity:                 79.37	(Threshold: 0.4417)
2022-07-01 14:28:10 - Precision with Cosine-Similarity:          65.79
2022-07-01 14:28:10 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:28:10 - Average Precision with Cosine-Similarity:  81.66

2022-07-01 14:28:10 - Accuracy with Manhatten-Distance:           71.36	(Threshold: 656.2858)
2022-07-01 14:28:10 - F1 with Manhatten-Distance:                 79.42	(Threshold: 656.2858)
2022-07-01 14:28:10 - Precision with Manhatten-Distance:          72.37
2022-07-01 14:28:10 - Recall with Manhatten-Distance:             88.00
2022-07-01 14:28:10 - Average Precision with Manhatten-Distance:  82.14

2022-07-01 14:28:10 - Accuracy with Euclidean-Distance:           68.84	(Threshold: 25.57

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:28:13 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 12 after 10 steps:
2022-07-01 14:28:13 - Accuracy with Cosine-Similarity:           68.84	(Threshold: 0.6408)
2022-07-01 14:28:13 - F1 with Cosine-Similarity:                 79.37	(Threshold: 0.4405)
2022-07-01 14:28:13 - Precision with Cosine-Similarity:          65.79
2022-07-01 14:28:13 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:28:13 - Average Precision with Cosine-Similarity:  81.68

2022-07-01 14:28:13 - Accuracy with Manhatten-Distance:           71.36	(Threshold: 655.5350)
2022-07-01 14:28:13 - F1 with Manhatten-Distance:                 79.42	(Threshold: 655.8931)
2022-07-01 14:28:13 - Precision with Manhatten-Distance:          72.37
2022-07-01 14:28:13 - Recall with Manhatten-Distance:             88.00
2022-07-01 14:28:13 - Average Precision with Manhatten-Distance:  82.17

2022-07-01 14:28:13 - Accuracy with Euclidean-Distance:           68.84	(Threshold: 25.57

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:28:16 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 13 after 10 steps:
2022-07-01 14:28:16 - Accuracy with Cosine-Similarity:           68.34	(Threshold: 0.6423)
2022-07-01 14:28:16 - F1 with Cosine-Similarity:                 79.37	(Threshold: 0.4397)
2022-07-01 14:28:16 - Precision with Cosine-Similarity:          65.79
2022-07-01 14:28:16 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:28:16 - Average Precision with Cosine-Similarity:  81.57

2022-07-01 14:28:16 - Accuracy with Manhatten-Distance:           70.85	(Threshold: 645.9094)
2022-07-01 14:28:16 - F1 with Manhatten-Distance:                 79.22	(Threshold: 753.9757)
2022-07-01 14:28:16 - Precision with Manhatten-Distance:          66.67
2022-07-01 14:28:16 - Recall with Manhatten-Distance:             97.60
2022-07-01 14:28:16 - Average Precision with Manhatten-Distance:  82.10

2022-07-01 14:28:16 - Accuracy with Euclidean-Distance:           68.84	(Threshold: 25.60

Iteration:   0%|          | 0/12 [00:00<?, ?it/s]

2022-07-01 14:28:19 - Binary Accuracy Evaluation of the model on sts-dev dataset in epoch 14 after 10 steps:
2022-07-01 14:28:19 - Accuracy with Cosine-Similarity:           68.34	(Threshold: 0.6422)
2022-07-01 14:28:19 - F1 with Cosine-Similarity:                 79.37	(Threshold: 0.4395)
2022-07-01 14:28:19 - Precision with Cosine-Similarity:          65.79
2022-07-01 14:28:19 - Recall with Cosine-Similarity:             100.00
2022-07-01 14:28:19 - Average Precision with Cosine-Similarity:  81.55

2022-07-01 14:28:19 - Accuracy with Manhatten-Distance:           70.85	(Threshold: 646.4954)
2022-07-01 14:28:19 - F1 with Manhatten-Distance:                 79.22	(Threshold: 754.0268)
2022-07-01 14:28:19 - Precision with Manhatten-Distance:          66.67
2022-07-01 14:28:19 - Recall with Manhatten-Distance:             97.60
2022-07-01 14:28:19 - Average Precision with Manhatten-Distance:  82.11

2022-07-01 14:28:19 - Accuracy with Euclidean-Distance:           68.84	(Threshold: 25.62

(SentenceTransformer(
   (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: RobertaModel 
   (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
 ),
 <sentence_transformers.evaluation.BinaryClassificationEvaluator.BinaryClassificationEvaluator at 0x7ff178ede910>)

#### Generating predictions for argument validity:

- Load best model and use it to predict the labels for the eval_df.
- Best model is: nli-roberta-large-ranking-loss-extra-data-2022-07-01_14-27-18

In [43]:
taska_valid_df = pd.read_csv('../data/TaskA_dev.csv')
taska_valid_df = taska_valid_df[taska_valid_df.Validity != 0]

In [47]:
def predict_labels(df, model_path, clm1, clm2, threshold):
    best_model = SentenceTransformer(model_path)

    encoded_premises = best_model.encode(taska_valid_df[clm1].tolist())
    encoded_conclusions = best_model.encode(taska_valid_df[clm2].tolist())
    
    scores = [util.pytorch_cos_sim(x[0], x[1]).item() for x in zip(encoded_conclusions, encoded_premises)]
    labels = [1 if x > threshold else -1 for x in scores]
    
    df['pred_labels'] = labels
    
    return df

In [58]:
eval_df = predict_labels(taska_valid_df, '../data/output/sentence-transformers/nli-roberta-large-ranking-loss-extra-data-2022-07-01_14-27-18/', 'Premise', 'Conclusion', 0.6146)
#eval_df = predict_labels(taska_valid_df, '../data/output/sentence-transformers/nli-roberta-large-ranking-loss-2022-07-01_14-24-04', 'Premise', 'Conclusion', 0.6146)

2022-07-01 15:09:48 - Load pretrained SentenceTransformer: ../data/output/sentence-transformers/nli-roberta-large-ranking-loss-extra-data-2022-07-01_14-27-18/
2022-07-01 15:09:50 - Use pytorch device: cuda


Batches:   0%|          | 0/7 [00:00<?, ?it/s]

Batches:   0%|          | 0/7 [00:00<?, ?it/s]

In [59]:
precision, recall, f1, _ = precision_recall_fscore_support(eval_df.Validity.tolist(), eval_df.pred_labels.tolist(), average='binary')

print('Precision: {}, Recall {}, F1: {}'.format(precision, recall, f1))

Precision: 0.7, Recall 0.896, F1: 0.7859649122807016
