### Load case reports

In [3]:
import pandas as pd
filtered_df = pd.read_csv('data/filtered_case_reports.csv')


### Generate case report pairs

In [5]:
from itertools import combinations

# Function to calculate the score based on shared keywords
def calculate_score(keywords1, keywords2):
    return len(set(keywords1) & set(keywords2))

# Generate pairs of abstracts with their scores
abstract_pairs = []
for (idx1, row1), (idx2, row2) in combinations(filtered_df.iterrows(), 2):
    score = calculate_score(row1['processed_keywords'], row2['processed_keywords'])
    abstract_pairs.append((row1['abstract'], row2['abstract'], score, row1['processed_keywords'], row2['processed_keywords']))


#count pairs by score
from collections import Counter
score_counter = Counter([score for _, _, score, _, _ in abstract_pairs])
score_counter.most_common(10)

[(0, 145727), (1, 833), (2, 46), (3, 4), (4, 1)]

In [6]:
len(abstract_pairs)

146611

### Balance similar and non-similar pairs

In [7]:
#balance pairs by score
from sklearn.utils import resample

# Split the data into majority and minority classes
majority_class = [pair for pair in abstract_pairs if pair[2] == 0]
minority_class = [pair for pair in abstract_pairs if pair[2] > 0]

# Undersample the majority class to match the size of the minority class
undersampled_majority_class = resample(majority_class, 
                                       replace=False,     # Don't duplicate samples
                                       n_samples= len(minority_class),  # Match minority class size
                                       random_state=42)

# Combine undersampled majority class with the minority class
balanced_abstract_pairs = undersampled_majority_class + minority_class

# count pairs by score
score_counter = Counter([score for _, _, score, _, _ in balanced_abstract_pairs])
score_counter.most_common(10)

[(0, 884), (1, 833), (2, 46), (3, 4), (4, 1)]

### Similarity score based on the number of common keywords

In [8]:
def normalize_score(score):
    if score == 0:
        return 0
    elif score == 1:
        return 0.65
    elif score == 2:
        return 0.85
    elif score == 3:
        return 0.95
    elif score == 4:
        return 1

# Apply the normalization function
balanced_abstract_pairs = [(abstract1, abstract2, normalize_score(score), keywords1, keywords2)
                  for abstract1, abstract2, score, keywords1, keywords2 in balanced_abstract_pairs]

In [9]:
# sort by score
balanced_abstract_pairs = sorted(balanced_abstract_pairs, key=lambda x: x[2], reverse=True)
# print pairs
for abstract1, abstract2, score, k1, k2 in balanced_abstract_pairs[:5]:
    print(f"Score: {score}")
    print(f"Abstract 1: {abstract1}")
    print(f"Abstract 2: {abstract2}")
    print(f"Keywords 1: {k1}")
    print(f"Keywords 2: {k2}")
    print("-" * 80)

Score: 1
Abstract 1: A Doença de Still do Adulto (DSA) é uma doença inflamatória sistémica, rara, de etiologia desconhecida, cujas principais manifestações incluem febre elevada, exantema maculopapular evanescente, artralgias/artrite e odinofagia persistente. O diagnóstico é clínico e implica a exclusão de outras patologias infeciosas, autoimunes e neoplásicas. Os autores apresentam o caso clínico de doente de 22 anos, previamente saudável, com quadro com cerca de 5 meses de evolução, caraterizado por poliartralgias simétricas de ritmo inflamatório, com incapacidade funcional marcada, associadas a febre vespertina, exantema evanescente, astenia, anorexia e perda ponderal. Ao exame físico apresentava mucosas descoradas, sinovite das articulações metacarpofalangicas e interfalangicas proximais, e múltiplas adenopatias infra e pericentimétricas, não dolorosas, móveis, de consistência duro-elástica, cervicais, axilares e inguinais. Analiticamente com elevação marcada da ferritina, com fraç

### Generate test and train split

In [None]:
from sklearn.model_selection import train_test_split
import numpy as np


filtered_balanced_abstract_pairs = [(abstract1, abstract2, score) for abstract1, abstract2, score, _, _ in balanced_abstract_pairs]
df_ = pd.DataFrame(filtered_balanced_abstract_pairs, columns=['abstract1', 'abstract2', 'score'])


bins = np.arange(-0.1, 1.1, 0.1)  # Creates bins from 0.0 to 1.0 with a width of 0.1
labels = [f"{bins[i]}-{bins[i+1]}" for i in range(len(bins) - 1)]


df_['score_bin'] = pd.cut(df_['score'], bins=bins, labels=labels, right=True)

train_df, test_df = train_test_split(
    df_,
    test_size=0.2,  # 20% for testing
    stratify=df_['score_bin'],  # Stratify based on the score bins
    random_state=42  # For reproducibility
)

train_data = list(zip(train_df['abstract1'], train_df['abstract2'], train_df['score']))
test_data = list(zip(test_df['abstract1'], test_df['abstract2'], test_df['score']))

#print counts of entries by score
print(train_df['score_bin'].value_counts())
print(test_df['score_bin'].value_counts())

score_bin
-0.1-0.0                                   707
0.6000000000000001-0.7000000000000001      666
0.8-0.9                                     37
0.9-1.0                                      4
0.0-0.1                                      0
0.1-0.20000000000000004                      0
0.20000000000000004-0.30000000000000004      0
0.30000000000000004-0.4                      0
0.4-0.5000000000000001                       0
0.5000000000000001-0.6000000000000001        0
0.7000000000000001-0.8                       0
1.0-1.1                                      0
Name: count, dtype: int64
score_bin
-0.1-0.0                                   177
0.6000000000000001-0.7000000000000001      167
0.8-0.9                                      9
0.9-1.0                                      1
0.0-0.1                                      0
0.1-0.20000000000000004                      0
0.20000000000000004-0.30000000000000004      0
0.30000000000000004-0.4                      0
0.4-0.50000000

In [10]:
#converto to hugginface dataset
from datasets import Dataset
from datasets import DatasetDict
train_dataset = Dataset.from_pandas(pd.DataFrame(train_data, columns=["abstract1", "abstract2", "score"]))
test_dataset = Dataset.from_pandas(pd.DataFrame(test_data, columns=["abstract1", "abstract2", "score"]))



### Train the Bi-encoder model

In [None]:
from sentence_transformers import SentenceTransformer, losses


#model1 = SentenceTransformer('distiluse-base-multilingual-cased-v1')# 
model1 = SentenceTransformer('neuralmind/bert-base-portuguese-cased')
#model1 = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')
loss1= losses.CosineSimilarityLoss(model1)



In [None]:
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator



# 5. (Optional) Specify training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir="models/sentence_transformers/bert-base-portuguese",
    # Optional training parameters:
    num_train_epochs=10,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    load_best_model_at_end=True,
    logging_steps=100,
    run_name="sentence_transformers_bert-base-portuguese",  # Will be used in W&B if `wandb` is installed
)


# Create the evaluator
evaluator = EmbeddingSimilarityEvaluator(
    test_dataset['abstract1'],  # Assuming these are the sentence pairs for evaluation
    test_dataset['abstract2'],
    test_dataset['score'],  # Assuming this contains the similarity scores
    main_similarity=SimilarityFunction.COSINE,
)

# 6. Create the trainer & start training
trainer = SentenceTransformerTrainer(
    model=model1,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    loss=loss1,
    evaluator=evaluator,
)



In [15]:
trainer.train()

  5%|▌         | 100/1930 [00:47<14:17,  2.13it/s]

{'loss': 0.1076, 'grad_norm': 1.9443687200546265, 'learning_rate': 2.5906735751295337e-05, 'epoch': 0.52}



  5%|▌         | 100/1930 [00:56<14:17,  2.13it/s]

{'eval_loss': 0.12303422391414642, 'eval_pearson_cosine': 0.3108885805349722, 'eval_spearman_cosine': 0.31321639860465283, 'eval_pearson_manhattan': 0.28447131285553723, 'eval_spearman_manhattan': 0.2908359018332696, 'eval_pearson_euclidean': 0.28732107522405664, 'eval_spearman_euclidean': 0.2924603646297584, 'eval_pearson_dot': 0.29768883204450025, 'eval_spearman_dot': 0.30908777855913117, 'eval_pearson_max': 0.3108885805349722, 'eval_spearman_max': 0.31321639860465283, 'eval_runtime': 9.4749, 'eval_samples_per_second': 40.739, 'eval_steps_per_second': 5.172, 'epoch': 0.52}


 10%|█         | 200/1930 [01:57<13:22,  2.16it/s]  

{'loss': 0.0941, 'grad_norm': 1.3062092065811157, 'learning_rate': 4.9798503166378815e-05, 'epoch': 1.04}


                                                  
 10%|█         | 200/1930 [02:07<13:22,  2.16it/s]

{'eval_loss': 0.10089336335659027, 'eval_pearson_cosine': 0.3858927360621347, 'eval_spearman_cosine': 0.38779356571774665, 'eval_pearson_manhattan': 0.35197940741812633, 'eval_spearman_manhattan': 0.34829504704325426, 'eval_pearson_euclidean': 0.35097371489113693, 'eval_spearman_euclidean': 0.3414750374834622, 'eval_pearson_dot': 0.3373687375389275, 'eval_spearman_dot': 0.3778662138207852, 'eval_pearson_max': 0.3858927360621347, 'eval_spearman_max': 0.38779356571774665, 'eval_runtime': 9.4677, 'eval_samples_per_second': 40.77, 'eval_steps_per_second': 5.175, 'epoch': 1.04}


 16%|█▌        | 300/1930 [04:56<12:40,  2.14it/s]   

{'loss': 0.0703, 'grad_norm': 0.874003529548645, 'learning_rate': 4.6919976971790444e-05, 'epoch': 1.55}


                                                  
 16%|█▌        | 300/1930 [05:05<12:40,  2.14it/s]

{'eval_loss': 0.10245029628276825, 'eval_pearson_cosine': 0.41057248612341346, 'eval_spearman_cosine': 0.4185107128361596, 'eval_pearson_manhattan': 0.38656833141645774, 'eval_spearman_manhattan': 0.38433220580224264, 'eval_pearson_euclidean': 0.39111010238511656, 'eval_spearman_euclidean': 0.38722235871173655, 'eval_pearson_dot': 0.4120286560481582, 'eval_spearman_dot': 0.4183042223950691, 'eval_pearson_max': 0.4120286560481582, 'eval_spearman_max': 0.4185107128361596, 'eval_runtime': 9.3977, 'eval_samples_per_second': 41.074, 'eval_steps_per_second': 5.214, 'epoch': 1.55}


 21%|██        | 400/1930 [07:56<11:54,  2.14it/s]   

{'loss': 0.0667, 'grad_norm': 0.822307825088501, 'learning_rate': 4.404145077720208e-05, 'epoch': 2.07}


                                                  
 21%|██        | 400/1930 [08:06<11:54,  2.14it/s]

{'eval_loss': 0.08549556881189346, 'eval_pearson_cosine': 0.5036241981480915, 'eval_spearman_cosine': 0.5053366741799017, 'eval_pearson_manhattan': 0.4760708087894603, 'eval_spearman_manhattan': 0.4749544053415992, 'eval_pearson_euclidean': 0.48163249538140673, 'eval_spearman_euclidean': 0.4792225877232408, 'eval_pearson_dot': 0.4848915451122877, 'eval_spearman_dot': 0.4922904488157849, 'eval_pearson_max': 0.5036241981480915, 'eval_spearman_max': 0.5053366741799017, 'eval_runtime': 9.3846, 'eval_samples_per_second': 41.131, 'eval_steps_per_second': 5.221, 'epoch': 2.07}


 26%|██▌       | 500/1930 [10:31<11:07,  2.14it/s]   

{'loss': 0.0365, 'grad_norm': 1.0046660900115967, 'learning_rate': 4.116292458261371e-05, 'epoch': 2.59}


                                                  
 26%|██▌       | 500/1930 [10:41<11:07,  2.14it/s]

{'eval_loss': 0.0820518359541893, 'eval_pearson_cosine': 0.5286153227860985, 'eval_spearman_cosine': 0.536008291166514, 'eval_pearson_manhattan': 0.5096270098253503, 'eval_spearman_manhattan': 0.5010570795459028, 'eval_pearson_euclidean': 0.5167805079144356, 'eval_spearman_euclidean': 0.511100931519012, 'eval_pearson_dot': 0.5093352400375827, 'eval_spearman_dot': 0.519878261235715, 'eval_pearson_max': 0.5286153227860985, 'eval_spearman_max': 0.536008291166514, 'eval_runtime': 9.4201, 'eval_samples_per_second': 40.976, 'eval_steps_per_second': 5.202, 'epoch': 2.59}


 31%|███       | 600/1930 [13:06<10:22,  2.14it/s]   

{'loss': 0.0362, 'grad_norm': 0.43170350790023804, 'learning_rate': 3.8284398388025336e-05, 'epoch': 3.11}


                                                  
 31%|███       | 600/1930 [13:16<10:22,  2.14it/s]

{'eval_loss': 0.07897860556840897, 'eval_pearson_cosine': 0.5511650945137673, 'eval_spearman_cosine': 0.5595341738906355, 'eval_pearson_manhattan': 0.5247815492473156, 'eval_spearman_manhattan': 0.537751631591725, 'eval_pearson_euclidean': 0.5282643599423433, 'eval_spearman_euclidean': 0.5395149434585612, 'eval_pearson_dot': 0.5367217899008688, 'eval_spearman_dot': 0.5473823838050246, 'eval_pearson_max': 0.5511650945137673, 'eval_spearman_max': 0.5595341738906355, 'eval_runtime': 9.4029, 'eval_samples_per_second': 41.051, 'eval_steps_per_second': 5.211, 'epoch': 3.11}


 36%|███▋      | 700/1930 [15:32<09:33,  2.14it/s]   

{'loss': 0.0194, 'grad_norm': 0.294084757566452, 'learning_rate': 3.5405872193436964e-05, 'epoch': 3.63}


                                                  
 36%|███▋      | 700/1930 [15:41<09:33,  2.14it/s]

{'eval_loss': 0.0748530775308609, 'eval_pearson_cosine': 0.5836421009622086, 'eval_spearman_cosine': 0.5937280161386577, 'eval_pearson_manhattan': 0.5668005931706691, 'eval_spearman_manhattan': 0.567611671007051, 'eval_pearson_euclidean': 0.5690694714239689, 'eval_spearman_euclidean': 0.5734170600056991, 'eval_pearson_dot': 0.581455657454224, 'eval_spearman_dot': 0.5908416673132829, 'eval_pearson_max': 0.5836421009622086, 'eval_spearman_max': 0.5937280161386577, 'eval_runtime': 9.3904, 'eval_samples_per_second': 41.106, 'eval_steps_per_second': 5.218, 'epoch': 3.63}


 41%|████▏     | 800/1930 [17:50<08:52,  2.12it/s]  

{'loss': 0.0203, 'grad_norm': 0.172941192984581, 'learning_rate': 3.252734599884859e-05, 'epoch': 4.15}


                                                  
 41%|████▏     | 800/1930 [17:59<08:52,  2.12it/s]

{'eval_loss': 0.07728032767772675, 'eval_pearson_cosine': 0.5676324980954546, 'eval_spearman_cosine': 0.5785549511190722, 'eval_pearson_manhattan': 0.5618666693013393, 'eval_spearman_manhattan': 0.5653965052734418, 'eval_pearson_euclidean': 0.5630320606570053, 'eval_spearman_euclidean': 0.5680811187628749, 'eval_pearson_dot': 0.5517200537205584, 'eval_spearman_dot': 0.563797244534242, 'eval_pearson_max': 0.5676324980954546, 'eval_spearman_max': 0.5785549511190722, 'eval_runtime': 9.4492, 'eval_samples_per_second': 40.85, 'eval_steps_per_second': 5.186, 'epoch': 4.15}


 47%|████▋     | 900/1930 [20:30<08:00,  2.14it/s]   

{'loss': 0.0107, 'grad_norm': 0.3483406901359558, 'learning_rate': 2.9648819804260218e-05, 'epoch': 4.66}


                                                  
 47%|████▋     | 900/1930 [20:40<08:00,  2.14it/s]

{'eval_loss': 0.07212235778570175, 'eval_pearson_cosine': 0.6009928457117034, 'eval_spearman_cosine': 0.607077854966531, 'eval_pearson_manhattan': 0.5837290921988479, 'eval_spearman_manhattan': 0.5858073186145268, 'eval_pearson_euclidean': 0.5872377645929159, 'eval_spearman_euclidean': 0.591572051464152, 'eval_pearson_dot': 0.5928039718403, 'eval_spearman_dot': 0.6005971221591121, 'eval_pearson_max': 0.6009928457117034, 'eval_spearman_max': 0.607077854966531, 'eval_runtime': 9.4031, 'eval_samples_per_second': 41.05, 'eval_steps_per_second': 5.211, 'epoch': 4.66}


 52%|█████▏    | 1000/1930 [23:08<07:12,  2.15it/s] 

{'loss': 0.0081, 'grad_norm': 0.16087380051612854, 'learning_rate': 2.677029360967185e-05, 'epoch': 5.18}


                                                   
 52%|█████▏    | 1000/1930 [23:18<07:12,  2.15it/s]

{'eval_loss': 0.07039910554885864, 'eval_pearson_cosine': 0.6101176120849288, 'eval_spearman_cosine': 0.620023034346234, 'eval_pearson_manhattan': 0.586540810934096, 'eval_spearman_manhattan': 0.5960748982848911, 'eval_pearson_euclidean': 0.5861968973152418, 'eval_spearman_euclidean': 0.5947674821242045, 'eval_pearson_dot': 0.6002496657912917, 'eval_spearman_dot': 0.615482384439561, 'eval_pearson_max': 0.6101176120849288, 'eval_spearman_max': 0.620023034346234, 'eval_runtime': 9.4048, 'eval_samples_per_second': 41.043, 'eval_steps_per_second': 5.21, 'epoch': 5.18}


 57%|█████▋    | 1100/1930 [25:23<06:27,  2.14it/s]  

{'loss': 0.0055, 'grad_norm': 0.14665736258029938, 'learning_rate': 2.3891767415083478e-05, 'epoch': 5.7}


                                                   
 57%|█████▋    | 1100/1930 [25:32<06:27,  2.14it/s]

{'eval_loss': 0.0670602098107338, 'eval_pearson_cosine': 0.6313472046351272, 'eval_spearman_cosine': 0.6351544921957666, 'eval_pearson_manhattan': 0.6075140370751756, 'eval_spearman_manhattan': 0.6126681943348207, 'eval_pearson_euclidean': 0.6082927712724578, 'eval_spearman_euclidean': 0.6136717592764934, 'eval_pearson_dot': 0.6223093048156046, 'eval_spearman_dot': 0.6271745936125321, 'eval_pearson_max': 0.6313472046351272, 'eval_spearman_max': 0.6351544921957666, 'eval_runtime': 9.3933, 'eval_samples_per_second': 41.093, 'eval_steps_per_second': 5.217, 'epoch': 5.7}


 62%|██████▏   | 1200/1930 [27:28<05:41,  2.14it/s]  

{'loss': 0.0042, 'grad_norm': 0.17940840125083923, 'learning_rate': 2.101324122049511e-05, 'epoch': 6.22}


                                                   
 62%|██████▏   | 1200/1930 [27:37<05:41,  2.14it/s]

{'eval_loss': 0.06740705668926239, 'eval_pearson_cosine': 0.6289035261855593, 'eval_spearman_cosine': 0.6314617964147693, 'eval_pearson_manhattan': 0.6045384828993763, 'eval_spearman_manhattan': 0.6092547421036917, 'eval_pearson_euclidean': 0.6066856889240914, 'eval_spearman_euclidean': 0.6114447057800119, 'eval_pearson_dot': 0.6205617209832117, 'eval_spearman_dot': 0.6234807090552476, 'eval_pearson_max': 0.6289035261855593, 'eval_spearman_max': 0.6314617964147693, 'eval_runtime': 9.4123, 'eval_samples_per_second': 41.01, 'eval_steps_per_second': 5.206, 'epoch': 6.22}


 67%|██████▋   | 1300/1930 [29:31<04:53,  2.14it/s]  

{'loss': 0.0032, 'grad_norm': 0.2075013369321823, 'learning_rate': 1.813471502590674e-05, 'epoch': 6.74}


                                                   
 67%|██████▋   | 1300/1930 [29:41<04:53,  2.14it/s]

{'eval_loss': 0.06656575947999954, 'eval_pearson_cosine': 0.6350856185148942, 'eval_spearman_cosine': 0.6418905741496808, 'eval_pearson_manhattan': 0.6078023249513697, 'eval_spearman_manhattan': 0.6145470552567749, 'eval_pearson_euclidean': 0.608861224306063, 'eval_spearman_euclidean': 0.617296338176233, 'eval_pearson_dot': 0.6246217533652095, 'eval_spearman_dot': 0.6331523551728274, 'eval_pearson_max': 0.6350856185148942, 'eval_spearman_max': 0.6418905741496808, 'eval_runtime': 9.5691, 'eval_samples_per_second': 40.338, 'eval_steps_per_second': 5.121, 'epoch': 6.74}


 73%|███████▎  | 1400/1930 [31:32<04:07,  2.14it/s]  

{'loss': 0.0025, 'grad_norm': 0.11912062764167786, 'learning_rate': 1.5256188831318365e-05, 'epoch': 7.25}


                                                   
 73%|███████▎  | 1400/1930 [31:42<04:07,  2.14it/s]

{'eval_loss': 0.06695668399333954, 'eval_pearson_cosine': 0.6321267524012041, 'eval_spearman_cosine': 0.6372834716485523, 'eval_pearson_manhattan': 0.6073592495725669, 'eval_spearman_manhattan': 0.6127777995085025, 'eval_pearson_euclidean': 0.6081248323439696, 'eval_spearman_euclidean': 0.6148079916518201, 'eval_pearson_dot': 0.6212300474599614, 'eval_spearman_dot': 0.6282825331122229, 'eval_pearson_max': 0.6321267524012041, 'eval_spearman_max': 0.6372834716485523, 'eval_runtime': 9.4383, 'eval_samples_per_second': 40.897, 'eval_steps_per_second': 5.192, 'epoch': 7.25}


 78%|███████▊  | 1500/1930 [33:32<03:20,  2.14it/s]  

{'loss': 0.0022, 'grad_norm': 0.11310374736785889, 'learning_rate': 1.2377662636729996e-05, 'epoch': 7.77}


                                                   
 78%|███████▊  | 1500/1930 [33:42<03:20,  2.14it/s]

{'eval_loss': 0.06701614707708359, 'eval_pearson_cosine': 0.6328405092105908, 'eval_spearman_cosine': 0.6383750060354795, 'eval_pearson_manhattan': 0.6102724482673654, 'eval_spearman_manhattan': 0.6175585822251942, 'eval_pearson_euclidean': 0.6121044366561429, 'eval_spearman_euclidean': 0.6213969031013762, 'eval_pearson_dot': 0.623146586325852, 'eval_spearman_dot': 0.6292569730348587, 'eval_pearson_max': 0.6328405092105908, 'eval_spearman_max': 0.6383750060354795, 'eval_runtime': 9.5033, 'eval_samples_per_second': 40.618, 'eval_steps_per_second': 5.156, 'epoch': 7.77}


 83%|████████▎ | 1600/1930 [35:41<02:34,  2.14it/s]  

{'loss': 0.002, 'grad_norm': 0.0963059738278389, 'learning_rate': 9.499136442141624e-06, 'epoch': 8.29}


                                                   
 83%|████████▎ | 1600/1930 [35:50<02:34,  2.14it/s]

{'eval_loss': 0.06658241152763367, 'eval_pearson_cosine': 0.6353737854569683, 'eval_spearman_cosine': 0.6412376982127386, 'eval_pearson_manhattan': 0.6109220901040936, 'eval_spearman_manhattan': 0.6178669507940996, 'eval_pearson_euclidean': 0.6128073325086839, 'eval_spearman_euclidean': 0.6207751731031592, 'eval_pearson_dot': 0.6244542299268359, 'eval_spearman_dot': 0.631202999817041, 'eval_pearson_max': 0.6353737854569683, 'eval_spearman_max': 0.6412376982127386, 'eval_runtime': 9.5044, 'eval_samples_per_second': 40.613, 'eval_steps_per_second': 5.155, 'epoch': 8.29}


 88%|████████▊ | 1700/1930 [38:03<01:47,  2.15it/s]  

{'loss': 0.0015, 'grad_norm': 0.1059274822473526, 'learning_rate': 6.620610247553253e-06, 'epoch': 8.81}


                                                   
 88%|████████▊ | 1700/1930 [38:12<01:47,  2.15it/s]

{'eval_loss': 0.06688278913497925, 'eval_pearson_cosine': 0.6328280697613957, 'eval_spearman_cosine': 0.6390455947391014, 'eval_pearson_manhattan': 0.6103181572744005, 'eval_spearman_manhattan': 0.6166986214590178, 'eval_pearson_euclidean': 0.6123258632962738, 'eval_spearman_euclidean': 0.6202333288714434, 'eval_pearson_dot': 0.6233826332228894, 'eval_spearman_dot': 0.6295143431010424, 'eval_pearson_max': 0.6328280697613957, 'eval_spearman_max': 0.6390455947391014, 'eval_runtime': 9.4193, 'eval_samples_per_second': 40.98, 'eval_steps_per_second': 5.202, 'epoch': 8.81}


 93%|█████████▎| 1800/1930 [40:07<01:00,  2.14it/s]  

{'loss': 0.0011, 'grad_norm': 0.1435842514038086, 'learning_rate': 3.7420840529648822e-06, 'epoch': 9.33}


                                                   
 93%|█████████▎| 1800/1930 [40:16<01:00,  2.14it/s]

{'eval_loss': 0.06656788289546967, 'eval_pearson_cosine': 0.6350912525110266, 'eval_spearman_cosine': 0.6402862016724461, 'eval_pearson_manhattan': 0.613454472669951, 'eval_spearman_manhattan': 0.619947665929624, 'eval_pearson_euclidean': 0.615415555571956, 'eval_spearman_euclidean': 0.6237018214446708, 'eval_pearson_dot': 0.6249343194235834, 'eval_spearman_dot': 0.6321010014244085, 'eval_pearson_max': 0.6350912525110266, 'eval_spearman_max': 0.6402862016724461, 'eval_runtime': 9.4452, 'eval_samples_per_second': 40.867, 'eval_steps_per_second': 5.188, 'epoch': 9.33}


 98%|█████████▊| 1900/1930 [43:10<00:13,  2.15it/s]  

{'loss': 0.0011, 'grad_norm': 0.07405678182840347, 'learning_rate': 8.635578583765112e-07, 'epoch': 9.84}


                                                   
 98%|█████████▊| 1900/1930 [43:19<00:13,  2.15it/s]

{'eval_loss': 0.06656013429164886, 'eval_pearson_cosine': 0.6352652863839424, 'eval_spearman_cosine': 0.6401309474893349, 'eval_pearson_manhattan': 0.6134625003830313, 'eval_spearman_manhattan': 0.6191032781328102, 'eval_pearson_euclidean': 0.6154400849864002, 'eval_spearman_euclidean': 0.6236359632383587, 'eval_pearson_dot': 0.6253966365594591, 'eval_spearman_dot': 0.6313444641952204, 'eval_pearson_max': 0.6352652863839424, 'eval_spearman_max': 0.6401309474893349, 'eval_runtime': 9.4628, 'eval_samples_per_second': 40.791, 'eval_steps_per_second': 5.178, 'epoch': 9.84}


100%|██████████| 1930/1930 [45:44<00:00,  1.42s/it]

{'train_runtime': 2744.0541, 'train_samples_per_second': 5.612, 'train_steps_per_second': 0.703, 'train_loss': 0.025580150809698773, 'epoch': 10.0}





TrainOutput(global_step=1930, training_loss=0.025580150809698773, metrics={'train_runtime': 2744.0541, 'train_samples_per_second': 5.612, 'train_steps_per_second': 0.703, 'total_flos': 0.0, 'train_loss': 0.025580150809698773, 'epoch': 10.0})

### Test bi-encoder model

In [16]:



# 7. Evaluate the model performance on the STS Benchmark test dataset
test_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=test_dataset["abstract1"],
    sentences2=test_dataset["abstract2"],
    scores=test_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-test",
)
test_evaluator(model1)


{'sts-test_pearson_cosine': 0.6352652863839424,
 'sts-test_spearman_cosine': 0.6401309474893349,
 'sts-test_pearson_manhattan': 0.6134625003830313,
 'sts-test_spearman_manhattan': 0.6191032781328102,
 'sts-test_pearson_euclidean': 0.6154400849864002,
 'sts-test_spearman_euclidean': 0.6236359632383587,
 'sts-test_pearson_dot': 0.6253966365594591,
 'sts-test_spearman_dot': 0.6313444641952204,
 'sts-test_pearson_max': 0.6352652863839424,
 'sts-test_spearman_max': 0.6401309474893349}

In [None]:


# 8. Save the trained model
model1.save_pretrained("directory_path")