Outras similaridades: Cos, Ovl, Jacc, Dice

Semantic Similarity

These models find semantically similar sentences within one language or across languages:

**distiluse-base-multilingual-cased-v2**: Multilingual knowledge distilled version of multilingual Universal Sentence Encoder. While the original mUSE model only supports 16 languages, this multilingual knowledge distilled version supports 50+ languages.

**xlm-r-distilroberta-base-paraphrase-v1** - Multilingual version of distilroberta-base-paraphrase-v1, trained on parallel data for 50+ languages.

**xlm-r-bert-base-nli-stsb-mean-tokens**: Produces similar embeddings as the bert-base-nli-stsb-mean-token model. Trained on parallel data for 50+ languages.

**distilbert-multilingual-nli-stsb-quora-ranking** - Multilingual version of distilbert-base-nli-stsb-quora-ranking. Fine-tuned with parallel data for 50+ languages.

**T-Systems-onsite/cross-en-de-roberta-sentence-transformer** - Multilingual model for English an German. [More]

# Imports e métodos necessários

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

In [2]:
import pandas as pd 
import xml.etree.ElementTree as et 

from scipy.stats import pearsonr

def parse_xml(xml_file):
    """Parse xml to pandas dataframe."""
    xtree = et.parse(xml_file)
    xroot = xtree.getroot() 

    df_cols = ['id', 't', 'h', 'similarity']
    rows = []

    for node in xroot:
        id_ = node.attrib.get("id")
        similarity = node.attrib.get("similarity")
        t = node.find("t").text
        h = node.find("h").text

        rows.append({
            "id": id_,
            "t": t, 
            "h": h,
            "similarity": similarity
        })
    return pd.DataFrame(rows, columns=df_cols, dtype=float)

def eval_similarity(pairs_gold, pairs_sys):
    '''
    Evaluate the semantic similarity output of the system against a gold score. 
    Results are printed to stdout.
    '''
    
    gold_values = np.array(pairs_gold)
    sys_values = np.array(pairs_sys)
    pearson = pearsonr(gold_values, sys_values)[0]
    absolute_diff = gold_values - sys_values
    mse = (absolute_diff ** 2).mean()
    
    print()
    print('Similarity evaluation')
    print('Pearson\t\tMean Squared Error')
    print('-------\t\t------------------')
    print('{:7.3f}\t\t{:18.2f}'.format(pearson, mse))

# Carregando os dados

In [3]:
!ls ../data/assin

assin-ptbr-dev.xml   assin-ptbr-train.xml  assin-ptpt-test.xml
assin-ptbr-test.xml  assin-ptpt-dev.xml    assin-ptpt-train.xml


In [4]:
!ls ../data/assin2

assin2-blind-test.xml  assin2-dev.xml  assin2-test.xml	assin2-train-only.xml


In [5]:
df_ptbr_train = parse_xml('../data/assin/assin-ptbr-train.xml')
df_ptbr_dev = parse_xml('../data/assin/assin-ptbr-dev.xml')
df_ptbr_test = parse_xml('../data/assin/assin-ptbr-test.xml')

df_ptpt_train = parse_xml('../data/assin/assin-ptpt-train.xml')
df_ptpt_dev = parse_xml('../data/assin/assin-ptpt-dev.xml')
df_ptpt_test = parse_xml('../data/assin/assin-ptpt-test.xml')

df_ptbr2_train = parse_xml('../data/assin2/assin2-train-only.xml')
df_ptbr2_dev = parse_xml('../data/assin2/assin2-dev.xml')
df_ptbr2_test = parse_xml('../data/assin2/assin2-test.xml')

## Gerar embeddings e salvar para iSklearn

In [6]:
from sentence_transformers import models, SentenceTransformer

model = SentenceTransformer('sbert-test', device='cpu')

In [20]:
t_embeddings.shape

(100, 1024)

In [22]:
np.concatenate((t_embeddings, t_embeddings), axis=1).shape

(100, 2048)

In [23]:
def save_embeddings(df, filename):
    t_embeddings = model.encode(df['t'], show_progress_bar=True)
    h_embeddings = model.encode(df['h'], show_progress_bar=True)
    embeddings = np.concatenate((t_embeddings, h_embeddings), axis=1)
    with open(filename + '.npy', 'wb') as f:
        np.save(f, embeddings)
    return embeddings

In [24]:
emb = save_embeddings(df_ptbr_train, 'emb_assin-pt-br_train')
print(emb.shape)
emb = save_embeddings(df_ptbr_dev, 'emb_assin-pt-br_dev')
print(emb.shape)
emb = save_embeddings(df_ptbr_test, 'emb_assin-pt-br_test')
print(emb.shape)

HBox(children=(FloatProgress(value=0.0, description='Batches', max=79.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=79.0, style=ProgressStyle(description_width…


(2500, 2048)


HBox(children=(FloatProgress(value=0.0, description='Batches', max=16.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=16.0, style=ProgressStyle(description_width…


(500, 2048)


HBox(children=(FloatProgress(value=0.0, description='Batches', max=63.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=63.0, style=ProgressStyle(description_width…


(2000, 2048)


In [25]:
emb = save_embeddings(df_ptpt_train, 'emb_assin-pt-pt_train')
print(emb.shape)
emb = save_embeddings(df_ptpt_dev, 'emb_assin-pt-pt_dev')
print(emb.shape)
emb = save_embeddings(df_ptpt_test, 'emb_assin-pt-pt_test')
print(emb.shape)

HBox(children=(FloatProgress(value=0.0, description='Batches', max=79.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=79.0, style=ProgressStyle(description_width…


(2500, 2048)


HBox(children=(FloatProgress(value=0.0, description='Batches', max=16.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=16.0, style=ProgressStyle(description_width…


(500, 2048)


HBox(children=(FloatProgress(value=0.0, description='Batches', max=63.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=63.0, style=ProgressStyle(description_width…


(2000, 2048)


In [26]:
emb = save_embeddings(df_ptbr2_train, 'emb_assin2-pt-br_train')
print(emb.shape)
emb = save_embeddings(df_ptbr2_dev, 'emb_assin2-pt-br_dev')
print(emb.shape)
emb = save_embeddings(df_ptbr2_test, 'emb_assin2-pt-br_test')
print(emb.shape)

HBox(children=(FloatProgress(value=0.0, description='Batches', max=204.0, style=ProgressStyle(description_widt…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=204.0, style=ProgressStyle(description_widt…


(6500, 2048)


HBox(children=(FloatProgress(value=0.0, description='Batches', max=16.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=16.0, style=ProgressStyle(description_width…


(500, 2048)


HBox(children=(FloatProgress(value=0.0, description='Batches', max=77.0, style=ProgressStyle(description_width…




HBox(children=(FloatProgress(value=0.0, description='Batches', max=77.0, style=ProgressStyle(description_width…


(2448, 2048)


In [29]:
emb[0].shape

(2048,)

## Concat

In [30]:
df_ptbr_train = pd.concat([df_ptbr_train, df_ptpt_train, df_ptbr2_train])
df_ptbr_dev = pd.concat([df_ptbr_dev, df_ptpt_dev, df_ptbr2_dev])
# df_ptbr_train = pd.concat([df_ptbr_train, df_ptpt_train, df_ptbr2_train])

In [31]:
print(f'assin-ptbr-train: {df_ptbr_train.shape}')
print(f'assin-ptbr-dev: {df_ptbr_dev.shape}')
print()
print(f'assin-ptbr-test: {df_ptbr_test.shape}')
print(f'assin-ptpt-test: {df_ptpt_test.shape}')
print(f'assin-ptbr2-test: {df_ptbr2_test.shape}')

assin-ptbr-train: (11500, 4)
assin-ptbr-dev: (1500, 4)

assin-ptbr-test: (2000, 4)
assin-ptpt-test: (2000, 4)
assin-ptbr2-test: (2448, 4)


In [8]:
df_ptbr_train.head()

Unnamed: 0,id,t,h,similarity
0,1.0,"A gente faz o aporte financeiro, é como se a e...",Fernando Moraes afirma que não tem vínculo com...,2.0
1,2.0,"Em 2013, a história de como Walt Disney conven...",P.L.Travers era completamente contra a adaptaç...,2.25
2,3.0,"David Silva bateu escanteio, Kompany escalou a...","David Silva cobrou escanteio, o zagueiro se ap...",3.75
3,4.0,"Para os ambientalistas, as metas anunciadas pe...","Dilma aproveitou seu discurso ontem, na Confer...",2.75
4,5.0,"De acordo com a PM, por volta das 10h30 havia ...",O protesto encerrou por volta de 12h15 (horári...,2.0


# Testes

## Fine-tuning Sentence-BERT

https://www.sbert.net/docs/training/overview.html

https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/sts/training_stsbenchmark_continue_training.py

In [34]:
from sklearn.metrics.pairwise import  cosine_similarity
from scipy.stats import spearmanr

def test_evaluation(df_ptbr_test):
    t_embeddings = model.encode(df_ptbr_test['t'])#,show_progress_bar=True)
    h_embeddings = model.encode(df_ptbr_test['h'])#,show_progress_bar=True)

    similarities = [5.0 * cosine_similarity([t], [h])[0][0] for t, h in zip(t_embeddings, h_embeddings)]
    pairs_gold = df_ptbr_test['similarity'].tolist()
    pairs_sys = similarities
    eval_similarity(pairs_gold, pairs_sys)
    
    print('Spearman correlation: {:7.3f}'.format(spearmanr(pairs_gold, pairs_sys)[0]))

In [10]:
#df_ptbr_train = pd.concat([df_ptbr_train, df_ptbr_dev])
df_ptbr_train.shape, df_ptbr_dev.shape

((11500, 4), (1500, 4))

In [11]:
"""
This example loads the pre-trained SentenceTransformer model 'bert-base-nli-mean-tokens' from the server.
It then fine-tunes this model for some epochs on the STS benchmark dataset.
Note: In this example, you must specify a SentenceTransformer model.
If you want to fine-tune a huggingface/transformers model like bert-base-uncased, see training_nli.py and training_stsbenchmark.py
"""
from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer,  SentencesDataset, LoggingHandler, losses, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from datetime import datetime
import os
import csv
from sentence_transformers import models, SentenceTransformer



# Read the dataset
for i in range(2, 10+1):
    train_batch_size = 16
    num_epochs = 4 # tentar distilroberta large
    model_save_path = f'../data/sbert_finetuning_assin2_nli-train-dev-v1-{i}'# + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")



    # Load a pre-trained sentence transformer model
    #model = SentenceTransformer(model_name)


    model = SentenceTransformer('sbert-test', device='cpu')

    train_samples = []
    dev_samples = []
    #test_samples = []

    for i, row in df_ptbr_train.iterrows():
        inp_example = InputExample(texts=[row['t'], row['h']], label=row['similarity'] / 5)
        train_samples.append(inp_example)

    for i, row in df_ptbr_dev.iterrows():
        inp_example = InputExample(texts=[row['t'], row['h']], label=row['similarity'] / 5)
        dev_samples.append(inp_example)

    #for i, row in df_ptbr_test.iterrows():
    #    inp_example = InputExample(texts=[row['t'], row['h']], label=row['similarity'] / 5)
    #    test_samples.append(inp_example)



    train_dataset = SentencesDataset(train_samples, model)
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
    train_loss = losses.CosineSimilarityLoss(model=model)


    evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')


    # Configure the training. We skip evaluation in this example
    warmup_steps = math.ceil(len(train_dataset) * num_epochs / train_batch_size * 0.1) #10% of train data for warm-up


    # Train the model
    model.fit(train_objectives=[(train_dataloader, train_loss)],
              evaluator=evaluator,
              epochs=num_epochs,
              evaluation_steps=1000,
              warmup_steps=warmup_steps,
              output_path=model_save_path)
    
    print('####################################################################################')
    print(f'Test {i}: ASSIN pt-br:')
    test_evaluation(df_ptbr_test)
    
    print(f'Test {i}: ASSIN pt-pt:')
    test_evaluation(df_ptpt_test)
    
    print(f'Test {i}: ASSIN 2:')
    test_evaluation(df_ptbr2_test)
    print('####################################################################################')
    

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…



####################################################################################
Test 499: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.812		              0.29
Spearman correlation:   0.800
Test 499: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.823		              0.49
Spearman correlation:   0.823
Test 499: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.858		              0.43
Spearman correlation:   0.834
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…



####################################################################################
Test 499: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.814		              0.29
Spearman correlation:   0.801
Test 499: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.827		              0.48
Spearman correlation:   0.826
Test 499: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.858		              0.43
Spearman correlation:   0.833
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…



####################################################################################
Test 499: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.809		              0.30
Spearman correlation:   0.796
Test 499: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.824		              0.49
Spearman correlation:   0.823
Test 499: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.855		              0.43
Spearman correlation:   0.829
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…



####################################################################################
Test 499: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.811		              0.29
Spearman correlation:   0.798
Test 499: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.826		              0.47
Spearman correlation:   0.826
Test 499: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.856		              0.42
Spearman correlation:   0.832
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…



####################################################################################
Test 499: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.812		              0.29
Spearman correlation:   0.799
Test 499: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.825		              0.47
Spearman correlation:   0.824
Test 499: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.857		              0.43
Spearman correlation:   0.833
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…



####################################################################################
Test 499: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.812		              0.29
Spearman correlation:   0.799
Test 499: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.826		              0.49
Spearman correlation:   0.825
Test 499: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.857		              0.43
Spearman correlation:   0.832
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…



####################################################################################
Test 499: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.811		              0.29
Spearman correlation:   0.798
Test 499: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.825		              0.49
Spearman correlation:   0.825
Test 499: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.859		              0.43
Spearman correlation:   0.833
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…



####################################################################################
Test 499: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.812		              0.29
Spearman correlation:   0.798
Test 499: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.825		              0.48
Spearman correlation:   0.826
Test 499: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.859		              0.42
Spearman correlation:   0.833
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=719.0, style=ProgressStyle(description_wi…



####################################################################################
Test 499: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.811		              0.29
Spearman correlation:   0.797
Test 499: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.825		              0.49
Spearman correlation:   0.824
Test 499: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.858		              0.42
Spearman correlation:   0.833
####################################################################################


Test 499: ASSIN pt-br:  

Similarity evaluation  
Pearson		Mean Squared Error  
-------		------------------
  0.814		              0.29
Spearman correlation:   0.801  
Test 499: ASSIN pt-pt:  

Similarity evaluation  
Pearson		Mean Squared Error  
-------		------------------
  0.825		              0.49  
Spearman correlation:   0.825  
Test 499: ASSIN 2:  
  
Similarity evaluation  
Pearson		Mean Squared Error  
-------		------------------
  0.859		              0.43  
Spearman correlation:   0.833

In [32]:
df_ptbr_train = pd.concat([df_ptbr_train, df_ptbr_dev])
df_ptbr_train.shape, df_ptbr_dev.shape

((13000, 4), (1500, 4))

In [36]:
"""
This example loads the pre-trained SentenceTransformer model 'bert-base-nli-mean-tokens' from the server.
It then fine-tunes this model for some epochs on the STS benchmark dataset.
Note: In this example, you must specify a SentenceTransformer model.
If you want to fine-tune a huggingface/transformers model like bert-base-uncased, see training_nli.py and training_stsbenchmark.py
"""
from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer,  SentencesDataset, LoggingHandler, losses, util, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from datetime import datetime
import os
import csv
from sentence_transformers import models, SentenceTransformer

train_samples = []
    #dev_samples = []
    #test_samples = []

for i, row in df_ptbr_train.iterrows():
    inp_example = InputExample(texts=[row['t'], row['h']], label=row['similarity'] / 5)
    train_samples.append(inp_example)

# Read the dataset
for i in range(2, 10+1):
    train_batch_size = 16
    num_epochs = 4 # tentar distilroberta large
    model_save_path = f'../data/sbert_finetuning_assin2_nli-v1-all-{i}'# + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")



    # Load a pre-trained sentence transformer model
    #model = SentenceTransformer(model_name)


    model = SentenceTransformer('sbert-test', device='cpu')

    #train_samples = []
    #dev_samples = []
    #test_samples = []

    #for i, row in df_ptbr_train.iterrows():
    #    inp_example = InputExample(texts=[row['t'], row['h']], label=row['similarity'] / 5)
    #    train_samples.append(inp_example)

    #for i, row in df_ptbr_dev.iterrows():
    #    inp_example = InputExample(texts=[row['t'], row['h']], label=row['similarity'] / 5)
    #    dev_samples.append(inp_example)

    #for i, row in df_ptbr_test.iterrows():
    #    inp_example = InputExample(texts=[row['t'], row['h']], label=row['similarity'] / 5)
    #    test_samples.append(inp_example)



    train_dataset = SentencesDataset(train_samples, model)
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
    train_loss = losses.CosineSimilarityLoss(model=model)


    #evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')


    # Configure the training. We skip evaluation in this example
    warmup_steps = math.ceil(len(train_dataset) * num_epochs / train_batch_size * 0.1) #10% of train data for warm-up


    # Train the model
    model.fit(train_objectives=[(train_dataloader, train_loss)],
              #evaluator=evaluator,
              epochs=num_epochs,
              #evaluation_steps=1000,
              warmup_steps=warmup_steps,
              output_path=model_save_path)
    
    print('####################################################################################')
    print(f'Test {i}: ASSIN pt-br:')
    test_evaluation(df_ptbr_test)
    
    print(f'Test {i}: ASSIN pt-pt:')
    test_evaluation(df_ptpt_test)
    
    print(f'Test {i}: ASSIN 2:')
    test_evaluation(df_ptbr2_test)
    print('####################################################################################')

HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…



####################################################################################
Test 2: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.815		              0.29
Spearman correlation:   0.804
Test 2: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.823		              0.49
Spearman correlation:   0.823
Test 2: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.855		              0.43
Spearman correlation:   0.830
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…



####################################################################################
Test 3: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.818		              0.28
Spearman correlation:   0.806
Test 3: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.826		              0.48
Spearman correlation:   0.826
Test 3: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.857		              0.43
Spearman correlation:   0.832
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…



####################################################################################
Test 4: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.815		              0.28
Spearman correlation:   0.803
Test 4: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.828		              0.47
Spearman correlation:   0.827
Test 4: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.856		              0.43
Spearman correlation:   0.830
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…



####################################################################################
Test 5: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.815		              0.29
Spearman correlation:   0.802
Test 5: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.827		              0.47
Spearman correlation:   0.827
Test 5: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.857		              0.43
Spearman correlation:   0.832
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…



####################################################################################
Test 6: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.802		              0.30
Spearman correlation:   0.789
Test 6: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.819		              0.48
Spearman correlation:   0.819
Test 6: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.855		              0.44
Spearman correlation:   0.829
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…



####################################################################################
Test 7: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.815		              0.29
Spearman correlation:   0.803
Test 7: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.827		              0.48
Spearman correlation:   0.828
Test 7: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.857		              0.43
Spearman correlation:   0.832
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…



####################################################################################
Test 8: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.815		              0.29
Spearman correlation:   0.803
Test 8: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.827		              0.48
Spearman correlation:   0.827
Test 8: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.857		              0.43
Spearman correlation:   0.831
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…



####################################################################################
Test 9: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.816		              0.28
Spearman correlation:   0.804
Test 9: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.830		              0.47
Spearman correlation:   0.828
Test 9: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.858		              0.42
Spearman correlation:   0.832
####################################################################################


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Iteration', max=813.0, style=ProgressStyle(description_wi…



####################################################################################
Test 10: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.815		              0.29
Spearman correlation:   0.802
Test 10: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.824		              0.48
Spearman correlation:   0.825
Test 10: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.856		              0.43
Spearman correlation:   0.831
####################################################################################


In [35]:
    print('####################################################################################')
    print(f'Test {i}: ASSIN pt-br:')
    test_evaluation(df_ptbr_test)
    
    print(f'Test {i}: ASSIN pt-pt:')
    test_evaluation(df_ptpt_test)
    
    print(f'Test {i}: ASSIN 2:')
    test_evaluation(df_ptbr2_test)
    print('####################################################################################')

####################################################################################
Test 1: ASSIN pt-br:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.815		              0.29
Spearman correlation:   0.804
Test 1: ASSIN pt-pt:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.826		              0.48
Spearman correlation:   0.826
Test 1: ASSIN 2:

Similarity evaluation
Pearson		Mean Squared Error
-------		------------------
  0.858		              0.43
Spearman correlation:   0.833
####################################################################################
