In [None]:
!python sentence-transformers/examples/datasets/get_data.py

In [None]:
import sys
sys.path.insert(0, "../lib")

from STSDataReaderBinary import STSDataReaderBinary
from STSDataReaderBinaryPositives import STSDataReaderBinaryPositives
from BSCLoss import BSCLoss, ComboBSCLoss
from BSCShuffler import ShuffledSentencesDataset, ShuffledSentenceTransformer
from BSCShuffler import BSCShuffler, ModelBSCShuffler, ModelExampleBasedShuffler
from torch.utils.data import DataLoader
import math
import os
from sentence_transformers import models, losses
from sentence_transformers import SentencesDataset, LoggingHandler, SentenceTransformer, evaluation
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, TripletEvaluator, SimilarityFunction
from sentence_transformers.readers import *
import pandas as pd
import logging
import csv

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])

In [None]:
word_embedding_model = models.Transformer('bert-base-uncased', max_seq_length=124)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                               pooling_mode_mean_tokens=True,
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False)

model = ShuffledSentenceTransformer(modules=[word_embedding_model, pooling_model], device='cuda')

# or use the trained model
#model = ShuffledSentenceTransformer('checkpoints_/bsc_sts')

# to train tau
# model._first_module().auto_model.config.output_hidden_states = True

sts_reader_pos = STSDataReaderBinaryPositives('sentence-transformers/examples/datasets/stsbenchmark', 
                           s1_col_idx=5, s2_col_idx=6, score_col_idx=4,normalize_scores=True, thr=0.6,
                                             get_positives=False)
sts_reader = STSDataReader('sentence-transformers/examples/datasets/stsbenchmark', 
                           s1_col_idx=5, s2_col_idx=6, score_col_idx=4,normalize_scores=True)

train_batch_size = 30
num_epochs = 5

train_data_bsc = ShuffledSentencesDataset(sts_reader_pos.get_examples('sts-train.csv'), model)
train_dataloader_bsc = DataLoader(train_data_bsc, shuffle=False, batch_size=train_batch_size)
train_loss_bsc = BSCLoss(model=model, tau=0.1)

train_data = SentencesDataset(sts_reader.get_examples('sts-train.csv'), model)
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)


evaluator = EmbeddingSimilarityEvaluator.from_input_examples(sts_reader.get_examples('sts-test.csv'), name='sts-test')
evaluator.device = 'cuda'
evaluator.main_similarity = SimilarityFunction.COSINE
evaluator_dev = EmbeddingSimilarityEvaluator.from_input_examples(sts_reader.get_examples('sts-dev.csv'), name='sts-dev')
evaluator_dev.device = 'cuda'
evaluator_dev.main_similarity = SimilarityFunction.COSINE
seq_evaluator = evaluation.SequentialEvaluator([evaluator, evaluator_dev],
                                               main_score_function=lambda scores: scores[-1])

warmup_steps = math.ceil(len(train_data)*num_epochs/train_batch_size*0.1)
model_save_path = 'checkpoints_/bsc-mse_sts'

get_ipython().system("rm -rf 'checkpoints_/bsc-mse_sts'")


shuffler = ModelExampleBasedShuffler(group_size=7, allow_same=True)
#shuffler = ModelBSCShuffler(group_size=7, by_clusters=True, num_clusters=200,
#                            file_name=None, output_file_name=None, column_name=None, max_ind=None)#

model.fit(train_objectives=[(train_dataloader_bsc, train_loss_bsc)],
          evaluator=seq_evaluator,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path=model_save_path,
          optimizer_params={'alpha_lr': 0.005, 'lr': 2e-5, 'eps': 1e-6, 'correct_bias': True},
          shuffler=shuffler,
          shuffle_idxs=[0]
         )

model = ShuffledSentenceTransformer('checkpoints_/bsc_sts')
model.evaluate(evaluator)